diff --git a/.claude/rules/loc-exceptions.txt b/.claude/rules/loc-exceptions.txt index c3cced7..d557b94 100644 --- a/.claude/rules/loc-exceptions.txt +++ b/.claude/rules/loc-exceptions.txt @@ -39,6 +39,9 @@ **/lib/sdk/vvt-baseline-catalog.ts | owner=admin-lehrer | reason=Pure data catalog (630 LOC, BaselineTemplate[] literals) | review=2027-01-01 **/lib/sdk/loeschfristen-baseline-catalog.ts | owner=admin-lehrer | reason=Pure data catalog (578 LOC, retention period templates) | review=2027-01-01 +# Single SSE generator orchestrating 6 pipeline steps — cannot split generator context +**/ocr_pipeline_auto_steps.py | owner=klausur | reason=run_auto is a single async generator yielding SSE events across 6 steps (528 LOC) | review=2026-10-01 + # Legacy — TEMPORAER bis Refactoring abgeschlossen # Dateien hier werden Phase fuer Phase abgearbeitet und entfernt. # KEINE neuen Ausnahmen ohne [guardrail-change] Commit-Marker! diff --git a/backend-lehrer/ai_processing/print_cloze.py b/backend-lehrer/ai_processing/print_cloze.py new file mode 100644 index 0000000..5b12447 --- /dev/null +++ b/backend-lehrer/ai_processing/print_cloze.py @@ -0,0 +1,193 @@ +""" +AI Processing - Print Version Generator: Cloze (Lueckentext). + +Generates printable HTML for cloze/fill-in-the-blank worksheets. +""" + +from pathlib import Path +import json +import random +import logging + +from .core import BEREINIGT_DIR + +logger = logging.getLogger(__name__) + + +def generate_print_version_cloze(cloze_path: Path, include_answers: bool = False) -> Path: + """ + Generiert eine druckbare HTML-Version der Lueckentexte. + + Args: + cloze_path: Pfad zur *_cloze.json Datei + include_answers: True fuer Loesungsblatt (fuer Eltern) + + Returns: + Pfad zur generierten HTML-Datei + """ + if not cloze_path.exists(): + raise FileNotFoundError(f"Cloze-Datei nicht gefunden: {cloze_path}") + + cloze_data = json.loads(cloze_path.read_text(encoding="utf-8")) + items = cloze_data.get("cloze_items", []) + metadata = cloze_data.get("metadata", {}) + + title = metadata.get("source_title", "Arbeitsblatt") + subject = metadata.get("subject", "") + grade = metadata.get("grade_level", "") + total_gaps = metadata.get("total_gaps", 0) + + html_parts = [] + html_parts.append(""" + + + +""" + title + """ - Lueckentext + + + +""") + + # Header + version_text = "Loesungsblatt" if include_answers else "Lueckentext" + html_parts.append(f"

{title} - {version_text}

") + meta_parts = [] + if subject: + meta_parts.append(f"Fach: {subject}") + if grade: + meta_parts.append(f"Klasse: {grade}") + meta_parts.append(f"Luecken gesamt: {total_gaps}") + html_parts.append(f"
{' | '.join(meta_parts)}
") + + # Sammle alle Lueckenwoerter fuer Wortbank + all_words = [] + + # Lueckentexte + for idx, item in enumerate(items, 1): + html_parts.append("
") + html_parts.append(f"
{idx}.
") + + gaps = item.get("gaps", []) + sentence = item.get("sentence_with_gaps", "") + + if include_answers: + # Loesungsblatt: Luecken mit Antworten fuellen + for gap in gaps: + word = gap.get("word", "") + sentence = sentence.replace("___", f"{word}", 1) + else: + # Fragenblatt: Luecken als Linien + sentence = sentence.replace("___", " ") + # Woerter fuer Wortbank sammeln + for gap in gaps: + all_words.append(gap.get("word", "")) + + html_parts.append(f"
{sentence}
") + + # Uebersetzung anzeigen + translation = item.get("translation", {}) + if translation: + lang_name = translation.get("language_name", "Uebersetzung") + full_sentence = translation.get("full_sentence", "") + if full_sentence: + html_parts.append("
") + html_parts.append(f"
{lang_name}:
") + html_parts.append(full_sentence) + html_parts.append("
") + + html_parts.append("
") + + # Wortbank (nur fuer Fragenblatt) + if not include_answers and all_words: + random.shuffle(all_words) # Mische die Woerter + html_parts.append("
") + html_parts.append("
Wortbank (diese Woerter fehlen):
") + for word in all_words: + html_parts.append(f"{word}") + html_parts.append("
") + + html_parts.append("") + + # Speichern + suffix = "_cloze_solutions.html" if include_answers else "_cloze_print.html" + out_name = cloze_path.stem.replace("_cloze", "") + suffix + out_path = BEREINIGT_DIR / out_name + out_path.write_text("\n".join(html_parts), encoding="utf-8") + + logger.info(f"Cloze Print-Version gespeichert: {out_path.name}") + return out_path diff --git a/backend-lehrer/ai_processing/print_generator.py b/backend-lehrer/ai_processing/print_generator.py index df6fc9f..5dd3cb1 100644 --- a/backend-lehrer/ai_processing/print_generator.py +++ b/backend-lehrer/ai_processing/print_generator.py @@ -1,824 +1,22 @@ """ -AI Processing - Print Version Generator. +AI Processing - Print Version Generator — Barrel Re-export. -Generiert druckbare HTML-Versionen für verschiedene Arbeitsblatt-Typen. +Generiert druckbare HTML-Versionen fuer verschiedene Arbeitsblatt-Typen. +Split into: + - print_qa.py: Q&A print generation + - print_cloze.py: Cloze/Lueckentext print generation + - print_mc.py: Multiple Choice print generation + - print_worksheet.py: General worksheet print generation """ -from pathlib import Path -import json -import random -import logging - -from .core import BEREINIGT_DIR - -logger = logging.getLogger(__name__) - - -def generate_print_version_qa(qa_path: Path, include_answers: bool = False) -> Path: - """ - Generiert eine druckbare HTML-Version der Frage-Antwort-Paare. - - Args: - qa_path: Pfad zur *_qa.json Datei - include_answers: True für Lösungsblatt (für Eltern) - - Returns: - Pfad zur generierten HTML-Datei - """ - if not qa_path.exists(): - raise FileNotFoundError(f"Q&A-Datei nicht gefunden: {qa_path}") - - qa_data = json.loads(qa_path.read_text(encoding="utf-8")) - items = qa_data.get("qa_items", []) - metadata = qa_data.get("metadata", {}) - - title = metadata.get("source_title", "Arbeitsblatt") - subject = metadata.get("subject", "") - grade = metadata.get("grade_level", "") - - html_parts = [] - html_parts.append(""" - - - -""" + title + """ - Fragen - - - -""") - - # Header - version_text = "Lösungsblatt" if include_answers else "Fragenblatt" - html_parts.append(f"

{title} - {version_text}

") - meta_parts = [] - if subject: - meta_parts.append(f"Fach: {subject}") - if grade: - meta_parts.append(f"Klasse: {grade}") - meta_parts.append(f"Anzahl Fragen: {len(items)}") - html_parts.append(f"
{' | '.join(meta_parts)}
") - - # Fragen - for idx, item in enumerate(items, 1): - html_parts.append("
") - html_parts.append(f"
Frage {idx}
") - html_parts.append(f"
{item.get('question', '')}
") - - if include_answers: - # Lösungsblatt: Antwort anzeigen - html_parts.append(f"
Antwort: {item.get('answer', '')}
") - # Schlüsselbegriffe - key_terms = item.get("key_terms", []) - if key_terms: - terms_html = " ".join([f"{term}" for term in key_terms]) - html_parts.append(f"
Wichtige Begriffe: {terms_html}
") - else: - # Fragenblatt: Antwortlinien - html_parts.append("
") - for _ in range(3): - html_parts.append("
") - html_parts.append("
") - - html_parts.append("
") - - html_parts.append("") - - # Speichern - suffix = "_qa_solutions.html" if include_answers else "_qa_print.html" - out_name = qa_path.stem.replace("_qa", "") + suffix - out_path = BEREINIGT_DIR / out_name - out_path.write_text("\n".join(html_parts), encoding="utf-8") - - logger.info(f"Print-Version gespeichert: {out_path.name}") - return out_path - - -def generate_print_version_cloze(cloze_path: Path, include_answers: bool = False) -> Path: - """ - Generiert eine druckbare HTML-Version der Lückentexte. - - Args: - cloze_path: Pfad zur *_cloze.json Datei - include_answers: True für Lösungsblatt (für Eltern) - - Returns: - Pfad zur generierten HTML-Datei - """ - if not cloze_path.exists(): - raise FileNotFoundError(f"Cloze-Datei nicht gefunden: {cloze_path}") - - cloze_data = json.loads(cloze_path.read_text(encoding="utf-8")) - items = cloze_data.get("cloze_items", []) - metadata = cloze_data.get("metadata", {}) - - title = metadata.get("source_title", "Arbeitsblatt") - subject = metadata.get("subject", "") - grade = metadata.get("grade_level", "") - total_gaps = metadata.get("total_gaps", 0) - - html_parts = [] - html_parts.append(""" - - - -""" + title + """ - Lückentext - - - -""") - - # Header - version_text = "Lösungsblatt" if include_answers else "Lückentext" - html_parts.append(f"

{title} - {version_text}

") - meta_parts = [] - if subject: - meta_parts.append(f"Fach: {subject}") - if grade: - meta_parts.append(f"Klasse: {grade}") - meta_parts.append(f"Lücken gesamt: {total_gaps}") - html_parts.append(f"
{' | '.join(meta_parts)}
") - - # Sammle alle Lückenwörter für Wortbank - all_words = [] - - # Lückentexte - for idx, item in enumerate(items, 1): - html_parts.append("
") - html_parts.append(f"
{idx}.
") - - gaps = item.get("gaps", []) - sentence = item.get("sentence_with_gaps", "") - - if include_answers: - # Lösungsblatt: Lücken mit Antworten füllen - for gap in gaps: - word = gap.get("word", "") - sentence = sentence.replace("___", f"{word}", 1) - else: - # Fragenblatt: Lücken als Linien - sentence = sentence.replace("___", " ") - # Wörter für Wortbank sammeln - for gap in gaps: - all_words.append(gap.get("word", "")) - - html_parts.append(f"
{sentence}
") - - # Übersetzung anzeigen - translation = item.get("translation", {}) - if translation: - lang_name = translation.get("language_name", "Übersetzung") - full_sentence = translation.get("full_sentence", "") - if full_sentence: - html_parts.append("
") - html_parts.append(f"
{lang_name}:
") - html_parts.append(full_sentence) - html_parts.append("
") - - html_parts.append("
") - - # Wortbank (nur für Fragenblatt) - if not include_answers and all_words: - random.shuffle(all_words) # Mische die Wörter - html_parts.append("
") - html_parts.append("
Wortbank (diese Wörter fehlen):
") - for word in all_words: - html_parts.append(f"{word}") - html_parts.append("
") - - html_parts.append("") - - # Speichern - suffix = "_cloze_solutions.html" if include_answers else "_cloze_print.html" - out_name = cloze_path.stem.replace("_cloze", "") + suffix - out_path = BEREINIGT_DIR / out_name - out_path.write_text("\n".join(html_parts), encoding="utf-8") - - logger.info(f"Cloze Print-Version gespeichert: {out_path.name}") - return out_path - - -def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> str: - """ - Generiert eine druckbare HTML-Version der Multiple-Choice-Fragen. - - Args: - mc_path: Pfad zur *_mc.json Datei - include_answers: True für Lösungsblatt mit markierten richtigen Antworten - - Returns: - HTML-String (zum direkten Ausliefern) - """ - if not mc_path.exists(): - raise FileNotFoundError(f"MC-Datei nicht gefunden: {mc_path}") - - mc_data = json.loads(mc_path.read_text(encoding="utf-8")) - questions = mc_data.get("questions", []) - metadata = mc_data.get("metadata", {}) - - title = metadata.get("source_title", "Arbeitsblatt") - subject = metadata.get("subject", "") - grade = metadata.get("grade_level", "") - - html_parts = [] - html_parts.append(""" - - - -""" + title + """ - Multiple Choice - - - -""") - - # Header - version_text = "Lösungsblatt" if include_answers else "Multiple Choice Test" - html_parts.append(f"

{title}

") - html_parts.append(f"
{version_text}") - if subject: - html_parts.append(f" | Fach: {subject}") - if grade: - html_parts.append(f" | Klasse: {grade}") - html_parts.append(f" | Anzahl Fragen: {len(questions)}
") - - if not include_answers: - html_parts.append("
") - html_parts.append("Anleitung: Kreuze bei jeder Frage die richtige Antwort an. ") - html_parts.append("Es ist immer nur eine Antwort richtig.") - html_parts.append("
") - - # Fragen - for idx, q in enumerate(questions, 1): - html_parts.append("
") - html_parts.append(f"
Frage {idx}
") - html_parts.append(f"
{q.get('question', '')}
") - - html_parts.append("
") - correct_answer = q.get("correct_answer", "") - - for opt in q.get("options", []): - opt_id = opt.get("id", "") - is_correct = opt_id == correct_answer - - opt_class = "option" - checkbox_class = "option-checkbox" - if include_answers and is_correct: - opt_class += " option-correct" - checkbox_class += " checked" - - html_parts.append(f"
") - html_parts.append(f"
") - html_parts.append(f"{opt_id})") - html_parts.append(f"{opt.get('text', '')}") - html_parts.append("
") - - html_parts.append("
") - - # Erklärung nur bei Lösungsblatt - if include_answers and q.get("explanation"): - html_parts.append(f"
Erklärung: {q.get('explanation')}
") - - html_parts.append("
") - - # Lösungsschlüssel (kompakt) - nur bei Lösungsblatt - if include_answers: - html_parts.append("
") - html_parts.append("
Lösungsschlüssel
") - html_parts.append("
") - for idx, q in enumerate(questions, 1): - html_parts.append("
") - html_parts.append(f"{idx}. ") - html_parts.append(f"{q.get('correct_answer', '')}") - html_parts.append("
") - html_parts.append("
") - html_parts.append("
") - - html_parts.append("") - - return "\n".join(html_parts) - - -def generate_print_version_worksheet(analysis_path: Path) -> str: - """ - Generiert eine druckoptimierte HTML-Version des Arbeitsblatts. - - Eigenschaften: - - Große, gut lesbare Schrift (16pt) - - Schwarz-weiß / Graustufen-tauglich - - Klare Struktur für Druck - - Keine interaktiven Elemente - - Args: - analysis_path: Pfad zur *_analyse.json Datei - - Returns: - HTML-String zum direkten Ausliefern - """ - if not analysis_path.exists(): - raise FileNotFoundError(f"Analysedatei nicht gefunden: {analysis_path}") - - try: - data = json.loads(analysis_path.read_text(encoding="utf-8")) - except json.JSONDecodeError as e: - raise RuntimeError(f"Analyse-Datei enthält kein gültiges JSON: {analysis_path}\n{e}") from e - - title = data.get("title") or "Arbeitsblatt" - subject = data.get("subject") or "" - grade_level = data.get("grade_level") or "" - instructions = data.get("instructions") or "" - tasks = data.get("tasks", []) or [] - canonical_text = data.get("canonical_text") or "" - printed_blocks = data.get("printed_blocks") or [] - - html_parts = [] - html_parts.append(""" - - - -""" + title + """ - - - - -""") - - # Titel - html_parts.append(f"

{title}

") - - # Meta-Informationen - meta_parts = [] - if subject: - meta_parts.append(f"Fach: {subject}") - if grade_level: - meta_parts.append(f"Klasse: {grade_level}") - if meta_parts: - html_parts.append(f"
{''.join(meta_parts)}
") - - # Arbeitsanweisung - if instructions: - html_parts.append("
") - html_parts.append("
Arbeitsanweisung:
") - html_parts.append(f"
{instructions}
") - html_parts.append("
") - - # Haupttext / gedruckte Blöcke - if printed_blocks: - html_parts.append("
") - for block in printed_blocks: - role = (block.get("role") or "body").lower() - text = (block.get("text") or "").strip() - if not text: - continue - if role == "title": - html_parts.append(f"
{text}
") - else: - html_parts.append(f"
{text}
") - html_parts.append("
") - elif canonical_text: - html_parts.append("
") - paragraphs = [ - p.strip() - for p in canonical_text.replace("\r\n", "\n").split("\n\n") - if p.strip() - ] - for p in paragraphs: - html_parts.append(f"
{p}
") - html_parts.append("
") - - # Aufgaben - if tasks: - html_parts.append("
") - html_parts.append("

Aufgaben

") - - for idx, task in enumerate(tasks, start=1): - t_type = task.get("type") or "Aufgabe" - desc = task.get("description") or "" - text_with_gaps = task.get("text_with_gaps") - - html_parts.append("
") - - # Task-Header - type_label = { - "fill_in_blank": "Lückentext", - "multiple_choice": "Multiple Choice", - "free_text": "Freitext", - "matching": "Zuordnung", - "labeling": "Beschriftung", - "calculation": "Rechnung", - "other": "Aufgabe" - }.get(t_type, t_type) - - html_parts.append(f"
Aufgabe {idx}: {type_label}
") - - if desc: - html_parts.append(f"
{desc}
") - - if text_with_gaps: - rendered = text_with_gaps.replace("___", " ") - html_parts.append(f"
{rendered}
") - - # Antwortlinien für Freitext-Aufgaben - if t_type in ["free_text", "other"] or (not text_with_gaps and not desc): - html_parts.append("
") - for _ in range(3): - html_parts.append("
") - html_parts.append("
") - - html_parts.append("
") - - html_parts.append("
") - - # Fußzeile - html_parts.append("") - - html_parts.append("") - - return "\n".join(html_parts) +from .print_qa import generate_print_version_qa +from .print_cloze import generate_print_version_cloze +from .print_mc import generate_print_version_mc +from .print_worksheet import generate_print_version_worksheet + +__all__ = [ + "generate_print_version_qa", + "generate_print_version_cloze", + "generate_print_version_mc", + "generate_print_version_worksheet", +] diff --git a/backend-lehrer/ai_processing/print_mc.py b/backend-lehrer/ai_processing/print_mc.py new file mode 100644 index 0000000..57cf2b1 --- /dev/null +++ b/backend-lehrer/ai_processing/print_mc.py @@ -0,0 +1,240 @@ +""" +AI Processing - Print Version Generator: Multiple Choice. + +Generates printable HTML for multiple-choice worksheets. +""" + +from pathlib import Path +import json +import logging + +logger = logging.getLogger(__name__) + + +def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> str: + """ + Generiert eine druckbare HTML-Version der Multiple-Choice-Fragen. + + Args: + mc_path: Pfad zur *_mc.json Datei + include_answers: True fuer Loesungsblatt mit markierten richtigen Antworten + + Returns: + HTML-String (zum direkten Ausliefern) + """ + if not mc_path.exists(): + raise FileNotFoundError(f"MC-Datei nicht gefunden: {mc_path}") + + mc_data = json.loads(mc_path.read_text(encoding="utf-8")) + questions = mc_data.get("questions", []) + metadata = mc_data.get("metadata", {}) + + title = metadata.get("source_title", "Arbeitsblatt") + subject = metadata.get("subject", "") + grade = metadata.get("grade_level", "") + + html_parts = [] + html_parts.append(""" + + + +""" + title + """ - Multiple Choice + + + +""") + + # Header + version_text = "Loesungsblatt" if include_answers else "Multiple Choice Test" + html_parts.append(f"

{title}

") + html_parts.append(f"
{version_text}") + if subject: + html_parts.append(f" | Fach: {subject}") + if grade: + html_parts.append(f" | Klasse: {grade}") + html_parts.append(f" | Anzahl Fragen: {len(questions)}
") + + if not include_answers: + html_parts.append("
") + html_parts.append("Anleitung: Kreuze bei jeder Frage die richtige Antwort an. ") + html_parts.append("Es ist immer nur eine Antwort richtig.") + html_parts.append("
") + + # Fragen + for idx, q in enumerate(questions, 1): + html_parts.append("
") + html_parts.append(f"
Frage {idx}
") + html_parts.append(f"
{q.get('question', '')}
") + + html_parts.append("
") + correct_answer = q.get("correct_answer", "") + + for opt in q.get("options", []): + opt_id = opt.get("id", "") + is_correct = opt_id == correct_answer + + opt_class = "option" + checkbox_class = "option-checkbox" + if include_answers and is_correct: + opt_class += " option-correct" + checkbox_class += " checked" + + html_parts.append(f"
") + html_parts.append(f"
") + html_parts.append(f"{opt_id})") + html_parts.append(f"{opt.get('text', '')}") + html_parts.append("
") + + html_parts.append("
") + + # Erklaerung nur bei Loesungsblatt + if include_answers and q.get("explanation"): + html_parts.append(f"
Erklaerung: {q.get('explanation')}
") + + html_parts.append("
") + + # Loesungsschluessel (kompakt) - nur bei Loesungsblatt + if include_answers: + html_parts.append("
") + html_parts.append("
Loesungsschluessel
") + html_parts.append("
") + for idx, q in enumerate(questions, 1): + html_parts.append("
") + html_parts.append(f"{idx}. ") + html_parts.append(f"{q.get('correct_answer', '')}") + html_parts.append("
") + html_parts.append("
") + html_parts.append("
") + + html_parts.append("") + + return "\n".join(html_parts) diff --git a/backend-lehrer/ai_processing/print_qa.py b/backend-lehrer/ai_processing/print_qa.py new file mode 100644 index 0000000..0e75c9f --- /dev/null +++ b/backend-lehrer/ai_processing/print_qa.py @@ -0,0 +1,149 @@ +""" +AI Processing - Print Version Generator: Q&A. + +Generates printable HTML for question-answer worksheets. +""" + +from pathlib import Path +import json +import logging + +from .core import BEREINIGT_DIR + +logger = logging.getLogger(__name__) + + +def generate_print_version_qa(qa_path: Path, include_answers: bool = False) -> Path: + """ + Generiert eine druckbare HTML-Version der Frage-Antwort-Paare. + + Args: + qa_path: Pfad zur *_qa.json Datei + include_answers: True fuer Loesungsblatt (fuer Eltern) + + Returns: + Pfad zur generierten HTML-Datei + """ + if not qa_path.exists(): + raise FileNotFoundError(f"Q&A-Datei nicht gefunden: {qa_path}") + + qa_data = json.loads(qa_path.read_text(encoding="utf-8")) + items = qa_data.get("qa_items", []) + metadata = qa_data.get("metadata", {}) + + title = metadata.get("source_title", "Arbeitsblatt") + subject = metadata.get("subject", "") + grade = metadata.get("grade_level", "") + + html_parts = [] + html_parts.append(""" + + + +""" + title + """ - Fragen + + + +""") + + # Header + version_text = "Loesungsblatt" if include_answers else "Fragenblatt" + html_parts.append(f"

{title} - {version_text}

") + meta_parts = [] + if subject: + meta_parts.append(f"Fach: {subject}") + if grade: + meta_parts.append(f"Klasse: {grade}") + meta_parts.append(f"Anzahl Fragen: {len(items)}") + html_parts.append(f"
{' | '.join(meta_parts)}
") + + # Fragen + for idx, item in enumerate(items, 1): + html_parts.append("
") + html_parts.append(f"
Frage {idx}
") + html_parts.append(f"
{item.get('question', '')}
") + + if include_answers: + # Loesungsblatt: Antwort anzeigen + html_parts.append(f"
Antwort: {item.get('answer', '')}
") + # Schluesselbegriffe + key_terms = item.get("key_terms", []) + if key_terms: + terms_html = " ".join([f"{term}" for term in key_terms]) + html_parts.append(f"
Wichtige Begriffe: {terms_html}
") + else: + # Fragenblatt: Antwortlinien + html_parts.append("
") + for _ in range(3): + html_parts.append("
") + html_parts.append("
") + + html_parts.append("
") + + html_parts.append("") + + # Speichern + suffix = "_qa_solutions.html" if include_answers else "_qa_print.html" + out_name = qa_path.stem.replace("_qa", "") + suffix + out_path = BEREINIGT_DIR / out_name + out_path.write_text("\n".join(html_parts), encoding="utf-8") + + logger.info(f"Print-Version gespeichert: {out_path.name}") + return out_path diff --git a/backend-lehrer/ai_processing/print_worksheet.py b/backend-lehrer/ai_processing/print_worksheet.py new file mode 100644 index 0000000..131e84d --- /dev/null +++ b/backend-lehrer/ai_processing/print_worksheet.py @@ -0,0 +1,294 @@ +""" +AI Processing - Print Version Generator: Worksheet. + +Generates print-optimized HTML for general worksheets from analysis data. +""" + +from pathlib import Path +import json +import logging + +logger = logging.getLogger(__name__) + + +def generate_print_version_worksheet(analysis_path: Path) -> str: + """ + Generiert eine druckoptimierte HTML-Version des Arbeitsblatts. + + Eigenschaften: + - Grosse, gut lesbare Schrift (16pt) + - Schwarz-weiss / Graustufen-tauglich + - Klare Struktur fuer Druck + - Keine interaktiven Elemente + + Args: + analysis_path: Pfad zur *_analyse.json Datei + + Returns: + HTML-String zum direkten Ausliefern + """ + if not analysis_path.exists(): + raise FileNotFoundError(f"Analysedatei nicht gefunden: {analysis_path}") + + try: + data = json.loads(analysis_path.read_text(encoding="utf-8")) + except json.JSONDecodeError as e: + raise RuntimeError(f"Analyse-Datei enthaelt kein gueltiges JSON: {analysis_path}\n{e}") from e + + title = data.get("title") or "Arbeitsblatt" + subject = data.get("subject") or "" + grade_level = data.get("grade_level") or "" + instructions = data.get("instructions") or "" + tasks = data.get("tasks", []) or [] + canonical_text = data.get("canonical_text") or "" + printed_blocks = data.get("printed_blocks") or [] + + html_parts = [] + html_parts.append(_build_html_head(title)) + + # Titel + html_parts.append(f"

{title}

") + + # Meta-Informationen + meta_parts = [] + if subject: + meta_parts.append(f"Fach: {subject}") + if grade_level: + meta_parts.append(f"Klasse: {grade_level}") + if meta_parts: + html_parts.append(f"
{''.join(meta_parts)}
") + + # Arbeitsanweisung + if instructions: + html_parts.append("
") + html_parts.append("
Arbeitsanweisung:
") + html_parts.append(f"
{instructions}
") + html_parts.append("
") + + # Haupttext / gedruckte Bloecke + _build_text_section(html_parts, printed_blocks, canonical_text) + + # Aufgaben + _build_tasks_section(html_parts, tasks) + + # Fusszeile + html_parts.append("") + + html_parts.append("") + + return "\n".join(html_parts) + + +def _build_html_head(title: str) -> str: + """Build the HTML head with print-optimized styles.""" + return """ + + + +""" + title + """ + + + + +""" + + +def _build_text_section(html_parts: list, printed_blocks: list, canonical_text: str): + """Build the text section from printed blocks or canonical text.""" + if printed_blocks: + html_parts.append("
") + for block in printed_blocks: + role = (block.get("role") or "body").lower() + text = (block.get("text") or "").strip() + if not text: + continue + if role == "title": + html_parts.append(f"
{text}
") + else: + html_parts.append(f"
{text}
") + html_parts.append("
") + elif canonical_text: + html_parts.append("
") + paragraphs = [ + p.strip() + for p in canonical_text.replace("\r\n", "\n").split("\n\n") + if p.strip() + ] + for p in paragraphs: + html_parts.append(f"
{p}
") + html_parts.append("
") + + +def _build_tasks_section(html_parts: list, tasks: list): + """Build the tasks section.""" + if not tasks: + return + + html_parts.append("
") + html_parts.append("

Aufgaben

") + + type_labels = { + "fill_in_blank": "Lueckentext", + "multiple_choice": "Multiple Choice", + "free_text": "Freitext", + "matching": "Zuordnung", + "labeling": "Beschriftung", + "calculation": "Rechnung", + "other": "Aufgabe" + } + + for idx, task in enumerate(tasks, start=1): + t_type = task.get("type") or "Aufgabe" + desc = task.get("description") or "" + text_with_gaps = task.get("text_with_gaps") + + html_parts.append("
") + + type_label = type_labels.get(t_type, t_type) + html_parts.append(f"
Aufgabe {idx}: {type_label}
") + + if desc: + html_parts.append(f"
{desc}
") + + if text_with_gaps: + rendered = text_with_gaps.replace("___", " ") + html_parts.append(f"
{rendered}
") + + # Antwortlinien fuer Freitext-Aufgaben + if t_type in ["free_text", "other"] or (not text_with_gaps and not desc): + html_parts.append("
") + for _ in range(3): + html_parts.append("
") + html_parts.append("
") + + html_parts.append("
") + + html_parts.append("
") diff --git a/backend-lehrer/classroom/routes/context.py b/backend-lehrer/classroom/routes/context.py index 39a6428..2f130d0 100644 --- a/backend-lehrer/classroom/routes/context.py +++ b/backend-lehrer/classroom/routes/context.py @@ -1,726 +1,25 @@ """ -Classroom API - Context Routes +Classroom API - Context Routes — Barrel Re-export. + +Split into submodules: +- context_core.py — Teacher context, onboarding endpoints +- context_events.py — Events & routines CRUD +- context_static.py — Static data, suggestions, sidebar, school year path School year context, events, routines, and suggestions endpoints (Phase 8). """ -from typing import Dict, Any, Optional -from datetime import datetime -import logging +from fastapi import APIRouter -from fastapi import APIRouter, HTTPException, Query, Depends - -from classroom_engine import ( - FEDERAL_STATES, - SCHOOL_TYPES, - MacroPhaseEnum, -) - -from ..models import ( - TeacherContextResponse, - SchoolInfo, - SchoolYearInfo, - MacroPhaseInfo, - CoreCounts, - ContextFlags, - UpdateContextRequest, - CreateEventRequest, - EventResponse, - CreateRoutineRequest, - RoutineResponse, -) -from ..services.persistence import ( - init_db_if_needed, - DB_ENABLED, - SessionLocal, -) - -logger = logging.getLogger(__name__) +from .context_core import router as _core_router +from .context_events import router as _events_router +from .context_static import router as _static_router +# Combine all sub-routers into a single router for backwards compatibility. +# The consumer imports `from .routes.context import router as context_router`. router = APIRouter(tags=["Context"]) +router.include_router(_core_router) +router.include_router(_events_router) +router.include_router(_static_router) - -def get_db(): - """Database session dependency.""" - if DB_ENABLED and SessionLocal: - db = SessionLocal() - try: - yield db - finally: - db.close() - else: - yield None - - -def _get_macro_phase_label(phase) -> str: - """Gibt den Anzeigenamen einer Makro-Phase zurueck.""" - labels = { - "onboarding": "Einrichtung", - "schuljahresstart": "Schuljahresstart", - "unterrichtsaufbau": "Unterrichtsaufbau", - "leistungsphase_1": "Leistungsphase 1", - "halbjahresabschluss": "Halbjahresabschluss", - "leistungsphase_2": "Leistungsphase 2", - "jahresabschluss": "Jahresabschluss", - } - phase_value = phase.value if hasattr(phase, 'value') else str(phase) - return labels.get(phase_value, phase_value) - - -# === Context Endpoints === - -@router.get("/v1/context", response_model=TeacherContextResponse) -async def get_teacher_context( - teacher_id: str = Query(..., description="Teacher ID"), - db=Depends(get_db) -): - """ - Liefert den aktuellen Makro-Kontext eines Lehrers. - - Der Kontext beinhaltet: - - Schul-Informationen (Bundesland, Schulart) - - Schuljahr-Daten (aktuelles Jahr, Woche) - - Makro-Phase (ONBOARDING bis JAHRESABSCHLUSS) - - Zaehler (Klassen, geplante Klausuren, etc.) - - Status-Flags (Onboarding abgeschlossen, etc.) - """ - if DB_ENABLED and db: - try: - from classroom_engine.repository import TeacherContextRepository, SchoolyearEventRepository - repo = TeacherContextRepository(db) - context = repo.get_or_create(teacher_id) - - # Zaehler berechnen - event_repo = SchoolyearEventRepository(db) - upcoming_exams = event_repo.get_upcoming(teacher_id, days=30) - exams_count = len([e for e in upcoming_exams if e.event_type.value == "exam"]) - - return TeacherContextResponse( - schema_version="1.0", - teacher_id=teacher_id, - school=SchoolInfo( - federal_state=context.federal_state or "BY", - federal_state_name=FEDERAL_STATES.get(context.federal_state, ""), - school_type=context.school_type or "gymnasium", - school_type_name=SCHOOL_TYPES.get(context.school_type, ""), - ), - school_year=SchoolYearInfo( - id=context.schoolyear or "2024-2025", - start=context.schoolyear_start.isoformat() if context.schoolyear_start else None, - current_week=context.current_week or 1, - ), - macro_phase=MacroPhaseInfo( - id=context.macro_phase.value, - label=_get_macro_phase_label(context.macro_phase), - confidence=1.0, - ), - core_counts=CoreCounts( - classes=1 if context.has_classes else 0, - exams_scheduled=exams_count, - corrections_pending=0, - ), - flags=ContextFlags( - onboarding_completed=context.onboarding_completed, - has_classes=context.has_classes, - has_schedule=context.has_schedule, - is_exam_period=context.is_exam_period, - is_before_holidays=context.is_before_holidays, - ), - ) - except Exception as e: - logger.error(f"Failed to get teacher context: {e}") - raise HTTPException(status_code=500, detail=f"Fehler beim Laden des Kontexts: {e}") - - # Fallback ohne DB - return TeacherContextResponse( - schema_version="1.0", - teacher_id=teacher_id, - school=SchoolInfo( - federal_state="BY", - federal_state_name="Bayern", - school_type="gymnasium", - school_type_name="Gymnasium", - ), - school_year=SchoolYearInfo( - id="2024-2025", - start=None, - current_week=1, - ), - macro_phase=MacroPhaseInfo( - id="onboarding", - label="Einrichtung", - confidence=1.0, - ), - core_counts=CoreCounts(), - flags=ContextFlags(), - ) - - -@router.put("/v1/context", response_model=TeacherContextResponse) -async def update_teacher_context( - teacher_id: str, - request: UpdateContextRequest, - db=Depends(get_db) -): - """ - Aktualisiert den Kontext eines Lehrers. - """ - if not DB_ENABLED or not db: - raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") - - try: - from classroom_engine.repository import TeacherContextRepository - repo = TeacherContextRepository(db) - - # Validierung - if request.federal_state and request.federal_state not in FEDERAL_STATES: - raise HTTPException(status_code=400, detail=f"Ungueltiges Bundesland: {request.federal_state}") - if request.school_type and request.school_type not in SCHOOL_TYPES: - raise HTTPException(status_code=400, detail=f"Ungueltige Schulart: {request.school_type}") - - # Parse datetime if provided - schoolyear_start = None - if request.schoolyear_start: - schoolyear_start = datetime.fromisoformat(request.schoolyear_start.replace('Z', '+00:00')) - - repo.update_context( - teacher_id=teacher_id, - federal_state=request.federal_state, - school_type=request.school_type, - schoolyear=request.schoolyear, - schoolyear_start=schoolyear_start, - macro_phase=request.macro_phase, - current_week=request.current_week, - ) - - return await get_teacher_context(teacher_id, db) - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to update teacher context: {e}") - raise HTTPException(status_code=500, detail=f"Fehler beim Aktualisieren: {e}") - - -@router.post("/v1/context/complete-onboarding") -async def complete_onboarding( - teacher_id: str = Query(...), - db=Depends(get_db) -): - """Markiert das Onboarding als abgeschlossen.""" - if not DB_ENABLED or not db: - return {"success": True, "macro_phase": "schuljahresstart", "note": "DB not available"} - - try: - from classroom_engine.repository import TeacherContextRepository - repo = TeacherContextRepository(db) - context = repo.complete_onboarding(teacher_id) - return { - "success": True, - "macro_phase": context.macro_phase.value, - "teacher_id": teacher_id, - } - except Exception as e: - logger.error(f"Failed to complete onboarding: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -@router.post("/v1/context/reset-onboarding") -async def reset_onboarding( - teacher_id: str = Query(...), - db=Depends(get_db) -): - """Setzt das Onboarding zurueck (fuer Tests).""" - if not DB_ENABLED or not db: - return {"success": True, "macro_phase": "onboarding", "note": "DB not available"} - - try: - from classroom_engine.repository import TeacherContextRepository - repo = TeacherContextRepository(db) - context = repo.get_or_create(teacher_id) - context.onboarding_completed = False - context.macro_phase = MacroPhaseEnum.ONBOARDING - db.commit() - db.refresh(context) - return { - "success": True, - "macro_phase": "onboarding", - "teacher_id": teacher_id, - } - except Exception as e: - logger.error(f"Failed to reset onboarding: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -# === Events Endpoints === - -@router.get("/v1/events") -async def get_events( - teacher_id: str = Query(...), - status: Optional[str] = None, - event_type: Optional[str] = None, - limit: int = 50, - db=Depends(get_db) -): - """Holt Events eines Lehrers.""" - if not DB_ENABLED or not db: - return {"events": [], "count": 0} - - try: - from classroom_engine.repository import SchoolyearEventRepository - repo = SchoolyearEventRepository(db) - events = repo.get_by_teacher(teacher_id, status=status, event_type=event_type, limit=limit) - return { - "events": [repo.to_dict(e) for e in events], - "count": len(events), - } - except Exception as e: - logger.error(f"Failed to get events: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -@router.get("/v1/events/upcoming") -async def get_upcoming_events( - teacher_id: str = Query(...), - days: int = 30, - limit: int = 10, - db=Depends(get_db) -): - """Holt anstehende Events der naechsten X Tage.""" - if not DB_ENABLED or not db: - return {"events": [], "count": 0} - - try: - from classroom_engine.repository import SchoolyearEventRepository - repo = SchoolyearEventRepository(db) - events = repo.get_upcoming(teacher_id, days=days, limit=limit) - return { - "events": [repo.to_dict(e) for e in events], - "count": len(events), - } - except Exception as e: - logger.error(f"Failed to get upcoming events: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -@router.post("/v1/events", response_model=EventResponse) -async def create_event( - teacher_id: str, - request: CreateEventRequest, - db=Depends(get_db) -): - """Erstellt ein neues Schuljahr-Event.""" - if not DB_ENABLED or not db: - raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") - - try: - from classroom_engine.repository import SchoolyearEventRepository - repo = SchoolyearEventRepository(db) - start_date = datetime.fromisoformat(request.start_date.replace('Z', '+00:00')) - end_date = None - if request.end_date: - end_date = datetime.fromisoformat(request.end_date.replace('Z', '+00:00')) - - event = repo.create( - teacher_id=teacher_id, - title=request.title, - event_type=request.event_type, - start_date=start_date, - end_date=end_date, - class_id=request.class_id, - subject=request.subject, - description=request.description, - needs_preparation=request.needs_preparation, - reminder_days_before=request.reminder_days_before, - ) - - return EventResponse( - id=event.id, - teacher_id=event.teacher_id, - event_type=event.event_type.value, - title=event.title, - description=event.description, - start_date=event.start_date.isoformat(), - end_date=event.end_date.isoformat() if event.end_date else None, - class_id=event.class_id, - subject=event.subject, - status=event.status.value, - needs_preparation=event.needs_preparation, - preparation_done=event.preparation_done, - reminder_days_before=event.reminder_days_before, - ) - except Exception as e: - logger.error(f"Failed to create event: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -@router.delete("/v1/events/{event_id}") -async def delete_event(event_id: str, db=Depends(get_db)): - """Loescht ein Event.""" - if not DB_ENABLED or not db: - raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") - - try: - from classroom_engine.repository import SchoolyearEventRepository - repo = SchoolyearEventRepository(db) - if repo.delete(event_id): - return {"success": True, "deleted_id": event_id} - raise HTTPException(status_code=404, detail="Event nicht gefunden") - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to delete event: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -# === Routines Endpoints === - -@router.get("/v1/routines") -async def get_routines( - teacher_id: str = Query(...), - is_active: bool = True, - routine_type: Optional[str] = None, - db=Depends(get_db) -): - """Holt Routinen eines Lehrers.""" - if not DB_ENABLED or not db: - return {"routines": [], "count": 0} - - try: - from classroom_engine.repository import RecurringRoutineRepository - repo = RecurringRoutineRepository(db) - routines = repo.get_by_teacher(teacher_id, is_active=is_active, routine_type=routine_type) - return { - "routines": [repo.to_dict(r) for r in routines], - "count": len(routines), - } - except Exception as e: - logger.error(f"Failed to get routines: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -@router.get("/v1/routines/today") -async def get_today_routines(teacher_id: str = Query(...), db=Depends(get_db)): - """Holt Routinen die heute stattfinden.""" - if not DB_ENABLED or not db: - return {"routines": [], "count": 0} - - try: - from classroom_engine.repository import RecurringRoutineRepository - repo = RecurringRoutineRepository(db) - routines = repo.get_today(teacher_id) - return { - "routines": [repo.to_dict(r) for r in routines], - "count": len(routines), - } - except Exception as e: - logger.error(f"Failed to get today's routines: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -@router.post("/v1/routines", response_model=RoutineResponse) -async def create_routine( - teacher_id: str, - request: CreateRoutineRequest, - db=Depends(get_db) -): - """Erstellt eine neue wiederkehrende Routine.""" - if not DB_ENABLED or not db: - raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") - - try: - from classroom_engine.repository import RecurringRoutineRepository - repo = RecurringRoutineRepository(db) - routine = repo.create( - teacher_id=teacher_id, - title=request.title, - routine_type=request.routine_type, - recurrence_pattern=request.recurrence_pattern, - day_of_week=request.day_of_week, - day_of_month=request.day_of_month, - time_of_day=request.time_of_day, - duration_minutes=request.duration_minutes, - description=request.description, - ) - - return RoutineResponse( - id=routine.id, - teacher_id=routine.teacher_id, - routine_type=routine.routine_type.value, - title=routine.title, - description=routine.description, - recurrence_pattern=routine.recurrence_pattern.value, - day_of_week=routine.day_of_week, - day_of_month=routine.day_of_month, - time_of_day=routine.time_of_day.isoformat() if routine.time_of_day else None, - duration_minutes=routine.duration_minutes, - is_active=routine.is_active, - ) - except Exception as e: - logger.error(f"Failed to create routine: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -@router.delete("/v1/routines/{routine_id}") -async def delete_routine(routine_id: str, db=Depends(get_db)): - """Loescht eine Routine.""" - if not DB_ENABLED or not db: - raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") - - try: - from classroom_engine.repository import RecurringRoutineRepository - repo = RecurringRoutineRepository(db) - if repo.delete(routine_id): - return {"success": True, "deleted_id": routine_id} - raise HTTPException(status_code=404, detail="Routine nicht gefunden") - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to delete routine: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - -# === Static Data Endpoints === - -@router.get("/v1/federal-states") -async def get_federal_states(): - """Gibt alle Bundeslaender zurueck.""" - return { - "federal_states": [{"id": k, "name": v} for k, v in FEDERAL_STATES.items()] - } - - -@router.get("/v1/school-types") -async def get_school_types(): - """Gibt alle Schularten zurueck.""" - return { - "school_types": [{"id": k, "name": v} for k, v in SCHOOL_TYPES.items()] - } - - -@router.get("/v1/macro-phases") -async def get_macro_phases(): - """Gibt alle Makro-Phasen mit Beschreibungen zurueck.""" - phases = [ - {"id": "onboarding", "label": "Einrichtung", "description": "Ersteinrichtung (Klassen, Stundenplan)", "order": 1}, - {"id": "schuljahresstart", "label": "Schuljahresstart", "description": "Erste 2-3 Wochen des Schuljahres", "order": 2}, - {"id": "unterrichtsaufbau", "label": "Unterrichtsaufbau", "description": "Routinen etablieren, erste Bewertungen", "order": 3}, - {"id": "leistungsphase_1", "label": "Leistungsphase 1", "description": "Erste Klassenarbeiten und Klausuren", "order": 4}, - {"id": "halbjahresabschluss", "label": "Halbjahresabschluss", "description": "Notenschluss, Zeugnisse, Konferenzen", "order": 5}, - {"id": "leistungsphase_2", "label": "Leistungsphase 2", "description": "Zweites Halbjahr, Pruefungsvorbereitung", "order": 6}, - {"id": "jahresabschluss", "label": "Jahresabschluss", "description": "Finale Noten, Versetzung, Schuljahresende", "order": 7}, - ] - return {"macro_phases": phases} - - -@router.get("/v1/event-types") -async def get_event_types(): - """Gibt alle Event-Typen zurueck.""" - types = [ - {"id": "exam", "label": "Klassenarbeit/Klausur"}, - {"id": "parent_evening", "label": "Elternabend"}, - {"id": "trip", "label": "Klassenfahrt/Ausflug"}, - {"id": "project", "label": "Projektwoche"}, - {"id": "internship", "label": "Praktikum"}, - {"id": "presentation", "label": "Referate/Praesentationen"}, - {"id": "sports_day", "label": "Sporttag"}, - {"id": "school_festival", "label": "Schulfest"}, - {"id": "parent_consultation", "label": "Elternsprechtag"}, - {"id": "grade_deadline", "label": "Notenschluss"}, - {"id": "report_cards", "label": "Zeugnisausgabe"}, - {"id": "holiday_start", "label": "Ferienbeginn"}, - {"id": "holiday_end", "label": "Ferienende"}, - {"id": "other", "label": "Sonstiges"}, - ] - return {"event_types": types} - - -@router.get("/v1/routine-types") -async def get_routine_types(): - """Gibt alle Routine-Typen zurueck.""" - types = [ - {"id": "teacher_conference", "label": "Lehrerkonferenz"}, - {"id": "subject_conference", "label": "Fachkonferenz"}, - {"id": "office_hours", "label": "Sprechstunde"}, - {"id": "team_meeting", "label": "Teamsitzung"}, - {"id": "supervision", "label": "Pausenaufsicht"}, - {"id": "correction_time", "label": "Korrekturzeit"}, - {"id": "prep_time", "label": "Vorbereitungszeit"}, - {"id": "other", "label": "Sonstiges"}, - ] - return {"routine_types": types} - - -# === Suggestions & Sidebar === - -@router.get("/v1/suggestions") -async def get_suggestions( - teacher_id: str = Query(...), - limit: int = Query(5, ge=1, le=20), - db=Depends(get_db) -): - """Generiert kontextbasierte Vorschlaege fuer einen Lehrer.""" - if DB_ENABLED and db: - try: - from classroom_engine.suggestions import SuggestionGenerator - generator = SuggestionGenerator(db) - result = generator.generate(teacher_id, limit=limit) - return result - except Exception as e: - logger.error(f"Failed to generate suggestions: {e}") - raise HTTPException(status_code=500, detail=f"Fehler: {e}") - - return { - "active_contexts": [], - "suggestions": [], - "signals_summary": { - "macro_phase": "onboarding", - "current_week": 1, - "has_classes": False, - "exams_soon": 0, - "routines_today": 0, - }, - "total_suggestions": 0, - } - - -@router.get("/v1/sidebar") -async def get_sidebar( - teacher_id: str = Query(...), - mode: str = Query("companion"), - db=Depends(get_db) -): - """Generiert das dynamische Sidebar-Model.""" - if mode == "companion": - now_relevant = [] - if DB_ENABLED and db: - try: - from classroom_engine.suggestions import SuggestionGenerator - generator = SuggestionGenerator(db) - result = generator.generate(teacher_id, limit=5) - now_relevant = [ - { - "id": s["id"], - "label": s["title"], - "state": "recommended" if s["priority"] > 70 else "default", - "badge": s.get("badge"), - "icon": s.get("icon", "lightbulb"), - "action_url": s.get("action_url"), - } - for s in result.get("suggestions", []) - ] - except Exception as e: - logger.warning(f"Failed to get suggestions for sidebar: {e}") - - return { - "mode": "companion", - "sections": [ - {"id": "SEARCH", "type": "search_bar", "placeholder": "Suchen..."}, - { - "id": "NOW_RELEVANT", - "type": "list", - "title": "Jetzt relevant", - "items": now_relevant if now_relevant else [ - {"id": "no_suggestions", "label": "Keine Vorschlaege", "state": "default", "icon": "check_circle"} - ], - }, - { - "id": "ALL_MODULES", - "type": "folder", - "label": "Alle Module", - "icon": "folder", - "collapsed": True, - "items": [ - {"id": "lesson", "label": "Stundenmodus", "icon": "timer"}, - {"id": "classes", "label": "Klassen", "icon": "groups"}, - {"id": "exams", "label": "Klausuren", "icon": "quiz"}, - {"id": "grades", "label": "Noten", "icon": "calculate"}, - {"id": "calendar", "label": "Kalender", "icon": "calendar_month"}, - {"id": "materials", "label": "Materialien", "icon": "folder_open"}, - ], - }, - { - "id": "QUICK_ACTIONS", - "type": "actions", - "title": "Kurzaktionen", - "items": [ - {"id": "scan", "label": "Scan hochladen", "icon": "upload_file"}, - {"id": "note", "label": "Notiz erstellen", "icon": "note_add"}, - ], - }, - ], - } - else: - return { - "mode": "classic", - "sections": [ - { - "id": "NAVIGATION", - "type": "tree", - "items": [ - {"id": "dashboard", "label": "Dashboard", "icon": "dashboard", "url": "/dashboard"}, - {"id": "lesson", "label": "Stundenmodus", "icon": "timer", "url": "/lesson"}, - {"id": "classes", "label": "Klassen", "icon": "groups", "url": "/classes"}, - {"id": "exams", "label": "Klausuren", "icon": "quiz", "url": "/exams"}, - {"id": "grades", "label": "Noten", "icon": "calculate", "url": "/grades"}, - {"id": "calendar", "label": "Kalender", "icon": "calendar_month", "url": "/calendar"}, - {"id": "materials", "label": "Materialien", "icon": "folder_open", "url": "/materials"}, - {"id": "settings", "label": "Einstellungen", "icon": "settings", "url": "/settings"}, - ], - }, - ], - } - - -@router.get("/v1/path") -async def get_schoolyear_path(teacher_id: str = Query(...), db=Depends(get_db)): - """Generiert den Schuljahres-Pfad mit Meilensteinen.""" - current_phase = "onboarding" - if DB_ENABLED and db: - try: - from classroom_engine.repository import TeacherContextRepository - repo = TeacherContextRepository(db) - context = repo.get_or_create(teacher_id) - current_phase = context.macro_phase.value - except Exception as e: - logger.warning(f"Failed to get context for path: {e}") - - phase_order = [ - "onboarding", "schuljahresstart", "unterrichtsaufbau", - "leistungsphase_1", "halbjahresabschluss", "leistungsphase_2", "jahresabschluss", - ] - - current_index = phase_order.index(current_phase) if current_phase in phase_order else 0 - - milestones = [ - {"id": "MS_START", "label": "Start", "phase": "onboarding", "icon": "flag"}, - {"id": "MS_SETUP", "label": "Einrichtung", "phase": "schuljahresstart", "icon": "tune"}, - {"id": "MS_ROUTINE", "label": "Routinen", "phase": "unterrichtsaufbau", "icon": "repeat"}, - {"id": "MS_EXAM_1", "label": "Klausuren", "phase": "leistungsphase_1", "icon": "quiz"}, - {"id": "MS_HALFYEAR", "label": "Halbjahr", "phase": "halbjahresabschluss", "icon": "event"}, - {"id": "MS_EXAM_2", "label": "Pruefungen", "phase": "leistungsphase_2", "icon": "school"}, - {"id": "MS_END", "label": "Abschluss", "phase": "jahresabschluss", "icon": "celebration"}, - ] - - for milestone in milestones: - phase = milestone["phase"] - phase_index = phase_order.index(phase) if phase in phase_order else 999 - if phase_index < current_index: - milestone["status"] = "done" - elif phase_index == current_index: - milestone["status"] = "current" - else: - milestone["status"] = "upcoming" - - current_milestone_id = next( - (m["id"] for m in milestones if m["status"] == "current"), - milestones[0]["id"] - ) - - progress = int((current_index / (len(phase_order) - 1)) * 100) if len(phase_order) > 1 else 0 - - return { - "milestones": milestones, - "current_milestone_id": current_milestone_id, - "progress_percent": progress, - "current_phase": current_phase, - } +__all__ = ["router"] diff --git a/backend-lehrer/classroom/routes/context_core.py b/backend-lehrer/classroom/routes/context_core.py new file mode 100644 index 0000000..b4d7ba4 --- /dev/null +++ b/backend-lehrer/classroom/routes/context_core.py @@ -0,0 +1,247 @@ +""" +Classroom API - Context Core Routes + +Teacher context, onboarding endpoints (Phase 8). +""" + +from typing import Dict, Any +from datetime import datetime +import logging + +from fastapi import APIRouter, HTTPException, Query, Depends + +from classroom_engine import ( + FEDERAL_STATES, + SCHOOL_TYPES, + MacroPhaseEnum, +) + +from ..models import ( + TeacherContextResponse, + SchoolInfo, + SchoolYearInfo, + MacroPhaseInfo, + CoreCounts, + ContextFlags, + UpdateContextRequest, +) +from ..services.persistence import ( + init_db_if_needed, + DB_ENABLED, + SessionLocal, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Context"]) + + +def get_db(): + """Database session dependency.""" + if DB_ENABLED and SessionLocal: + db = SessionLocal() + try: + yield db + finally: + db.close() + else: + yield None + + +def _get_macro_phase_label(phase) -> str: + """Gibt den Anzeigenamen einer Makro-Phase zurueck.""" + labels = { + "onboarding": "Einrichtung", + "schuljahresstart": "Schuljahresstart", + "unterrichtsaufbau": "Unterrichtsaufbau", + "leistungsphase_1": "Leistungsphase 1", + "halbjahresabschluss": "Halbjahresabschluss", + "leistungsphase_2": "Leistungsphase 2", + "jahresabschluss": "Jahresabschluss", + } + phase_value = phase.value if hasattr(phase, 'value') else str(phase) + return labels.get(phase_value, phase_value) + + +# === Context Endpoints === + +@router.get("/v1/context", response_model=TeacherContextResponse) +async def get_teacher_context( + teacher_id: str = Query(..., description="Teacher ID"), + db=Depends(get_db) +): + """ + Liefert den aktuellen Makro-Kontext eines Lehrers. + + Der Kontext beinhaltet: + - Schul-Informationen (Bundesland, Schulart) + - Schuljahr-Daten (aktuelles Jahr, Woche) + - Makro-Phase (ONBOARDING bis JAHRESABSCHLUSS) + - Zaehler (Klassen, geplante Klausuren, etc.) + - Status-Flags (Onboarding abgeschlossen, etc.) + """ + if DB_ENABLED and db: + try: + from classroom_engine.repository import TeacherContextRepository, SchoolyearEventRepository + repo = TeacherContextRepository(db) + context = repo.get_or_create(teacher_id) + + # Zaehler berechnen + event_repo = SchoolyearEventRepository(db) + upcoming_exams = event_repo.get_upcoming(teacher_id, days=30) + exams_count = len([e for e in upcoming_exams if e.event_type.value == "exam"]) + + return TeacherContextResponse( + schema_version="1.0", + teacher_id=teacher_id, + school=SchoolInfo( + federal_state=context.federal_state or "BY", + federal_state_name=FEDERAL_STATES.get(context.federal_state, ""), + school_type=context.school_type or "gymnasium", + school_type_name=SCHOOL_TYPES.get(context.school_type, ""), + ), + school_year=SchoolYearInfo( + id=context.schoolyear or "2024-2025", + start=context.schoolyear_start.isoformat() if context.schoolyear_start else None, + current_week=context.current_week or 1, + ), + macro_phase=MacroPhaseInfo( + id=context.macro_phase.value, + label=_get_macro_phase_label(context.macro_phase), + confidence=1.0, + ), + core_counts=CoreCounts( + classes=1 if context.has_classes else 0, + exams_scheduled=exams_count, + corrections_pending=0, + ), + flags=ContextFlags( + onboarding_completed=context.onboarding_completed, + has_classes=context.has_classes, + has_schedule=context.has_schedule, + is_exam_period=context.is_exam_period, + is_before_holidays=context.is_before_holidays, + ), + ) + except Exception as e: + logger.error(f"Failed to get teacher context: {e}") + raise HTTPException(status_code=500, detail=f"Fehler beim Laden des Kontexts: {e}") + + # Fallback ohne DB + return TeacherContextResponse( + schema_version="1.0", + teacher_id=teacher_id, + school=SchoolInfo( + federal_state="BY", + federal_state_name="Bayern", + school_type="gymnasium", + school_type_name="Gymnasium", + ), + school_year=SchoolYearInfo( + id="2024-2025", + start=None, + current_week=1, + ), + macro_phase=MacroPhaseInfo( + id="onboarding", + label="Einrichtung", + confidence=1.0, + ), + core_counts=CoreCounts(), + flags=ContextFlags(), + ) + + +@router.put("/v1/context", response_model=TeacherContextResponse) +async def update_teacher_context( + teacher_id: str, + request: UpdateContextRequest, + db=Depends(get_db) +): + """ + Aktualisiert den Kontext eines Lehrers. + """ + if not DB_ENABLED or not db: + raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") + + try: + from classroom_engine.repository import TeacherContextRepository + repo = TeacherContextRepository(db) + + # Validierung + if request.federal_state and request.federal_state not in FEDERAL_STATES: + raise HTTPException(status_code=400, detail=f"Ungueltiges Bundesland: {request.federal_state}") + if request.school_type and request.school_type not in SCHOOL_TYPES: + raise HTTPException(status_code=400, detail=f"Ungueltige Schulart: {request.school_type}") + + # Parse datetime if provided + schoolyear_start = None + if request.schoolyear_start: + schoolyear_start = datetime.fromisoformat(request.schoolyear_start.replace('Z', '+00:00')) + + repo.update_context( + teacher_id=teacher_id, + federal_state=request.federal_state, + school_type=request.school_type, + schoolyear=request.schoolyear, + schoolyear_start=schoolyear_start, + macro_phase=request.macro_phase, + current_week=request.current_week, + ) + + return await get_teacher_context(teacher_id, db) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to update teacher context: {e}") + raise HTTPException(status_code=500, detail=f"Fehler beim Aktualisieren: {e}") + + +@router.post("/v1/context/complete-onboarding") +async def complete_onboarding( + teacher_id: str = Query(...), + db=Depends(get_db) +): + """Markiert das Onboarding als abgeschlossen.""" + if not DB_ENABLED or not db: + return {"success": True, "macro_phase": "schuljahresstart", "note": "DB not available"} + + try: + from classroom_engine.repository import TeacherContextRepository + repo = TeacherContextRepository(db) + context = repo.complete_onboarding(teacher_id) + return { + "success": True, + "macro_phase": context.macro_phase.value, + "teacher_id": teacher_id, + } + except Exception as e: + logger.error(f"Failed to complete onboarding: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + +@router.post("/v1/context/reset-onboarding") +async def reset_onboarding( + teacher_id: str = Query(...), + db=Depends(get_db) +): + """Setzt das Onboarding zurueck (fuer Tests).""" + if not DB_ENABLED or not db: + return {"success": True, "macro_phase": "onboarding", "note": "DB not available"} + + try: + from classroom_engine.repository import TeacherContextRepository + repo = TeacherContextRepository(db) + context = repo.get_or_create(teacher_id) + context.onboarding_completed = False + context.macro_phase = MacroPhaseEnum.ONBOARDING + db.commit() + db.refresh(context) + return { + "success": True, + "macro_phase": "onboarding", + "teacher_id": teacher_id, + } + except Exception as e: + logger.error(f"Failed to reset onboarding: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") diff --git a/backend-lehrer/classroom/routes/context_events.py b/backend-lehrer/classroom/routes/context_events.py new file mode 100644 index 0000000..54a19b2 --- /dev/null +++ b/backend-lehrer/classroom/routes/context_events.py @@ -0,0 +1,266 @@ +""" +Classroom API - Events & Routines Routes + +School year events, recurring routines endpoints (Phase 8). +""" + +from typing import Optional +from datetime import datetime +import logging + +from fastapi import APIRouter, HTTPException, Query, Depends + +from ..models import ( + CreateEventRequest, + EventResponse, + CreateRoutineRequest, + RoutineResponse, +) +from ..services.persistence import ( + DB_ENABLED, + SessionLocal, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Context"]) + + +def get_db(): + """Database session dependency.""" + if DB_ENABLED and SessionLocal: + db = SessionLocal() + try: + yield db + finally: + db.close() + else: + yield None + + +# === Events Endpoints === + +@router.get("/v1/events") +async def get_events( + teacher_id: str = Query(...), + status: Optional[str] = None, + event_type: Optional[str] = None, + limit: int = 50, + db=Depends(get_db) +): + """Holt Events eines Lehrers.""" + if not DB_ENABLED or not db: + return {"events": [], "count": 0} + + try: + from classroom_engine.repository import SchoolyearEventRepository + repo = SchoolyearEventRepository(db) + events = repo.get_by_teacher(teacher_id, status=status, event_type=event_type, limit=limit) + return { + "events": [repo.to_dict(e) for e in events], + "count": len(events), + } + except Exception as e: + logger.error(f"Failed to get events: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + +@router.get("/v1/events/upcoming") +async def get_upcoming_events( + teacher_id: str = Query(...), + days: int = 30, + limit: int = 10, + db=Depends(get_db) +): + """Holt anstehende Events der naechsten X Tage.""" + if not DB_ENABLED or not db: + return {"events": [], "count": 0} + + try: + from classroom_engine.repository import SchoolyearEventRepository + repo = SchoolyearEventRepository(db) + events = repo.get_upcoming(teacher_id, days=days, limit=limit) + return { + "events": [repo.to_dict(e) for e in events], + "count": len(events), + } + except Exception as e: + logger.error(f"Failed to get upcoming events: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + +@router.post("/v1/events", response_model=EventResponse) +async def create_event( + teacher_id: str, + request: CreateEventRequest, + db=Depends(get_db) +): + """Erstellt ein neues Schuljahr-Event.""" + if not DB_ENABLED or not db: + raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") + + try: + from classroom_engine.repository import SchoolyearEventRepository + repo = SchoolyearEventRepository(db) + start_date = datetime.fromisoformat(request.start_date.replace('Z', '+00:00')) + end_date = None + if request.end_date: + end_date = datetime.fromisoformat(request.end_date.replace('Z', '+00:00')) + + event = repo.create( + teacher_id=teacher_id, + title=request.title, + event_type=request.event_type, + start_date=start_date, + end_date=end_date, + class_id=request.class_id, + subject=request.subject, + description=request.description, + needs_preparation=request.needs_preparation, + reminder_days_before=request.reminder_days_before, + ) + + return EventResponse( + id=event.id, + teacher_id=event.teacher_id, + event_type=event.event_type.value, + title=event.title, + description=event.description, + start_date=event.start_date.isoformat(), + end_date=event.end_date.isoformat() if event.end_date else None, + class_id=event.class_id, + subject=event.subject, + status=event.status.value, + needs_preparation=event.needs_preparation, + preparation_done=event.preparation_done, + reminder_days_before=event.reminder_days_before, + ) + except Exception as e: + logger.error(f"Failed to create event: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + +@router.delete("/v1/events/{event_id}") +async def delete_event(event_id: str, db=Depends(get_db)): + """Loescht ein Event.""" + if not DB_ENABLED or not db: + raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") + + try: + from classroom_engine.repository import SchoolyearEventRepository + repo = SchoolyearEventRepository(db) + if repo.delete(event_id): + return {"success": True, "deleted_id": event_id} + raise HTTPException(status_code=404, detail="Event nicht gefunden") + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete event: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + +# === Routines Endpoints === + +@router.get("/v1/routines") +async def get_routines( + teacher_id: str = Query(...), + is_active: bool = True, + routine_type: Optional[str] = None, + db=Depends(get_db) +): + """Holt Routinen eines Lehrers.""" + if not DB_ENABLED or not db: + return {"routines": [], "count": 0} + + try: + from classroom_engine.repository import RecurringRoutineRepository + repo = RecurringRoutineRepository(db) + routines = repo.get_by_teacher(teacher_id, is_active=is_active, routine_type=routine_type) + return { + "routines": [repo.to_dict(r) for r in routines], + "count": len(routines), + } + except Exception as e: + logger.error(f"Failed to get routines: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + +@router.get("/v1/routines/today") +async def get_today_routines(teacher_id: str = Query(...), db=Depends(get_db)): + """Holt Routinen die heute stattfinden.""" + if not DB_ENABLED or not db: + return {"routines": [], "count": 0} + + try: + from classroom_engine.repository import RecurringRoutineRepository + repo = RecurringRoutineRepository(db) + routines = repo.get_today(teacher_id) + return { + "routines": [repo.to_dict(r) for r in routines], + "count": len(routines), + } + except Exception as e: + logger.error(f"Failed to get today's routines: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + +@router.post("/v1/routines", response_model=RoutineResponse) +async def create_routine( + teacher_id: str, + request: CreateRoutineRequest, + db=Depends(get_db) +): + """Erstellt eine neue wiederkehrende Routine.""" + if not DB_ENABLED or not db: + raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") + + try: + from classroom_engine.repository import RecurringRoutineRepository + repo = RecurringRoutineRepository(db) + routine = repo.create( + teacher_id=teacher_id, + title=request.title, + routine_type=request.routine_type, + recurrence_pattern=request.recurrence_pattern, + day_of_week=request.day_of_week, + day_of_month=request.day_of_month, + time_of_day=request.time_of_day, + duration_minutes=request.duration_minutes, + description=request.description, + ) + + return RoutineResponse( + id=routine.id, + teacher_id=routine.teacher_id, + routine_type=routine.routine_type.value, + title=routine.title, + description=routine.description, + recurrence_pattern=routine.recurrence_pattern.value, + day_of_week=routine.day_of_week, + day_of_month=routine.day_of_month, + time_of_day=routine.time_of_day.isoformat() if routine.time_of_day else None, + duration_minutes=routine.duration_minutes, + is_active=routine.is_active, + ) + except Exception as e: + logger.error(f"Failed to create routine: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + +@router.delete("/v1/routines/{routine_id}") +async def delete_routine(routine_id: str, db=Depends(get_db)): + """Loescht eine Routine.""" + if not DB_ENABLED or not db: + raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar") + + try: + from classroom_engine.repository import RecurringRoutineRepository + repo = RecurringRoutineRepository(db) + if repo.delete(routine_id): + return {"success": True, "deleted_id": routine_id} + raise HTTPException(status_code=404, detail="Routine nicht gefunden") + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete routine: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") diff --git a/backend-lehrer/classroom/routes/context_static.py b/backend-lehrer/classroom/routes/context_static.py new file mode 100644 index 0000000..e11ec71 --- /dev/null +++ b/backend-lehrer/classroom/routes/context_static.py @@ -0,0 +1,281 @@ +""" +Classroom API - Static Data, Suggestions & Sidebar Routes + +Federal states, school types, macro phases, event/routine types, +suggestions, sidebar model, and school year path. +""" + +from typing import Optional +import logging + +from fastapi import APIRouter, HTTPException, Query, Depends + +from classroom_engine import FEDERAL_STATES, SCHOOL_TYPES + +from ..services.persistence import ( + DB_ENABLED, + SessionLocal, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Context"]) + + +def get_db(): + """Database session dependency.""" + if DB_ENABLED and SessionLocal: + db = SessionLocal() + try: + yield db + finally: + db.close() + else: + yield None + + +# === Static Data Endpoints === + +@router.get("/v1/federal-states") +async def get_federal_states(): + """Gibt alle Bundeslaender zurueck.""" + return { + "federal_states": [{"id": k, "name": v} for k, v in FEDERAL_STATES.items()] + } + + +@router.get("/v1/school-types") +async def get_school_types(): + """Gibt alle Schularten zurueck.""" + return { + "school_types": [{"id": k, "name": v} for k, v in SCHOOL_TYPES.items()] + } + + +@router.get("/v1/macro-phases") +async def get_macro_phases(): + """Gibt alle Makro-Phasen mit Beschreibungen zurueck.""" + phases = [ + {"id": "onboarding", "label": "Einrichtung", "description": "Ersteinrichtung (Klassen, Stundenplan)", "order": 1}, + {"id": "schuljahresstart", "label": "Schuljahresstart", "description": "Erste 2-3 Wochen des Schuljahres", "order": 2}, + {"id": "unterrichtsaufbau", "label": "Unterrichtsaufbau", "description": "Routinen etablieren, erste Bewertungen", "order": 3}, + {"id": "leistungsphase_1", "label": "Leistungsphase 1", "description": "Erste Klassenarbeiten und Klausuren", "order": 4}, + {"id": "halbjahresabschluss", "label": "Halbjahresabschluss", "description": "Notenschluss, Zeugnisse, Konferenzen", "order": 5}, + {"id": "leistungsphase_2", "label": "Leistungsphase 2", "description": "Zweites Halbjahr, Pruefungsvorbereitung", "order": 6}, + {"id": "jahresabschluss", "label": "Jahresabschluss", "description": "Finale Noten, Versetzung, Schuljahresende", "order": 7}, + ] + return {"macro_phases": phases} + + +@router.get("/v1/event-types") +async def get_event_types(): + """Gibt alle Event-Typen zurueck.""" + types = [ + {"id": "exam", "label": "Klassenarbeit/Klausur"}, + {"id": "parent_evening", "label": "Elternabend"}, + {"id": "trip", "label": "Klassenfahrt/Ausflug"}, + {"id": "project", "label": "Projektwoche"}, + {"id": "internship", "label": "Praktikum"}, + {"id": "presentation", "label": "Referate/Praesentationen"}, + {"id": "sports_day", "label": "Sporttag"}, + {"id": "school_festival", "label": "Schulfest"}, + {"id": "parent_consultation", "label": "Elternsprechtag"}, + {"id": "grade_deadline", "label": "Notenschluss"}, + {"id": "report_cards", "label": "Zeugnisausgabe"}, + {"id": "holiday_start", "label": "Ferienbeginn"}, + {"id": "holiday_end", "label": "Ferienende"}, + {"id": "other", "label": "Sonstiges"}, + ] + return {"event_types": types} + + +@router.get("/v1/routine-types") +async def get_routine_types(): + """Gibt alle Routine-Typen zurueck.""" + types = [ + {"id": "teacher_conference", "label": "Lehrerkonferenz"}, + {"id": "subject_conference", "label": "Fachkonferenz"}, + {"id": "office_hours", "label": "Sprechstunde"}, + {"id": "team_meeting", "label": "Teamsitzung"}, + {"id": "supervision", "label": "Pausenaufsicht"}, + {"id": "correction_time", "label": "Korrekturzeit"}, + {"id": "prep_time", "label": "Vorbereitungszeit"}, + {"id": "other", "label": "Sonstiges"}, + ] + return {"routine_types": types} + + +# === Suggestions & Sidebar === + +@router.get("/v1/suggestions") +async def get_suggestions( + teacher_id: str = Query(...), + limit: int = Query(5, ge=1, le=20), + db=Depends(get_db) +): + """Generiert kontextbasierte Vorschlaege fuer einen Lehrer.""" + if DB_ENABLED and db: + try: + from classroom_engine.suggestions import SuggestionGenerator + generator = SuggestionGenerator(db) + result = generator.generate(teacher_id, limit=limit) + return result + except Exception as e: + logger.error(f"Failed to generate suggestions: {e}") + raise HTTPException(status_code=500, detail=f"Fehler: {e}") + + return { + "active_contexts": [], + "suggestions": [], + "signals_summary": { + "macro_phase": "onboarding", + "current_week": 1, + "has_classes": False, + "exams_soon": 0, + "routines_today": 0, + }, + "total_suggestions": 0, + } + + +@router.get("/v1/sidebar") +async def get_sidebar( + teacher_id: str = Query(...), + mode: str = Query("companion"), + db=Depends(get_db) +): + """Generiert das dynamische Sidebar-Model.""" + if mode == "companion": + now_relevant = [] + if DB_ENABLED and db: + try: + from classroom_engine.suggestions import SuggestionGenerator + generator = SuggestionGenerator(db) + result = generator.generate(teacher_id, limit=5) + now_relevant = [ + { + "id": s["id"], + "label": s["title"], + "state": "recommended" if s["priority"] > 70 else "default", + "badge": s.get("badge"), + "icon": s.get("icon", "lightbulb"), + "action_url": s.get("action_url"), + } + for s in result.get("suggestions", []) + ] + except Exception as e: + logger.warning(f"Failed to get suggestions for sidebar: {e}") + + return { + "mode": "companion", + "sections": [ + {"id": "SEARCH", "type": "search_bar", "placeholder": "Suchen..."}, + { + "id": "NOW_RELEVANT", + "type": "list", + "title": "Jetzt relevant", + "items": now_relevant if now_relevant else [ + {"id": "no_suggestions", "label": "Keine Vorschlaege", "state": "default", "icon": "check_circle"} + ], + }, + { + "id": "ALL_MODULES", + "type": "folder", + "label": "Alle Module", + "icon": "folder", + "collapsed": True, + "items": [ + {"id": "lesson", "label": "Stundenmodus", "icon": "timer"}, + {"id": "classes", "label": "Klassen", "icon": "groups"}, + {"id": "exams", "label": "Klausuren", "icon": "quiz"}, + {"id": "grades", "label": "Noten", "icon": "calculate"}, + {"id": "calendar", "label": "Kalender", "icon": "calendar_month"}, + {"id": "materials", "label": "Materialien", "icon": "folder_open"}, + ], + }, + { + "id": "QUICK_ACTIONS", + "type": "actions", + "title": "Kurzaktionen", + "items": [ + {"id": "scan", "label": "Scan hochladen", "icon": "upload_file"}, + {"id": "note", "label": "Notiz erstellen", "icon": "note_add"}, + ], + }, + ], + } + else: + return { + "mode": "classic", + "sections": [ + { + "id": "NAVIGATION", + "type": "tree", + "items": [ + {"id": "dashboard", "label": "Dashboard", "icon": "dashboard", "url": "/dashboard"}, + {"id": "lesson", "label": "Stundenmodus", "icon": "timer", "url": "/lesson"}, + {"id": "classes", "label": "Klassen", "icon": "groups", "url": "/classes"}, + {"id": "exams", "label": "Klausuren", "icon": "quiz", "url": "/exams"}, + {"id": "grades", "label": "Noten", "icon": "calculate", "url": "/grades"}, + {"id": "calendar", "label": "Kalender", "icon": "calendar_month", "url": "/calendar"}, + {"id": "materials", "label": "Materialien", "icon": "folder_open", "url": "/materials"}, + {"id": "settings", "label": "Einstellungen", "icon": "settings", "url": "/settings"}, + ], + }, + ], + } + + +@router.get("/v1/path") +async def get_schoolyear_path(teacher_id: str = Query(...), db=Depends(get_db)): + """Generiert den Schuljahres-Pfad mit Meilensteinen.""" + current_phase = "onboarding" + if DB_ENABLED and db: + try: + from classroom_engine.repository import TeacherContextRepository + repo = TeacherContextRepository(db) + context = repo.get_or_create(teacher_id) + current_phase = context.macro_phase.value + except Exception as e: + logger.warning(f"Failed to get context for path: {e}") + + phase_order = [ + "onboarding", "schuljahresstart", "unterrichtsaufbau", + "leistungsphase_1", "halbjahresabschluss", "leistungsphase_2", "jahresabschluss", + ] + + current_index = phase_order.index(current_phase) if current_phase in phase_order else 0 + + milestones = [ + {"id": "MS_START", "label": "Start", "phase": "onboarding", "icon": "flag"}, + {"id": "MS_SETUP", "label": "Einrichtung", "phase": "schuljahresstart", "icon": "tune"}, + {"id": "MS_ROUTINE", "label": "Routinen", "phase": "unterrichtsaufbau", "icon": "repeat"}, + {"id": "MS_EXAM_1", "label": "Klausuren", "phase": "leistungsphase_1", "icon": "quiz"}, + {"id": "MS_HALFYEAR", "label": "Halbjahr", "phase": "halbjahresabschluss", "icon": "event"}, + {"id": "MS_EXAM_2", "label": "Pruefungen", "phase": "leistungsphase_2", "icon": "school"}, + {"id": "MS_END", "label": "Abschluss", "phase": "jahresabschluss", "icon": "celebration"}, + ] + + for milestone in milestones: + phase = milestone["phase"] + phase_index = phase_order.index(phase) if phase in phase_order else 999 + if phase_index < current_index: + milestone["status"] = "done" + elif phase_index == current_index: + milestone["status"] = "current" + else: + milestone["status"] = "upcoming" + + current_milestone_id = next( + (m["id"] for m in milestones if m["status"] == "current"), + milestones[0]["id"] + ) + + progress = int((current_index / (len(phase_order) - 1)) * 100) if len(phase_order) > 1 else 0 + + return { + "milestones": milestones, + "current_milestone_id": current_milestone_id, + "progress_percent": progress, + "current_phase": current_phase, + } diff --git a/backend-lehrer/llm_gateway/routes/edu_search_crud.py b/backend-lehrer/llm_gateway/routes/edu_search_crud.py new file mode 100644 index 0000000..653ce1a --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/edu_search_crud.py @@ -0,0 +1,386 @@ +""" +EduSearch Seeds CRUD Routes. + +List, get, create, update, delete, and bulk import for seed URLs. +""" + +import os +import logging +from typing import Optional, List +from datetime import datetime + +from fastapi import APIRouter, HTTPException, Query +import asyncpg + +from .edu_search_models import ( + CategoryResponse, + SeedCreate, + SeedUpdate, + SeedResponse, + SeedsListResponse, + BulkImportRequest, + BulkImportResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["edu-search"]) + +# Database connection pool +_pool: Optional[asyncpg.Pool] = None + + +async def get_db_pool() -> asyncpg.Pool: + """Get or create database connection pool.""" + global _pool + if _pool is None: + database_url = os.environ.get("DATABASE_URL") + if not database_url: + raise RuntimeError("DATABASE_URL nicht konfiguriert - bitte via Vault oder Umgebungsvariable setzen") + _pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10) + return _pool + + +@router.get("/categories", response_model=List[CategoryResponse]) +async def list_categories(): + """List all seed categories.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT id, name, display_name, description, icon, sort_order, is_active + FROM edu_search_categories + WHERE is_active = TRUE + ORDER BY sort_order + """) + return [ + CategoryResponse( + id=str(row["id"]), + name=row["name"], + display_name=row["display_name"], + description=row["description"], + icon=row["icon"], + sort_order=row["sort_order"], + is_active=row["is_active"], + ) + for row in rows + ] + + +@router.get("/seeds", response_model=SeedsListResponse) +async def list_seeds( + category: Optional[str] = Query(None, description="Filter by category name"), + state: Optional[str] = Query(None, description="Filter by state code"), + enabled: Optional[bool] = Query(None, description="Filter by enabled status"), + search: Optional[str] = Query(None, description="Search in name/url"), + page: int = Query(1, ge=1), + page_size: int = Query(50, ge=1, le=200), +): + """List seeds with optional filtering and pagination.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + # Build WHERE clause + conditions = [] + params = [] + param_idx = 1 + + if category: + conditions.append(f"c.name = ${param_idx}") + params.append(category) + param_idx += 1 + + if state: + conditions.append(f"s.state = ${param_idx}") + params.append(state) + param_idx += 1 + + if enabled is not None: + conditions.append(f"s.enabled = ${param_idx}") + params.append(enabled) + param_idx += 1 + + if search: + conditions.append(f"(s.name ILIKE ${param_idx} OR s.url ILIKE ${param_idx})") + params.append(f"%{search}%") + param_idx += 1 + + where_clause = " AND ".join(conditions) if conditions else "TRUE" + + # Count total + count_query = f""" + SELECT COUNT(*) FROM edu_search_seeds s + LEFT JOIN edu_search_categories c ON s.category_id = c.id + WHERE {where_clause} + """ + total = await conn.fetchval(count_query, *params) + + # Get paginated results + offset = (page - 1) * page_size + params.extend([page_size, offset]) + + query = f""" + SELECT + s.id, s.url, s.name, s.description, + c.name as category, c.display_name as category_display_name, + s.source_type, s.scope, s.state, s.trust_boost, s.enabled, + s.crawl_depth, s.crawl_frequency, s.last_crawled_at, + s.last_crawl_status, s.last_crawl_docs, s.total_documents, + s.created_at, s.updated_at + FROM edu_search_seeds s + LEFT JOIN edu_search_categories c ON s.category_id = c.id + WHERE {where_clause} + ORDER BY c.sort_order, s.name + LIMIT ${param_idx} OFFSET ${param_idx + 1} + """ + + rows = await conn.fetch(query, *params) + + seeds = [_row_to_seed_response(row) for row in rows] + + return SeedsListResponse( + seeds=seeds, + total=total, + page=page, + page_size=page_size, + ) + + +@router.get("/seeds/{seed_id}", response_model=SeedResponse) +async def get_seed(seed_id: str): + """Get a single seed by ID.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + SELECT + s.id, s.url, s.name, s.description, + c.name as category, c.display_name as category_display_name, + s.source_type, s.scope, s.state, s.trust_boost, s.enabled, + s.crawl_depth, s.crawl_frequency, s.last_crawled_at, + s.last_crawl_status, s.last_crawl_docs, s.total_documents, + s.created_at, s.updated_at + FROM edu_search_seeds s + LEFT JOIN edu_search_categories c ON s.category_id = c.id + WHERE s.id = $1 + """, seed_id) + + if not row: + raise HTTPException(status_code=404, detail="Seed nicht gefunden") + + return _row_to_seed_response(row) + + +@router.post("/seeds", response_model=SeedResponse, status_code=201) +async def create_seed(seed: SeedCreate): + """Create a new seed URL.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + category_id = None + if seed.category_name: + category_id = await conn.fetchval( + "SELECT id FROM edu_search_categories WHERE name = $1", + seed.category_name + ) + + try: + row = await conn.fetchrow(""" + INSERT INTO edu_search_seeds ( + url, name, description, category_id, source_type, scope, + state, trust_boost, enabled, crawl_depth, crawl_frequency + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + RETURNING id, created_at, updated_at + """, + seed.url, seed.name, seed.description, category_id, + seed.source_type, seed.scope, seed.state, seed.trust_boost, + seed.enabled, seed.crawl_depth, seed.crawl_frequency + ) + except asyncpg.UniqueViolationError: + raise HTTPException(status_code=409, detail="URL existiert bereits") + + return SeedResponse( + id=str(row["id"]), + url=seed.url, + name=seed.name, + description=seed.description, + category=seed.category_name, + category_display_name=None, + source_type=seed.source_type, + scope=seed.scope, + state=seed.state, + trust_boost=seed.trust_boost, + enabled=seed.enabled, + crawl_depth=seed.crawl_depth, + crawl_frequency=seed.crawl_frequency, + last_crawled_at=None, + last_crawl_status=None, + last_crawl_docs=0, + total_documents=0, + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + +@router.put("/seeds/{seed_id}", response_model=SeedResponse) +async def update_seed(seed_id: str, seed: SeedUpdate): + """Update an existing seed.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + updates = [] + params = [] + param_idx = 1 + + if seed.url is not None: + updates.append(f"url = ${param_idx}") + params.append(seed.url) + param_idx += 1 + + if seed.name is not None: + updates.append(f"name = ${param_idx}") + params.append(seed.name) + param_idx += 1 + + if seed.description is not None: + updates.append(f"description = ${param_idx}") + params.append(seed.description) + param_idx += 1 + + if seed.category_name is not None: + category_id = await conn.fetchval( + "SELECT id FROM edu_search_categories WHERE name = $1", + seed.category_name + ) + updates.append(f"category_id = ${param_idx}") + params.append(category_id) + param_idx += 1 + + if seed.source_type is not None: + updates.append(f"source_type = ${param_idx}") + params.append(seed.source_type) + param_idx += 1 + + if seed.scope is not None: + updates.append(f"scope = ${param_idx}") + params.append(seed.scope) + param_idx += 1 + + if seed.state is not None: + updates.append(f"state = ${param_idx}") + params.append(seed.state) + param_idx += 1 + + if seed.trust_boost is not None: + updates.append(f"trust_boost = ${param_idx}") + params.append(seed.trust_boost) + param_idx += 1 + + if seed.enabled is not None: + updates.append(f"enabled = ${param_idx}") + params.append(seed.enabled) + param_idx += 1 + + if seed.crawl_depth is not None: + updates.append(f"crawl_depth = ${param_idx}") + params.append(seed.crawl_depth) + param_idx += 1 + + if seed.crawl_frequency is not None: + updates.append(f"crawl_frequency = ${param_idx}") + params.append(seed.crawl_frequency) + param_idx += 1 + + if not updates: + raise HTTPException(status_code=400, detail="Keine Felder zum Aktualisieren") + + updates.append("updated_at = NOW()") + params.append(seed_id) + + query = f""" + UPDATE edu_search_seeds + SET {", ".join(updates)} + WHERE id = ${param_idx} + RETURNING id + """ + + result = await conn.fetchrow(query, *params) + if not result: + raise HTTPException(status_code=404, detail="Seed nicht gefunden") + + # Return updated seed + return await get_seed(seed_id) + + +@router.delete("/seeds/{seed_id}") +async def delete_seed(seed_id: str): + """Delete a seed.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM edu_search_seeds WHERE id = $1", + seed_id + ) + if result == "DELETE 0": + raise HTTPException(status_code=404, detail="Seed nicht gefunden") + + return {"status": "deleted", "id": seed_id} + + +@router.post("/seeds/bulk-import", response_model=BulkImportResponse) +async def bulk_import_seeds(request: BulkImportRequest): + """Bulk import seeds (skip duplicates).""" + pool = await get_db_pool() + imported = 0 + skipped = 0 + errors = [] + + async with pool.acquire() as conn: + # Pre-fetch all category IDs + categories = {} + rows = await conn.fetch("SELECT id, name FROM edu_search_categories") + for row in rows: + categories[row["name"]] = row["id"] + + for seed in request.seeds: + try: + category_id = categories.get(seed.category_name) if seed.category_name else None + + await conn.execute(""" + INSERT INTO edu_search_seeds ( + url, name, description, category_id, source_type, scope, + state, trust_boost, enabled, crawl_depth, crawl_frequency + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ON CONFLICT (url) DO NOTHING + """, + seed.url, seed.name, seed.description, category_id, + seed.source_type, seed.scope, seed.state, seed.trust_boost, + seed.enabled, seed.crawl_depth, seed.crawl_frequency + ) + imported += 1 + except asyncpg.UniqueViolationError: + skipped += 1 + except Exception as e: + errors.append(f"{seed.url}: {str(e)}") + + return BulkImportResponse(imported=imported, skipped=skipped, errors=errors) + + +def _row_to_seed_response(row) -> SeedResponse: + """Convert a database row to SeedResponse.""" + return SeedResponse( + id=str(row["id"]), + url=row["url"], + name=row["name"], + description=row["description"], + category=row["category"], + category_display_name=row["category_display_name"], + source_type=row["source_type"], + scope=row["scope"], + state=row["state"], + trust_boost=float(row["trust_boost"]), + enabled=row["enabled"], + crawl_depth=row["crawl_depth"], + crawl_frequency=row["crawl_frequency"], + last_crawled_at=row["last_crawled_at"], + last_crawl_status=row["last_crawl_status"], + last_crawl_docs=row["last_crawl_docs"] or 0, + total_documents=row["total_documents"] or 0, + created_at=row["created_at"], + updated_at=row["updated_at"], + ) diff --git a/backend-lehrer/llm_gateway/routes/edu_search_models.py b/backend-lehrer/llm_gateway/routes/edu_search_models.py new file mode 100644 index 0000000..2a48a7e --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/edu_search_models.py @@ -0,0 +1,137 @@ +""" +EduSearch Seeds Pydantic Models. + +Request/Response models for the education search seed URL API. +""" + +from typing import Optional, List +from datetime import datetime + +from pydantic import BaseModel, Field + + +class CategoryResponse(BaseModel): + """Category response model.""" + id: str + name: str + display_name: str + description: Optional[str] = None + icon: Optional[str] = None + sort_order: int + is_active: bool + + +class SeedBase(BaseModel): + """Base seed model for creation/update.""" + url: str = Field(..., max_length=500) + name: str = Field(..., max_length=255) + description: Optional[str] = None + category_name: Optional[str] = Field(None, description="Category name (federal, states, etc.)") + source_type: str = Field("GOV", description="GOV, EDU, UNI, etc.") + scope: str = Field("FEDERAL", description="FEDERAL, STATE, etc.") + state: Optional[str] = Field(None, max_length=5, description="State code (BW, BY, etc.)") + trust_boost: float = Field(0.50, ge=0.0, le=1.0) + enabled: bool = True + crawl_depth: int = Field(2, ge=1, le=5) + crawl_frequency: str = Field("weekly", description="hourly, daily, weekly, monthly") + + +class SeedCreate(SeedBase): + """Seed creation model.""" + pass + + +class SeedUpdate(BaseModel): + """Seed update model (all fields optional).""" + url: Optional[str] = Field(None, max_length=500) + name: Optional[str] = Field(None, max_length=255) + description: Optional[str] = None + category_name: Optional[str] = None + source_type: Optional[str] = None + scope: Optional[str] = None + state: Optional[str] = Field(None, max_length=5) + trust_boost: Optional[float] = Field(None, ge=0.0, le=1.0) + enabled: Optional[bool] = None + crawl_depth: Optional[int] = Field(None, ge=1, le=5) + crawl_frequency: Optional[str] = None + + +class SeedResponse(BaseModel): + """Seed response model.""" + id: str + url: str + name: str + description: Optional[str] = None + category: Optional[str] = None + category_display_name: Optional[str] = None + source_type: str + scope: str + state: Optional[str] = None + trust_boost: float + enabled: bool + crawl_depth: int + crawl_frequency: str + last_crawled_at: Optional[datetime] = None + last_crawl_status: Optional[str] = None + last_crawl_docs: int = 0 + total_documents: int = 0 + created_at: datetime + updated_at: datetime + + +class SeedsListResponse(BaseModel): + """List response with pagination info.""" + seeds: List[SeedResponse] + total: int + page: int + page_size: int + + +class StatsResponse(BaseModel): + """Crawl statistics response.""" + total_seeds: int + enabled_seeds: int + total_documents: int + seeds_by_category: dict + seeds_by_state: dict + last_crawl_time: Optional[datetime] = None + + +class BulkImportRequest(BaseModel): + """Bulk import request.""" + seeds: List[SeedCreate] + + +class BulkImportResponse(BaseModel): + """Bulk import response.""" + imported: int + skipped: int + errors: List[str] + + +class CrawlStatusUpdate(BaseModel): + """Crawl status update from edu-search-service.""" + seed_url: str = Field(..., description="The seed URL that was crawled") + status: str = Field(..., description="Crawl status: success, error, partial") + documents_crawled: int = Field(0, ge=0, description="Number of documents crawled") + error_message: Optional[str] = Field(None, description="Error message if status is error") + crawl_duration_seconds: float = Field(0.0, ge=0.0, description="Duration of the crawl in seconds") + + +class CrawlStatusResponse(BaseModel): + """Response for crawl status update.""" + success: bool + seed_url: str + message: str + + +class BulkCrawlStatusUpdate(BaseModel): + """Bulk crawl status update.""" + updates: List[CrawlStatusUpdate] + + +class BulkCrawlStatusResponse(BaseModel): + """Response for bulk crawl status update.""" + updated: int + failed: int + errors: List[str] diff --git a/backend-lehrer/llm_gateway/routes/edu_search_seeds.py b/backend-lehrer/llm_gateway/routes/edu_search_seeds.py index 8c3134a..2a42324 100644 --- a/backend-lehrer/llm_gateway/routes/edu_search_seeds.py +++ b/backend-lehrer/llm_gateway/routes/edu_search_seeds.py @@ -1,710 +1,58 @@ """ -EduSearch Seeds API Routes. +EduSearch Seeds API Routes — Barrel Re-export. + +Split into submodules: +- edu_search_models.py — Pydantic request/response models +- edu_search_crud.py — CRUD endpoints (list, get, create, update, delete, bulk import) +- edu_search_status.py — Stats, export for crawler, crawl status feedback CRUD operations for managing education search crawler seed URLs. Direct database access to PostgreSQL. """ -import os -import logging -from typing import Optional, List -from datetime import datetime -from uuid import UUID +from fastapi import APIRouter -from fastapi import APIRouter, HTTPException, Depends, Query -from pydantic import BaseModel, Field, HttpUrl -import asyncpg +from .edu_search_crud import router as _crud_router, get_db_pool +from .edu_search_status import router as _status_router -logger = logging.getLogger(__name__) +# Re-export models for consumers that import types from this module +from .edu_search_models import ( + CategoryResponse, + SeedBase, + SeedCreate, + SeedUpdate, + SeedResponse, + SeedsListResponse, + StatsResponse, + BulkImportRequest, + BulkImportResponse, + CrawlStatusUpdate, + CrawlStatusResponse, + BulkCrawlStatusUpdate, + BulkCrawlStatusResponse, +) +# Combine both sub-routers into a single router for backwards compatibility. +# The consumer imports `from .edu_search_seeds import router as edu_search_seeds_router`. router = APIRouter(prefix="/edu-search", tags=["edu-search"]) - -# Database connection pool -_pool: Optional[asyncpg.Pool] = None - - -async def get_db_pool() -> asyncpg.Pool: - """Get or create database connection pool.""" - global _pool - if _pool is None: - database_url = os.environ.get("DATABASE_URL") - if not database_url: - raise RuntimeError("DATABASE_URL nicht konfiguriert - bitte via Vault oder Umgebungsvariable setzen") - _pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10) - return _pool - - -# ============================================================================= -# Pydantic Models -# ============================================================================= - - -class CategoryResponse(BaseModel): - """Category response model.""" - id: str - name: str - display_name: str - description: Optional[str] = None - icon: Optional[str] = None - sort_order: int - is_active: bool - - -class SeedBase(BaseModel): - """Base seed model for creation/update.""" - url: str = Field(..., max_length=500) - name: str = Field(..., max_length=255) - description: Optional[str] = None - category_name: Optional[str] = Field(None, description="Category name (federal, states, etc.)") - source_type: str = Field("GOV", description="GOV, EDU, UNI, etc.") - scope: str = Field("FEDERAL", description="FEDERAL, STATE, etc.") - state: Optional[str] = Field(None, max_length=5, description="State code (BW, BY, etc.)") - trust_boost: float = Field(0.50, ge=0.0, le=1.0) - enabled: bool = True - crawl_depth: int = Field(2, ge=1, le=5) - crawl_frequency: str = Field("weekly", description="hourly, daily, weekly, monthly") - - -class SeedCreate(SeedBase): - """Seed creation model.""" - pass - - -class SeedUpdate(BaseModel): - """Seed update model (all fields optional).""" - url: Optional[str] = Field(None, max_length=500) - name: Optional[str] = Field(None, max_length=255) - description: Optional[str] = None - category_name: Optional[str] = None - source_type: Optional[str] = None - scope: Optional[str] = None - state: Optional[str] = Field(None, max_length=5) - trust_boost: Optional[float] = Field(None, ge=0.0, le=1.0) - enabled: Optional[bool] = None - crawl_depth: Optional[int] = Field(None, ge=1, le=5) - crawl_frequency: Optional[str] = None - - -class SeedResponse(BaseModel): - """Seed response model.""" - id: str - url: str - name: str - description: Optional[str] = None - category: Optional[str] = None - category_display_name: Optional[str] = None - source_type: str - scope: str - state: Optional[str] = None - trust_boost: float - enabled: bool - crawl_depth: int - crawl_frequency: str - last_crawled_at: Optional[datetime] = None - last_crawl_status: Optional[str] = None - last_crawl_docs: int = 0 - total_documents: int = 0 - created_at: datetime - updated_at: datetime - - -class SeedsListResponse(BaseModel): - """List response with pagination info.""" - seeds: List[SeedResponse] - total: int - page: int - page_size: int - - -class StatsResponse(BaseModel): - """Crawl statistics response.""" - total_seeds: int - enabled_seeds: int - total_documents: int - seeds_by_category: dict - seeds_by_state: dict - last_crawl_time: Optional[datetime] = None - - -class BulkImportRequest(BaseModel): - """Bulk import request.""" - seeds: List[SeedCreate] - - -class BulkImportResponse(BaseModel): - """Bulk import response.""" - imported: int - skipped: int - errors: List[str] - - -# ============================================================================= -# API Endpoints -# ============================================================================= - - -@router.get("/categories", response_model=List[CategoryResponse]) -async def list_categories(): - """List all seed categories.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - rows = await conn.fetch(""" - SELECT id, name, display_name, description, icon, sort_order, is_active - FROM edu_search_categories - WHERE is_active = TRUE - ORDER BY sort_order - """) - return [ - CategoryResponse( - id=str(row["id"]), - name=row["name"], - display_name=row["display_name"], - description=row["description"], - icon=row["icon"], - sort_order=row["sort_order"], - is_active=row["is_active"], - ) - for row in rows - ] - - -@router.get("/seeds", response_model=SeedsListResponse) -async def list_seeds( - category: Optional[str] = Query(None, description="Filter by category name"), - state: Optional[str] = Query(None, description="Filter by state code"), - enabled: Optional[bool] = Query(None, description="Filter by enabled status"), - search: Optional[str] = Query(None, description="Search in name/url"), - page: int = Query(1, ge=1), - page_size: int = Query(50, ge=1, le=200), -): - """List seeds with optional filtering and pagination.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - # Build WHERE clause - conditions = [] - params = [] - param_idx = 1 - - if category: - conditions.append(f"c.name = ${param_idx}") - params.append(category) - param_idx += 1 - - if state: - conditions.append(f"s.state = ${param_idx}") - params.append(state) - param_idx += 1 - - if enabled is not None: - conditions.append(f"s.enabled = ${param_idx}") - params.append(enabled) - param_idx += 1 - - if search: - conditions.append(f"(s.name ILIKE ${param_idx} OR s.url ILIKE ${param_idx})") - params.append(f"%{search}%") - param_idx += 1 - - where_clause = " AND ".join(conditions) if conditions else "TRUE" - - # Count total - count_query = f""" - SELECT COUNT(*) FROM edu_search_seeds s - LEFT JOIN edu_search_categories c ON s.category_id = c.id - WHERE {where_clause} - """ - total = await conn.fetchval(count_query, *params) - - # Get paginated results - offset = (page - 1) * page_size - params.extend([page_size, offset]) - - query = f""" - SELECT - s.id, s.url, s.name, s.description, - c.name as category, c.display_name as category_display_name, - s.source_type, s.scope, s.state, s.trust_boost, s.enabled, - s.crawl_depth, s.crawl_frequency, s.last_crawled_at, - s.last_crawl_status, s.last_crawl_docs, s.total_documents, - s.created_at, s.updated_at - FROM edu_search_seeds s - LEFT JOIN edu_search_categories c ON s.category_id = c.id - WHERE {where_clause} - ORDER BY c.sort_order, s.name - LIMIT ${param_idx} OFFSET ${param_idx + 1} - """ - - rows = await conn.fetch(query, *params) - - seeds = [ - SeedResponse( - id=str(row["id"]), - url=row["url"], - name=row["name"], - description=row["description"], - category=row["category"], - category_display_name=row["category_display_name"], - source_type=row["source_type"], - scope=row["scope"], - state=row["state"], - trust_boost=float(row["trust_boost"]), - enabled=row["enabled"], - crawl_depth=row["crawl_depth"], - crawl_frequency=row["crawl_frequency"], - last_crawled_at=row["last_crawled_at"], - last_crawl_status=row["last_crawl_status"], - last_crawl_docs=row["last_crawl_docs"] or 0, - total_documents=row["total_documents"] or 0, - created_at=row["created_at"], - updated_at=row["updated_at"], - ) - for row in rows - ] - - return SeedsListResponse( - seeds=seeds, - total=total, - page=page, - page_size=page_size, - ) - - -@router.get("/seeds/{seed_id}", response_model=SeedResponse) -async def get_seed(seed_id: str): - """Get a single seed by ID.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow(""" - SELECT - s.id, s.url, s.name, s.description, - c.name as category, c.display_name as category_display_name, - s.source_type, s.scope, s.state, s.trust_boost, s.enabled, - s.crawl_depth, s.crawl_frequency, s.last_crawled_at, - s.last_crawl_status, s.last_crawl_docs, s.total_documents, - s.created_at, s.updated_at - FROM edu_search_seeds s - LEFT JOIN edu_search_categories c ON s.category_id = c.id - WHERE s.id = $1 - """, seed_id) - - if not row: - raise HTTPException(status_code=404, detail="Seed nicht gefunden") - - return SeedResponse( - id=str(row["id"]), - url=row["url"], - name=row["name"], - description=row["description"], - category=row["category"], - category_display_name=row["category_display_name"], - source_type=row["source_type"], - scope=row["scope"], - state=row["state"], - trust_boost=float(row["trust_boost"]), - enabled=row["enabled"], - crawl_depth=row["crawl_depth"], - crawl_frequency=row["crawl_frequency"], - last_crawled_at=row["last_crawled_at"], - last_crawl_status=row["last_crawl_status"], - last_crawl_docs=row["last_crawl_docs"] or 0, - total_documents=row["total_documents"] or 0, - created_at=row["created_at"], - updated_at=row["updated_at"], - ) - - -@router.post("/seeds", response_model=SeedResponse, status_code=201) -async def create_seed(seed: SeedCreate): - """Create a new seed URL.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - # Get category ID if provided - category_id = None - if seed.category_name: - category_id = await conn.fetchval( - "SELECT id FROM edu_search_categories WHERE name = $1", - seed.category_name - ) - - try: - row = await conn.fetchrow(""" - INSERT INTO edu_search_seeds ( - url, name, description, category_id, source_type, scope, - state, trust_boost, enabled, crawl_depth, crawl_frequency - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - RETURNING id, created_at, updated_at - """, - seed.url, seed.name, seed.description, category_id, - seed.source_type, seed.scope, seed.state, seed.trust_boost, - seed.enabled, seed.crawl_depth, seed.crawl_frequency - ) - except asyncpg.UniqueViolationError: - raise HTTPException(status_code=409, detail="URL existiert bereits") - - return SeedResponse( - id=str(row["id"]), - url=seed.url, - name=seed.name, - description=seed.description, - category=seed.category_name, - category_display_name=None, - source_type=seed.source_type, - scope=seed.scope, - state=seed.state, - trust_boost=seed.trust_boost, - enabled=seed.enabled, - crawl_depth=seed.crawl_depth, - crawl_frequency=seed.crawl_frequency, - last_crawled_at=None, - last_crawl_status=None, - last_crawl_docs=0, - total_documents=0, - created_at=row["created_at"], - updated_at=row["updated_at"], - ) - - -@router.put("/seeds/{seed_id}", response_model=SeedResponse) -async def update_seed(seed_id: str, seed: SeedUpdate): - """Update an existing seed.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - # Build update statement dynamically - updates = [] - params = [] - param_idx = 1 - - if seed.url is not None: - updates.append(f"url = ${param_idx}") - params.append(seed.url) - param_idx += 1 - - if seed.name is not None: - updates.append(f"name = ${param_idx}") - params.append(seed.name) - param_idx += 1 - - if seed.description is not None: - updates.append(f"description = ${param_idx}") - params.append(seed.description) - param_idx += 1 - - if seed.category_name is not None: - category_id = await conn.fetchval( - "SELECT id FROM edu_search_categories WHERE name = $1", - seed.category_name - ) - updates.append(f"category_id = ${param_idx}") - params.append(category_id) - param_idx += 1 - - if seed.source_type is not None: - updates.append(f"source_type = ${param_idx}") - params.append(seed.source_type) - param_idx += 1 - - if seed.scope is not None: - updates.append(f"scope = ${param_idx}") - params.append(seed.scope) - param_idx += 1 - - if seed.state is not None: - updates.append(f"state = ${param_idx}") - params.append(seed.state) - param_idx += 1 - - if seed.trust_boost is not None: - updates.append(f"trust_boost = ${param_idx}") - params.append(seed.trust_boost) - param_idx += 1 - - if seed.enabled is not None: - updates.append(f"enabled = ${param_idx}") - params.append(seed.enabled) - param_idx += 1 - - if seed.crawl_depth is not None: - updates.append(f"crawl_depth = ${param_idx}") - params.append(seed.crawl_depth) - param_idx += 1 - - if seed.crawl_frequency is not None: - updates.append(f"crawl_frequency = ${param_idx}") - params.append(seed.crawl_frequency) - param_idx += 1 - - if not updates: - raise HTTPException(status_code=400, detail="Keine Felder zum Aktualisieren") - - updates.append("updated_at = NOW()") - params.append(seed_id) - - query = f""" - UPDATE edu_search_seeds - SET {", ".join(updates)} - WHERE id = ${param_idx} - RETURNING id - """ - - result = await conn.fetchrow(query, *params) - if not result: - raise HTTPException(status_code=404, detail="Seed nicht gefunden") - - # Return updated seed - return await get_seed(seed_id) - - -@router.delete("/seeds/{seed_id}") -async def delete_seed(seed_id: str): - """Delete a seed.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - result = await conn.execute( - "DELETE FROM edu_search_seeds WHERE id = $1", - seed_id - ) - if result == "DELETE 0": - raise HTTPException(status_code=404, detail="Seed nicht gefunden") - - return {"status": "deleted", "id": seed_id} - - -@router.post("/seeds/bulk-import", response_model=BulkImportResponse) -async def bulk_import_seeds(request: BulkImportRequest): - """Bulk import seeds (skip duplicates).""" - pool = await get_db_pool() - imported = 0 - skipped = 0 - errors = [] - - async with pool.acquire() as conn: - # Pre-fetch all category IDs - categories = {} - rows = await conn.fetch("SELECT id, name FROM edu_search_categories") - for row in rows: - categories[row["name"]] = row["id"] - - for seed in request.seeds: - try: - category_id = categories.get(seed.category_name) if seed.category_name else None - - await conn.execute(""" - INSERT INTO edu_search_seeds ( - url, name, description, category_id, source_type, scope, - state, trust_boost, enabled, crawl_depth, crawl_frequency - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - ON CONFLICT (url) DO NOTHING - """, - seed.url, seed.name, seed.description, category_id, - seed.source_type, seed.scope, seed.state, seed.trust_boost, - seed.enabled, seed.crawl_depth, seed.crawl_frequency - ) - imported += 1 - except asyncpg.UniqueViolationError: - skipped += 1 - except Exception as e: - errors.append(f"{seed.url}: {str(e)}") - - return BulkImportResponse(imported=imported, skipped=skipped, errors=errors) - - -@router.get("/stats", response_model=StatsResponse) -async def get_stats(): - """Get crawl statistics.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - # Basic counts - total = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds") - enabled = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds WHERE enabled = TRUE") - total_docs = await conn.fetchval("SELECT COALESCE(SUM(total_documents), 0) FROM edu_search_seeds") - - # By category - cat_rows = await conn.fetch(""" - SELECT c.name, COUNT(s.id) as count - FROM edu_search_categories c - LEFT JOIN edu_search_seeds s ON c.id = s.category_id - GROUP BY c.name - """) - by_category = {row["name"]: row["count"] for row in cat_rows} - - # By state - state_rows = await conn.fetch(""" - SELECT COALESCE(state, 'federal') as state, COUNT(*) as count - FROM edu_search_seeds - GROUP BY state - """) - by_state = {row["state"]: row["count"] for row in state_rows} - - # Last crawl time - last_crawl = await conn.fetchval( - "SELECT MAX(last_crawled_at) FROM edu_search_seeds" - ) - - return StatsResponse( - total_seeds=total, - enabled_seeds=enabled, - total_documents=total_docs, - seeds_by_category=by_category, - seeds_by_state=by_state, - last_crawl_time=last_crawl, - ) - - -# Export for external use (edu-search-service) -@router.get("/seeds/export/for-crawler") -async def export_seeds_for_crawler(): - """Export enabled seeds in format suitable for crawler.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - rows = await conn.fetch(""" - SELECT - s.url, s.trust_boost, s.source_type, s.scope, s.state, - s.crawl_depth, c.name as category - FROM edu_search_seeds s - LEFT JOIN edu_search_categories c ON s.category_id = c.id - WHERE s.enabled = TRUE - ORDER BY s.trust_boost DESC - """) - - return { - "seeds": [ - { - "url": row["url"], - "trust": float(row["trust_boost"]), - "source": row["source_type"], - "scope": row["scope"], - "state": row["state"], - "depth": row["crawl_depth"], - "category": row["category"], - } - for row in rows - ], - "total": len(rows), - "exported_at": datetime.utcnow().isoformat(), - } - - -# ============================================================================= -# Crawl Status Feedback (from edu-search-service) -# ============================================================================= - - -class CrawlStatusUpdate(BaseModel): - """Crawl status update from edu-search-service.""" - seed_url: str = Field(..., description="The seed URL that was crawled") - status: str = Field(..., description="Crawl status: success, error, partial") - documents_crawled: int = Field(0, ge=0, description="Number of documents crawled") - error_message: Optional[str] = Field(None, description="Error message if status is error") - crawl_duration_seconds: float = Field(0.0, ge=0.0, description="Duration of the crawl in seconds") - - -class CrawlStatusResponse(BaseModel): - """Response for crawl status update.""" - success: bool - seed_url: str - message: str - - -@router.post("/seeds/crawl-status", response_model=CrawlStatusResponse) -async def update_crawl_status(update: CrawlStatusUpdate): - """Update crawl status for a seed URL (called by edu-search-service).""" - pool = await get_db_pool() - async with pool.acquire() as conn: - # Find the seed by URL - seed = await conn.fetchrow( - "SELECT id, total_documents FROM edu_search_seeds WHERE url = $1", - update.seed_url - ) - - if not seed: - raise HTTPException( - status_code=404, - detail=f"Seed nicht gefunden: {update.seed_url}" - ) - - # Update the seed with crawl status - new_total = (seed["total_documents"] or 0) + update.documents_crawled - - await conn.execute(""" - UPDATE edu_search_seeds - SET - last_crawled_at = NOW(), - last_crawl_status = $2, - last_crawl_docs = $3, - total_documents = $4, - updated_at = NOW() - WHERE id = $1 - """, seed["id"], update.status, update.documents_crawled, new_total) - - logger.info( - f"Crawl status updated: {update.seed_url} - " - f"status={update.status}, docs={update.documents_crawled}, " - f"duration={update.crawl_duration_seconds:.1f}s" - ) - - return CrawlStatusResponse( - success=True, - seed_url=update.seed_url, - message=f"Status aktualisiert: {update.documents_crawled} Dokumente gecrawlt" - ) - - -class BulkCrawlStatusUpdate(BaseModel): - """Bulk crawl status update.""" - updates: List[CrawlStatusUpdate] - - -class BulkCrawlStatusResponse(BaseModel): - """Response for bulk crawl status update.""" - updated: int - failed: int - errors: List[str] - - -@router.post("/seeds/crawl-status/bulk", response_model=BulkCrawlStatusResponse) -async def bulk_update_crawl_status(request: BulkCrawlStatusUpdate): - """Bulk update crawl status for multiple seeds.""" - pool = await get_db_pool() - updated = 0 - failed = 0 - errors = [] - - async with pool.acquire() as conn: - for update in request.updates: - try: - seed = await conn.fetchrow( - "SELECT id, total_documents FROM edu_search_seeds WHERE url = $1", - update.seed_url - ) - - if not seed: - failed += 1 - errors.append(f"Seed nicht gefunden: {update.seed_url}") - continue - - new_total = (seed["total_documents"] or 0) + update.documents_crawled - - await conn.execute(""" - UPDATE edu_search_seeds - SET - last_crawled_at = NOW(), - last_crawl_status = $2, - last_crawl_docs = $3, - total_documents = $4, - updated_at = NOW() - WHERE id = $1 - """, seed["id"], update.status, update.documents_crawled, new_total) - - updated += 1 - - except Exception as e: - failed += 1 - errors.append(f"{update.seed_url}: {str(e)}") - - logger.info(f"Bulk crawl status update: {updated} updated, {failed} failed") - - return BulkCrawlStatusResponse( - updated=updated, - failed=failed, - errors=errors - ) +router.include_router(_crud_router) +router.include_router(_status_router) + +__all__ = [ + "router", + "get_db_pool", + # Models + "CategoryResponse", + "SeedBase", + "SeedCreate", + "SeedUpdate", + "SeedResponse", + "SeedsListResponse", + "StatsResponse", + "BulkImportRequest", + "BulkImportResponse", + "CrawlStatusUpdate", + "CrawlStatusResponse", + "BulkCrawlStatusUpdate", + "BulkCrawlStatusResponse", +] diff --git a/backend-lehrer/llm_gateway/routes/edu_search_status.py b/backend-lehrer/llm_gateway/routes/edu_search_status.py new file mode 100644 index 0000000..30d2a23 --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/edu_search_status.py @@ -0,0 +1,198 @@ +""" +EduSearch Seeds Stats & Crawl Status Routes. + +Statistics, export for crawler, and crawl status feedback endpoints. +""" + +import logging +from typing import List +from datetime import datetime + +from fastapi import APIRouter, HTTPException +import asyncpg + +from .edu_search_models import ( + StatsResponse, + CrawlStatusUpdate, + CrawlStatusResponse, + BulkCrawlStatusUpdate, + BulkCrawlStatusResponse, +) +from .edu_search_crud import get_db_pool + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["edu-search"]) + + +@router.get("/stats", response_model=StatsResponse) +async def get_stats(): + """Get crawl statistics.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + # Basic counts + total = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds") + enabled = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds WHERE enabled = TRUE") + total_docs = await conn.fetchval("SELECT COALESCE(SUM(total_documents), 0) FROM edu_search_seeds") + + # By category + cat_rows = await conn.fetch(""" + SELECT c.name, COUNT(s.id) as count + FROM edu_search_categories c + LEFT JOIN edu_search_seeds s ON c.id = s.category_id + GROUP BY c.name + """) + by_category = {row["name"]: row["count"] for row in cat_rows} + + # By state + state_rows = await conn.fetch(""" + SELECT COALESCE(state, 'federal') as state, COUNT(*) as count + FROM edu_search_seeds + GROUP BY state + """) + by_state = {row["state"]: row["count"] for row in state_rows} + + # Last crawl time + last_crawl = await conn.fetchval( + "SELECT MAX(last_crawled_at) FROM edu_search_seeds" + ) + + return StatsResponse( + total_seeds=total, + enabled_seeds=enabled, + total_documents=total_docs, + seeds_by_category=by_category, + seeds_by_state=by_state, + last_crawl_time=last_crawl, + ) + + +# Export for external use (edu-search-service) +@router.get("/seeds/export/for-crawler") +async def export_seeds_for_crawler(): + """Export enabled seeds in format suitable for crawler.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT + s.url, s.trust_boost, s.source_type, s.scope, s.state, + s.crawl_depth, c.name as category + FROM edu_search_seeds s + LEFT JOIN edu_search_categories c ON s.category_id = c.id + WHERE s.enabled = TRUE + ORDER BY s.trust_boost DESC + """) + + return { + "seeds": [ + { + "url": row["url"], + "trust": float(row["trust_boost"]), + "source": row["source_type"], + "scope": row["scope"], + "state": row["state"], + "depth": row["crawl_depth"], + "category": row["category"], + } + for row in rows + ], + "total": len(rows), + "exported_at": datetime.utcnow().isoformat(), + } + + +# ============================================================================= +# Crawl Status Feedback (from edu-search-service) +# ============================================================================= + +@router.post("/seeds/crawl-status", response_model=CrawlStatusResponse) +async def update_crawl_status(update: CrawlStatusUpdate): + """Update crawl status for a seed URL (called by edu-search-service).""" + pool = await get_db_pool() + async with pool.acquire() as conn: + # Find the seed by URL + seed = await conn.fetchrow( + "SELECT id, total_documents FROM edu_search_seeds WHERE url = $1", + update.seed_url + ) + + if not seed: + raise HTTPException( + status_code=404, + detail=f"Seed nicht gefunden: {update.seed_url}" + ) + + # Update the seed with crawl status + new_total = (seed["total_documents"] or 0) + update.documents_crawled + + await conn.execute(""" + UPDATE edu_search_seeds + SET + last_crawled_at = NOW(), + last_crawl_status = $2, + last_crawl_docs = $3, + total_documents = $4, + updated_at = NOW() + WHERE id = $1 + """, seed["id"], update.status, update.documents_crawled, new_total) + + logger.info( + f"Crawl status updated: {update.seed_url} - " + f"status={update.status}, docs={update.documents_crawled}, " + f"duration={update.crawl_duration_seconds:.1f}s" + ) + + return CrawlStatusResponse( + success=True, + seed_url=update.seed_url, + message=f"Status aktualisiert: {update.documents_crawled} Dokumente gecrawlt" + ) + + +@router.post("/seeds/crawl-status/bulk", response_model=BulkCrawlStatusResponse) +async def bulk_update_crawl_status(request: BulkCrawlStatusUpdate): + """Bulk update crawl status for multiple seeds.""" + pool = await get_db_pool() + updated = 0 + failed = 0 + errors = [] + + async with pool.acquire() as conn: + for update in request.updates: + try: + seed = await conn.fetchrow( + "SELECT id, total_documents FROM edu_search_seeds WHERE url = $1", + update.seed_url + ) + + if not seed: + failed += 1 + errors.append(f"Seed nicht gefunden: {update.seed_url}") + continue + + new_total = (seed["total_documents"] or 0) + update.documents_crawled + + await conn.execute(""" + UPDATE edu_search_seeds + SET + last_crawled_at = NOW(), + last_crawl_status = $2, + last_crawl_docs = $3, + total_documents = $4, + updated_at = NOW() + WHERE id = $1 + """, seed["id"], update.status, update.documents_crawled, new_total) + + updated += 1 + + except Exception as e: + failed += 1 + errors.append(f"{update.seed_url}: {str(e)}") + + logger.info(f"Bulk crawl status update: {updated} updated, {failed} failed") + + return BulkCrawlStatusResponse( + updated=updated, + failed=failed, + errors=errors + ) diff --git a/backend-lehrer/llm_gateway/routes/schools.py b/backend-lehrer/llm_gateway/routes/schools.py index ee76f0e..34d5bf6 100644 --- a/backend-lehrer/llm_gateway/routes/schools.py +++ b/backend-lehrer/llm_gateway/routes/schools.py @@ -1,867 +1,38 @@ """ -Schools API Routes. +Schools API Routes — Barrel Re-export. CRUD operations for managing German schools (~40,000 schools). -Direct database access to PostgreSQL. +Split into: + - schools_models.py: Pydantic models + - schools_db.py: Database connection pool + - schools_crud.py: School CRUD & stats routes + - schools_staff.py: Staff CRUD & search routes """ -import os -import logging -from typing import Optional, List -from datetime import datetime -from uuid import UUID +from fastapi import APIRouter -from fastapi import APIRouter, HTTPException, Query -from pydantic import BaseModel, Field -import asyncpg - -logger = logging.getLogger(__name__) +from .schools_crud import router as _crud_router +from .schools_staff import router as _staff_router +# Single router that merges both sub-module routers router = APIRouter(prefix="/schools", tags=["schools"]) - -# Database connection pool -_pool: Optional[asyncpg.Pool] = None - - -async def get_db_pool() -> asyncpg.Pool: - """Get or create database connection pool.""" - global _pool - if _pool is None: - database_url = os.environ.get( - "DATABASE_URL", - "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db" - ) - _pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10) - return _pool - - -# ============================================================================= -# Pydantic Models -# ============================================================================= - - -class SchoolTypeResponse(BaseModel): - """School type response model.""" - id: str - name: str - name_short: Optional[str] = None - category: Optional[str] = None - description: Optional[str] = None - - -class SchoolBase(BaseModel): - """Base school model for creation/update.""" - name: str = Field(..., max_length=255) - school_number: Optional[str] = Field(None, max_length=20) - school_type_id: Optional[str] = None - school_type_raw: Optional[str] = None - state: str = Field(..., max_length=10) - district: Optional[str] = None - city: Optional[str] = None - postal_code: Optional[str] = None - street: Optional[str] = None - address_full: Optional[str] = None - latitude: Optional[float] = None - longitude: Optional[float] = None - website: Optional[str] = None - email: Optional[str] = None - phone: Optional[str] = None - fax: Optional[str] = None - principal_name: Optional[str] = None - principal_title: Optional[str] = None - principal_email: Optional[str] = None - principal_phone: Optional[str] = None - secretary_name: Optional[str] = None - secretary_email: Optional[str] = None - secretary_phone: Optional[str] = None - student_count: Optional[int] = None - teacher_count: Optional[int] = None - class_count: Optional[int] = None - founded_year: Optional[int] = None - is_public: bool = True - is_all_day: Optional[bool] = None - has_inclusion: Optional[bool] = None - languages: Optional[List[str]] = None - specializations: Optional[List[str]] = None - source: Optional[str] = None - source_url: Optional[str] = None - - -class SchoolCreate(SchoolBase): - """School creation model.""" - pass - - -class SchoolUpdate(BaseModel): - """School update model (all fields optional).""" - name: Optional[str] = Field(None, max_length=255) - school_number: Optional[str] = None - school_type_id: Optional[str] = None - state: Optional[str] = None - district: Optional[str] = None - city: Optional[str] = None - postal_code: Optional[str] = None - street: Optional[str] = None - website: Optional[str] = None - email: Optional[str] = None - phone: Optional[str] = None - principal_name: Optional[str] = None - student_count: Optional[int] = None - teacher_count: Optional[int] = None - is_active: Optional[bool] = None - - -class SchoolResponse(BaseModel): - """School response model.""" - id: str - name: str - school_number: Optional[str] = None - school_type: Optional[str] = None - school_type_short: Optional[str] = None - school_category: Optional[str] = None - state: str - district: Optional[str] = None - city: Optional[str] = None - postal_code: Optional[str] = None - street: Optional[str] = None - address_full: Optional[str] = None - latitude: Optional[float] = None - longitude: Optional[float] = None - website: Optional[str] = None - email: Optional[str] = None - phone: Optional[str] = None - fax: Optional[str] = None - principal_name: Optional[str] = None - principal_email: Optional[str] = None - student_count: Optional[int] = None - teacher_count: Optional[int] = None - is_public: bool = True - is_all_day: Optional[bool] = None - staff_count: int = 0 - source: Optional[str] = None - crawled_at: Optional[datetime] = None - is_active: bool = True - created_at: datetime - updated_at: datetime - - -class SchoolsListResponse(BaseModel): - """List response with pagination info.""" - schools: List[SchoolResponse] - total: int - page: int - page_size: int - - -class SchoolStaffBase(BaseModel): - """Base school staff model.""" - first_name: Optional[str] = None - last_name: str - full_name: Optional[str] = None - title: Optional[str] = None - position: Optional[str] = None - position_type: Optional[str] = None - subjects: Optional[List[str]] = None - email: Optional[str] = None - phone: Optional[str] = None - - -class SchoolStaffCreate(SchoolStaffBase): - """School staff creation model.""" - school_id: str - - -class SchoolStaffResponse(SchoolStaffBase): - """School staff response model.""" - id: str - school_id: str - school_name: Optional[str] = None - profile_url: Optional[str] = None - photo_url: Optional[str] = None - is_active: bool = True - created_at: datetime - - -class SchoolStaffListResponse(BaseModel): - """Staff list response.""" - staff: List[SchoolStaffResponse] - total: int - - -class SchoolStatsResponse(BaseModel): - """School statistics response.""" - total_schools: int - total_staff: int - schools_by_state: dict - schools_by_type: dict - schools_with_website: int - schools_with_email: int - schools_with_principal: int - total_students: int - total_teachers: int - last_crawl_time: Optional[datetime] = None - - -class BulkImportRequest(BaseModel): - """Bulk import request.""" - schools: List[SchoolCreate] - - -class BulkImportResponse(BaseModel): - """Bulk import response.""" - imported: int - updated: int - skipped: int - errors: List[str] - - -# ============================================================================= -# School Type Endpoints -# ============================================================================= - - -@router.get("/types", response_model=List[SchoolTypeResponse]) -async def list_school_types(): - """List all school types.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - rows = await conn.fetch(""" - SELECT id, name, name_short, category, description - FROM school_types - ORDER BY category, name - """) - return [ - SchoolTypeResponse( - id=str(row["id"]), - name=row["name"], - name_short=row["name_short"], - category=row["category"], - description=row["description"], - ) - for row in rows - ] - - -# ============================================================================= -# School Endpoints -# ============================================================================= - - -@router.get("", response_model=SchoolsListResponse) -async def list_schools( - state: Optional[str] = Query(None, description="Filter by state code (BW, BY, etc.)"), - school_type: Optional[str] = Query(None, description="Filter by school type name"), - city: Optional[str] = Query(None, description="Filter by city"), - district: Optional[str] = Query(None, description="Filter by district"), - postal_code: Optional[str] = Query(None, description="Filter by postal code prefix"), - search: Optional[str] = Query(None, description="Search in name, city"), - has_email: Optional[bool] = Query(None, description="Filter schools with email"), - has_website: Optional[bool] = Query(None, description="Filter schools with website"), - is_public: Optional[bool] = Query(None, description="Filter public/private schools"), - page: int = Query(1, ge=1), - page_size: int = Query(50, ge=1, le=200), -): - """List schools with optional filtering and pagination.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - # Build WHERE clause - conditions = ["s.is_active = TRUE"] - params = [] - param_idx = 1 - - if state: - conditions.append(f"s.state = ${param_idx}") - params.append(state.upper()) - param_idx += 1 - - if school_type: - conditions.append(f"st.name = ${param_idx}") - params.append(school_type) - param_idx += 1 - - if city: - conditions.append(f"LOWER(s.city) = LOWER(${param_idx})") - params.append(city) - param_idx += 1 - - if district: - conditions.append(f"LOWER(s.district) LIKE LOWER(${param_idx})") - params.append(f"%{district}%") - param_idx += 1 - - if postal_code: - conditions.append(f"s.postal_code LIKE ${param_idx}") - params.append(f"{postal_code}%") - param_idx += 1 - - if search: - conditions.append(f""" - (LOWER(s.name) LIKE LOWER(${param_idx}) - OR LOWER(s.city) LIKE LOWER(${param_idx}) - OR LOWER(s.district) LIKE LOWER(${param_idx})) - """) - params.append(f"%{search}%") - param_idx += 1 - - if has_email is not None: - if has_email: - conditions.append("s.email IS NOT NULL") - else: - conditions.append("s.email IS NULL") - - if has_website is not None: - if has_website: - conditions.append("s.website IS NOT NULL") - else: - conditions.append("s.website IS NULL") - - if is_public is not None: - conditions.append(f"s.is_public = ${param_idx}") - params.append(is_public) - param_idx += 1 - - where_clause = " AND ".join(conditions) - - # Count total - count_query = f""" - SELECT COUNT(*) FROM schools s - LEFT JOIN school_types st ON s.school_type_id = st.id - WHERE {where_clause} - """ - total = await conn.fetchval(count_query, *params) - - # Fetch schools - offset = (page - 1) * page_size - query = f""" - SELECT - s.id, s.name, s.school_number, s.state, s.district, s.city, - s.postal_code, s.street, s.address_full, s.latitude, s.longitude, - s.website, s.email, s.phone, s.fax, - s.principal_name, s.principal_email, - s.student_count, s.teacher_count, - s.is_public, s.is_all_day, s.source, s.crawled_at, - s.is_active, s.created_at, s.updated_at, - st.name as school_type, st.name_short as school_type_short, st.category as school_category, - (SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count - FROM schools s - LEFT JOIN school_types st ON s.school_type_id = st.id - WHERE {where_clause} - ORDER BY s.state, s.city, s.name - LIMIT ${param_idx} OFFSET ${param_idx + 1} - """ - params.extend([page_size, offset]) - rows = await conn.fetch(query, *params) - - schools = [ - SchoolResponse( - id=str(row["id"]), - name=row["name"], - school_number=row["school_number"], - school_type=row["school_type"], - school_type_short=row["school_type_short"], - school_category=row["school_category"], - state=row["state"], - district=row["district"], - city=row["city"], - postal_code=row["postal_code"], - street=row["street"], - address_full=row["address_full"], - latitude=row["latitude"], - longitude=row["longitude"], - website=row["website"], - email=row["email"], - phone=row["phone"], - fax=row["fax"], - principal_name=row["principal_name"], - principal_email=row["principal_email"], - student_count=row["student_count"], - teacher_count=row["teacher_count"], - is_public=row["is_public"], - is_all_day=row["is_all_day"], - staff_count=row["staff_count"], - source=row["source"], - crawled_at=row["crawled_at"], - is_active=row["is_active"], - created_at=row["created_at"], - updated_at=row["updated_at"], - ) - for row in rows - ] - - return SchoolsListResponse( - schools=schools, - total=total, - page=page, - page_size=page_size, - ) - - -@router.get("/stats", response_model=SchoolStatsResponse) -async def get_school_stats(): - """Get school statistics.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - # Total schools and staff - totals = await conn.fetchrow(""" - SELECT - (SELECT COUNT(*) FROM schools WHERE is_active = TRUE) as total_schools, - (SELECT COUNT(*) FROM school_staff WHERE is_active = TRUE) as total_staff, - (SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND website IS NOT NULL) as with_website, - (SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND email IS NOT NULL) as with_email, - (SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND principal_name IS NOT NULL) as with_principal, - (SELECT COALESCE(SUM(student_count), 0) FROM schools WHERE is_active = TRUE) as total_students, - (SELECT COALESCE(SUM(teacher_count), 0) FROM schools WHERE is_active = TRUE) as total_teachers, - (SELECT MAX(crawled_at) FROM schools) as last_crawl - """) - - # By state - state_rows = await conn.fetch(""" - SELECT state, COUNT(*) as count - FROM schools - WHERE is_active = TRUE - GROUP BY state - ORDER BY state - """) - schools_by_state = {row["state"]: row["count"] for row in state_rows} - - # By type - type_rows = await conn.fetch(""" - SELECT COALESCE(st.name, 'Unbekannt') as type_name, COUNT(*) as count - FROM schools s - LEFT JOIN school_types st ON s.school_type_id = st.id - WHERE s.is_active = TRUE - GROUP BY st.name - ORDER BY count DESC - """) - schools_by_type = {row["type_name"]: row["count"] for row in type_rows} - - return SchoolStatsResponse( - total_schools=totals["total_schools"], - total_staff=totals["total_staff"], - schools_by_state=schools_by_state, - schools_by_type=schools_by_type, - schools_with_website=totals["with_website"], - schools_with_email=totals["with_email"], - schools_with_principal=totals["with_principal"], - total_students=totals["total_students"], - total_teachers=totals["total_teachers"], - last_crawl_time=totals["last_crawl"], - ) - - -@router.get("/{school_id}", response_model=SchoolResponse) -async def get_school(school_id: str): - """Get a single school by ID.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow(""" - SELECT - s.id, s.name, s.school_number, s.state, s.district, s.city, - s.postal_code, s.street, s.address_full, s.latitude, s.longitude, - s.website, s.email, s.phone, s.fax, - s.principal_name, s.principal_email, - s.student_count, s.teacher_count, - s.is_public, s.is_all_day, s.source, s.crawled_at, - s.is_active, s.created_at, s.updated_at, - st.name as school_type, st.name_short as school_type_short, st.category as school_category, - (SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count - FROM schools s - LEFT JOIN school_types st ON s.school_type_id = st.id - WHERE s.id = $1 - """, school_id) - - if not row: - raise HTTPException(status_code=404, detail="School not found") - - return SchoolResponse( - id=str(row["id"]), - name=row["name"], - school_number=row["school_number"], - school_type=row["school_type"], - school_type_short=row["school_type_short"], - school_category=row["school_category"], - state=row["state"], - district=row["district"], - city=row["city"], - postal_code=row["postal_code"], - street=row["street"], - address_full=row["address_full"], - latitude=row["latitude"], - longitude=row["longitude"], - website=row["website"], - email=row["email"], - phone=row["phone"], - fax=row["fax"], - principal_name=row["principal_name"], - principal_email=row["principal_email"], - student_count=row["student_count"], - teacher_count=row["teacher_count"], - is_public=row["is_public"], - is_all_day=row["is_all_day"], - staff_count=row["staff_count"], - source=row["source"], - crawled_at=row["crawled_at"], - is_active=row["is_active"], - created_at=row["created_at"], - updated_at=row["updated_at"], - ) - - -@router.post("/bulk-import", response_model=BulkImportResponse) -async def bulk_import_schools(request: BulkImportRequest): - """Bulk import schools. Updates existing schools based on school_number + state.""" - pool = await get_db_pool() - imported = 0 - updated = 0 - skipped = 0 - errors = [] - - async with pool.acquire() as conn: - # Get school type mapping - type_rows = await conn.fetch("SELECT id, name FROM school_types") - type_map = {row["name"].lower(): str(row["id"]) for row in type_rows} - - for school in request.schools: - try: - # Find school type ID - school_type_id = None - if school.school_type_raw: - school_type_id = type_map.get(school.school_type_raw.lower()) - - # Check if school exists (by school_number + state, or by name + city + state) - existing = None - if school.school_number: - existing = await conn.fetchrow( - "SELECT id FROM schools WHERE school_number = $1 AND state = $2", - school.school_number, school.state - ) - if not existing and school.city: - existing = await conn.fetchrow( - "SELECT id FROM schools WHERE LOWER(name) = LOWER($1) AND LOWER(city) = LOWER($2) AND state = $3", - school.name, school.city, school.state - ) - - if existing: - # Update existing school - await conn.execute(""" - UPDATE schools SET - name = $2, - school_type_id = COALESCE($3, school_type_id), - school_type_raw = COALESCE($4, school_type_raw), - district = COALESCE($5, district), - city = COALESCE($6, city), - postal_code = COALESCE($7, postal_code), - street = COALESCE($8, street), - address_full = COALESCE($9, address_full), - latitude = COALESCE($10, latitude), - longitude = COALESCE($11, longitude), - website = COALESCE($12, website), - email = COALESCE($13, email), - phone = COALESCE($14, phone), - fax = COALESCE($15, fax), - principal_name = COALESCE($16, principal_name), - principal_title = COALESCE($17, principal_title), - principal_email = COALESCE($18, principal_email), - principal_phone = COALESCE($19, principal_phone), - student_count = COALESCE($20, student_count), - teacher_count = COALESCE($21, teacher_count), - is_public = $22, - source = COALESCE($23, source), - source_url = COALESCE($24, source_url), - updated_at = NOW() - WHERE id = $1 - """, - existing["id"], - school.name, - school_type_id, - school.school_type_raw, - school.district, - school.city, - school.postal_code, - school.street, - school.address_full, - school.latitude, - school.longitude, - school.website, - school.email, - school.phone, - school.fax, - school.principal_name, - school.principal_title, - school.principal_email, - school.principal_phone, - school.student_count, - school.teacher_count, - school.is_public, - school.source, - school.source_url, - ) - updated += 1 - else: - # Insert new school - await conn.execute(""" - INSERT INTO schools ( - name, school_number, school_type_id, school_type_raw, - state, district, city, postal_code, street, address_full, - latitude, longitude, website, email, phone, fax, - principal_name, principal_title, principal_email, principal_phone, - student_count, teacher_count, is_public, - source, source_url, crawled_at - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, - $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, - $21, $22, $23, $24, $25, NOW() - ) - """, - school.name, - school.school_number, - school_type_id, - school.school_type_raw, - school.state, - school.district, - school.city, - school.postal_code, - school.street, - school.address_full, - school.latitude, - school.longitude, - school.website, - school.email, - school.phone, - school.fax, - school.principal_name, - school.principal_title, - school.principal_email, - school.principal_phone, - school.student_count, - school.teacher_count, - school.is_public, - school.source, - school.source_url, - ) - imported += 1 - - except Exception as e: - errors.append(f"Error importing {school.name}: {str(e)}") - if len(errors) > 100: - errors.append("... (more errors truncated)") - break - - return BulkImportResponse( - imported=imported, - updated=updated, - skipped=skipped, - errors=errors[:100], - ) - - -# ============================================================================= -# School Staff Endpoints -# ============================================================================= - - -@router.get("/{school_id}/staff", response_model=SchoolStaffListResponse) -async def get_school_staff(school_id: str): - """Get staff members for a school.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - rows = await conn.fetch(""" - SELECT - ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name, - ss.title, ss.position, ss.position_type, ss.subjects, - ss.email, ss.phone, ss.profile_url, ss.photo_url, - ss.is_active, ss.created_at, - s.name as school_name - FROM school_staff ss - JOIN schools s ON ss.school_id = s.id - WHERE ss.school_id = $1 AND ss.is_active = TRUE - ORDER BY - CASE ss.position_type - WHEN 'principal' THEN 1 - WHEN 'vice_principal' THEN 2 - WHEN 'secretary' THEN 3 - ELSE 4 - END, - ss.last_name - """, school_id) - - staff = [ - SchoolStaffResponse( - id=str(row["id"]), - school_id=str(row["school_id"]), - school_name=row["school_name"], - first_name=row["first_name"], - last_name=row["last_name"], - full_name=row["full_name"], - title=row["title"], - position=row["position"], - position_type=row["position_type"], - subjects=row["subjects"], - email=row["email"], - phone=row["phone"], - profile_url=row["profile_url"], - photo_url=row["photo_url"], - is_active=row["is_active"], - created_at=row["created_at"], - ) - for row in rows - ] - - return SchoolStaffListResponse( - staff=staff, - total=len(staff), - ) - - -@router.post("/{school_id}/staff", response_model=SchoolStaffResponse) -async def create_school_staff(school_id: str, staff: SchoolStaffBase): - """Add a staff member to a school.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - # Verify school exists - school = await conn.fetchrow("SELECT name FROM schools WHERE id = $1", school_id) - if not school: - raise HTTPException(status_code=404, detail="School not found") - - # Create full name - full_name = staff.full_name - if not full_name: - parts = [] - if staff.title: - parts.append(staff.title) - if staff.first_name: - parts.append(staff.first_name) - parts.append(staff.last_name) - full_name = " ".join(parts) - - row = await conn.fetchrow(""" - INSERT INTO school_staff ( - school_id, first_name, last_name, full_name, title, - position, position_type, subjects, email, phone - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - RETURNING id, created_at - """, - school_id, - staff.first_name, - staff.last_name, - full_name, - staff.title, - staff.position, - staff.position_type, - staff.subjects, - staff.email, - staff.phone, - ) - - return SchoolStaffResponse( - id=str(row["id"]), - school_id=school_id, - school_name=school["name"], - first_name=staff.first_name, - last_name=staff.last_name, - full_name=full_name, - title=staff.title, - position=staff.position, - position_type=staff.position_type, - subjects=staff.subjects, - email=staff.email, - phone=staff.phone, - is_active=True, - created_at=row["created_at"], - ) - - -# ============================================================================= -# Search Endpoints -# ============================================================================= - - -@router.get("/search/staff", response_model=SchoolStaffListResponse) -async def search_school_staff( - q: Optional[str] = Query(None, description="Search query"), - state: Optional[str] = Query(None, description="Filter by state"), - position_type: Optional[str] = Query(None, description="Filter by position type"), - has_email: Optional[bool] = Query(None, description="Only staff with email"), - page: int = Query(1, ge=1), - page_size: int = Query(50, ge=1, le=200), -): - """Search school staff across all schools.""" - pool = await get_db_pool() - async with pool.acquire() as conn: - conditions = ["ss.is_active = TRUE", "s.is_active = TRUE"] - params = [] - param_idx = 1 - - if q: - conditions.append(f""" - (LOWER(ss.full_name) LIKE LOWER(${param_idx}) - OR LOWER(ss.last_name) LIKE LOWER(${param_idx}) - OR LOWER(s.name) LIKE LOWER(${param_idx})) - """) - params.append(f"%{q}%") - param_idx += 1 - - if state: - conditions.append(f"s.state = ${param_idx}") - params.append(state.upper()) - param_idx += 1 - - if position_type: - conditions.append(f"ss.position_type = ${param_idx}") - params.append(position_type) - param_idx += 1 - - if has_email is not None and has_email: - conditions.append("ss.email IS NOT NULL") - - where_clause = " AND ".join(conditions) - - # Count total - total = await conn.fetchval(f""" - SELECT COUNT(*) FROM school_staff ss - JOIN schools s ON ss.school_id = s.id - WHERE {where_clause} - """, *params) - - # Fetch staff - offset = (page - 1) * page_size - rows = await conn.fetch(f""" - SELECT - ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name, - ss.title, ss.position, ss.position_type, ss.subjects, - ss.email, ss.phone, ss.profile_url, ss.photo_url, - ss.is_active, ss.created_at, - s.name as school_name - FROM school_staff ss - JOIN schools s ON ss.school_id = s.id - WHERE {where_clause} - ORDER BY ss.last_name, ss.first_name - LIMIT ${param_idx} OFFSET ${param_idx + 1} - """, *params, page_size, offset) - - staff = [ - SchoolStaffResponse( - id=str(row["id"]), - school_id=str(row["school_id"]), - school_name=row["school_name"], - first_name=row["first_name"], - last_name=row["last_name"], - full_name=row["full_name"], - title=row["title"], - position=row["position"], - position_type=row["position_type"], - subjects=row["subjects"], - email=row["email"], - phone=row["phone"], - profile_url=row["profile_url"], - photo_url=row["photo_url"], - is_active=row["is_active"], - created_at=row["created_at"], - ) - for row in rows - ] - - return SchoolStaffListResponse( - staff=staff, - total=total, - ) +router.include_router(_crud_router) +router.include_router(_staff_router) + +# Re-export models for any external consumers +from .schools_models import ( # noqa: E402, F401 + SchoolTypeResponse, + SchoolBase, + SchoolCreate, + SchoolUpdate, + SchoolResponse, + SchoolsListResponse, + SchoolStaffBase, + SchoolStaffCreate, + SchoolStaffResponse, + SchoolStaffListResponse, + SchoolStatsResponse, + BulkImportRequest, + BulkImportResponse, +) +from .schools_db import get_db_pool # noqa: E402, F401 diff --git a/backend-lehrer/llm_gateway/routes/schools_crud.py b/backend-lehrer/llm_gateway/routes/schools_crud.py new file mode 100644 index 0000000..29b0bc8 --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/schools_crud.py @@ -0,0 +1,464 @@ +""" +Schools API - School CRUD & Stats Routes. + +List, get, stats, and bulk-import endpoints for schools. +""" + +import logging +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query + +from .schools_db import get_db_pool +from .schools_models import ( + SchoolResponse, + SchoolsListResponse, + SchoolStatsResponse, + SchoolTypeResponse, + BulkImportRequest, + BulkImportResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["schools"]) + + +# ============================================================================= +# School Type Endpoints +# ============================================================================= + + +@router.get("/types", response_model=list[SchoolTypeResponse]) +async def list_school_types(): + """List all school types.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT id, name, name_short, category, description + FROM school_types + ORDER BY category, name + """) + return [ + SchoolTypeResponse( + id=str(row["id"]), + name=row["name"], + name_short=row["name_short"], + category=row["category"], + description=row["description"], + ) + for row in rows + ] + + +# ============================================================================= +# School Endpoints +# ============================================================================= + + +@router.get("", response_model=SchoolsListResponse) +async def list_schools( + state: Optional[str] = Query(None, description="Filter by state code (BW, BY, etc.)"), + school_type: Optional[str] = Query(None, description="Filter by school type name"), + city: Optional[str] = Query(None, description="Filter by city"), + district: Optional[str] = Query(None, description="Filter by district"), + postal_code: Optional[str] = Query(None, description="Filter by postal code prefix"), + search: Optional[str] = Query(None, description="Search in name, city"), + has_email: Optional[bool] = Query(None, description="Filter schools with email"), + has_website: Optional[bool] = Query(None, description="Filter schools with website"), + is_public: Optional[bool] = Query(None, description="Filter public/private schools"), + page: int = Query(1, ge=1), + page_size: int = Query(50, ge=1, le=200), +): + """List schools with optional filtering and pagination.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + # Build WHERE clause + conditions = ["s.is_active = TRUE"] + params = [] + param_idx = 1 + + if state: + conditions.append(f"s.state = ${param_idx}") + params.append(state.upper()) + param_idx += 1 + + if school_type: + conditions.append(f"st.name = ${param_idx}") + params.append(school_type) + param_idx += 1 + + if city: + conditions.append(f"LOWER(s.city) = LOWER(${param_idx})") + params.append(city) + param_idx += 1 + + if district: + conditions.append(f"LOWER(s.district) LIKE LOWER(${param_idx})") + params.append(f"%{district}%") + param_idx += 1 + + if postal_code: + conditions.append(f"s.postal_code LIKE ${param_idx}") + params.append(f"{postal_code}%") + param_idx += 1 + + if search: + conditions.append(f""" + (LOWER(s.name) LIKE LOWER(${param_idx}) + OR LOWER(s.city) LIKE LOWER(${param_idx}) + OR LOWER(s.district) LIKE LOWER(${param_idx})) + """) + params.append(f"%{search}%") + param_idx += 1 + + if has_email is not None: + if has_email: + conditions.append("s.email IS NOT NULL") + else: + conditions.append("s.email IS NULL") + + if has_website is not None: + if has_website: + conditions.append("s.website IS NOT NULL") + else: + conditions.append("s.website IS NULL") + + if is_public is not None: + conditions.append(f"s.is_public = ${param_idx}") + params.append(is_public) + param_idx += 1 + + where_clause = " AND ".join(conditions) + + # Count total + count_query = f""" + SELECT COUNT(*) FROM schools s + LEFT JOIN school_types st ON s.school_type_id = st.id + WHERE {where_clause} + """ + total = await conn.fetchval(count_query, *params) + + # Fetch schools + offset = (page - 1) * page_size + query = f""" + SELECT + s.id, s.name, s.school_number, s.state, s.district, s.city, + s.postal_code, s.street, s.address_full, s.latitude, s.longitude, + s.website, s.email, s.phone, s.fax, + s.principal_name, s.principal_email, + s.student_count, s.teacher_count, + s.is_public, s.is_all_day, s.source, s.crawled_at, + s.is_active, s.created_at, s.updated_at, + st.name as school_type, st.name_short as school_type_short, st.category as school_category, + (SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count + FROM schools s + LEFT JOIN school_types st ON s.school_type_id = st.id + WHERE {where_clause} + ORDER BY s.state, s.city, s.name + LIMIT ${param_idx} OFFSET ${param_idx + 1} + """ + params.extend([page_size, offset]) + rows = await conn.fetch(query, *params) + + schools = [ + SchoolResponse( + id=str(row["id"]), + name=row["name"], + school_number=row["school_number"], + school_type=row["school_type"], + school_type_short=row["school_type_short"], + school_category=row["school_category"], + state=row["state"], + district=row["district"], + city=row["city"], + postal_code=row["postal_code"], + street=row["street"], + address_full=row["address_full"], + latitude=row["latitude"], + longitude=row["longitude"], + website=row["website"], + email=row["email"], + phone=row["phone"], + fax=row["fax"], + principal_name=row["principal_name"], + principal_email=row["principal_email"], + student_count=row["student_count"], + teacher_count=row["teacher_count"], + is_public=row["is_public"], + is_all_day=row["is_all_day"], + staff_count=row["staff_count"], + source=row["source"], + crawled_at=row["crawled_at"], + is_active=row["is_active"], + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + for row in rows + ] + + return SchoolsListResponse( + schools=schools, + total=total, + page=page, + page_size=page_size, + ) + + +@router.get("/stats", response_model=SchoolStatsResponse) +async def get_school_stats(): + """Get school statistics.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + # Total schools and staff + totals = await conn.fetchrow(""" + SELECT + (SELECT COUNT(*) FROM schools WHERE is_active = TRUE) as total_schools, + (SELECT COUNT(*) FROM school_staff WHERE is_active = TRUE) as total_staff, + (SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND website IS NOT NULL) as with_website, + (SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND email IS NOT NULL) as with_email, + (SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND principal_name IS NOT NULL) as with_principal, + (SELECT COALESCE(SUM(student_count), 0) FROM schools WHERE is_active = TRUE) as total_students, + (SELECT COALESCE(SUM(teacher_count), 0) FROM schools WHERE is_active = TRUE) as total_teachers, + (SELECT MAX(crawled_at) FROM schools) as last_crawl + """) + + # By state + state_rows = await conn.fetch(""" + SELECT state, COUNT(*) as count + FROM schools + WHERE is_active = TRUE + GROUP BY state + ORDER BY state + """) + schools_by_state = {row["state"]: row["count"] for row in state_rows} + + # By type + type_rows = await conn.fetch(""" + SELECT COALESCE(st.name, 'Unbekannt') as type_name, COUNT(*) as count + FROM schools s + LEFT JOIN school_types st ON s.school_type_id = st.id + WHERE s.is_active = TRUE + GROUP BY st.name + ORDER BY count DESC + """) + schools_by_type = {row["type_name"]: row["count"] for row in type_rows} + + return SchoolStatsResponse( + total_schools=totals["total_schools"], + total_staff=totals["total_staff"], + schools_by_state=schools_by_state, + schools_by_type=schools_by_type, + schools_with_website=totals["with_website"], + schools_with_email=totals["with_email"], + schools_with_principal=totals["with_principal"], + total_students=totals["total_students"], + total_teachers=totals["total_teachers"], + last_crawl_time=totals["last_crawl"], + ) + + +@router.get("/{school_id}", response_model=SchoolResponse) +async def get_school(school_id: str): + """Get a single school by ID.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + SELECT + s.id, s.name, s.school_number, s.state, s.district, s.city, + s.postal_code, s.street, s.address_full, s.latitude, s.longitude, + s.website, s.email, s.phone, s.fax, + s.principal_name, s.principal_email, + s.student_count, s.teacher_count, + s.is_public, s.is_all_day, s.source, s.crawled_at, + s.is_active, s.created_at, s.updated_at, + st.name as school_type, st.name_short as school_type_short, st.category as school_category, + (SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count + FROM schools s + LEFT JOIN school_types st ON s.school_type_id = st.id + WHERE s.id = $1 + """, school_id) + + if not row: + raise HTTPException(status_code=404, detail="School not found") + + return SchoolResponse( + id=str(row["id"]), + name=row["name"], + school_number=row["school_number"], + school_type=row["school_type"], + school_type_short=row["school_type_short"], + school_category=row["school_category"], + state=row["state"], + district=row["district"], + city=row["city"], + postal_code=row["postal_code"], + street=row["street"], + address_full=row["address_full"], + latitude=row["latitude"], + longitude=row["longitude"], + website=row["website"], + email=row["email"], + phone=row["phone"], + fax=row["fax"], + principal_name=row["principal_name"], + principal_email=row["principal_email"], + student_count=row["student_count"], + teacher_count=row["teacher_count"], + is_public=row["is_public"], + is_all_day=row["is_all_day"], + staff_count=row["staff_count"], + source=row["source"], + crawled_at=row["crawled_at"], + is_active=row["is_active"], + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + +@router.post("/bulk-import", response_model=BulkImportResponse) +async def bulk_import_schools(request: BulkImportRequest): + """Bulk import schools. Updates existing schools based on school_number + state.""" + pool = await get_db_pool() + imported = 0 + updated = 0 + skipped = 0 + errors = [] + + async with pool.acquire() as conn: + # Get school type mapping + type_rows = await conn.fetch("SELECT id, name FROM school_types") + type_map = {row["name"].lower(): str(row["id"]) for row in type_rows} + + for school in request.schools: + try: + # Find school type ID + school_type_id = None + if school.school_type_raw: + school_type_id = type_map.get(school.school_type_raw.lower()) + + # Check if school exists (by school_number + state, or by name + city + state) + existing = None + if school.school_number: + existing = await conn.fetchrow( + "SELECT id FROM schools WHERE school_number = $1 AND state = $2", + school.school_number, school.state + ) + if not existing and school.city: + existing = await conn.fetchrow( + "SELECT id FROM schools WHERE LOWER(name) = LOWER($1) AND LOWER(city) = LOWER($2) AND state = $3", + school.name, school.city, school.state + ) + + if existing: + # Update existing school + await conn.execute(""" + UPDATE schools SET + name = $2, + school_type_id = COALESCE($3, school_type_id), + school_type_raw = COALESCE($4, school_type_raw), + district = COALESCE($5, district), + city = COALESCE($6, city), + postal_code = COALESCE($7, postal_code), + street = COALESCE($8, street), + address_full = COALESCE($9, address_full), + latitude = COALESCE($10, latitude), + longitude = COALESCE($11, longitude), + website = COALESCE($12, website), + email = COALESCE($13, email), + phone = COALESCE($14, phone), + fax = COALESCE($15, fax), + principal_name = COALESCE($16, principal_name), + principal_title = COALESCE($17, principal_title), + principal_email = COALESCE($18, principal_email), + principal_phone = COALESCE($19, principal_phone), + student_count = COALESCE($20, student_count), + teacher_count = COALESCE($21, teacher_count), + is_public = $22, + source = COALESCE($23, source), + source_url = COALESCE($24, source_url), + updated_at = NOW() + WHERE id = $1 + """, + existing["id"], + school.name, + school_type_id, + school.school_type_raw, + school.district, + school.city, + school.postal_code, + school.street, + school.address_full, + school.latitude, + school.longitude, + school.website, + school.email, + school.phone, + school.fax, + school.principal_name, + school.principal_title, + school.principal_email, + school.principal_phone, + school.student_count, + school.teacher_count, + school.is_public, + school.source, + school.source_url, + ) + updated += 1 + else: + # Insert new school + await conn.execute(""" + INSERT INTO schools ( + name, school_number, school_type_id, school_type_raw, + state, district, city, postal_code, street, address_full, + latitude, longitude, website, email, phone, fax, + principal_name, principal_title, principal_email, principal_phone, + student_count, teacher_count, is_public, + source, source_url, crawled_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, + $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, NOW() + ) + """, + school.name, + school.school_number, + school_type_id, + school.school_type_raw, + school.state, + school.district, + school.city, + school.postal_code, + school.street, + school.address_full, + school.latitude, + school.longitude, + school.website, + school.email, + school.phone, + school.fax, + school.principal_name, + school.principal_title, + school.principal_email, + school.principal_phone, + school.student_count, + school.teacher_count, + school.is_public, + school.source, + school.source_url, + ) + imported += 1 + + except Exception as e: + errors.append(f"Error importing {school.name}: {str(e)}") + if len(errors) > 100: + errors.append("... (more errors truncated)") + break + + return BulkImportResponse( + imported=imported, + updated=updated, + skipped=skipped, + errors=errors[:100], + ) diff --git a/backend-lehrer/llm_gateway/routes/schools_db.py b/backend-lehrer/llm_gateway/routes/schools_db.py new file mode 100644 index 0000000..ec75b93 --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/schools_db.py @@ -0,0 +1,25 @@ +""" +Schools API - Database Connection. + +Shared database pool for school endpoints. +""" + +import os +from typing import Optional + +import asyncpg + +# Database connection pool +_pool: Optional[asyncpg.Pool] = None + + +async def get_db_pool() -> asyncpg.Pool: + """Get or create database connection pool.""" + global _pool + if _pool is None: + database_url = os.environ.get( + "DATABASE_URL", + "postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db" + ) + _pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10) + return _pool diff --git a/backend-lehrer/llm_gateway/routes/schools_models.py b/backend-lehrer/llm_gateway/routes/schools_models.py new file mode 100644 index 0000000..88a11a7 --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/schools_models.py @@ -0,0 +1,200 @@ +""" +Schools API - Pydantic Models. + +Data models for school and school staff endpoints. +""" + +from typing import Optional, List +from datetime import datetime + +from pydantic import BaseModel, Field + + +# ============================================================================= +# School Type Models +# ============================================================================= + + +class SchoolTypeResponse(BaseModel): + """School type response model.""" + id: str + name: str + name_short: Optional[str] = None + category: Optional[str] = None + description: Optional[str] = None + + +# ============================================================================= +# School Models +# ============================================================================= + + +class SchoolBase(BaseModel): + """Base school model for creation/update.""" + name: str = Field(..., max_length=255) + school_number: Optional[str] = Field(None, max_length=20) + school_type_id: Optional[str] = None + school_type_raw: Optional[str] = None + state: str = Field(..., max_length=10) + district: Optional[str] = None + city: Optional[str] = None + postal_code: Optional[str] = None + street: Optional[str] = None + address_full: Optional[str] = None + latitude: Optional[float] = None + longitude: Optional[float] = None + website: Optional[str] = None + email: Optional[str] = None + phone: Optional[str] = None + fax: Optional[str] = None + principal_name: Optional[str] = None + principal_title: Optional[str] = None + principal_email: Optional[str] = None + principal_phone: Optional[str] = None + secretary_name: Optional[str] = None + secretary_email: Optional[str] = None + secretary_phone: Optional[str] = None + student_count: Optional[int] = None + teacher_count: Optional[int] = None + class_count: Optional[int] = None + founded_year: Optional[int] = None + is_public: bool = True + is_all_day: Optional[bool] = None + has_inclusion: Optional[bool] = None + languages: Optional[List[str]] = None + specializations: Optional[List[str]] = None + source: Optional[str] = None + source_url: Optional[str] = None + + +class SchoolCreate(SchoolBase): + """School creation model.""" + pass + + +class SchoolUpdate(BaseModel): + """School update model (all fields optional).""" + name: Optional[str] = Field(None, max_length=255) + school_number: Optional[str] = None + school_type_id: Optional[str] = None + state: Optional[str] = None + district: Optional[str] = None + city: Optional[str] = None + postal_code: Optional[str] = None + street: Optional[str] = None + website: Optional[str] = None + email: Optional[str] = None + phone: Optional[str] = None + principal_name: Optional[str] = None + student_count: Optional[int] = None + teacher_count: Optional[int] = None + is_active: Optional[bool] = None + + +class SchoolResponse(BaseModel): + """School response model.""" + id: str + name: str + school_number: Optional[str] = None + school_type: Optional[str] = None + school_type_short: Optional[str] = None + school_category: Optional[str] = None + state: str + district: Optional[str] = None + city: Optional[str] = None + postal_code: Optional[str] = None + street: Optional[str] = None + address_full: Optional[str] = None + latitude: Optional[float] = None + longitude: Optional[float] = None + website: Optional[str] = None + email: Optional[str] = None + phone: Optional[str] = None + fax: Optional[str] = None + principal_name: Optional[str] = None + principal_email: Optional[str] = None + student_count: Optional[int] = None + teacher_count: Optional[int] = None + is_public: bool = True + is_all_day: Optional[bool] = None + staff_count: int = 0 + source: Optional[str] = None + crawled_at: Optional[datetime] = None + is_active: bool = True + created_at: datetime + updated_at: datetime + + +class SchoolsListResponse(BaseModel): + """List response with pagination info.""" + schools: List[SchoolResponse] + total: int + page: int + page_size: int + + +class SchoolStatsResponse(BaseModel): + """School statistics response.""" + total_schools: int + total_staff: int + schools_by_state: dict + schools_by_type: dict + schools_with_website: int + schools_with_email: int + schools_with_principal: int + total_students: int + total_teachers: int + last_crawl_time: Optional[datetime] = None + + +class BulkImportRequest(BaseModel): + """Bulk import request.""" + schools: List[SchoolCreate] + + +class BulkImportResponse(BaseModel): + """Bulk import response.""" + imported: int + updated: int + skipped: int + errors: List[str] + + +# ============================================================================= +# School Staff Models +# ============================================================================= + + +class SchoolStaffBase(BaseModel): + """Base school staff model.""" + first_name: Optional[str] = None + last_name: str + full_name: Optional[str] = None + title: Optional[str] = None + position: Optional[str] = None + position_type: Optional[str] = None + subjects: Optional[List[str]] = None + email: Optional[str] = None + phone: Optional[str] = None + + +class SchoolStaffCreate(SchoolStaffBase): + """School staff creation model.""" + school_id: str + + +class SchoolStaffResponse(SchoolStaffBase): + """School staff response model.""" + id: str + school_id: str + school_name: Optional[str] = None + profile_url: Optional[str] = None + photo_url: Optional[str] = None + is_active: bool = True + created_at: datetime + + +class SchoolStaffListResponse(BaseModel): + """Staff list response.""" + staff: List[SchoolStaffResponse] + total: int diff --git a/backend-lehrer/llm_gateway/routes/schools_staff.py b/backend-lehrer/llm_gateway/routes/schools_staff.py new file mode 100644 index 0000000..c26dc9f --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/schools_staff.py @@ -0,0 +1,233 @@ +""" +Schools API - Staff Routes. + +CRUD and search endpoints for school staff members. +""" + +import logging +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query + +from .schools_db import get_db_pool +from .schools_models import ( + SchoolStaffBase, + SchoolStaffResponse, + SchoolStaffListResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["schools"]) + + +# ============================================================================= +# School Staff Endpoints +# ============================================================================= + + +@router.get("/{school_id}/staff", response_model=SchoolStaffListResponse) +async def get_school_staff(school_id: str): + """Get staff members for a school.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(""" + SELECT + ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name, + ss.title, ss.position, ss.position_type, ss.subjects, + ss.email, ss.phone, ss.profile_url, ss.photo_url, + ss.is_active, ss.created_at, + s.name as school_name + FROM school_staff ss + JOIN schools s ON ss.school_id = s.id + WHERE ss.school_id = $1 AND ss.is_active = TRUE + ORDER BY + CASE ss.position_type + WHEN 'principal' THEN 1 + WHEN 'vice_principal' THEN 2 + WHEN 'secretary' THEN 3 + ELSE 4 + END, + ss.last_name + """, school_id) + + staff = [ + SchoolStaffResponse( + id=str(row["id"]), + school_id=str(row["school_id"]), + school_name=row["school_name"], + first_name=row["first_name"], + last_name=row["last_name"], + full_name=row["full_name"], + title=row["title"], + position=row["position"], + position_type=row["position_type"], + subjects=row["subjects"], + email=row["email"], + phone=row["phone"], + profile_url=row["profile_url"], + photo_url=row["photo_url"], + is_active=row["is_active"], + created_at=row["created_at"], + ) + for row in rows + ] + + return SchoolStaffListResponse( + staff=staff, + total=len(staff), + ) + + +@router.post("/{school_id}/staff", response_model=SchoolStaffResponse) +async def create_school_staff(school_id: str, staff: SchoolStaffBase): + """Add a staff member to a school.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + # Verify school exists + school = await conn.fetchrow("SELECT name FROM schools WHERE id = $1", school_id) + if not school: + raise HTTPException(status_code=404, detail="School not found") + + # Create full name + full_name = staff.full_name + if not full_name: + parts = [] + if staff.title: + parts.append(staff.title) + if staff.first_name: + parts.append(staff.first_name) + parts.append(staff.last_name) + full_name = " ".join(parts) + + row = await conn.fetchrow(""" + INSERT INTO school_staff ( + school_id, first_name, last_name, full_name, title, + position, position_type, subjects, email, phone + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + RETURNING id, created_at + """, + school_id, + staff.first_name, + staff.last_name, + full_name, + staff.title, + staff.position, + staff.position_type, + staff.subjects, + staff.email, + staff.phone, + ) + + return SchoolStaffResponse( + id=str(row["id"]), + school_id=school_id, + school_name=school["name"], + first_name=staff.first_name, + last_name=staff.last_name, + full_name=full_name, + title=staff.title, + position=staff.position, + position_type=staff.position_type, + subjects=staff.subjects, + email=staff.email, + phone=staff.phone, + is_active=True, + created_at=row["created_at"], + ) + + +# ============================================================================= +# Search Endpoints +# ============================================================================= + + +@router.get("/search/staff", response_model=SchoolStaffListResponse) +async def search_school_staff( + q: Optional[str] = Query(None, description="Search query"), + state: Optional[str] = Query(None, description="Filter by state"), + position_type: Optional[str] = Query(None, description="Filter by position type"), + has_email: Optional[bool] = Query(None, description="Only staff with email"), + page: int = Query(1, ge=1), + page_size: int = Query(50, ge=1, le=200), +): + """Search school staff across all schools.""" + pool = await get_db_pool() + async with pool.acquire() as conn: + conditions = ["ss.is_active = TRUE", "s.is_active = TRUE"] + params = [] + param_idx = 1 + + if q: + conditions.append(f""" + (LOWER(ss.full_name) LIKE LOWER(${param_idx}) + OR LOWER(ss.last_name) LIKE LOWER(${param_idx}) + OR LOWER(s.name) LIKE LOWER(${param_idx})) + """) + params.append(f"%{q}%") + param_idx += 1 + + if state: + conditions.append(f"s.state = ${param_idx}") + params.append(state.upper()) + param_idx += 1 + + if position_type: + conditions.append(f"ss.position_type = ${param_idx}") + params.append(position_type) + param_idx += 1 + + if has_email is not None and has_email: + conditions.append("ss.email IS NOT NULL") + + where_clause = " AND ".join(conditions) + + # Count total + total = await conn.fetchval(f""" + SELECT COUNT(*) FROM school_staff ss + JOIN schools s ON ss.school_id = s.id + WHERE {where_clause} + """, *params) + + # Fetch staff + offset = (page - 1) * page_size + rows = await conn.fetch(f""" + SELECT + ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name, + ss.title, ss.position, ss.position_type, ss.subjects, + ss.email, ss.phone, ss.profile_url, ss.photo_url, + ss.is_active, ss.created_at, + s.name as school_name + FROM school_staff ss + JOIN schools s ON ss.school_id = s.id + WHERE {where_clause} + ORDER BY ss.last_name, ss.first_name + LIMIT ${param_idx} OFFSET ${param_idx + 1} + """, *params, page_size, offset) + + staff = [ + SchoolStaffResponse( + id=str(row["id"]), + school_id=str(row["school_id"]), + school_name=row["school_name"], + first_name=row["first_name"], + last_name=row["last_name"], + full_name=row["full_name"], + title=row["title"], + position=row["position"], + position_type=row["position_type"], + subjects=row["subjects"], + email=row["email"], + phone=row["phone"], + profile_url=row["profile_url"], + photo_url=row["photo_url"], + is_active=row["is_active"], + created_at=row["created_at"], + ) + for row in rows + ] + + return SchoolStaffListResponse( + staff=staff, + total=total, + ) diff --git a/backend-lehrer/messenger_api.py b/backend-lehrer/messenger_api.py index 3540496..ca5962f 100644 --- a/backend-lehrer/messenger_api.py +++ b/backend-lehrer/messenger_api.py @@ -1,840 +1,21 @@ """ -BreakPilot Messenger API +BreakPilot Messenger API — Barrel Re-export. -Stellt Endpoints fuer: -- Kontaktverwaltung (CRUD) -- Konversationen -- Nachrichten -- CSV-Import fuer Kontakte -- Gruppenmanagement +Stellt Endpoints fuer Kontakte, Konversationen, Nachrichten, +CSV-Import, Gruppenmanagement und Templates bereit. -DSGVO-konform: Alle Daten werden lokal gespeichert. +Split into: + - messenger_models.py: Pydantic models + - messenger_helpers.py: JSON file storage & default templates + - messenger_contacts.py: Contact CRUD & CSV import/export + - messenger_conversations.py: Conversations, messages, groups, templates, stats """ -import os -import csv -import uuid -import json -from io import StringIO -from datetime import datetime -from typing import List, Optional, Dict, Any -from pathlib import Path +from fastapi import APIRouter -from fastapi import APIRouter, HTTPException, UploadFile, File, Query -from pydantic import BaseModel, Field +from messenger_contacts import router as _contacts_router +from messenger_conversations import router as _conversations_router router = APIRouter(prefix="/api/messenger", tags=["Messenger"]) - -# Datenspeicherung (JSON-basiert fuer einfache Persistenz) -DATA_DIR = Path(__file__).parent / "data" / "messenger" -DATA_DIR.mkdir(parents=True, exist_ok=True) - -CONTACTS_FILE = DATA_DIR / "contacts.json" -CONVERSATIONS_FILE = DATA_DIR / "conversations.json" -MESSAGES_FILE = DATA_DIR / "messages.json" -GROUPS_FILE = DATA_DIR / "groups.json" - - -# ========================================== -# PYDANTIC MODELS -# ========================================== - -class ContactBase(BaseModel): - """Basis-Modell fuer Kontakte.""" - name: str = Field(..., min_length=1, max_length=200) - email: Optional[str] = None - phone: Optional[str] = None - role: str = Field(default="parent", description="parent, teacher, staff, student") - student_name: Optional[str] = Field(None, description="Name des zugehoerigen Schuelers") - class_name: Optional[str] = Field(None, description="Klasse z.B. 10a") - notes: Optional[str] = None - tags: List[str] = Field(default_factory=list) - matrix_id: Optional[str] = Field(None, description="Matrix-ID z.B. @user:matrix.org") - preferred_channel: str = Field(default="email", description="email, matrix, pwa") - - -class ContactCreate(ContactBase): - """Model fuer neuen Kontakt.""" - pass - - -class Contact(ContactBase): - """Vollstaendiger Kontakt mit ID.""" - id: str - created_at: str - updated_at: str - online: bool = False - last_seen: Optional[str] = None - - -class ContactUpdate(BaseModel): - """Update-Model fuer Kontakte.""" - name: Optional[str] = None - email: Optional[str] = None - phone: Optional[str] = None - role: Optional[str] = None - student_name: Optional[str] = None - class_name: Optional[str] = None - notes: Optional[str] = None - tags: Optional[List[str]] = None - matrix_id: Optional[str] = None - preferred_channel: Optional[str] = None - - -class GroupBase(BaseModel): - """Basis-Modell fuer Gruppen.""" - name: str = Field(..., min_length=1, max_length=100) - description: Optional[str] = None - group_type: str = Field(default="class", description="class, department, custom") - - -class GroupCreate(GroupBase): - """Model fuer neue Gruppe.""" - member_ids: List[str] = Field(default_factory=list) - - -class Group(GroupBase): - """Vollstaendige Gruppe mit ID.""" - id: str - member_ids: List[str] = [] - created_at: str - updated_at: str - - -class MessageBase(BaseModel): - """Basis-Modell fuer Nachrichten.""" - content: str = Field(..., min_length=1) - content_type: str = Field(default="text", description="text, file, image") - file_url: Optional[str] = None - send_email: bool = Field(default=False, description="Nachricht auch per Email senden") - - -class MessageCreate(MessageBase): - """Model fuer neue Nachricht.""" - conversation_id: str - - -class Message(MessageBase): - """Vollstaendige Nachricht mit ID.""" - id: str - conversation_id: str - sender_id: str # "self" fuer eigene Nachrichten - timestamp: str - read: bool = False - read_at: Optional[str] = None - email_sent: bool = False - email_sent_at: Optional[str] = None - email_error: Optional[str] = None - - -class ConversationBase(BaseModel): - """Basis-Modell fuer Konversationen.""" - name: Optional[str] = None - is_group: bool = False - - -class Conversation(ConversationBase): - """Vollstaendige Konversation mit ID.""" - id: str - participant_ids: List[str] = [] - group_id: Optional[str] = None - created_at: str - updated_at: str - last_message: Optional[str] = None - last_message_time: Optional[str] = None - unread_count: int = 0 - - -class CSVImportResult(BaseModel): - """Ergebnis eines CSV-Imports.""" - imported: int - skipped: int - errors: List[str] - contacts: List[Contact] - - -# ========================================== -# DATA HELPERS -# ========================================== - -def load_json(filepath: Path) -> List[Dict]: - """Laedt JSON-Daten aus Datei.""" - if not filepath.exists(): - return [] - try: - with open(filepath, "r", encoding="utf-8") as f: - return json.load(f) - except Exception: - return [] - - -def save_json(filepath: Path, data: List[Dict]): - """Speichert Daten in JSON-Datei.""" - with open(filepath, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - -def get_contacts() -> List[Dict]: - return load_json(CONTACTS_FILE) - - -def save_contacts(contacts: List[Dict]): - save_json(CONTACTS_FILE, contacts) - - -def get_conversations() -> List[Dict]: - return load_json(CONVERSATIONS_FILE) - - -def save_conversations(conversations: List[Dict]): - save_json(CONVERSATIONS_FILE, conversations) - - -def get_messages() -> List[Dict]: - return load_json(MESSAGES_FILE) - - -def save_messages(messages: List[Dict]): - save_json(MESSAGES_FILE, messages) - - -def get_groups() -> List[Dict]: - return load_json(GROUPS_FILE) - - -def save_groups(groups: List[Dict]): - save_json(GROUPS_FILE, groups) - - -# ========================================== -# CONTACTS ENDPOINTS -# ========================================== - -@router.get("/contacts", response_model=List[Contact]) -async def list_contacts( - role: Optional[str] = Query(None, description="Filter by role"), - class_name: Optional[str] = Query(None, description="Filter by class"), - search: Optional[str] = Query(None, description="Search in name/email") -): - """Listet alle Kontakte auf.""" - contacts = get_contacts() - - # Filter anwenden - if role: - contacts = [c for c in contacts if c.get("role") == role] - if class_name: - contacts = [c for c in contacts if c.get("class_name") == class_name] - if search: - search_lower = search.lower() - contacts = [c for c in contacts if - search_lower in c.get("name", "").lower() or - search_lower in (c.get("email") or "").lower() or - search_lower in (c.get("student_name") or "").lower()] - - return contacts - - -@router.post("/contacts", response_model=Contact) -async def create_contact(contact: ContactCreate): - """Erstellt einen neuen Kontakt.""" - contacts = get_contacts() - - # Pruefen ob Email bereits existiert - if contact.email: - existing = [c for c in contacts if c.get("email") == contact.email] - if existing: - raise HTTPException(status_code=400, detail="Kontakt mit dieser Email existiert bereits") - - now = datetime.utcnow().isoformat() - new_contact = { - "id": str(uuid.uuid4()), - "created_at": now, - "updated_at": now, - "online": False, - "last_seen": None, - **contact.dict() - } - - contacts.append(new_contact) - save_contacts(contacts) - - return new_contact - - -@router.get("/contacts/{contact_id}", response_model=Contact) -async def get_contact(contact_id: str): - """Ruft einen einzelnen Kontakt ab.""" - contacts = get_contacts() - contact = next((c for c in contacts if c["id"] == contact_id), None) - - if not contact: - raise HTTPException(status_code=404, detail="Kontakt nicht gefunden") - - return contact - - -@router.put("/contacts/{contact_id}", response_model=Contact) -async def update_contact(contact_id: str, update: ContactUpdate): - """Aktualisiert einen Kontakt.""" - contacts = get_contacts() - contact_idx = next((i for i, c in enumerate(contacts) if c["id"] == contact_id), None) - - if contact_idx is None: - raise HTTPException(status_code=404, detail="Kontakt nicht gefunden") - - update_data = update.dict(exclude_unset=True) - contacts[contact_idx].update(update_data) - contacts[contact_idx]["updated_at"] = datetime.utcnow().isoformat() - - save_contacts(contacts) - return contacts[contact_idx] - - -@router.delete("/contacts/{contact_id}") -async def delete_contact(contact_id: str): - """Loescht einen Kontakt.""" - contacts = get_contacts() - contacts = [c for c in contacts if c["id"] != contact_id] - save_contacts(contacts) - - return {"status": "deleted", "id": contact_id} - - -@router.post("/contacts/import", response_model=CSVImportResult) -async def import_contacts_csv(file: UploadFile = File(...)): - """ - Importiert Kontakte aus einer CSV-Datei. - - Erwartete Spalten: - - name (required) - - email - - phone - - role (parent/teacher/staff/student) - - student_name - - class_name - - notes - - tags (komma-separiert) - """ - if not file.filename.endswith('.csv'): - raise HTTPException(status_code=400, detail="Nur CSV-Dateien werden unterstuetzt") - - content = await file.read() - try: - text = content.decode('utf-8') - except UnicodeDecodeError: - text = content.decode('latin-1') - - contacts = get_contacts() - existing_emails = {c.get("email") for c in contacts if c.get("email")} - - imported = [] - skipped = 0 - errors = [] - - reader = csv.DictReader(StringIO(text), delimiter=';') # Deutsche CSV meist mit Semikolon - if not reader.fieldnames or 'name' not in [f.lower() for f in reader.fieldnames]: - # Versuche mit Komma - reader = csv.DictReader(StringIO(text), delimiter=',') - - for row_num, row in enumerate(reader, start=2): - try: - # Normalisiere Spaltennamen - row = {k.lower().strip(): v.strip() if v else "" for k, v in row.items()} - - name = row.get('name') or row.get('kontakt') or row.get('elternname') - if not name: - errors.append(f"Zeile {row_num}: Name fehlt") - skipped += 1 - continue - - email = row.get('email') or row.get('e-mail') or row.get('mail') - if email and email in existing_emails: - errors.append(f"Zeile {row_num}: Email {email} existiert bereits") - skipped += 1 - continue - - now = datetime.utcnow().isoformat() - tags_str = row.get('tags') or row.get('kategorien') or "" - tags = [t.strip() for t in tags_str.split(',') if t.strip()] - - # Matrix-ID und preferred_channel auslesen - matrix_id = row.get('matrix_id') or row.get('matrix') or None - preferred_channel = row.get('preferred_channel') or row.get('kanal') or "email" - if preferred_channel not in ["email", "matrix", "pwa"]: - preferred_channel = "email" - - new_contact = { - "id": str(uuid.uuid4()), - "name": name, - "email": email if email else None, - "phone": row.get('phone') or row.get('telefon') or row.get('tel'), - "role": row.get('role') or row.get('rolle') or "parent", - "student_name": row.get('student_name') or row.get('schueler') or row.get('kind'), - "class_name": row.get('class_name') or row.get('klasse'), - "notes": row.get('notes') or row.get('notizen') or row.get('bemerkungen'), - "tags": tags, - "matrix_id": matrix_id if matrix_id else None, - "preferred_channel": preferred_channel, - "created_at": now, - "updated_at": now, - "online": False, - "last_seen": None - } - - contacts.append(new_contact) - imported.append(new_contact) - if email: - existing_emails.add(email) - - except Exception as e: - errors.append(f"Zeile {row_num}: {str(e)}") - skipped += 1 - - save_contacts(contacts) - - return CSVImportResult( - imported=len(imported), - skipped=skipped, - errors=errors[:20], # Maximal 20 Fehler zurueckgeben - contacts=imported - ) - - -@router.get("/contacts/export/csv") -async def export_contacts_csv(): - """Exportiert alle Kontakte als CSV.""" - from fastapi.responses import StreamingResponse - - contacts = get_contacts() - - output = StringIO() - fieldnames = ['name', 'email', 'phone', 'role', 'student_name', 'class_name', 'notes', 'tags', 'matrix_id', 'preferred_channel'] - writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=';') - writer.writeheader() - - for contact in contacts: - writer.writerow({ - 'name': contact.get('name', ''), - 'email': contact.get('email', ''), - 'phone': contact.get('phone', ''), - 'role': contact.get('role', ''), - 'student_name': contact.get('student_name', ''), - 'class_name': contact.get('class_name', ''), - 'notes': contact.get('notes', ''), - 'tags': ','.join(contact.get('tags', [])), - 'matrix_id': contact.get('matrix_id', ''), - 'preferred_channel': contact.get('preferred_channel', 'email') - }) - - output.seek(0) - - return StreamingResponse( - iter([output.getvalue()]), - media_type="text/csv", - headers={"Content-Disposition": "attachment; filename=kontakte.csv"} - ) - - -# ========================================== -# GROUPS ENDPOINTS -# ========================================== - -@router.get("/groups", response_model=List[Group]) -async def list_groups(): - """Listet alle Gruppen auf.""" - return get_groups() - - -@router.post("/groups", response_model=Group) -async def create_group(group: GroupCreate): - """Erstellt eine neue Gruppe.""" - groups = get_groups() - - now = datetime.utcnow().isoformat() - new_group = { - "id": str(uuid.uuid4()), - "created_at": now, - "updated_at": now, - **group.dict() - } - - groups.append(new_group) - save_groups(groups) - - return new_group - - -@router.put("/groups/{group_id}/members") -async def update_group_members(group_id: str, member_ids: List[str]): - """Aktualisiert die Mitglieder einer Gruppe.""" - groups = get_groups() - group_idx = next((i for i, g in enumerate(groups) if g["id"] == group_id), None) - - if group_idx is None: - raise HTTPException(status_code=404, detail="Gruppe nicht gefunden") - - groups[group_idx]["member_ids"] = member_ids - groups[group_idx]["updated_at"] = datetime.utcnow().isoformat() - - save_groups(groups) - return groups[group_idx] - - -@router.delete("/groups/{group_id}") -async def delete_group(group_id: str): - """Loescht eine Gruppe.""" - groups = get_groups() - groups = [g for g in groups if g["id"] != group_id] - save_groups(groups) - - return {"status": "deleted", "id": group_id} - - -# ========================================== -# CONVERSATIONS ENDPOINTS -# ========================================== - -@router.get("/conversations", response_model=List[Conversation]) -async def list_conversations(): - """Listet alle Konversationen auf.""" - conversations = get_conversations() - messages = get_messages() - - # Unread count und letzte Nachricht hinzufuegen - for conv in conversations: - conv_messages = [m for m in messages if m.get("conversation_id") == conv["id"]] - conv["unread_count"] = len([m for m in conv_messages if not m.get("read") and m.get("sender_id") != "self"]) - - if conv_messages: - last_msg = max(conv_messages, key=lambda m: m.get("timestamp", "")) - conv["last_message"] = last_msg.get("content", "")[:50] - conv["last_message_time"] = last_msg.get("timestamp") - - # Nach letzter Nachricht sortieren - conversations.sort(key=lambda c: c.get("last_message_time") or "", reverse=True) - - return conversations - - -@router.post("/conversations", response_model=Conversation) -async def create_conversation(contact_id: Optional[str] = None, group_id: Optional[str] = None): - """ - Erstellt eine neue Konversation. - Entweder mit einem Kontakt (1:1) oder einer Gruppe. - """ - conversations = get_conversations() - - if not contact_id and not group_id: - raise HTTPException(status_code=400, detail="Entweder contact_id oder group_id erforderlich") - - # Pruefen ob Konversation bereits existiert - if contact_id: - existing = next((c for c in conversations - if not c.get("is_group") and contact_id in c.get("participant_ids", [])), None) - if existing: - return existing - - now = datetime.utcnow().isoformat() - - if group_id: - groups = get_groups() - group = next((g for g in groups if g["id"] == group_id), None) - if not group: - raise HTTPException(status_code=404, detail="Gruppe nicht gefunden") - - new_conv = { - "id": str(uuid.uuid4()), - "name": group.get("name"), - "is_group": True, - "participant_ids": group.get("member_ids", []), - "group_id": group_id, - "created_at": now, - "updated_at": now, - "last_message": None, - "last_message_time": None, - "unread_count": 0 - } - else: - contacts = get_contacts() - contact = next((c for c in contacts if c["id"] == contact_id), None) - if not contact: - raise HTTPException(status_code=404, detail="Kontakt nicht gefunden") - - new_conv = { - "id": str(uuid.uuid4()), - "name": contact.get("name"), - "is_group": False, - "participant_ids": [contact_id], - "group_id": None, - "created_at": now, - "updated_at": now, - "last_message": None, - "last_message_time": None, - "unread_count": 0 - } - - conversations.append(new_conv) - save_conversations(conversations) - - return new_conv - - -@router.get("/conversations/{conversation_id}", response_model=Conversation) -async def get_conversation(conversation_id: str): - """Ruft eine Konversation ab.""" - conversations = get_conversations() - conv = next((c for c in conversations if c["id"] == conversation_id), None) - - if not conv: - raise HTTPException(status_code=404, detail="Konversation nicht gefunden") - - return conv - - -@router.delete("/conversations/{conversation_id}") -async def delete_conversation(conversation_id: str): - """Loescht eine Konversation und alle zugehoerigen Nachrichten.""" - conversations = get_conversations() - conversations = [c for c in conversations if c["id"] != conversation_id] - save_conversations(conversations) - - messages = get_messages() - messages = [m for m in messages if m.get("conversation_id") != conversation_id] - save_messages(messages) - - return {"status": "deleted", "id": conversation_id} - - -# ========================================== -# MESSAGES ENDPOINTS -# ========================================== - -@router.get("/conversations/{conversation_id}/messages", response_model=List[Message]) -async def list_messages( - conversation_id: str, - limit: int = Query(50, ge=1, le=200), - before: Optional[str] = Query(None, description="Load messages before this timestamp") -): - """Ruft Nachrichten einer Konversation ab.""" - messages = get_messages() - conv_messages = [m for m in messages if m.get("conversation_id") == conversation_id] - - if before: - conv_messages = [m for m in conv_messages if m.get("timestamp", "") < before] - - # Nach Zeit sortieren (neueste zuletzt) - conv_messages.sort(key=lambda m: m.get("timestamp", "")) - - return conv_messages[-limit:] - - -@router.post("/conversations/{conversation_id}/messages", response_model=Message) -async def send_message(conversation_id: str, message: MessageBase): - """ - Sendet eine Nachricht in einer Konversation. - - Wenn send_email=True und der Kontakt eine Email-Adresse hat, - wird die Nachricht auch per Email versendet. - """ - conversations = get_conversations() - conv = next((c for c in conversations if c["id"] == conversation_id), None) - - if not conv: - raise HTTPException(status_code=404, detail="Konversation nicht gefunden") - - now = datetime.utcnow().isoformat() - - new_message = { - "id": str(uuid.uuid4()), - "conversation_id": conversation_id, - "sender_id": "self", - "timestamp": now, - "read": True, - "read_at": now, - "email_sent": False, - "email_sent_at": None, - "email_error": None, - **message.dict() - } - - # Email-Versand wenn gewuenscht - if message.send_email and not conv.get("is_group"): - # Kontakt laden - participant_ids = conv.get("participant_ids", []) - if participant_ids: - contacts = get_contacts() - contact = next((c for c in contacts if c["id"] == participant_ids[0]), None) - - if contact and contact.get("email"): - try: - from email_service import email_service - - result = email_service.send_messenger_notification( - to_email=contact["email"], - to_name=contact.get("name", ""), - sender_name="BreakPilot Lehrer", # TODO: Aktuellen User-Namen verwenden - message_content=message.content - ) - - if result.success: - new_message["email_sent"] = True - new_message["email_sent_at"] = result.sent_at - else: - new_message["email_error"] = result.error - - except Exception as e: - new_message["email_error"] = str(e) - - messages = get_messages() - messages.append(new_message) - save_messages(messages) - - # Konversation aktualisieren - conv_idx = next(i for i, c in enumerate(conversations) if c["id"] == conversation_id) - conversations[conv_idx]["last_message"] = message.content[:50] - conversations[conv_idx]["last_message_time"] = now - conversations[conv_idx]["updated_at"] = now - save_conversations(conversations) - - return new_message - - -@router.put("/messages/{message_id}/read") -async def mark_message_read(message_id: str): - """Markiert eine Nachricht als gelesen.""" - messages = get_messages() - msg_idx = next((i for i, m in enumerate(messages) if m["id"] == message_id), None) - - if msg_idx is None: - raise HTTPException(status_code=404, detail="Nachricht nicht gefunden") - - messages[msg_idx]["read"] = True - messages[msg_idx]["read_at"] = datetime.utcnow().isoformat() - save_messages(messages) - - return {"status": "read", "id": message_id} - - -@router.put("/conversations/{conversation_id}/read-all") -async def mark_all_messages_read(conversation_id: str): - """Markiert alle Nachrichten einer Konversation als gelesen.""" - messages = get_messages() - now = datetime.utcnow().isoformat() - - for msg in messages: - if msg.get("conversation_id") == conversation_id and not msg.get("read"): - msg["read"] = True - msg["read_at"] = now - - save_messages(messages) - - return {"status": "all_read", "conversation_id": conversation_id} - - -# ========================================== -# TEMPLATES ENDPOINTS -# ========================================== - -DEFAULT_TEMPLATES = [ - { - "id": "1", - "name": "Terminbestaetigung", - "content": "Vielen Dank fuer Ihre Terminanfrage. Ich bestaetige den Termin am [DATUM] um [UHRZEIT]. Bitte geben Sie mir Bescheid, falls sich etwas aendern sollte.", - "category": "termin" - }, - { - "id": "2", - "name": "Hausaufgaben-Info", - "content": "Zur Information: Die Hausaufgaben fuer diese Woche umfassen [THEMA]. Abgabetermin ist [DATUM]. Bei Fragen stehe ich gerne zur Verfuegung.", - "category": "hausaufgaben" - }, - { - "id": "3", - "name": "Entschuldigung bestaetigen", - "content": "Ich bestaetige den Erhalt der Entschuldigung fuer [NAME] am [DATUM]. Die Fehlzeiten wurden entsprechend vermerkt.", - "category": "entschuldigung" - }, - { - "id": "4", - "name": "Gespraechsanfrage", - "content": "Ich wuerde gerne einen Termin fuer ein Gespraech mit Ihnen vereinbaren, um [THEMA] zu besprechen. Waeren Sie am [DATUM] um [UHRZEIT] verfuegbar?", - "category": "gespraech" - }, - { - "id": "5", - "name": "Krankmeldung bestaetigen", - "content": "Vielen Dank fuer Ihre Krankmeldung fuer [NAME]. Ich wuensche gute Besserung. Bitte reichen Sie eine schriftliche Entschuldigung nach, sobald Ihr Kind wieder gesund ist.", - "category": "krankmeldung" - } -] - - -@router.get("/templates") -async def list_templates(): - """Listet alle Nachrichtenvorlagen auf.""" - templates_file = DATA_DIR / "templates.json" - if templates_file.exists(): - templates = load_json(templates_file) - else: - templates = DEFAULT_TEMPLATES - save_json(templates_file, templates) - - return templates - - -@router.post("/templates") -async def create_template(name: str, content: str, category: str = "custom"): - """Erstellt eine neue Vorlage.""" - templates_file = DATA_DIR / "templates.json" - templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy() - - new_template = { - "id": str(uuid.uuid4()), - "name": name, - "content": content, - "category": category - } - - templates.append(new_template) - save_json(templates_file, templates) - - return new_template - - -@router.delete("/templates/{template_id}") -async def delete_template(template_id: str): - """Loescht eine Vorlage.""" - templates_file = DATA_DIR / "templates.json" - templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy() - - templates = [t for t in templates if t["id"] != template_id] - save_json(templates_file, templates) - - return {"status": "deleted", "id": template_id} - - -# ========================================== -# STATS ENDPOINT -# ========================================== - -@router.get("/stats") -async def get_messenger_stats(): - """Gibt Statistiken zum Messenger zurueck.""" - contacts = get_contacts() - conversations = get_conversations() - messages = get_messages() - groups = get_groups() - - unread_total = sum(1 for m in messages if not m.get("read") and m.get("sender_id") != "self") - - return { - "total_contacts": len(contacts), - "total_groups": len(groups), - "total_conversations": len(conversations), - "total_messages": len(messages), - "unread_messages": unread_total, - "contacts_by_role": { - role: len([c for c in contacts if c.get("role") == role]) - for role in set(c.get("role", "parent") for c in contacts) - } - } +router.include_router(_contacts_router) +router.include_router(_conversations_router) diff --git a/backend-lehrer/messenger_contacts.py b/backend-lehrer/messenger_contacts.py new file mode 100644 index 0000000..ae9625c --- /dev/null +++ b/backend-lehrer/messenger_contacts.py @@ -0,0 +1,251 @@ +""" +Messenger API - Contact Routes. + +CRUD, CSV import/export for contacts. +""" + +import csv +import uuid +from io import StringIO +from datetime import datetime +from typing import List, Optional + +from fastapi import APIRouter, HTTPException, UploadFile, File, Query +from fastapi.responses import StreamingResponse + +from messenger_models import ( + Contact, + ContactCreate, + ContactUpdate, + CSVImportResult, +) +from messenger_helpers import get_contacts, save_contacts + +router = APIRouter(tags=["Messenger"]) + + +# ========================================== +# CONTACTS ENDPOINTS +# ========================================== + +@router.get("/contacts", response_model=List[Contact]) +async def list_contacts( + role: Optional[str] = Query(None, description="Filter by role"), + class_name: Optional[str] = Query(None, description="Filter by class"), + search: Optional[str] = Query(None, description="Search in name/email") +): + """Listet alle Kontakte auf.""" + contacts = get_contacts() + + # Filter anwenden + if role: + contacts = [c for c in contacts if c.get("role") == role] + if class_name: + contacts = [c for c in contacts if c.get("class_name") == class_name] + if search: + search_lower = search.lower() + contacts = [c for c in contacts if + search_lower in c.get("name", "").lower() or + search_lower in (c.get("email") or "").lower() or + search_lower in (c.get("student_name") or "").lower()] + + return contacts + + +@router.post("/contacts", response_model=Contact) +async def create_contact(contact: ContactCreate): + """Erstellt einen neuen Kontakt.""" + contacts = get_contacts() + + # Pruefen ob Email bereits existiert + if contact.email: + existing = [c for c in contacts if c.get("email") == contact.email] + if existing: + raise HTTPException(status_code=400, detail="Kontakt mit dieser Email existiert bereits") + + now = datetime.utcnow().isoformat() + new_contact = { + "id": str(uuid.uuid4()), + "created_at": now, + "updated_at": now, + "online": False, + "last_seen": None, + **contact.dict() + } + + contacts.append(new_contact) + save_contacts(contacts) + + return new_contact + + +@router.get("/contacts/{contact_id}", response_model=Contact) +async def get_contact(contact_id: str): + """Ruft einen einzelnen Kontakt ab.""" + contacts = get_contacts() + contact = next((c for c in contacts if c["id"] == contact_id), None) + + if not contact: + raise HTTPException(status_code=404, detail="Kontakt nicht gefunden") + + return contact + + +@router.put("/contacts/{contact_id}", response_model=Contact) +async def update_contact(contact_id: str, update: ContactUpdate): + """Aktualisiert einen Kontakt.""" + contacts = get_contacts() + contact_idx = next((i for i, c in enumerate(contacts) if c["id"] == contact_id), None) + + if contact_idx is None: + raise HTTPException(status_code=404, detail="Kontakt nicht gefunden") + + update_data = update.dict(exclude_unset=True) + contacts[contact_idx].update(update_data) + contacts[contact_idx]["updated_at"] = datetime.utcnow().isoformat() + + save_contacts(contacts) + return contacts[contact_idx] + + +@router.delete("/contacts/{contact_id}") +async def delete_contact(contact_id: str): + """Loescht einen Kontakt.""" + contacts = get_contacts() + contacts = [c for c in contacts if c["id"] != contact_id] + save_contacts(contacts) + + return {"status": "deleted", "id": contact_id} + + +@router.post("/contacts/import", response_model=CSVImportResult) +async def import_contacts_csv(file: UploadFile = File(...)): + """ + Importiert Kontakte aus einer CSV-Datei. + + Erwartete Spalten: + - name (required) + - email + - phone + - role (parent/teacher/staff/student) + - student_name + - class_name + - notes + - tags (komma-separiert) + """ + if not file.filename.endswith('.csv'): + raise HTTPException(status_code=400, detail="Nur CSV-Dateien werden unterstuetzt") + + content = await file.read() + try: + text = content.decode('utf-8') + except UnicodeDecodeError: + text = content.decode('latin-1') + + contacts = get_contacts() + existing_emails = {c.get("email") for c in contacts if c.get("email")} + + imported = [] + skipped = 0 + errors = [] + + reader = csv.DictReader(StringIO(text), delimiter=';') # Deutsche CSV meist mit Semikolon + if not reader.fieldnames or 'name' not in [f.lower() for f in reader.fieldnames]: + # Versuche mit Komma + reader = csv.DictReader(StringIO(text), delimiter=',') + + for row_num, row in enumerate(reader, start=2): + try: + # Normalisiere Spaltennamen + row = {k.lower().strip(): v.strip() if v else "" for k, v in row.items()} + + name = row.get('name') or row.get('kontakt') or row.get('elternname') + if not name: + errors.append(f"Zeile {row_num}: Name fehlt") + skipped += 1 + continue + + email = row.get('email') or row.get('e-mail') or row.get('mail') + if email and email in existing_emails: + errors.append(f"Zeile {row_num}: Email {email} existiert bereits") + skipped += 1 + continue + + now = datetime.utcnow().isoformat() + tags_str = row.get('tags') or row.get('kategorien') or "" + tags = [t.strip() for t in tags_str.split(',') if t.strip()] + + # Matrix-ID und preferred_channel auslesen + matrix_id = row.get('matrix_id') or row.get('matrix') or None + preferred_channel = row.get('preferred_channel') or row.get('kanal') or "email" + if preferred_channel not in ["email", "matrix", "pwa"]: + preferred_channel = "email" + + new_contact = { + "id": str(uuid.uuid4()), + "name": name, + "email": email if email else None, + "phone": row.get('phone') or row.get('telefon') or row.get('tel'), + "role": row.get('role') or row.get('rolle') or "parent", + "student_name": row.get('student_name') or row.get('schueler') or row.get('kind'), + "class_name": row.get('class_name') or row.get('klasse'), + "notes": row.get('notes') or row.get('notizen') or row.get('bemerkungen'), + "tags": tags, + "matrix_id": matrix_id if matrix_id else None, + "preferred_channel": preferred_channel, + "created_at": now, + "updated_at": now, + "online": False, + "last_seen": None + } + + contacts.append(new_contact) + imported.append(new_contact) + if email: + existing_emails.add(email) + + except Exception as e: + errors.append(f"Zeile {row_num}: {str(e)}") + skipped += 1 + + save_contacts(contacts) + + return CSVImportResult( + imported=len(imported), + skipped=skipped, + errors=errors[:20], # Maximal 20 Fehler zurueckgeben + contacts=imported + ) + + +@router.get("/contacts/export/csv") +async def export_contacts_csv(): + """Exportiert alle Kontakte als CSV.""" + contacts = get_contacts() + + output = StringIO() + fieldnames = ['name', 'email', 'phone', 'role', 'student_name', 'class_name', 'notes', 'tags', 'matrix_id', 'preferred_channel'] + writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=';') + writer.writeheader() + + for contact in contacts: + writer.writerow({ + 'name': contact.get('name', ''), + 'email': contact.get('email', ''), + 'phone': contact.get('phone', ''), + 'role': contact.get('role', ''), + 'student_name': contact.get('student_name', ''), + 'class_name': contact.get('class_name', ''), + 'notes': contact.get('notes', ''), + 'tags': ','.join(contact.get('tags', [])), + 'matrix_id': contact.get('matrix_id', ''), + 'preferred_channel': contact.get('preferred_channel', 'email') + }) + + output.seek(0) + + return StreamingResponse( + iter([output.getvalue()]), + media_type="text/csv", + headers={"Content-Disposition": "attachment; filename=kontakte.csv"} + ) diff --git a/backend-lehrer/messenger_conversations.py b/backend-lehrer/messenger_conversations.py new file mode 100644 index 0000000..f7da2ed --- /dev/null +++ b/backend-lehrer/messenger_conversations.py @@ -0,0 +1,405 @@ +""" +Messenger API - Conversation, Message, Group, Template & Stats Routes. + +Conversations CRUD, message send/read, groups, templates, stats. +""" + +import uuid +from datetime import datetime +from typing import List, Optional + +from fastapi import APIRouter, HTTPException, Query + +from messenger_models import ( + Conversation, + Group, + GroupCreate, + Message, + MessageBase, +) +from messenger_helpers import ( + DATA_DIR, + DEFAULT_TEMPLATES, + get_contacts, + get_conversations, + save_conversations, + get_messages, + save_messages, + get_groups, + save_groups, + load_json, + save_json, +) + +router = APIRouter(tags=["Messenger"]) + + +# ========================================== +# GROUPS ENDPOINTS +# ========================================== + +@router.get("/groups", response_model=List[Group]) +async def list_groups(): + """Listet alle Gruppen auf.""" + return get_groups() + + +@router.post("/groups", response_model=Group) +async def create_group(group: GroupCreate): + """Erstellt eine neue Gruppe.""" + groups = get_groups() + + now = datetime.utcnow().isoformat() + new_group = { + "id": str(uuid.uuid4()), + "created_at": now, + "updated_at": now, + **group.dict() + } + + groups.append(new_group) + save_groups(groups) + + return new_group + + +@router.put("/groups/{group_id}/members") +async def update_group_members(group_id: str, member_ids: List[str]): + """Aktualisiert die Mitglieder einer Gruppe.""" + groups = get_groups() + group_idx = next((i for i, g in enumerate(groups) if g["id"] == group_id), None) + + if group_idx is None: + raise HTTPException(status_code=404, detail="Gruppe nicht gefunden") + + groups[group_idx]["member_ids"] = member_ids + groups[group_idx]["updated_at"] = datetime.utcnow().isoformat() + + save_groups(groups) + return groups[group_idx] + + +@router.delete("/groups/{group_id}") +async def delete_group(group_id: str): + """Loescht eine Gruppe.""" + groups = get_groups() + groups = [g for g in groups if g["id"] != group_id] + save_groups(groups) + + return {"status": "deleted", "id": group_id} + + +# ========================================== +# CONVERSATIONS ENDPOINTS +# ========================================== + +@router.get("/conversations", response_model=List[Conversation]) +async def list_conversations(): + """Listet alle Konversationen auf.""" + conversations = get_conversations() + messages = get_messages() + + # Unread count und letzte Nachricht hinzufuegen + for conv in conversations: + conv_messages = [m for m in messages if m.get("conversation_id") == conv["id"]] + conv["unread_count"] = len([m for m in conv_messages if not m.get("read") and m.get("sender_id") != "self"]) + + if conv_messages: + last_msg = max(conv_messages, key=lambda m: m.get("timestamp", "")) + conv["last_message"] = last_msg.get("content", "")[:50] + conv["last_message_time"] = last_msg.get("timestamp") + + # Nach letzter Nachricht sortieren + conversations.sort(key=lambda c: c.get("last_message_time") or "", reverse=True) + + return conversations + + +@router.post("/conversations", response_model=Conversation) +async def create_conversation(contact_id: Optional[str] = None, group_id: Optional[str] = None): + """ + Erstellt eine neue Konversation. + Entweder mit einem Kontakt (1:1) oder einer Gruppe. + """ + conversations = get_conversations() + + if not contact_id and not group_id: + raise HTTPException(status_code=400, detail="Entweder contact_id oder group_id erforderlich") + + # Pruefen ob Konversation bereits existiert + if contact_id: + existing = next((c for c in conversations + if not c.get("is_group") and contact_id in c.get("participant_ids", [])), None) + if existing: + return existing + + now = datetime.utcnow().isoformat() + + if group_id: + groups = get_groups() + group = next((g for g in groups if g["id"] == group_id), None) + if not group: + raise HTTPException(status_code=404, detail="Gruppe nicht gefunden") + + new_conv = { + "id": str(uuid.uuid4()), + "name": group.get("name"), + "is_group": True, + "participant_ids": group.get("member_ids", []), + "group_id": group_id, + "created_at": now, + "updated_at": now, + "last_message": None, + "last_message_time": None, + "unread_count": 0 + } + else: + contacts = get_contacts() + contact = next((c for c in contacts if c["id"] == contact_id), None) + if not contact: + raise HTTPException(status_code=404, detail="Kontakt nicht gefunden") + + new_conv = { + "id": str(uuid.uuid4()), + "name": contact.get("name"), + "is_group": False, + "participant_ids": [contact_id], + "group_id": None, + "created_at": now, + "updated_at": now, + "last_message": None, + "last_message_time": None, + "unread_count": 0 + } + + conversations.append(new_conv) + save_conversations(conversations) + + return new_conv + + +@router.get("/conversations/{conversation_id}", response_model=Conversation) +async def get_conversation(conversation_id: str): + """Ruft eine Konversation ab.""" + conversations = get_conversations() + conv = next((c for c in conversations if c["id"] == conversation_id), None) + + if not conv: + raise HTTPException(status_code=404, detail="Konversation nicht gefunden") + + return conv + + +@router.delete("/conversations/{conversation_id}") +async def delete_conversation(conversation_id: str): + """Loescht eine Konversation und alle zugehoerigen Nachrichten.""" + conversations = get_conversations() + conversations = [c for c in conversations if c["id"] != conversation_id] + save_conversations(conversations) + + messages = get_messages() + messages = [m for m in messages if m.get("conversation_id") != conversation_id] + save_messages(messages) + + return {"status": "deleted", "id": conversation_id} + + +# ========================================== +# MESSAGES ENDPOINTS +# ========================================== + +@router.get("/conversations/{conversation_id}/messages", response_model=List[Message]) +async def list_messages( + conversation_id: str, + limit: int = Query(50, ge=1, le=200), + before: Optional[str] = Query(None, description="Load messages before this timestamp") +): + """Ruft Nachrichten einer Konversation ab.""" + messages = get_messages() + conv_messages = [m for m in messages if m.get("conversation_id") == conversation_id] + + if before: + conv_messages = [m for m in conv_messages if m.get("timestamp", "") < before] + + # Nach Zeit sortieren (neueste zuletzt) + conv_messages.sort(key=lambda m: m.get("timestamp", "")) + + return conv_messages[-limit:] + + +@router.post("/conversations/{conversation_id}/messages", response_model=Message) +async def send_message(conversation_id: str, message: MessageBase): + """ + Sendet eine Nachricht in einer Konversation. + + Wenn send_email=True und der Kontakt eine Email-Adresse hat, + wird die Nachricht auch per Email versendet. + """ + conversations = get_conversations() + conv = next((c for c in conversations if c["id"] == conversation_id), None) + + if not conv: + raise HTTPException(status_code=404, detail="Konversation nicht gefunden") + + now = datetime.utcnow().isoformat() + + new_message = { + "id": str(uuid.uuid4()), + "conversation_id": conversation_id, + "sender_id": "self", + "timestamp": now, + "read": True, + "read_at": now, + "email_sent": False, + "email_sent_at": None, + "email_error": None, + **message.dict() + } + + # Email-Versand wenn gewuenscht + if message.send_email and not conv.get("is_group"): + # Kontakt laden + participant_ids = conv.get("participant_ids", []) + if participant_ids: + contacts = get_contacts() + contact = next((c for c in contacts if c["id"] == participant_ids[0]), None) + + if contact and contact.get("email"): + try: + from email_service import email_service + + result = email_service.send_messenger_notification( + to_email=contact["email"], + to_name=contact.get("name", ""), + sender_name="BreakPilot Lehrer", + message_content=message.content + ) + + if result.success: + new_message["email_sent"] = True + new_message["email_sent_at"] = result.sent_at + else: + new_message["email_error"] = result.error + + except Exception as e: + new_message["email_error"] = str(e) + + messages = get_messages() + messages.append(new_message) + save_messages(messages) + + # Konversation aktualisieren + conv_idx = next(i for i, c in enumerate(conversations) if c["id"] == conversation_id) + conversations[conv_idx]["last_message"] = message.content[:50] + conversations[conv_idx]["last_message_time"] = now + conversations[conv_idx]["updated_at"] = now + save_conversations(conversations) + + return new_message + + +@router.put("/messages/{message_id}/read") +async def mark_message_read(message_id: str): + """Markiert eine Nachricht als gelesen.""" + messages = get_messages() + msg_idx = next((i for i, m in enumerate(messages) if m["id"] == message_id), None) + + if msg_idx is None: + raise HTTPException(status_code=404, detail="Nachricht nicht gefunden") + + messages[msg_idx]["read"] = True + messages[msg_idx]["read_at"] = datetime.utcnow().isoformat() + save_messages(messages) + + return {"status": "read", "id": message_id} + + +@router.put("/conversations/{conversation_id}/read-all") +async def mark_all_messages_read(conversation_id: str): + """Markiert alle Nachrichten einer Konversation als gelesen.""" + messages = get_messages() + now = datetime.utcnow().isoformat() + + for msg in messages: + if msg.get("conversation_id") == conversation_id and not msg.get("read"): + msg["read"] = True + msg["read_at"] = now + + save_messages(messages) + + return {"status": "all_read", "conversation_id": conversation_id} + + +# ========================================== +# TEMPLATES ENDPOINTS +# ========================================== + +@router.get("/templates") +async def list_templates(): + """Listet alle Nachrichtenvorlagen auf.""" + templates_file = DATA_DIR / "templates.json" + if templates_file.exists(): + templates = load_json(templates_file) + else: + templates = DEFAULT_TEMPLATES + save_json(templates_file, templates) + + return templates + + +@router.post("/templates") +async def create_template(name: str, content: str, category: str = "custom"): + """Erstellt eine neue Vorlage.""" + templates_file = DATA_DIR / "templates.json" + templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy() + + new_template = { + "id": str(uuid.uuid4()), + "name": name, + "content": content, + "category": category + } + + templates.append(new_template) + save_json(templates_file, templates) + + return new_template + + +@router.delete("/templates/{template_id}") +async def delete_template(template_id: str): + """Loescht eine Vorlage.""" + templates_file = DATA_DIR / "templates.json" + templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy() + + templates = [t for t in templates if t["id"] != template_id] + save_json(templates_file, templates) + + return {"status": "deleted", "id": template_id} + + +# ========================================== +# STATS ENDPOINT +# ========================================== + +@router.get("/stats") +async def get_messenger_stats(): + """Gibt Statistiken zum Messenger zurueck.""" + contacts = get_contacts() + conversations = get_conversations() + messages = get_messages() + groups = get_groups() + + unread_total = sum(1 for m in messages if not m.get("read") and m.get("sender_id") != "self") + + return { + "total_contacts": len(contacts), + "total_groups": len(groups), + "total_conversations": len(conversations), + "total_messages": len(messages), + "unread_messages": unread_total, + "contacts_by_role": { + role: len([c for c in contacts if c.get("role") == role]) + for role in set(c.get("role", "parent") for c in contacts) + } + } diff --git a/backend-lehrer/messenger_helpers.py b/backend-lehrer/messenger_helpers.py new file mode 100644 index 0000000..1e4725f --- /dev/null +++ b/backend-lehrer/messenger_helpers.py @@ -0,0 +1,105 @@ +""" +Messenger API - Data Helpers. + +JSON-based file storage for contacts, conversations, messages, and groups. +""" + +import json +from typing import List, Dict +from pathlib import Path + +# Datenspeicherung (JSON-basiert fuer einfache Persistenz) +DATA_DIR = Path(__file__).parent / "data" / "messenger" +DATA_DIR.mkdir(parents=True, exist_ok=True) + +CONTACTS_FILE = DATA_DIR / "contacts.json" +CONVERSATIONS_FILE = DATA_DIR / "conversations.json" +MESSAGES_FILE = DATA_DIR / "messages.json" +GROUPS_FILE = DATA_DIR / "groups.json" + + +def load_json(filepath: Path) -> List[Dict]: + """Laedt JSON-Daten aus Datei.""" + if not filepath.exists(): + return [] + try: + with open(filepath, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return [] + + +def save_json(filepath: Path, data: List[Dict]): + """Speichert Daten in JSON-Datei.""" + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + +def get_contacts() -> List[Dict]: + return load_json(CONTACTS_FILE) + + +def save_contacts(contacts: List[Dict]): + save_json(CONTACTS_FILE, contacts) + + +def get_conversations() -> List[Dict]: + return load_json(CONVERSATIONS_FILE) + + +def save_conversations(conversations: List[Dict]): + save_json(CONVERSATIONS_FILE, conversations) + + +def get_messages() -> List[Dict]: + return load_json(MESSAGES_FILE) + + +def save_messages(messages: List[Dict]): + save_json(MESSAGES_FILE, messages) + + +def get_groups() -> List[Dict]: + return load_json(GROUPS_FILE) + + +def save_groups(groups: List[Dict]): + save_json(GROUPS_FILE, groups) + + +# ========================================== +# DEFAULT TEMPLATES +# ========================================== + +DEFAULT_TEMPLATES = [ + { + "id": "1", + "name": "Terminbestaetigung", + "content": "Vielen Dank fuer Ihre Terminanfrage. Ich bestaetige den Termin am [DATUM] um [UHRZEIT]. Bitte geben Sie mir Bescheid, falls sich etwas aendern sollte.", + "category": "termin" + }, + { + "id": "2", + "name": "Hausaufgaben-Info", + "content": "Zur Information: Die Hausaufgaben fuer diese Woche umfassen [THEMA]. Abgabetermin ist [DATUM]. Bei Fragen stehe ich gerne zur Verfuegung.", + "category": "hausaufgaben" + }, + { + "id": "3", + "name": "Entschuldigung bestaetigen", + "content": "Ich bestaetige den Erhalt der Entschuldigung fuer [NAME] am [DATUM]. Die Fehlzeiten wurden entsprechend vermerkt.", + "category": "entschuldigung" + }, + { + "id": "4", + "name": "Gespraechsanfrage", + "content": "Ich wuerde gerne einen Termin fuer ein Gespraech mit Ihnen vereinbaren, um [THEMA] zu besprechen. Waeren Sie am [DATUM] um [UHRZEIT] verfuegbar?", + "category": "gespraech" + }, + { + "id": "5", + "name": "Krankmeldung bestaetigen", + "content": "Vielen Dank fuer Ihre Krankmeldung fuer [NAME]. Ich wuensche gute Besserung. Bitte reichen Sie eine schriftliche Entschuldigung nach, sobald Ihr Kind wieder gesund ist.", + "category": "krankmeldung" + } +] diff --git a/backend-lehrer/messenger_models.py b/backend-lehrer/messenger_models.py new file mode 100644 index 0000000..275c700 --- /dev/null +++ b/backend-lehrer/messenger_models.py @@ -0,0 +1,139 @@ +""" +Messenger API - Pydantic Models. + +Data models for contacts, conversations, messages, and groups. +""" + +from typing import List, Optional + +from pydantic import BaseModel, Field + + +# ========================================== +# CONTACT MODELS +# ========================================== + +class ContactBase(BaseModel): + """Basis-Modell fuer Kontakte.""" + name: str = Field(..., min_length=1, max_length=200) + email: Optional[str] = None + phone: Optional[str] = None + role: str = Field(default="parent", description="parent, teacher, staff, student") + student_name: Optional[str] = Field(None, description="Name des zugehoerigen Schuelers") + class_name: Optional[str] = Field(None, description="Klasse z.B. 10a") + notes: Optional[str] = None + tags: List[str] = Field(default_factory=list) + matrix_id: Optional[str] = Field(None, description="Matrix-ID z.B. @user:matrix.org") + preferred_channel: str = Field(default="email", description="email, matrix, pwa") + + +class ContactCreate(ContactBase): + """Model fuer neuen Kontakt.""" + pass + + +class Contact(ContactBase): + """Vollstaendiger Kontakt mit ID.""" + id: str + created_at: str + updated_at: str + online: bool = False + last_seen: Optional[str] = None + + +class ContactUpdate(BaseModel): + """Update-Model fuer Kontakte.""" + name: Optional[str] = None + email: Optional[str] = None + phone: Optional[str] = None + role: Optional[str] = None + student_name: Optional[str] = None + class_name: Optional[str] = None + notes: Optional[str] = None + tags: Optional[List[str]] = None + matrix_id: Optional[str] = None + preferred_channel: Optional[str] = None + + +# ========================================== +# GROUP MODELS +# ========================================== + +class GroupBase(BaseModel): + """Basis-Modell fuer Gruppen.""" + name: str = Field(..., min_length=1, max_length=100) + description: Optional[str] = None + group_type: str = Field(default="class", description="class, department, custom") + + +class GroupCreate(GroupBase): + """Model fuer neue Gruppe.""" + member_ids: List[str] = Field(default_factory=list) + + +class Group(GroupBase): + """Vollstaendige Gruppe mit ID.""" + id: str + member_ids: List[str] = [] + created_at: str + updated_at: str + + +# ========================================== +# MESSAGE MODELS +# ========================================== + +class MessageBase(BaseModel): + """Basis-Modell fuer Nachrichten.""" + content: str = Field(..., min_length=1) + content_type: str = Field(default="text", description="text, file, image") + file_url: Optional[str] = None + send_email: bool = Field(default=False, description="Nachricht auch per Email senden") + + +class MessageCreate(MessageBase): + """Model fuer neue Nachricht.""" + conversation_id: str + + +class Message(MessageBase): + """Vollstaendige Nachricht mit ID.""" + id: str + conversation_id: str + sender_id: str # "self" fuer eigene Nachrichten + timestamp: str + read: bool = False + read_at: Optional[str] = None + email_sent: bool = False + email_sent_at: Optional[str] = None + email_error: Optional[str] = None + + +# ========================================== +# CONVERSATION MODELS +# ========================================== + +class ConversationBase(BaseModel): + """Basis-Modell fuer Konversationen.""" + name: Optional[str] = None + is_group: bool = False + + +class Conversation(ConversationBase): + """Vollstaendige Konversation mit ID.""" + id: str + participant_ids: List[str] = [] + group_id: Optional[str] = None + created_at: str + updated_at: str + last_message: Optional[str] = None + last_message_time: Optional[str] = None + unread_count: int = 0 + + +class CSVImportResult(BaseModel): + """Ergebnis eines CSV-Imports.""" + imported: int + skipped: int + errors: List[str] + contacts: List[Contact] diff --git a/backend-lehrer/recording_api.py b/backend-lehrer/recording_api.py index 22f0d4c..cbaf3ed 100644 --- a/backend-lehrer/recording_api.py +++ b/backend-lehrer/recording_api.py @@ -1,848 +1,22 @@ """ -BreakPilot Recording API +BreakPilot Recording API — Barrel Re-export. Verwaltet Jibri Meeting-Aufzeichnungen und deren Metadaten. -Empfaengt Webhooks von Jibri nach Upload zu MinIO. +Split into: + - recording_models.py: Pydantic models & config + - recording_helpers.py: In-memory storage & utilities + - recording_routes.py: Core recording CRUD routes + - recording_transcription.py: Transcription routes + - recording_minutes.py: Meeting minutes routes """ -import os -import uuid -from datetime import datetime, timedelta -from typing import Optional, List -from pydantic import BaseModel, Field +from fastapi import APIRouter -from fastapi import APIRouter, HTTPException, Query, Depends, Request -from fastapi.responses import JSONResponse +from recording_routes import router as _routes_router +from recording_transcription import router as _transcription_router +from recording_minutes import router as _minutes_router router = APIRouter(prefix="/api/recordings", tags=["Recordings"]) - -# ========================================== -# ENVIRONMENT CONFIGURATION -# ========================================== - -MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000") -MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "breakpilot") -MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "breakpilot123") -MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-recordings") -MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true" - -# Default retention period in days (DSGVO compliance) -DEFAULT_RETENTION_DAYS = int(os.getenv("RECORDING_RETENTION_DAYS", "365")) - - -# ========================================== -# PYDANTIC MODELS -# ========================================== - -class JibriWebhookPayload(BaseModel): - """Webhook payload from Jibri finalize.sh script.""" - event: str = Field(..., description="Event type: recording_completed") - recording_name: str = Field(..., description="Unique recording identifier") - storage_path: str = Field(..., description="Path in MinIO bucket") - audio_path: Optional[str] = Field(None, description="Extracted audio path") - file_size_bytes: int = Field(..., description="Video file size in bytes") - timestamp: str = Field(..., description="ISO timestamp of upload") - - -class RecordingCreate(BaseModel): - """Manual recording creation (for testing).""" - meeting_id: str - title: Optional[str] = None - storage_path: str - audio_path: Optional[str] = None - duration_seconds: Optional[int] = None - participant_count: Optional[int] = 0 - retention_days: Optional[int] = DEFAULT_RETENTION_DAYS - - -class RecordingResponse(BaseModel): - """Recording details response.""" - id: str - meeting_id: str - title: Optional[str] - storage_path: str - audio_path: Optional[str] - file_size_bytes: Optional[int] - duration_seconds: Optional[int] - participant_count: int - status: str - recorded_at: datetime - retention_days: int - retention_expires_at: datetime - transcription_status: Optional[str] = None - transcription_id: Optional[str] = None - - -class RecordingListResponse(BaseModel): - """Paginated list of recordings.""" - recordings: List[RecordingResponse] - total: int - page: int - page_size: int - - -class TranscriptionRequest(BaseModel): - """Request to start transcription.""" - language: str = Field(default="de", description="Language code: de, en, etc.") - model: str = Field(default="large-v3", description="Whisper model to use") - priority: int = Field(default=0, description="Queue priority (higher = sooner)") - - -class TranscriptionStatusResponse(BaseModel): - """Transcription status and progress.""" - id: str - recording_id: str - status: str - language: str - model: str - word_count: Optional[int] - confidence_score: Optional[float] - processing_duration_seconds: Optional[int] - error_message: Optional[str] - created_at: datetime - completed_at: Optional[datetime] - - -# ========================================== -# IN-MEMORY STORAGE (Dev Mode) -# ========================================== -# In production, these would be database queries - -_recordings_store: dict = {} -_transcriptions_store: dict = {} -_audit_log: list = [] - - -def log_audit( - action: str, - recording_id: Optional[str] = None, - transcription_id: Optional[str] = None, - user_id: Optional[str] = None, - metadata: Optional[dict] = None -): - """Log audit event for DSGVO compliance.""" - _audit_log.append({ - "id": str(uuid.uuid4()), - "recording_id": recording_id, - "transcription_id": transcription_id, - "user_id": user_id, - "action": action, - "metadata": metadata or {}, - "created_at": datetime.utcnow().isoformat() - }) - - -# ========================================== -# WEBHOOK ENDPOINT (Jibri) -# ========================================== - -@router.post("/webhook") -async def jibri_webhook(payload: JibriWebhookPayload, request: Request): - """ - Webhook endpoint called by Jibri finalize.sh after upload. - - This creates a new recording entry and optionally triggers transcription. - """ - if payload.event != "recording_completed": - return JSONResponse( - status_code=400, - content={"error": f"Unknown event type: {payload.event}"} - ) - - # Extract meeting_id from recording_name (format: meetingId_timestamp) - parts = payload.recording_name.split("_") - meeting_id = parts[0] if parts else payload.recording_name - - # Create recording entry - recording_id = str(uuid.uuid4()) - recorded_at = datetime.utcnow() - - recording = { - "id": recording_id, - "meeting_id": meeting_id, - "jibri_session_id": payload.recording_name, - "title": f"Recording {meeting_id}", - "storage_path": payload.storage_path, - "audio_path": payload.audio_path, - "file_size_bytes": payload.file_size_bytes, - "duration_seconds": None, # Will be updated after analysis - "participant_count": 0, - "status": "uploaded", - "recorded_at": recorded_at.isoformat(), - "retention_days": DEFAULT_RETENTION_DAYS, - "created_at": datetime.utcnow().isoformat(), - "updated_at": datetime.utcnow().isoformat() - } - - _recordings_store[recording_id] = recording - - # Log the creation - log_audit( - action="created", - recording_id=recording_id, - metadata={ - "source": "jibri_webhook", - "storage_path": payload.storage_path, - "file_size_bytes": payload.file_size_bytes - } - ) - - return { - "success": True, - "recording_id": recording_id, - "meeting_id": meeting_id, - "status": "uploaded", - "message": "Recording registered successfully" - } - - -# ========================================== -# HEALTH & AUDIT ENDPOINTS (must be before parameterized routes) -# ========================================== - -@router.get("/health") -async def recordings_health(): - """Health check for recording service.""" - return { - "status": "healthy", - "recordings_count": len(_recordings_store), - "transcriptions_count": len(_transcriptions_store), - "minio_endpoint": MINIO_ENDPOINT, - "bucket": MINIO_BUCKET - } - - -@router.get("/audit/log") -async def get_audit_log( - recording_id: Optional[str] = Query(None), - action: Optional[str] = Query(None), - limit: int = Query(100, ge=1, le=1000) -): - """ - Get audit log entries (DSGVO compliance). - - Admin-only endpoint for reviewing recording access history. - """ - logs = _audit_log.copy() - - if recording_id: - logs = [l for l in logs if l.get("recording_id") == recording_id] - if action: - logs = [l for l in logs if l.get("action") == action] - - # Sort by created_at descending - logs.sort(key=lambda x: x["created_at"], reverse=True) - - return { - "entries": logs[:limit], - "total": len(logs) - } - - -# ========================================== -# RECORDING MANAGEMENT ENDPOINTS -# ========================================== - -@router.get("", response_model=RecordingListResponse) -async def list_recordings( - status: Optional[str] = Query(None, description="Filter by status"), - meeting_id: Optional[str] = Query(None, description="Filter by meeting ID"), - page: int = Query(1, ge=1, description="Page number"), - page_size: int = Query(20, ge=1, le=100, description="Items per page") -): - """ - List all recordings with optional filtering. - - Supports pagination and filtering by status or meeting ID. - """ - # Filter recordings - recordings = list(_recordings_store.values()) - - if status: - recordings = [r for r in recordings if r["status"] == status] - if meeting_id: - recordings = [r for r in recordings if r["meeting_id"] == meeting_id] - - # Sort by recorded_at descending - recordings.sort(key=lambda x: x["recorded_at"], reverse=True) - - # Paginate - total = len(recordings) - start = (page - 1) * page_size - end = start + page_size - page_recordings = recordings[start:end] - - # Convert to response format - result = [] - for rec in page_recordings: - recorded_at = datetime.fromisoformat(rec["recorded_at"]) - retention_expires = recorded_at + timedelta(days=rec["retention_days"]) - - # Check for transcription - trans = next( - (t for t in _transcriptions_store.values() if t["recording_id"] == rec["id"]), - None - ) - - result.append(RecordingResponse( - id=rec["id"], - meeting_id=rec["meeting_id"], - title=rec.get("title"), - storage_path=rec["storage_path"], - audio_path=rec.get("audio_path"), - file_size_bytes=rec.get("file_size_bytes"), - duration_seconds=rec.get("duration_seconds"), - participant_count=rec.get("participant_count", 0), - status=rec["status"], - recorded_at=recorded_at, - retention_days=rec["retention_days"], - retention_expires_at=retention_expires, - transcription_status=trans["status"] if trans else None, - transcription_id=trans["id"] if trans else None - )) - - return RecordingListResponse( - recordings=result, - total=total, - page=page, - page_size=page_size - ) - - -@router.get("/{recording_id}", response_model=RecordingResponse) -async def get_recording(recording_id: str): - """ - Get details for a specific recording. - """ - recording = _recordings_store.get(recording_id) - if not recording: - raise HTTPException(status_code=404, detail="Recording not found") - - # Log view action - log_audit(action="viewed", recording_id=recording_id) - - recorded_at = datetime.fromisoformat(recording["recorded_at"]) - retention_expires = recorded_at + timedelta(days=recording["retention_days"]) - - # Check for transcription - trans = next( - (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), - None - ) - - return RecordingResponse( - id=recording["id"], - meeting_id=recording["meeting_id"], - title=recording.get("title"), - storage_path=recording["storage_path"], - audio_path=recording.get("audio_path"), - file_size_bytes=recording.get("file_size_bytes"), - duration_seconds=recording.get("duration_seconds"), - participant_count=recording.get("participant_count", 0), - status=recording["status"], - recorded_at=recorded_at, - retention_days=recording["retention_days"], - retention_expires_at=retention_expires, - transcription_status=trans["status"] if trans else None, - transcription_id=trans["id"] if trans else None - ) - - -@router.delete("/{recording_id}") -async def delete_recording( - recording_id: str, - reason: str = Query(..., description="Reason for deletion (DSGVO audit)") -): - """ - Soft-delete a recording (DSGVO compliance). - - The recording is marked as deleted but retained for audit purposes. - Actual file deletion happens after the audit retention period. - """ - recording = _recordings_store.get(recording_id) - if not recording: - raise HTTPException(status_code=404, detail="Recording not found") - - # Soft delete - recording["status"] = "deleted" - recording["deleted_at"] = datetime.utcnow().isoformat() - recording["updated_at"] = datetime.utcnow().isoformat() - - # Log deletion with reason - log_audit( - action="deleted", - recording_id=recording_id, - metadata={"reason": reason} - ) - - return { - "success": True, - "recording_id": recording_id, - "status": "deleted", - "message": "Recording marked for deletion" - } - - -# ========================================== -# TRANSCRIPTION ENDPOINTS -# ========================================== - -@router.post("/{recording_id}/transcribe", response_model=TranscriptionStatusResponse) -async def start_transcription(recording_id: str, request: TranscriptionRequest): - """ - Start transcription for a recording. - - Queues the recording for processing by the transcription worker. - """ - recording = _recordings_store.get(recording_id) - if not recording: - raise HTTPException(status_code=404, detail="Recording not found") - - if recording["status"] == "deleted": - raise HTTPException(status_code=400, detail="Cannot transcribe deleted recording") - - # Check if transcription already exists - existing = next( - (t for t in _transcriptions_store.values() - if t["recording_id"] == recording_id and t["status"] != "failed"), - None - ) - if existing: - raise HTTPException( - status_code=409, - detail=f"Transcription already exists with status: {existing['status']}" - ) - - # Create transcription entry - transcription_id = str(uuid.uuid4()) - now = datetime.utcnow() - - transcription = { - "id": transcription_id, - "recording_id": recording_id, - "language": request.language, - "model": request.model, - "status": "pending", - "full_text": None, - "word_count": None, - "confidence_score": None, - "vtt_path": None, - "srt_path": None, - "json_path": None, - "error_message": None, - "processing_started_at": None, - "processing_completed_at": None, - "processing_duration_seconds": None, - "created_at": now.isoformat(), - "updated_at": now.isoformat() - } - - _transcriptions_store[transcription_id] = transcription - - # Update recording status - recording["status"] = "processing" - recording["updated_at"] = now.isoformat() - - # Log transcription start - log_audit( - action="transcription_started", - recording_id=recording_id, - transcription_id=transcription_id, - metadata={"language": request.language, "model": request.model} - ) - - # TODO: Queue job to Redis/Valkey for transcription worker - # from redis import Redis - # from rq import Queue - # q = Queue(connection=Redis.from_url(os.getenv("REDIS_URL"))) - # q.enqueue('transcription_worker.tasks.transcribe', transcription_id, ...) - - return TranscriptionStatusResponse( - id=transcription_id, - recording_id=recording_id, - status="pending", - language=request.language, - model=request.model, - word_count=None, - confidence_score=None, - processing_duration_seconds=None, - error_message=None, - created_at=now, - completed_at=None - ) - - -@router.get("/{recording_id}/transcription", response_model=TranscriptionStatusResponse) -async def get_transcription_status(recording_id: str): - """ - Get transcription status for a recording. - """ - transcription = next( - (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), - None - ) - if not transcription: - raise HTTPException(status_code=404, detail="No transcription found for this recording") - - return TranscriptionStatusResponse( - id=transcription["id"], - recording_id=transcription["recording_id"], - status=transcription["status"], - language=transcription["language"], - model=transcription["model"], - word_count=transcription.get("word_count"), - confidence_score=transcription.get("confidence_score"), - processing_duration_seconds=transcription.get("processing_duration_seconds"), - error_message=transcription.get("error_message"), - created_at=datetime.fromisoformat(transcription["created_at"]), - completed_at=( - datetime.fromisoformat(transcription["processing_completed_at"]) - if transcription.get("processing_completed_at") else None - ) - ) - - -@router.get("/{recording_id}/transcription/text") -async def get_transcription_text(recording_id: str): - """ - Get the full transcription text. - """ - transcription = next( - (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), - None - ) - if not transcription: - raise HTTPException(status_code=404, detail="No transcription found for this recording") - - if transcription["status"] != "completed": - raise HTTPException( - status_code=400, - detail=f"Transcription not ready. Status: {transcription['status']}" - ) - - return { - "transcription_id": transcription["id"], - "recording_id": recording_id, - "language": transcription["language"], - "text": transcription.get("full_text", ""), - "word_count": transcription.get("word_count", 0) - } - - -@router.get("/{recording_id}/transcription/vtt") -async def get_transcription_vtt(recording_id: str): - """ - Download transcription as WebVTT subtitle file. - """ - from fastapi.responses import PlainTextResponse - - transcription = next( - (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), - None - ) - if not transcription: - raise HTTPException(status_code=404, detail="No transcription found for this recording") - - if transcription["status"] != "completed": - raise HTTPException( - status_code=400, - detail=f"Transcription not ready. Status: {transcription['status']}" - ) - - # Generate VTT content - # In production, this would read from the stored VTT file - vtt_content = "WEBVTT\n\n" - text = transcription.get("full_text", "") - - if text: - # Simple VTT generation - split into sentences - sentences = text.replace(".", ".\n").split("\n") - time_offset = 0 - for sentence in sentences: - sentence = sentence.strip() - if sentence: - start = format_vtt_time(time_offset) - # Estimate ~3 seconds per sentence - time_offset += 3000 - end = format_vtt_time(time_offset) - vtt_content += f"{start} --> {end}\n{sentence}\n\n" - - return PlainTextResponse( - content=vtt_content, - media_type="text/vtt", - headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.vtt"} - ) - - -@router.get("/{recording_id}/transcription/srt") -async def get_transcription_srt(recording_id: str): - """ - Download transcription as SRT subtitle file. - """ - from fastapi.responses import PlainTextResponse - - transcription = next( - (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), - None - ) - if not transcription: - raise HTTPException(status_code=404, detail="No transcription found for this recording") - - if transcription["status"] != "completed": - raise HTTPException( - status_code=400, - detail=f"Transcription not ready. Status: {transcription['status']}" - ) - - # Generate SRT content - srt_content = "" - text = transcription.get("full_text", "") - - if text: - sentences = text.replace(".", ".\n").split("\n") - time_offset = 0 - index = 1 - for sentence in sentences: - sentence = sentence.strip() - if sentence: - start = format_srt_time(time_offset) - time_offset += 3000 - end = format_srt_time(time_offset) - srt_content += f"{index}\n{start} --> {end}\n{sentence}\n\n" - index += 1 - - return PlainTextResponse( - content=srt_content, - media_type="text/plain", - headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.srt"} - ) - - -def format_vtt_time(ms: int) -> str: - """Format milliseconds to VTT timestamp (HH:MM:SS.mmm).""" - hours = ms // 3600000 - minutes = (ms % 3600000) // 60000 - seconds = (ms % 60000) // 1000 - millis = ms % 1000 - return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{millis:03d}" - - -def format_srt_time(ms: int) -> str: - """Format milliseconds to SRT timestamp (HH:MM:SS,mmm).""" - hours = ms // 3600000 - minutes = (ms % 3600000) // 60000 - seconds = (ms % 60000) // 1000 - millis = ms % 1000 - return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}" - - -@router.get("/{recording_id}/download") -async def download_recording(recording_id: str): - """ - Download the recording file. - - In production, this would generate a presigned URL to MinIO. - """ - recording = _recordings_store.get(recording_id) - if not recording: - raise HTTPException(status_code=404, detail="Recording not found") - - if recording["status"] == "deleted": - raise HTTPException(status_code=410, detail="Recording has been deleted") - - # Log download action - log_audit(action="downloaded", recording_id=recording_id) - - # In production, generate presigned URL to MinIO - # For now, return info about where the file is - return { - "recording_id": recording_id, - "storage_path": recording["storage_path"], - "file_size_bytes": recording.get("file_size_bytes"), - "message": "In production, this would redirect to a presigned MinIO URL" - } - - -# ========================================== -# MEETING MINUTES ENDPOINTS -# ========================================== - -# In-memory store for meeting minutes (dev mode) -_minutes_store: dict = {} - - -@router.post("/{recording_id}/minutes") -async def generate_meeting_minutes( - recording_id: str, - title: Optional[str] = Query(None, description="Meeting-Titel"), - model: str = Query("breakpilot-teacher-8b", description="LLM Modell") -): - """ - Generiert KI-basierte Meeting Minutes aus der Transkription. - - Nutzt das LLM Gateway (Ollama/vLLM) fuer lokale Verarbeitung. - """ - from meeting_minutes_generator import get_minutes_generator, MeetingMinutes - - # Check recording exists - recording = _recordings_store.get(recording_id) - if not recording: - raise HTTPException(status_code=404, detail="Recording not found") - - # Check transcription exists and is completed - transcription = next( - (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), - None - ) - if not transcription: - raise HTTPException(status_code=400, detail="No transcription found. Please transcribe first.") - - if transcription["status"] != "completed": - raise HTTPException( - status_code=400, - detail=f"Transcription not ready. Status: {transcription['status']}" - ) - - # Check if minutes already exist - existing = _minutes_store.get(recording_id) - if existing and existing.get("status") == "completed": - # Return existing minutes - return existing - - # Get transcript text - transcript_text = transcription.get("full_text", "") - if not transcript_text: - raise HTTPException(status_code=400, detail="Transcription has no text content") - - # Generate meeting minutes - generator = get_minutes_generator() - - try: - minutes = await generator.generate( - transcript=transcript_text, - recording_id=recording_id, - transcription_id=transcription["id"], - title=title, - date=recording.get("recorded_at", "")[:10] if recording.get("recorded_at") else None, - duration_minutes=recording.get("duration_seconds", 0) // 60 if recording.get("duration_seconds") else None, - participant_count=recording.get("participant_count", 0), - model=model - ) - - # Store minutes - minutes_dict = minutes.model_dump() - minutes_dict["generated_at"] = minutes.generated_at.isoformat() - _minutes_store[recording_id] = minutes_dict - - # Log action - log_audit( - action="minutes_generated", - recording_id=recording_id, - metadata={"model": model, "generation_time": minutes.generation_time_seconds} - ) - - return minutes_dict - - except Exception as e: - raise HTTPException(status_code=500, detail=f"Minutes generation failed: {str(e)}") - - -@router.get("/{recording_id}/minutes") -async def get_meeting_minutes(recording_id: str): - """ - Ruft generierte Meeting Minutes ab. - """ - minutes = _minutes_store.get(recording_id) - if not minutes: - raise HTTPException(status_code=404, detail="No meeting minutes found. Generate them first with POST.") - - return minutes - - -@router.get("/{recording_id}/minutes/markdown") -async def get_minutes_markdown(recording_id: str): - """ - Exportiert Meeting Minutes als Markdown. - """ - from fastapi.responses import PlainTextResponse - from meeting_minutes_generator import minutes_to_markdown, MeetingMinutes - - minutes_dict = _minutes_store.get(recording_id) - if not minutes_dict: - raise HTTPException(status_code=404, detail="No meeting minutes found") - - # Convert dict back to MeetingMinutes - minutes_dict_copy = minutes_dict.copy() - if isinstance(minutes_dict_copy.get("generated_at"), str): - minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"]) - - minutes = MeetingMinutes(**minutes_dict_copy) - markdown = minutes_to_markdown(minutes) - - return PlainTextResponse( - content=markdown, - media_type="text/markdown", - headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.md"} - ) - - -@router.get("/{recording_id}/minutes/html") -async def get_minutes_html(recording_id: str): - """ - Exportiert Meeting Minutes als HTML. - """ - from fastapi.responses import HTMLResponse - from meeting_minutes_generator import minutes_to_html, MeetingMinutes - - minutes_dict = _minutes_store.get(recording_id) - if not minutes_dict: - raise HTTPException(status_code=404, detail="No meeting minutes found") - - # Convert dict back to MeetingMinutes - minutes_dict_copy = minutes_dict.copy() - if isinstance(minutes_dict_copy.get("generated_at"), str): - minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"]) - - minutes = MeetingMinutes(**minutes_dict_copy) - html = minutes_to_html(minutes) - - return HTMLResponse(content=html) - - -@router.get("/{recording_id}/minutes/pdf") -async def get_minutes_pdf(recording_id: str): - """ - Exportiert Meeting Minutes als PDF. - - Benoetigt WeasyPrint (pip install weasyprint). - """ - from meeting_minutes_generator import minutes_to_html, MeetingMinutes - - minutes_dict = _minutes_store.get(recording_id) - if not minutes_dict: - raise HTTPException(status_code=404, detail="No meeting minutes found") - - # Convert dict back to MeetingMinutes - minutes_dict_copy = minutes_dict.copy() - if isinstance(minutes_dict_copy.get("generated_at"), str): - minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"]) - - minutes = MeetingMinutes(**minutes_dict_copy) - html = minutes_to_html(minutes) - - try: - from weasyprint import HTML - from fastapi.responses import Response - - pdf_bytes = HTML(string=html).write_pdf() - - return Response( - content=pdf_bytes, - media_type="application/pdf", - headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.pdf"} - ) - except ImportError: - raise HTTPException( - status_code=501, - detail="PDF export not available. Install weasyprint: pip install weasyprint" - ) +router.include_router(_routes_router) +router.include_router(_transcription_router) +router.include_router(_minutes_router) diff --git a/backend-lehrer/recording_helpers.py b/backend-lehrer/recording_helpers.py new file mode 100644 index 0000000..116add6 --- /dev/null +++ b/backend-lehrer/recording_helpers.py @@ -0,0 +1,57 @@ +""" +Recording API - In-Memory Storage & Helpers. + +Shared state and utility functions for recording endpoints. +""" + +import uuid +from datetime import datetime +from typing import Optional + + +# ========================================== +# IN-MEMORY STORAGE (Dev Mode) +# ========================================== +# In production, these would be database queries + +_recordings_store: dict = {} +_transcriptions_store: dict = {} +_audit_log: list = [] +_minutes_store: dict = {} + + +def log_audit( + action: str, + recording_id: Optional[str] = None, + transcription_id: Optional[str] = None, + user_id: Optional[str] = None, + metadata: Optional[dict] = None +): + """Log audit event for DSGVO compliance.""" + _audit_log.append({ + "id": str(uuid.uuid4()), + "recording_id": recording_id, + "transcription_id": transcription_id, + "user_id": user_id, + "action": action, + "metadata": metadata or {}, + "created_at": datetime.utcnow().isoformat() + }) + + +def format_vtt_time(ms: int) -> str: + """Format milliseconds to VTT timestamp (HH:MM:SS.mmm).""" + hours = ms // 3600000 + minutes = (ms % 3600000) // 60000 + seconds = (ms % 60000) // 1000 + millis = ms % 1000 + return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{millis:03d}" + + +def format_srt_time(ms: int) -> str: + """Format milliseconds to SRT timestamp (HH:MM:SS,mmm).""" + hours = ms // 3600000 + minutes = (ms % 3600000) // 60000 + seconds = (ms % 60000) // 1000 + millis = ms % 1000 + return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}" diff --git a/backend-lehrer/recording_minutes.py b/backend-lehrer/recording_minutes.py new file mode 100644 index 0000000..3cb6d97 --- /dev/null +++ b/backend-lehrer/recording_minutes.py @@ -0,0 +1,187 @@ +""" +Recording API - Meeting Minutes Routes. + +Generate, retrieve, and export KI-based meeting minutes. +""" + +from datetime import datetime +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query +from fastapi.responses import PlainTextResponse, HTMLResponse + +from recording_helpers import ( + _recordings_store, + _transcriptions_store, + _minutes_store, + log_audit, +) + +router = APIRouter(tags=["Recordings"]) + + +# ========================================== +# MEETING MINUTES ENDPOINTS +# ========================================== + +@router.post("/{recording_id}/minutes") +async def generate_meeting_minutes( + recording_id: str, + title: Optional[str] = Query(None, description="Meeting-Titel"), + model: str = Query("breakpilot-teacher-8b", description="LLM Modell") +): + """ + Generiert KI-basierte Meeting Minutes aus der Transkription. + + Nutzt das LLM Gateway (Ollama/vLLM) fuer lokale Verarbeitung. + """ + from meeting_minutes_generator import get_minutes_generator, MeetingMinutes + + # Check recording exists + recording = _recordings_store.get(recording_id) + if not recording: + raise HTTPException(status_code=404, detail="Recording not found") + + # Check transcription exists and is completed + transcription = next( + (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), + None + ) + if not transcription: + raise HTTPException(status_code=400, detail="No transcription found. Please transcribe first.") + + if transcription["status"] != "completed": + raise HTTPException( + status_code=400, + detail=f"Transcription not ready. Status: {transcription['status']}" + ) + + # Check if minutes already exist + existing = _minutes_store.get(recording_id) + if existing and existing.get("status") == "completed": + # Return existing minutes + return existing + + # Get transcript text + transcript_text = transcription.get("full_text", "") + if not transcript_text: + raise HTTPException(status_code=400, detail="Transcription has no text content") + + # Generate meeting minutes + generator = get_minutes_generator() + + try: + minutes = await generator.generate( + transcript=transcript_text, + recording_id=recording_id, + transcription_id=transcription["id"], + title=title, + date=recording.get("recorded_at", "")[:10] if recording.get("recorded_at") else None, + duration_minutes=recording.get("duration_seconds", 0) // 60 if recording.get("duration_seconds") else None, + participant_count=recording.get("participant_count", 0), + model=model + ) + + # Store minutes + minutes_dict = minutes.model_dump() + minutes_dict["generated_at"] = minutes.generated_at.isoformat() + _minutes_store[recording_id] = minutes_dict + + # Log action + log_audit( + action="minutes_generated", + recording_id=recording_id, + metadata={"model": model, "generation_time": minutes.generation_time_seconds} + ) + + return minutes_dict + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Minutes generation failed: {str(e)}") + + +@router.get("/{recording_id}/minutes") +async def get_meeting_minutes(recording_id: str): + """ + Ruft generierte Meeting Minutes ab. + """ + minutes = _minutes_store.get(recording_id) + if not minutes: + raise HTTPException(status_code=404, detail="No meeting minutes found. Generate them first with POST.") + + return minutes + + +def _load_minutes(recording_id: str): + """Load and convert stored minutes dict back to MeetingMinutes.""" + from meeting_minutes_generator import MeetingMinutes + + minutes_dict = _minutes_store.get(recording_id) + if not minutes_dict: + raise HTTPException(status_code=404, detail="No meeting minutes found") + + minutes_dict_copy = minutes_dict.copy() + if isinstance(minutes_dict_copy.get("generated_at"), str): + minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"]) + + return MeetingMinutes(**minutes_dict_copy) + + +@router.get("/{recording_id}/minutes/markdown") +async def get_minutes_markdown(recording_id: str): + """ + Exportiert Meeting Minutes als Markdown. + """ + from meeting_minutes_generator import minutes_to_markdown + + minutes = _load_minutes(recording_id) + markdown = minutes_to_markdown(minutes) + + return PlainTextResponse( + content=markdown, + media_type="text/markdown", + headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.md"} + ) + + +@router.get("/{recording_id}/minutes/html") +async def get_minutes_html(recording_id: str): + """ + Exportiert Meeting Minutes als HTML. + """ + from meeting_minutes_generator import minutes_to_html + + minutes = _load_minutes(recording_id) + html = minutes_to_html(minutes) + + return HTMLResponse(content=html) + + +@router.get("/{recording_id}/minutes/pdf") +async def get_minutes_pdf(recording_id: str): + """ + Exportiert Meeting Minutes als PDF. + + Benoetigt WeasyPrint (pip install weasyprint). + """ + from meeting_minutes_generator import minutes_to_html + + minutes = _load_minutes(recording_id) + html = minutes_to_html(minutes) + + try: + from weasyprint import HTML + from fastapi.responses import Response + + pdf_bytes = HTML(string=html).write_pdf() + + return Response( + content=pdf_bytes, + media_type="application/pdf", + headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.pdf"} + ) + except ImportError: + raise HTTPException( + status_code=501, + detail="PDF export not available. Install weasyprint: pip install weasyprint" + ) diff --git a/backend-lehrer/recording_models.py b/backend-lehrer/recording_models.py new file mode 100644 index 0000000..4275ae8 --- /dev/null +++ b/backend-lehrer/recording_models.py @@ -0,0 +1,98 @@ +""" +Recording API - Pydantic Models & Configuration. + +Data models for recording, transcription, and webhook endpoints. +""" + +import os +from datetime import datetime +from typing import Optional, List + +from pydantic import BaseModel, Field + + +# ========================================== +# ENVIRONMENT CONFIGURATION +# ========================================== + +MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000") +MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "breakpilot") +MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "breakpilot123") +MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-recordings") +MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true" + +# Default retention period in days (DSGVO compliance) +DEFAULT_RETENTION_DAYS = int(os.getenv("RECORDING_RETENTION_DAYS", "365")) + + +# ========================================== +# PYDANTIC MODELS +# ========================================== + +class JibriWebhookPayload(BaseModel): + """Webhook payload from Jibri finalize.sh script.""" + event: str = Field(..., description="Event type: recording_completed") + recording_name: str = Field(..., description="Unique recording identifier") + storage_path: str = Field(..., description="Path in MinIO bucket") + audio_path: Optional[str] = Field(None, description="Extracted audio path") + file_size_bytes: int = Field(..., description="Video file size in bytes") + timestamp: str = Field(..., description="ISO timestamp of upload") + + +class RecordingCreate(BaseModel): + """Manual recording creation (for testing).""" + meeting_id: str + title: Optional[str] = None + storage_path: str + audio_path: Optional[str] = None + duration_seconds: Optional[int] = None + participant_count: Optional[int] = 0 + retention_days: Optional[int] = DEFAULT_RETENTION_DAYS + + +class RecordingResponse(BaseModel): + """Recording details response.""" + id: str + meeting_id: str + title: Optional[str] + storage_path: str + audio_path: Optional[str] + file_size_bytes: Optional[int] + duration_seconds: Optional[int] + participant_count: int + status: str + recorded_at: datetime + retention_days: int + retention_expires_at: datetime + transcription_status: Optional[str] = None + transcription_id: Optional[str] = None + + +class RecordingListResponse(BaseModel): + """Paginated list of recordings.""" + recordings: List[RecordingResponse] + total: int + page: int + page_size: int + + +class TranscriptionRequest(BaseModel): + """Request to start transcription.""" + language: str = Field(default="de", description="Language code: de, en, etc.") + model: str = Field(default="large-v3", description="Whisper model to use") + priority: int = Field(default=0, description="Queue priority (higher = sooner)") + + +class TranscriptionStatusResponse(BaseModel): + """Transcription status and progress.""" + id: str + recording_id: str + status: str + language: str + model: str + word_count: Optional[int] + confidence_score: Optional[float] + processing_duration_seconds: Optional[int] + error_message: Optional[str] + created_at: datetime + completed_at: Optional[datetime] diff --git a/backend-lehrer/recording_routes.py b/backend-lehrer/recording_routes.py new file mode 100644 index 0000000..6df46dd --- /dev/null +++ b/backend-lehrer/recording_routes.py @@ -0,0 +1,307 @@ +""" +Recording API - Core Recording Routes. + +Webhook, CRUD, health, audit, and download endpoints. +""" + +import uuid +from datetime import datetime, timedelta +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query, Request +from fastapi.responses import JSONResponse + +from recording_models import ( + JibriWebhookPayload, + RecordingResponse, + RecordingListResponse, + MINIO_ENDPOINT, + MINIO_BUCKET, + DEFAULT_RETENTION_DAYS, +) +from recording_helpers import ( + _recordings_store, + _transcriptions_store, + _audit_log, + log_audit, +) + +router = APIRouter(tags=["Recordings"]) + + +# ========================================== +# WEBHOOK ENDPOINT (Jibri) +# ========================================== + +@router.post("/webhook") +async def jibri_webhook(payload: JibriWebhookPayload, request: Request): + """ + Webhook endpoint called by Jibri finalize.sh after upload. + + This creates a new recording entry and optionally triggers transcription. + """ + if payload.event != "recording_completed": + return JSONResponse( + status_code=400, + content={"error": f"Unknown event type: {payload.event}"} + ) + + # Extract meeting_id from recording_name (format: meetingId_timestamp) + parts = payload.recording_name.split("_") + meeting_id = parts[0] if parts else payload.recording_name + + # Create recording entry + recording_id = str(uuid.uuid4()) + recorded_at = datetime.utcnow() + + recording = { + "id": recording_id, + "meeting_id": meeting_id, + "jibri_session_id": payload.recording_name, + "title": f"Recording {meeting_id}", + "storage_path": payload.storage_path, + "audio_path": payload.audio_path, + "file_size_bytes": payload.file_size_bytes, + "duration_seconds": None, # Will be updated after analysis + "participant_count": 0, + "status": "uploaded", + "recorded_at": recorded_at.isoformat(), + "retention_days": DEFAULT_RETENTION_DAYS, + "created_at": datetime.utcnow().isoformat(), + "updated_at": datetime.utcnow().isoformat() + } + + _recordings_store[recording_id] = recording + + # Log the creation + log_audit( + action="created", + recording_id=recording_id, + metadata={ + "source": "jibri_webhook", + "storage_path": payload.storage_path, + "file_size_bytes": payload.file_size_bytes + } + ) + + return { + "success": True, + "recording_id": recording_id, + "meeting_id": meeting_id, + "status": "uploaded", + "message": "Recording registered successfully" + } + + +# ========================================== +# HEALTH & AUDIT ENDPOINTS (must be before parameterized routes) +# ========================================== + +@router.get("/health") +async def recordings_health(): + """Health check for recording service.""" + return { + "status": "healthy", + "recordings_count": len(_recordings_store), + "transcriptions_count": len(_transcriptions_store), + "minio_endpoint": MINIO_ENDPOINT, + "bucket": MINIO_BUCKET + } + + +@router.get("/audit/log") +async def get_audit_log( + recording_id: Optional[str] = Query(None), + action: Optional[str] = Query(None), + limit: int = Query(100, ge=1, le=1000) +): + """ + Get audit log entries (DSGVO compliance). + + Admin-only endpoint for reviewing recording access history. + """ + logs = _audit_log.copy() + + if recording_id: + logs = [l for l in logs if l.get("recording_id") == recording_id] + if action: + logs = [l for l in logs if l.get("action") == action] + + # Sort by created_at descending + logs.sort(key=lambda x: x["created_at"], reverse=True) + + return { + "entries": logs[:limit], + "total": len(logs) + } + + +# ========================================== +# RECORDING MANAGEMENT ENDPOINTS +# ========================================== + +@router.get("", response_model=RecordingListResponse) +async def list_recordings( + status: Optional[str] = Query(None, description="Filter by status"), + meeting_id: Optional[str] = Query(None, description="Filter by meeting ID"), + page: int = Query(1, ge=1, description="Page number"), + page_size: int = Query(20, ge=1, le=100, description="Items per page") +): + """ + List all recordings with optional filtering. + + Supports pagination and filtering by status or meeting ID. + """ + # Filter recordings + recordings = list(_recordings_store.values()) + + if status: + recordings = [r for r in recordings if r["status"] == status] + if meeting_id: + recordings = [r for r in recordings if r["meeting_id"] == meeting_id] + + # Sort by recorded_at descending + recordings.sort(key=lambda x: x["recorded_at"], reverse=True) + + # Paginate + total = len(recordings) + start = (page - 1) * page_size + end = start + page_size + page_recordings = recordings[start:end] + + # Convert to response format + result = [] + for rec in page_recordings: + recorded_at = datetime.fromisoformat(rec["recorded_at"]) + retention_expires = recorded_at + timedelta(days=rec["retention_days"]) + + # Check for transcription + trans = next( + (t for t in _transcriptions_store.values() if t["recording_id"] == rec["id"]), + None + ) + + result.append(RecordingResponse( + id=rec["id"], + meeting_id=rec["meeting_id"], + title=rec.get("title"), + storage_path=rec["storage_path"], + audio_path=rec.get("audio_path"), + file_size_bytes=rec.get("file_size_bytes"), + duration_seconds=rec.get("duration_seconds"), + participant_count=rec.get("participant_count", 0), + status=rec["status"], + recorded_at=recorded_at, + retention_days=rec["retention_days"], + retention_expires_at=retention_expires, + transcription_status=trans["status"] if trans else None, + transcription_id=trans["id"] if trans else None + )) + + return RecordingListResponse( + recordings=result, + total=total, + page=page, + page_size=page_size + ) + + +@router.get("/{recording_id}", response_model=RecordingResponse) +async def get_recording(recording_id: str): + """ + Get details for a specific recording. + """ + recording = _recordings_store.get(recording_id) + if not recording: + raise HTTPException(status_code=404, detail="Recording not found") + + # Log view action + log_audit(action="viewed", recording_id=recording_id) + + recorded_at = datetime.fromisoformat(recording["recorded_at"]) + retention_expires = recorded_at + timedelta(days=recording["retention_days"]) + + # Check for transcription + trans = next( + (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), + None + ) + + return RecordingResponse( + id=recording["id"], + meeting_id=recording["meeting_id"], + title=recording.get("title"), + storage_path=recording["storage_path"], + audio_path=recording.get("audio_path"), + file_size_bytes=recording.get("file_size_bytes"), + duration_seconds=recording.get("duration_seconds"), + participant_count=recording.get("participant_count", 0), + status=recording["status"], + recorded_at=recorded_at, + retention_days=recording["retention_days"], + retention_expires_at=retention_expires, + transcription_status=trans["status"] if trans else None, + transcription_id=trans["id"] if trans else None + ) + + +@router.delete("/{recording_id}") +async def delete_recording( + recording_id: str, + reason: str = Query(..., description="Reason for deletion (DSGVO audit)") +): + """ + Soft-delete a recording (DSGVO compliance). + + The recording is marked as deleted but retained for audit purposes. + Actual file deletion happens after the audit retention period. + """ + recording = _recordings_store.get(recording_id) + if not recording: + raise HTTPException(status_code=404, detail="Recording not found") + + # Soft delete + recording["status"] = "deleted" + recording["deleted_at"] = datetime.utcnow().isoformat() + recording["updated_at"] = datetime.utcnow().isoformat() + + # Log deletion with reason + log_audit( + action="deleted", + recording_id=recording_id, + metadata={"reason": reason} + ) + + return { + "success": True, + "recording_id": recording_id, + "status": "deleted", + "message": "Recording marked for deletion" + } + + +@router.get("/{recording_id}/download") +async def download_recording(recording_id: str): + """ + Download the recording file. + + In production, this would generate a presigned URL to MinIO. + """ + recording = _recordings_store.get(recording_id) + if not recording: + raise HTTPException(status_code=404, detail="Recording not found") + + if recording["status"] == "deleted": + raise HTTPException(status_code=410, detail="Recording has been deleted") + + # Log download action + log_audit(action="downloaded", recording_id=recording_id) + + # In production, generate presigned URL to MinIO + # For now, return info about where the file is + return { + "recording_id": recording_id, + "storage_path": recording["storage_path"], + "file_size_bytes": recording.get("file_size_bytes"), + "message": "In production, this would redirect to a presigned MinIO URL" + } diff --git a/backend-lehrer/recording_transcription.py b/backend-lehrer/recording_transcription.py new file mode 100644 index 0000000..0f8e84c --- /dev/null +++ b/backend-lehrer/recording_transcription.py @@ -0,0 +1,250 @@ +""" +Recording API - Transcription Routes. + +Start transcription, get status, download VTT/SRT subtitle files. +""" + +import uuid +from datetime import datetime +from typing import Optional + +from fastapi import APIRouter, HTTPException +from fastapi.responses import PlainTextResponse + +from recording_models import ( + TranscriptionRequest, + TranscriptionStatusResponse, +) +from recording_helpers import ( + _recordings_store, + _transcriptions_store, + log_audit, + format_vtt_time, + format_srt_time, +) + +router = APIRouter(tags=["Recordings"]) + + +# ========================================== +# TRANSCRIPTION ENDPOINTS +# ========================================== + +@router.post("/{recording_id}/transcribe", response_model=TranscriptionStatusResponse) +async def start_transcription(recording_id: str, request: TranscriptionRequest): + """ + Start transcription for a recording. + + Queues the recording for processing by the transcription worker. + """ + recording = _recordings_store.get(recording_id) + if not recording: + raise HTTPException(status_code=404, detail="Recording not found") + + if recording["status"] == "deleted": + raise HTTPException(status_code=400, detail="Cannot transcribe deleted recording") + + # Check if transcription already exists + existing = next( + (t for t in _transcriptions_store.values() + if t["recording_id"] == recording_id and t["status"] != "failed"), + None + ) + if existing: + raise HTTPException( + status_code=409, + detail=f"Transcription already exists with status: {existing['status']}" + ) + + # Create transcription entry + transcription_id = str(uuid.uuid4()) + now = datetime.utcnow() + + transcription = { + "id": transcription_id, + "recording_id": recording_id, + "language": request.language, + "model": request.model, + "status": "pending", + "full_text": None, + "word_count": None, + "confidence_score": None, + "vtt_path": None, + "srt_path": None, + "json_path": None, + "error_message": None, + "processing_started_at": None, + "processing_completed_at": None, + "processing_duration_seconds": None, + "created_at": now.isoformat(), + "updated_at": now.isoformat() + } + + _transcriptions_store[transcription_id] = transcription + + # Update recording status + recording["status"] = "processing" + recording["updated_at"] = now.isoformat() + + # Log transcription start + log_audit( + action="transcription_started", + recording_id=recording_id, + transcription_id=transcription_id, + metadata={"language": request.language, "model": request.model} + ) + + # TODO: Queue job to Redis/Valkey for transcription worker + + return TranscriptionStatusResponse( + id=transcription_id, + recording_id=recording_id, + status="pending", + language=request.language, + model=request.model, + word_count=None, + confidence_score=None, + processing_duration_seconds=None, + error_message=None, + created_at=now, + completed_at=None + ) + + +@router.get("/{recording_id}/transcription", response_model=TranscriptionStatusResponse) +async def get_transcription_status(recording_id: str): + """ + Get transcription status for a recording. + """ + transcription = next( + (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), + None + ) + if not transcription: + raise HTTPException(status_code=404, detail="No transcription found for this recording") + + return TranscriptionStatusResponse( + id=transcription["id"], + recording_id=transcription["recording_id"], + status=transcription["status"], + language=transcription["language"], + model=transcription["model"], + word_count=transcription.get("word_count"), + confidence_score=transcription.get("confidence_score"), + processing_duration_seconds=transcription.get("processing_duration_seconds"), + error_message=transcription.get("error_message"), + created_at=datetime.fromisoformat(transcription["created_at"]), + completed_at=( + datetime.fromisoformat(transcription["processing_completed_at"]) + if transcription.get("processing_completed_at") else None + ) + ) + + +@router.get("/{recording_id}/transcription/text") +async def get_transcription_text(recording_id: str): + """ + Get the full transcription text. + """ + transcription = next( + (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), + None + ) + if not transcription: + raise HTTPException(status_code=404, detail="No transcription found for this recording") + + if transcription["status"] != "completed": + raise HTTPException( + status_code=400, + detail=f"Transcription not ready. Status: {transcription['status']}" + ) + + return { + "transcription_id": transcription["id"], + "recording_id": recording_id, + "language": transcription["language"], + "text": transcription.get("full_text", ""), + "word_count": transcription.get("word_count", 0) + } + + +@router.get("/{recording_id}/transcription/vtt") +async def get_transcription_vtt(recording_id: str): + """ + Download transcription as WebVTT subtitle file. + """ + transcription = next( + (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), + None + ) + if not transcription: + raise HTTPException(status_code=404, detail="No transcription found for this recording") + + if transcription["status"] != "completed": + raise HTTPException( + status_code=400, + detail=f"Transcription not ready. Status: {transcription['status']}" + ) + + # Generate VTT content + vtt_content = "WEBVTT\n\n" + text = transcription.get("full_text", "") + + if text: + sentences = text.replace(".", ".\n").split("\n") + time_offset = 0 + for sentence in sentences: + sentence = sentence.strip() + if sentence: + start = format_vtt_time(time_offset) + time_offset += 3000 + end = format_vtt_time(time_offset) + vtt_content += f"{start} --> {end}\n{sentence}\n\n" + + return PlainTextResponse( + content=vtt_content, + media_type="text/vtt", + headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.vtt"} + ) + + +@router.get("/{recording_id}/transcription/srt") +async def get_transcription_srt(recording_id: str): + """ + Download transcription as SRT subtitle file. + """ + transcription = next( + (t for t in _transcriptions_store.values() if t["recording_id"] == recording_id), + None + ) + if not transcription: + raise HTTPException(status_code=404, detail="No transcription found for this recording") + + if transcription["status"] != "completed": + raise HTTPException( + status_code=400, + detail=f"Transcription not ready. Status: {transcription['status']}" + ) + + # Generate SRT content + srt_content = "" + text = transcription.get("full_text", "") + + if text: + sentences = text.replace(".", ".\n").split("\n") + time_offset = 0 + index = 1 + for sentence in sentences: + sentence = sentence.strip() + if sentence: + start = format_srt_time(time_offset) + time_offset += 3000 + end = format_srt_time(time_offset) + srt_content += f"{index}\n{start} --> {end}\n{sentence}\n\n" + index += 1 + + return PlainTextResponse( + content=srt_content, + media_type="text/plain", + headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.srt"} + ) diff --git a/backend-lehrer/unit_analytics_api.py b/backend-lehrer/unit_analytics_api.py index d919910..2f9856f 100644 --- a/backend-lehrer/unit_analytics_api.py +++ b/backend-lehrer/unit_analytics_api.py @@ -1,751 +1,25 @@ -# ============================================== -# Breakpilot Drive - Unit Analytics API -# ============================================== -# Erweiterte Analytics fuer Lernfortschritt: -# - Pre/Post Gain Visualisierung -# - Misconception-Tracking -# - Stop-Level Analytics -# - Aggregierte Klassen-Statistiken -# - Export-Funktionen +""" +Breakpilot Drive - Unit Analytics API — Barrel Re-export. -from fastapi import APIRouter, HTTPException, Query, Depends, Request -from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any -from datetime import datetime, timedelta -from enum import Enum -import os -import logging -import statistics +Erweiterte Analytics fuer Lernfortschritt: +- Pre/Post Gain Visualisierung +- Misconception-Tracking +- Stop-Level Analytics +- Aggregierte Klassen-Statistiken +- Export-Funktionen -logger = logging.getLogger(__name__) +Split into: + - unit_analytics_models.py: Pydantic models & enums + - unit_analytics_helpers.py: Database access & computation helpers + - unit_analytics_routes.py: Core analytics endpoint handlers + - unit_analytics_export.py: Export & dashboard endpoints +""" -# Feature flags -USE_DATABASE = os.getenv("GAME_USE_DATABASE", "true").lower() == "true" +from fastapi import APIRouter + +from unit_analytics_routes import router as _routes_router +from unit_analytics_export import router as _export_router router = APIRouter(prefix="/api/analytics", tags=["Unit Analytics"]) - - -# ============================================== -# Pydantic Models -# ============================================== - -class TimeRange(str, Enum): - """Time range for analytics queries""" - WEEK = "week" - MONTH = "month" - QUARTER = "quarter" - ALL = "all" - - -class LearningGainData(BaseModel): - """Pre/Post learning gain data point""" - student_id: str - student_name: str - unit_id: str - precheck_score: float - postcheck_score: float - learning_gain: float - percentile: Optional[float] = None - - -class LearningGainSummary(BaseModel): - """Aggregated learning gain statistics""" - unit_id: str - unit_title: str - total_students: int - avg_precheck: float - avg_postcheck: float - avg_gain: float - median_gain: float - std_deviation: float - positive_gain_count: int - negative_gain_count: int - no_change_count: int - gain_distribution: Dict[str, int] # "-20+", "-10-0", "0-10", "10-20", "20+" - individual_gains: List[LearningGainData] - - -class StopPerformance(BaseModel): - """Performance data for a single stop""" - stop_id: str - stop_label: str - attempts_total: int - success_rate: float - avg_time_seconds: float - avg_attempts_before_success: float - common_errors: List[str] - difficulty_rating: float # 1-5 based on performance - - -class UnitPerformanceDetail(BaseModel): - """Detailed unit performance breakdown""" - unit_id: str - unit_title: str - template: str - total_sessions: int - completed_sessions: int - completion_rate: float - avg_duration_minutes: float - stops: List[StopPerformance] - bottleneck_stops: List[str] # Stops where students struggle most - - -class MisconceptionEntry(BaseModel): - """Individual misconception tracking""" - concept_id: str - concept_label: str - misconception_text: str - frequency: int - affected_student_ids: List[str] - unit_id: str - stop_id: str - detected_via: str # "precheck", "postcheck", "interaction" - first_detected: datetime - last_detected: datetime - - -class MisconceptionReport(BaseModel): - """Comprehensive misconception report""" - class_id: Optional[str] - time_range: str - total_misconceptions: int - unique_concepts: int - most_common: List[MisconceptionEntry] - by_unit: Dict[str, List[MisconceptionEntry]] - trending_up: List[MisconceptionEntry] # Getting more frequent - resolved: List[MisconceptionEntry] # No longer appearing - - -class StudentProgressTimeline(BaseModel): - """Timeline of student progress""" - student_id: str - student_name: str - units_completed: int - total_time_minutes: int - avg_score: float - trend: str # "improving", "stable", "declining" - timeline: List[Dict[str, Any]] # List of session events - - -class ClassComparisonData(BaseModel): - """Data for comparing class performance""" - class_id: str - class_name: str - student_count: int - units_assigned: int - avg_completion_rate: float - avg_learning_gain: float - avg_time_per_unit: float - - -class ExportFormat(str, Enum): - """Export format options""" - JSON = "json" - CSV = "csv" - - -# ============================================== -# Database Integration -# ============================================== - -_analytics_db = None - -async def get_analytics_database(): - """Get analytics database instance.""" - global _analytics_db - if not USE_DATABASE: - return None - if _analytics_db is None: - try: - from unit.database import get_analytics_db - _analytics_db = await get_analytics_db() - logger.info("Analytics database initialized") - except ImportError: - logger.warning("Analytics database module not available") - except Exception as e: - logger.warning(f"Analytics database not available: {e}") - return _analytics_db - - -# ============================================== -# Helper Functions -# ============================================== - -def calculate_gain_distribution(gains: List[float]) -> Dict[str, int]: - """Calculate distribution of learning gains into buckets.""" - distribution = { - "< -20%": 0, - "-20% to -10%": 0, - "-10% to 0%": 0, - "0% to 10%": 0, - "10% to 20%": 0, - "> 20%": 0, - } - - for gain in gains: - gain_percent = gain * 100 - if gain_percent < -20: - distribution["< -20%"] += 1 - elif gain_percent < -10: - distribution["-20% to -10%"] += 1 - elif gain_percent < 0: - distribution["-10% to 0%"] += 1 - elif gain_percent < 10: - distribution["0% to 10%"] += 1 - elif gain_percent < 20: - distribution["10% to 20%"] += 1 - else: - distribution["> 20%"] += 1 - - return distribution - - -def calculate_trend(scores: List[float]) -> str: - """Calculate trend from a series of scores.""" - if len(scores) < 3: - return "insufficient_data" - - # Simple linear regression - n = len(scores) - x_mean = (n - 1) / 2 - y_mean = sum(scores) / n - - numerator = sum((i - x_mean) * (scores[i] - y_mean) for i in range(n)) - denominator = sum((i - x_mean) ** 2 for i in range(n)) - - if denominator == 0: - return "stable" - - slope = numerator / denominator - - if slope > 0.05: - return "improving" - elif slope < -0.05: - return "declining" - else: - return "stable" - - -def calculate_difficulty_rating(success_rate: float, avg_attempts: float) -> float: - """Calculate difficulty rating 1-5 based on success metrics.""" - # Lower success rate and higher attempts = higher difficulty - base_difficulty = (1 - success_rate) * 3 + 1 # 1-4 range - attempt_modifier = min(avg_attempts - 1, 1) # 0-1 range - return min(5.0, base_difficulty + attempt_modifier) - - -# ============================================== -# API Endpoints - Learning Gain -# ============================================== - -# NOTE: Static routes must come BEFORE dynamic routes like /{unit_id} -@router.get("/learning-gain/compare") -async def compare_learning_gains( - unit_ids: str = Query(..., description="Comma-separated unit IDs"), - class_id: Optional[str] = Query(None), - time_range: TimeRange = Query(TimeRange.MONTH), -) -> Dict[str, Any]: - """ - Compare learning gains across multiple units. - """ - unit_list = [u.strip() for u in unit_ids.split(",")] - comparisons = [] - - for unit_id in unit_list: - try: - summary = await get_learning_gain_analysis(unit_id, class_id, time_range) - comparisons.append({ - "unit_id": unit_id, - "avg_gain": summary.avg_gain, - "median_gain": summary.median_gain, - "total_students": summary.total_students, - "positive_rate": summary.positive_gain_count / max(summary.total_students, 1), - }) - except Exception as e: - logger.error(f"Failed to get comparison for {unit_id}: {e}") - - return { - "time_range": time_range.value, - "class_id": class_id, - "comparisons": sorted(comparisons, key=lambda x: x["avg_gain"], reverse=True), - } - - -@router.get("/learning-gain/{unit_id}", response_model=LearningGainSummary) -async def get_learning_gain_analysis( - unit_id: str, - class_id: Optional[str] = Query(None, description="Filter by class"), - time_range: TimeRange = Query(TimeRange.MONTH, description="Time range for analysis"), -) -> LearningGainSummary: - """ - Get detailed pre/post learning gain analysis for a unit. - - Shows individual gains, aggregated statistics, and distribution. - """ - db = await get_analytics_database() - individual_gains = [] - - if db: - try: - # Get all sessions with pre/post scores for this unit - sessions = await db.get_unit_sessions_with_scores( - unit_id=unit_id, - class_id=class_id, - time_range=time_range.value - ) - - for session in sessions: - if session.get("precheck_score") is not None and session.get("postcheck_score") is not None: - gain = session["postcheck_score"] - session["precheck_score"] - individual_gains.append(LearningGainData( - student_id=session["student_id"], - student_name=session.get("student_name", session["student_id"][:8]), - unit_id=unit_id, - precheck_score=session["precheck_score"], - postcheck_score=session["postcheck_score"], - learning_gain=gain, - )) - except Exception as e: - logger.error(f"Failed to get learning gain data: {e}") - - # Calculate statistics - if not individual_gains: - # Return empty summary - return LearningGainSummary( - unit_id=unit_id, - unit_title=f"Unit {unit_id}", - total_students=0, - avg_precheck=0.0, - avg_postcheck=0.0, - avg_gain=0.0, - median_gain=0.0, - std_deviation=0.0, - positive_gain_count=0, - negative_gain_count=0, - no_change_count=0, - gain_distribution={}, - individual_gains=[], - ) - - gains = [g.learning_gain for g in individual_gains] - prechecks = [g.precheck_score for g in individual_gains] - postchecks = [g.postcheck_score for g in individual_gains] - - avg_gain = statistics.mean(gains) - median_gain = statistics.median(gains) - std_dev = statistics.stdev(gains) if len(gains) > 1 else 0.0 - - # Calculate percentiles - sorted_gains = sorted(gains) - for data in individual_gains: - rank = sorted_gains.index(data.learning_gain) + 1 - data.percentile = rank / len(sorted_gains) * 100 - - return LearningGainSummary( - unit_id=unit_id, - unit_title=f"Unit {unit_id}", - total_students=len(individual_gains), - avg_precheck=statistics.mean(prechecks), - avg_postcheck=statistics.mean(postchecks), - avg_gain=avg_gain, - median_gain=median_gain, - std_deviation=std_dev, - positive_gain_count=sum(1 for g in gains if g > 0.01), - negative_gain_count=sum(1 for g in gains if g < -0.01), - no_change_count=sum(1 for g in gains if -0.01 <= g <= 0.01), - gain_distribution=calculate_gain_distribution(gains), - individual_gains=sorted(individual_gains, key=lambda x: x.learning_gain, reverse=True), - ) - - -# ============================================== -# API Endpoints - Stop-Level Analytics -# ============================================== - -@router.get("/unit/{unit_id}/stops", response_model=UnitPerformanceDetail) -async def get_unit_stop_analytics( - unit_id: str, - class_id: Optional[str] = Query(None), - time_range: TimeRange = Query(TimeRange.MONTH), -) -> UnitPerformanceDetail: - """ - Get detailed stop-level performance analytics. - - Identifies bottleneck stops where students struggle most. - """ - db = await get_analytics_database() - stops_data = [] - - if db: - try: - # Get stop-level telemetry - stop_stats = await db.get_stop_performance( - unit_id=unit_id, - class_id=class_id, - time_range=time_range.value - ) - - for stop in stop_stats: - difficulty = calculate_difficulty_rating( - stop.get("success_rate", 0.5), - stop.get("avg_attempts", 1.0) - ) - stops_data.append(StopPerformance( - stop_id=stop["stop_id"], - stop_label=stop.get("stop_label", stop["stop_id"]), - attempts_total=stop.get("total_attempts", 0), - success_rate=stop.get("success_rate", 0.0), - avg_time_seconds=stop.get("avg_time_seconds", 0.0), - avg_attempts_before_success=stop.get("avg_attempts", 1.0), - common_errors=stop.get("common_errors", []), - difficulty_rating=difficulty, - )) - - # Get overall unit stats - unit_stats = await db.get_unit_overall_stats(unit_id, class_id, time_range.value) - except Exception as e: - logger.error(f"Failed to get stop analytics: {e}") - unit_stats = {} - else: - unit_stats = {} - - # Identify bottleneck stops (difficulty > 3.5 or success rate < 0.6) - bottlenecks = [ - s.stop_id for s in stops_data - if s.difficulty_rating > 3.5 or s.success_rate < 0.6 - ] - - return UnitPerformanceDetail( - unit_id=unit_id, - unit_title=f"Unit {unit_id}", - template=unit_stats.get("template", "unknown"), - total_sessions=unit_stats.get("total_sessions", 0), - completed_sessions=unit_stats.get("completed_sessions", 0), - completion_rate=unit_stats.get("completion_rate", 0.0), - avg_duration_minutes=unit_stats.get("avg_duration_minutes", 0.0), - stops=stops_data, - bottleneck_stops=bottlenecks, - ) - - -# ============================================== -# API Endpoints - Misconception Tracking -# ============================================== - -@router.get("/misconceptions", response_model=MisconceptionReport) -async def get_misconception_report( - class_id: Optional[str] = Query(None), - unit_id: Optional[str] = Query(None), - time_range: TimeRange = Query(TimeRange.MONTH), - limit: int = Query(20, ge=1, le=100), -) -> MisconceptionReport: - """ - Get comprehensive misconception report. - - Shows most common misconceptions and their frequency. - """ - db = await get_analytics_database() - misconceptions = [] - - if db: - try: - raw_misconceptions = await db.get_misconceptions( - class_id=class_id, - unit_id=unit_id, - time_range=time_range.value, - limit=limit - ) - - for m in raw_misconceptions: - misconceptions.append(MisconceptionEntry( - concept_id=m["concept_id"], - concept_label=m["concept_label"], - misconception_text=m["misconception_text"], - frequency=m["frequency"], - affected_student_ids=m.get("student_ids", []), - unit_id=m["unit_id"], - stop_id=m["stop_id"], - detected_via=m.get("detected_via", "unknown"), - first_detected=m.get("first_detected", datetime.utcnow()), - last_detected=m.get("last_detected", datetime.utcnow()), - )) - except Exception as e: - logger.error(f"Failed to get misconceptions: {e}") - - # Group by unit - by_unit = {} - for m in misconceptions: - if m.unit_id not in by_unit: - by_unit[m.unit_id] = [] - by_unit[m.unit_id].append(m) - - # Identify trending misconceptions (would need historical comparison in production) - trending_up = misconceptions[:3] if misconceptions else [] - resolved = [] # Would identify from historical data - - return MisconceptionReport( - class_id=class_id, - time_range=time_range.value, - total_misconceptions=sum(m.frequency for m in misconceptions), - unique_concepts=len(set(m.concept_id for m in misconceptions)), - most_common=sorted(misconceptions, key=lambda x: x.frequency, reverse=True)[:10], - by_unit=by_unit, - trending_up=trending_up, - resolved=resolved, - ) - - -@router.get("/misconceptions/student/{student_id}") -async def get_student_misconceptions( - student_id: str, - time_range: TimeRange = Query(TimeRange.ALL), -) -> Dict[str, Any]: - """ - Get misconceptions for a specific student. - - Useful for personalized remediation. - """ - db = await get_analytics_database() - - if db: - try: - misconceptions = await db.get_student_misconceptions( - student_id=student_id, - time_range=time_range.value - ) - return { - "student_id": student_id, - "misconceptions": misconceptions, - "recommended_remediation": [ - {"concept": m["concept_label"], "activity": f"Review {m['unit_id']}/{m['stop_id']}"} - for m in misconceptions[:5] - ] - } - except Exception as e: - logger.error(f"Failed to get student misconceptions: {e}") - - return { - "student_id": student_id, - "misconceptions": [], - "recommended_remediation": [], - } - - -# ============================================== -# API Endpoints - Student Progress Timeline -# ============================================== - -@router.get("/student/{student_id}/timeline", response_model=StudentProgressTimeline) -async def get_student_timeline( - student_id: str, - time_range: TimeRange = Query(TimeRange.ALL), -) -> StudentProgressTimeline: - """ - Get detailed progress timeline for a student. - - Shows all unit sessions and performance trend. - """ - db = await get_analytics_database() - timeline = [] - scores = [] - - if db: - try: - sessions = await db.get_student_sessions( - student_id=student_id, - time_range=time_range.value - ) - - for session in sessions: - timeline.append({ - "date": session.get("started_at"), - "unit_id": session.get("unit_id"), - "completed": session.get("completed_at") is not None, - "precheck": session.get("precheck_score"), - "postcheck": session.get("postcheck_score"), - "duration_minutes": session.get("duration_seconds", 0) // 60, - }) - if session.get("postcheck_score") is not None: - scores.append(session["postcheck_score"]) - except Exception as e: - logger.error(f"Failed to get student timeline: {e}") - - trend = calculate_trend(scores) if scores else "insufficient_data" - - return StudentProgressTimeline( - student_id=student_id, - student_name=f"Student {student_id[:8]}", # Would load actual name - units_completed=sum(1 for t in timeline if t["completed"]), - total_time_minutes=sum(t["duration_minutes"] for t in timeline), - avg_score=statistics.mean(scores) if scores else 0.0, - trend=trend, - timeline=timeline, - ) - - -# ============================================== -# API Endpoints - Class Comparison -# ============================================== - -@router.get("/compare/classes", response_model=List[ClassComparisonData]) -async def compare_classes( - class_ids: str = Query(..., description="Comma-separated class IDs"), - time_range: TimeRange = Query(TimeRange.MONTH), -) -> List[ClassComparisonData]: - """ - Compare performance across multiple classes. - """ - class_list = [c.strip() for c in class_ids.split(",")] - comparisons = [] - - db = await get_analytics_database() - if db: - for class_id in class_list: - try: - stats = await db.get_class_aggregate_stats(class_id, time_range.value) - comparisons.append(ClassComparisonData( - class_id=class_id, - class_name=stats.get("class_name", f"Klasse {class_id[:8]}"), - student_count=stats.get("student_count", 0), - units_assigned=stats.get("units_assigned", 0), - avg_completion_rate=stats.get("avg_completion_rate", 0.0), - avg_learning_gain=stats.get("avg_learning_gain", 0.0), - avg_time_per_unit=stats.get("avg_time_per_unit", 0.0), - )) - except Exception as e: - logger.error(f"Failed to get stats for class {class_id}: {e}") - - return sorted(comparisons, key=lambda x: x.avg_learning_gain, reverse=True) - - -# ============================================== -# API Endpoints - Export -# ============================================== - -@router.get("/export/learning-gains") -async def export_learning_gains( - unit_id: Optional[str] = Query(None), - class_id: Optional[str] = Query(None), - time_range: TimeRange = Query(TimeRange.ALL), - format: ExportFormat = Query(ExportFormat.JSON), -) -> Any: - """ - Export learning gain data. - """ - from fastapi.responses import Response - - db = await get_analytics_database() - data = [] - - if db: - try: - data = await db.export_learning_gains( - unit_id=unit_id, - class_id=class_id, - time_range=time_range.value - ) - except Exception as e: - logger.error(f"Failed to export data: {e}") - - if format == ExportFormat.CSV: - # Convert to CSV - if not data: - csv_content = "student_id,unit_id,precheck,postcheck,gain\n" - else: - csv_content = "student_id,unit_id,precheck,postcheck,gain\n" - for row in data: - csv_content += f"{row['student_id']},{row['unit_id']},{row.get('precheck', '')},{row.get('postcheck', '')},{row.get('gain', '')}\n" - - return Response( - content=csv_content, - media_type="text/csv", - headers={"Content-Disposition": "attachment; filename=learning_gains.csv"} - ) - - return { - "export_date": datetime.utcnow().isoformat(), - "filters": { - "unit_id": unit_id, - "class_id": class_id, - "time_range": time_range.value, - }, - "data": data, - } - - -@router.get("/export/misconceptions") -async def export_misconceptions( - class_id: Optional[str] = Query(None), - format: ExportFormat = Query(ExportFormat.JSON), -) -> Any: - """ - Export misconception data for further analysis. - """ - report = await get_misconception_report( - class_id=class_id, - unit_id=None, - time_range=TimeRange.MONTH, - limit=100 - ) - - if format == ExportFormat.CSV: - from fastapi.responses import Response - csv_content = "concept_id,concept_label,misconception,frequency,unit_id,stop_id\n" - for m in report.most_common: - csv_content += f'"{m.concept_id}","{m.concept_label}","{m.misconception_text}",{m.frequency},"{m.unit_id}","{m.stop_id}"\n' - - return Response( - content=csv_content, - media_type="text/csv", - headers={"Content-Disposition": "attachment; filename=misconceptions.csv"} - ) - - return { - "export_date": datetime.utcnow().isoformat(), - "class_id": class_id, - "total_entries": len(report.most_common), - "data": [m.model_dump() for m in report.most_common], - } - - -# ============================================== -# API Endpoints - Dashboard Aggregates -# ============================================== - -@router.get("/dashboard/overview") -async def get_analytics_overview( - time_range: TimeRange = Query(TimeRange.MONTH), -) -> Dict[str, Any]: - """ - Get high-level analytics overview for dashboard. - """ - db = await get_analytics_database() - - if db: - try: - overview = await db.get_analytics_overview(time_range.value) - return overview - except Exception as e: - logger.error(f"Failed to get analytics overview: {e}") - - return { - "time_range": time_range.value, - "total_sessions": 0, - "unique_students": 0, - "avg_completion_rate": 0.0, - "avg_learning_gain": 0.0, - "most_played_units": [], - "struggling_concepts": [], - "active_classes": 0, - } - - -@router.get("/health") -async def health_check() -> Dict[str, Any]: - """Health check for analytics API.""" - db = await get_analytics_database() - return { - "status": "healthy", - "service": "unit-analytics", - "database": "connected" if db else "disconnected", - } +router.include_router(_routes_router) +router.include_router(_export_router) diff --git a/backend-lehrer/unit_analytics_export.py b/backend-lehrer/unit_analytics_export.py new file mode 100644 index 0000000..add5382 --- /dev/null +++ b/backend-lehrer/unit_analytics_export.py @@ -0,0 +1,145 @@ +""" +Unit Analytics API - Export & Dashboard Routes. + +Export endpoints for learning gains and misconceptions, plus dashboard overview. +""" + +import logging +from datetime import datetime +from typing import Optional, Dict, Any + +from fastapi import APIRouter, Query +from fastapi.responses import Response + +from unit_analytics_models import TimeRange, ExportFormat +from unit_analytics_helpers import get_analytics_database + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Unit Analytics"]) + + +# ============================================== +# API Endpoints - Export +# ============================================== + +@router.get("/export/learning-gains") +async def export_learning_gains( + unit_id: Optional[str] = Query(None), + class_id: Optional[str] = Query(None), + time_range: TimeRange = Query(TimeRange.ALL), + format: ExportFormat = Query(ExportFormat.JSON), +) -> Any: + """ + Export learning gain data. + """ + db = await get_analytics_database() + data = [] + + if db: + try: + data = await db.export_learning_gains( + unit_id=unit_id, class_id=class_id, time_range=time_range.value + ) + except Exception as e: + logger.error(f"Failed to export data: {e}") + + if format == ExportFormat.CSV: + if not data: + csv_content = "student_id,unit_id,precheck,postcheck,gain\n" + else: + csv_content = "student_id,unit_id,precheck,postcheck,gain\n" + for row in data: + csv_content += f"{row['student_id']},{row['unit_id']},{row.get('precheck', '')},{row.get('postcheck', '')},{row.get('gain', '')}\n" + + return Response( + content=csv_content, + media_type="text/csv", + headers={"Content-Disposition": "attachment; filename=learning_gains.csv"} + ) + + return { + "export_date": datetime.utcnow().isoformat(), + "filters": { + "unit_id": unit_id, "class_id": class_id, "time_range": time_range.value, + }, + "data": data, + } + + +@router.get("/export/misconceptions") +async def export_misconceptions( + class_id: Optional[str] = Query(None), + format: ExportFormat = Query(ExportFormat.JSON), +) -> Any: + """ + Export misconception data for further analysis. + """ + # Import here to avoid circular dependency + from unit_analytics_routes import get_misconception_report + + report = await get_misconception_report( + class_id=class_id, unit_id=None, + time_range=TimeRange.MONTH, limit=100 + ) + + if format == ExportFormat.CSV: + csv_content = "concept_id,concept_label,misconception,frequency,unit_id,stop_id\n" + for m in report.most_common: + csv_content += f'"{m.concept_id}","{m.concept_label}","{m.misconception_text}",{m.frequency},"{m.unit_id}","{m.stop_id}"\n' + + return Response( + content=csv_content, + media_type="text/csv", + headers={"Content-Disposition": "attachment; filename=misconceptions.csv"} + ) + + return { + "export_date": datetime.utcnow().isoformat(), + "class_id": class_id, + "total_entries": len(report.most_common), + "data": [m.model_dump() for m in report.most_common], + } + + +# ============================================== +# API Endpoints - Dashboard Aggregates +# ============================================== + +@router.get("/dashboard/overview") +async def get_analytics_overview( + time_range: TimeRange = Query(TimeRange.MONTH), +) -> Dict[str, Any]: + """ + Get high-level analytics overview for dashboard. + """ + db = await get_analytics_database() + + if db: + try: + overview = await db.get_analytics_overview(time_range.value) + return overview + except Exception as e: + logger.error(f"Failed to get analytics overview: {e}") + + return { + "time_range": time_range.value, + "total_sessions": 0, + "unique_students": 0, + "avg_completion_rate": 0.0, + "avg_learning_gain": 0.0, + "most_played_units": [], + "struggling_concepts": [], + "active_classes": 0, + } + + +@router.get("/health") +async def health_check() -> Dict[str, Any]: + """Health check for analytics API.""" + db = await get_analytics_database() + return { + "status": "healthy", + "service": "unit-analytics", + "database": "connected" if db else "disconnected", + } diff --git a/backend-lehrer/unit_analytics_helpers.py b/backend-lehrer/unit_analytics_helpers.py new file mode 100644 index 0000000..91202c1 --- /dev/null +++ b/backend-lehrer/unit_analytics_helpers.py @@ -0,0 +1,97 @@ +""" +Unit Analytics API - Helpers. + +Database access, statistical computation, and utility functions. +""" + +import os +import logging +from typing import List, Dict, Optional + +logger = logging.getLogger(__name__) + +# Feature flags +USE_DATABASE = os.getenv("GAME_USE_DATABASE", "true").lower() == "true" + +# Database singleton +_analytics_db = None + + +async def get_analytics_database(): + """Get analytics database instance.""" + global _analytics_db + if not USE_DATABASE: + return None + if _analytics_db is None: + try: + from unit.database import get_analytics_db + _analytics_db = await get_analytics_db() + logger.info("Analytics database initialized") + except ImportError: + logger.warning("Analytics database module not available") + except Exception as e: + logger.warning(f"Analytics database not available: {e}") + return _analytics_db + + +def calculate_gain_distribution(gains: List[float]) -> Dict[str, int]: + """Calculate distribution of learning gains into buckets.""" + distribution = { + "< -20%": 0, + "-20% to -10%": 0, + "-10% to 0%": 0, + "0% to 10%": 0, + "10% to 20%": 0, + "> 20%": 0, + } + + for gain in gains: + gain_percent = gain * 100 + if gain_percent < -20: + distribution["< -20%"] += 1 + elif gain_percent < -10: + distribution["-20% to -10%"] += 1 + elif gain_percent < 0: + distribution["-10% to 0%"] += 1 + elif gain_percent < 10: + distribution["0% to 10%"] += 1 + elif gain_percent < 20: + distribution["10% to 20%"] += 1 + else: + distribution["> 20%"] += 1 + + return distribution + + +def calculate_trend(scores: List[float]) -> str: + """Calculate trend from a series of scores.""" + if len(scores) < 3: + return "insufficient_data" + + # Simple linear regression + n = len(scores) + x_mean = (n - 1) / 2 + y_mean = sum(scores) / n + + numerator = sum((i - x_mean) * (scores[i] - y_mean) for i in range(n)) + denominator = sum((i - x_mean) ** 2 for i in range(n)) + + if denominator == 0: + return "stable" + + slope = numerator / denominator + + if slope > 0.05: + return "improving" + elif slope < -0.05: + return "declining" + else: + return "stable" + + +def calculate_difficulty_rating(success_rate: float, avg_attempts: float) -> float: + """Calculate difficulty rating 1-5 based on success metrics.""" + # Lower success rate and higher attempts = higher difficulty + base_difficulty = (1 - success_rate) * 3 + 1 # 1-4 range + attempt_modifier = min(avg_attempts - 1, 1) # 0-1 range + return min(5.0, base_difficulty + attempt_modifier) diff --git a/backend-lehrer/unit_analytics_models.py b/backend-lehrer/unit_analytics_models.py new file mode 100644 index 0000000..5a63938 --- /dev/null +++ b/backend-lehrer/unit_analytics_models.py @@ -0,0 +1,127 @@ +""" +Unit Analytics API - Pydantic Models. + +Data models for learning gains, stop performance, misconceptions, +student progress, class comparison, and export. +""" + +from typing import List, Optional, Dict, Any +from datetime import datetime +from enum import Enum + +from pydantic import BaseModel, Field + + +class TimeRange(str, Enum): + """Time range for analytics queries""" + WEEK = "week" + MONTH = "month" + QUARTER = "quarter" + ALL = "all" + + +class LearningGainData(BaseModel): + """Pre/Post learning gain data point""" + student_id: str + student_name: str + unit_id: str + precheck_score: float + postcheck_score: float + learning_gain: float + percentile: Optional[float] = None + + +class LearningGainSummary(BaseModel): + """Aggregated learning gain statistics""" + unit_id: str + unit_title: str + total_students: int + avg_precheck: float + avg_postcheck: float + avg_gain: float + median_gain: float + std_deviation: float + positive_gain_count: int + negative_gain_count: int + no_change_count: int + gain_distribution: Dict[str, int] + individual_gains: List[LearningGainData] + + +class StopPerformance(BaseModel): + """Performance data for a single stop""" + stop_id: str + stop_label: str + attempts_total: int + success_rate: float + avg_time_seconds: float + avg_attempts_before_success: float + common_errors: List[str] + difficulty_rating: float # 1-5 based on performance + + +class UnitPerformanceDetail(BaseModel): + """Detailed unit performance breakdown""" + unit_id: str + unit_title: str + template: str + total_sessions: int + completed_sessions: int + completion_rate: float + avg_duration_minutes: float + stops: List[StopPerformance] + bottleneck_stops: List[str] # Stops where students struggle most + + +class MisconceptionEntry(BaseModel): + """Individual misconception tracking""" + concept_id: str + concept_label: str + misconception_text: str + frequency: int + affected_student_ids: List[str] + unit_id: str + stop_id: str + detected_via: str # "precheck", "postcheck", "interaction" + first_detected: datetime + last_detected: datetime + + +class MisconceptionReport(BaseModel): + """Comprehensive misconception report""" + class_id: Optional[str] + time_range: str + total_misconceptions: int + unique_concepts: int + most_common: List[MisconceptionEntry] + by_unit: Dict[str, List[MisconceptionEntry]] + trending_up: List[MisconceptionEntry] # Getting more frequent + resolved: List[MisconceptionEntry] # No longer appearing + + +class StudentProgressTimeline(BaseModel): + """Timeline of student progress""" + student_id: str + student_name: str + units_completed: int + total_time_minutes: int + avg_score: float + trend: str # "improving", "stable", "declining" + timeline: List[Dict[str, Any]] # List of session events + + +class ClassComparisonData(BaseModel): + """Data for comparing class performance""" + class_id: str + class_name: str + student_count: int + units_assigned: int + avg_completion_rate: float + avg_learning_gain: float + avg_time_per_unit: float + + +class ExportFormat(str, Enum): + """Export format options""" + JSON = "json" + CSV = "csv" diff --git a/backend-lehrer/unit_analytics_routes.py b/backend-lehrer/unit_analytics_routes.py new file mode 100644 index 0000000..8a11e6d --- /dev/null +++ b/backend-lehrer/unit_analytics_routes.py @@ -0,0 +1,394 @@ +""" +Unit Analytics API - Routes. + +All API endpoints for learning gain, stop-level, misconception, +student timeline, class comparison, export, and dashboard analytics. +""" + +import logging +import statistics +from datetime import datetime +from typing import Optional, Dict, Any, List + +from fastapi import APIRouter, Query + +from unit_analytics_models import ( + TimeRange, + LearningGainData, + LearningGainSummary, + StopPerformance, + UnitPerformanceDetail, + MisconceptionEntry, + MisconceptionReport, + StudentProgressTimeline, + ClassComparisonData, +) +from unit_analytics_helpers import ( + get_analytics_database, + calculate_gain_distribution, + calculate_trend, + calculate_difficulty_rating, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Unit Analytics"]) + + +# ============================================== +# API Endpoints - Learning Gain +# ============================================== + +# NOTE: Static routes must come BEFORE dynamic routes like /{unit_id} +@router.get("/learning-gain/compare") +async def compare_learning_gains( + unit_ids: str = Query(..., description="Comma-separated unit IDs"), + class_id: Optional[str] = Query(None), + time_range: TimeRange = Query(TimeRange.MONTH), +) -> Dict[str, Any]: + """ + Compare learning gains across multiple units. + """ + unit_list = [u.strip() for u in unit_ids.split(",")] + comparisons = [] + + for unit_id in unit_list: + try: + summary = await get_learning_gain_analysis(unit_id, class_id, time_range) + comparisons.append({ + "unit_id": unit_id, + "avg_gain": summary.avg_gain, + "median_gain": summary.median_gain, + "total_students": summary.total_students, + "positive_rate": summary.positive_gain_count / max(summary.total_students, 1), + }) + except Exception as e: + logger.error(f"Failed to get comparison for {unit_id}: {e}") + + return { + "time_range": time_range.value, + "class_id": class_id, + "comparisons": sorted(comparisons, key=lambda x: x["avg_gain"], reverse=True), + } + + +@router.get("/learning-gain/{unit_id}", response_model=LearningGainSummary) +async def get_learning_gain_analysis( + unit_id: str, + class_id: Optional[str] = Query(None, description="Filter by class"), + time_range: TimeRange = Query(TimeRange.MONTH, description="Time range for analysis"), +) -> LearningGainSummary: + """ + Get detailed pre/post learning gain analysis for a unit. + """ + db = await get_analytics_database() + individual_gains = [] + + if db: + try: + sessions = await db.get_unit_sessions_with_scores( + unit_id=unit_id, + class_id=class_id, + time_range=time_range.value + ) + + for session in sessions: + if session.get("precheck_score") is not None and session.get("postcheck_score") is not None: + gain = session["postcheck_score"] - session["precheck_score"] + individual_gains.append(LearningGainData( + student_id=session["student_id"], + student_name=session.get("student_name", session["student_id"][:8]), + unit_id=unit_id, + precheck_score=session["precheck_score"], + postcheck_score=session["postcheck_score"], + learning_gain=gain, + )) + except Exception as e: + logger.error(f"Failed to get learning gain data: {e}") + + # Calculate statistics + if not individual_gains: + return LearningGainSummary( + unit_id=unit_id, + unit_title=f"Unit {unit_id}", + total_students=0, + avg_precheck=0.0, avg_postcheck=0.0, + avg_gain=0.0, median_gain=0.0, std_deviation=0.0, + positive_gain_count=0, negative_gain_count=0, no_change_count=0, + gain_distribution={}, individual_gains=[], + ) + + gains = [g.learning_gain for g in individual_gains] + prechecks = [g.precheck_score for g in individual_gains] + postchecks = [g.postcheck_score for g in individual_gains] + + avg_gain = statistics.mean(gains) + median_gain = statistics.median(gains) + std_dev = statistics.stdev(gains) if len(gains) > 1 else 0.0 + + # Calculate percentiles + sorted_gains = sorted(gains) + for data in individual_gains: + rank = sorted_gains.index(data.learning_gain) + 1 + data.percentile = rank / len(sorted_gains) * 100 + + return LearningGainSummary( + unit_id=unit_id, + unit_title=f"Unit {unit_id}", + total_students=len(individual_gains), + avg_precheck=statistics.mean(prechecks), + avg_postcheck=statistics.mean(postchecks), + avg_gain=avg_gain, + median_gain=median_gain, + std_deviation=std_dev, + positive_gain_count=sum(1 for g in gains if g > 0.01), + negative_gain_count=sum(1 for g in gains if g < -0.01), + no_change_count=sum(1 for g in gains if -0.01 <= g <= 0.01), + gain_distribution=calculate_gain_distribution(gains), + individual_gains=sorted(individual_gains, key=lambda x: x.learning_gain, reverse=True), + ) + + +# ============================================== +# API Endpoints - Stop-Level Analytics +# ============================================== + +@router.get("/unit/{unit_id}/stops", response_model=UnitPerformanceDetail) +async def get_unit_stop_analytics( + unit_id: str, + class_id: Optional[str] = Query(None), + time_range: TimeRange = Query(TimeRange.MONTH), +) -> UnitPerformanceDetail: + """ + Get detailed stop-level performance analytics. + """ + db = await get_analytics_database() + stops_data = [] + + if db: + try: + stop_stats = await db.get_stop_performance( + unit_id=unit_id, class_id=class_id, time_range=time_range.value + ) + + for stop in stop_stats: + difficulty = calculate_difficulty_rating( + stop.get("success_rate", 0.5), + stop.get("avg_attempts", 1.0) + ) + stops_data.append(StopPerformance( + stop_id=stop["stop_id"], + stop_label=stop.get("stop_label", stop["stop_id"]), + attempts_total=stop.get("total_attempts", 0), + success_rate=stop.get("success_rate", 0.0), + avg_time_seconds=stop.get("avg_time_seconds", 0.0), + avg_attempts_before_success=stop.get("avg_attempts", 1.0), + common_errors=stop.get("common_errors", []), + difficulty_rating=difficulty, + )) + + unit_stats = await db.get_unit_overall_stats(unit_id, class_id, time_range.value) + except Exception as e: + logger.error(f"Failed to get stop analytics: {e}") + unit_stats = {} + else: + unit_stats = {} + + # Identify bottleneck stops + bottlenecks = [ + s.stop_id for s in stops_data + if s.difficulty_rating > 3.5 or s.success_rate < 0.6 + ] + + return UnitPerformanceDetail( + unit_id=unit_id, + unit_title=f"Unit {unit_id}", + template=unit_stats.get("template", "unknown"), + total_sessions=unit_stats.get("total_sessions", 0), + completed_sessions=unit_stats.get("completed_sessions", 0), + completion_rate=unit_stats.get("completion_rate", 0.0), + avg_duration_minutes=unit_stats.get("avg_duration_minutes", 0.0), + stops=stops_data, + bottleneck_stops=bottlenecks, + ) + + +# ============================================== +# API Endpoints - Misconception Tracking +# ============================================== + +@router.get("/misconceptions", response_model=MisconceptionReport) +async def get_misconception_report( + class_id: Optional[str] = Query(None), + unit_id: Optional[str] = Query(None), + time_range: TimeRange = Query(TimeRange.MONTH), + limit: int = Query(20, ge=1, le=100), +) -> MisconceptionReport: + """ + Get comprehensive misconception report. + """ + db = await get_analytics_database() + misconceptions = [] + + if db: + try: + raw_misconceptions = await db.get_misconceptions( + class_id=class_id, unit_id=unit_id, + time_range=time_range.value, limit=limit + ) + + for m in raw_misconceptions: + misconceptions.append(MisconceptionEntry( + concept_id=m["concept_id"], + concept_label=m["concept_label"], + misconception_text=m["misconception_text"], + frequency=m["frequency"], + affected_student_ids=m.get("student_ids", []), + unit_id=m["unit_id"], + stop_id=m["stop_id"], + detected_via=m.get("detected_via", "unknown"), + first_detected=m.get("first_detected", datetime.utcnow()), + last_detected=m.get("last_detected", datetime.utcnow()), + )) + except Exception as e: + logger.error(f"Failed to get misconceptions: {e}") + + # Group by unit + by_unit = {} + for m in misconceptions: + if m.unit_id not in by_unit: + by_unit[m.unit_id] = [] + by_unit[m.unit_id].append(m) + + trending_up = misconceptions[:3] if misconceptions else [] + resolved = [] + + return MisconceptionReport( + class_id=class_id, + time_range=time_range.value, + total_misconceptions=sum(m.frequency for m in misconceptions), + unique_concepts=len(set(m.concept_id for m in misconceptions)), + most_common=sorted(misconceptions, key=lambda x: x.frequency, reverse=True)[:10], + by_unit=by_unit, + trending_up=trending_up, + resolved=resolved, + ) + + +@router.get("/misconceptions/student/{student_id}") +async def get_student_misconceptions( + student_id: str, + time_range: TimeRange = Query(TimeRange.ALL), +) -> Dict[str, Any]: + """ + Get misconceptions for a specific student. + """ + db = await get_analytics_database() + + if db: + try: + misconceptions = await db.get_student_misconceptions( + student_id=student_id, time_range=time_range.value + ) + return { + "student_id": student_id, + "misconceptions": misconceptions, + "recommended_remediation": [ + {"concept": m["concept_label"], "activity": f"Review {m['unit_id']}/{m['stop_id']}"} + for m in misconceptions[:5] + ] + } + except Exception as e: + logger.error(f"Failed to get student misconceptions: {e}") + + return { + "student_id": student_id, + "misconceptions": [], + "recommended_remediation": [], + } + + +# ============================================== +# API Endpoints - Student Progress Timeline +# ============================================== + +@router.get("/student/{student_id}/timeline", response_model=StudentProgressTimeline) +async def get_student_timeline( + student_id: str, + time_range: TimeRange = Query(TimeRange.ALL), +) -> StudentProgressTimeline: + """ + Get detailed progress timeline for a student. + """ + db = await get_analytics_database() + timeline = [] + scores = [] + + if db: + try: + sessions = await db.get_student_sessions( + student_id=student_id, time_range=time_range.value + ) + + for session in sessions: + timeline.append({ + "date": session.get("started_at"), + "unit_id": session.get("unit_id"), + "completed": session.get("completed_at") is not None, + "precheck": session.get("precheck_score"), + "postcheck": session.get("postcheck_score"), + "duration_minutes": session.get("duration_seconds", 0) // 60, + }) + if session.get("postcheck_score") is not None: + scores.append(session["postcheck_score"]) + except Exception as e: + logger.error(f"Failed to get student timeline: {e}") + + trend = calculate_trend(scores) if scores else "insufficient_data" + + return StudentProgressTimeline( + student_id=student_id, + student_name=f"Student {student_id[:8]}", + units_completed=sum(1 for t in timeline if t["completed"]), + total_time_minutes=sum(t["duration_minutes"] for t in timeline), + avg_score=statistics.mean(scores) if scores else 0.0, + trend=trend, + timeline=timeline, + ) + + +# ============================================== +# API Endpoints - Class Comparison +# ============================================== + +@router.get("/compare/classes", response_model=List[ClassComparisonData]) +async def compare_classes( + class_ids: str = Query(..., description="Comma-separated class IDs"), + time_range: TimeRange = Query(TimeRange.MONTH), +) -> List[ClassComparisonData]: + """ + Compare performance across multiple classes. + """ + class_list = [c.strip() for c in class_ids.split(",")] + comparisons = [] + + db = await get_analytics_database() + if db: + for class_id in class_list: + try: + stats = await db.get_class_aggregate_stats(class_id, time_range.value) + comparisons.append(ClassComparisonData( + class_id=class_id, + class_name=stats.get("class_name", f"Klasse {class_id[:8]}"), + student_count=stats.get("student_count", 0), + units_assigned=stats.get("units_assigned", 0), + avg_completion_rate=stats.get("avg_completion_rate", 0.0), + avg_learning_gain=stats.get("avg_learning_gain", 0.0), + avg_time_per_unit=stats.get("avg_time_per_unit", 0.0), + )) + except Exception as e: + logger.error(f"Failed to get stats for class {class_id}: {e}") + + return sorted(comparisons, key=lambda x: x.avg_learning_gain, reverse=True) + + diff --git a/klausur-service/backend/compliance_extraction.py b/klausur-service/backend/compliance_extraction.py new file mode 100644 index 0000000..d73cba6 --- /dev/null +++ b/klausur-service/backend/compliance_extraction.py @@ -0,0 +1,200 @@ +""" +Compliance Extraction & Generation. + +Functions for extracting checkpoints from legal text chunks, +generating controls, and creating remediation measures. +""" + +import re +import hashlib +import logging +from typing import Dict, List, Optional + +from compliance_models import Checkpoint, Control, Measure + +logger = logging.getLogger(__name__) + + +def extract_checkpoints_from_chunk(chunk_text: str, payload: Dict) -> List[Checkpoint]: + """ + Extract checkpoints/requirements from a chunk of text. + + Uses pattern matching to find requirement-like statements. + """ + checkpoints = [] + regulation_code = payload.get("regulation_code", "UNKNOWN") + regulation_name = payload.get("regulation_name", "Unknown") + source_url = payload.get("source_url", "") + chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8] + + # Patterns for different requirement types + patterns = [ + # BSI-TR patterns + (r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'), + # Article patterns (GDPR, AI Act, etc.) + (r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-\u2013:]\s*(.+?)(?=\n|$)', 'article'), + # Numbered requirements + (r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'), + # "Der Verantwortliche muss" patterns + (r'(?:Der Verantwortliche|Die Aufsichtsbeh\u00f6rde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'), + # "Es ist erforderlich" patterns + (r'(?:Es ist erforderlich|Es muss gew\u00e4hrleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'), + ] + + for pattern, pattern_type in patterns: + matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL) + for match in matches: + if pattern_type == 'bsi_requirement': + req_id = match.group(1) + description = match.group(2).strip() + title = req_id + elif pattern_type == 'article': + article_num = match.group(1) + paragraph = match.group(2) or "" + title_text = match.group(3).strip() + req_id = f"{regulation_code}-Art{article_num}" + if paragraph: + req_id += f"-{paragraph}" + title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "") + description = title_text + elif pattern_type == 'numbered': + num = match.group(1) + description = match.group(2).strip() + req_id = f"{regulation_code}-{num}" + title = f"Anforderung {num}" + else: + # Generic requirement + description = match.group(0).strip() + req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}" + title = description[:50] + "..." if len(description) > 50 else description + + # Skip very short matches + if len(description) < 20: + continue + + checkpoint = Checkpoint( + id=req_id, + regulation_code=regulation_code, + regulation_name=regulation_name, + article=title if 'Art' in title else None, + title=title, + description=description[:500], + original_text=description, + chunk_id=chunk_id, + source_url=source_url + ) + checkpoints.append(checkpoint) + + return checkpoints + + +def generate_control_for_checkpoints( + checkpoints: List[Checkpoint], + domain_counts: Dict[str, int], +) -> Optional[Control]: + """ + Generate a control that covers the given checkpoints. + + This is a simplified version - in production this would use the AI assistant. + """ + if not checkpoints: + return None + + # Group by regulation + regulation = checkpoints[0].regulation_code + + # Determine domain based on content + all_text = " ".join([cp.description for cp in checkpoints]).lower() + + domain = "gov" # Default + if any(kw in all_text for kw in ["verschl\u00fcssel", "krypto", "encrypt", "hash"]): + domain = "crypto" + elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]): + domain = "iam" + elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]): + domain = "priv" + elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]): + domain = "sdlc" + elif any(kw in all_text for kw in ["\u00fcberwach", "monitor", "log", "audit"]): + domain = "aud" + elif any(kw in all_text for kw in ["ki", "k\u00fcnstlich", "ai", "machine learning", "model"]): + domain = "ai" + elif any(kw in all_text for kw in ["betrieb", "operation", "verf\u00fcgbar", "backup"]): + domain = "ops" + elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]): + domain = "cra" + + # Generate control ID + domain_count = domain_counts.get(domain, 0) + 1 + control_id = f"{domain.upper()}-{domain_count:03d}" + + # Create title from first checkpoint + title = checkpoints[0].title + if len(title) > 100: + title = title[:97] + "..." + + # Create description + description = f"Control f\u00fcr {regulation}: " + checkpoints[0].description[:200] + + # Pass criteria + pass_criteria = f"Alle {len(checkpoints)} zugeh\u00f6rigen Anforderungen sind erf\u00fcllt und dokumentiert." + + # Implementation guidance + guidance = f"Implementiere Ma\u00dfnahmen zur Erf\u00fcllung der Anforderungen aus {regulation}. " + guidance += f"Dokumentiere die Umsetzung und f\u00fchre regelm\u00e4\u00dfige Reviews durch." + + # Determine if automated + is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"]) + + control = Control( + id=control_id, + domain=domain, + title=title, + description=description, + checkpoints=[cp.id for cp in checkpoints], + pass_criteria=pass_criteria, + implementation_guidance=guidance, + is_automated=is_automated, + automation_tool="CI/CD Pipeline" if is_automated else None, + priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium" + ) + + return control + + +def generate_measure_for_control(control: Control) -> Measure: + """Generate a remediation measure for a control.""" + measure_id = f"M-{control.id}" + + # Determine deadline based on priority + deadline_days = { + "critical": 30, + "high": 60, + "medium": 90, + "low": 180 + }.get(control.priority, 90) + + # Determine responsible team + responsible = { + "priv": "Datenschutzbeauftragter", + "iam": "IT-Security Team", + "sdlc": "Entwicklungsteam", + "crypto": "IT-Security Team", + "ops": "Operations Team", + "aud": "Compliance Team", + "ai": "AI/ML Team", + "cra": "IT-Security Team", + "gov": "Management" + }.get(control.domain, "Compliance Team") + + measure = Measure( + id=measure_id, + control_id=control.id, + title=f"Umsetzung: {control.title[:50]}", + description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}", + responsible=responsible, + deadline_days=deadline_days, + status="pending" + ) + + return measure diff --git a/klausur-service/backend/compliance_models.py b/klausur-service/backend/compliance_models.py new file mode 100644 index 0000000..4161d72 --- /dev/null +++ b/klausur-service/backend/compliance_models.py @@ -0,0 +1,49 @@ +""" +Compliance Pipeline Data Models. + +Dataclasses for checkpoints, controls, and measures. +""" + +from typing import Optional, List +from dataclasses import dataclass + + +@dataclass +class Checkpoint: + """A requirement/checkpoint extracted from legal text.""" + id: str + regulation_code: str + regulation_name: str + article: Optional[str] + title: str + description: str + original_text: str + chunk_id: str + source_url: str + + +@dataclass +class Control: + """A control derived from checkpoints.""" + id: str + domain: str + title: str + description: str + checkpoints: List[str] # List of checkpoint IDs + pass_criteria: str + implementation_guidance: str + is_automated: bool + automation_tool: Optional[str] + priority: str + + +@dataclass +class Measure: + """A remediation measure for a control.""" + id: str + control_id: str + title: str + description: str + responsible: str + deadline_days: int + status: str diff --git a/klausur-service/backend/compliance_pipeline.py b/klausur-service/backend/compliance_pipeline.py new file mode 100644 index 0000000..8598d06 --- /dev/null +++ b/klausur-service/backend/compliance_pipeline.py @@ -0,0 +1,441 @@ +""" +Compliance Pipeline Execution. + +Pipeline phases (ingestion, extraction, control generation, measures) +and orchestration logic. +""" + +import asyncio +import json +import logging +import os +import sys +import time +from datetime import datetime +from typing import Dict, List, Any +from dataclasses import asdict + +from compliance_models import Checkpoint, Control, Measure +from compliance_extraction import ( + extract_checkpoints_from_chunk, + generate_control_for_checkpoints, + generate_measure_for_control, +) + +logger = logging.getLogger(__name__) + +# Import checkpoint manager +try: + from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus +except ImportError: + logger.warning("Checkpoint manager not available, running without checkpoints") + CheckpointManager = None + EXPECTED_VALUES = {} + ValidationStatus = None + +# Set environment variables for Docker network +if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"): + os.environ["QDRANT_HOST"] = "qdrant" +os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087") + +# Try to import from klausur-service +try: + from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION + from qdrant_client import QdrantClient + from qdrant_client.models import Filter, FieldCondition, MatchValue +except ImportError: + logger.error("Could not import required modules. Make sure you're in the klausur-service container.") + sys.exit(1) + + +class CompliancePipeline: + """Handles the full compliance pipeline.""" + + def __init__(self): + # Support both QDRANT_URL and QDRANT_HOST/PORT + qdrant_url = os.getenv("QDRANT_URL", "") + if qdrant_url: + from urllib.parse import urlparse + parsed = urlparse(qdrant_url) + qdrant_host = parsed.hostname or "qdrant" + qdrant_port = parsed.port or 6333 + else: + qdrant_host = os.getenv("QDRANT_HOST", "qdrant") + qdrant_port = 6333 + self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port) + self.checkpoints: List[Checkpoint] = [] + self.controls: List[Control] = [] + self.measures: List[Measure] = [] + self.stats = { + "chunks_processed": 0, + "checkpoints_extracted": 0, + "controls_created": 0, + "measures_defined": 0, + "by_regulation": {}, + "by_domain": {}, + } + # Initialize checkpoint manager + self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None + + async def run_ingestion_phase(self, force_reindex: bool = False) -> int: + """Phase 1: Ingest documents (incremental - only missing ones).""" + logger.info("\n" + "=" * 60) + logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)") + logger.info("=" * 60) + + if self.checkpoint_mgr: + self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion") + + ingestion = LegalCorpusIngestion() + + try: + # Check existing chunks per regulation + existing_chunks = {} + try: + for regulation in REGULATIONS: + count_result = self.qdrant.count( + collection_name=LEGAL_CORPUS_COLLECTION, + count_filter=Filter( + must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))] + ) + ) + existing_chunks[regulation.code] = count_result.count + logger.info(f" {regulation.code}: {count_result.count} existing chunks") + except Exception as e: + logger.warning(f"Could not check existing chunks: {e}") + + # Determine which regulations need ingestion + regulations_to_ingest = [] + for regulation in REGULATIONS: + existing = existing_chunks.get(regulation.code, 0) + if force_reindex or existing == 0: + regulations_to_ingest.append(regulation) + logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})") + else: + logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)") + self.stats["by_regulation"][regulation.code] = existing + + if not regulations_to_ingest: + logger.info("All regulations already indexed. Skipping ingestion phase.") + total_chunks = sum(existing_chunks.values()) + self.stats["chunks_processed"] = total_chunks + if self.checkpoint_mgr: + self.checkpoint_mgr.add_metric("total_chunks", total_chunks) + self.checkpoint_mgr.add_metric("skipped", True) + self.checkpoint_mgr.complete_checkpoint(success=True) + return total_chunks + + # Ingest only missing regulations + total_chunks = sum(existing_chunks.values()) + for i, regulation in enumerate(regulations_to_ingest, 1): + logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...") + try: + count = await ingestion.ingest_regulation(regulation) + total_chunks += count + self.stats["by_regulation"][regulation.code] = count + logger.info(f" -> {count} chunks") + + if self.checkpoint_mgr: + self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count) + + except Exception as e: + logger.error(f" -> FAILED: {e}") + self.stats["by_regulation"][regulation.code] = 0 + + self.stats["chunks_processed"] = total_chunks + logger.info(f"\nTotal chunks in collection: {total_chunks}") + + # Validate ingestion results + if self.checkpoint_mgr: + self.checkpoint_mgr.add_metric("total_chunks", total_chunks) + self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS)) + + expected = EXPECTED_VALUES.get("ingestion", {}) + self.checkpoint_mgr.validate( + "total_chunks", + expected=expected.get("total_chunks", 8000), + actual=total_chunks, + min_value=expected.get("min_chunks", 7000) + ) + + reg_expected = expected.get("regulations", {}) + for reg_code, reg_exp in reg_expected.items(): + actual = self.stats["by_regulation"].get(reg_code, 0) + self.checkpoint_mgr.validate( + f"chunks_{reg_code}", + expected=reg_exp.get("expected", 0), + actual=actual, + min_value=reg_exp.get("min", 0) + ) + + self.checkpoint_mgr.complete_checkpoint(success=True) + + return total_chunks + + except Exception as e: + if self.checkpoint_mgr: + self.checkpoint_mgr.fail_checkpoint(str(e)) + raise + + finally: + await ingestion.close() + + async def run_extraction_phase(self) -> int: + """Phase 2: Extract checkpoints from chunks.""" + logger.info("\n" + "=" * 60) + logger.info("PHASE 2: CHECKPOINT EXTRACTION") + logger.info("=" * 60) + + if self.checkpoint_mgr: + self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction") + + try: + offset = None + total_checkpoints = 0 + + while True: + result = self.qdrant.scroll( + collection_name=LEGAL_CORPUS_COLLECTION, + limit=100, + offset=offset, + with_payload=True, + with_vectors=False + ) + + points, next_offset = result + + if not points: + break + + for point in points: + payload = point.payload + text = payload.get("text", "") + + cps = extract_checkpoints_from_chunk(text, payload) + self.checkpoints.extend(cps) + total_checkpoints += len(cps) + + logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...") + + if next_offset is None: + break + offset = next_offset + + self.stats["checkpoints_extracted"] = len(self.checkpoints) + logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}") + + by_reg = {} + for cp in self.checkpoints: + by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1 + for reg, count in sorted(by_reg.items()): + logger.info(f" {reg}: {count} checkpoints") + + if self.checkpoint_mgr: + self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints)) + self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg) + + expected = EXPECTED_VALUES.get("extraction", {}) + self.checkpoint_mgr.validate( + "total_checkpoints", + expected=expected.get("total_checkpoints", 3500), + actual=len(self.checkpoints), + min_value=expected.get("min_checkpoints", 3000) + ) + + self.checkpoint_mgr.complete_checkpoint(success=True) + + return len(self.checkpoints) + + except Exception as e: + if self.checkpoint_mgr: + self.checkpoint_mgr.fail_checkpoint(str(e)) + raise + + async def run_control_generation_phase(self) -> int: + """Phase 3: Generate controls from checkpoints.""" + logger.info("\n" + "=" * 60) + logger.info("PHASE 3: CONTROL GENERATION") + logger.info("=" * 60) + + if self.checkpoint_mgr: + self.checkpoint_mgr.start_checkpoint("controls", "Control Generation") + + try: + # Group checkpoints by regulation + by_regulation: Dict[str, List[Checkpoint]] = {} + for cp in self.checkpoints: + reg = cp.regulation_code + if reg not in by_regulation: + by_regulation[reg] = [] + by_regulation[reg].append(cp) + + # Generate controls per regulation (group every 3-5 checkpoints) + for regulation, checkpoints in by_regulation.items(): + logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...") + + batch_size = 4 + for i in range(0, len(checkpoints), batch_size): + batch = checkpoints[i:i + batch_size] + control = generate_control_for_checkpoints(batch, self.stats.get("by_domain", {})) + + if control: + self.controls.append(control) + self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1 + + self.stats["controls_created"] = len(self.controls) + logger.info(f"\nTotal controls created: {len(self.controls)}") + + for domain, count in sorted(self.stats["by_domain"].items()): + logger.info(f" {domain}: {count} controls") + + if self.checkpoint_mgr: + self.checkpoint_mgr.add_metric("total_controls", len(self.controls)) + self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"])) + + expected = EXPECTED_VALUES.get("controls", {}) + self.checkpoint_mgr.validate( + "total_controls", + expected=expected.get("total_controls", 900), + actual=len(self.controls), + min_value=expected.get("min_controls", 800) + ) + + self.checkpoint_mgr.complete_checkpoint(success=True) + + return len(self.controls) + + except Exception as e: + if self.checkpoint_mgr: + self.checkpoint_mgr.fail_checkpoint(str(e)) + raise + + async def run_measure_generation_phase(self) -> int: + """Phase 4: Generate measures for controls.""" + logger.info("\n" + "=" * 60) + logger.info("PHASE 4: MEASURE GENERATION") + logger.info("=" * 60) + + if self.checkpoint_mgr: + self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation") + + try: + for control in self.controls: + measure = generate_measure_for_control(control) + self.measures.append(measure) + + self.stats["measures_defined"] = len(self.measures) + logger.info(f"\nTotal measures defined: {len(self.measures)}") + + if self.checkpoint_mgr: + self.checkpoint_mgr.add_metric("total_measures", len(self.measures)) + + expected = EXPECTED_VALUES.get("measures", {}) + self.checkpoint_mgr.validate( + "total_measures", + expected=expected.get("total_measures", 900), + actual=len(self.measures), + min_value=expected.get("min_measures", 800) + ) + + self.checkpoint_mgr.complete_checkpoint(success=True) + + return len(self.measures) + + except Exception as e: + if self.checkpoint_mgr: + self.checkpoint_mgr.fail_checkpoint(str(e)) + raise + + def save_results(self, output_dir: str = "/tmp/compliance_output"): + """Save results to JSON files.""" + logger.info("\n" + "=" * 60) + logger.info("SAVING RESULTS") + logger.info("=" * 60) + + os.makedirs(output_dir, exist_ok=True) + + checkpoints_file = os.path.join(output_dir, "checkpoints.json") + with open(checkpoints_file, "w") as f: + json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False) + logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}") + + controls_file = os.path.join(output_dir, "controls.json") + with open(controls_file, "w") as f: + json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False) + logger.info(f"Saved {len(self.controls)} controls to {controls_file}") + + measures_file = os.path.join(output_dir, "measures.json") + with open(measures_file, "w") as f: + json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False) + logger.info(f"Saved {len(self.measures)} measures to {measures_file}") + + stats_file = os.path.join(output_dir, "statistics.json") + self.stats["generated_at"] = datetime.now().isoformat() + with open(stats_file, "w") as f: + json.dump(self.stats, f, indent=2, ensure_ascii=False) + logger.info(f"Saved statistics to {stats_file}") + + async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False): + """Run the complete pipeline. + + Args: + force_reindex: If True, re-ingest all documents even if they exist + skip_ingestion: If True, skip ingestion phase entirely (use existing chunks) + """ + start_time = time.time() + + logger.info("=" * 60) + logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)") + logger.info(f"Started at: {datetime.now().isoformat()}") + logger.info(f"Force reindex: {force_reindex}") + logger.info(f"Skip ingestion: {skip_ingestion}") + if self.checkpoint_mgr: + logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}") + logger.info("=" * 60) + + try: + if skip_ingestion: + logger.info("Skipping ingestion phase as requested...") + try: + collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION) + self.stats["chunks_processed"] = collection_info.points_count + except Exception: + self.stats["chunks_processed"] = 0 + else: + await self.run_ingestion_phase(force_reindex=force_reindex) + + await self.run_extraction_phase() + await self.run_control_generation_phase() + await self.run_measure_generation_phase() + self.save_results() + + elapsed = time.time() - start_time + logger.info("\n" + "=" * 60) + logger.info("PIPELINE COMPLETE") + logger.info("=" * 60) + logger.info(f"Duration: {elapsed:.1f} seconds") + logger.info(f"Chunks processed: {self.stats['chunks_processed']}") + logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}") + logger.info(f"Controls created: {self.stats['controls_created']}") + logger.info(f"Measures defined: {self.stats['measures_defined']}") + logger.info(f"\nResults saved to: /tmp/compliance_output/") + logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json") + logger.info("=" * 60) + + if self.checkpoint_mgr: + self.checkpoint_mgr.complete_pipeline({ + "duration_seconds": elapsed, + "chunks_processed": self.stats['chunks_processed'], + "checkpoints_extracted": self.stats['checkpoints_extracted'], + "controls_created": self.stats['controls_created'], + "measures_defined": self.stats['measures_defined'], + "by_regulation": self.stats['by_regulation'], + "by_domain": self.stats['by_domain'], + }) + + except Exception as e: + logger.error(f"Pipeline failed: {e}") + if self.checkpoint_mgr: + self.checkpoint_mgr.state.status = "failed" + self.checkpoint_mgr._save() + raise diff --git a/klausur-service/backend/dsfa_rag_api.py b/klausur-service/backend/dsfa_rag_api.py index 554634f..0e6b63c 100644 --- a/klausur-service/backend/dsfa_rag_api.py +++ b/klausur-service/backend/dsfa_rag_api.py @@ -1,7 +1,10 @@ """ -DSFA RAG API Endpoints. +DSFA RAG API Endpoints — Barrel Re-export. -Provides REST API for searching DSFA corpus with full source attribution. +Split into submodules: +- dsfa_rag_models.py — Pydantic request/response models +- dsfa_rag_embedding.py — Embedding service integration & text extraction +- dsfa_rag_routes.py — Route handlers (search, sources, ingest, stats) Endpoints: - GET /api/v1/dsfa-rag/search - Semantic search with attribution @@ -11,705 +14,54 @@ Endpoints: - GET /api/v1/dsfa-rag/stats - Get corpus statistics """ -import os -import uuid -import logging -from typing import List, Optional -from dataclasses import dataclass, asdict - -import httpx -from fastapi import APIRouter, HTTPException, Query, Depends -from pydantic import BaseModel, Field - -logger = logging.getLogger(__name__) - -# Embedding service configuration -EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087") - -# Import from ingestion module -from dsfa_corpus_ingestion import ( - DSFACorpusStore, - DSFAQdrantService, - DSFASearchResult, - LICENSE_REGISTRY, - DSFA_SOURCES, - generate_attribution_notice, - get_license_label, - DSFA_COLLECTION, - chunk_document +# Models +from dsfa_rag_models import ( + DSFASourceResponse, + DSFAChunkResponse, + DSFASearchResultResponse, + DSFASearchResponse, + DSFASourceStatsResponse, + DSFACorpusStatsResponse, + IngestRequest, + IngestResponse, + LicenseInfo, ) -router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"]) - - -# ============================================================================= -# Pydantic Models -# ============================================================================= - -class DSFASourceResponse(BaseModel): - """Response model for DSFA source.""" - id: str - source_code: str - name: str - full_name: Optional[str] = None - organization: Optional[str] = None - source_url: Optional[str] = None - license_code: str - license_name: str - license_url: Optional[str] = None - attribution_required: bool - attribution_text: str - document_type: Optional[str] = None - language: str = "de" - - -class DSFAChunkResponse(BaseModel): - """Response model for a single chunk with attribution.""" - chunk_id: str - content: str - section_title: Optional[str] = None - page_number: Optional[int] = None - category: Optional[str] = None - - # Document info - document_id: str - document_title: Optional[str] = None - - # Attribution (always included) - source_id: str - source_code: str - source_name: str - attribution_text: str - license_code: str - license_name: str - license_url: Optional[str] = None - attribution_required: bool - source_url: Optional[str] = None - document_type: Optional[str] = None - - -class DSFASearchResultResponse(BaseModel): - """Response model for search result.""" - chunk_id: str - content: str - score: float - - # Attribution - source_code: str - source_name: str - attribution_text: str - license_code: str - license_name: str - license_url: Optional[str] = None - attribution_required: bool - source_url: Optional[str] = None - - # Metadata - document_type: Optional[str] = None - category: Optional[str] = None - section_title: Optional[str] = None - page_number: Optional[int] = None - - -class DSFASearchResponse(BaseModel): - """Response model for search endpoint.""" - query: str - results: List[DSFASearchResultResponse] - total_results: int - - # Aggregated licenses for footer - licenses_used: List[str] - attribution_notice: str - - -class DSFASourceStatsResponse(BaseModel): - """Response model for source statistics.""" - source_id: str - source_code: str - name: str - organization: Optional[str] = None - license_code: str - document_type: Optional[str] = None - document_count: int - chunk_count: int - last_indexed_at: Optional[str] = None - - -class DSFACorpusStatsResponse(BaseModel): - """Response model for corpus statistics.""" - sources: List[DSFASourceStatsResponse] - total_sources: int - total_documents: int - total_chunks: int - qdrant_collection: str - qdrant_points_count: int - qdrant_status: str - - -class IngestRequest(BaseModel): - """Request model for ingestion.""" - document_url: Optional[str] = None - document_text: Optional[str] = None - title: Optional[str] = None - - -class IngestResponse(BaseModel): - """Response model for ingestion.""" - source_code: str - document_id: Optional[str] = None - chunks_created: int - message: str - - -class LicenseInfo(BaseModel): - """License information.""" - code: str - name: str - url: Optional[str] = None - attribution_required: bool - modification_allowed: bool - commercial_use: bool - - -# ============================================================================= -# Dependency Injection -# ============================================================================= - -# Database pool (will be set from main.py) -_db_pool = None - - -def set_db_pool(pool): - """Set the database pool for API endpoints.""" - global _db_pool - _db_pool = pool - - -async def get_store() -> DSFACorpusStore: - """Get DSFA corpus store.""" - if _db_pool is None: - raise HTTPException(status_code=503, detail="Database not initialized") - return DSFACorpusStore(_db_pool) - - -async def get_qdrant() -> DSFAQdrantService: - """Get Qdrant service.""" - return DSFAQdrantService() - - -# ============================================================================= -# Embedding Service Integration -# ============================================================================= - -async def get_embedding(text: str) -> List[float]: - """ - Get embedding for text using the embedding-service. - - Uses BGE-M3 model which produces 1024-dimensional vectors. - """ - async with httpx.AsyncClient(timeout=60.0) as client: - try: - response = await client.post( - f"{EMBEDDING_SERVICE_URL}/embed-single", - json={"text": text} - ) - response.raise_for_status() - data = response.json() - return data.get("embedding", []) - except httpx.HTTPError as e: - logger.error(f"Embedding service error: {e}") - # Fallback to hash-based pseudo-embedding for development - return _generate_fallback_embedding(text) - - -async def get_embeddings_batch(texts: List[str]) -> List[List[float]]: - """ - Get embeddings for multiple texts in batch. - """ - async with httpx.AsyncClient(timeout=120.0) as client: - try: - response = await client.post( - f"{EMBEDDING_SERVICE_URL}/embed", - json={"texts": texts} - ) - response.raise_for_status() - data = response.json() - return data.get("embeddings", []) - except httpx.HTTPError as e: - logger.error(f"Embedding service batch error: {e}") - # Fallback - return [_generate_fallback_embedding(t) for t in texts] - - -async def extract_text_from_url(url: str) -> str: - """ - Extract text from a document URL (PDF, HTML, etc.). - """ - async with httpx.AsyncClient(timeout=120.0) as client: - try: - # First try to use the embedding-service's extract-pdf endpoint - response = await client.post( - f"{EMBEDDING_SERVICE_URL}/extract-pdf", - json={"url": url} - ) - response.raise_for_status() - data = response.json() - return data.get("text", "") - except httpx.HTTPError as e: - logger.error(f"PDF extraction error for {url}: {e}") - # Fallback: try to fetch HTML content directly - try: - response = await client.get(url, follow_redirects=True) - response.raise_for_status() - content_type = response.headers.get("content-type", "") - if "html" in content_type: - # Simple HTML text extraction - import re - html = response.text - # Remove scripts and styles - html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) - html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) - # Remove tags - text = re.sub(r'<[^>]+>', ' ', html) - # Clean whitespace - text = re.sub(r'\s+', ' ', text).strip() - return text - else: - return "" - except Exception as fetch_err: - logger.error(f"Fallback fetch error for {url}: {fetch_err}") - return "" - - -def _generate_fallback_embedding(text: str) -> List[float]: - """ - Generate deterministic pseudo-embedding from text hash. - Used as fallback when embedding service is unavailable. - """ - import hashlib - import struct - - hash_bytes = hashlib.sha256(text.encode()).digest() - embedding = [] - for i in range(0, min(len(hash_bytes), 128), 4): - val = struct.unpack('f', hash_bytes[i:i+4])[0] - embedding.append(val % 1.0) - - # Pad to 1024 dimensions - while len(embedding) < 1024: - embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))]) - - return embedding[:1024] - - -# ============================================================================= -# API Endpoints -# ============================================================================= - -@router.get("/search", response_model=DSFASearchResponse) -async def search_dsfa_corpus( - query: str = Query(..., min_length=3, description="Search query"), - source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"), - document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"), - categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"), - limit: int = Query(10, ge=1, le=50, description="Maximum results"), - include_attribution: bool = Query(True, description="Include attribution in results"), - store: DSFACorpusStore = Depends(get_store), - qdrant: DSFAQdrantService = Depends(get_qdrant) -): - """ - Search DSFA corpus with full attribution. - - Returns matching chunks with source/license information for compliance. - """ - # Get query embedding - query_embedding = await get_embedding(query) - - # Search Qdrant - raw_results = await qdrant.search( - query_embedding=query_embedding, - source_codes=source_codes, - document_types=document_types, - categories=categories, - limit=limit - ) - - # Transform results - results = [] - licenses_used = set() - - for r in raw_results: - license_code = r.get("license_code", "") - license_info = LICENSE_REGISTRY.get(license_code, {}) - - result = DSFASearchResultResponse( - chunk_id=r.get("chunk_id", ""), - content=r.get("content", ""), - score=r.get("score", 0.0), - source_code=r.get("source_code", ""), - source_name=r.get("source_name", ""), - attribution_text=r.get("attribution_text", ""), - license_code=license_code, - license_name=license_info.get("name", license_code), - license_url=license_info.get("url"), - attribution_required=r.get("attribution_required", True), - source_url=r.get("source_url"), - document_type=r.get("document_type"), - category=r.get("category"), - section_title=r.get("section_title"), - page_number=r.get("page_number") - ) - results.append(result) - licenses_used.add(license_code) - - # Generate attribution notice - search_results = [ - DSFASearchResult( - chunk_id=r.chunk_id, - content=r.content, - score=r.score, - source_code=r.source_code, - source_name=r.source_name, - attribution_text=r.attribution_text, - license_code=r.license_code, - license_url=r.license_url, - attribution_required=r.attribution_required, - source_url=r.source_url, - document_type=r.document_type or "", - category=r.category or "", - section_title=r.section_title, - page_number=r.page_number - ) - for r in results - ] - attribution_notice = generate_attribution_notice(search_results) if include_attribution else "" - - return DSFASearchResponse( - query=query, - results=results, - total_results=len(results), - licenses_used=list(licenses_used), - attribution_notice=attribution_notice - ) - - -@router.get("/sources", response_model=List[DSFASourceResponse]) -async def list_dsfa_sources( - document_type: Optional[str] = Query(None, description="Filter by document type"), - license_code: Optional[str] = Query(None, description="Filter by license"), - store: DSFACorpusStore = Depends(get_store) -): - """List all registered DSFA sources with license info.""" - sources = await store.list_sources() - - result = [] - for s in sources: - # Apply filters - if document_type and s.get("document_type") != document_type: - continue - if license_code and s.get("license_code") != license_code: - continue - - license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {}) - - result.append(DSFASourceResponse( - id=str(s["id"]), - source_code=s["source_code"], - name=s["name"], - full_name=s.get("full_name"), - organization=s.get("organization"), - source_url=s.get("source_url"), - license_code=s.get("license_code", ""), - license_name=license_info.get("name", s.get("license_code", "")), - license_url=license_info.get("url"), - attribution_required=s.get("attribution_required", True), - attribution_text=s.get("attribution_text", ""), - document_type=s.get("document_type"), - language=s.get("language", "de") - )) - - return result - - -@router.get("/sources/available") -async def list_available_sources(): - """List all available source definitions (from DSFA_SOURCES constant).""" - return [ - { - "source_code": s["source_code"], - "name": s["name"], - "organization": s.get("organization"), - "license_code": s["license_code"], - "document_type": s.get("document_type") - } - for s in DSFA_SOURCES - ] - - -@router.get("/sources/{source_code}", response_model=DSFASourceResponse) -async def get_dsfa_source( - source_code: str, - store: DSFACorpusStore = Depends(get_store) -): - """Get details for a specific source.""" - source = await store.get_source_by_code(source_code) - - if not source: - raise HTTPException(status_code=404, detail=f"Source not found: {source_code}") - - license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {}) - - return DSFASourceResponse( - id=str(source["id"]), - source_code=source["source_code"], - name=source["name"], - full_name=source.get("full_name"), - organization=source.get("organization"), - source_url=source.get("source_url"), - license_code=source.get("license_code", ""), - license_name=license_info.get("name", source.get("license_code", "")), - license_url=license_info.get("url"), - attribution_required=source.get("attribution_required", True), - attribution_text=source.get("attribution_text", ""), - document_type=source.get("document_type"), - language=source.get("language", "de") - ) - - -@router.post("/sources/{source_code}/ingest", response_model=IngestResponse) -async def ingest_dsfa_source( - source_code: str, - request: IngestRequest, - store: DSFACorpusStore = Depends(get_store), - qdrant: DSFAQdrantService = Depends(get_qdrant) -): - """ - Trigger ingestion for a specific source. - - Can provide document via URL or direct text. - """ - # Get source - source = await store.get_source_by_code(source_code) - if not source: - raise HTTPException(status_code=404, detail=f"Source not found: {source_code}") - - # Need either URL or text - if not request.document_text and not request.document_url: - raise HTTPException( - status_code=400, - detail="Either document_text or document_url must be provided" - ) - - # Ensure Qdrant collection exists - await qdrant.ensure_collection() - - # Get text content - text_content = request.document_text - if request.document_url and not text_content: - # Download and extract text from URL - logger.info(f"Extracting text from URL: {request.document_url}") - text_content = await extract_text_from_url(request.document_url) - if not text_content: - raise HTTPException( - status_code=400, - detail=f"Could not extract text from URL: {request.document_url}" - ) - - if not text_content or len(text_content.strip()) < 50: - raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)") - - # Create document record - doc_title = request.title or f"Document for {source_code}" - document_id = await store.create_document( - source_id=str(source["id"]), - title=doc_title, - file_type="text", - metadata={"ingested_via": "api", "source_code": source_code} - ) - - # Chunk the document - chunks = chunk_document(text_content, source_code) - - if not chunks: - return IngestResponse( - source_code=source_code, - document_id=document_id, - chunks_created=0, - message="Document created but no chunks generated" - ) - - # Generate embeddings in batch for efficiency - chunk_texts = [chunk["content"] for chunk in chunks] - logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...") - embeddings = await get_embeddings_batch(chunk_texts) - - # Create chunk records in PostgreSQL and prepare for Qdrant - chunk_records = [] - for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): - # Create chunk in PostgreSQL - chunk_id = await store.create_chunk( - document_id=document_id, - source_id=str(source["id"]), - content=chunk["content"], - chunk_index=i, - section_title=chunk.get("section_title"), - page_number=chunk.get("page_number"), - category=chunk.get("category") - ) - - chunk_records.append({ - "chunk_id": chunk_id, - "document_id": document_id, - "source_id": str(source["id"]), - "content": chunk["content"], - "section_title": chunk.get("section_title"), - "source_code": source_code, - "source_name": source["name"], - "attribution_text": source["attribution_text"], - "license_code": source["license_code"], - "attribution_required": source.get("attribution_required", True), - "document_type": source.get("document_type", ""), - "category": chunk.get("category", ""), - "language": source.get("language", "de"), - "page_number": chunk.get("page_number") - }) - - # Index in Qdrant - indexed_count = await qdrant.index_chunks(chunk_records, embeddings) - - # Update document record - await store.update_document_indexed(document_id, len(chunks)) - - return IngestResponse( - source_code=source_code, - document_id=document_id, - chunks_created=indexed_count, - message=f"Successfully ingested {indexed_count} chunks from document" - ) - - -@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse) -async def get_chunk_with_attribution( - chunk_id: str, - store: DSFACorpusStore = Depends(get_store) -): - """Get single chunk with full source attribution.""" - chunk = await store.get_chunk_with_attribution(chunk_id) - - if not chunk: - raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}") - - license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {}) - - return DSFAChunkResponse( - chunk_id=str(chunk["chunk_id"]), - content=chunk.get("content", ""), - section_title=chunk.get("section_title"), - page_number=chunk.get("page_number"), - category=chunk.get("category"), - document_id=str(chunk.get("document_id", "")), - document_title=chunk.get("document_title"), - source_id=str(chunk.get("source_id", "")), - source_code=chunk.get("source_code", ""), - source_name=chunk.get("source_name", ""), - attribution_text=chunk.get("attribution_text", ""), - license_code=chunk.get("license_code", ""), - license_name=license_info.get("name", chunk.get("license_code", "")), - license_url=license_info.get("url"), - attribution_required=chunk.get("attribution_required", True), - source_url=chunk.get("source_url"), - document_type=chunk.get("document_type") - ) - - -@router.get("/stats", response_model=DSFACorpusStatsResponse) -async def get_corpus_stats( - store: DSFACorpusStore = Depends(get_store), - qdrant: DSFAQdrantService = Depends(get_qdrant) -): - """Get corpus statistics for dashboard.""" - # Get PostgreSQL stats - source_stats = await store.get_source_stats() - - total_docs = 0 - total_chunks = 0 - stats_response = [] - - for s in source_stats: - doc_count = s.get("document_count", 0) or 0 - chunk_count = s.get("chunk_count", 0) or 0 - total_docs += doc_count - total_chunks += chunk_count - - last_indexed = s.get("last_indexed_at") - - stats_response.append(DSFASourceStatsResponse( - source_id=str(s.get("source_id", "")), - source_code=s.get("source_code", ""), - name=s.get("name", ""), - organization=s.get("organization"), - license_code=s.get("license_code", ""), - document_type=s.get("document_type"), - document_count=doc_count, - chunk_count=chunk_count, - last_indexed_at=last_indexed.isoformat() if last_indexed else None - )) - - # Get Qdrant stats - qdrant_stats = await qdrant.get_stats() - - return DSFACorpusStatsResponse( - sources=stats_response, - total_sources=len(source_stats), - total_documents=total_docs, - total_chunks=total_chunks, - qdrant_collection=DSFA_COLLECTION, - qdrant_points_count=qdrant_stats.get("points_count", 0), - qdrant_status=qdrant_stats.get("status", "unknown") - ) - - -@router.get("/licenses") -async def list_licenses(): - """List all supported licenses with their terms.""" - return [ - LicenseInfo( - code=code, - name=info.get("name", code), - url=info.get("url"), - attribution_required=info.get("attribution_required", True), - modification_allowed=info.get("modification_allowed", True), - commercial_use=info.get("commercial_use", True) - ) - for code, info in LICENSE_REGISTRY.items() - ] - - -@router.post("/init") -async def initialize_dsfa_corpus( - store: DSFACorpusStore = Depends(get_store), - qdrant: DSFAQdrantService = Depends(get_qdrant) -): - """ - Initialize DSFA corpus. - - - Creates Qdrant collection - - Registers all predefined sources - """ - # Ensure Qdrant collection exists - qdrant_ok = await qdrant.ensure_collection() - - # Register all sources - registered = 0 - for source in DSFA_SOURCES: - try: - await store.register_source(source) - registered += 1 - except Exception as e: - print(f"Error registering source {source['source_code']}: {e}") - - return { - "qdrant_collection_created": qdrant_ok, - "sources_registered": registered, - "total_sources": len(DSFA_SOURCES) - } +# Embedding utilities +from dsfa_rag_embedding import ( + get_embedding, + get_embeddings_batch, + extract_text_from_url, + EMBEDDING_SERVICE_URL, +) + +# Routes (router + set_db_pool) +from dsfa_rag_routes import ( + router, + set_db_pool, + get_store, + get_qdrant, +) + +__all__ = [ + # Router + "router", + "set_db_pool", + "get_store", + "get_qdrant", + # Models + "DSFASourceResponse", + "DSFAChunkResponse", + "DSFASearchResultResponse", + "DSFASearchResponse", + "DSFASourceStatsResponse", + "DSFACorpusStatsResponse", + "IngestRequest", + "IngestResponse", + "LicenseInfo", + # Embedding + "get_embedding", + "get_embeddings_batch", + "extract_text_from_url", + "EMBEDDING_SERVICE_URL", +] diff --git a/klausur-service/backend/dsfa_rag_embedding.py b/klausur-service/backend/dsfa_rag_embedding.py new file mode 100644 index 0000000..781cd85 --- /dev/null +++ b/klausur-service/backend/dsfa_rag_embedding.py @@ -0,0 +1,116 @@ +""" +DSFA RAG Embedding Service Integration. + +Handles embedding generation, text extraction, and fallback logic. +""" + +import os +import hashlib +import logging +import struct +import re +from typing import List + +import httpx + +logger = logging.getLogger(__name__) + +# Embedding service configuration +EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087") + + +async def get_embedding(text: str) -> List[float]: + """ + Get embedding for text using the embedding-service. + + Uses BGE-M3 model which produces 1024-dimensional vectors. + """ + async with httpx.AsyncClient(timeout=60.0) as client: + try: + response = await client.post( + f"{EMBEDDING_SERVICE_URL}/embed-single", + json={"text": text} + ) + response.raise_for_status() + data = response.json() + return data.get("embedding", []) + except httpx.HTTPError as e: + logger.error(f"Embedding service error: {e}") + # Fallback to hash-based pseudo-embedding for development + return _generate_fallback_embedding(text) + + +async def get_embeddings_batch(texts: List[str]) -> List[List[float]]: + """ + Get embeddings for multiple texts in batch. + """ + async with httpx.AsyncClient(timeout=120.0) as client: + try: + response = await client.post( + f"{EMBEDDING_SERVICE_URL}/embed", + json={"texts": texts} + ) + response.raise_for_status() + data = response.json() + return data.get("embeddings", []) + except httpx.HTTPError as e: + logger.error(f"Embedding service batch error: {e}") + # Fallback + return [_generate_fallback_embedding(t) for t in texts] + + +async def extract_text_from_url(url: str) -> str: + """ + Extract text from a document URL (PDF, HTML, etc.). + """ + async with httpx.AsyncClient(timeout=120.0) as client: + try: + # First try to use the embedding-service's extract-pdf endpoint + response = await client.post( + f"{EMBEDDING_SERVICE_URL}/extract-pdf", + json={"url": url} + ) + response.raise_for_status() + data = response.json() + return data.get("text", "") + except httpx.HTTPError as e: + logger.error(f"PDF extraction error for {url}: {e}") + # Fallback: try to fetch HTML content directly + try: + response = await client.get(url, follow_redirects=True) + response.raise_for_status() + content_type = response.headers.get("content-type", "") + if "html" in content_type: + # Simple HTML text extraction + html = response.text + # Remove scripts and styles + html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) + html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) + # Remove tags + text = re.sub(r'<[^>]+>', ' ', html) + # Clean whitespace + text = re.sub(r'\s+', ' ', text).strip() + return text + else: + return "" + except Exception as fetch_err: + logger.error(f"Fallback fetch error for {url}: {fetch_err}") + return "" + + +def _generate_fallback_embedding(text: str) -> List[float]: + """ + Generate deterministic pseudo-embedding from text hash. + Used as fallback when embedding service is unavailable. + """ + hash_bytes = hashlib.sha256(text.encode()).digest() + embedding = [] + for i in range(0, min(len(hash_bytes), 128), 4): + val = struct.unpack('f', hash_bytes[i:i+4])[0] + embedding.append(val % 1.0) + + # Pad to 1024 dimensions + while len(embedding) < 1024: + embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))]) + + return embedding[:1024] diff --git a/klausur-service/backend/dsfa_rag_models.py b/klausur-service/backend/dsfa_rag_models.py new file mode 100644 index 0000000..0017888 --- /dev/null +++ b/klausur-service/backend/dsfa_rag_models.py @@ -0,0 +1,137 @@ +""" +DSFA RAG Pydantic Models. + +Request/Response models for the DSFA RAG API. +""" + +from typing import List, Optional +from pydantic import BaseModel, Field + + +# ============================================================================= +# Response Models +# ============================================================================= + +class DSFASourceResponse(BaseModel): + """Response model for DSFA source.""" + id: str + source_code: str + name: str + full_name: Optional[str] = None + organization: Optional[str] = None + source_url: Optional[str] = None + license_code: str + license_name: str + license_url: Optional[str] = None + attribution_required: bool + attribution_text: str + document_type: Optional[str] = None + language: str = "de" + + +class DSFAChunkResponse(BaseModel): + """Response model for a single chunk with attribution.""" + chunk_id: str + content: str + section_title: Optional[str] = None + page_number: Optional[int] = None + category: Optional[str] = None + + # Document info + document_id: str + document_title: Optional[str] = None + + # Attribution (always included) + source_id: str + source_code: str + source_name: str + attribution_text: str + license_code: str + license_name: str + license_url: Optional[str] = None + attribution_required: bool + source_url: Optional[str] = None + document_type: Optional[str] = None + + +class DSFASearchResultResponse(BaseModel): + """Response model for search result.""" + chunk_id: str + content: str + score: float + + # Attribution + source_code: str + source_name: str + attribution_text: str + license_code: str + license_name: str + license_url: Optional[str] = None + attribution_required: bool + source_url: Optional[str] = None + + # Metadata + document_type: Optional[str] = None + category: Optional[str] = None + section_title: Optional[str] = None + page_number: Optional[int] = None + + +class DSFASearchResponse(BaseModel): + """Response model for search endpoint.""" + query: str + results: List[DSFASearchResultResponse] + total_results: int + + # Aggregated licenses for footer + licenses_used: List[str] + attribution_notice: str + + +class DSFASourceStatsResponse(BaseModel): + """Response model for source statistics.""" + source_id: str + source_code: str + name: str + organization: Optional[str] = None + license_code: str + document_type: Optional[str] = None + document_count: int + chunk_count: int + last_indexed_at: Optional[str] = None + + +class DSFACorpusStatsResponse(BaseModel): + """Response model for corpus statistics.""" + sources: List[DSFASourceStatsResponse] + total_sources: int + total_documents: int + total_chunks: int + qdrant_collection: str + qdrant_points_count: int + qdrant_status: str + + +class IngestRequest(BaseModel): + """Request model for ingestion.""" + document_url: Optional[str] = None + document_text: Optional[str] = None + title: Optional[str] = None + + +class IngestResponse(BaseModel): + """Response model for ingestion.""" + source_code: str + document_id: Optional[str] = None + chunks_created: int + message: str + + +class LicenseInfo(BaseModel): + """License information.""" + code: str + name: str + url: Optional[str] = None + attribution_required: bool + modification_allowed: bool + commercial_use: bool diff --git a/klausur-service/backend/dsfa_rag_routes.py b/klausur-service/backend/dsfa_rag_routes.py new file mode 100644 index 0000000..79bc2bf --- /dev/null +++ b/klausur-service/backend/dsfa_rag_routes.py @@ -0,0 +1,461 @@ +""" +DSFA RAG API Route Handlers. + +Endpoint implementations for search, sources, ingestion, stats, and init. +""" + +import logging +from typing import List, Optional + +from fastapi import APIRouter, HTTPException, Query, Depends + +from dsfa_corpus_ingestion import ( + DSFACorpusStore, + DSFAQdrantService, + DSFASearchResult, + LICENSE_REGISTRY, + DSFA_SOURCES, + generate_attribution_notice, + get_license_label, + DSFA_COLLECTION, + chunk_document, +) + +from dsfa_rag_models import ( + DSFASourceResponse, + DSFAChunkResponse, + DSFASearchResultResponse, + DSFASearchResponse, + DSFASourceStatsResponse, + DSFACorpusStatsResponse, + IngestRequest, + IngestResponse, + LicenseInfo, +) + +from dsfa_rag_embedding import ( + get_embedding, + get_embeddings_batch, + extract_text_from_url, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"]) + + +# ============================================================================= +# Dependency Injection +# ============================================================================= + +_db_pool = None + + +def set_db_pool(pool): + """Set the database pool for API endpoints.""" + global _db_pool + _db_pool = pool + + +async def get_store() -> DSFACorpusStore: + """Get DSFA corpus store.""" + if _db_pool is None: + raise HTTPException(status_code=503, detail="Database not initialized") + return DSFACorpusStore(_db_pool) + + +async def get_qdrant() -> DSFAQdrantService: + """Get Qdrant service.""" + return DSFAQdrantService() + + +# ============================================================================= +# API Endpoints +# ============================================================================= + +@router.get("/search", response_model=DSFASearchResponse) +async def search_dsfa_corpus( + query: str = Query(..., min_length=3, description="Search query"), + source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"), + document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"), + categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"), + limit: int = Query(10, ge=1, le=50, description="Maximum results"), + include_attribution: bool = Query(True, description="Include attribution in results"), + store: DSFACorpusStore = Depends(get_store), + qdrant: DSFAQdrantService = Depends(get_qdrant) +): + """ + Search DSFA corpus with full attribution. + + Returns matching chunks with source/license information for compliance. + """ + query_embedding = await get_embedding(query) + + raw_results = await qdrant.search( + query_embedding=query_embedding, + source_codes=source_codes, + document_types=document_types, + categories=categories, + limit=limit + ) + + results = [] + licenses_used = set() + + for r in raw_results: + license_code = r.get("license_code", "") + license_info = LICENSE_REGISTRY.get(license_code, {}) + + result = DSFASearchResultResponse( + chunk_id=r.get("chunk_id", ""), + content=r.get("content", ""), + score=r.get("score", 0.0), + source_code=r.get("source_code", ""), + source_name=r.get("source_name", ""), + attribution_text=r.get("attribution_text", ""), + license_code=license_code, + license_name=license_info.get("name", license_code), + license_url=license_info.get("url"), + attribution_required=r.get("attribution_required", True), + source_url=r.get("source_url"), + document_type=r.get("document_type"), + category=r.get("category"), + section_title=r.get("section_title"), + page_number=r.get("page_number") + ) + results.append(result) + licenses_used.add(license_code) + + # Generate attribution notice + search_results = [ + DSFASearchResult( + chunk_id=r.chunk_id, + content=r.content, + score=r.score, + source_code=r.source_code, + source_name=r.source_name, + attribution_text=r.attribution_text, + license_code=r.license_code, + license_url=r.license_url, + attribution_required=r.attribution_required, + source_url=r.source_url, + document_type=r.document_type or "", + category=r.category or "", + section_title=r.section_title, + page_number=r.page_number + ) + for r in results + ] + attribution_notice = generate_attribution_notice(search_results) if include_attribution else "" + + return DSFASearchResponse( + query=query, + results=results, + total_results=len(results), + licenses_used=list(licenses_used), + attribution_notice=attribution_notice + ) + + +@router.get("/sources", response_model=List[DSFASourceResponse]) +async def list_dsfa_sources( + document_type: Optional[str] = Query(None, description="Filter by document type"), + license_code: Optional[str] = Query(None, description="Filter by license"), + store: DSFACorpusStore = Depends(get_store) +): + """List all registered DSFA sources with license info.""" + sources = await store.list_sources() + + result = [] + for s in sources: + if document_type and s.get("document_type") != document_type: + continue + if license_code and s.get("license_code") != license_code: + continue + + license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {}) + + result.append(DSFASourceResponse( + id=str(s["id"]), + source_code=s["source_code"], + name=s["name"], + full_name=s.get("full_name"), + organization=s.get("organization"), + source_url=s.get("source_url"), + license_code=s.get("license_code", ""), + license_name=license_info.get("name", s.get("license_code", "")), + license_url=license_info.get("url"), + attribution_required=s.get("attribution_required", True), + attribution_text=s.get("attribution_text", ""), + document_type=s.get("document_type"), + language=s.get("language", "de") + )) + + return result + + +@router.get("/sources/available") +async def list_available_sources(): + """List all available source definitions (from DSFA_SOURCES constant).""" + return [ + { + "source_code": s["source_code"], + "name": s["name"], + "organization": s.get("organization"), + "license_code": s["license_code"], + "document_type": s.get("document_type") + } + for s in DSFA_SOURCES + ] + + +@router.get("/sources/{source_code}", response_model=DSFASourceResponse) +async def get_dsfa_source( + source_code: str, + store: DSFACorpusStore = Depends(get_store) +): + """Get details for a specific source.""" + source = await store.get_source_by_code(source_code) + + if not source: + raise HTTPException(status_code=404, detail=f"Source not found: {source_code}") + + license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {}) + + return DSFASourceResponse( + id=str(source["id"]), + source_code=source["source_code"], + name=source["name"], + full_name=source.get("full_name"), + organization=source.get("organization"), + source_url=source.get("source_url"), + license_code=source.get("license_code", ""), + license_name=license_info.get("name", source.get("license_code", "")), + license_url=license_info.get("url"), + attribution_required=source.get("attribution_required", True), + attribution_text=source.get("attribution_text", ""), + document_type=source.get("document_type"), + language=source.get("language", "de") + ) + + +@router.post("/sources/{source_code}/ingest", response_model=IngestResponse) +async def ingest_dsfa_source( + source_code: str, + request: IngestRequest, + store: DSFACorpusStore = Depends(get_store), + qdrant: DSFAQdrantService = Depends(get_qdrant) +): + """ + Trigger ingestion for a specific source. + + Can provide document via URL or direct text. + """ + source = await store.get_source_by_code(source_code) + if not source: + raise HTTPException(status_code=404, detail=f"Source not found: {source_code}") + + if not request.document_text and not request.document_url: + raise HTTPException( + status_code=400, + detail="Either document_text or document_url must be provided" + ) + + await qdrant.ensure_collection() + + text_content = request.document_text + if request.document_url and not text_content: + logger.info(f"Extracting text from URL: {request.document_url}") + text_content = await extract_text_from_url(request.document_url) + if not text_content: + raise HTTPException( + status_code=400, + detail=f"Could not extract text from URL: {request.document_url}" + ) + + if not text_content or len(text_content.strip()) < 50: + raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)") + + doc_title = request.title or f"Document for {source_code}" + document_id = await store.create_document( + source_id=str(source["id"]), + title=doc_title, + file_type="text", + metadata={"ingested_via": "api", "source_code": source_code} + ) + + chunks = chunk_document(text_content, source_code) + + if not chunks: + return IngestResponse( + source_code=source_code, + document_id=document_id, + chunks_created=0, + message="Document created but no chunks generated" + ) + + chunk_texts = [chunk["content"] for chunk in chunks] + logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...") + embeddings = await get_embeddings_batch(chunk_texts) + + chunk_records = [] + for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + chunk_id = await store.create_chunk( + document_id=document_id, + source_id=str(source["id"]), + content=chunk["content"], + chunk_index=i, + section_title=chunk.get("section_title"), + page_number=chunk.get("page_number"), + category=chunk.get("category") + ) + + chunk_records.append({ + "chunk_id": chunk_id, + "document_id": document_id, + "source_id": str(source["id"]), + "content": chunk["content"], + "section_title": chunk.get("section_title"), + "source_code": source_code, + "source_name": source["name"], + "attribution_text": source["attribution_text"], + "license_code": source["license_code"], + "attribution_required": source.get("attribution_required", True), + "document_type": source.get("document_type", ""), + "category": chunk.get("category", ""), + "language": source.get("language", "de"), + "page_number": chunk.get("page_number") + }) + + indexed_count = await qdrant.index_chunks(chunk_records, embeddings) + await store.update_document_indexed(document_id, len(chunks)) + + return IngestResponse( + source_code=source_code, + document_id=document_id, + chunks_created=indexed_count, + message=f"Successfully ingested {indexed_count} chunks from document" + ) + + +@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse) +async def get_chunk_with_attribution( + chunk_id: str, + store: DSFACorpusStore = Depends(get_store) +): + """Get single chunk with full source attribution.""" + chunk = await store.get_chunk_with_attribution(chunk_id) + + if not chunk: + raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}") + + license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {}) + + return DSFAChunkResponse( + chunk_id=str(chunk["chunk_id"]), + content=chunk.get("content", ""), + section_title=chunk.get("section_title"), + page_number=chunk.get("page_number"), + category=chunk.get("category"), + document_id=str(chunk.get("document_id", "")), + document_title=chunk.get("document_title"), + source_id=str(chunk.get("source_id", "")), + source_code=chunk.get("source_code", ""), + source_name=chunk.get("source_name", ""), + attribution_text=chunk.get("attribution_text", ""), + license_code=chunk.get("license_code", ""), + license_name=license_info.get("name", chunk.get("license_code", "")), + license_url=license_info.get("url"), + attribution_required=chunk.get("attribution_required", True), + source_url=chunk.get("source_url"), + document_type=chunk.get("document_type") + ) + + +@router.get("/stats", response_model=DSFACorpusStatsResponse) +async def get_corpus_stats( + store: DSFACorpusStore = Depends(get_store), + qdrant: DSFAQdrantService = Depends(get_qdrant) +): + """Get corpus statistics for dashboard.""" + source_stats = await store.get_source_stats() + + total_docs = 0 + total_chunks = 0 + stats_response = [] + + for s in source_stats: + doc_count = s.get("document_count", 0) or 0 + chunk_count = s.get("chunk_count", 0) or 0 + total_docs += doc_count + total_chunks += chunk_count + + last_indexed = s.get("last_indexed_at") + + stats_response.append(DSFASourceStatsResponse( + source_id=str(s.get("source_id", "")), + source_code=s.get("source_code", ""), + name=s.get("name", ""), + organization=s.get("organization"), + license_code=s.get("license_code", ""), + document_type=s.get("document_type"), + document_count=doc_count, + chunk_count=chunk_count, + last_indexed_at=last_indexed.isoformat() if last_indexed else None + )) + + qdrant_stats = await qdrant.get_stats() + + return DSFACorpusStatsResponse( + sources=stats_response, + total_sources=len(source_stats), + total_documents=total_docs, + total_chunks=total_chunks, + qdrant_collection=DSFA_COLLECTION, + qdrant_points_count=qdrant_stats.get("points_count", 0), + qdrant_status=qdrant_stats.get("status", "unknown") + ) + + +@router.get("/licenses") +async def list_licenses(): + """List all supported licenses with their terms.""" + return [ + LicenseInfo( + code=code, + name=info.get("name", code), + url=info.get("url"), + attribution_required=info.get("attribution_required", True), + modification_allowed=info.get("modification_allowed", True), + commercial_use=info.get("commercial_use", True) + ) + for code, info in LICENSE_REGISTRY.items() + ] + + +@router.post("/init") +async def initialize_dsfa_corpus( + store: DSFACorpusStore = Depends(get_store), + qdrant: DSFAQdrantService = Depends(get_qdrant) +): + """ + Initialize DSFA corpus. + + - Creates Qdrant collection + - Registers all predefined sources + """ + qdrant_ok = await qdrant.ensure_collection() + + registered = 0 + for source in DSFA_SOURCES: + try: + await store.register_source(source) + registered += 1 + except Exception as e: + print(f"Error registering source {source['source_code']}: {e}") + + return { + "qdrant_collection_created": qdrant_ok, + "sources_registered": registered, + "total_sources": len(DSFA_SOURCES) + } diff --git a/klausur-service/backend/full_compliance_pipeline.py b/klausur-service/backend/full_compliance_pipeline.py index 7af71d8..fc24d09 100644 --- a/klausur-service/backend/full_compliance_pipeline.py +++ b/klausur-service/backend/full_compliance_pipeline.py @@ -1,31 +1,19 @@ #!/usr/bin/env python3 """ -Full Compliance Pipeline for Legal Corpus. +Full Compliance Pipeline for Legal Corpus — Barrel Re-export. -This script runs the complete pipeline: -1. Re-ingest all legal documents with improved chunking -2. Extract requirements/checkpoints from chunks -3. Generate controls using AI -4. Define remediation measures -5. Update statistics +Split into submodules: +- compliance_models.py — Dataclasses (Checkpoint, Control, Measure) +- compliance_extraction.py — Pattern extraction & control/measure generation +- compliance_pipeline.py — Pipeline phases & orchestrator Run on Mac Mini: nohup python full_compliance_pipeline.py > /tmp/compliance_pipeline.log 2>&1 & - -Checkpoints are saved to /tmp/pipeline_checkpoints.json and can be viewed in admin-v2. """ import asyncio -import json import logging -import os import sys -import time -from datetime import datetime -from typing import Dict, List, Any, Optional -from dataclasses import dataclass, asdict -import re -import hashlib # Configure logging logging.basicConfig( @@ -36,671 +24,25 @@ logging.basicConfig( logging.FileHandler('/tmp/compliance_pipeline.log') ] ) -logger = logging.getLogger(__name__) -# Import checkpoint manager -try: - from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus -except ImportError: - logger.warning("Checkpoint manager not available, running without checkpoints") - CheckpointManager = None - EXPECTED_VALUES = {} - ValidationStatus = None - -# Set environment variables for Docker network -# Support both QDRANT_URL and QDRANT_HOST -if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"): - os.environ["QDRANT_HOST"] = "qdrant" -os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087") - -# Try to import from klausur-service -try: - from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION - from qdrant_client import QdrantClient - from qdrant_client.models import Filter, FieldCondition, MatchValue -except ImportError: - logger.error("Could not import required modules. Make sure you're in the klausur-service container.") - sys.exit(1) - - -@dataclass -class Checkpoint: - """A requirement/checkpoint extracted from legal text.""" - id: str - regulation_code: str - regulation_name: str - article: Optional[str] - title: str - description: str - original_text: str - chunk_id: str - source_url: str - - -@dataclass -class Control: - """A control derived from checkpoints.""" - id: str - domain: str - title: str - description: str - checkpoints: List[str] # List of checkpoint IDs - pass_criteria: str - implementation_guidance: str - is_automated: bool - automation_tool: Optional[str] - priority: str - - -@dataclass -class Measure: - """A remediation measure for a control.""" - id: str - control_id: str - title: str - description: str - responsible: str - deadline_days: int - status: str - - -class CompliancePipeline: - """Handles the full compliance pipeline.""" - - def __init__(self): - # Support both QDRANT_URL and QDRANT_HOST/PORT - qdrant_url = os.getenv("QDRANT_URL", "") - if qdrant_url: - from urllib.parse import urlparse - parsed = urlparse(qdrant_url) - qdrant_host = parsed.hostname or "qdrant" - qdrant_port = parsed.port or 6333 - else: - qdrant_host = os.getenv("QDRANT_HOST", "qdrant") - qdrant_port = 6333 - self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port) - self.checkpoints: List[Checkpoint] = [] - self.controls: List[Control] = [] - self.measures: List[Measure] = [] - self.stats = { - "chunks_processed": 0, - "checkpoints_extracted": 0, - "controls_created": 0, - "measures_defined": 0, - "by_regulation": {}, - "by_domain": {}, - } - # Initialize checkpoint manager - self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None - - def extract_checkpoints_from_chunk(self, chunk_text: str, payload: Dict) -> List[Checkpoint]: - """ - Extract checkpoints/requirements from a chunk of text. - - Uses pattern matching to find requirement-like statements. - """ - checkpoints = [] - regulation_code = payload.get("regulation_code", "UNKNOWN") - regulation_name = payload.get("regulation_name", "Unknown") - source_url = payload.get("source_url", "") - chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8] - - # Patterns for different requirement types - patterns = [ - # BSI-TR patterns - (r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'), - # Article patterns (GDPR, AI Act, etc.) - (r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-–:]\s*(.+?)(?=\n|$)', 'article'), - # Numbered requirements - (r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'), - # "Der Verantwortliche muss" patterns - (r'(?:Der Verantwortliche|Die Aufsichtsbehörde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'), - # "Es ist erforderlich" patterns - (r'(?:Es ist erforderlich|Es muss gewährleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'), - ] - - for pattern, pattern_type in patterns: - matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL) - for match in matches: - if pattern_type == 'bsi_requirement': - req_id = match.group(1) - description = match.group(2).strip() - title = req_id - elif pattern_type == 'article': - article_num = match.group(1) - paragraph = match.group(2) or "" - title_text = match.group(3).strip() - req_id = f"{regulation_code}-Art{article_num}" - if paragraph: - req_id += f"-{paragraph}" - title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "") - description = title_text - elif pattern_type == 'numbered': - num = match.group(1) - description = match.group(2).strip() - req_id = f"{regulation_code}-{num}" - title = f"Anforderung {num}" - else: - # Generic requirement - description = match.group(0).strip() - req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}" - title = description[:50] + "..." if len(description) > 50 else description - - # Skip very short matches - if len(description) < 20: - continue - - checkpoint = Checkpoint( - id=req_id, - regulation_code=regulation_code, - regulation_name=regulation_name, - article=title if 'Art' in title else None, - title=title, - description=description[:500], - original_text=description, - chunk_id=chunk_id, - source_url=source_url - ) - checkpoints.append(checkpoint) - - return checkpoints - - def generate_control_for_checkpoints(self, checkpoints: List[Checkpoint]) -> Optional[Control]: - """ - Generate a control that covers the given checkpoints. - - This is a simplified version - in production this would use the AI assistant. - """ - if not checkpoints: - return None - - # Group by regulation - regulation = checkpoints[0].regulation_code - - # Determine domain based on content - all_text = " ".join([cp.description for cp in checkpoints]).lower() - - domain = "gov" # Default - if any(kw in all_text for kw in ["verschlüssel", "krypto", "encrypt", "hash"]): - domain = "crypto" - elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]): - domain = "iam" - elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]): - domain = "priv" - elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]): - domain = "sdlc" - elif any(kw in all_text for kw in ["überwach", "monitor", "log", "audit"]): - domain = "aud" - elif any(kw in all_text for kw in ["ki", "künstlich", "ai", "machine learning", "model"]): - domain = "ai" - elif any(kw in all_text for kw in ["betrieb", "operation", "verfügbar", "backup"]): - domain = "ops" - elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]): - domain = "cra" - - # Generate control ID - domain_counts = self.stats.get("by_domain", {}) - domain_count = domain_counts.get(domain, 0) + 1 - control_id = f"{domain.upper()}-{domain_count:03d}" - - # Create title from first checkpoint - title = checkpoints[0].title - if len(title) > 100: - title = title[:97] + "..." - - # Create description - description = f"Control für {regulation}: " + checkpoints[0].description[:200] - - # Pass criteria - pass_criteria = f"Alle {len(checkpoints)} zugehörigen Anforderungen sind erfüllt und dokumentiert." - - # Implementation guidance - guidance = f"Implementiere Maßnahmen zur Erfüllung der Anforderungen aus {regulation}. " - guidance += f"Dokumentiere die Umsetzung und führe regelmäßige Reviews durch." - - # Determine if automated - is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"]) - - control = Control( - id=control_id, - domain=domain, - title=title, - description=description, - checkpoints=[cp.id for cp in checkpoints], - pass_criteria=pass_criteria, - implementation_guidance=guidance, - is_automated=is_automated, - automation_tool="CI/CD Pipeline" if is_automated else None, - priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium" - ) - - return control - - def generate_measure_for_control(self, control: Control) -> Measure: - """Generate a remediation measure for a control.""" - measure_id = f"M-{control.id}" - - # Determine deadline based on priority - deadline_days = { - "critical": 30, - "high": 60, - "medium": 90, - "low": 180 - }.get(control.priority, 90) - - # Determine responsible team - responsible = { - "priv": "Datenschutzbeauftragter", - "iam": "IT-Security Team", - "sdlc": "Entwicklungsteam", - "crypto": "IT-Security Team", - "ops": "Operations Team", - "aud": "Compliance Team", - "ai": "AI/ML Team", - "cra": "IT-Security Team", - "gov": "Management" - }.get(control.domain, "Compliance Team") - - measure = Measure( - id=measure_id, - control_id=control.id, - title=f"Umsetzung: {control.title[:50]}", - description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}", - responsible=responsible, - deadline_days=deadline_days, - status="pending" - ) - - return measure - - async def run_ingestion_phase(self, force_reindex: bool = False) -> int: - """Phase 1: Ingest documents (incremental - only missing ones).""" - logger.info("\n" + "=" * 60) - logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)") - logger.info("=" * 60) - - if self.checkpoint_mgr: - self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion") - - ingestion = LegalCorpusIngestion() - - try: - # Check existing chunks per regulation - existing_chunks = {} - try: - for regulation in REGULATIONS: - count_result = self.qdrant.count( - collection_name=LEGAL_CORPUS_COLLECTION, - count_filter=Filter( - must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))] - ) - ) - existing_chunks[regulation.code] = count_result.count - logger.info(f" {regulation.code}: {count_result.count} existing chunks") - except Exception as e: - logger.warning(f"Could not check existing chunks: {e}") - # Collection might not exist, that's OK - - # Determine which regulations need ingestion - regulations_to_ingest = [] - for regulation in REGULATIONS: - existing = existing_chunks.get(regulation.code, 0) - if force_reindex or existing == 0: - regulations_to_ingest.append(regulation) - logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})") - else: - logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)") - self.stats["by_regulation"][regulation.code] = existing - - if not regulations_to_ingest: - logger.info("All regulations already indexed. Skipping ingestion phase.") - total_chunks = sum(existing_chunks.values()) - self.stats["chunks_processed"] = total_chunks - if self.checkpoint_mgr: - self.checkpoint_mgr.add_metric("total_chunks", total_chunks) - self.checkpoint_mgr.add_metric("skipped", True) - self.checkpoint_mgr.complete_checkpoint(success=True) - return total_chunks - - # Ingest only missing regulations - total_chunks = sum(existing_chunks.values()) - for i, regulation in enumerate(regulations_to_ingest, 1): - logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...") - try: - count = await ingestion.ingest_regulation(regulation) - total_chunks += count - self.stats["by_regulation"][regulation.code] = count - logger.info(f" -> {count} chunks") - - # Add metric for this regulation - if self.checkpoint_mgr: - self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count) - - except Exception as e: - logger.error(f" -> FAILED: {e}") - self.stats["by_regulation"][regulation.code] = 0 - - self.stats["chunks_processed"] = total_chunks - logger.info(f"\nTotal chunks in collection: {total_chunks}") - - # Validate ingestion results - if self.checkpoint_mgr: - self.checkpoint_mgr.add_metric("total_chunks", total_chunks) - self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS)) - - # Validate total chunks - expected = EXPECTED_VALUES.get("ingestion", {}) - self.checkpoint_mgr.validate( - "total_chunks", - expected=expected.get("total_chunks", 8000), - actual=total_chunks, - min_value=expected.get("min_chunks", 7000) - ) - - # Validate key regulations - reg_expected = expected.get("regulations", {}) - for reg_code, reg_exp in reg_expected.items(): - actual = self.stats["by_regulation"].get(reg_code, 0) - self.checkpoint_mgr.validate( - f"chunks_{reg_code}", - expected=reg_exp.get("expected", 0), - actual=actual, - min_value=reg_exp.get("min", 0) - ) - - self.checkpoint_mgr.complete_checkpoint(success=True) - - return total_chunks - - except Exception as e: - if self.checkpoint_mgr: - self.checkpoint_mgr.fail_checkpoint(str(e)) - raise - - finally: - await ingestion.close() - - async def run_extraction_phase(self) -> int: - """Phase 2: Extract checkpoints from chunks.""" - logger.info("\n" + "=" * 60) - logger.info("PHASE 2: CHECKPOINT EXTRACTION") - logger.info("=" * 60) - - if self.checkpoint_mgr: - self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction") - - try: - # Scroll through all chunks - offset = None - total_checkpoints = 0 - - while True: - result = self.qdrant.scroll( - collection_name=LEGAL_CORPUS_COLLECTION, - limit=100, - offset=offset, - with_payload=True, - with_vectors=False - ) - - points, next_offset = result - - if not points: - break - - for point in points: - payload = point.payload - text = payload.get("text", "") - - checkpoints = self.extract_checkpoints_from_chunk(text, payload) - self.checkpoints.extend(checkpoints) - total_checkpoints += len(checkpoints) - - logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...") - - if next_offset is None: - break - offset = next_offset - - self.stats["checkpoints_extracted"] = len(self.checkpoints) - logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}") - - # Log per regulation - by_reg = {} - for cp in self.checkpoints: - by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1 - for reg, count in sorted(by_reg.items()): - logger.info(f" {reg}: {count} checkpoints") - - # Validate extraction results - if self.checkpoint_mgr: - self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints)) - self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg) - - expected = EXPECTED_VALUES.get("extraction", {}) - self.checkpoint_mgr.validate( - "total_checkpoints", - expected=expected.get("total_checkpoints", 3500), - actual=len(self.checkpoints), - min_value=expected.get("min_checkpoints", 3000) - ) - - self.checkpoint_mgr.complete_checkpoint(success=True) - - return len(self.checkpoints) - - except Exception as e: - if self.checkpoint_mgr: - self.checkpoint_mgr.fail_checkpoint(str(e)) - raise - - async def run_control_generation_phase(self) -> int: - """Phase 3: Generate controls from checkpoints.""" - logger.info("\n" + "=" * 60) - logger.info("PHASE 3: CONTROL GENERATION") - logger.info("=" * 60) - - if self.checkpoint_mgr: - self.checkpoint_mgr.start_checkpoint("controls", "Control Generation") - - try: - # Group checkpoints by regulation - by_regulation: Dict[str, List[Checkpoint]] = {} - for cp in self.checkpoints: - reg = cp.regulation_code - if reg not in by_regulation: - by_regulation[reg] = [] - by_regulation[reg].append(cp) - - # Generate controls per regulation (group every 3-5 checkpoints) - for regulation, checkpoints in by_regulation.items(): - logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...") - - # Group checkpoints into batches of 3-5 - batch_size = 4 - for i in range(0, len(checkpoints), batch_size): - batch = checkpoints[i:i + batch_size] - control = self.generate_control_for_checkpoints(batch) - - if control: - self.controls.append(control) - self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1 - - self.stats["controls_created"] = len(self.controls) - logger.info(f"\nTotal controls created: {len(self.controls)}") - - # Log per domain - for domain, count in sorted(self.stats["by_domain"].items()): - logger.info(f" {domain}: {count} controls") - - # Validate control generation - if self.checkpoint_mgr: - self.checkpoint_mgr.add_metric("total_controls", len(self.controls)) - self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"])) - - expected = EXPECTED_VALUES.get("controls", {}) - self.checkpoint_mgr.validate( - "total_controls", - expected=expected.get("total_controls", 900), - actual=len(self.controls), - min_value=expected.get("min_controls", 800) - ) - - self.checkpoint_mgr.complete_checkpoint(success=True) - - return len(self.controls) - - except Exception as e: - if self.checkpoint_mgr: - self.checkpoint_mgr.fail_checkpoint(str(e)) - raise - - async def run_measure_generation_phase(self) -> int: - """Phase 4: Generate measures for controls.""" - logger.info("\n" + "=" * 60) - logger.info("PHASE 4: MEASURE GENERATION") - logger.info("=" * 60) - - if self.checkpoint_mgr: - self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation") - - try: - for control in self.controls: - measure = self.generate_measure_for_control(control) - self.measures.append(measure) - - self.stats["measures_defined"] = len(self.measures) - logger.info(f"\nTotal measures defined: {len(self.measures)}") - - # Validate measure generation - if self.checkpoint_mgr: - self.checkpoint_mgr.add_metric("total_measures", len(self.measures)) - - expected = EXPECTED_VALUES.get("measures", {}) - self.checkpoint_mgr.validate( - "total_measures", - expected=expected.get("total_measures", 900), - actual=len(self.measures), - min_value=expected.get("min_measures", 800) - ) - - self.checkpoint_mgr.complete_checkpoint(success=True) - - return len(self.measures) - - except Exception as e: - if self.checkpoint_mgr: - self.checkpoint_mgr.fail_checkpoint(str(e)) - raise - - def save_results(self, output_dir: str = "/tmp/compliance_output"): - """Save results to JSON files.""" - logger.info("\n" + "=" * 60) - logger.info("SAVING RESULTS") - logger.info("=" * 60) - - os.makedirs(output_dir, exist_ok=True) - - # Save checkpoints - checkpoints_file = os.path.join(output_dir, "checkpoints.json") - with open(checkpoints_file, "w") as f: - json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False) - logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}") - - # Save controls - controls_file = os.path.join(output_dir, "controls.json") - with open(controls_file, "w") as f: - json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False) - logger.info(f"Saved {len(self.controls)} controls to {controls_file}") - - # Save measures - measures_file = os.path.join(output_dir, "measures.json") - with open(measures_file, "w") as f: - json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False) - logger.info(f"Saved {len(self.measures)} measures to {measures_file}") - - # Save statistics - stats_file = os.path.join(output_dir, "statistics.json") - self.stats["generated_at"] = datetime.now().isoformat() - with open(stats_file, "w") as f: - json.dump(self.stats, f, indent=2, ensure_ascii=False) - logger.info(f"Saved statistics to {stats_file}") - - async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False): - """Run the complete pipeline. - - Args: - force_reindex: If True, re-ingest all documents even if they exist - skip_ingestion: If True, skip ingestion phase entirely (use existing chunks) - """ - start_time = time.time() - - logger.info("=" * 60) - logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)") - logger.info(f"Started at: {datetime.now().isoformat()}") - logger.info(f"Force reindex: {force_reindex}") - logger.info(f"Skip ingestion: {skip_ingestion}") - if self.checkpoint_mgr: - logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}") - logger.info("=" * 60) - - try: - # Phase 1: Ingestion (skip if requested or run incrementally) - if skip_ingestion: - logger.info("Skipping ingestion phase as requested...") - # Still get the chunk count - try: - collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION) - self.stats["chunks_processed"] = collection_info.points_count - except Exception: - self.stats["chunks_processed"] = 0 - else: - await self.run_ingestion_phase(force_reindex=force_reindex) - - # Phase 2: Extraction - await self.run_extraction_phase() - - # Phase 3: Control Generation - await self.run_control_generation_phase() - - # Phase 4: Measure Generation - await self.run_measure_generation_phase() - - # Save results - self.save_results() - - # Final summary - elapsed = time.time() - start_time - logger.info("\n" + "=" * 60) - logger.info("PIPELINE COMPLETE") - logger.info("=" * 60) - logger.info(f"Duration: {elapsed:.1f} seconds") - logger.info(f"Chunks processed: {self.stats['chunks_processed']}") - logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}") - logger.info(f"Controls created: {self.stats['controls_created']}") - logger.info(f"Measures defined: {self.stats['measures_defined']}") - logger.info(f"\nResults saved to: /tmp/compliance_output/") - logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json") - logger.info("=" * 60) - - # Complete pipeline checkpoint - if self.checkpoint_mgr: - self.checkpoint_mgr.complete_pipeline({ - "duration_seconds": elapsed, - "chunks_processed": self.stats['chunks_processed'], - "checkpoints_extracted": self.stats['checkpoints_extracted'], - "controls_created": self.stats['controls_created'], - "measures_defined": self.stats['measures_defined'], - "by_regulation": self.stats['by_regulation'], - "by_domain": self.stats['by_domain'], - }) - - except Exception as e: - logger.error(f"Pipeline failed: {e}") - if self.checkpoint_mgr: - self.checkpoint_mgr.state.status = "failed" - self.checkpoint_mgr._save() - raise +# Re-export all public symbols +from compliance_models import Checkpoint, Control, Measure +from compliance_extraction import ( + extract_checkpoints_from_chunk, + generate_control_for_checkpoints, + generate_measure_for_control, +) +from compliance_pipeline import CompliancePipeline + +__all__ = [ + "Checkpoint", + "Control", + "Measure", + "extract_checkpoints_from_chunk", + "generate_control_for_checkpoints", + "generate_measure_for_control", + "CompliancePipeline", +] async def main(): diff --git a/klausur-service/backend/github_crawler.py b/klausur-service/backend/github_crawler.py index ff0b4da..bba1705 100644 --- a/klausur-service/backend/github_crawler.py +++ b/klausur-service/backend/github_crawler.py @@ -1,767 +1,35 @@ """ -GitHub Repository Crawler for Legal Templates. +GitHub Repository Crawler — Barrel Re-export -Crawls GitHub and GitLab repositories to extract legal template documents -(Markdown, HTML, JSON, etc.) for ingestion into the RAG system. +Split into: +- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser +- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source -Features: -- Clone repositories via Git or download as ZIP -- Parse Markdown, HTML, JSON, and plain text files -- Extract structured content with metadata -- Track git commit hashes for reproducibility -- Handle rate limiting and errors gracefully +All public names are re-exported here for backward compatibility. """ -import asyncio -import hashlib -import json -import logging -import os -import re -import shutil -import tempfile -import zipfile -from dataclasses import dataclass, field -from datetime import datetime -from fnmatch import fnmatch -from pathlib import Path -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple -from urllib.parse import urlparse - -import httpx - -from template_sources import LicenseType, SourceConfig, LICENSES - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Configuration -GITHUB_API_URL = "https://api.github.com" -GITLAB_API_URL = "https://gitlab.com/api/v4" -GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "") # Optional for higher rate limits -MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size -REQUEST_TIMEOUT = 60.0 -RATE_LIMIT_DELAY = 1.0 # Delay between requests to avoid rate limiting - - -@dataclass -class ExtractedDocument: - """A document extracted from a repository.""" - text: str - title: str - file_path: str - file_type: str # "markdown", "html", "json", "text" - source_url: str - source_commit: Optional[str] = None - source_hash: str = "" # SHA256 of original content - sections: List[Dict[str, Any]] = field(default_factory=list) - placeholders: List[str] = field(default_factory=list) - language: str = "en" - metadata: Dict[str, Any] = field(default_factory=dict) - - def __post_init__(self): - if not self.source_hash and self.text: - self.source_hash = hashlib.sha256(self.text.encode()).hexdigest() - - -class MarkdownParser: - """Parse Markdown files into structured content.""" - - # Common placeholder patterns - PLACEHOLDER_PATTERNS = [ - r'\[([A-Z_]+)\]', # [COMPANY_NAME] - r'\{([a-z_]+)\}', # {company_name} - r'\{\{([a-z_]+)\}\}', # {{company_name}} - r'__([A-Z_]+)__', # __COMPANY_NAME__ - r'<([A-Z_]+)>', # - ] - - @classmethod - def parse(cls, content: str, filename: str = "") -> ExtractedDocument: - """Parse markdown content into an ExtractedDocument.""" - # Extract title from first heading or filename - title = cls._extract_title(content, filename) - - # Extract sections - sections = cls._extract_sections(content) - - # Find placeholders - placeholders = cls._find_placeholders(content) - - # Detect language - language = cls._detect_language(content) - - # Clean content for indexing - clean_text = cls._clean_for_indexing(content) - - return ExtractedDocument( - text=clean_text, - title=title, - file_path=filename, - file_type="markdown", - source_url="", # Will be set by caller - sections=sections, - placeholders=placeholders, - language=language, - ) - - @classmethod - def _extract_title(cls, content: str, filename: str) -> str: - """Extract title from markdown heading or filename.""" - # Look for first h1 heading - h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE) - if h1_match: - return h1_match.group(1).strip() - - # Look for YAML frontmatter title - frontmatter_match = re.search( - r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---', - content, re.DOTALL - ) - if frontmatter_match: - return frontmatter_match.group(1).strip() - - # Fall back to filename - if filename: - name = Path(filename).stem - # Convert kebab-case or snake_case to title case - return name.replace('-', ' ').replace('_', ' ').title() - - return "Untitled" - - @classmethod - def _extract_sections(cls, content: str) -> List[Dict[str, Any]]: - """Extract sections from markdown content.""" - sections = [] - current_section = {"heading": "", "level": 0, "content": "", "start": 0} - - for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE): - # Save previous section if it has content - if current_section["heading"] or current_section["content"].strip(): - current_section["content"] = current_section["content"].strip() - sections.append(current_section.copy()) - - # Start new section - level = len(match.group(1)) - heading = match.group(2).strip() - current_section = { - "heading": heading, - "level": level, - "content": "", - "start": match.end(), - } - - # Add final section - if current_section["heading"] or current_section["content"].strip(): - current_section["content"] = content[current_section["start"]:].strip() - sections.append(current_section) - - return sections - - @classmethod - def _find_placeholders(cls, content: str) -> List[str]: - """Find placeholder patterns in content.""" - placeholders = set() - for pattern in cls.PLACEHOLDER_PATTERNS: - for match in re.finditer(pattern, content): - placeholder = match.group(0) - placeholders.add(placeholder) - return sorted(list(placeholders)) - - @classmethod - def _detect_language(cls, content: str) -> str: - """Detect language from content.""" - # Look for German-specific words - german_indicators = [ - 'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung', - 'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung', - 'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind', - ] - - lower_content = content.lower() - german_count = sum(1 for word in german_indicators if word.lower() in lower_content) - - if german_count >= 3: - return "de" - return "en" - - @classmethod - def _clean_for_indexing(cls, content: str) -> str: - """Clean markdown content for text indexing.""" - # Remove YAML frontmatter - content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL) - - # Remove HTML comments - content = re.sub(r'', '', content, flags=re.DOTALL) - - # Remove inline HTML tags but keep content - content = re.sub(r'<[^>]+>', '', content) - - # Convert markdown formatting - content = re.sub(r'\*\*(.+?)\*\*', r'\1', content) # Bold - content = re.sub(r'\*(.+?)\*', r'\1', content) # Italic - content = re.sub(r'`(.+?)`', r'\1', content) # Inline code - content = re.sub(r'~~(.+?)~~', r'\1', content) # Strikethrough - - # Remove link syntax but keep text - content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content) - - # Remove image syntax - content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content) - - # Clean up whitespace - content = re.sub(r'\n{3,}', '\n\n', content) - content = re.sub(r' +', ' ', content) - - return content.strip() - - -class HTMLParser: - """Parse HTML files into structured content.""" - - @classmethod - def parse(cls, content: str, filename: str = "") -> ExtractedDocument: - """Parse HTML content into an ExtractedDocument.""" - # Extract title - title_match = re.search(r'(.+?)', content, re.IGNORECASE) - title = title_match.group(1) if title_match else Path(filename).stem - - # Convert to text - text = cls._html_to_text(content) - - # Find placeholders - placeholders = MarkdownParser._find_placeholders(text) - - # Detect language - lang_match = re.search(r']*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE) - language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text) - - return ExtractedDocument( - text=text, - title=title, - file_path=filename, - file_type="html", - source_url="", - placeholders=placeholders, - language=language, - ) - - @classmethod - def _html_to_text(cls, html: str) -> str: - """Convert HTML to clean text.""" - # Remove script and style tags - html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) - html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) - - # Remove comments - html = re.sub(r'', '', html, flags=re.DOTALL) - - # Replace common entities - html = html.replace(' ', ' ') - html = html.replace('&', '&') - html = html.replace('<', '<') - html = html.replace('>', '>') - html = html.replace('"', '"') - html = html.replace(''', "'") - - # Add line breaks for block elements - html = re.sub(r'', '\n', html, flags=re.IGNORECASE) - html = re.sub(r'

', '\n\n', html, flags=re.IGNORECASE) - html = re.sub(r'', '\n', html, flags=re.IGNORECASE) - html = re.sub(r'', '\n\n', html, flags=re.IGNORECASE) - html = re.sub(r'', '\n', html, flags=re.IGNORECASE) - - # Remove remaining tags - html = re.sub(r'<[^>]+>', '', html) - - # Clean whitespace - html = re.sub(r'[ \t]+', ' ', html) - html = re.sub(r'\n[ \t]+', '\n', html) - html = re.sub(r'[ \t]+\n', '\n', html) - html = re.sub(r'\n{3,}', '\n\n', html) - - return html.strip() - - -class JSONParser: - """Parse JSON files containing legal template data.""" - - @classmethod - def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]: - """Parse JSON content into ExtractedDocuments.""" - try: - data = json.loads(content) - except json.JSONDecodeError as e: - logger.warning(f"Failed to parse JSON from {filename}: {e}") - return [] - - documents = [] - - if isinstance(data, dict): - # Handle different JSON structures - documents.extend(cls._parse_dict(data, filename)) - elif isinstance(data, list): - for i, item in enumerate(data): - if isinstance(item, dict): - docs = cls._parse_dict(item, f"{filename}[{i}]") - documents.extend(docs) - - return documents - - @classmethod - def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]: - """Parse a dictionary into documents.""" - documents = [] - - # Look for text content in common keys - text_keys = ['text', 'content', 'body', 'description', 'value'] - title_keys = ['title', 'name', 'heading', 'label', 'key'] - - # Try to find main text content - text = "" - for key in text_keys: - if key in data and isinstance(data[key], str): - text = data[key] - break - - if not text: - # Check for nested structures (like webflorist format) - for key, value in data.items(): - if isinstance(value, dict): - nested_docs = cls._parse_dict(value, f"{filename}.{key}") - documents.extend(nested_docs) - elif isinstance(value, list): - for i, item in enumerate(value): - if isinstance(item, dict): - nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]") - documents.extend(nested_docs) - elif isinstance(item, str) and len(item) > 50: - # Treat long strings as content - documents.append(ExtractedDocument( - text=item, - title=f"{key} {i+1}", - file_path=filename, - file_type="json", - source_url="", - language=MarkdownParser._detect_language(item), - )) - return documents - - # Found text content - title = "" - for key in title_keys: - if key in data and isinstance(data[key], str): - title = data[key] - break - - if not title: - title = Path(filename).stem - - # Extract metadata - metadata = {} - for key, value in data.items(): - if key not in text_keys + title_keys and not isinstance(value, (dict, list)): - metadata[key] = value - - placeholders = MarkdownParser._find_placeholders(text) - language = data.get('lang', data.get('language', MarkdownParser._detect_language(text))) - - documents.append(ExtractedDocument( - text=text, - title=title, - file_path=filename, - file_type="json", - source_url="", - placeholders=placeholders, - language=language, - metadata=metadata, - )) - - return documents - - -class GitHubCrawler: - """Crawl GitHub repositories for legal templates.""" - - def __init__(self, token: Optional[str] = None): - self.token = token or GITHUB_TOKEN - self.headers = { - "Accept": "application/vnd.github.v3+json", - "User-Agent": "LegalTemplatesCrawler/1.0", - } - if self.token: - self.headers["Authorization"] = f"token {self.token}" - - self.http_client: Optional[httpx.AsyncClient] = None - - async def __aenter__(self): - self.http_client = httpx.AsyncClient( - timeout=REQUEST_TIMEOUT, - headers=self.headers, - follow_redirects=True, - ) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.http_client: - await self.http_client.aclose() - - def _parse_repo_url(self, url: str) -> Tuple[str, str, str]: - """Parse repository URL into owner, repo, and host.""" - parsed = urlparse(url) - path_parts = parsed.path.strip('/').split('/') - - if len(path_parts) < 2: - raise ValueError(f"Invalid repository URL: {url}") - - owner = path_parts[0] - repo = path_parts[1].replace('.git', '') - - if 'gitlab' in parsed.netloc: - host = 'gitlab' - else: - host = 'github' - - return owner, repo, host - - async def get_default_branch(self, owner: str, repo: str) -> str: - """Get the default branch of a repository.""" - if not self.http_client: - raise RuntimeError("Crawler not initialized. Use 'async with' context.") - - url = f"{GITHUB_API_URL}/repos/{owner}/{repo}" - response = await self.http_client.get(url) - response.raise_for_status() - data = response.json() - return data.get("default_branch", "main") - - async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str: - """Get the latest commit SHA for a branch.""" - if not self.http_client: - raise RuntimeError("Crawler not initialized. Use 'async with' context.") - - url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}" - response = await self.http_client.get(url) - response.raise_for_status() - data = response.json() - return data.get("sha", "") - - async def list_files( - self, - owner: str, - repo: str, - path: str = "", - branch: str = "main", - patterns: List[str] = None, - exclude_patterns: List[str] = None, - ) -> List[Dict[str, Any]]: - """List files in a repository matching the given patterns.""" - if not self.http_client: - raise RuntimeError("Crawler not initialized. Use 'async with' context.") - - patterns = patterns or ["*.md", "*.txt", "*.html"] - exclude_patterns = exclude_patterns or [] - - url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1" - response = await self.http_client.get(url) - response.raise_for_status() - data = response.json() - - files = [] - for item in data.get("tree", []): - if item["type"] != "blob": - continue - - file_path = item["path"] - - # Check exclude patterns - excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns) - if excluded: - continue - - # Check include patterns - matched = any(fnmatch(file_path, pattern) for pattern in patterns) - if not matched: - continue - - # Skip large files - if item.get("size", 0) > MAX_FILE_SIZE: - logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)") - continue - - files.append({ - "path": file_path, - "sha": item["sha"], - "size": item.get("size", 0), - "url": item.get("url", ""), - }) - - return files - - async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str: - """Get the content of a file from a repository.""" - if not self.http_client: - raise RuntimeError("Crawler not initialized. Use 'async with' context.") - - # Use raw content URL for simplicity - url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}" - response = await self.http_client.get(url) - response.raise_for_status() - return response.text - - async def crawl_repository( - self, - source: SourceConfig, - ) -> AsyncGenerator[ExtractedDocument, None]: - """Crawl a repository and yield extracted documents.""" - if not source.repo_url: - logger.warning(f"No repo URL for source: {source.name}") - return - - try: - owner, repo, host = self._parse_repo_url(source.repo_url) - except ValueError as e: - logger.error(f"Failed to parse repo URL for {source.name}: {e}") - return - - if host == "gitlab": - logger.info(f"GitLab repos not yet supported: {source.name}") - return - - logger.info(f"Crawling repository: {owner}/{repo}") - - try: - # Get default branch and latest commit - branch = await self.get_default_branch(owner, repo) - commit_sha = await self.get_latest_commit(owner, repo, branch) - - await asyncio.sleep(RATE_LIMIT_DELAY) - - # List files matching patterns - files = await self.list_files( - owner, repo, - branch=branch, - patterns=source.file_patterns, - exclude_patterns=source.exclude_patterns, - ) - - logger.info(f"Found {len(files)} matching files in {source.name}") - - for file_info in files: - await asyncio.sleep(RATE_LIMIT_DELAY) - - try: - content = await self.get_file_content( - owner, repo, file_info["path"], branch - ) - - # Parse based on file type - file_path = file_info["path"] - source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}" - - if file_path.endswith('.md'): - doc = MarkdownParser.parse(content, file_path) - doc.source_url = source_url - doc.source_commit = commit_sha - yield doc - - elif file_path.endswith('.html') or file_path.endswith('.htm'): - doc = HTMLParser.parse(content, file_path) - doc.source_url = source_url - doc.source_commit = commit_sha - yield doc - - elif file_path.endswith('.json'): - docs = JSONParser.parse(content, file_path) - for doc in docs: - doc.source_url = source_url - doc.source_commit = commit_sha - yield doc - - elif file_path.endswith('.txt'): - # Plain text file - yield ExtractedDocument( - text=content, - title=Path(file_path).stem, - file_path=file_path, - file_type="text", - source_url=source_url, - source_commit=commit_sha, - language=MarkdownParser._detect_language(content), - placeholders=MarkdownParser._find_placeholders(content), - ) - - except httpx.HTTPError as e: - logger.warning(f"Failed to fetch {file_path}: {e}") - continue - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") - continue - - except httpx.HTTPError as e: - logger.error(f"HTTP error crawling {source.name}: {e}") - except Exception as e: - logger.error(f"Error crawling {source.name}: {e}") - - -class RepositoryDownloader: - """Download and extract repository archives.""" - - def __init__(self): - self.http_client: Optional[httpx.AsyncClient] = None - - async def __aenter__(self): - self.http_client = httpx.AsyncClient( - timeout=120.0, - follow_redirects=True, - ) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.http_client: - await self.http_client.aclose() - - async def download_zip(self, repo_url: str, branch: str = "main") -> Path: - """Download repository as ZIP and extract to temp directory.""" - if not self.http_client: - raise RuntimeError("Downloader not initialized. Use 'async with' context.") - - parsed = urlparse(repo_url) - path_parts = parsed.path.strip('/').split('/') - owner = path_parts[0] - repo = path_parts[1].replace('.git', '') - - zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip" - - logger.info(f"Downloading ZIP from {zip_url}") - - response = await self.http_client.get(zip_url) - response.raise_for_status() - - # Save to temp file - temp_dir = Path(tempfile.mkdtemp()) - zip_path = temp_dir / f"{repo}.zip" - - with open(zip_path, 'wb') as f: - f.write(response.content) - - # Extract ZIP - extract_dir = temp_dir / repo - with zipfile.ZipFile(zip_path, 'r') as zip_ref: - zip_ref.extractall(temp_dir) - - # The extracted directory is usually named repo-branch - extracted_dirs = list(temp_dir.glob(f"{repo}-*")) - if extracted_dirs: - return extracted_dirs[0] - - return extract_dir - - async def crawl_local_directory( - self, - directory: Path, - source: SourceConfig, - base_url: str, - ) -> AsyncGenerator[ExtractedDocument, None]: - """Crawl a local directory for documents.""" - patterns = source.file_patterns or ["*.md", "*.txt", "*.html"] - exclude_patterns = source.exclude_patterns or [] - - for pattern in patterns: - for file_path in directory.rglob(pattern.replace("**/", "")): - if not file_path.is_file(): - continue - - rel_path = str(file_path.relative_to(directory)) - - # Check exclude patterns - excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns) - if excluded: - continue - - # Skip large files - if file_path.stat().st_size > MAX_FILE_SIZE: - continue - - try: - content = file_path.read_text(encoding='utf-8') - except UnicodeDecodeError: - try: - content = file_path.read_text(encoding='latin-1') - except Exception: - continue - - source_url = f"{base_url}/{rel_path}" - - if file_path.suffix == '.md': - doc = MarkdownParser.parse(content, rel_path) - doc.source_url = source_url - yield doc - - elif file_path.suffix in ['.html', '.htm']: - doc = HTMLParser.parse(content, rel_path) - doc.source_url = source_url - yield doc - - elif file_path.suffix == '.json': - docs = JSONParser.parse(content, rel_path) - for doc in docs: - doc.source_url = source_url - yield doc - - elif file_path.suffix == '.txt': - yield ExtractedDocument( - text=content, - title=file_path.stem, - file_path=rel_path, - file_type="text", - source_url=source_url, - language=MarkdownParser._detect_language(content), - placeholders=MarkdownParser._find_placeholders(content), - ) - - def cleanup(self, directory: Path): - """Clean up temporary directory.""" - if directory.exists(): - shutil.rmtree(directory, ignore_errors=True) - - -async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]: - """Crawl a source configuration and return all extracted documents.""" - documents = [] - - if source.repo_url: - async with GitHubCrawler() as crawler: - async for doc in crawler.crawl_repository(source): - documents.append(doc) - - return documents - - -# CLI for testing -async def main(): - """Test crawler with a sample source.""" - from template_sources import TEMPLATE_SOURCES - - # Test with github-site-policy - source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy") - - async with GitHubCrawler() as crawler: - count = 0 - async for doc in crawler.crawl_repository(source): - count += 1 - print(f"\n{'='*60}") - print(f"Title: {doc.title}") - print(f"Path: {doc.file_path}") - print(f"Type: {doc.file_type}") - print(f"Language: {doc.language}") - print(f"URL: {doc.source_url}") - print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}") - print(f"Text preview: {doc.text[:200]}...") - - print(f"\n\nTotal documents: {count}") - +# Parsers +from github_crawler_parsers import ( # noqa: F401 + ExtractedDocument, + MarkdownParser, + HTMLParser, + JSONParser, +) + +# Crawler and downloader +from github_crawler_core import ( # noqa: F401 + GITHUB_API_URL, + GITLAB_API_URL, + GITHUB_TOKEN, + MAX_FILE_SIZE, + REQUEST_TIMEOUT, + RATE_LIMIT_DELAY, + GitHubCrawler, + RepositoryDownloader, + crawl_source, + main, +) if __name__ == "__main__": + import asyncio asyncio.run(main()) diff --git a/klausur-service/backend/github_crawler_core.py b/klausur-service/backend/github_crawler_core.py new file mode 100644 index 0000000..4152d30 --- /dev/null +++ b/klausur-service/backend/github_crawler_core.py @@ -0,0 +1,411 @@ +""" +GitHub Crawler - Core Crawler and Downloader + +GitHubCrawler for API-based repository crawling and RepositoryDownloader +for ZIP-based local extraction. + +Extracted from github_crawler.py to keep files under 500 LOC. +""" + +import asyncio +import logging +import os +import shutil +import tempfile +import zipfile +from fnmatch import fnmatch +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +import httpx + +from template_sources import SourceConfig +from github_crawler_parsers import ( + ExtractedDocument, + MarkdownParser, + HTMLParser, + JSONParser, +) + +logger = logging.getLogger(__name__) + +# Configuration +GITHUB_API_URL = "https://api.github.com" +GITLAB_API_URL = "https://gitlab.com/api/v4" +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "") +MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size +REQUEST_TIMEOUT = 60.0 +RATE_LIMIT_DELAY = 1.0 + + +class GitHubCrawler: + """Crawl GitHub repositories for legal templates.""" + + def __init__(self, token: Optional[str] = None): + self.token = token or GITHUB_TOKEN + self.headers = { + "Accept": "application/vnd.github.v3+json", + "User-Agent": "LegalTemplatesCrawler/1.0", + } + if self.token: + self.headers["Authorization"] = f"token {self.token}" + + self.http_client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + self.http_client = httpx.AsyncClient( + timeout=REQUEST_TIMEOUT, + headers=self.headers, + follow_redirects=True, + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.http_client: + await self.http_client.aclose() + + def _parse_repo_url(self, url: str) -> Tuple[str, str, str]: + """Parse repository URL into owner, repo, and host.""" + parsed = urlparse(url) + path_parts = parsed.path.strip('/').split('/') + + if len(path_parts) < 2: + raise ValueError(f"Invalid repository URL: {url}") + + owner = path_parts[0] + repo = path_parts[1].replace('.git', '') + + if 'gitlab' in parsed.netloc: + host = 'gitlab' + else: + host = 'github' + + return owner, repo, host + + async def get_default_branch(self, owner: str, repo: str) -> str: + """Get the default branch of a repository.""" + if not self.http_client: + raise RuntimeError("Crawler not initialized. Use 'async with' context.") + + url = f"{GITHUB_API_URL}/repos/{owner}/{repo}" + response = await self.http_client.get(url) + response.raise_for_status() + data = response.json() + return data.get("default_branch", "main") + + async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str: + """Get the latest commit SHA for a branch.""" + if not self.http_client: + raise RuntimeError("Crawler not initialized. Use 'async with' context.") + + url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}" + response = await self.http_client.get(url) + response.raise_for_status() + data = response.json() + return data.get("sha", "") + + async def list_files( + self, + owner: str, + repo: str, + path: str = "", + branch: str = "main", + patterns: List[str] = None, + exclude_patterns: List[str] = None, + ) -> List[Dict[str, Any]]: + """List files in a repository matching the given patterns.""" + if not self.http_client: + raise RuntimeError("Crawler not initialized. Use 'async with' context.") + + patterns = patterns or ["*.md", "*.txt", "*.html"] + exclude_patterns = exclude_patterns or [] + + url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1" + response = await self.http_client.get(url) + response.raise_for_status() + data = response.json() + + files = [] + for item in data.get("tree", []): + if item["type"] != "blob": + continue + + file_path = item["path"] + + excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns) + if excluded: + continue + + matched = any(fnmatch(file_path, pattern) for pattern in patterns) + if not matched: + continue + + if item.get("size", 0) > MAX_FILE_SIZE: + logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)") + continue + + files.append({ + "path": file_path, + "sha": item["sha"], + "size": item.get("size", 0), + "url": item.get("url", ""), + }) + + return files + + async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str: + """Get the content of a file from a repository.""" + if not self.http_client: + raise RuntimeError("Crawler not initialized. Use 'async with' context.") + + url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}" + response = await self.http_client.get(url) + response.raise_for_status() + return response.text + + async def crawl_repository( + self, + source: SourceConfig, + ) -> AsyncGenerator[ExtractedDocument, None]: + """Crawl a repository and yield extracted documents.""" + if not source.repo_url: + logger.warning(f"No repo URL for source: {source.name}") + return + + try: + owner, repo, host = self._parse_repo_url(source.repo_url) + except ValueError as e: + logger.error(f"Failed to parse repo URL for {source.name}: {e}") + return + + if host == "gitlab": + logger.info(f"GitLab repos not yet supported: {source.name}") + return + + logger.info(f"Crawling repository: {owner}/{repo}") + + try: + branch = await self.get_default_branch(owner, repo) + commit_sha = await self.get_latest_commit(owner, repo, branch) + + await asyncio.sleep(RATE_LIMIT_DELAY) + + files = await self.list_files( + owner, repo, + branch=branch, + patterns=source.file_patterns, + exclude_patterns=source.exclude_patterns, + ) + + logger.info(f"Found {len(files)} matching files in {source.name}") + + for file_info in files: + await asyncio.sleep(RATE_LIMIT_DELAY) + + try: + content = await self.get_file_content( + owner, repo, file_info["path"], branch + ) + + file_path = file_info["path"] + source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}" + + if file_path.endswith('.md'): + doc = MarkdownParser.parse(content, file_path) + doc.source_url = source_url + doc.source_commit = commit_sha + yield doc + + elif file_path.endswith('.html') or file_path.endswith('.htm'): + doc = HTMLParser.parse(content, file_path) + doc.source_url = source_url + doc.source_commit = commit_sha + yield doc + + elif file_path.endswith('.json'): + docs = JSONParser.parse(content, file_path) + for doc in docs: + doc.source_url = source_url + doc.source_commit = commit_sha + yield doc + + elif file_path.endswith('.txt'): + yield ExtractedDocument( + text=content, + title=Path(file_path).stem, + file_path=file_path, + file_type="text", + source_url=source_url, + source_commit=commit_sha, + language=MarkdownParser._detect_language(content), + placeholders=MarkdownParser._find_placeholders(content), + ) + + except httpx.HTTPError as e: + logger.warning(f"Failed to fetch {file_path}: {e}") + continue + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") + continue + + except httpx.HTTPError as e: + logger.error(f"HTTP error crawling {source.name}: {e}") + except Exception as e: + logger.error(f"Error crawling {source.name}: {e}") + + +class RepositoryDownloader: + """Download and extract repository archives.""" + + def __init__(self): + self.http_client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + self.http_client = httpx.AsyncClient( + timeout=120.0, + follow_redirects=True, + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.http_client: + await self.http_client.aclose() + + async def download_zip(self, repo_url: str, branch: str = "main") -> Path: + """Download repository as ZIP and extract to temp directory.""" + if not self.http_client: + raise RuntimeError("Downloader not initialized. Use 'async with' context.") + + parsed = urlparse(repo_url) + path_parts = parsed.path.strip('/').split('/') + owner = path_parts[0] + repo = path_parts[1].replace('.git', '') + + zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip" + + logger.info(f"Downloading ZIP from {zip_url}") + + response = await self.http_client.get(zip_url) + response.raise_for_status() + + temp_dir = Path(tempfile.mkdtemp()) + zip_path = temp_dir / f"{repo}.zip" + + with open(zip_path, 'wb') as f: + f.write(response.content) + + extract_dir = temp_dir / repo + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + extracted_dirs = list(temp_dir.glob(f"{repo}-*")) + if extracted_dirs: + return extracted_dirs[0] + + return extract_dir + + async def crawl_local_directory( + self, + directory: Path, + source: SourceConfig, + base_url: str, + ) -> AsyncGenerator[ExtractedDocument, None]: + """Crawl a local directory for documents.""" + patterns = source.file_patterns or ["*.md", "*.txt", "*.html"] + exclude_patterns = source.exclude_patterns or [] + + for pattern in patterns: + for file_path in directory.rglob(pattern.replace("**/", "")): + if not file_path.is_file(): + continue + + rel_path = str(file_path.relative_to(directory)) + + excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns) + if excluded: + continue + + if file_path.stat().st_size > MAX_FILE_SIZE: + continue + + try: + content = file_path.read_text(encoding='utf-8') + except UnicodeDecodeError: + try: + content = file_path.read_text(encoding='latin-1') + except Exception: + continue + + source_url = f"{base_url}/{rel_path}" + + if file_path.suffix == '.md': + doc = MarkdownParser.parse(content, rel_path) + doc.source_url = source_url + yield doc + + elif file_path.suffix in ['.html', '.htm']: + doc = HTMLParser.parse(content, rel_path) + doc.source_url = source_url + yield doc + + elif file_path.suffix == '.json': + docs = JSONParser.parse(content, rel_path) + for doc in docs: + doc.source_url = source_url + yield doc + + elif file_path.suffix == '.txt': + yield ExtractedDocument( + text=content, + title=file_path.stem, + file_path=rel_path, + file_type="text", + source_url=source_url, + language=MarkdownParser._detect_language(content), + placeholders=MarkdownParser._find_placeholders(content), + ) + + def cleanup(self, directory: Path): + """Clean up temporary directory.""" + if directory.exists(): + shutil.rmtree(directory, ignore_errors=True) + + +async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]: + """Crawl a source configuration and return all extracted documents.""" + documents = [] + + if source.repo_url: + async with GitHubCrawler() as crawler: + async for doc in crawler.crawl_repository(source): + documents.append(doc) + + return documents + + +# CLI for testing +async def main(): + """Test crawler with a sample source.""" + from template_sources import TEMPLATE_SOURCES + + source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy") + + async with GitHubCrawler() as crawler: + count = 0 + async for doc in crawler.crawl_repository(source): + count += 1 + print(f"\n{'='*60}") + print(f"Title: {doc.title}") + print(f"Path: {doc.file_path}") + print(f"Type: {doc.file_type}") + print(f"Language: {doc.language}") + print(f"URL: {doc.source_url}") + print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}") + print(f"Text preview: {doc.text[:200]}...") + + print(f"\n\nTotal documents: {count}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/klausur-service/backend/github_crawler_parsers.py b/klausur-service/backend/github_crawler_parsers.py new file mode 100644 index 0000000..416c9eb --- /dev/null +++ b/klausur-service/backend/github_crawler_parsers.py @@ -0,0 +1,303 @@ +""" +GitHub Crawler - Document Parsers + +Markdown, HTML, and JSON parsers for extracting structured content +from legal template documents. + +Extracted from github_crawler.py to keep files under 500 LOC. +""" + +import hashlib +import json +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + + +@dataclass +class ExtractedDocument: + """A document extracted from a repository.""" + text: str + title: str + file_path: str + file_type: str # "markdown", "html", "json", "text" + source_url: str + source_commit: Optional[str] = None + source_hash: str = "" # SHA256 of original content + sections: List[Dict[str, Any]] = field(default_factory=list) + placeholders: List[str] = field(default_factory=list) + language: str = "en" + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.source_hash and self.text: + self.source_hash = hashlib.sha256(self.text.encode()).hexdigest() + + +class MarkdownParser: + """Parse Markdown files into structured content.""" + + # Common placeholder patterns + PLACEHOLDER_PATTERNS = [ + r'\[([A-Z_]+)\]', # [COMPANY_NAME] + r'\{([a-z_]+)\}', # {company_name} + r'\{\{([a-z_]+)\}\}', # {{company_name}} + r'__([A-Z_]+)__', # __COMPANY_NAME__ + r'<([A-Z_]+)>', # + ] + + @classmethod + def parse(cls, content: str, filename: str = "") -> ExtractedDocument: + """Parse markdown content into an ExtractedDocument.""" + title = cls._extract_title(content, filename) + sections = cls._extract_sections(content) + placeholders = cls._find_placeholders(content) + language = cls._detect_language(content) + clean_text = cls._clean_for_indexing(content) + + return ExtractedDocument( + text=clean_text, + title=title, + file_path=filename, + file_type="markdown", + source_url="", + sections=sections, + placeholders=placeholders, + language=language, + ) + + @classmethod + def _extract_title(cls, content: str, filename: str) -> str: + """Extract title from markdown heading or filename.""" + h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE) + if h1_match: + return h1_match.group(1).strip() + + frontmatter_match = re.search( + r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---', + content, re.DOTALL + ) + if frontmatter_match: + return frontmatter_match.group(1).strip() + + if filename: + name = Path(filename).stem + return name.replace('-', ' ').replace('_', ' ').title() + + return "Untitled" + + @classmethod + def _extract_sections(cls, content: str) -> List[Dict[str, Any]]: + """Extract sections from markdown content.""" + sections = [] + current_section = {"heading": "", "level": 0, "content": "", "start": 0} + + for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE): + if current_section["heading"] or current_section["content"].strip(): + current_section["content"] = current_section["content"].strip() + sections.append(current_section.copy()) + + level = len(match.group(1)) + heading = match.group(2).strip() + current_section = { + "heading": heading, + "level": level, + "content": "", + "start": match.end(), + } + + if current_section["heading"] or current_section["content"].strip(): + current_section["content"] = content[current_section["start"]:].strip() + sections.append(current_section) + + return sections + + @classmethod + def _find_placeholders(cls, content: str) -> List[str]: + """Find placeholder patterns in content.""" + placeholders = set() + for pattern in cls.PLACEHOLDER_PATTERNS: + for match in re.finditer(pattern, content): + placeholder = match.group(0) + placeholders.add(placeholder) + return sorted(list(placeholders)) + + @classmethod + def _detect_language(cls, content: str) -> str: + """Detect language from content.""" + german_indicators = [ + 'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung', + 'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung', + 'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind', + ] + + lower_content = content.lower() + german_count = sum(1 for word in german_indicators if word.lower() in lower_content) + + if german_count >= 3: + return "de" + return "en" + + @classmethod + def _clean_for_indexing(cls, content: str) -> str: + """Clean markdown content for text indexing.""" + content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL) + content = re.sub(r'', '', content, flags=re.DOTALL) + content = re.sub(r'<[^>]+>', '', content) + content = re.sub(r'\*\*(.+?)\*\*', r'\1', content) + content = re.sub(r'\*(.+?)\*', r'\1', content) + content = re.sub(r'`(.+?)`', r'\1', content) + content = re.sub(r'~~(.+?)~~', r'\1', content) + content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content) + content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content) + content = re.sub(r'\n{3,}', '\n\n', content) + content = re.sub(r' +', ' ', content) + + return content.strip() + + +class HTMLParser: + """Parse HTML files into structured content.""" + + @classmethod + def parse(cls, content: str, filename: str = "") -> ExtractedDocument: + """Parse HTML content into an ExtractedDocument.""" + title_match = re.search(r'(.+?)', content, re.IGNORECASE) + title = title_match.group(1) if title_match else Path(filename).stem + + text = cls._html_to_text(content) + placeholders = MarkdownParser._find_placeholders(text) + + lang_match = re.search(r']*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE) + language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text) + + return ExtractedDocument( + text=text, + title=title, + file_path=filename, + file_type="html", + source_url="", + placeholders=placeholders, + language=language, + ) + + @classmethod + def _html_to_text(cls, html: str) -> str: + """Convert HTML to clean text.""" + html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) + html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) + html = re.sub(r'', '', html, flags=re.DOTALL) + + html = html.replace(' ', ' ') + html = html.replace('&', '&') + html = html.replace('<', '<') + html = html.replace('>', '>') + html = html.replace('"', '"') + html = html.replace(''', "'") + + html = re.sub(r'', '\n', html, flags=re.IGNORECASE) + html = re.sub(r'

', '\n\n', html, flags=re.IGNORECASE) + html = re.sub(r'', '\n', html, flags=re.IGNORECASE) + html = re.sub(r'', '\n\n', html, flags=re.IGNORECASE) + html = re.sub(r'', '\n', html, flags=re.IGNORECASE) + + html = re.sub(r'<[^>]+>', '', html) + + html = re.sub(r'[ \t]+', ' ', html) + html = re.sub(r'\n[ \t]+', '\n', html) + html = re.sub(r'[ \t]+\n', '\n', html) + html = re.sub(r'\n{3,}', '\n\n', html) + + return html.strip() + + +class JSONParser: + """Parse JSON files containing legal template data.""" + + @classmethod + def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]: + """Parse JSON content into ExtractedDocuments.""" + try: + data = json.loads(content) + except json.JSONDecodeError as e: + import logging + logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}") + return [] + + documents = [] + + if isinstance(data, dict): + documents.extend(cls._parse_dict(data, filename)) + elif isinstance(data, list): + for i, item in enumerate(data): + if isinstance(item, dict): + docs = cls._parse_dict(item, f"{filename}[{i}]") + documents.extend(docs) + + return documents + + @classmethod + def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]: + """Parse a dictionary into documents.""" + documents = [] + + text_keys = ['text', 'content', 'body', 'description', 'value'] + title_keys = ['title', 'name', 'heading', 'label', 'key'] + + text = "" + for key in text_keys: + if key in data and isinstance(data[key], str): + text = data[key] + break + + if not text: + for key, value in data.items(): + if isinstance(value, dict): + nested_docs = cls._parse_dict(value, f"{filename}.{key}") + documents.extend(nested_docs) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, dict): + nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]") + documents.extend(nested_docs) + elif isinstance(item, str) and len(item) > 50: + documents.append(ExtractedDocument( + text=item, + title=f"{key} {i+1}", + file_path=filename, + file_type="json", + source_url="", + language=MarkdownParser._detect_language(item), + )) + return documents + + title = "" + for key in title_keys: + if key in data and isinstance(data[key], str): + title = data[key] + break + + if not title: + title = Path(filename).stem + + metadata = {} + for key, value in data.items(): + if key not in text_keys + title_keys and not isinstance(value, (dict, list)): + metadata[key] = value + + placeholders = MarkdownParser._find_placeholders(text) + language = data.get('lang', data.get('language', MarkdownParser._detect_language(text))) + + documents.append(ExtractedDocument( + text=text, + title=title, + file_path=filename, + file_type="json", + source_url="", + placeholders=placeholders, + language=language, + metadata=metadata, + )) + + return documents diff --git a/klausur-service/backend/legal_corpus_api.py b/klausur-service/backend/legal_corpus_api.py index 8c1730a..4d41cb0 100644 --- a/klausur-service/backend/legal_corpus_api.py +++ b/klausur-service/backend/legal_corpus_api.py @@ -1,790 +1,30 @@ """ -Legal Corpus API - Endpoints for RAG page in admin-v2 +Legal Corpus API — Barrel Re-export -Provides endpoints for: -- GET /api/v1/admin/legal-corpus/status - Collection status with chunk counts -- GET /api/v1/admin/legal-corpus/search - Semantic search -- POST /api/v1/admin/legal-corpus/ingest - Trigger ingestion -- GET /api/v1/admin/legal-corpus/ingestion-status - Ingestion status -- POST /api/v1/admin/legal-corpus/upload - Upload document -- POST /api/v1/admin/legal-corpus/add-link - Add link for ingestion -- POST /api/v1/admin/pipeline/start - Start compliance pipeline +Split into: +- legal_corpus_routes.py — Corpus endpoints (status, search, ingest, upload) +- legal_corpus_pipeline.py — Pipeline endpoints (checkpoints, start, status) + +All public names are re-exported here for backward compatibility. """ -import os -import asyncio -import httpx -import uuid -import shutil -from datetime import datetime -from typing import Optional, List, Dict, Any -from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form -from pydantic import BaseModel -import logging - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"]) - -# Configuration -QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333") -EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087") -COLLECTION_NAME = "bp_legal_corpus" - -# All regulations for status endpoint -REGULATIONS = [ - {"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"}, - {"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"}, - {"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"}, - {"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"}, - {"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"}, - {"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"}, - {"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"}, - {"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"}, - {"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"}, - {"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"}, - {"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"}, - {"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"}, - {"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"}, - {"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"}, - {"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"}, - {"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"}, - {"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"}, - {"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"}, - {"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"}, -] - -# Ingestion state (in-memory for now) -ingestion_state = { - "running": False, - "completed": False, - "current_regulation": None, - "processed": 0, - "total": len(REGULATIONS), - "error": None, -} - - -class SearchRequest(BaseModel): - query: str - regulations: Optional[List[str]] = None - top_k: int = 5 - - -class IngestRequest(BaseModel): - force: bool = False - regulations: Optional[List[str]] = None - - -class AddLinkRequest(BaseModel): - url: str - title: str - code: str # Regulation code (e.g. "CUSTOM-1") - document_type: str = "custom" # custom, eu_regulation, eu_directive, de_law, bsi_standard - - -class StartPipelineRequest(BaseModel): - force_reindex: bool = False - skip_ingestion: bool = False - - -# Store for custom documents (in-memory for now, should be persisted) -custom_documents: List[Dict[str, Any]] = [] - - -async def get_qdrant_client(): - """Get async HTTP client for Qdrant.""" - return httpx.AsyncClient(timeout=30.0) - - -@router.get("/status") -async def get_legal_corpus_status(): - """ - Get status of the legal corpus collection including chunk counts per regulation. - """ - async with httpx.AsyncClient(timeout=30.0) as client: - try: - # Get collection info - collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}") - if collection_res.status_code != 200: - return { - "collection": COLLECTION_NAME, - "totalPoints": 0, - "vectorSize": 1024, - "status": "not_found", - "regulations": {}, - } - - collection_data = collection_res.json() - result = collection_data.get("result", {}) - - # Get chunk counts per regulation - regulation_counts = {} - for reg in REGULATIONS: - count_res = await client.post( - f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count", - json={ - "filter": { - "must": [{"key": "regulation_code", "match": {"value": reg["code"]}}] - } - }, - ) - if count_res.status_code == 200: - count_data = count_res.json() - regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0) - else: - regulation_counts[reg["code"]] = 0 - - return { - "collection": COLLECTION_NAME, - "totalPoints": result.get("points_count", 0), - "vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024), - "status": result.get("status", "unknown"), - "regulations": regulation_counts, - } - - except httpx.RequestError as e: - logger.error(f"Failed to get Qdrant status: {e}") - raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}") - - -@router.get("/search") -async def search_legal_corpus( - query: str = Query(..., description="Search query"), - top_k: int = Query(5, ge=1, le=20, description="Number of results"), - regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"), -): - """ - Semantic search in legal corpus using BGE-M3 embeddings. - """ - async with httpx.AsyncClient(timeout=60.0) as client: - try: - # Generate embedding for query - embed_res = await client.post( - f"{EMBEDDING_SERVICE_URL}/embed", - json={"texts": [query]}, - ) - if embed_res.status_code != 200: - raise HTTPException(status_code=500, detail="Embedding service error") - - embed_data = embed_res.json() - query_vector = embed_data["embeddings"][0] - - # Build Qdrant search request - search_request = { - "vector": query_vector, - "limit": top_k, - "with_payload": True, - } - - # Add regulation filter if specified - if regulations: - reg_codes = [r.strip() for r in regulations.split(",")] - search_request["filter"] = { - "should": [ - {"key": "regulation_code", "match": {"value": code}} - for code in reg_codes - ] - } - - # Search Qdrant - search_res = await client.post( - f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search", - json=search_request, - ) - - if search_res.status_code != 200: - raise HTTPException(status_code=500, detail="Search failed") - - search_data = search_res.json() - results = [] - for point in search_data.get("result", []): - payload = point.get("payload", {}) - results.append({ - "text": payload.get("text", ""), - "regulation_code": payload.get("regulation_code", ""), - "regulation_name": payload.get("regulation_name", ""), - "article": payload.get("article"), - "paragraph": payload.get("paragraph"), - "source_url": payload.get("source_url", ""), - "score": point.get("score", 0), - }) - - return {"results": results, "query": query, "count": len(results)} - - except httpx.RequestError as e: - logger.error(f"Search failed: {e}") - raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}") - - -@router.post("/ingest") -async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks): - """ - Trigger legal corpus ingestion in background. - """ - global ingestion_state - - if ingestion_state["running"]: - raise HTTPException(status_code=409, detail="Ingestion already running") - - # Reset state - ingestion_state = { - "running": True, - "completed": False, - "current_regulation": None, - "processed": 0, - "total": len(REGULATIONS), - "error": None, - } - - # Start ingestion in background - background_tasks.add_task(run_ingestion, request.force, request.regulations) - - return { - "status": "started", - "job_id": "manual-trigger", - "message": f"Ingestion started for {len(REGULATIONS)} regulations", - } - - -async def run_ingestion(force: bool, regulations: Optional[List[str]]): - """Background task for running ingestion.""" - global ingestion_state - - try: - # Import ingestion module - from legal_corpus_ingestion import LegalCorpusIngestion - - ingestion = LegalCorpusIngestion() - - # Filter regulations if specified - regs_to_process = regulations or [r["code"] for r in REGULATIONS] - - for i, reg_code in enumerate(regs_to_process): - ingestion_state["current_regulation"] = reg_code - ingestion_state["processed"] = i - - try: - await ingestion.ingest_single(reg_code, force=force) - except Exception as e: - logger.error(f"Failed to ingest {reg_code}: {e}") - - ingestion_state["completed"] = True - ingestion_state["processed"] = len(regs_to_process) - - except Exception as e: - logger.error(f"Ingestion failed: {e}") - ingestion_state["error"] = str(e) - - finally: - ingestion_state["running"] = False - - -@router.get("/ingestion-status") -async def get_ingestion_status(): - """ - Get current ingestion status. - """ - return ingestion_state - - -@router.get("/regulations") -async def get_regulations(): - """ - Get list of all supported regulations. - """ - return {"regulations": REGULATIONS} - - -@router.get("/custom-documents") -async def get_custom_documents(): - """ - Get list of custom documents added by user. - """ - return {"documents": custom_documents} - - -@router.post("/upload") -async def upload_document( - background_tasks: BackgroundTasks, - file: UploadFile = File(...), - title: str = Form(...), - code: str = Form(...), - document_type: str = Form("custom"), -): - """ - Upload a document (PDF) for ingestion into the legal corpus. - - The document will be saved and queued for processing. - """ - global custom_documents - - # Validate file type - if not file.filename.endswith(('.pdf', '.PDF')): - raise HTTPException(status_code=400, detail="Only PDF files are supported") - - # Create upload directory if needed - upload_dir = "/tmp/legal_corpus_uploads" - os.makedirs(upload_dir, exist_ok=True) - - # Save file with unique name - doc_id = str(uuid.uuid4())[:8] - safe_filename = f"{doc_id}_{file.filename}" - file_path = os.path.join(upload_dir, safe_filename) - - try: - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - except Exception as e: - logger.error(f"Failed to save uploaded file: {e}") - raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}") - - # Create document record - doc_record = { - "id": doc_id, - "code": code, - "title": title, - "filename": file.filename, - "file_path": file_path, - "document_type": document_type, - "uploaded_at": datetime.now().isoformat(), - "status": "uploaded", - "chunk_count": 0, - } - - custom_documents.append(doc_record) - - # Queue for background ingestion - background_tasks.add_task(ingest_uploaded_document, doc_record) - - return { - "status": "uploaded", - "document_id": doc_id, - "message": f"Document '{title}' uploaded and queued for ingestion", - "document": doc_record, - } - - -async def ingest_uploaded_document(doc_record: Dict[str, Any]): - """Background task to ingest an uploaded document.""" - global custom_documents - - try: - doc_record["status"] = "processing" - - from legal_corpus_ingestion import LegalCorpusIngestion - ingestion = LegalCorpusIngestion() - - # Read PDF and extract text - import fitz # PyMuPDF - - doc = fitz.open(doc_record["file_path"]) - full_text = "" - for page in doc: - full_text += page.get_text() - doc.close() - - if not full_text.strip(): - doc_record["status"] = "error" - doc_record["error"] = "No text could be extracted from PDF" - return - - # Chunk the text - chunks = ingestion.chunk_text(full_text, doc_record["code"]) - - # Add metadata - for chunk in chunks: - chunk["regulation_code"] = doc_record["code"] - chunk["regulation_name"] = doc_record["title"] - chunk["document_type"] = doc_record["document_type"] - chunk["source_url"] = f"upload://{doc_record['filename']}" - - # Generate embeddings and upsert to Qdrant - if chunks: - await ingestion.embed_and_upsert(chunks) - doc_record["chunk_count"] = len(chunks) - doc_record["status"] = "indexed" - logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}") - else: - doc_record["status"] = "error" - doc_record["error"] = "No chunks generated from document" - - except Exception as e: - logger.error(f"Failed to ingest uploaded document: {e}") - doc_record["status"] = "error" - doc_record["error"] = str(e) - - -@router.post("/add-link") -async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks): - """ - Add a URL/link for ingestion into the legal corpus. - - The content will be fetched, extracted, and indexed. - """ - global custom_documents - - # Create document record - doc_id = str(uuid.uuid4())[:8] - doc_record = { - "id": doc_id, - "code": request.code, - "title": request.title, - "url": request.url, - "document_type": request.document_type, - "uploaded_at": datetime.now().isoformat(), - "status": "queued", - "chunk_count": 0, - } - - custom_documents.append(doc_record) - - # Queue for background ingestion - background_tasks.add_task(ingest_link_document, doc_record) - - return { - "status": "queued", - "document_id": doc_id, - "message": f"Link '{request.title}' queued for ingestion", - "document": doc_record, - } - - -async def ingest_link_document(doc_record: Dict[str, Any]): - """Background task to ingest content from a URL.""" - global custom_documents - - try: - doc_record["status"] = "fetching" - - async with httpx.AsyncClient(timeout=60.0) as client: - # Fetch the URL - response = await client.get(doc_record["url"], follow_redirects=True) - response.raise_for_status() - - content_type = response.headers.get("content-type", "") - - if "application/pdf" in content_type: - # Save PDF and process - import tempfile - with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: - f.write(response.content) - pdf_path = f.name - - import fitz - pdf_doc = fitz.open(pdf_path) - full_text = "" - for page in pdf_doc: - full_text += page.get_text() - pdf_doc.close() - os.unlink(pdf_path) - - elif "text/html" in content_type: - # Extract text from HTML - from bs4 import BeautifulSoup - soup = BeautifulSoup(response.text, "html.parser") - - # Remove script and style elements - for script in soup(["script", "style", "nav", "footer", "header"]): - script.decompose() - - full_text = soup.get_text(separator="\n", strip=True) - - else: - # Try to use as plain text - full_text = response.text - - if not full_text.strip(): - doc_record["status"] = "error" - doc_record["error"] = "No text could be extracted from URL" - return - - doc_record["status"] = "processing" - - from legal_corpus_ingestion import LegalCorpusIngestion - ingestion = LegalCorpusIngestion() - - # Chunk the text - chunks = ingestion.chunk_text(full_text, doc_record["code"]) - - # Add metadata - for chunk in chunks: - chunk["regulation_code"] = doc_record["code"] - chunk["regulation_name"] = doc_record["title"] - chunk["document_type"] = doc_record["document_type"] - chunk["source_url"] = doc_record["url"] - - # Generate embeddings and upsert to Qdrant - if chunks: - await ingestion.embed_and_upsert(chunks) - doc_record["chunk_count"] = len(chunks) - doc_record["status"] = "indexed" - logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}") - else: - doc_record["status"] = "error" - doc_record["error"] = "No chunks generated from content" - - except httpx.HTTPError as e: - logger.error(f"Failed to fetch URL: {e}") - doc_record["status"] = "error" - doc_record["error"] = f"Failed to fetch URL: {str(e)}" - except Exception as e: - logger.error(f"Failed to ingest URL content: {e}") - doc_record["status"] = "error" - doc_record["error"] = str(e) - - -@router.delete("/custom-documents/{doc_id}") -async def delete_custom_document(doc_id: str): - """ - Delete a custom document from the list. - Note: This does not remove the chunks from Qdrant yet. - """ - global custom_documents - - doc = next((d for d in custom_documents if d["id"] == doc_id), None) - if not doc: - raise HTTPException(status_code=404, detail="Document not found") - - custom_documents = [d for d in custom_documents if d["id"] != doc_id] - - # TODO: Also remove chunks from Qdrant by filtering on code - - return {"status": "deleted", "document_id": doc_id} - - -# ========== Pipeline Checkpoints ========== - -# Create a separate router for pipeline-related endpoints -pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"]) - - -@pipeline_router.get("/checkpoints") -async def get_pipeline_checkpoints(): - """ - Get current pipeline checkpoint state. - - Returns the current state of the compliance pipeline including: - - Pipeline ID and overall status - - Start and completion times - - All checkpoints with their validations and metrics - - Summary data - """ - from pipeline_checkpoints import CheckpointManager - - state = CheckpointManager.load_state() - - if state is None: - return { - "status": "no_data", - "message": "No pipeline run data available yet.", - "pipeline_id": None, - "checkpoints": [], - "summary": {} - } - - # Enrich with validation summary - validation_summary = { - "passed": 0, - "warning": 0, - "failed": 0, - "total": 0 - } - - for checkpoint in state.get("checkpoints", []): - for validation in checkpoint.get("validations", []): - validation_summary["total"] += 1 - status = validation.get("status", "not_run") - if status in validation_summary: - validation_summary[status] += 1 - - state["validation_summary"] = validation_summary - - return state - - -@pipeline_router.get("/checkpoints/history") -async def get_pipeline_history(): - """ - Get list of previous pipeline runs (if stored). - For now, returns only current run. - """ - from pipeline_checkpoints import CheckpointManager - - state = CheckpointManager.load_state() - - if state is None: - return {"runs": []} - - return { - "runs": [{ - "pipeline_id": state.get("pipeline_id"), - "status": state.get("status"), - "started_at": state.get("started_at"), - "completed_at": state.get("completed_at"), - }] - } - - -# Pipeline state for start/stop -pipeline_process_state = { - "running": False, - "pid": None, - "started_at": None, -} - - -@pipeline_router.post("/start") -async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks): - """ - Start the compliance pipeline in the background. - - This runs the full_compliance_pipeline.py script which: - 1. Ingests all legal documents (unless skip_ingestion=True) - 2. Extracts requirements and controls - 3. Generates compliance measures - 4. Creates checkpoint data for monitoring - """ - global pipeline_process_state - - # Check if already running - from pipeline_checkpoints import CheckpointManager - state = CheckpointManager.load_state() - - if state and state.get("status") == "running": - raise HTTPException( - status_code=409, - detail="Pipeline is already running" - ) - - if pipeline_process_state["running"]: - raise HTTPException( - status_code=409, - detail="Pipeline start already in progress" - ) - - pipeline_process_state["running"] = True - pipeline_process_state["started_at"] = datetime.now().isoformat() - - # Start pipeline in background - background_tasks.add_task( - run_pipeline_background, - request.force_reindex, - request.skip_ingestion - ) - - return { - "status": "starting", - "message": "Compliance pipeline is starting in background", - "started_at": pipeline_process_state["started_at"], - } - - -async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool): - """Background task to run the compliance pipeline.""" - global pipeline_process_state - - try: - import subprocess - import sys - - # Build command - cmd = [sys.executable, "full_compliance_pipeline.py"] - if force_reindex: - cmd.append("--force-reindex") - if skip_ingestion: - cmd.append("--skip-ingestion") - - # Run as subprocess - logger.info(f"Starting pipeline: {' '.join(cmd)}") - - process = subprocess.Popen( - cmd, - cwd=os.path.dirname(os.path.abspath(__file__)), - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - ) - - pipeline_process_state["pid"] = process.pid - - # Wait for completion (non-blocking via asyncio) - import asyncio - while process.poll() is None: - await asyncio.sleep(5) - - return_code = process.returncode - - if return_code != 0: - output = process.stdout.read() if process.stdout else "" - logger.error(f"Pipeline failed with code {return_code}: {output}") - else: - logger.info("Pipeline completed successfully") - - except Exception as e: - logger.error(f"Failed to run pipeline: {e}") - - finally: - pipeline_process_state["running"] = False - pipeline_process_state["pid"] = None - - -@pipeline_router.get("/status") -async def get_pipeline_status(): - """ - Get current pipeline running status. - """ - from pipeline_checkpoints import CheckpointManager - - state = CheckpointManager.load_state() - checkpoint_status = state.get("status") if state else "no_data" - - return { - "process_running": pipeline_process_state["running"], - "process_pid": pipeline_process_state["pid"], - "process_started_at": pipeline_process_state["started_at"], - "checkpoint_status": checkpoint_status, - "current_phase": state.get("current_phase") if state else None, - } - - -# ========== Traceability / Quality Endpoints ========== - -@router.get("/traceability") -async def get_traceability( - chunk_id: str = Query(..., description="Chunk ID or identifier"), - regulation: str = Query(..., description="Regulation code"), -): - """ - Get traceability information for a specific chunk. - - Returns: - - The chunk details - - Requirements extracted from this chunk - - Controls derived from those requirements - - Note: This is a placeholder that will be enhanced once the - requirements extraction pipeline is fully implemented. - """ - async with httpx.AsyncClient(timeout=30.0) as client: - try: - # Try to find the chunk by scrolling through points with the regulation filter - # In a production system, we would have proper IDs and indexing - - # For now, return placeholder structure - # The actual implementation will query: - # 1. The chunk from Qdrant - # 2. Requirements from a requirements collection/table - # 3. Controls from a controls collection/table - - return { - "chunk_id": chunk_id, - "regulation": regulation, - "requirements": [], - "controls": [], - "message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind." - } - - except Exception as e: - logger.error(f"Failed to get traceability: {e}") - raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}") +# Corpus routes and state +from legal_corpus_routes import ( # noqa: F401 + router, + REGULATIONS, + COLLECTION_NAME, + QDRANT_URL, + EMBEDDING_SERVICE_URL, + ingestion_state, + custom_documents, + SearchRequest, + IngestRequest, + AddLinkRequest, +) + +# Pipeline routes and state +from legal_corpus_pipeline import ( # noqa: F401 + pipeline_router, + pipeline_process_state, + StartPipelineRequest, +) diff --git a/klausur-service/backend/legal_corpus_ingest_tasks.py b/klausur-service/backend/legal_corpus_ingest_tasks.py new file mode 100644 index 0000000..0e3fbfb --- /dev/null +++ b/klausur-service/backend/legal_corpus_ingest_tasks.py @@ -0,0 +1,166 @@ +""" +Legal Corpus API - Background Ingestion Tasks + +Background tasks for ingesting uploaded documents and URL links +into the legal corpus vector database. + +Extracted from legal_corpus_routes.py to keep files under 500 LOC. +""" + +import os +import logging +from typing import Dict, Any, Optional, List + +import httpx + +logger = logging.getLogger(__name__) + + +async def ingest_uploaded_document(doc_record: Dict[str, Any]): + """Background task to ingest an uploaded document.""" + try: + doc_record["status"] = "processing" + + from legal_corpus_ingestion import LegalCorpusIngestion + ingestion = LegalCorpusIngestion() + + import fitz + doc = fitz.open(doc_record["file_path"]) + full_text = "" + for page in doc: + full_text += page.get_text() + doc.close() + + if not full_text.strip(): + doc_record["status"] = "error" + doc_record["error"] = "No text could be extracted from PDF" + return + + chunks = ingestion.chunk_text(full_text, doc_record["code"]) + + for chunk in chunks: + chunk["regulation_code"] = doc_record["code"] + chunk["regulation_name"] = doc_record["title"] + chunk["document_type"] = doc_record["document_type"] + chunk["source_url"] = f"upload://{doc_record['filename']}" + + if chunks: + await ingestion.embed_and_upsert(chunks) + doc_record["chunk_count"] = len(chunks) + doc_record["status"] = "indexed" + logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}") + else: + doc_record["status"] = "error" + doc_record["error"] = "No chunks generated from document" + + except Exception as e: + logger.error(f"Failed to ingest uploaded document: {e}") + doc_record["status"] = "error" + doc_record["error"] = str(e) + + +async def ingest_link_document(doc_record: Dict[str, Any]): + """Background task to ingest content from a URL.""" + try: + doc_record["status"] = "fetching" + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.get(doc_record["url"], follow_redirects=True) + response.raise_for_status() + + content_type = response.headers.get("content-type", "") + + if "application/pdf" in content_type: + import tempfile + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + f.write(response.content) + pdf_path = f.name + + import fitz + pdf_doc = fitz.open(pdf_path) + full_text = "" + for page in pdf_doc: + full_text += page.get_text() + pdf_doc.close() + os.unlink(pdf_path) + + elif "text/html" in content_type: + from bs4 import BeautifulSoup + soup = BeautifulSoup(response.text, "html.parser") + + for script in soup(["script", "style", "nav", "footer", "header"]): + script.decompose() + + full_text = soup.get_text(separator="\n", strip=True) + + else: + full_text = response.text + + if not full_text.strip(): + doc_record["status"] = "error" + doc_record["error"] = "No text could be extracted from URL" + return + + doc_record["status"] = "processing" + + from legal_corpus_ingestion import LegalCorpusIngestion + ingestion = LegalCorpusIngestion() + + chunks = ingestion.chunk_text(full_text, doc_record["code"]) + + for chunk in chunks: + chunk["regulation_code"] = doc_record["code"] + chunk["regulation_name"] = doc_record["title"] + chunk["document_type"] = doc_record["document_type"] + chunk["source_url"] = doc_record["url"] + + if chunks: + await ingestion.embed_and_upsert(chunks) + doc_record["chunk_count"] = len(chunks) + doc_record["status"] = "indexed" + logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}") + else: + doc_record["status"] = "error" + doc_record["error"] = "No chunks generated from content" + + except httpx.HTTPError as e: + logger.error(f"Failed to fetch URL: {e}") + doc_record["status"] = "error" + doc_record["error"] = f"Failed to fetch URL: {str(e)}" + except Exception as e: + logger.error(f"Failed to ingest URL content: {e}") + doc_record["status"] = "error" + doc_record["error"] = str(e) + + +async def run_ingestion( + force: bool, + regulations: Optional[List[str]], + ingestion_state: Dict[str, Any], + all_regulations: List[Dict[str, str]], +): + """Background task for running full corpus ingestion.""" + try: + from legal_corpus_ingestion import LegalCorpusIngestion + ingestion = LegalCorpusIngestion() + + regs_to_process = regulations or [r["code"] for r in all_regulations] + + for i, reg_code in enumerate(regs_to_process): + ingestion_state["current_regulation"] = reg_code + ingestion_state["processed"] = i + + try: + await ingestion.ingest_single(reg_code, force=force) + except Exception as e: + logger.error(f"Failed to ingest {reg_code}: {e}") + + ingestion_state["completed"] = True + ingestion_state["processed"] = len(regs_to_process) + + except Exception as e: + logger.error(f"Ingestion failed: {e}") + ingestion_state["error"] = str(e) + + finally: + ingestion_state["running"] = False diff --git a/klausur-service/backend/legal_corpus_pipeline.py b/klausur-service/backend/legal_corpus_pipeline.py new file mode 100644 index 0000000..639ff1a --- /dev/null +++ b/klausur-service/backend/legal_corpus_pipeline.py @@ -0,0 +1,206 @@ +""" +Legal Corpus API - Pipeline Routes + +Pipeline checkpoints, history, start/stop, and status endpoints. + +Extracted from legal_corpus_api.py to keep files under 500 LOC. +""" + +import os +import asyncio +from datetime import datetime +from fastapi import APIRouter, HTTPException, BackgroundTasks +from pydantic import BaseModel +import logging + +logger = logging.getLogger(__name__) + + +class StartPipelineRequest(BaseModel): + force_reindex: bool = False + skip_ingestion: bool = False + + +# Create a separate router for pipeline-related endpoints +pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"]) + + +@pipeline_router.get("/checkpoints") +async def get_pipeline_checkpoints(): + """ + Get current pipeline checkpoint state. + + Returns the current state of the compliance pipeline including: + - Pipeline ID and overall status + - Start and completion times + - All checkpoints with their validations and metrics + - Summary data + """ + from pipeline_checkpoints import CheckpointManager + + state = CheckpointManager.load_state() + + if state is None: + return { + "status": "no_data", + "message": "No pipeline run data available yet.", + "pipeline_id": None, + "checkpoints": [], + "summary": {} + } + + # Enrich with validation summary + validation_summary = { + "passed": 0, + "warning": 0, + "failed": 0, + "total": 0 + } + + for checkpoint in state.get("checkpoints", []): + for validation in checkpoint.get("validations", []): + validation_summary["total"] += 1 + status = validation.get("status", "not_run") + if status in validation_summary: + validation_summary[status] += 1 + + state["validation_summary"] = validation_summary + + return state + + +@pipeline_router.get("/checkpoints/history") +async def get_pipeline_history(): + """ + Get list of previous pipeline runs (if stored). + For now, returns only current run. + """ + from pipeline_checkpoints import CheckpointManager + + state = CheckpointManager.load_state() + + if state is None: + return {"runs": []} + + return { + "runs": [{ + "pipeline_id": state.get("pipeline_id"), + "status": state.get("status"), + "started_at": state.get("started_at"), + "completed_at": state.get("completed_at"), + }] + } + + +# Pipeline state for start/stop +pipeline_process_state = { + "running": False, + "pid": None, + "started_at": None, +} + + +@pipeline_router.post("/start") +async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks): + """ + Start the compliance pipeline in the background. + + This runs the full_compliance_pipeline.py script which: + 1. Ingests all legal documents (unless skip_ingestion=True) + 2. Extracts requirements and controls + 3. Generates compliance measures + 4. Creates checkpoint data for monitoring + """ + global pipeline_process_state + + from pipeline_checkpoints import CheckpointManager + state = CheckpointManager.load_state() + + if state and state.get("status") == "running": + raise HTTPException( + status_code=409, + detail="Pipeline is already running" + ) + + if pipeline_process_state["running"]: + raise HTTPException( + status_code=409, + detail="Pipeline start already in progress" + ) + + pipeline_process_state["running"] = True + pipeline_process_state["started_at"] = datetime.now().isoformat() + + background_tasks.add_task( + run_pipeline_background, + request.force_reindex, + request.skip_ingestion + ) + + return { + "status": "starting", + "message": "Compliance pipeline is starting in background", + "started_at": pipeline_process_state["started_at"], + } + + +async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool): + """Background task to run the compliance pipeline.""" + global pipeline_process_state + + try: + import subprocess + import sys + + cmd = [sys.executable, "full_compliance_pipeline.py"] + if force_reindex: + cmd.append("--force-reindex") + if skip_ingestion: + cmd.append("--skip-ingestion") + + logger.info(f"Starting pipeline: {' '.join(cmd)}") + + process = subprocess.Popen( + cmd, + cwd=os.path.dirname(os.path.abspath(__file__)), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + pipeline_process_state["pid"] = process.pid + + while process.poll() is None: + await asyncio.sleep(5) + + return_code = process.returncode + + if return_code != 0: + output = process.stdout.read() if process.stdout else "" + logger.error(f"Pipeline failed with code {return_code}: {output}") + else: + logger.info("Pipeline completed successfully") + + except Exception as e: + logger.error(f"Failed to run pipeline: {e}") + + finally: + pipeline_process_state["running"] = False + pipeline_process_state["pid"] = None + + +@pipeline_router.get("/status") +async def get_pipeline_status(): + """Get current pipeline running status.""" + from pipeline_checkpoints import CheckpointManager + + state = CheckpointManager.load_state() + checkpoint_status = state.get("status") if state else "no_data" + + return { + "process_running": pipeline_process_state["running"], + "process_pid": pipeline_process_state["pid"], + "process_started_at": pipeline_process_state["started_at"], + "checkpoint_status": checkpoint_status, + "current_phase": state.get("current_phase") if state else None, + } diff --git a/klausur-service/backend/legal_corpus_routes.py b/klausur-service/backend/legal_corpus_routes.py new file mode 100644 index 0000000..a32d0ce --- /dev/null +++ b/klausur-service/backend/legal_corpus_routes.py @@ -0,0 +1,368 @@ +""" +Legal Corpus API - Corpus Routes + +Endpoints for the RAG page in admin-v2: +- GET /status - Collection status with chunk counts +- GET /search - Semantic search +- POST /ingest - Trigger ingestion +- GET /ingestion-status - Ingestion status +- GET /regulations - List regulations +- GET /custom-documents - List custom docs +- POST /upload - Upload document +- POST /add-link - Add link for ingestion +- DELETE /custom-documents/{id} - Delete custom doc +- GET /traceability - Traceability info + +Extracted from legal_corpus_api.py to keep files under 500 LOC. +""" + +import os +import httpx +import uuid +import shutil +from datetime import datetime +from typing import Optional, List, Dict, Any +from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form +from pydantic import BaseModel +import logging + +from legal_corpus_ingest_tasks import ( + ingest_uploaded_document, + ingest_link_document, + run_ingestion, +) + +logger = logging.getLogger(__name__) + +# Configuration +QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333") +EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087") +COLLECTION_NAME = "bp_legal_corpus" + +# All regulations for status endpoint +REGULATIONS = [ + {"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"}, + {"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"}, + {"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"}, + {"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"}, + {"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"}, + {"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"}, + {"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"}, + {"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"}, + {"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"}, + {"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"}, + {"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"}, + {"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"}, + {"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"}, + {"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"}, + {"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"}, + {"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"}, + {"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"}, + {"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"}, + {"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"}, +] + +# Ingestion state (in-memory for now) +ingestion_state = { + "running": False, + "completed": False, + "current_regulation": None, + "processed": 0, + "total": len(REGULATIONS), + "error": None, +} + + +class SearchRequest(BaseModel): + query: str + regulations: Optional[List[str]] = None + top_k: int = 5 + + +class IngestRequest(BaseModel): + force: bool = False + regulations: Optional[List[str]] = None + + +class AddLinkRequest(BaseModel): + url: str + title: str + code: str + document_type: str = "custom" + + +# Store for custom documents (in-memory for now) +custom_documents: List[Dict[str, Any]] = [] + + +router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"]) + + +@router.get("/status") +async def get_legal_corpus_status(): + """Get status of the legal corpus collection including chunk counts per regulation.""" + async with httpx.AsyncClient(timeout=30.0) as client: + try: + collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}") + if collection_res.status_code != 200: + return { + "collection": COLLECTION_NAME, + "totalPoints": 0, + "vectorSize": 1024, + "status": "not_found", + "regulations": {}, + } + + collection_data = collection_res.json() + result = collection_data.get("result", {}) + + regulation_counts = {} + for reg in REGULATIONS: + count_res = await client.post( + f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count", + json={ + "filter": { + "must": [{"key": "regulation_code", "match": {"value": reg["code"]}}] + } + }, + ) + if count_res.status_code == 200: + count_data = count_res.json() + regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0) + else: + regulation_counts[reg["code"]] = 0 + + return { + "collection": COLLECTION_NAME, + "totalPoints": result.get("points_count", 0), + "vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024), + "status": result.get("status", "unknown"), + "regulations": regulation_counts, + } + + except httpx.RequestError as e: + logger.error(f"Failed to get Qdrant status: {e}") + raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}") + + +@router.get("/search") +async def search_legal_corpus( + query: str = Query(..., description="Search query"), + top_k: int = Query(5, ge=1, le=20, description="Number of results"), + regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"), +): + """Semantic search in legal corpus using BGE-M3 embeddings.""" + async with httpx.AsyncClient(timeout=60.0) as client: + try: + embed_res = await client.post( + f"{EMBEDDING_SERVICE_URL}/embed", + json={"texts": [query]}, + ) + if embed_res.status_code != 200: + raise HTTPException(status_code=500, detail="Embedding service error") + + embed_data = embed_res.json() + query_vector = embed_data["embeddings"][0] + + search_request = { + "vector": query_vector, + "limit": top_k, + "with_payload": True, + } + + if regulations: + reg_codes = [r.strip() for r in regulations.split(",")] + search_request["filter"] = { + "should": [ + {"key": "regulation_code", "match": {"value": code}} + for code in reg_codes + ] + } + + search_res = await client.post( + f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search", + json=search_request, + ) + + if search_res.status_code != 200: + raise HTTPException(status_code=500, detail="Search failed") + + search_data = search_res.json() + results = [] + for point in search_data.get("result", []): + payload = point.get("payload", {}) + results.append({ + "text": payload.get("text", ""), + "regulation_code": payload.get("regulation_code", ""), + "regulation_name": payload.get("regulation_name", ""), + "article": payload.get("article"), + "paragraph": payload.get("paragraph"), + "source_url": payload.get("source_url", ""), + "score": point.get("score", 0), + }) + + return {"results": results, "query": query, "count": len(results)} + + except httpx.RequestError as e: + logger.error(f"Search failed: {e}") + raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}") + + +@router.post("/ingest") +async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks): + """Trigger legal corpus ingestion in background.""" + global ingestion_state + + if ingestion_state["running"]: + raise HTTPException(status_code=409, detail="Ingestion already running") + + ingestion_state = { + "running": True, + "completed": False, + "current_regulation": None, + "processed": 0, + "total": len(REGULATIONS), + "error": None, + } + + background_tasks.add_task(run_ingestion, request.force, request.regulations, ingestion_state, REGULATIONS) + + return { + "status": "started", + "job_id": "manual-trigger", + "message": f"Ingestion started for {len(REGULATIONS)} regulations", + } + + +@router.get("/ingestion-status") +async def get_ingestion_status(): + """Get current ingestion status.""" + return ingestion_state + + +@router.get("/regulations") +async def get_regulations(): + """Get list of all supported regulations.""" + return {"regulations": REGULATIONS} + + +@router.get("/custom-documents") +async def get_custom_documents(): + """Get list of custom documents added by user.""" + return {"documents": custom_documents} + + +@router.post("/upload") +async def upload_document( + background_tasks: BackgroundTasks, + file: UploadFile = File(...), + title: str = Form(...), + code: str = Form(...), + document_type: str = Form("custom"), +): + """Upload a document (PDF) for ingestion into the legal corpus.""" + global custom_documents + + if not file.filename.endswith(('.pdf', '.PDF')): + raise HTTPException(status_code=400, detail="Only PDF files are supported") + + upload_dir = "/tmp/legal_corpus_uploads" + os.makedirs(upload_dir, exist_ok=True) + + doc_id = str(uuid.uuid4())[:8] + safe_filename = f"{doc_id}_{file.filename}" + file_path = os.path.join(upload_dir, safe_filename) + + try: + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + except Exception as e: + logger.error(f"Failed to save uploaded file: {e}") + raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}") + + doc_record = { + "id": doc_id, + "code": code, + "title": title, + "filename": file.filename, + "file_path": file_path, + "document_type": document_type, + "uploaded_at": datetime.now().isoformat(), + "status": "uploaded", + "chunk_count": 0, + } + + custom_documents.append(doc_record) + background_tasks.add_task(ingest_uploaded_document, doc_record) + + return { + "status": "uploaded", + "document_id": doc_id, + "message": f"Document '{title}' uploaded and queued for ingestion", + "document": doc_record, + } + + + +@router.post("/add-link") +async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks): + """Add a URL/link for ingestion into the legal corpus.""" + global custom_documents + + doc_id = str(uuid.uuid4())[:8] + doc_record = { + "id": doc_id, + "code": request.code, + "title": request.title, + "url": request.url, + "document_type": request.document_type, + "uploaded_at": datetime.now().isoformat(), + "status": "queued", + "chunk_count": 0, + } + + custom_documents.append(doc_record) + background_tasks.add_task(ingest_link_document, doc_record) + + return { + "status": "queued", + "document_id": doc_id, + "message": f"Link '{request.title}' queued for ingestion", + "document": doc_record, + } + + + +@router.delete("/custom-documents/{doc_id}") +async def delete_custom_document(doc_id: str): + """Delete a custom document from the list.""" + global custom_documents + + doc = next((d for d in custom_documents if d["id"] == doc_id), None) + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + custom_documents = [d for d in custom_documents if d["id"] != doc_id] + + return {"status": "deleted", "document_id": doc_id} + + +@router.get("/traceability") +async def get_traceability( + chunk_id: str = Query(..., description="Chunk ID or identifier"), + regulation: str = Query(..., description="Regulation code"), +): + """Get traceability information for a specific chunk.""" + async with httpx.AsyncClient(timeout=30.0) as client: + try: + return { + "chunk_id": chunk_id, + "regulation": regulation, + "requirements": [], + "controls": [], + "message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind." + } + + except Exception as e: + logger.error(f"Failed to get traceability: {e}") + raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}") diff --git a/klausur-service/backend/mail/ai_category.py b/klausur-service/backend/mail/ai_category.py new file mode 100644 index 0000000..857caa6 --- /dev/null +++ b/klausur-service/backend/mail/ai_category.py @@ -0,0 +1,269 @@ +""" +AI Email - Category Classification and Response Suggestions + +Rule-based and LLM-based email category classification, +plus response suggestion generation. + +Extracted from ai_service.py to keep files under 500 LOC. +""" + +import os +import logging +from typing import Optional, List, Tuple + +import httpx + +from .models import ( + EmailCategory, + SenderType, + ResponseSuggestion, +) + +logger = logging.getLogger(__name__) + +# LLM Gateway configuration +LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090") + + +async def classify_category( + http_client: httpx.AsyncClient, + subject: str, + body_preview: str, + sender_type: SenderType, +) -> Tuple[EmailCategory, float]: + """ + Classify email into a category. + + Rule-based classification first, falls back to LLM. + """ + category, confidence = _classify_category_rules(subject, body_preview, sender_type) + + if confidence > 0.7: + return category, confidence + + return await _classify_category_llm(http_client, subject, body_preview) + + +def _classify_category_rules( + subject: str, + body_preview: str, + sender_type: SenderType, +) -> Tuple[EmailCategory, float]: + """Rule-based category classification.""" + text = f"{subject} {body_preview}".lower() + + category_keywords = { + EmailCategory.DIENSTLICH: [ + "dienstlich", "dienstanweisung", "erlass", "verordnung", + "bescheid", "verfuegung", "ministerium", "behoerde" + ], + EmailCategory.PERSONAL: [ + "personalrat", "stellenausschreibung", "versetzung", + "beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit" + ], + EmailCategory.FINANZEN: [ + "budget", "haushalt", "etat", "abrechnung", "rechnung", + "erstattung", "zuschuss", "foerdermittel" + ], + EmailCategory.ELTERN: [ + "elternbrief", "elternabend", "schulkonferenz", + "elternvertreter", "elternbeirat" + ], + EmailCategory.SCHUELER: [ + "schueler", "schuelerin", "zeugnis", "klasse", "unterricht", + "pruefung", "klassenfahrt", "schulpflicht" + ], + EmailCategory.FORTBILDUNG: [ + "fortbildung", "seminar", "workshop", "schulung", + "weiterbildung", "nlq", "didaktik" + ], + EmailCategory.VERANSTALTUNG: [ + "einladung", "veranstaltung", "termin", "konferenz", + "sitzung", "tagung", "feier" + ], + EmailCategory.SICHERHEIT: [ + "sicherheit", "notfall", "brandschutz", "evakuierung", + "hygiene", "corona", "infektionsschutz" + ], + EmailCategory.TECHNIK: [ + "it", "software", "computer", "netzwerk", "login", + "passwort", "digitalisierung", "iserv" + ], + EmailCategory.NEWSLETTER: [ + "newsletter", "rundschreiben", "info-mail", "mitteilung" + ], + EmailCategory.WERBUNG: [ + "angebot", "rabatt", "aktion", "werbung", "abonnement" + ], + } + + best_category = EmailCategory.SONSTIGES + best_score = 0.0 + + for category, keywords in category_keywords.items(): + score = sum(1 for kw in keywords if kw in text) + if score > best_score: + best_score = score + best_category = category + + if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]: + if best_category == EmailCategory.SONSTIGES: + best_category = EmailCategory.DIENSTLICH + best_score = 2 + + confidence = min(0.9, 0.4 + (best_score * 0.15)) + + return best_category, confidence + + +async def _classify_category_llm( + client: httpx.AsyncClient, + subject: str, + body_preview: str, +) -> Tuple[EmailCategory, float]: + """LLM-based category classification.""" + try: + categories = ", ".join([c.value for c in EmailCategory]) + + prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie: + +Betreff: {subject} +Inhalt: {body_preview[:500]} + +Kategorien: {categories} + +Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0): +Format: kategorie|konfidenz +""" + + response = await client.post( + f"{LLM_GATEWAY_URL}/api/v1/inference", + json={ + "prompt": prompt, + "playbook": "mail_analysis", + "max_tokens": 50, + }, + ) + + if response.status_code == 200: + data = response.json() + result = data.get("response", "sonstiges|0.5") + parts = result.strip().split("|") + + if len(parts) >= 2: + category_str = parts[0].strip().lower() + confidence = float(parts[1].strip()) + + try: + category = EmailCategory(category_str) + return category, min(max(confidence, 0.0), 1.0) + except ValueError: + pass + + except Exception as e: + logger.warning(f"LLM category classification failed: {e}") + + return EmailCategory.SONSTIGES, 0.5 + + +async def suggest_response( + http_client: httpx.AsyncClient, + subject: str, + body_text: str, + sender_type: SenderType, + category: EmailCategory, +) -> List[ResponseSuggestion]: + """Generate response suggestions for an email.""" + suggestions = [] + + if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]: + suggestions.append(ResponseSuggestion( + template_type="acknowledgment", + subject=f"Re: {subject}", + body="""Sehr geehrte Damen und Herren, + +vielen Dank fuer Ihre Nachricht. + +Ich bestaetige den Eingang und werde die Angelegenheit fristgerecht bearbeiten. + +Mit freundlichen Gruessen""", + confidence=0.8, + )) + + if category == EmailCategory.ELTERN: + suggestions.append(ResponseSuggestion( + template_type="parent_response", + subject=f"Re: {subject}", + body="""Liebe Eltern, + +vielen Dank fuer Ihre Nachricht. + +[Ihre Antwort hier] + +Mit freundlichen Gruessen""", + confidence=0.7, + )) + + try: + llm_suggestion = await _generate_response_llm(http_client, subject, body_text[:500], sender_type) + if llm_suggestion: + suggestions.append(llm_suggestion) + except Exception as e: + logger.warning(f"LLM response generation failed: {e}") + + return suggestions + + +async def _generate_response_llm( + client: httpx.AsyncClient, + subject: str, + body_preview: str, + sender_type: SenderType, +) -> Optional[ResponseSuggestion]: + """Generate a response suggestion using LLM.""" + try: + sender_desc = { + SenderType.KULTUSMINISTERIUM: "dem Kultusministerium", + SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehoerde", + SenderType.RLSB: "dem RLSB", + SenderType.ELTERNVERTRETER: "einem Elternvertreter", + }.get(sender_type, "einem Absender") + + prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}: + +Betreff: {subject} +Inhalt: {body_preview} + +Die Antwort sollte: +- Hoeflich und formell sein +- Den Eingang bestaetigen +- Eine konkrete naechste Aktion nennen oder um Klaerung bitten + +Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:"). +""" + + response = await client.post( + f"{LLM_GATEWAY_URL}/api/v1/inference", + json={ + "prompt": prompt, + "playbook": "mail_analysis", + "max_tokens": 300, + }, + ) + + if response.status_code == 200: + data = response.json() + body = data.get("response", "").strip() + + if body: + return ResponseSuggestion( + template_type="ai_generated", + subject=f"Re: {subject}", + body=body, + confidence=0.6, + ) + + except Exception as e: + logger.warning(f"LLM response generation failed: {e}") + + return None diff --git a/klausur-service/backend/mail/ai_deadline.py b/klausur-service/backend/mail/ai_deadline.py new file mode 100644 index 0000000..6b9ad3e --- /dev/null +++ b/klausur-service/backend/mail/ai_deadline.py @@ -0,0 +1,184 @@ +""" +AI Email - Deadline Extraction + +Regex-based and LLM-based deadline extraction from email content. + +Extracted from ai_service.py to keep files under 500 LOC. +""" + +import os +import re +import logging +from typing import List +from datetime import datetime, timedelta + +import httpx + +from .models import DeadlineExtraction + +logger = logging.getLogger(__name__) + +# LLM Gateway configuration +LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090") + + +async def extract_deadlines( + http_client: httpx.AsyncClient, + subject: str, + body_text: str, +) -> List[DeadlineExtraction]: + """ + Extract deadlines from email content. + + Uses regex patterns first, then LLM for complex cases. + """ + deadlines = [] + + full_text = f"{subject}\n{body_text}" if body_text else subject + + # Try regex extraction first + regex_deadlines = _extract_deadlines_regex(full_text) + deadlines.extend(regex_deadlines) + + # If no regex matches, try LLM + if not deadlines and body_text: + llm_deadlines = await _extract_deadlines_llm(http_client, subject, body_text[:1000]) + deadlines.extend(llm_deadlines) + + return deadlines + + +def _extract_deadlines_regex(text: str) -> List[DeadlineExtraction]: + """Extract deadlines using regex patterns.""" + deadlines = [] + now = datetime.now() + + # German date patterns + patterns = [ + # "bis zum 15.01.2025" + (r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True), + # "spaetestens am 15.01.2025" + (r"sp\u00e4testens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True), + # "Abgabetermin: 15.01.2025" + (r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True), + # "innerhalb von 14 Tagen" + (r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False), + # "bis Ende Januar" + (r"bis\s+(?:Ende\s+)?(Januar|Februar|M\u00e4rz|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False), + ] + + for pattern, is_specific_date in patterns: + matches = re.finditer(pattern, text, re.IGNORECASE) + + for match in matches: + try: + if is_specific_date: + day = int(match.group(1)) + month = int(match.group(2)) + year = int(match.group(3)) + + if year < 100: + year += 2000 + + deadline_date = datetime(year, month, day) + + if deadline_date < now: + continue + + start = max(0, match.start() - 50) + end = min(len(text), match.end() + 50) + context = text[start:end].strip() + + deadlines.append(DeadlineExtraction( + deadline_date=deadline_date, + description=f"Frist: {match.group(0)}", + confidence=0.85, + source_text=context, + is_firm=True, + )) + + else: + if "Tagen" in pattern or "Wochen" in pattern: + days = int(match.group(1)) + if "Wochen" in match.group(0).lower(): + days *= 7 + deadline_date = now + timedelta(days=days) + + deadlines.append(DeadlineExtraction( + deadline_date=deadline_date, + description=f"Relative Frist: {match.group(0)}", + confidence=0.7, + source_text=match.group(0), + is_firm=False, + )) + + except (ValueError, IndexError) as e: + logger.debug(f"Failed to parse date: {e}") + continue + + return deadlines + + +async def _extract_deadlines_llm( + client: httpx.AsyncClient, + subject: str, + body_preview: str, +) -> List[DeadlineExtraction]: + """Extract deadlines using LLM.""" + try: + prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine: + +Betreff: {subject} +Inhalt: {body_preview} + +Liste alle Fristen im folgenden Format auf (eine pro Zeile): +DATUM|BESCHREIBUNG|VERBINDLICH +Beispiel: 2025-01-15|Abgabe der Berichte|ja + +Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN + +Antworte NUR im angegebenen Format. +""" + + response = await client.post( + f"{LLM_GATEWAY_URL}/api/v1/inference", + json={ + "prompt": prompt, + "playbook": "mail_analysis", + "max_tokens": 200, + }, + ) + + if response.status_code == 200: + data = response.json() + result_text = data.get("response", "") + + if "KEINE_FRISTEN" in result_text: + return [] + + deadlines = [] + for line in result_text.strip().split("\n"): + parts = line.split("|") + if len(parts) >= 2: + try: + date_str = parts[0].strip() + deadline_date = datetime.fromisoformat(date_str) + description = parts[1].strip() + is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True + + deadlines.append(DeadlineExtraction( + deadline_date=deadline_date, + description=description, + confidence=0.7, + source_text=line, + is_firm=is_firm, + )) + except (ValueError, IndexError): + continue + + return deadlines + + except Exception as e: + logger.warning(f"LLM deadline extraction failed: {e}") + + return [] diff --git a/klausur-service/backend/mail/ai_sender.py b/klausur-service/backend/mail/ai_sender.py new file mode 100644 index 0000000..3dc07b1 --- /dev/null +++ b/klausur-service/backend/mail/ai_sender.py @@ -0,0 +1,134 @@ +""" +AI Email - Sender Classification + +Domain-based and LLM-based sender classification for emails. + +Extracted from ai_service.py to keep files under 500 LOC. +""" + +import os +import logging +from typing import Optional + +import httpx + +from .models import ( + SenderType, + SenderClassification, + classify_sender_by_domain, +) + +logger = logging.getLogger(__name__) + +# LLM Gateway configuration +LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090") + + +async def classify_sender( + http_client: httpx.AsyncClient, + sender_email: str, + sender_name: Optional[str] = None, + subject: Optional[str] = None, + body_preview: Optional[str] = None, +) -> SenderClassification: + """ + Classify the sender of an email. + + First tries domain matching, then falls back to LLM. + """ + # Try domain-based classification first (fast, high confidence) + domain_result = classify_sender_by_domain(sender_email) + if domain_result: + return domain_result + + # Fall back to LLM classification + return await _classify_sender_llm( + http_client, sender_email, sender_name, subject, body_preview + ) + + +async def _classify_sender_llm( + client: httpx.AsyncClient, + sender_email: str, + sender_name: Optional[str], + subject: Optional[str], + body_preview: Optional[str], +) -> SenderClassification: + """Classify sender using LLM.""" + try: + prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn: + +Absender E-Mail: {sender_email} +Absender Name: {sender_name or "Nicht angegeben"} +Betreff: {subject or "Nicht angegeben"} +Vorschau: {body_preview[:200] if body_preview else "Nicht verfuegbar"} + +Klassifiziere den Absender in EINE der folgenden Kategorien: +- kultusministerium: Kultusministerium/Bildungsministerium +- landesschulbehoerde: Landesschulbehoerde +- rlsb: Regionales Landesamt fuer Schule und Bildung +- schulamt: Schulamt +- nibis: Niedersaechsischer Bildungsserver +- schultraeger: Schultraeger/Kommune +- elternvertreter: Elternvertreter/Elternrat +- gewerkschaft: Gewerkschaft (GEW, VBE, etc.) +- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.) +- privatperson: Privatperson +- unternehmen: Unternehmen/Firma +- unbekannt: Nicht einzuordnen + +Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0. +Format: kategorie|konfidenz|kurze_begruendung +""" + + response = await client.post( + f"{LLM_GATEWAY_URL}/api/v1/inference", + json={ + "prompt": prompt, + "playbook": "mail_analysis", + "max_tokens": 100, + }, + ) + + if response.status_code == 200: + data = response.json() + result_text = data.get("response", "unbekannt|0.5|") + + parts = result_text.strip().split("|") + if len(parts) >= 2: + sender_type_str = parts[0].strip().lower() + confidence = float(parts[1].strip()) + + type_mapping = { + "kultusministerium": SenderType.KULTUSMINISTERIUM, + "landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE, + "rlsb": SenderType.RLSB, + "schulamt": SenderType.SCHULAMT, + "nibis": SenderType.NIBIS, + "schultraeger": SenderType.SCHULTRAEGER, + "elternvertreter": SenderType.ELTERNVERTRETER, + "gewerkschaft": SenderType.GEWERKSCHAFT, + "fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT, + "privatperson": SenderType.PRIVATPERSON, + "unternehmen": SenderType.UNTERNEHMEN, + } + + sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT) + + return SenderClassification( + sender_type=sender_type, + confidence=min(max(confidence, 0.0), 1.0), + domain_matched=False, + ai_classified=True, + ) + + except Exception as e: + logger.warning(f"LLM sender classification failed: {e}") + + # Default fallback + return SenderClassification( + sender_type=SenderType.UNBEKANNT, + confidence=0.3, + domain_matched=False, + ai_classified=False, + ) diff --git a/klausur-service/backend/mail/ai_service.py b/klausur-service/backend/mail/ai_service.py index 0323084..1ec79d3 100644 --- a/klausur-service/backend/mail/ai_service.py +++ b/klausur-service/backend/mail/ai_service.py @@ -1,18 +1,19 @@ """ -AI Email Analysis Service +AI Email Analysis Service — Barrel Re-export -KI-powered email analysis with: -- Sender classification (authority recognition) -- Deadline extraction -- Category classification -- Response suggestions +Split into: +- mail/ai_sender.py — Sender classification (domain + LLM) +- mail/ai_deadline.py — Deadline extraction (regex + LLM) +- mail/ai_category.py — Category classification + response suggestions + +The AIEmailService class and get_ai_email_service() are defined here +to maintain the original public API. """ -import os -import re import logging -from typing import Optional, List, Dict, Any, Tuple -from datetime import datetime, timedelta +from typing import Optional, List, Tuple +from datetime import datetime + import httpx from .models import ( @@ -23,17 +24,15 @@ from .models import ( DeadlineExtraction, EmailAnalysisResult, ResponseSuggestion, - KNOWN_AUTHORITIES_NI, - classify_sender_by_domain, get_priority_from_sender_type, ) from .mail_db import update_email_ai_analysis +from .ai_sender import classify_sender, LLM_GATEWAY_URL +from .ai_deadline import extract_deadlines +from .ai_category import classify_category, suggest_response logger = logging.getLogger(__name__) -# LLM Gateway configuration -LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090") - class AIEmailService: """ @@ -56,10 +55,6 @@ class AIEmailService: self._http_client = httpx.AsyncClient(timeout=30.0) return self._http_client - # ========================================================================= - # Sender Classification - # ========================================================================= - async def classify_sender( self, sender_email: str, @@ -67,300 +62,20 @@ class AIEmailService: subject: Optional[str] = None, body_preview: Optional[str] = None, ) -> SenderClassification: - """ - Classify the sender of an email. - - First tries domain matching, then falls back to LLM. - - Args: - sender_email: Sender's email address - sender_name: Sender's display name - subject: Email subject - body_preview: First 200 chars of body - - Returns: - SenderClassification with type and confidence - """ - # Try domain-based classification first (fast, high confidence) - domain_result = classify_sender_by_domain(sender_email) - if domain_result: - return domain_result - - # Fall back to LLM classification - return await self._classify_sender_llm( - sender_email, sender_name, subject, body_preview + """Classify the sender of an email.""" + client = await self.get_http_client() + return await classify_sender( + client, sender_email, sender_name, subject, body_preview ) - async def _classify_sender_llm( - self, - sender_email: str, - sender_name: Optional[str], - subject: Optional[str], - body_preview: Optional[str], - ) -> SenderClassification: - """Classify sender using LLM.""" - try: - client = await self.get_http_client() - - prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn: - -Absender E-Mail: {sender_email} -Absender Name: {sender_name or "Nicht angegeben"} -Betreff: {subject or "Nicht angegeben"} -Vorschau: {body_preview[:200] if body_preview else "Nicht verfügbar"} - -Klassifiziere den Absender in EINE der folgenden Kategorien: -- kultusministerium: Kultusministerium/Bildungsministerium -- landesschulbehoerde: Landesschulbehörde -- rlsb: Regionales Landesamt für Schule und Bildung -- schulamt: Schulamt -- nibis: Niedersächsischer Bildungsserver -- schultraeger: Schulträger/Kommune -- elternvertreter: Elternvertreter/Elternrat -- gewerkschaft: Gewerkschaft (GEW, VBE, etc.) -- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.) -- privatperson: Privatperson -- unternehmen: Unternehmen/Firma -- unbekannt: Nicht einzuordnen - -Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0. -Format: kategorie|konfidenz|kurze_begründung -""" - - response = await client.post( - f"{LLM_GATEWAY_URL}/api/v1/inference", - json={ - "prompt": prompt, - "playbook": "mail_analysis", - "max_tokens": 100, - }, - ) - - if response.status_code == 200: - data = response.json() - result_text = data.get("response", "unbekannt|0.5|") - - # Parse response - parts = result_text.strip().split("|") - if len(parts) >= 2: - sender_type_str = parts[0].strip().lower() - confidence = float(parts[1].strip()) - - # Map to enum - type_mapping = { - "kultusministerium": SenderType.KULTUSMINISTERIUM, - "landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE, - "rlsb": SenderType.RLSB, - "schulamt": SenderType.SCHULAMT, - "nibis": SenderType.NIBIS, - "schultraeger": SenderType.SCHULTRAEGER, - "elternvertreter": SenderType.ELTERNVERTRETER, - "gewerkschaft": SenderType.GEWERKSCHAFT, - "fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT, - "privatperson": SenderType.PRIVATPERSON, - "unternehmen": SenderType.UNTERNEHMEN, - } - - sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT) - - return SenderClassification( - sender_type=sender_type, - confidence=min(max(confidence, 0.0), 1.0), - domain_matched=False, - ai_classified=True, - ) - - except Exception as e: - logger.warning(f"LLM sender classification failed: {e}") - - # Default fallback - return SenderClassification( - sender_type=SenderType.UNBEKANNT, - confidence=0.3, - domain_matched=False, - ai_classified=False, - ) - - # ========================================================================= - # Deadline Extraction - # ========================================================================= - async def extract_deadlines( self, subject: str, body_text: str, ) -> List[DeadlineExtraction]: - """ - Extract deadlines from email content. - - Uses regex patterns first, then LLM for complex cases. - - Args: - subject: Email subject - body_text: Email body text - - Returns: - List of extracted deadlines - """ - deadlines = [] - - # Combine subject and body - full_text = f"{subject}\n{body_text}" if body_text else subject - - # Try regex extraction first - regex_deadlines = self._extract_deadlines_regex(full_text) - deadlines.extend(regex_deadlines) - - # If no regex matches, try LLM - if not deadlines and body_text: - llm_deadlines = await self._extract_deadlines_llm(subject, body_text[:1000]) - deadlines.extend(llm_deadlines) - - return deadlines - - def _extract_deadlines_regex(self, text: str) -> List[DeadlineExtraction]: - """Extract deadlines using regex patterns.""" - deadlines = [] - now = datetime.now() - - # German date patterns - patterns = [ - # "bis zum 15.01.2025" - (r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True), - # "spätestens am 15.01.2025" - (r"spätestens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True), - # "Abgabetermin: 15.01.2025" - (r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True), - # "innerhalb von 14 Tagen" - (r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False), - # "bis Ende Januar" - (r"bis\s+(?:Ende\s+)?(Januar|Februar|März|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False), - ] - - for pattern, is_specific_date in patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - - for match in matches: - try: - if is_specific_date: - day = int(match.group(1)) - month = int(match.group(2)) - year = int(match.group(3)) - - # Handle 2-digit years - if year < 100: - year += 2000 - - deadline_date = datetime(year, month, day) - - # Skip past dates - if deadline_date < now: - continue - - # Get surrounding context - start = max(0, match.start() - 50) - end = min(len(text), match.end() + 50) - context = text[start:end].strip() - - deadlines.append(DeadlineExtraction( - deadline_date=deadline_date, - description=f"Frist: {match.group(0)}", - confidence=0.85, - source_text=context, - is_firm=True, - )) - - else: - # Relative dates (innerhalb von X Tagen) - if "Tagen" in pattern or "Wochen" in pattern: - days = int(match.group(1)) - if "Wochen" in match.group(0).lower(): - days *= 7 - deadline_date = now + timedelta(days=days) - - deadlines.append(DeadlineExtraction( - deadline_date=deadline_date, - description=f"Relative Frist: {match.group(0)}", - confidence=0.7, - source_text=match.group(0), - is_firm=False, - )) - - except (ValueError, IndexError) as e: - logger.debug(f"Failed to parse date: {e}") - continue - - return deadlines - - async def _extract_deadlines_llm( - self, - subject: str, - body_preview: str, - ) -> List[DeadlineExtraction]: - """Extract deadlines using LLM.""" - try: - client = await self.get_http_client() - - prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine: - -Betreff: {subject} -Inhalt: {body_preview} - -Liste alle Fristen im folgenden Format auf (eine pro Zeile): -DATUM|BESCHREIBUNG|VERBINDLICH -Beispiel: 2025-01-15|Abgabe der Berichte|ja - -Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN - -Antworte NUR im angegebenen Format. -""" - - response = await client.post( - f"{LLM_GATEWAY_URL}/api/v1/inference", - json={ - "prompt": prompt, - "playbook": "mail_analysis", - "max_tokens": 200, - }, - ) - - if response.status_code == 200: - data = response.json() - result_text = data.get("response", "") - - if "KEINE_FRISTEN" in result_text: - return [] - - deadlines = [] - for line in result_text.strip().split("\n"): - parts = line.split("|") - if len(parts) >= 2: - try: - date_str = parts[0].strip() - deadline_date = datetime.fromisoformat(date_str) - description = parts[1].strip() - is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True - - deadlines.append(DeadlineExtraction( - deadline_date=deadline_date, - description=description, - confidence=0.7, - source_text=line, - is_firm=is_firm, - )) - except (ValueError, IndexError): - continue - - return deadlines - - except Exception as e: - logger.warning(f"LLM deadline extraction failed: {e}") - - return [] - - # ========================================================================= - # Email Category Classification - # ========================================================================= + """Extract deadlines from email content.""" + client = await self.get_http_client() + return await extract_deadlines(client, subject, body_text) async def classify_category( self, @@ -368,155 +83,9 @@ Antworte NUR im angegebenen Format. body_preview: str, sender_type: SenderType, ) -> Tuple[EmailCategory, float]: - """ - Classify email into a category. - - Args: - subject: Email subject - body_preview: First 200 chars of body - sender_type: Already classified sender type - - Returns: - Tuple of (category, confidence) - """ - # Rule-based classification first - category, confidence = self._classify_category_rules(subject, body_preview, sender_type) - - if confidence > 0.7: - return category, confidence - - # Fall back to LLM - return await self._classify_category_llm(subject, body_preview) - - def _classify_category_rules( - self, - subject: str, - body_preview: str, - sender_type: SenderType, - ) -> Tuple[EmailCategory, float]: - """Rule-based category classification.""" - text = f"{subject} {body_preview}".lower() - - # Keywords for each category - category_keywords = { - EmailCategory.DIENSTLICH: [ - "dienstlich", "dienstanweisung", "erlass", "verordnung", - "bescheid", "verfügung", "ministerium", "behörde" - ], - EmailCategory.PERSONAL: [ - "personalrat", "stellenausschreibung", "versetzung", - "beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit" - ], - EmailCategory.FINANZEN: [ - "budget", "haushalt", "etat", "abrechnung", "rechnung", - "erstattung", "zuschuss", "fördermittel" - ], - EmailCategory.ELTERN: [ - "elternbrief", "elternabend", "schulkonferenz", - "elternvertreter", "elternbeirat" - ], - EmailCategory.SCHUELER: [ - "schüler", "schülerin", "zeugnis", "klasse", "unterricht", - "prüfung", "klassenfahrt", "schulpflicht" - ], - EmailCategory.FORTBILDUNG: [ - "fortbildung", "seminar", "workshop", "schulung", - "weiterbildung", "nlq", "didaktik" - ], - EmailCategory.VERANSTALTUNG: [ - "einladung", "veranstaltung", "termin", "konferenz", - "sitzung", "tagung", "feier" - ], - EmailCategory.SICHERHEIT: [ - "sicherheit", "notfall", "brandschutz", "evakuierung", - "hygiene", "corona", "infektionsschutz" - ], - EmailCategory.TECHNIK: [ - "it", "software", "computer", "netzwerk", "login", - "passwort", "digitalisierung", "iserv" - ], - EmailCategory.NEWSLETTER: [ - "newsletter", "rundschreiben", "info-mail", "mitteilung" - ], - EmailCategory.WERBUNG: [ - "angebot", "rabatt", "aktion", "werbung", "abonnement" - ], - } - - best_category = EmailCategory.SONSTIGES - best_score = 0.0 - - for category, keywords in category_keywords.items(): - score = sum(1 for kw in keywords if kw in text) - if score > best_score: - best_score = score - best_category = category - - # Adjust based on sender type - if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]: - if best_category == EmailCategory.SONSTIGES: - best_category = EmailCategory.DIENSTLICH - best_score = 2 - - # Convert score to confidence - confidence = min(0.9, 0.4 + (best_score * 0.15)) - - return best_category, confidence - - async def _classify_category_llm( - self, - subject: str, - body_preview: str, - ) -> Tuple[EmailCategory, float]: - """LLM-based category classification.""" - try: - client = await self.get_http_client() - - categories = ", ".join([c.value for c in EmailCategory]) - - prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie: - -Betreff: {subject} -Inhalt: {body_preview[:500]} - -Kategorien: {categories} - -Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0): -Format: kategorie|konfidenz -""" - - response = await client.post( - f"{LLM_GATEWAY_URL}/api/v1/inference", - json={ - "prompt": prompt, - "playbook": "mail_analysis", - "max_tokens": 50, - }, - ) - - if response.status_code == 200: - data = response.json() - result = data.get("response", "sonstiges|0.5") - parts = result.strip().split("|") - - if len(parts) >= 2: - category_str = parts[0].strip().lower() - confidence = float(parts[1].strip()) - - try: - category = EmailCategory(category_str) - return category, min(max(confidence, 0.0), 1.0) - except ValueError: - pass - - except Exception as e: - logger.warning(f"LLM category classification failed: {e}") - - return EmailCategory.SONSTIGES, 0.5 - - # ========================================================================= - # Full Analysis Pipeline - # ========================================================================= + """Classify email into a category.""" + client = await self.get_http_client() + return await classify_category(client, subject, body_preview, sender_type) async def analyze_email( self, @@ -527,20 +96,7 @@ Format: kategorie|konfidenz body_text: Optional[str], body_preview: Optional[str], ) -> EmailAnalysisResult: - """ - Run full analysis pipeline on an email. - - Args: - email_id: Database ID of the email - sender_email: Sender's email address - sender_name: Sender's display name - subject: Email subject - body_text: Full body text - body_preview: Preview text - - Returns: - Complete analysis result - """ + """Run full analysis pipeline on an email.""" # 1. Classify sender sender_classification = await self.classify_sender( sender_email, sender_name, subject, body_preview @@ -569,8 +125,8 @@ Format: kategorie|konfidenz elif days_until <= 7: suggested_priority = max(suggested_priority, TaskPriority.MEDIUM) - # 5. Generate summary (optional, can be expensive) - summary = None # Could add LLM summary generation here + # 5. Summary (optional) + summary = None # 6. Determine if task should be auto-created auto_create_task = ( @@ -612,10 +168,6 @@ Format: kategorie|konfidenz auto_create_task=auto_create_task, ) - # ========================================================================= - # Response Suggestions - # ========================================================================= - async def suggest_response( self, subject: str, @@ -623,114 +175,11 @@ Format: kategorie|konfidenz sender_type: SenderType, category: EmailCategory, ) -> List[ResponseSuggestion]: - """ - Generate response suggestions for an email. - - Args: - subject: Original email subject - body_text: Original email body - sender_type: Classified sender type - category: Classified category - - Returns: - List of response suggestions - """ - suggestions = [] - - # Add standard templates based on sender type and category - if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]: - suggestions.append(ResponseSuggestion( - template_type="acknowledgment", - subject=f"Re: {subject}", - body="""Sehr geehrte Damen und Herren, - -vielen Dank für Ihre Nachricht. - -Ich bestätige den Eingang und werde die Angelegenheit fristgerecht bearbeiten. - -Mit freundlichen Grüßen""", - confidence=0.8, - )) - - if category == EmailCategory.ELTERN: - suggestions.append(ResponseSuggestion( - template_type="parent_response", - subject=f"Re: {subject}", - body="""Liebe Eltern, - -vielen Dank für Ihre Nachricht. - -[Ihre Antwort hier] - -Mit freundlichen Grüßen""", - confidence=0.7, - )) - - # Add LLM-generated suggestion - try: - llm_suggestion = await self._generate_response_llm(subject, body_text[:500], sender_type) - if llm_suggestion: - suggestions.append(llm_suggestion) - except Exception as e: - logger.warning(f"LLM response generation failed: {e}") - - return suggestions - - async def _generate_response_llm( - self, - subject: str, - body_preview: str, - sender_type: SenderType, - ) -> Optional[ResponseSuggestion]: - """Generate a response suggestion using LLM.""" - try: - client = await self.get_http_client() - - sender_desc = { - SenderType.KULTUSMINISTERIUM: "dem Kultusministerium", - SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehörde", - SenderType.RLSB: "dem RLSB", - SenderType.ELTERNVERTRETER: "einem Elternvertreter", - }.get(sender_type, "einem Absender") - - prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}: - -Betreff: {subject} -Inhalt: {body_preview} - -Die Antwort sollte: -- Höflich und formell sein -- Den Eingang bestätigen -- Eine konkrete nächste Aktion nennen oder um Klärung bitten - -Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:"). -""" - - response = await client.post( - f"{LLM_GATEWAY_URL}/api/v1/inference", - json={ - "prompt": prompt, - "playbook": "mail_analysis", - "max_tokens": 300, - }, - ) - - if response.status_code == 200: - data = response.json() - body = data.get("response", "").strip() - - if body: - return ResponseSuggestion( - template_type="ai_generated", - subject=f"Re: {subject}", - body=body, - confidence=0.6, - ) - - except Exception as e: - logger.warning(f"LLM response generation failed: {e}") - - return None + """Generate response suggestions for an email.""" + client = await self.get_http_client() + return await suggest_response( + client, subject, body_text, sender_type, category + ) # Global instance diff --git a/klausur-service/backend/metrics_db.py b/klausur-service/backend/metrics_db.py index f3b0ff7..d5e2fa7 100644 --- a/klausur-service/backend/metrics_db.py +++ b/klausur-service/backend/metrics_db.py @@ -1,833 +1,36 @@ """ -PostgreSQL Metrics Database Service -Stores search feedback, calculates quality metrics (Precision, Recall, MRR). +PostgreSQL Metrics Database Service — Barrel Re-export + +Split into: +- metrics_db_core.py — Pool, feedback, metrics, relevance +- metrics_db_schema.py — Table initialization (DDL) +- metrics_db_zeugnis.py — Zeugnis source/document/stats operations + +All public names are re-exported here for backward compatibility. """ -import os -from typing import Optional, List, Dict -from datetime import datetime, timedelta -import asyncio - -# Database Configuration - uses test default if not configured (for CI) -DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics") - -# Connection pool -_pool = None - - -async def get_pool(): - """Get or create database connection pool.""" - global _pool - if _pool is None: - try: - import asyncpg - _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) - except ImportError: - print("Warning: asyncpg not installed. Metrics storage disabled.") - return None - except Exception as e: - print(f"Warning: Failed to connect to PostgreSQL: {e}") - return None - return _pool - - -async def init_metrics_tables() -> bool: - """Initialize metrics tables in PostgreSQL.""" - pool = await get_pool() - if pool is None: - return False - - create_tables_sql = """ - -- RAG Search Feedback Table - CREATE TABLE IF NOT EXISTS rag_search_feedback ( - id SERIAL PRIMARY KEY, - result_id VARCHAR(255) NOT NULL, - query_text TEXT, - collection_name VARCHAR(100), - score FLOAT, - rating INTEGER CHECK (rating >= 1 AND rating <= 5), - notes TEXT, - user_id VARCHAR(100), - created_at TIMESTAMP DEFAULT NOW() - ); - - -- Index for efficient querying - CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at); - CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name); - CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating); - - -- RAG Search Logs Table (for latency tracking) - CREATE TABLE IF NOT EXISTS rag_search_logs ( - id SERIAL PRIMARY KEY, - query_text TEXT NOT NULL, - collection_name VARCHAR(100), - result_count INTEGER, - latency_ms INTEGER, - top_score FLOAT, - filters JSONB, - created_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at); - - -- RAG Upload History Table - CREATE TABLE IF NOT EXISTS rag_upload_history ( - id SERIAL PRIMARY KEY, - filename VARCHAR(500) NOT NULL, - collection_name VARCHAR(100), - year INTEGER, - pdfs_extracted INTEGER, - minio_path VARCHAR(1000), - uploaded_by VARCHAR(100), - created_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at); - - -- Binäre Relevanz-Judgments für echte Precision/Recall - CREATE TABLE IF NOT EXISTS rag_relevance_judgments ( - id SERIAL PRIMARY KEY, - query_id VARCHAR(255) NOT NULL, - query_text TEXT NOT NULL, - result_id VARCHAR(255) NOT NULL, - result_rank INTEGER, - is_relevant BOOLEAN NOT NULL, - collection_name VARCHAR(100), - user_id VARCHAR(100), - created_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id); - CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at); - - -- Zeugnisse Source Tracking - CREATE TABLE IF NOT EXISTS zeugnis_sources ( - id VARCHAR(36) PRIMARY KEY, - bundesland VARCHAR(10) NOT NULL, - name VARCHAR(255) NOT NULL, - base_url TEXT, - license_type VARCHAR(50) NOT NULL, - training_allowed BOOLEAN DEFAULT FALSE, - verified_by VARCHAR(100), - verified_at TIMESTAMP, - created_at TIMESTAMP DEFAULT NOW(), - updated_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland); - - -- Zeugnisse Seed URLs - CREATE TABLE IF NOT EXISTS zeugnis_seed_urls ( - id VARCHAR(36) PRIMARY KEY, - source_id VARCHAR(36) REFERENCES zeugnis_sources(id), - url TEXT NOT NULL, - doc_type VARCHAR(50), - status VARCHAR(20) DEFAULT 'pending', - last_crawled TIMESTAMP, - error_message TEXT, - created_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id); - CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status); - - -- Zeugnisse Documents - CREATE TABLE IF NOT EXISTS zeugnis_documents ( - id VARCHAR(36) PRIMARY KEY, - seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id), - title VARCHAR(500), - url TEXT NOT NULL, - content_hash VARCHAR(64), - minio_path TEXT, - training_allowed BOOLEAN DEFAULT FALSE, - indexed_in_qdrant BOOLEAN DEFAULT FALSE, - file_size INTEGER, - content_type VARCHAR(100), - created_at TIMESTAMP DEFAULT NOW(), - updated_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id); - CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash); - - -- Zeugnisse Document Versions - CREATE TABLE IF NOT EXISTS zeugnis_document_versions ( - id VARCHAR(36) PRIMARY KEY, - document_id VARCHAR(36) REFERENCES zeugnis_documents(id), - version INTEGER NOT NULL, - content_hash VARCHAR(64), - minio_path TEXT, - change_summary TEXT, - created_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id); - - -- Zeugnisse Usage Events (Audit Trail) - CREATE TABLE IF NOT EXISTS zeugnis_usage_events ( - id VARCHAR(36) PRIMARY KEY, - document_id VARCHAR(36) REFERENCES zeugnis_documents(id), - event_type VARCHAR(50) NOT NULL, - user_id VARCHAR(100), - details JSONB, - created_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id); - CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type); - CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at); - - -- Crawler Queue - CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue ( - id VARCHAR(36) PRIMARY KEY, - source_id VARCHAR(36) REFERENCES zeugnis_sources(id), - priority INTEGER DEFAULT 5, - status VARCHAR(20) DEFAULT 'pending', - started_at TIMESTAMP, - completed_at TIMESTAMP, - documents_found INTEGER DEFAULT 0, - documents_indexed INTEGER DEFAULT 0, - error_count INTEGER DEFAULT 0, - created_at TIMESTAMP DEFAULT NOW() - ); - - CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status); - """ - - try: - async with pool.acquire() as conn: - await conn.execute(create_tables_sql) - print("RAG metrics tables initialized") - return True - except Exception as e: - print(f"Failed to initialize metrics tables: {e}") - return False - - -# ============================================================================= -# Feedback Storage -# ============================================================================= - -async def store_feedback( - result_id: str, - rating: int, - query_text: Optional[str] = None, - collection_name: Optional[str] = None, - score: Optional[float] = None, - notes: Optional[str] = None, - user_id: Optional[str] = None, -) -> bool: - """Store search result feedback.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO rag_search_feedback - (result_id, query_text, collection_name, score, rating, notes, user_id) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """, - result_id, query_text, collection_name, score, rating, notes, user_id - ) - return True - except Exception as e: - print(f"Failed to store feedback: {e}") - return False - - -async def log_search( - query_text: str, - collection_name: str, - result_count: int, - latency_ms: int, - top_score: Optional[float] = None, - filters: Optional[Dict] = None, -) -> bool: - """Log a search for metrics tracking.""" - pool = await get_pool() - if pool is None: - return False - - try: - import json - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO rag_search_logs - (query_text, collection_name, result_count, latency_ms, top_score, filters) - VALUES ($1, $2, $3, $4, $5, $6) - """, - query_text, collection_name, result_count, latency_ms, top_score, - json.dumps(filters) if filters else None - ) - return True - except Exception as e: - print(f"Failed to log search: {e}") - return False - - -async def log_upload( - filename: str, - collection_name: str, - year: int, - pdfs_extracted: int, - minio_path: Optional[str] = None, - uploaded_by: Optional[str] = None, -) -> bool: - """Log an upload for history tracking.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO rag_upload_history - (filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by) - VALUES ($1, $2, $3, $4, $5, $6) - """, - filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by - ) - return True - except Exception as e: - print(f"Failed to log upload: {e}") - return False - - -# ============================================================================= -# Metrics Calculation -# ============================================================================= - -async def calculate_metrics( - collection_name: Optional[str] = None, - days: int = 7, -) -> Dict: - """ - Calculate RAG quality metrics from stored feedback. - - Returns: - Dict with precision, recall, MRR, latency, etc. - """ - pool = await get_pool() - if pool is None: - return {"error": "Database not available", "connected": False} - - try: - async with pool.acquire() as conn: - # Date filter - since = datetime.now() - timedelta(days=days) - - # Collection filter - collection_filter = "" - params = [since] - if collection_name: - collection_filter = "AND collection_name = $2" - params.append(collection_name) - - # Total feedback count - total_feedback = await conn.fetchval( - f""" - SELECT COUNT(*) FROM rag_search_feedback - WHERE created_at >= $1 {collection_filter} - """, - *params - ) - - # Rating distribution - rating_dist = await conn.fetch( - f""" - SELECT rating, COUNT(*) as count - FROM rag_search_feedback - WHERE created_at >= $1 {collection_filter} - GROUP BY rating - ORDER BY rating DESC - """, - *params - ) - - # Average rating (proxy for precision) - avg_rating = await conn.fetchval( - f""" - SELECT AVG(rating) FROM rag_search_feedback - WHERE created_at >= $1 {collection_filter} - """, - *params - ) - - # Score distribution - score_dist = await conn.fetch( - f""" - SELECT - CASE - WHEN score >= 0.9 THEN '0.9+' - WHEN score >= 0.7 THEN '0.7-0.9' - WHEN score >= 0.5 THEN '0.5-0.7' - ELSE '<0.5' - END as range, - COUNT(*) as count - FROM rag_search_feedback - WHERE created_at >= $1 AND score IS NOT NULL {collection_filter} - GROUP BY range - ORDER BY range DESC - """, - *params - ) - - # Search logs for latency - latency_stats = await conn.fetchrow( - f""" - SELECT - AVG(latency_ms) as avg_latency, - COUNT(*) as total_searches, - AVG(result_count) as avg_results - FROM rag_search_logs - WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')} - """, - *params - ) - - # Calculate precision@5 (% of top 5 rated 4+) - precision_at_5 = await conn.fetchval( - f""" - SELECT - CASE WHEN COUNT(*) > 0 - THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*) - ELSE 0 END - FROM rag_search_feedback - WHERE created_at >= $1 {collection_filter} - """, - *params - ) or 0 - - # Calculate MRR (Mean Reciprocal Rank) - simplified - # Using average rating as proxy for relevance - mrr = (avg_rating or 0) / 5.0 - - # Error rate (ratings of 1 or 2) - error_count = sum( - r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2 - ) - error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0 - - # Score distribution as percentages - total_scored = sum(s['count'] for s in score_dist) - score_distribution = {} - for s in score_dist: - if total_scored > 0: - score_distribution[s['range']] = round(s['count'] / total_scored * 100) - else: - score_distribution[s['range']] = 0 - - return { - "connected": True, - "period_days": days, - "precision_at_5": round(precision_at_5, 2), - "recall_at_10": round(precision_at_5 * 1.1, 2), # Estimated - "mrr": round(mrr, 2), - "avg_latency_ms": round(latency_stats['avg_latency'] or 0), - "total_ratings": total_feedback, - "total_searches": latency_stats['total_searches'] or 0, - "error_rate": round(error_rate, 1), - "score_distribution": score_distribution, - "rating_distribution": { - str(r['rating']): r['count'] for r in rating_dist if r['rating'] - }, - } - - except Exception as e: - print(f"Failed to calculate metrics: {e}") - return {"error": str(e), "connected": False} - - -async def get_recent_feedback(limit: int = 20) -> List[Dict]: - """Get recent feedback entries.""" - pool = await get_pool() - if pool is None: - return [] - - try: - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT result_id, rating, query_text, collection_name, score, notes, created_at - FROM rag_search_feedback - ORDER BY created_at DESC - LIMIT $1 - """, - limit - ) - return [ - { - "result_id": r['result_id'], - "rating": r['rating'], - "query_text": r['query_text'], - "collection_name": r['collection_name'], - "score": r['score'], - "notes": r['notes'], - "created_at": r['created_at'].isoformat() if r['created_at'] else None, - } - for r in rows - ] - except Exception as e: - print(f"Failed to get recent feedback: {e}") - return [] - - -async def get_upload_history(limit: int = 20) -> List[Dict]: - """Get recent upload history.""" - pool = await get_pool() - if pool is None: - return [] - - try: - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at - FROM rag_upload_history - ORDER BY created_at DESC - LIMIT $1 - """, - limit - ) - return [ - { - "filename": r['filename'], - "collection_name": r['collection_name'], - "year": r['year'], - "pdfs_extracted": r['pdfs_extracted'], - "minio_path": r['minio_path'], - "uploaded_by": r['uploaded_by'], - "created_at": r['created_at'].isoformat() if r['created_at'] else None, - } - for r in rows - ] - except Exception as e: - print(f"Failed to get upload history: {e}") - return [] - - -# ============================================================================= -# Relevance Judgments (Binary Precision/Recall) -# ============================================================================= - -async def store_relevance_judgment( - query_id: str, - query_text: str, - result_id: str, - is_relevant: bool, - result_rank: Optional[int] = None, - collection_name: Optional[str] = None, - user_id: Optional[str] = None, -) -> bool: - """Store binary relevance judgment for Precision/Recall calculation.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO rag_relevance_judgments - (query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT DO NOTHING - """, - query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id - ) - return True - except Exception as e: - print(f"Failed to store relevance judgment: {e}") - return False - - -async def calculate_precision_recall( - collection_name: Optional[str] = None, - days: int = 7, - k: int = 10, -) -> Dict: - """ - Calculate true Precision@k and Recall@k from binary relevance judgments. - - Precision@k = (Relevant docs in top k) / k - Recall@k = (Relevant docs in top k) / (Total relevant docs for query) - """ - pool = await get_pool() - if pool is None: - return {"error": "Database not available", "connected": False} - - try: - async with pool.acquire() as conn: - since = datetime.now() - timedelta(days=days) - - collection_filter = "" - params = [since, k] - if collection_name: - collection_filter = "AND collection_name = $3" - params.append(collection_name) - - # Get precision@k per query, then average - precision_result = await conn.fetchval( - f""" - WITH query_precision AS ( - SELECT - query_id, - COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT / - GREATEST(COUNT(*), 1) as precision - FROM rag_relevance_judgments - WHERE created_at >= $1 - AND (result_rank IS NULL OR result_rank <= $2) - {collection_filter} - GROUP BY query_id - ) - SELECT AVG(precision) FROM query_precision - """, - *params - ) or 0 - - # Get recall@k per query, then average - recall_result = await conn.fetchval( - f""" - WITH query_recall AS ( - SELECT - query_id, - COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT / - GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall - FROM rag_relevance_judgments - WHERE created_at >= $1 - {collection_filter} - GROUP BY query_id - ) - SELECT AVG(recall) FROM query_recall - """, - *params - ) or 0 - - # Total judgments - total_judgments = await conn.fetchval( - f""" - SELECT COUNT(*) FROM rag_relevance_judgments - WHERE created_at >= $1 {collection_filter} - """, - since, *([collection_name] if collection_name else []) - ) - - # Unique queries - unique_queries = await conn.fetchval( - f""" - SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments - WHERE created_at >= $1 {collection_filter} - """, - since, *([collection_name] if collection_name else []) - ) - - return { - "connected": True, - "period_days": days, - "k": k, - "precision_at_k": round(precision_result, 3), - "recall_at_k": round(recall_result, 3), - "f1_score": round( - 2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3 - ), - "total_judgments": total_judgments or 0, - "unique_queries": unique_queries or 0, - } - - except Exception as e: - print(f"Failed to calculate precision/recall: {e}") - return {"error": str(e), "connected": False} - - -# ============================================================================= -# Zeugnis Database Operations -# ============================================================================= - -async def get_zeugnis_sources() -> List[Dict]: - """Get all zeugnis sources (Bundesländer).""" - pool = await get_pool() - if pool is None: - return [] - - try: - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT id, bundesland, name, base_url, license_type, training_allowed, - verified_by, verified_at, created_at, updated_at - FROM zeugnis_sources - ORDER BY bundesland - """ - ) - return [dict(r) for r in rows] - except Exception as e: - print(f"Failed to get zeugnis sources: {e}") - return [] - - -async def upsert_zeugnis_source( - id: str, - bundesland: str, - name: str, - license_type: str, - training_allowed: bool, - base_url: Optional[str] = None, - verified_by: Optional[str] = None, -) -> bool: - """Insert or update a zeugnis source.""" - pool = await get_pool() - if pool is None: - return False - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) - ON CONFLICT (id) DO UPDATE SET - name = EXCLUDED.name, - base_url = EXCLUDED.base_url, - license_type = EXCLUDED.license_type, - training_allowed = EXCLUDED.training_allowed, - verified_by = EXCLUDED.verified_by, - verified_at = NOW(), - updated_at = NOW() - """, - id, bundesland, name, base_url, license_type, training_allowed, verified_by - ) - return True - except Exception as e: - print(f"Failed to upsert zeugnis source: {e}") - return False - - -async def get_zeugnis_documents( - bundesland: Optional[str] = None, - limit: int = 100, - offset: int = 0, -) -> List[Dict]: - """Get zeugnis documents with optional filtering.""" - pool = await get_pool() - if pool is None: - return [] - - try: - async with pool.acquire() as conn: - if bundesland: - rows = await conn.fetch( - """ - SELECT d.*, s.bundesland, s.name as source_name - FROM zeugnis_documents d - JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id - JOIN zeugnis_sources s ON u.source_id = s.id - WHERE s.bundesland = $1 - ORDER BY d.created_at DESC - LIMIT $2 OFFSET $3 - """, - bundesland, limit, offset - ) - else: - rows = await conn.fetch( - """ - SELECT d.*, s.bundesland, s.name as source_name - FROM zeugnis_documents d - JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id - JOIN zeugnis_sources s ON u.source_id = s.id - ORDER BY d.created_at DESC - LIMIT $1 OFFSET $2 - """, - limit, offset - ) - return [dict(r) for r in rows] - except Exception as e: - print(f"Failed to get zeugnis documents: {e}") - return [] - - -async def get_zeugnis_stats() -> Dict: - """Get zeugnis crawler statistics.""" - pool = await get_pool() - if pool is None: - return {"error": "Database not available"} - - try: - async with pool.acquire() as conn: - # Total sources - sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources") - - # Total documents - documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents") - - # Indexed documents - indexed = await conn.fetchval( - "SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true" - ) - - # Training allowed - training_allowed = await conn.fetchval( - "SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true" - ) - - # Per Bundesland stats - per_bundesland = await conn.fetch( - """ - SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count - FROM zeugnis_sources s - LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id - LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id - GROUP BY s.bundesland, s.name, s.training_allowed - ORDER BY s.bundesland - """ - ) - - # Active crawls - active_crawls = await conn.fetchval( - "SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'" - ) - - return { - "total_sources": sources or 0, - "total_documents": documents or 0, - "indexed_documents": indexed or 0, - "training_allowed_documents": training_allowed or 0, - "active_crawls": active_crawls or 0, - "per_bundesland": [dict(r) for r in per_bundesland], - } - except Exception as e: - print(f"Failed to get zeugnis stats: {e}") - return {"error": str(e)} - - -async def log_zeugnis_event( - document_id: str, - event_type: str, - user_id: Optional[str] = None, - details: Optional[Dict] = None, -) -> bool: - """Log a zeugnis usage event for audit trail.""" - pool = await get_pool() - if pool is None: - return False - - try: - import json - import uuid - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details) - VALUES ($1, $2, $3, $4, $5) - """, - str(uuid.uuid4()), document_id, event_type, user_id, - json.dumps(details) if details else None - ) - return True - except Exception as e: - print(f"Failed to log zeugnis event: {e}") - return False +# Schema: table initialization +from metrics_db_schema import init_metrics_tables # noqa: F401 + +# Core: pool, feedback, search logs, metrics, relevance +from metrics_db_core import ( # noqa: F401 + DATABASE_URL, + get_pool, + store_feedback, + log_search, + log_upload, + calculate_metrics, + get_recent_feedback, + get_upload_history, + store_relevance_judgment, + calculate_precision_recall, +) + +# Zeugnis operations +from metrics_db_zeugnis import ( # noqa: F401 + get_zeugnis_sources, + upsert_zeugnis_source, + get_zeugnis_documents, + get_zeugnis_stats, + log_zeugnis_event, +) diff --git a/klausur-service/backend/metrics_db_core.py b/klausur-service/backend/metrics_db_core.py new file mode 100644 index 0000000..663f77f --- /dev/null +++ b/klausur-service/backend/metrics_db_core.py @@ -0,0 +1,459 @@ +""" +PostgreSQL Metrics Database - Core Operations + +Connection pool, table initialization, feedback storage, search logging, +upload history, metrics calculation, and relevance judgments. + +Extracted from metrics_db.py to keep files under 500 LOC. +""" + +import os +from typing import Optional, List, Dict +from datetime import datetime, timedelta + +# Database Configuration - uses test default if not configured (for CI) +DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics") + +# Connection pool +_pool = None + + +async def get_pool(): + """Get or create database connection pool.""" + global _pool + if _pool is None: + try: + import asyncpg + _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) + except ImportError: + print("Warning: asyncpg not installed. Metrics storage disabled.") + return None + except Exception as e: + print(f"Warning: Failed to connect to PostgreSQL: {e}") + return None + return _pool + + + +# ============================================================================= +# Feedback Storage +# ============================================================================= + +async def store_feedback( + result_id: str, + rating: int, + query_text: Optional[str] = None, + collection_name: Optional[str] = None, + score: Optional[float] = None, + notes: Optional[str] = None, + user_id: Optional[str] = None, +) -> bool: + """Store search result feedback.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO rag_search_feedback + (result_id, query_text, collection_name, score, rating, notes, user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7) + """, + result_id, query_text, collection_name, score, rating, notes, user_id + ) + return True + except Exception as e: + print(f"Failed to store feedback: {e}") + return False + + +async def log_search( + query_text: str, + collection_name: str, + result_count: int, + latency_ms: int, + top_score: Optional[float] = None, + filters: Optional[Dict] = None, +) -> bool: + """Log a search for metrics tracking.""" + pool = await get_pool() + if pool is None: + return False + + try: + import json + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO rag_search_logs + (query_text, collection_name, result_count, latency_ms, top_score, filters) + VALUES ($1, $2, $3, $4, $5, $6) + """, + query_text, collection_name, result_count, latency_ms, top_score, + json.dumps(filters) if filters else None + ) + return True + except Exception as e: + print(f"Failed to log search: {e}") + return False + + +async def log_upload( + filename: str, + collection_name: str, + year: int, + pdfs_extracted: int, + minio_path: Optional[str] = None, + uploaded_by: Optional[str] = None, +) -> bool: + """Log an upload for history tracking.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO rag_upload_history + (filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by) + VALUES ($1, $2, $3, $4, $5, $6) + """, + filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by + ) + return True + except Exception as e: + print(f"Failed to log upload: {e}") + return False + + +# ============================================================================= +# Metrics Calculation +# ============================================================================= + +async def calculate_metrics( + collection_name: Optional[str] = None, + days: int = 7, +) -> Dict: + """ + Calculate RAG quality metrics from stored feedback. + + Returns: + Dict with precision, recall, MRR, latency, etc. + """ + pool = await get_pool() + if pool is None: + return {"error": "Database not available", "connected": False} + + try: + async with pool.acquire() as conn: + since = datetime.now() - timedelta(days=days) + + collection_filter = "" + params = [since] + if collection_name: + collection_filter = "AND collection_name = $2" + params.append(collection_name) + + total_feedback = await conn.fetchval( + f""" + SELECT COUNT(*) FROM rag_search_feedback + WHERE created_at >= $1 {collection_filter} + """, + *params + ) + + rating_dist = await conn.fetch( + f""" + SELECT rating, COUNT(*) as count + FROM rag_search_feedback + WHERE created_at >= $1 {collection_filter} + GROUP BY rating + ORDER BY rating DESC + """, + *params + ) + + avg_rating = await conn.fetchval( + f""" + SELECT AVG(rating) FROM rag_search_feedback + WHERE created_at >= $1 {collection_filter} + """, + *params + ) + + score_dist = await conn.fetch( + f""" + SELECT + CASE + WHEN score >= 0.9 THEN '0.9+' + WHEN score >= 0.7 THEN '0.7-0.9' + WHEN score >= 0.5 THEN '0.5-0.7' + ELSE '<0.5' + END as range, + COUNT(*) as count + FROM rag_search_feedback + WHERE created_at >= $1 AND score IS NOT NULL {collection_filter} + GROUP BY range + ORDER BY range DESC + """, + *params + ) + + latency_stats = await conn.fetchrow( + f""" + SELECT + AVG(latency_ms) as avg_latency, + COUNT(*) as total_searches, + AVG(result_count) as avg_results + FROM rag_search_logs + WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')} + """, + *params + ) + + precision_at_5 = await conn.fetchval( + f""" + SELECT + CASE WHEN COUNT(*) > 0 + THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*) + ELSE 0 END + FROM rag_search_feedback + WHERE created_at >= $1 {collection_filter} + """, + *params + ) or 0 + + mrr = (avg_rating or 0) / 5.0 + + error_count = sum( + r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2 + ) + error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0 + + total_scored = sum(s['count'] for s in score_dist) + score_distribution = {} + for s in score_dist: + if total_scored > 0: + score_distribution[s['range']] = round(s['count'] / total_scored * 100) + else: + score_distribution[s['range']] = 0 + + return { + "connected": True, + "period_days": days, + "precision_at_5": round(precision_at_5, 2), + "recall_at_10": round(precision_at_5 * 1.1, 2), + "mrr": round(mrr, 2), + "avg_latency_ms": round(latency_stats['avg_latency'] or 0), + "total_ratings": total_feedback, + "total_searches": latency_stats['total_searches'] or 0, + "error_rate": round(error_rate, 1), + "score_distribution": score_distribution, + "rating_distribution": { + str(r['rating']): r['count'] for r in rating_dist if r['rating'] + }, + } + + except Exception as e: + print(f"Failed to calculate metrics: {e}") + return {"error": str(e), "connected": False} + + +async def get_recent_feedback(limit: int = 20) -> List[Dict]: + """Get recent feedback entries.""" + pool = await get_pool() + if pool is None: + return [] + + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT result_id, rating, query_text, collection_name, score, notes, created_at + FROM rag_search_feedback + ORDER BY created_at DESC + LIMIT $1 + """, + limit + ) + return [ + { + "result_id": r['result_id'], + "rating": r['rating'], + "query_text": r['query_text'], + "collection_name": r['collection_name'], + "score": r['score'], + "notes": r['notes'], + "created_at": r['created_at'].isoformat() if r['created_at'] else None, + } + for r in rows + ] + except Exception as e: + print(f"Failed to get recent feedback: {e}") + return [] + + +async def get_upload_history(limit: int = 20) -> List[Dict]: + """Get recent upload history.""" + pool = await get_pool() + if pool is None: + return [] + + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at + FROM rag_upload_history + ORDER BY created_at DESC + LIMIT $1 + """, + limit + ) + return [ + { + "filename": r['filename'], + "collection_name": r['collection_name'], + "year": r['year'], + "pdfs_extracted": r['pdfs_extracted'], + "minio_path": r['minio_path'], + "uploaded_by": r['uploaded_by'], + "created_at": r['created_at'].isoformat() if r['created_at'] else None, + } + for r in rows + ] + except Exception as e: + print(f"Failed to get upload history: {e}") + return [] + + +# ============================================================================= +# Relevance Judgments (Binary Precision/Recall) +# ============================================================================= + +async def store_relevance_judgment( + query_id: str, + query_text: str, + result_id: str, + is_relevant: bool, + result_rank: Optional[int] = None, + collection_name: Optional[str] = None, + user_id: Optional[str] = None, +) -> bool: + """Store binary relevance judgment for Precision/Recall calculation.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO rag_relevance_judgments + (query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT DO NOTHING + """, + query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id + ) + return True + except Exception as e: + print(f"Failed to store relevance judgment: {e}") + return False + + +async def calculate_precision_recall( + collection_name: Optional[str] = None, + days: int = 7, + k: int = 10, +) -> Dict: + """ + Calculate true Precision@k and Recall@k from binary relevance judgments. + + Precision@k = (Relevant docs in top k) / k + Recall@k = (Relevant docs in top k) / (Total relevant docs for query) + """ + pool = await get_pool() + if pool is None: + return {"error": "Database not available", "connected": False} + + try: + async with pool.acquire() as conn: + since = datetime.now() - timedelta(days=days) + + collection_filter = "" + params = [since, k] + if collection_name: + collection_filter = "AND collection_name = $3" + params.append(collection_name) + + precision_result = await conn.fetchval( + f""" + WITH query_precision AS ( + SELECT + query_id, + COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT / + GREATEST(COUNT(*), 1) as precision + FROM rag_relevance_judgments + WHERE created_at >= $1 + AND (result_rank IS NULL OR result_rank <= $2) + {collection_filter} + GROUP BY query_id + ) + SELECT AVG(precision) FROM query_precision + """, + *params + ) or 0 + + recall_result = await conn.fetchval( + f""" + WITH query_recall AS ( + SELECT + query_id, + COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT / + GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall + FROM rag_relevance_judgments + WHERE created_at >= $1 + {collection_filter} + GROUP BY query_id + ) + SELECT AVG(recall) FROM query_recall + """, + *params + ) or 0 + + total_judgments = await conn.fetchval( + f""" + SELECT COUNT(*) FROM rag_relevance_judgments + WHERE created_at >= $1 {collection_filter} + """, + since, *([collection_name] if collection_name else []) + ) + + unique_queries = await conn.fetchval( + f""" + SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments + WHERE created_at >= $1 {collection_filter} + """, + since, *([collection_name] if collection_name else []) + ) + + return { + "connected": True, + "period_days": days, + "k": k, + "precision_at_k": round(precision_result, 3), + "recall_at_k": round(recall_result, 3), + "f1_score": round( + 2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3 + ), + "total_judgments": total_judgments or 0, + "unique_queries": unique_queries or 0, + } + + except Exception as e: + print(f"Failed to calculate precision/recall: {e}") + return {"error": str(e), "connected": False} diff --git a/klausur-service/backend/metrics_db_schema.py b/klausur-service/backend/metrics_db_schema.py new file mode 100644 index 0000000..ce7dedc --- /dev/null +++ b/klausur-service/backend/metrics_db_schema.py @@ -0,0 +1,182 @@ +""" +PostgreSQL Metrics Database - Schema Initialization + +Table creation DDL for all metrics, feedback, and zeugnis tables. + +Extracted from metrics_db_core.py to keep files under 500 LOC. +""" + +from metrics_db_core import get_pool + + +async def init_metrics_tables() -> bool: + """Initialize metrics tables in PostgreSQL.""" + pool = await get_pool() + if pool is None: + return False + + create_tables_sql = """ + -- RAG Search Feedback Table + CREATE TABLE IF NOT EXISTS rag_search_feedback ( + id SERIAL PRIMARY KEY, + result_id VARCHAR(255) NOT NULL, + query_text TEXT, + collection_name VARCHAR(100), + score FLOAT, + rating INTEGER CHECK (rating >= 1 AND rating <= 5), + notes TEXT, + user_id VARCHAR(100), + created_at TIMESTAMP DEFAULT NOW() + ); + + -- Index for efficient querying + CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at); + CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name); + CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating); + + -- RAG Search Logs Table (for latency tracking) + CREATE TABLE IF NOT EXISTS rag_search_logs ( + id SERIAL PRIMARY KEY, + query_text TEXT NOT NULL, + collection_name VARCHAR(100), + result_count INTEGER, + latency_ms INTEGER, + top_score FLOAT, + filters JSONB, + created_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at); + + -- RAG Upload History Table + CREATE TABLE IF NOT EXISTS rag_upload_history ( + id SERIAL PRIMARY KEY, + filename VARCHAR(500) NOT NULL, + collection_name VARCHAR(100), + year INTEGER, + pdfs_extracted INTEGER, + minio_path VARCHAR(1000), + uploaded_by VARCHAR(100), + created_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at); + + -- Binaere Relevanz-Judgments fuer echte Precision/Recall + CREATE TABLE IF NOT EXISTS rag_relevance_judgments ( + id SERIAL PRIMARY KEY, + query_id VARCHAR(255) NOT NULL, + query_text TEXT NOT NULL, + result_id VARCHAR(255) NOT NULL, + result_rank INTEGER, + is_relevant BOOLEAN NOT NULL, + collection_name VARCHAR(100), + user_id VARCHAR(100), + created_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id); + CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at); + + -- Zeugnisse Source Tracking + CREATE TABLE IF NOT EXISTS zeugnis_sources ( + id VARCHAR(36) PRIMARY KEY, + bundesland VARCHAR(10) NOT NULL, + name VARCHAR(255) NOT NULL, + base_url TEXT, + license_type VARCHAR(50) NOT NULL, + training_allowed BOOLEAN DEFAULT FALSE, + verified_by VARCHAR(100), + verified_at TIMESTAMP, + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland); + + -- Zeugnisse Seed URLs + CREATE TABLE IF NOT EXISTS zeugnis_seed_urls ( + id VARCHAR(36) PRIMARY KEY, + source_id VARCHAR(36) REFERENCES zeugnis_sources(id), + url TEXT NOT NULL, + doc_type VARCHAR(50), + status VARCHAR(20) DEFAULT 'pending', + last_crawled TIMESTAMP, + error_message TEXT, + created_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id); + CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status); + + -- Zeugnisse Documents + CREATE TABLE IF NOT EXISTS zeugnis_documents ( + id VARCHAR(36) PRIMARY KEY, + seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id), + title VARCHAR(500), + url TEXT NOT NULL, + content_hash VARCHAR(64), + minio_path TEXT, + training_allowed BOOLEAN DEFAULT FALSE, + indexed_in_qdrant BOOLEAN DEFAULT FALSE, + file_size INTEGER, + content_type VARCHAR(100), + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id); + CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash); + + -- Zeugnisse Document Versions + CREATE TABLE IF NOT EXISTS zeugnis_document_versions ( + id VARCHAR(36) PRIMARY KEY, + document_id VARCHAR(36) REFERENCES zeugnis_documents(id), + version INTEGER NOT NULL, + content_hash VARCHAR(64), + minio_path TEXT, + change_summary TEXT, + created_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id); + + -- Zeugnisse Usage Events (Audit Trail) + CREATE TABLE IF NOT EXISTS zeugnis_usage_events ( + id VARCHAR(36) PRIMARY KEY, + document_id VARCHAR(36) REFERENCES zeugnis_documents(id), + event_type VARCHAR(50) NOT NULL, + user_id VARCHAR(100), + details JSONB, + created_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id); + CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type); + CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at); + + -- Crawler Queue + CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue ( + id VARCHAR(36) PRIMARY KEY, + source_id VARCHAR(36) REFERENCES zeugnis_sources(id), + priority INTEGER DEFAULT 5, + status VARCHAR(20) DEFAULT 'pending', + started_at TIMESTAMP, + completed_at TIMESTAMP, + documents_found INTEGER DEFAULT 0, + documents_indexed INTEGER DEFAULT 0, + error_count INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT NOW() + ); + + CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status); + """ + + try: + async with pool.acquire() as conn: + await conn.execute(create_tables_sql) + print("RAG metrics tables initialized") + return True + except Exception as e: + print(f"Failed to initialize metrics tables: {e}") + return False diff --git a/klausur-service/backend/metrics_db_zeugnis.py b/klausur-service/backend/metrics_db_zeugnis.py new file mode 100644 index 0000000..94acd6a --- /dev/null +++ b/klausur-service/backend/metrics_db_zeugnis.py @@ -0,0 +1,193 @@ +""" +PostgreSQL Metrics Database - Zeugnis Operations + +Zeugnis source management, document queries, statistics, and event logging. + +Extracted from metrics_db.py to keep files under 500 LOC. +""" + +from typing import Optional, List, Dict + +from metrics_db_core import get_pool + + +# ============================================================================= +# Zeugnis Database Operations +# ============================================================================= + +async def get_zeugnis_sources() -> List[Dict]: + """Get all zeugnis sources (Bundeslaender).""" + pool = await get_pool() + if pool is None: + return [] + + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT id, bundesland, name, base_url, license_type, training_allowed, + verified_by, verified_at, created_at, updated_at + FROM zeugnis_sources + ORDER BY bundesland + """ + ) + return [dict(r) for r in rows] + except Exception as e: + print(f"Failed to get zeugnis sources: {e}") + return [] + + +async def upsert_zeugnis_source( + id: str, + bundesland: str, + name: str, + license_type: str, + training_allowed: bool, + base_url: Optional[str] = None, + verified_by: Optional[str] = None, +) -> bool: + """Insert or update a zeugnis source.""" + pool = await get_pool() + if pool is None: + return False + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + ON CONFLICT (id) DO UPDATE SET + name = EXCLUDED.name, + base_url = EXCLUDED.base_url, + license_type = EXCLUDED.license_type, + training_allowed = EXCLUDED.training_allowed, + verified_by = EXCLUDED.verified_by, + verified_at = NOW(), + updated_at = NOW() + """, + id, bundesland, name, base_url, license_type, training_allowed, verified_by + ) + return True + except Exception as e: + print(f"Failed to upsert zeugnis source: {e}") + return False + + +async def get_zeugnis_documents( + bundesland: Optional[str] = None, + limit: int = 100, + offset: int = 0, +) -> List[Dict]: + """Get zeugnis documents with optional filtering.""" + pool = await get_pool() + if pool is None: + return [] + + try: + async with pool.acquire() as conn: + if bundesland: + rows = await conn.fetch( + """ + SELECT d.*, s.bundesland, s.name as source_name + FROM zeugnis_documents d + JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id + JOIN zeugnis_sources s ON u.source_id = s.id + WHERE s.bundesland = $1 + ORDER BY d.created_at DESC + LIMIT $2 OFFSET $3 + """, + bundesland, limit, offset + ) + else: + rows = await conn.fetch( + """ + SELECT d.*, s.bundesland, s.name as source_name + FROM zeugnis_documents d + JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id + JOIN zeugnis_sources s ON u.source_id = s.id + ORDER BY d.created_at DESC + LIMIT $1 OFFSET $2 + """, + limit, offset + ) + return [dict(r) for r in rows] + except Exception as e: + print(f"Failed to get zeugnis documents: {e}") + return [] + + +async def get_zeugnis_stats() -> Dict: + """Get zeugnis crawler statistics.""" + pool = await get_pool() + if pool is None: + return {"error": "Database not available"} + + try: + async with pool.acquire() as conn: + sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources") + documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents") + + indexed = await conn.fetchval( + "SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true" + ) + + training_allowed = await conn.fetchval( + "SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true" + ) + + per_bundesland = await conn.fetch( + """ + SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count + FROM zeugnis_sources s + LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id + LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id + GROUP BY s.bundesland, s.name, s.training_allowed + ORDER BY s.bundesland + """ + ) + + active_crawls = await conn.fetchval( + "SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'" + ) + + return { + "total_sources": sources or 0, + "total_documents": documents or 0, + "indexed_documents": indexed or 0, + "training_allowed_documents": training_allowed or 0, + "active_crawls": active_crawls or 0, + "per_bundesland": [dict(r) for r in per_bundesland], + } + except Exception as e: + print(f"Failed to get zeugnis stats: {e}") + return {"error": str(e)} + + +async def log_zeugnis_event( + document_id: str, + event_type: str, + user_id: Optional[str] = None, + details: Optional[Dict] = None, +) -> bool: + """Log a zeugnis usage event for audit trail.""" + pool = await get_pool() + if pool is None: + return False + + try: + import json + import uuid + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details) + VALUES ($1, $2, $3, $4, $5) + """, + str(uuid.uuid4()), document_id, event_type, user_id, + json.dumps(details) if details else None + ) + return True + except Exception as e: + print(f"Failed to log zeugnis event: {e}") + return False diff --git a/klausur-service/backend/ocr_labeling_api.py b/klausur-service/backend/ocr_labeling_api.py index 43e7f61..924f964 100644 --- a/klausur-service/backend/ocr_labeling_api.py +++ b/klausur-service/backend/ocr_labeling_api.py @@ -1,845 +1,81 @@ """ -OCR Labeling API for Handwriting Training Data Collection +OCR Labeling API — Barrel Re-export -DATENSCHUTZ/PRIVACY: -- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama) -- Keine Daten werden an externe Server gesendet -- Bilder werden mit SHA256-Hash dedupliziert -- Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision) +Split into: +- ocr_labeling_models.py — Pydantic models and constants +- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing +- ocr_labeling_routes.py — Session/queue/labeling route handlers +- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers -Endpoints: -- POST /sessions - Create labeling session -- POST /sessions/{id}/upload - Upload images for labeling -- GET /queue - Get next items to label -- POST /confirm - Confirm OCR as correct -- POST /correct - Save corrected ground truth -- POST /skip - Skip unusable item -- GET /stats - Get labeling statistics -- POST /export - Export training data +All public names are re-exported here for backward compatibility. """ -from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks -from pydantic import BaseModel -from typing import Optional, List, Dict, Any -from datetime import datetime -import uuid -import hashlib -import os -import base64 - -# Import database functions -from metrics_db import ( - create_ocr_labeling_session, - get_ocr_labeling_sessions, - get_ocr_labeling_session, - add_ocr_labeling_item, - get_ocr_labeling_queue, - get_ocr_labeling_item, - confirm_ocr_label, - correct_ocr_label, - skip_ocr_item, - get_ocr_labeling_stats, - export_training_samples, - get_training_samples, +# Models +from ocr_labeling_models import ( # noqa: F401 + LOCAL_STORAGE_PATH, + SessionCreate, + SessionResponse, + ItemResponse, + ConfirmRequest, + CorrectRequest, + SkipRequest, + ExportRequest, + StatsResponse, ) -# Try to import Vision OCR service -try: - import sys - sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services')) - from vision_ocr_service import get_vision_ocr_service, VisionOCRService - VISION_OCR_AVAILABLE = True -except ImportError: - VISION_OCR_AVAILABLE = False - print("Warning: Vision OCR service not available") +# Helpers +from ocr_labeling_helpers import ( # noqa: F401 + VISION_OCR_AVAILABLE, + PADDLEOCR_AVAILABLE, + TROCR_AVAILABLE, + DONUT_AVAILABLE, + MINIO_AVAILABLE, + TRAINING_EXPORT_AVAILABLE, + compute_image_hash, + run_ocr_on_image, + run_vision_ocr_wrapper, + run_paddleocr_wrapper, + run_trocr_wrapper, + run_donut_wrapper, + save_image_locally, + get_image_url, +) -# Try to import PaddleOCR from hybrid_vocab_extractor +# Conditional re-exports from helpers' optional imports try: - from hybrid_vocab_extractor import run_paddle_ocr - PADDLEOCR_AVAILABLE = True + from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401 except ImportError: - PADDLEOCR_AVAILABLE = False - print("Warning: PaddleOCR not available") + pass -# Try to import TrOCR service try: - from services.trocr_service import run_trocr_ocr - TROCR_AVAILABLE = True -except ImportError: - TROCR_AVAILABLE = False - print("Warning: TrOCR service not available") - -# Try to import Donut service -try: - from services.donut_ocr_service import run_donut_ocr - DONUT_AVAILABLE = True -except ImportError: - DONUT_AVAILABLE = False - print("Warning: Donut OCR service not available") - -# Try to import MinIO storage -try: - from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET - MINIO_AVAILABLE = True -except ImportError: - MINIO_AVAILABLE = False - print("Warning: MinIO storage not available, using local storage") - -# Try to import Training Export Service -try: - from training_export_service import ( + from training_export_service import ( # noqa: F401 TrainingExportService, TrainingSample, get_training_export_service, ) - TRAINING_EXPORT_AVAILABLE = True except ImportError: - TRAINING_EXPORT_AVAILABLE = False - print("Warning: Training export service not available") - -router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) - -# Local storage path (fallback if MinIO not available) -LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling") - - -# ============================================================================= -# Pydantic Models -# ============================================================================= - -class SessionCreate(BaseModel): - name: str - source_type: str = "klausur" # klausur, handwriting_sample, scan - description: Optional[str] = None - ocr_model: Optional[str] = "llama3.2-vision:11b" - - -class SessionResponse(BaseModel): - id: str - name: str - source_type: str - description: Optional[str] - ocr_model: Optional[str] - total_items: int - labeled_items: int - confirmed_items: int - corrected_items: int - skipped_items: int - created_at: datetime - - -class ItemResponse(BaseModel): - id: str - session_id: str - session_name: str - image_path: str - image_url: Optional[str] - ocr_text: Optional[str] - ocr_confidence: Optional[float] - ground_truth: Optional[str] - status: str - metadata: Optional[Dict] - created_at: datetime - - -class ConfirmRequest(BaseModel): - item_id: str - label_time_seconds: Optional[int] = None - - -class CorrectRequest(BaseModel): - item_id: str - ground_truth: str - label_time_seconds: Optional[int] = None - - -class SkipRequest(BaseModel): - item_id: str - - -class ExportRequest(BaseModel): - export_format: str = "generic" # generic, trocr, llama_vision - session_id: Optional[str] = None - batch_id: Optional[str] = None - - -class StatsResponse(BaseModel): - total_sessions: Optional[int] = None - total_items: int - labeled_items: int - confirmed_items: int - corrected_items: int - pending_items: int - exportable_items: Optional[int] = None - accuracy_rate: float - avg_label_time_seconds: Optional[float] = None - - -# ============================================================================= -# Helper Functions -# ============================================================================= - -def compute_image_hash(image_data: bytes) -> str: - """Compute SHA256 hash of image data.""" - return hashlib.sha256(image_data).hexdigest() - - -async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple: - """ - Run OCR on an image using the specified model. - - Models: - - llama3.2-vision:11b: Vision LLM (default, best for handwriting) - - trocr: Microsoft TrOCR (fast for printed text) - - paddleocr: PaddleOCR + LLM hybrid (4x faster) - - donut: Document Understanding Transformer (structured documents) - - Returns: - Tuple of (ocr_text, confidence) - """ - print(f"Running OCR with model: {model}") - - # Route to appropriate OCR service based on model - if model == "paddleocr": - return await run_paddleocr_wrapper(image_data, filename) - elif model == "donut": - return await run_donut_wrapper(image_data, filename) - elif model == "trocr": - return await run_trocr_wrapper(image_data, filename) - else: - # Default: Vision LLM (llama3.2-vision or similar) - return await run_vision_ocr_wrapper(image_data, filename) - - -async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple: - """Vision LLM OCR wrapper.""" - if not VISION_OCR_AVAILABLE: - print("Vision OCR service not available") - return None, 0.0 - - try: - service = get_vision_ocr_service() - if not await service.is_available(): - print("Vision OCR service not available (is_available check failed)") - return None, 0.0 - - result = await service.extract_text( - image_data, - filename=filename, - is_handwriting=True - ) - return result.text, result.confidence - except Exception as e: - print(f"Vision OCR failed: {e}") - return None, 0.0 - - -async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple: - """PaddleOCR wrapper - uses hybrid_vocab_extractor.""" - if not PADDLEOCR_AVAILABLE: - print("PaddleOCR not available, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - try: - # run_paddle_ocr returns (regions, raw_text) - regions, raw_text = run_paddle_ocr(image_data) - - if not raw_text: - print("PaddleOCR returned empty text") - return None, 0.0 - - # Calculate average confidence from regions - if regions: - avg_confidence = sum(r.confidence for r in regions) / len(regions) - else: - avg_confidence = 0.5 - - return raw_text, avg_confidence - except Exception as e: - print(f"PaddleOCR failed: {e}, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - -async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple: - """TrOCR wrapper.""" - if not TROCR_AVAILABLE: - print("TrOCR not available, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - try: - text, confidence = await run_trocr_ocr(image_data) - return text, confidence - except Exception as e: - print(f"TrOCR failed: {e}, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - -async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple: - """Donut OCR wrapper.""" - if not DONUT_AVAILABLE: - print("Donut not available, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - try: - text, confidence = await run_donut_ocr(image_data) - return text, confidence - except Exception as e: - print(f"Donut OCR failed: {e}, falling back to Vision OCR") - return await run_vision_ocr_wrapper(image_data, filename) - - -def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str: - """Save image to local storage.""" - session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id) - os.makedirs(session_dir, exist_ok=True) - - filename = f"{item_id}.{extension}" - filepath = os.path.join(session_dir, filename) - - with open(filepath, 'wb') as f: - f.write(image_data) - - return filepath - - -def get_image_url(image_path: str) -> str: - """Get URL for an image.""" - # For local images, return a relative path that the frontend can use - if image_path.startswith(LOCAL_STORAGE_PATH): - relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/') - return f"/api/v1/ocr-label/images/{relative_path}" - # For MinIO images, the path is already a URL or key - return image_path - - -# ============================================================================= -# API Endpoints -# ============================================================================= - -@router.post("/sessions", response_model=SessionResponse) -async def create_session(session: SessionCreate): - """ - Create a new OCR labeling session. - - A session groups related images for labeling (e.g., all scans from one class). - """ - session_id = str(uuid.uuid4()) - - success = await create_ocr_labeling_session( - session_id=session_id, - name=session.name, - source_type=session.source_type, - description=session.description, - ocr_model=session.ocr_model, - ) - - if not success: - raise HTTPException(status_code=500, detail="Failed to create session") - - return SessionResponse( - id=session_id, - name=session.name, - source_type=session.source_type, - description=session.description, - ocr_model=session.ocr_model, - total_items=0, - labeled_items=0, - confirmed_items=0, - corrected_items=0, - skipped_items=0, - created_at=datetime.utcnow(), - ) - - -@router.get("/sessions", response_model=List[SessionResponse]) -async def list_sessions(limit: int = Query(50, ge=1, le=100)): - """List all OCR labeling sessions.""" - sessions = await get_ocr_labeling_sessions(limit=limit) - - return [ - SessionResponse( - id=s['id'], - name=s['name'], - source_type=s['source_type'], - description=s.get('description'), - ocr_model=s.get('ocr_model'), - total_items=s.get('total_items', 0), - labeled_items=s.get('labeled_items', 0), - confirmed_items=s.get('confirmed_items', 0), - corrected_items=s.get('corrected_items', 0), - skipped_items=s.get('skipped_items', 0), - created_at=s.get('created_at', datetime.utcnow()), - ) - for s in sessions - ] - - -@router.get("/sessions/{session_id}", response_model=SessionResponse) -async def get_session(session_id: str): - """Get a specific OCR labeling session.""" - session = await get_ocr_labeling_session(session_id) - - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - return SessionResponse( - id=session['id'], - name=session['name'], - source_type=session['source_type'], - description=session.get('description'), - ocr_model=session.get('ocr_model'), - total_items=session.get('total_items', 0), - labeled_items=session.get('labeled_items', 0), - confirmed_items=session.get('confirmed_items', 0), - corrected_items=session.get('corrected_items', 0), - skipped_items=session.get('skipped_items', 0), - created_at=session.get('created_at', datetime.utcnow()), - ) - - -@router.post("/sessions/{session_id}/upload") -async def upload_images( - session_id: str, - background_tasks: BackgroundTasks, - files: List[UploadFile] = File(...), - run_ocr: bool = Form(True), - metadata: Optional[str] = Form(None), # JSON string -): - """ - Upload images to a labeling session. - - Args: - session_id: Session to add images to - files: Image files to upload (PNG, JPG, PDF) - run_ocr: Whether to run OCR immediately (default: True) - metadata: Optional JSON metadata (subject, year, etc.) - """ - import json - - # Verify session exists - session = await get_ocr_labeling_session(session_id) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - # Parse metadata - meta_dict = None - if metadata: - try: - meta_dict = json.loads(metadata) - except json.JSONDecodeError: - meta_dict = {"raw": metadata} - - results = [] - ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') - - for file in files: - # Read file content - content = await file.read() - - # Compute hash for deduplication - image_hash = compute_image_hash(content) - - # Generate item ID - item_id = str(uuid.uuid4()) - - # Determine file extension - extension = file.filename.split('.')[-1].lower() if file.filename else 'png' - if extension not in ['png', 'jpg', 'jpeg', 'pdf']: - extension = 'png' - - # Save image - if MINIO_AVAILABLE: - # Upload to MinIO - try: - image_path = upload_ocr_image(session_id, item_id, content, extension) - except Exception as e: - print(f"MinIO upload failed, using local storage: {e}") - image_path = save_image_locally(session_id, item_id, content, extension) - else: - # Save locally - image_path = save_image_locally(session_id, item_id, content, extension) - - # Run OCR if requested - ocr_text = None - ocr_confidence = None - - if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now - ocr_text, ocr_confidence = await run_ocr_on_image( - content, - file.filename or f"{item_id}.{extension}", - model=ocr_model - ) - - # Add to database - success = await add_ocr_labeling_item( - item_id=item_id, - session_id=session_id, - image_path=image_path, - image_hash=image_hash, - ocr_text=ocr_text, - ocr_confidence=ocr_confidence, - ocr_model=ocr_model if ocr_text else None, - metadata=meta_dict, - ) - - if success: - results.append({ - "id": item_id, - "filename": file.filename, - "image_path": image_path, - "image_hash": image_hash, - "ocr_text": ocr_text, - "ocr_confidence": ocr_confidence, - "status": "pending", - }) - - return { - "session_id": session_id, - "uploaded_count": len(results), - "items": results, - } - - -@router.get("/queue", response_model=List[ItemResponse]) -async def get_labeling_queue( - session_id: Optional[str] = Query(None), - status: str = Query("pending"), - limit: int = Query(10, ge=1, le=50), -): - """ - Get items from the labeling queue. - - Args: - session_id: Optional filter by session - status: Filter by status (pending, confirmed, corrected, skipped) - limit: Number of items to return - """ - items = await get_ocr_labeling_queue( - session_id=session_id, - status=status, - limit=limit, - ) - - return [ - ItemResponse( - id=item['id'], - session_id=item['session_id'], - session_name=item.get('session_name', ''), - image_path=item['image_path'], - image_url=get_image_url(item['image_path']), - ocr_text=item.get('ocr_text'), - ocr_confidence=item.get('ocr_confidence'), - ground_truth=item.get('ground_truth'), - status=item.get('status', 'pending'), - metadata=item.get('metadata'), - created_at=item.get('created_at', datetime.utcnow()), - ) - for item in items - ] - - -@router.get("/items/{item_id}", response_model=ItemResponse) -async def get_item(item_id: str): - """Get a specific labeling item.""" - item = await get_ocr_labeling_item(item_id) - - if not item: - raise HTTPException(status_code=404, detail="Item not found") - - return ItemResponse( - id=item['id'], - session_id=item['session_id'], - session_name=item.get('session_name', ''), - image_path=item['image_path'], - image_url=get_image_url(item['image_path']), - ocr_text=item.get('ocr_text'), - ocr_confidence=item.get('ocr_confidence'), - ground_truth=item.get('ground_truth'), - status=item.get('status', 'pending'), - metadata=item.get('metadata'), - created_at=item.get('created_at', datetime.utcnow()), - ) - - -@router.post("/confirm") -async def confirm_item(request: ConfirmRequest): - """ - Confirm that OCR text is correct. - - Sets ground_truth = ocr_text and marks item as confirmed. - """ - success = await confirm_ocr_label( - item_id=request.item_id, - labeled_by="admin", # TODO: Get from auth - label_time_seconds=request.label_time_seconds, - ) - - if not success: - raise HTTPException(status_code=400, detail="Failed to confirm item") - - return {"status": "confirmed", "item_id": request.item_id} - - -@router.post("/correct") -async def correct_item(request: CorrectRequest): - """ - Save corrected ground truth for an item. - - Use this when OCR text is wrong and needs manual correction. - """ - success = await correct_ocr_label( - item_id=request.item_id, - ground_truth=request.ground_truth, - labeled_by="admin", # TODO: Get from auth - label_time_seconds=request.label_time_seconds, - ) - - if not success: - raise HTTPException(status_code=400, detail="Failed to correct item") - - return {"status": "corrected", "item_id": request.item_id} - - -@router.post("/skip") -async def skip_item(request: SkipRequest): - """ - Skip an item (unusable image, etc.). - - Skipped items are not included in training exports. - """ - success = await skip_ocr_item( - item_id=request.item_id, - labeled_by="admin", # TODO: Get from auth - ) - - if not success: - raise HTTPException(status_code=400, detail="Failed to skip item") - - return {"status": "skipped", "item_id": request.item_id} - - -@router.get("/stats") -async def get_stats(session_id: Optional[str] = Query(None)): - """ - Get labeling statistics. - - Args: - session_id: Optional session ID for session-specific stats - """ - stats = await get_ocr_labeling_stats(session_id=session_id) - - if "error" in stats: - raise HTTPException(status_code=500, detail=stats["error"]) - - return stats - - -@router.post("/export") -async def export_data(request: ExportRequest): - """ - Export labeled data for training. - - Formats: - - generic: JSONL with image_path and ground_truth - - trocr: Format for TrOCR/Microsoft Transformer fine-tuning - - llama_vision: Format for llama3.2-vision fine-tuning - - Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/ - """ - # First, get samples from database - db_samples = await export_training_samples( - export_format=request.export_format, - session_id=request.session_id, - batch_id=request.batch_id, - exported_by="admin", # TODO: Get from auth - ) - - if not db_samples: - return { - "export_format": request.export_format, - "batch_id": request.batch_id, - "exported_count": 0, - "samples": [], - "message": "No labeled samples found to export", - } - - # If training export service is available, also write to disk - export_result = None - if TRAINING_EXPORT_AVAILABLE: - try: - export_service = get_training_export_service() - - # Convert DB samples to TrainingSample objects - training_samples = [] - for s in db_samples: - training_samples.append(TrainingSample( - id=s.get('id', s.get('item_id', '')), - image_path=s.get('image_path', ''), - ground_truth=s.get('ground_truth', ''), - ocr_text=s.get('ocr_text'), - ocr_confidence=s.get('ocr_confidence'), - metadata=s.get('metadata'), - )) - - # Export to files - export_result = export_service.export( - samples=training_samples, - export_format=request.export_format, - batch_id=request.batch_id, - ) - except Exception as e: - print(f"Training export failed: {e}") - # Continue without file export - - response = { - "export_format": request.export_format, - "batch_id": request.batch_id or (export_result.batch_id if export_result else None), - "exported_count": len(db_samples), - "samples": db_samples, - } - - if export_result: - response["export_path"] = export_result.export_path - response["manifest_path"] = export_result.manifest_path - - return response - - -@router.get("/training-samples") -async def list_training_samples( - export_format: Optional[str] = Query(None), - batch_id: Optional[str] = Query(None), - limit: int = Query(100, ge=1, le=1000), -): - """Get exported training samples.""" - samples = await get_training_samples( - export_format=export_format, - batch_id=batch_id, - limit=limit, - ) - - return { - "count": len(samples), - "samples": samples, - } - - -@router.get("/images/{path:path}") -async def get_image(path: str): - """ - Serve an image from local storage. - - This endpoint is used when images are stored locally (not in MinIO). - """ - from fastapi.responses import FileResponse - - filepath = os.path.join(LOCAL_STORAGE_PATH, path) - - if not os.path.exists(filepath): - raise HTTPException(status_code=404, detail="Image not found") - - # Determine content type - extension = filepath.split('.')[-1].lower() - content_type = { - 'png': 'image/png', - 'jpg': 'image/jpeg', - 'jpeg': 'image/jpeg', - 'pdf': 'application/pdf', - }.get(extension, 'application/octet-stream') - - return FileResponse(filepath, media_type=content_type) - - -@router.post("/run-ocr/{item_id}") -async def run_ocr_for_item(item_id: str): - """ - Run OCR on an existing item. - - Use this to re-run OCR or run it if it was skipped during upload. - """ - item = await get_ocr_labeling_item(item_id) - - if not item: - raise HTTPException(status_code=404, detail="Item not found") - - # Load image - image_path = item['image_path'] - - if image_path.startswith(LOCAL_STORAGE_PATH): - # Load from local storage - if not os.path.exists(image_path): - raise HTTPException(status_code=404, detail="Image file not found") - with open(image_path, 'rb') as f: - image_data = f.read() - elif MINIO_AVAILABLE: - # Load from MinIO - try: - image_data = get_ocr_image(image_path) - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to load image: {e}") - else: - raise HTTPException(status_code=500, detail="Cannot load image") - - # Get OCR model from session - session = await get_ocr_labeling_session(item['session_id']) - ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b' - - # Run OCR - ocr_text, ocr_confidence = await run_ocr_on_image( - image_data, - os.path.basename(image_path), - model=ocr_model - ) - - if ocr_text is None: - raise HTTPException(status_code=500, detail="OCR failed") - - # Update item in database - from metrics_db import get_pool - pool = await get_pool() - if pool: - async with pool.acquire() as conn: - await conn.execute( - """ - UPDATE ocr_labeling_items - SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4 - WHERE id = $1 - """, - item_id, ocr_text, ocr_confidence, ocr_model - ) - - return { - "item_id": item_id, - "ocr_text": ocr_text, - "ocr_confidence": ocr_confidence, - "ocr_model": ocr_model, - } - - -@router.get("/exports") -async def list_exports(export_format: Optional[str] = Query(None)): - """ - List all available training data exports. - - Args: - export_format: Optional filter by format (generic, trocr, llama_vision) - - Returns: - List of export manifests with paths and metadata - """ - if not TRAINING_EXPORT_AVAILABLE: - return { - "exports": [], - "message": "Training export service not available", - } - - try: - export_service = get_training_export_service() - exports = export_service.list_exports(export_format=export_format) - - return { - "count": len(exports), - "exports": exports, - } - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}") + pass + +try: + from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401 +except ImportError: + pass + +try: + from services.trocr_service import run_trocr_ocr # noqa: F401 +except ImportError: + pass + +try: + from services.donut_ocr_service import run_donut_ocr # noqa: F401 +except ImportError: + pass + +try: + from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401 +except ImportError: + pass + +# Routes (router is the main export for app.include_router) +from ocr_labeling_routes import router # noqa: F401 +from ocr_labeling_upload_routes import router as upload_router # noqa: F401 diff --git a/klausur-service/backend/ocr_labeling_helpers.py b/klausur-service/backend/ocr_labeling_helpers.py new file mode 100644 index 0000000..5188670 --- /dev/null +++ b/klausur-service/backend/ocr_labeling_helpers.py @@ -0,0 +1,205 @@ +""" +OCR Labeling - Helper Functions and OCR Wrappers + +Extracted from ocr_labeling_api.py to keep files under 500 LOC. + +DATENSCHUTZ/PRIVACY: +- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama) +- Keine Daten werden an externe Server gesendet +""" + +import os +import hashlib + +from ocr_labeling_models import LOCAL_STORAGE_PATH + +# Try to import Vision OCR service +try: + import sys + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services')) + from vision_ocr_service import get_vision_ocr_service, VisionOCRService + VISION_OCR_AVAILABLE = True +except ImportError: + VISION_OCR_AVAILABLE = False + print("Warning: Vision OCR service not available") + +# Try to import PaddleOCR from hybrid_vocab_extractor +try: + from hybrid_vocab_extractor import run_paddle_ocr + PADDLEOCR_AVAILABLE = True +except ImportError: + PADDLEOCR_AVAILABLE = False + print("Warning: PaddleOCR not available") + +# Try to import TrOCR service +try: + from services.trocr_service import run_trocr_ocr + TROCR_AVAILABLE = True +except ImportError: + TROCR_AVAILABLE = False + print("Warning: TrOCR service not available") + +# Try to import Donut service +try: + from services.donut_ocr_service import run_donut_ocr + DONUT_AVAILABLE = True +except ImportError: + DONUT_AVAILABLE = False + print("Warning: Donut OCR service not available") + +# Try to import MinIO storage +try: + from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET + MINIO_AVAILABLE = True +except ImportError: + MINIO_AVAILABLE = False + print("Warning: MinIO storage not available, using local storage") + +# Try to import Training Export Service +try: + from training_export_service import ( + TrainingExportService, + TrainingSample, + get_training_export_service, + ) + TRAINING_EXPORT_AVAILABLE = True +except ImportError: + TRAINING_EXPORT_AVAILABLE = False + print("Warning: Training export service not available") + + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def compute_image_hash(image_data: bytes) -> str: + """Compute SHA256 hash of image data.""" + return hashlib.sha256(image_data).hexdigest() + + +async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple: + """ + Run OCR on an image using the specified model. + + Models: + - llama3.2-vision:11b: Vision LLM (default, best for handwriting) + - trocr: Microsoft TrOCR (fast for printed text) + - paddleocr: PaddleOCR + LLM hybrid (4x faster) + - donut: Document Understanding Transformer (structured documents) + + Returns: + Tuple of (ocr_text, confidence) + """ + print(f"Running OCR with model: {model}") + + # Route to appropriate OCR service based on model + if model == "paddleocr": + return await run_paddleocr_wrapper(image_data, filename) + elif model == "donut": + return await run_donut_wrapper(image_data, filename) + elif model == "trocr": + return await run_trocr_wrapper(image_data, filename) + else: + # Default: Vision LLM (llama3.2-vision or similar) + return await run_vision_ocr_wrapper(image_data, filename) + + +async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple: + """Vision LLM OCR wrapper.""" + if not VISION_OCR_AVAILABLE: + print("Vision OCR service not available") + return None, 0.0 + + try: + service = get_vision_ocr_service() + if not await service.is_available(): + print("Vision OCR service not available (is_available check failed)") + return None, 0.0 + + result = await service.extract_text( + image_data, + filename=filename, + is_handwriting=True + ) + return result.text, result.confidence + except Exception as e: + print(f"Vision OCR failed: {e}") + return None, 0.0 + + +async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple: + """PaddleOCR wrapper - uses hybrid_vocab_extractor.""" + if not PADDLEOCR_AVAILABLE: + print("PaddleOCR not available, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + try: + # run_paddle_ocr returns (regions, raw_text) + regions, raw_text = run_paddle_ocr(image_data) + + if not raw_text: + print("PaddleOCR returned empty text") + return None, 0.0 + + # Calculate average confidence from regions + if regions: + avg_confidence = sum(r.confidence for r in regions) / len(regions) + else: + avg_confidence = 0.5 + + return raw_text, avg_confidence + except Exception as e: + print(f"PaddleOCR failed: {e}, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + +async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple: + """TrOCR wrapper.""" + if not TROCR_AVAILABLE: + print("TrOCR not available, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + try: + text, confidence = await run_trocr_ocr(image_data) + return text, confidence + except Exception as e: + print(f"TrOCR failed: {e}, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + +async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple: + """Donut OCR wrapper.""" + if not DONUT_AVAILABLE: + print("Donut not available, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + try: + text, confidence = await run_donut_ocr(image_data) + return text, confidence + except Exception as e: + print(f"Donut OCR failed: {e}, falling back to Vision OCR") + return await run_vision_ocr_wrapper(image_data, filename) + + +def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str: + """Save image to local storage.""" + session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id) + os.makedirs(session_dir, exist_ok=True) + + filename = f"{item_id}.{extension}" + filepath = os.path.join(session_dir, filename) + + with open(filepath, 'wb') as f: + f.write(image_data) + + return filepath + + +def get_image_url(image_path: str) -> str: + """Get URL for an image.""" + # For local images, return a relative path that the frontend can use + if image_path.startswith(LOCAL_STORAGE_PATH): + relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/') + return f"/api/v1/ocr-label/images/{relative_path}" + # For MinIO images, the path is already a URL or key + return image_path diff --git a/klausur-service/backend/ocr_labeling_models.py b/klausur-service/backend/ocr_labeling_models.py new file mode 100644 index 0000000..f27601f --- /dev/null +++ b/klausur-service/backend/ocr_labeling_models.py @@ -0,0 +1,86 @@ +""" +OCR Labeling - Pydantic Models and Constants + +Extracted from ocr_labeling_api.py to keep files under 500 LOC. +""" + +import os +from pydantic import BaseModel +from typing import Optional, Dict +from datetime import datetime + + +# Local storage path (fallback if MinIO not available) +LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling") + + +# ============================================================================= +# Pydantic Models +# ============================================================================= + +class SessionCreate(BaseModel): + name: str + source_type: str = "klausur" # klausur, handwriting_sample, scan + description: Optional[str] = None + ocr_model: Optional[str] = "llama3.2-vision:11b" + + +class SessionResponse(BaseModel): + id: str + name: str + source_type: str + description: Optional[str] + ocr_model: Optional[str] + total_items: int + labeled_items: int + confirmed_items: int + corrected_items: int + skipped_items: int + created_at: datetime + + +class ItemResponse(BaseModel): + id: str + session_id: str + session_name: str + image_path: str + image_url: Optional[str] + ocr_text: Optional[str] + ocr_confidence: Optional[float] + ground_truth: Optional[str] + status: str + metadata: Optional[Dict] + created_at: datetime + + +class ConfirmRequest(BaseModel): + item_id: str + label_time_seconds: Optional[int] = None + + +class CorrectRequest(BaseModel): + item_id: str + ground_truth: str + label_time_seconds: Optional[int] = None + + +class SkipRequest(BaseModel): + item_id: str + + +class ExportRequest(BaseModel): + export_format: str = "generic" # generic, trocr, llama_vision + session_id: Optional[str] = None + batch_id: Optional[str] = None + + +class StatsResponse(BaseModel): + total_sessions: Optional[int] = None + total_items: int + labeled_items: int + confirmed_items: int + corrected_items: int + pending_items: int + exportable_items: Optional[int] = None + accuracy_rate: float + avg_label_time_seconds: Optional[float] = None diff --git a/klausur-service/backend/ocr_labeling_routes.py b/klausur-service/backend/ocr_labeling_routes.py new file mode 100644 index 0000000..b2365da --- /dev/null +++ b/klausur-service/backend/ocr_labeling_routes.py @@ -0,0 +1,241 @@ +""" +OCR Labeling - Session and Labeling Route Handlers + +Extracted from ocr_labeling_api.py to keep files under 500 LOC. + +Endpoints: +- POST /sessions - Create labeling session +- GET /sessions - List sessions +- GET /sessions/{id} - Get session +- GET /queue - Get labeling queue +- GET /items/{id} - Get item +- POST /confirm - Confirm OCR +- POST /correct - Correct ground truth +- POST /skip - Skip item +- GET /stats - Get statistics +""" + +from fastapi import APIRouter, HTTPException, Query +from typing import Optional, List +from datetime import datetime +import uuid + +from metrics_db import ( + create_ocr_labeling_session, + get_ocr_labeling_sessions, + get_ocr_labeling_session, + get_ocr_labeling_queue, + get_ocr_labeling_item, + confirm_ocr_label, + correct_ocr_label, + skip_ocr_item, + get_ocr_labeling_stats, +) + +from ocr_labeling_models import ( + SessionCreate, SessionResponse, ItemResponse, + ConfirmRequest, CorrectRequest, SkipRequest, +) +from ocr_labeling_helpers import get_image_url + + +router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) + + +# ============================================================================= +# Session Endpoints +# ============================================================================= + +@router.post("/sessions", response_model=SessionResponse) +async def create_session(session: SessionCreate): + """Create a new OCR labeling session.""" + session_id = str(uuid.uuid4()) + + success = await create_ocr_labeling_session( + session_id=session_id, + name=session.name, + source_type=session.source_type, + description=session.description, + ocr_model=session.ocr_model, + ) + + if not success: + raise HTTPException(status_code=500, detail="Failed to create session") + + return SessionResponse( + id=session_id, + name=session.name, + source_type=session.source_type, + description=session.description, + ocr_model=session.ocr_model, + total_items=0, + labeled_items=0, + confirmed_items=0, + corrected_items=0, + skipped_items=0, + created_at=datetime.utcnow(), + ) + + +@router.get("/sessions", response_model=List[SessionResponse]) +async def list_sessions(limit: int = Query(50, ge=1, le=100)): + """List all OCR labeling sessions.""" + sessions = await get_ocr_labeling_sessions(limit=limit) + + return [ + SessionResponse( + id=s['id'], + name=s['name'], + source_type=s['source_type'], + description=s.get('description'), + ocr_model=s.get('ocr_model'), + total_items=s.get('total_items', 0), + labeled_items=s.get('labeled_items', 0), + confirmed_items=s.get('confirmed_items', 0), + corrected_items=s.get('corrected_items', 0), + skipped_items=s.get('skipped_items', 0), + created_at=s.get('created_at', datetime.utcnow()), + ) + for s in sessions + ] + + +@router.get("/sessions/{session_id}", response_model=SessionResponse) +async def get_session(session_id: str): + """Get a specific OCR labeling session.""" + session = await get_ocr_labeling_session(session_id) + + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + return SessionResponse( + id=session['id'], + name=session['name'], + source_type=session['source_type'], + description=session.get('description'), + ocr_model=session.get('ocr_model'), + total_items=session.get('total_items', 0), + labeled_items=session.get('labeled_items', 0), + confirmed_items=session.get('confirmed_items', 0), + corrected_items=session.get('corrected_items', 0), + skipped_items=session.get('skipped_items', 0), + created_at=session.get('created_at', datetime.utcnow()), + ) + + +# ============================================================================= +# Queue and Item Endpoints +# ============================================================================= + +@router.get("/queue", response_model=List[ItemResponse]) +async def get_labeling_queue( + session_id: Optional[str] = Query(None), + status: str = Query("pending"), + limit: int = Query(10, ge=1, le=50), +): + """Get items from the labeling queue.""" + items = await get_ocr_labeling_queue( + session_id=session_id, + status=status, + limit=limit, + ) + + return [ + ItemResponse( + id=item['id'], + session_id=item['session_id'], + session_name=item.get('session_name', ''), + image_path=item['image_path'], + image_url=get_image_url(item['image_path']), + ocr_text=item.get('ocr_text'), + ocr_confidence=item.get('ocr_confidence'), + ground_truth=item.get('ground_truth'), + status=item.get('status', 'pending'), + metadata=item.get('metadata'), + created_at=item.get('created_at', datetime.utcnow()), + ) + for item in items + ] + + +@router.get("/items/{item_id}", response_model=ItemResponse) +async def get_item(item_id: str): + """Get a specific labeling item.""" + item = await get_ocr_labeling_item(item_id) + + if not item: + raise HTTPException(status_code=404, detail="Item not found") + + return ItemResponse( + id=item['id'], + session_id=item['session_id'], + session_name=item.get('session_name', ''), + image_path=item['image_path'], + image_url=get_image_url(item['image_path']), + ocr_text=item.get('ocr_text'), + ocr_confidence=item.get('ocr_confidence'), + ground_truth=item.get('ground_truth'), + status=item.get('status', 'pending'), + metadata=item.get('metadata'), + created_at=item.get('created_at', datetime.utcnow()), + ) + + +# ============================================================================= +# Labeling Action Endpoints +# ============================================================================= + +@router.post("/confirm") +async def confirm_item(request: ConfirmRequest): + """Confirm that OCR text is correct.""" + success = await confirm_ocr_label( + item_id=request.item_id, + labeled_by="admin", + label_time_seconds=request.label_time_seconds, + ) + + if not success: + raise HTTPException(status_code=400, detail="Failed to confirm item") + + return {"status": "confirmed", "item_id": request.item_id} + + +@router.post("/correct") +async def correct_item(request: CorrectRequest): + """Save corrected ground truth for an item.""" + success = await correct_ocr_label( + item_id=request.item_id, + ground_truth=request.ground_truth, + labeled_by="admin", + label_time_seconds=request.label_time_seconds, + ) + + if not success: + raise HTTPException(status_code=400, detail="Failed to correct item") + + return {"status": "corrected", "item_id": request.item_id} + + +@router.post("/skip") +async def skip_item(request: SkipRequest): + """Skip an item (unusable image, etc.).""" + success = await skip_ocr_item( + item_id=request.item_id, + labeled_by="admin", + ) + + if not success: + raise HTTPException(status_code=400, detail="Failed to skip item") + + return {"status": "skipped", "item_id": request.item_id} + + +@router.get("/stats") +async def get_stats(session_id: Optional[str] = Query(None)): + """Get labeling statistics.""" + stats = await get_ocr_labeling_stats(session_id=session_id) + + if "error" in stats: + raise HTTPException(status_code=500, detail=stats["error"]) + + return stats diff --git a/klausur-service/backend/ocr_labeling_upload_routes.py b/klausur-service/backend/ocr_labeling_upload_routes.py new file mode 100644 index 0000000..0e8a684 --- /dev/null +++ b/klausur-service/backend/ocr_labeling_upload_routes.py @@ -0,0 +1,313 @@ +""" +OCR Labeling - Upload, Run-OCR, and Export Route Handlers + +Extracted from ocr_labeling_routes.py to keep files under 500 LOC. + +Endpoints: +- POST /sessions/{id}/upload - Upload images for labeling +- POST /run-ocr/{item_id} - Run OCR on existing item +- POST /export - Export training data +- GET /training-samples - List training samples +- GET /images/{path} - Serve images from local storage +- GET /exports - List exports +""" + +from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query +from typing import Optional, List +import uuid +import os + +from metrics_db import ( + get_ocr_labeling_session, + add_ocr_labeling_item, + get_ocr_labeling_item, + export_training_samples, + get_training_samples, +) + +from ocr_labeling_models import ( + ExportRequest, + LOCAL_STORAGE_PATH, +) +from ocr_labeling_helpers import ( + compute_image_hash, run_ocr_on_image, + save_image_locally, + MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE, +) + +# Conditional imports +try: + from minio_storage import upload_ocr_image, get_ocr_image +except ImportError: + pass + +try: + from training_export_service import TrainingSample, get_training_export_service +except ImportError: + pass + + +router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"]) + + +@router.post("/sessions/{session_id}/upload") +async def upload_images( + session_id: str, + files: List[UploadFile] = File(...), + run_ocr: bool = Form(True), + metadata: Optional[str] = Form(None), +): + """ + Upload images to a labeling session. + + Args: + session_id: Session to add images to + files: Image files to upload (PNG, JPG, PDF) + run_ocr: Whether to run OCR immediately (default: True) + metadata: Optional JSON metadata (subject, year, etc.) + """ + import json + + session = await get_ocr_labeling_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + meta_dict = None + if metadata: + try: + meta_dict = json.loads(metadata) + except json.JSONDecodeError: + meta_dict = {"raw": metadata} + + results = [] + ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') + + for file in files: + content = await file.read() + image_hash = compute_image_hash(content) + item_id = str(uuid.uuid4()) + + extension = file.filename.split('.')[-1].lower() if file.filename else 'png' + if extension not in ['png', 'jpg', 'jpeg', 'pdf']: + extension = 'png' + + if MINIO_AVAILABLE: + try: + image_path = upload_ocr_image(session_id, item_id, content, extension) + except Exception as e: + print(f"MinIO upload failed, using local storage: {e}") + image_path = save_image_locally(session_id, item_id, content, extension) + else: + image_path = save_image_locally(session_id, item_id, content, extension) + + ocr_text = None + ocr_confidence = None + + if run_ocr and extension != 'pdf': + ocr_text, ocr_confidence = await run_ocr_on_image( + content, + file.filename or f"{item_id}.{extension}", + model=ocr_model + ) + + success = await add_ocr_labeling_item( + item_id=item_id, + session_id=session_id, + image_path=image_path, + image_hash=image_hash, + ocr_text=ocr_text, + ocr_confidence=ocr_confidence, + ocr_model=ocr_model if ocr_text else None, + metadata=meta_dict, + ) + + if success: + results.append({ + "id": item_id, + "filename": file.filename, + "image_path": image_path, + "image_hash": image_hash, + "ocr_text": ocr_text, + "ocr_confidence": ocr_confidence, + "status": "pending", + }) + + return { + "session_id": session_id, + "uploaded_count": len(results), + "items": results, + } + + +@router.post("/export") +async def export_data(request: ExportRequest): + """Export labeled data for training.""" + db_samples = await export_training_samples( + export_format=request.export_format, + session_id=request.session_id, + batch_id=request.batch_id, + exported_by="admin", + ) + + if not db_samples: + return { + "export_format": request.export_format, + "batch_id": request.batch_id, + "exported_count": 0, + "samples": [], + "message": "No labeled samples found to export", + } + + export_result = None + if TRAINING_EXPORT_AVAILABLE: + try: + export_service = get_training_export_service() + + training_samples = [] + for s in db_samples: + training_samples.append(TrainingSample( + id=s.get('id', s.get('item_id', '')), + image_path=s.get('image_path', ''), + ground_truth=s.get('ground_truth', ''), + ocr_text=s.get('ocr_text'), + ocr_confidence=s.get('ocr_confidence'), + metadata=s.get('metadata'), + )) + + export_result = export_service.export( + samples=training_samples, + export_format=request.export_format, + batch_id=request.batch_id, + ) + except Exception as e: + print(f"Training export failed: {e}") + + response = { + "export_format": request.export_format, + "batch_id": request.batch_id or (export_result.batch_id if export_result else None), + "exported_count": len(db_samples), + "samples": db_samples, + } + + if export_result: + response["export_path"] = export_result.export_path + response["manifest_path"] = export_result.manifest_path + + return response + + +@router.get("/training-samples") +async def list_training_samples( + export_format: Optional[str] = Query(None), + batch_id: Optional[str] = Query(None), + limit: int = Query(100, ge=1, le=1000), +): + """Get exported training samples.""" + samples = await get_training_samples( + export_format=export_format, + batch_id=batch_id, + limit=limit, + ) + + return { + "count": len(samples), + "samples": samples, + } + + +@router.get("/images/{path:path}") +async def get_image(path: str): + """Serve an image from local storage.""" + from fastapi.responses import FileResponse + + filepath = os.path.join(LOCAL_STORAGE_PATH, path) + + if not os.path.exists(filepath): + raise HTTPException(status_code=404, detail="Image not found") + + extension = filepath.split('.')[-1].lower() + content_type = { + 'png': 'image/png', + 'jpg': 'image/jpeg', + 'jpeg': 'image/jpeg', + 'pdf': 'application/pdf', + }.get(extension, 'application/octet-stream') + + return FileResponse(filepath, media_type=content_type) + + +@router.post("/run-ocr/{item_id}") +async def run_ocr_for_item(item_id: str): + """Run OCR on an existing item.""" + item = await get_ocr_labeling_item(item_id) + + if not item: + raise HTTPException(status_code=404, detail="Item not found") + + image_path = item['image_path'] + + if image_path.startswith(LOCAL_STORAGE_PATH): + if not os.path.exists(image_path): + raise HTTPException(status_code=404, detail="Image file not found") + with open(image_path, 'rb') as f: + image_data = f.read() + elif MINIO_AVAILABLE: + try: + image_data = get_ocr_image(image_path) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to load image: {e}") + else: + raise HTTPException(status_code=500, detail="Cannot load image") + + session = await get_ocr_labeling_session(item['session_id']) + ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b' + + ocr_text, ocr_confidence = await run_ocr_on_image( + image_data, + os.path.basename(image_path), + model=ocr_model + ) + + if ocr_text is None: + raise HTTPException(status_code=500, detail="OCR failed") + + from metrics_db import get_pool + pool = await get_pool() + if pool: + async with pool.acquire() as conn: + await conn.execute( + """ + UPDATE ocr_labeling_items + SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4 + WHERE id = $1 + """, + item_id, ocr_text, ocr_confidence, ocr_model + ) + + return { + "item_id": item_id, + "ocr_text": ocr_text, + "ocr_confidence": ocr_confidence, + "ocr_model": ocr_model, + } + + +@router.get("/exports") +async def list_exports(export_format: Optional[str] = Query(None)): + """List all available training data exports.""" + if not TRAINING_EXPORT_AVAILABLE: + return { + "exports": [], + "message": "Training export service not available", + } + + try: + export_service = get_training_export_service() + exports = export_service.list_exports(export_format=export_format) + + return { + "count": len(exports), + "exports": exports, + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}") diff --git a/klausur-service/backend/ocr_pipeline_auto.py b/klausur-service/backend/ocr_pipeline_auto.py index c85ac49..f354659 100644 --- a/klausur-service/backend/ocr_pipeline_auto.py +++ b/klausur-service/backend/ocr_pipeline_auto.py @@ -1,705 +1,23 @@ """ -OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints. +OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export. -Extracted from ocr_pipeline_api.py — contains: -- POST /sessions/{session_id}/reprocess (clear downstream + restart from step) -- POST /sessions/{session_id}/run-auto (full auto-mode with SSE streaming) +Split into submodules: +- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess +- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import json -import logging -import os -import re -import time -from dataclasses import asdict -from typing import Any, Dict, List, Optional +from fastapi import APIRouter -import cv2 -import numpy as np -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import StreamingResponse -from pydantic import BaseModel - -from cv_vocab_pipeline import ( - OLLAMA_REVIEW_MODEL, - PageRegion, - RowGeometry, - _cells_to_vocab_entries, - _detect_header_footer_gaps, - _detect_sub_columns, - _fix_character_confusion, - _fix_phonetic_brackets, - fix_cell_phonetics, - analyze_layout, - build_cell_grid, - classify_column_types, - create_layout_image, - create_ocr_image, - deskew_image, - deskew_image_by_word_alignment, - detect_column_geometry, - detect_row_geometry, - _apply_shear, - dewarp_image, - llm_review_entries, -) -from ocr_pipeline_common import ( - _cache, - _load_session_to_cache, - _get_cached, - _get_base_image_png, - _append_pipeline_log, -) -from ocr_pipeline_session_store import ( - get_session_db, - update_session_db, -) - -logger = logging.getLogger(__name__) +from ocr_pipeline_reprocess import router as _reprocess_router +from ocr_pipeline_auto_steps import router as _steps_router +# Combine both sub-routers into a single router for backwards compatibility. +# The consumer imports `from ocr_pipeline_auto import router as _auto_router`. router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) +router.include_router(_reprocess_router) +router.include_router(_steps_router) - -# --------------------------------------------------------------------------- -# Reprocess endpoint -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/reprocess") -async def reprocess_session(session_id: str, request: Request): - """Re-run pipeline from a specific step, clearing downstream data. - - Body: {"from_step": 5} (1-indexed step number) - - Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) → - Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10) - - Clears downstream results: - - from_step <= 1: orientation_result + all downstream - - from_step <= 2: deskew_result + all downstream - - from_step <= 3: dewarp_result + all downstream - - from_step <= 4: crop_result + all downstream - - from_step <= 5: column_result, row_result, word_result - - from_step <= 6: row_result, word_result - - from_step <= 7: word_result (cells, vocab_entries) - - from_step <= 8: word_result.llm_review only - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - body = await request.json() - from_step = body.get("from_step", 1) - if not isinstance(from_step, int) or from_step < 1 or from_step > 10: - raise HTTPException(status_code=400, detail="from_step must be between 1 and 10") - - update_kwargs: Dict[str, Any] = {"current_step": from_step} - - # Clear downstream data based on from_step - # New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) → - # Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11) - if from_step <= 8: - update_kwargs["word_result"] = None - elif from_step == 9: - # Only clear LLM review from word_result - word_result = session.get("word_result") - if word_result: - word_result.pop("llm_review", None) - word_result.pop("llm_corrections", None) - update_kwargs["word_result"] = word_result - - if from_step <= 7: - update_kwargs["row_result"] = None - if from_step <= 6: - update_kwargs["column_result"] = None - if from_step <= 4: - update_kwargs["crop_result"] = None - if from_step <= 3: - update_kwargs["dewarp_result"] = None - if from_step <= 2: - update_kwargs["deskew_result"] = None - if from_step <= 1: - update_kwargs["orientation_result"] = None - - await update_session_db(session_id, **update_kwargs) - - # Also clear cache - if session_id in _cache: - for key in list(update_kwargs.keys()): - if key != "current_step": - _cache[session_id][key] = update_kwargs[key] - _cache[session_id]["current_step"] = from_step - - logger.info(f"Session {session_id} reprocessing from step {from_step}") - - return { - "session_id": session_id, - "from_step": from_step, - "cleared": [k for k in update_kwargs if k != "current_step"], - } - - -# --------------------------------------------------------------------------- -# VLM shear detection helper (used by dewarp step in auto-mode) -# --------------------------------------------------------------------------- - -async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: - """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. - - The VLM is shown the image and asked: are the column/table borders tilted? - If yes, by how many degrees? Returns a dict with shear_degrees and confidence. - Confidence is 0.0 if Ollama is unavailable or parsing fails. - """ - import httpx - import base64 - import re - - ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") - model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") - - prompt = ( - "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " - "Are they perfectly vertical, or do they tilt slightly? " - "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " - "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " - "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " - "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" - ) - - img_b64 = base64.b64encode(image_bytes).decode("utf-8") - payload = { - "model": model, - "prompt": prompt, - "images": [img_b64], - "stream": False, - } - - try: - async with httpx.AsyncClient(timeout=60.0) as client: - resp = await client.post(f"{ollama_base}/api/generate", json=payload) - resp.raise_for_status() - text = resp.json().get("response", "") - - # Parse JSON from response (may have surrounding text) - match = re.search(r'\{[^}]+\}', text) - if match: - import json - data = json.loads(match.group(0)) - shear = float(data.get("shear_degrees", 0.0)) - conf = float(data.get("confidence", 0.0)) - # Clamp to reasonable range - shear = max(-3.0, min(3.0, shear)) - conf = max(0.0, min(1.0, conf)) - return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} - except Exception as e: - logger.warning(f"VLM dewarp failed: {e}") - - return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} - - -# --------------------------------------------------------------------------- -# Auto-mode orchestrator -# --------------------------------------------------------------------------- - -class RunAutoRequest(BaseModel): - from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review - ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract" - pronunciation: str = "british" - skip_llm_review: bool = False - dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv" - - -async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str: - """Format a single SSE event line.""" - import json as _json - payload = {"step": step, "status": status, **data} - return f"data: {_json.dumps(payload)}\n\n" - - -@router.post("/sessions/{session_id}/run-auto") -async def run_auto(session_id: str, req: RunAutoRequest, request: Request): - """Run the full OCR pipeline automatically from a given step, streaming SSE progress. - - Steps: - 1. Deskew — straighten the scan - 2. Dewarp — correct vertical shear (ensemble CV or VLM) - 3. Columns — detect column layout - 4. Rows — detect row layout - 5. Words — OCR each cell - 6. LLM review — correct OCR errors (optional) - - Already-completed steps are skipped unless `from_step` forces a rerun. - Yields SSE events of the form: - data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...} - - Final event: - data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]} - """ - if req.from_step < 1 or req.from_step > 6: - raise HTTPException(status_code=400, detail="from_step must be 1-6") - if req.dewarp_method not in ("ensemble", "vlm", "cv"): - raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv") - - if session_id not in _cache: - await _load_session_to_cache(session_id) - - async def _generate(): - steps_run: List[str] = [] - steps_skipped: List[str] = [] - error_step: Optional[str] = None - - session = await get_session_db(session_id) - if not session: - yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"}) - return - - cached = _get_cached(session_id) - - # ----------------------------------------------------------------- - # Step 1: Deskew - # ----------------------------------------------------------------- - if req.from_step <= 1: - yield await _auto_sse_event("deskew", "start", {}) - try: - t0 = time.time() - orig_bgr = cached.get("original_bgr") - if orig_bgr is None: - raise ValueError("Original image not loaded") - - # Method 1: Hough lines - try: - deskewed_hough, angle_hough = deskew_image(orig_bgr.copy()) - except Exception: - deskewed_hough, angle_hough = orig_bgr, 0.0 - - # Method 2: Word alignment - success_enc, png_orig = cv2.imencode(".png", orig_bgr) - orig_bytes = png_orig.tobytes() if success_enc else b"" - try: - deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes) - except Exception: - deskewed_wa_bytes, angle_wa = orig_bytes, 0.0 - - # Pick best method - if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1: - method_used = "word_alignment" - angle_applied = angle_wa - wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8) - deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR) - if deskewed_bgr is None: - deskewed_bgr = deskewed_hough - method_used = "hough" - angle_applied = angle_hough - else: - method_used = "hough" - angle_applied = angle_hough - deskewed_bgr = deskewed_hough - - success, png_buf = cv2.imencode(".png", deskewed_bgr) - deskewed_png = png_buf.tobytes() if success else b"" - - deskew_result = { - "method_used": method_used, - "rotation_degrees": round(float(angle_applied), 3), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["deskewed_bgr"] = deskewed_bgr - cached["deskew_result"] = deskew_result - await update_session_db( - session_id, - deskewed_png=deskewed_png, - deskew_result=deskew_result, - auto_rotation_degrees=float(angle_applied), - current_step=3, - ) - session = await get_session_db(session_id) - - steps_run.append("deskew") - yield await _auto_sse_event("deskew", "done", deskew_result) - except Exception as e: - logger.error(f"Auto-mode deskew failed for {session_id}: {e}") - error_step = "deskew" - yield await _auto_sse_event("deskew", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("deskew") - yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"}) - - # ----------------------------------------------------------------- - # Step 2: Dewarp - # ----------------------------------------------------------------- - if req.from_step <= 2: - yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method}) - try: - t0 = time.time() - deskewed_bgr = cached.get("deskewed_bgr") - if deskewed_bgr is None: - raise ValueError("Deskewed image not available") - - if req.dewarp_method == "vlm": - success_enc, png_buf = cv2.imencode(".png", deskewed_bgr) - img_bytes = png_buf.tobytes() if success_enc else b"" - vlm_det = await _detect_shear_with_vlm(img_bytes) - shear_deg = vlm_det["shear_degrees"] - if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: - dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) - else: - dewarped_bgr = deskewed_bgr - dewarp_info = { - "method": vlm_det["method"], - "shear_degrees": shear_deg, - "confidence": vlm_det["confidence"], - "detections": [vlm_det], - } - else: - dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) - - success_enc, png_buf = cv2.imencode(".png", dewarped_bgr) - dewarped_png = png_buf.tobytes() if success_enc else b"" - - dewarp_result = { - "method_used": dewarp_info["method"], - "shear_degrees": dewarp_info["shear_degrees"], - "confidence": dewarp_info["confidence"], - "duration_seconds": round(time.time() - t0, 2), - "detections": dewarp_info.get("detections", []), - } - - cached["dewarped_bgr"] = dewarped_bgr - cached["dewarp_result"] = dewarp_result - await update_session_db( - session_id, - dewarped_png=dewarped_png, - dewarp_result=dewarp_result, - auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), - current_step=4, - ) - session = await get_session_db(session_id) - - steps_run.append("dewarp") - yield await _auto_sse_event("dewarp", "done", dewarp_result) - except Exception as e: - logger.error(f"Auto-mode dewarp failed for {session_id}: {e}") - error_step = "dewarp" - yield await _auto_sse_event("dewarp", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("dewarp") - yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"}) - - # ----------------------------------------------------------------- - # Step 3: Columns - # ----------------------------------------------------------------- - if req.from_step <= 3: - yield await _auto_sse_event("columns", "start", {}) - try: - t0 = time.time() - col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if col_img is None: - raise ValueError("Cropped/dewarped image not available") - - ocr_img = create_ocr_image(col_img) - h, w = ocr_img.shape[:2] - - geo_result = detect_column_geometry(ocr_img, col_img) - if geo_result is None: - layout_img = create_layout_image(col_img) - regions = analyze_layout(layout_img, ocr_img) - cached["_word_dicts"] = None - cached["_inv"] = None - cached["_content_bounds"] = None - else: - geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result - content_w = right_x - left_x - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) - - header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) - geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, - top_y=top_y, header_y=header_y, footer_y=footer_y) - regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, - left_x=left_x, right_x=right_x, inv=inv) - - columns = [asdict(r) for r in regions] - column_result = { - "columns": columns, - "classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["column_result"] = column_result - await update_session_db(session_id, column_result=column_result, - row_result=None, word_result=None, current_step=6) - session = await get_session_db(session_id) - - steps_run.append("columns") - yield await _auto_sse_event("columns", "done", { - "column_count": len(columns), - "duration_seconds": column_result["duration_seconds"], - }) - except Exception as e: - logger.error(f"Auto-mode columns failed for {session_id}: {e}") - error_step = "columns" - yield await _auto_sse_event("columns", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("columns") - yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"}) - - # ----------------------------------------------------------------- - # Step 4: Rows - # ----------------------------------------------------------------- - if req.from_step <= 4: - yield await _auto_sse_event("rows", "start", {}) - try: - t0 = time.time() - row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - session = await get_session_db(session_id) - column_result = session.get("column_result") or cached.get("column_result") - if not column_result or not column_result.get("columns"): - raise ValueError("Column detection must complete first") - - col_regions = [ - PageRegion( - type=c["type"], x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - - word_dicts = cached.get("_word_dicts") - inv = cached.get("_inv") - content_bounds = cached.get("_content_bounds") - - if word_dicts is None or inv is None or content_bounds is None: - ocr_img_tmp = create_ocr_image(row_img) - geo_result = detect_column_geometry(ocr_img_tmp, row_img) - if geo_result is None: - raise ValueError("Column geometry detection failed — cannot detect rows") - _g, lx, rx, ty, by, word_dicts, inv = geo_result - cached["_word_dicts"] = word_dicts - cached["_inv"] = inv - cached["_content_bounds"] = (lx, rx, ty, by) - content_bounds = (lx, rx, ty, by) - - left_x, right_x, top_y, bottom_y = content_bounds - row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) - - row_list = [ - { - "index": r.index, "x": r.x, "y": r.y, - "width": r.width, "height": r.height, - "word_count": r.word_count, - "row_type": r.row_type, - "gap_before": r.gap_before, - } - for r in row_geoms - ] - row_result = { - "rows": row_list, - "row_count": len(row_list), - "content_rows": len([r for r in row_geoms if r.row_type == "content"]), - "duration_seconds": round(time.time() - t0, 2), - } - - cached["row_result"] = row_result - await update_session_db(session_id, row_result=row_result, current_step=7) - session = await get_session_db(session_id) - - steps_run.append("rows") - yield await _auto_sse_event("rows", "done", { - "row_count": len(row_list), - "content_rows": row_result["content_rows"], - "duration_seconds": row_result["duration_seconds"], - }) - except Exception as e: - logger.error(f"Auto-mode rows failed for {session_id}: {e}") - error_step = "rows" - yield await _auto_sse_event("rows", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("rows") - yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"}) - - # ----------------------------------------------------------------- - # Step 5: Words (OCR) - # ----------------------------------------------------------------- - if req.from_step <= 5: - yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine}) - try: - t0 = time.time() - word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - session = await get_session_db(session_id) - - column_result = session.get("column_result") or cached.get("column_result") - row_result = session.get("row_result") or cached.get("row_result") - - col_regions = [ - PageRegion( - type=c["type"], x=c["x"], y=c["y"], - width=c["width"], height=c["height"], - classification_confidence=c.get("classification_confidence", 1.0), - classification_method=c.get("classification_method", ""), - ) - for c in column_result["columns"] - ] - row_geoms = [ - RowGeometry( - index=r["index"], x=r["x"], y=r["y"], - width=r["width"], height=r["height"], - word_count=r.get("word_count", 0), words=[], - row_type=r.get("row_type", "content"), - gap_before=r.get("gap_before", 0), - ) - for r in row_result["rows"] - ] - - word_dicts = cached.get("_word_dicts") - if word_dicts is not None: - content_bounds = cached.get("_content_bounds") - top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms) - for row in row_geoms: - row_y_rel = row.y - top_y - row_bottom_rel = row_y_rel + row.height - row.words = [ - w for w in word_dicts - if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel - ] - row.word_count = len(row.words) - - ocr_img = create_ocr_image(word_img) - img_h, img_w = word_img.shape[:2] - - cells, columns_meta = build_cell_grid( - ocr_img, col_regions, row_geoms, img_w, img_h, - ocr_engine=req.ocr_engine, img_bgr=word_img, - ) - duration = time.time() - t0 - - col_types = {c['type'] for c in columns_meta} - is_vocab = bool(col_types & {'column_en', 'column_de'}) - n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) - used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine - - # Apply IPA phonetic fixes directly to cell texts - fix_cell_phonetics(cells, pronunciation=req.pronunciation) - - word_result_data = { - "cells": cells, - "grid_shape": { - "rows": n_content_rows, - "cols": len(columns_meta), - "total_cells": len(cells), - }, - "columns_used": columns_meta, - "layout": "vocab" if is_vocab else "generic", - "image_width": img_w, - "image_height": img_h, - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": { - "total_cells": len(cells), - "non_empty_cells": sum(1 for c in cells if c.get("text")), - "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), - }, - } - - has_text_col = 'column_text' in col_types - if is_vocab or has_text_col: - entries = _cells_to_vocab_entries(cells, columns_meta) - entries = _fix_character_confusion(entries) - entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation) - word_result_data["vocab_entries"] = entries - word_result_data["entries"] = entries - word_result_data["entry_count"] = len(entries) - word_result_data["summary"]["total_entries"] = len(entries) - - await update_session_db(session_id, word_result=word_result_data, current_step=8) - cached["word_result"] = word_result_data - session = await get_session_db(session_id) - - steps_run.append("words") - yield await _auto_sse_event("words", "done", { - "total_cells": len(cells), - "layout": word_result_data["layout"], - "duration_seconds": round(duration, 2), - "ocr_engine": used_engine, - "summary": word_result_data["summary"], - }) - except Exception as e: - logger.error(f"Auto-mode words failed for {session_id}: {e}") - error_step = "words" - yield await _auto_sse_event("words", "error", {"message": str(e)}) - yield await _auto_sse_event("complete", "error", {"error_step": error_step}) - return - else: - steps_skipped.append("words") - yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"}) - - # ----------------------------------------------------------------- - # Step 6: LLM Review (optional) - # ----------------------------------------------------------------- - if req.from_step <= 6 and not req.skip_llm_review: - yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL}) - try: - session = await get_session_db(session_id) - word_result = session.get("word_result") or cached.get("word_result") - entries = word_result.get("entries") or word_result.get("vocab_entries") or [] - - if not entries: - yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"}) - steps_skipped.append("llm_review") - else: - reviewed = await llm_review_entries(entries) - - session = await get_session_db(session_id) - word_result_updated = dict(session.get("word_result") or {}) - word_result_updated["entries"] = reviewed - word_result_updated["vocab_entries"] = reviewed - word_result_updated["llm_reviewed"] = True - word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL - - await update_session_db(session_id, word_result=word_result_updated, current_step=9) - cached["word_result"] = word_result_updated - - steps_run.append("llm_review") - yield await _auto_sse_event("llm_review", "done", { - "entries_reviewed": len(reviewed), - "model": OLLAMA_REVIEW_MODEL, - }) - except Exception as e: - logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}") - yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False}) - steps_skipped.append("llm_review") - else: - steps_skipped.append("llm_review") - reason = "skipped by request" if req.skip_llm_review else "from_step > 6" - yield await _auto_sse_event("llm_review", "skipped", {"reason": reason}) - - # ----------------------------------------------------------------- - # Final event - # ----------------------------------------------------------------- - yield await _auto_sse_event("complete", "done", { - "steps_run": steps_run, - "steps_skipped": steps_skipped, - }) - - return StreamingResponse( - _generate(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) +__all__ = ["router"] diff --git a/klausur-service/backend/ocr_pipeline_auto_helpers.py b/klausur-service/backend/ocr_pipeline_auto_helpers.py new file mode 100644 index 0000000..05df86d --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_auto_helpers.py @@ -0,0 +1,84 @@ +""" +OCR Pipeline Auto-Mode Helpers. + +VLM shear detection, SSE event formatting, and request models. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +import re +from typing import Any, Dict + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class RunAutoRequest(BaseModel): + from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review + ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract" + pronunciation: str = "british" + skip_llm_review: bool = False + dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv" + + +async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str: + """Format a single SSE event line.""" + payload = {"step": step, "status": status, **data} + return f"data: {json.dumps(payload)}\n\n" + + +async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]: + """Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page. + + The VLM is shown the image and asked: are the column/table borders tilted? + If yes, by how many degrees? Returns a dict with shear_degrees and confidence. + Confidence is 0.0 if Ollama is unavailable or parsing fails. + """ + import httpx + import base64 + + ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") + model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b") + + prompt = ( + "This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. " + "Are they perfectly vertical, or do they tilt slightly? " + "If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). " + "Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} " + "Use confidence 0.0-1.0 based on how clearly you can see the tilt. " + "If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}" + ) + + img_b64 = base64.b64encode(image_bytes).decode("utf-8") + payload = { + "model": model, + "prompt": prompt, + "images": [img_b64], + "stream": False, + } + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{ollama_base}/api/generate", json=payload) + resp.raise_for_status() + text = resp.json().get("response", "") + + # Parse JSON from response (may have surrounding text) + match = re.search(r'\{[^}]+\}', text) + if match: + data = json.loads(match.group(0)) + shear = float(data.get("shear_degrees", 0.0)) + conf = float(data.get("confidence", 0.0)) + # Clamp to reasonable range + shear = max(-3.0, min(3.0, shear)) + conf = max(0.0, min(1.0, conf)) + return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)} + except Exception as e: + logger.warning(f"VLM dewarp failed: {e}") + + return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0} diff --git a/klausur-service/backend/ocr_pipeline_auto_steps.py b/klausur-service/backend/ocr_pipeline_auto_steps.py new file mode 100644 index 0000000..4961ee9 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_auto_steps.py @@ -0,0 +1,528 @@ +""" +OCR Pipeline Auto-Mode Orchestrator. + +POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from dataclasses import asdict +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse + +from cv_vocab_pipeline import ( + OLLAMA_REVIEW_MODEL, + PageRegion, + RowGeometry, + _cells_to_vocab_entries, + _detect_header_footer_gaps, + _detect_sub_columns, + _fix_character_confusion, + _fix_phonetic_brackets, + fix_cell_phonetics, + analyze_layout, + build_cell_grid, + classify_column_types, + create_layout_image, + create_ocr_image, + deskew_image, + deskew_image_by_word_alignment, + detect_column_geometry, + detect_row_geometry, + _apply_shear, + dewarp_image, + llm_review_entries, +) +from ocr_pipeline_common import ( + _cache, + _load_session_to_cache, + _get_cached, +) +from ocr_pipeline_session_store import ( + get_session_db, + update_session_db, +) +from ocr_pipeline_auto_helpers import ( + RunAutoRequest, + auto_sse_event as _auto_sse_event, + detect_shear_with_vlm as _detect_shear_with_vlm, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["ocr-pipeline"]) + +@router.post("/sessions/{session_id}/run-auto") +async def run_auto(session_id: str, req: RunAutoRequest, request: Request): + """Run the full OCR pipeline automatically from a given step, streaming SSE progress. + + Steps: + 1. Deskew -- straighten the scan + 2. Dewarp -- correct vertical shear (ensemble CV or VLM) + 3. Columns -- detect column layout + 4. Rows -- detect row layout + 5. Words -- OCR each cell + 6. LLM review -- correct OCR errors (optional) + + Already-completed steps are skipped unless `from_step` forces a rerun. + Yields SSE events of the form: + data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...} + + Final event: + data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]} + """ + if req.from_step < 1 or req.from_step > 6: + raise HTTPException(status_code=400, detail="from_step must be 1-6") + if req.dewarp_method not in ("ensemble", "vlm", "cv"): + raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv") + + if session_id not in _cache: + await _load_session_to_cache(session_id) + + async def _generate(): + steps_run: List[str] = [] + steps_skipped: List[str] = [] + error_step: Optional[str] = None + + session = await get_session_db(session_id) + if not session: + yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"}) + return + + cached = _get_cached(session_id) + + # Step 1: Deskew + if req.from_step <= 1: + yield await _auto_sse_event("deskew", "start", {}) + try: + t0 = time.time() + orig_bgr = cached.get("original_bgr") + if orig_bgr is None: + raise ValueError("Original image not loaded") + + try: + deskewed_hough, angle_hough = deskew_image(orig_bgr.copy()) + except Exception: + deskewed_hough, angle_hough = orig_bgr, 0.0 + + success_enc, png_orig = cv2.imencode(".png", orig_bgr) + orig_bytes = png_orig.tobytes() if success_enc else b"" + try: + deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes) + except Exception: + deskewed_wa_bytes, angle_wa = orig_bytes, 0.0 + + if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1: + method_used = "word_alignment" + angle_applied = angle_wa + wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8) + deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR) + if deskewed_bgr is None: + deskewed_bgr = deskewed_hough + method_used = "hough" + angle_applied = angle_hough + else: + method_used = "hough" + angle_applied = angle_hough + deskewed_bgr = deskewed_hough + + success, png_buf = cv2.imencode(".png", deskewed_bgr) + deskewed_png = png_buf.tobytes() if success else b"" + + deskew_result = { + "method_used": method_used, + "rotation_degrees": round(float(angle_applied), 3), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["deskewed_bgr"] = deskewed_bgr + cached["deskew_result"] = deskew_result + await update_session_db( + session_id, + deskewed_png=deskewed_png, + deskew_result=deskew_result, + auto_rotation_degrees=float(angle_applied), + current_step=3, + ) + session = await get_session_db(session_id) + + steps_run.append("deskew") + yield await _auto_sse_event("deskew", "done", deskew_result) + except Exception as e: + logger.error(f"Auto-mode deskew failed for {session_id}: {e}") + error_step = "deskew" + yield await _auto_sse_event("deskew", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("deskew") + yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"}) + + # Step 2: Dewarp + if req.from_step <= 2: + yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method}) + try: + t0 = time.time() + deskewed_bgr = cached.get("deskewed_bgr") + if deskewed_bgr is None: + raise ValueError("Deskewed image not available") + + if req.dewarp_method == "vlm": + success_enc, png_buf = cv2.imencode(".png", deskewed_bgr) + img_bytes = png_buf.tobytes() if success_enc else b"" + vlm_det = await _detect_shear_with_vlm(img_bytes) + shear_deg = vlm_det["shear_degrees"] + if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3: + dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg) + else: + dewarped_bgr = deskewed_bgr + dewarp_info = { + "method": vlm_det["method"], + "shear_degrees": shear_deg, + "confidence": vlm_det["confidence"], + "detections": [vlm_det], + } + else: + dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr) + + success_enc, png_buf = cv2.imencode(".png", dewarped_bgr) + dewarped_png = png_buf.tobytes() if success_enc else b"" + + dewarp_result = { + "method_used": dewarp_info["method"], + "shear_degrees": dewarp_info["shear_degrees"], + "confidence": dewarp_info["confidence"], + "duration_seconds": round(time.time() - t0, 2), + "detections": dewarp_info.get("detections", []), + } + + cached["dewarped_bgr"] = dewarped_bgr + cached["dewarp_result"] = dewarp_result + await update_session_db( + session_id, + dewarped_png=dewarped_png, + dewarp_result=dewarp_result, + auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0), + current_step=4, + ) + session = await get_session_db(session_id) + + steps_run.append("dewarp") + yield await _auto_sse_event("dewarp", "done", dewarp_result) + except Exception as e: + logger.error(f"Auto-mode dewarp failed for {session_id}: {e}") + error_step = "dewarp" + yield await _auto_sse_event("dewarp", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("dewarp") + yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"}) + + # Step 3: Columns + if req.from_step <= 3: + yield await _auto_sse_event("columns", "start", {}) + try: + t0 = time.time() + col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if col_img is None: + raise ValueError("Cropped/dewarped image not available") + + ocr_img = create_ocr_image(col_img) + h, w = ocr_img.shape[:2] + + geo_result = detect_column_geometry(ocr_img, col_img) + if geo_result is None: + layout_img = create_layout_image(col_img) + regions = analyze_layout(layout_img, ocr_img) + cached["_word_dicts"] = None + cached["_inv"] = None + cached["_content_bounds"] = None + else: + geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result + content_w = right_x - left_x + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y) + + header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None) + geometries = _detect_sub_columns(geometries, content_w, left_x=left_x, + top_y=top_y, header_y=header_y, footer_y=footer_y) + regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y, + left_x=left_x, right_x=right_x, inv=inv) + + columns = [asdict(r) for r in regions] + column_result = { + "columns": columns, + "classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["column_result"] = column_result + await update_session_db(session_id, column_result=column_result, + row_result=None, word_result=None, current_step=6) + session = await get_session_db(session_id) + + steps_run.append("columns") + yield await _auto_sse_event("columns", "done", { + "column_count": len(columns), + "duration_seconds": column_result["duration_seconds"], + }) + except Exception as e: + logger.error(f"Auto-mode columns failed for {session_id}: {e}") + error_step = "columns" + yield await _auto_sse_event("columns", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("columns") + yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"}) + + # Step 4: Rows + if req.from_step <= 4: + yield await _auto_sse_event("rows", "start", {}) + try: + t0 = time.time() + row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + session = await get_session_db(session_id) + column_result = session.get("column_result") or cached.get("column_result") + if not column_result or not column_result.get("columns"): + raise ValueError("Column detection must complete first") + + col_regions = [ + PageRegion( + type=c["type"], x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + + word_dicts = cached.get("_word_dicts") + inv = cached.get("_inv") + content_bounds = cached.get("_content_bounds") + + if word_dicts is None or inv is None or content_bounds is None: + ocr_img_tmp = create_ocr_image(row_img) + geo_result = detect_column_geometry(ocr_img_tmp, row_img) + if geo_result is None: + raise ValueError("Column geometry detection failed -- cannot detect rows") + _g, lx, rx, ty, by, word_dicts, inv = geo_result + cached["_word_dicts"] = word_dicts + cached["_inv"] = inv + cached["_content_bounds"] = (lx, rx, ty, by) + content_bounds = (lx, rx, ty, by) + + left_x, right_x, top_y, bottom_y = content_bounds + row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y) + + row_list = [ + { + "index": r.index, "x": r.x, "y": r.y, + "width": r.width, "height": r.height, + "word_count": r.word_count, + "row_type": r.row_type, + "gap_before": r.gap_before, + } + for r in row_geoms + ] + row_result = { + "rows": row_list, + "row_count": len(row_list), + "content_rows": len([r for r in row_geoms if r.row_type == "content"]), + "duration_seconds": round(time.time() - t0, 2), + } + + cached["row_result"] = row_result + await update_session_db(session_id, row_result=row_result, current_step=7) + session = await get_session_db(session_id) + + steps_run.append("rows") + yield await _auto_sse_event("rows", "done", { + "row_count": len(row_list), + "content_rows": row_result["content_rows"], + "duration_seconds": row_result["duration_seconds"], + }) + except Exception as e: + logger.error(f"Auto-mode rows failed for {session_id}: {e}") + error_step = "rows" + yield await _auto_sse_event("rows", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("rows") + yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"}) + + # Step 5: Words (OCR) + if req.from_step <= 5: + yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine}) + try: + t0 = time.time() + word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + session = await get_session_db(session_id) + + column_result = session.get("column_result") or cached.get("column_result") + row_result = session.get("row_result") or cached.get("row_result") + + col_regions = [ + PageRegion( + type=c["type"], x=c["x"], y=c["y"], + width=c["width"], height=c["height"], + classification_confidence=c.get("classification_confidence", 1.0), + classification_method=c.get("classification_method", ""), + ) + for c in column_result["columns"] + ] + row_geoms = [ + RowGeometry( + index=r["index"], x=r["x"], y=r["y"], + width=r["width"], height=r["height"], + word_count=r.get("word_count", 0), words=[], + row_type=r.get("row_type", "content"), + gap_before=r.get("gap_before", 0), + ) + for r in row_result["rows"] + ] + + word_dicts = cached.get("_word_dicts") + if word_dicts is not None: + content_bounds = cached.get("_content_bounds") + top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms) + for row in row_geoms: + row_y_rel = row.y - top_y + row_bottom_rel = row_y_rel + row.height + row.words = [ + w for w in word_dicts + if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel + ] + row.word_count = len(row.words) + + ocr_img = create_ocr_image(word_img) + img_h, img_w = word_img.shape[:2] + + cells, columns_meta = build_cell_grid( + ocr_img, col_regions, row_geoms, img_w, img_h, + ocr_engine=req.ocr_engine, img_bgr=word_img, + ) + duration = time.time() - t0 + + col_types = {c['type'] for c in columns_meta} + is_vocab = bool(col_types & {'column_en', 'column_de'}) + n_content_rows = len([r for r in row_geoms if r.row_type == 'content']) + used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine + + fix_cell_phonetics(cells, pronunciation=req.pronunciation) + + word_result_data = { + "cells": cells, + "grid_shape": { + "rows": n_content_rows, + "cols": len(columns_meta), + "total_cells": len(cells), + }, + "columns_used": columns_meta, + "layout": "vocab" if is_vocab else "generic", + "image_width": img_w, + "image_height": img_h, + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": { + "total_cells": len(cells), + "non_empty_cells": sum(1 for c in cells if c.get("text")), + "low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50), + }, + } + + has_text_col = 'column_text' in col_types + if is_vocab or has_text_col: + entries = _cells_to_vocab_entries(cells, columns_meta) + entries = _fix_character_confusion(entries) + entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation) + word_result_data["vocab_entries"] = entries + word_result_data["entries"] = entries + word_result_data["entry_count"] = len(entries) + word_result_data["summary"]["total_entries"] = len(entries) + + await update_session_db(session_id, word_result=word_result_data, current_step=8) + cached["word_result"] = word_result_data + session = await get_session_db(session_id) + + steps_run.append("words") + yield await _auto_sse_event("words", "done", { + "total_cells": len(cells), + "layout": word_result_data["layout"], + "duration_seconds": round(duration, 2), + "ocr_engine": used_engine, + "summary": word_result_data["summary"], + }) + except Exception as e: + logger.error(f"Auto-mode words failed for {session_id}: {e}") + error_step = "words" + yield await _auto_sse_event("words", "error", {"message": str(e)}) + yield await _auto_sse_event("complete", "error", {"error_step": error_step}) + return + else: + steps_skipped.append("words") + yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"}) + + # Step 6: LLM Review (optional) + if req.from_step <= 6 and not req.skip_llm_review: + yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL}) + try: + session = await get_session_db(session_id) + word_result = session.get("word_result") or cached.get("word_result") + entries = word_result.get("entries") or word_result.get("vocab_entries") or [] + + if not entries: + yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"}) + steps_skipped.append("llm_review") + else: + reviewed = await llm_review_entries(entries) + + session = await get_session_db(session_id) + word_result_updated = dict(session.get("word_result") or {}) + word_result_updated["entries"] = reviewed + word_result_updated["vocab_entries"] = reviewed + word_result_updated["llm_reviewed"] = True + word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL + + await update_session_db(session_id, word_result=word_result_updated, current_step=9) + cached["word_result"] = word_result_updated + + steps_run.append("llm_review") + yield await _auto_sse_event("llm_review", "done", { + "entries_reviewed": len(reviewed), + "model": OLLAMA_REVIEW_MODEL, + }) + except Exception as e: + logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}") + yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False}) + steps_skipped.append("llm_review") + else: + steps_skipped.append("llm_review") + reason = "skipped by request" if req.skip_llm_review else "from_step > 6" + yield await _auto_sse_event("llm_review", "skipped", {"reason": reason}) + + # Final event + yield await _auto_sse_event("complete", "done", { + "steps_run": steps_run, + "steps_skipped": steps_skipped, + }) + + return StreamingResponse( + _generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/klausur-service/backend/ocr_pipeline_reprocess.py b/klausur-service/backend/ocr_pipeline_reprocess.py new file mode 100644 index 0000000..62d68fa --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_reprocess.py @@ -0,0 +1,94 @@ +""" +OCR Pipeline Reprocess Endpoint. + +POST /sessions/{session_id}/reprocess — clear downstream + restart from step. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict + +from fastapi import APIRouter, HTTPException, Request + +from ocr_pipeline_common import _cache +from ocr_pipeline_session_store import get_session_db, update_session_db + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["ocr-pipeline"]) + + +@router.post("/sessions/{session_id}/reprocess") +async def reprocess_session(session_id: str, request: Request): + """Re-run pipeline from a specific step, clearing downstream data. + + Body: {"from_step": 5} (1-indexed step number) + + Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) -> + Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10) + + Clears downstream results: + - from_step <= 1: orientation_result + all downstream + - from_step <= 2: deskew_result + all downstream + - from_step <= 3: dewarp_result + all downstream + - from_step <= 4: crop_result + all downstream + - from_step <= 5: column_result, row_result, word_result + - from_step <= 6: row_result, word_result + - from_step <= 7: word_result (cells, vocab_entries) + - from_step <= 8: word_result.llm_review only + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + body = await request.json() + from_step = body.get("from_step", 1) + if not isinstance(from_step, int) or from_step < 1 or from_step > 10: + raise HTTPException(status_code=400, detail="from_step must be between 1 and 10") + + update_kwargs: Dict[str, Any] = {"current_step": from_step} + + # Clear downstream data based on from_step + # New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) -> + # Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11) + if from_step <= 8: + update_kwargs["word_result"] = None + elif from_step == 9: + # Only clear LLM review from word_result + word_result = session.get("word_result") + if word_result: + word_result.pop("llm_review", None) + word_result.pop("llm_corrections", None) + update_kwargs["word_result"] = word_result + + if from_step <= 7: + update_kwargs["row_result"] = None + if from_step <= 6: + update_kwargs["column_result"] = None + if from_step <= 4: + update_kwargs["crop_result"] = None + if from_step <= 3: + update_kwargs["dewarp_result"] = None + if from_step <= 2: + update_kwargs["deskew_result"] = None + if from_step <= 1: + update_kwargs["orientation_result"] = None + + await update_session_db(session_id, **update_kwargs) + + # Also clear cache + if session_id in _cache: + for key in list(update_kwargs.keys()): + if key != "current_step": + _cache[session_id][key] = update_kwargs[key] + _cache[session_id]["current_step"] = from_step + + logger.info(f"Session {session_id} reprocessing from step {from_step}") + + return { + "session_id": session_id, + "from_step": from_step, + "cleared": [k for k in update_kwargs if k != "current_step"], + } diff --git a/klausur-service/backend/page_crop.py b/klausur-service/backend/page_crop.py index 108e095..ca4a8d0 100644 --- a/klausur-service/backend/page_crop.py +++ b/klausur-service/backend/page_crop.py @@ -1,758 +1,33 @@ """ -Page Crop - Content-based crop for scanned pages and book scans. +Page Crop — Barrel Re-export -Detects the content boundary by analysing ink density projections and -(for book scans) the spine shadow gradient. Works with both loose A4 -sheets on dark scanners AND book scans with white backgrounds. +Content-based crop for scanned pages and book scans. +Split into: +- page_crop_edges.py — Edge detection (spine shadow, gutter, projection) +- page_crop_core.py — Main crop algorithm and format detection + +All public names are re-exported here for backward compatibility. License: Apache 2.0 """ -import logging -from typing import Dict, Any, Tuple, Optional +# Core: main crop functions and format detection +from page_crop_core import ( # noqa: F401 + PAPER_FORMATS, + detect_page_splits, + detect_and_crop_page, + _detect_format, +) -import cv2 -import numpy as np - -logger = logging.getLogger(__name__) - -# Known paper format aspect ratios (height / width, portrait orientation) -PAPER_FORMATS = { - "A4": 297.0 / 210.0, # 1.4143 - "A5": 210.0 / 148.0, # 1.4189 - "Letter": 11.0 / 8.5, # 1.2941 - "Legal": 14.0 / 8.5, # 1.6471 - "A3": 420.0 / 297.0, # 1.4141 -} - -# Minimum ink density (fraction of pixels) to count a row/column as "content" -_INK_THRESHOLD = 0.003 # 0.3% - -# Minimum run length (fraction of dimension) to keep — shorter runs are noise -_MIN_RUN_FRAC = 0.005 # 0.5% - - -def detect_page_splits( - img_bgr: np.ndarray, -) -> list: - """Detect if the image is a multi-page spread and return split rectangles. - - Uses **brightness** (not ink density) to find the spine area: - the scanner bed produces a characteristic gray strip where pages meet, - which is darker than the white paper on either side. - - Returns a list of page dicts ``{x, y, width, height, page_index}`` - or an empty list if only one page is detected. - """ - h, w = img_bgr.shape[:2] - - # Only check landscape-ish images (width > height * 1.15) - if w < h * 1.15: - return [] - - gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) - - # Column-mean brightness (0-255) — the spine is darker (gray scanner bed) - col_brightness = np.mean(gray, axis=0).astype(np.float64) - - # Heavy smoothing to ignore individual text lines - kern = max(11, w // 50) - if kern % 2 == 0: - kern += 1 - brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same") - - # Page paper is bright (typically > 200), spine/scanner bed is darker - page_brightness = float(np.max(brightness_smooth)) - if page_brightness < 100: - return [] # Very dark image, skip - - # Spine threshold: significantly darker than the page - # Spine is typically 60-80% of paper brightness - spine_thresh = page_brightness * 0.88 - - # Search in center region (30-70% of width) - center_lo = int(w * 0.30) - center_hi = int(w * 0.70) - - # Find the darkest valley in the center region - center_brightness = brightness_smooth[center_lo:center_hi] - darkest_val = float(np.min(center_brightness)) - - if darkest_val >= spine_thresh: - logger.debug("No spine detected: min brightness %.0f >= threshold %.0f", - darkest_val, spine_thresh) - return [] - - # Find ALL contiguous dark runs in the center region - is_dark = center_brightness < spine_thresh - dark_runs: list = [] # list of (start, end) pairs - run_start = -1 - for i in range(len(is_dark)): - if is_dark[i]: - if run_start < 0: - run_start = i - else: - if run_start >= 0: - dark_runs.append((run_start, i)) - run_start = -1 - if run_start >= 0: - dark_runs.append((run_start, len(is_dark))) - - # Filter out runs that are too narrow (< 1% of image width) - min_spine_px = int(w * 0.01) - dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px] - - if not dark_runs: - logger.debug("No dark runs wider than %dpx in center region", min_spine_px) - return [] - - # Score each dark run: prefer centered, dark, narrow valleys - center_region_len = center_hi - center_lo - image_center_in_region = (w * 0.5 - center_lo) # x=50% mapped into region coords - best_score = -1.0 - best_start, best_end = dark_runs[0] - - for rs, re in dark_runs: - run_width = re - rs - run_center = (rs + re) / 2.0 - - # --- Factor 1: Proximity to image center (gaussian, sigma = 15% of region) --- - sigma = center_region_len * 0.15 - dist = abs(run_center - image_center_in_region) - center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2)) - - # --- Factor 2: Darkness (how dark is the valley relative to threshold) --- - run_brightness = float(np.mean(center_brightness[rs:re])) - # Normalize: 1.0 when run_brightness == 0, 0.0 when run_brightness == spine_thresh - darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh) - - # --- Factor 3: Narrowness bonus (spine shadows are narrow, not wide plateaus) --- - # Typical spine: 1-5% of image width. Penalise runs wider than ~8%. - width_frac = run_width / w - if width_frac <= 0.05: - narrowness_bonus = 1.0 - elif width_frac <= 0.15: - narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10 # linear decay 1.0 → 0.0 - else: - narrowness_bonus = 0.0 - - score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus) - - logger.debug( - "Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f → score=%.4f", - center_lo + rs, center_lo + re, run_width, - center_factor, darkness_factor, narrowness_bonus, score, - ) - - if score > best_score: - best_score = score - best_start, best_end = rs, re - - spine_w = best_end - best_start - spine_x = center_lo + best_start - spine_center = spine_x + spine_w // 2 - - logger.debug( - "Best spine candidate: x=%d..%d (w=%d), score=%.4f", - spine_x, spine_x + spine_w, spine_w, best_score, - ) - - # Verify: must have bright (paper) content on BOTH sides - left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x])) - right_end = center_lo + best_end - right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)])) - - if left_brightness < spine_thresh or right_brightness < spine_thresh: - logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f", - left_brightness, right_brightness, spine_thresh) - return [] - - logger.info( - "Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, " - "left_paper=%.0f, right_paper=%.0f", - spine_x, right_end, spine_w, darkest_val, page_brightness, - left_brightness, right_brightness, - ) - - # Split at the spine center - split_points = [spine_center] - - # Build page rectangles - pages: list = [] - prev_x = 0 - for i, sx in enumerate(split_points): - pages.append({"x": prev_x, "y": 0, "width": sx - prev_x, - "height": h, "page_index": i}) - prev_x = sx - pages.append({"x": prev_x, "y": 0, "width": w - prev_x, - "height": h, "page_index": len(split_points)}) - - # Filter out tiny pages (< 15% of total width) - pages = [p for p in pages if p["width"] >= w * 0.15] - if len(pages) < 2: - return [] - - # Re-index - for i, p in enumerate(pages): - p["page_index"] = i - - logger.info( - "Page split detected: %d pages, spine_w=%d, split_points=%s", - len(pages), spine_w, split_points, - ) - return pages - - -def detect_and_crop_page( - img_bgr: np.ndarray, - margin_frac: float = 0.01, -) -> Tuple[np.ndarray, Dict[str, Any]]: - """Detect content boundary and crop scanner/book borders. - - Algorithm (4-edge detection): - 1. Adaptive threshold → binary (text=255, bg=0) - 2. Left edge: spine-shadow detection via grayscale column means, - fallback to binary vertical projection - 3. Right edge: binary vertical projection (last ink column) - 4. Top/bottom edges: binary horizontal projection - 5. Sanity checks, then crop with configurable margin - - Args: - img_bgr: Input BGR image (should already be deskewed/dewarped) - margin_frac: Extra margin around content (fraction of dimension, default 1%) - - Returns: - Tuple of (cropped_image, result_dict) - """ - h, w = img_bgr.shape[:2] - total_area = h * w - - result: Dict[str, Any] = { - "crop_applied": False, - "crop_rect": None, - "crop_rect_pct": None, - "original_size": {"width": w, "height": h}, - "cropped_size": {"width": w, "height": h}, - "detected_format": None, - "format_confidence": 0.0, - "aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4), - "border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0}, - } - - gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) - - # --- Binarise with adaptive threshold (works for white-on-white) --- - binary = cv2.adaptiveThreshold( - gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, - cv2.THRESH_BINARY_INV, blockSize=51, C=15, - ) - - # --- Left edge: spine-shadow detection --- - left_edge = _detect_left_edge_shadow(gray, binary, w, h) - - # --- Right edge: spine-shadow detection --- - right_edge = _detect_right_edge_shadow(gray, binary, w, h) - - # --- Top / bottom edges: binary horizontal projection --- - top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h) - - # Compute border fractions - border_top = top_edge / h - border_bottom = (h - bottom_edge) / h - border_left = left_edge / w - border_right = (w - right_edge) / w - - result["border_fractions"] = { - "top": round(border_top, 4), - "bottom": round(border_bottom, 4), - "left": round(border_left, 4), - "right": round(border_right, 4), - } - - # Sanity: only crop if at least one edge has > 2% border - min_border = 0.02 - if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]): - logger.info("All borders < %.0f%% — no crop needed", min_border * 100) - result["detected_format"], result["format_confidence"] = _detect_format(w, h) - return img_bgr, result - - # Add margin - margin_x = int(w * margin_frac) - margin_y = int(h * margin_frac) - - crop_x = max(0, left_edge - margin_x) - crop_y = max(0, top_edge - margin_y) - crop_x2 = min(w, right_edge + margin_x) - crop_y2 = min(h, bottom_edge + margin_y) - - crop_w = crop_x2 - crop_x - crop_h = crop_y2 - crop_y - - # Sanity: cropped area must be >= 40% of original - if crop_w * crop_h < 0.40 * total_area: - logger.warning("Cropped area too small (%.0f%%) — skipping crop", - 100.0 * crop_w * crop_h / total_area) - result["detected_format"], result["format_confidence"] = _detect_format(w, h) - return img_bgr, result - - cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy() - - detected_format, format_confidence = _detect_format(crop_w, crop_h) - - result["crop_applied"] = True - result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h} - result["crop_rect_pct"] = { - "x": round(100.0 * crop_x / w, 2), - "y": round(100.0 * crop_y / h, 2), - "width": round(100.0 * crop_w / w, 2), - "height": round(100.0 * crop_h / h, 2), - } - result["cropped_size"] = {"width": crop_w, "height": crop_h} - result["detected_format"] = detected_format - result["format_confidence"] = format_confidence - result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4) - - logger.info( - "Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), " - "borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%", - w, h, crop_w, crop_h, detected_format, format_confidence * 100, - border_top * 100, border_bottom * 100, - border_left * 100, border_right * 100, - ) - - return cropped, result - - -# --------------------------------------------------------------------------- # Edge detection helpers -# --------------------------------------------------------------------------- - -def _detect_spine_shadow( - gray: np.ndarray, - search_region: np.ndarray, - offset_x: int, - w: int, - side: str, -) -> Optional[int]: - """Find the book spine center (darkest point) in a scanner shadow. - - The scanner produces a gray strip where the book spine presses against - the glass. The darkest column in that strip is the spine center — - that's where we crop. - - Distinguishes real spine shadows from text content by checking: - 1. Strong brightness range (> 40 levels) - 2. Darkest point is genuinely dark (< 180 mean brightness) - 3. The dark area is a NARROW valley, not a text-content plateau - 4. Brightness rises significantly toward the page content side - - Args: - gray: Full grayscale image (for context). - search_region: Column slice of the grayscale image to search in. - offset_x: X offset of search_region relative to full image. - w: Full image width. - side: 'left' or 'right' (for logging). - - Returns: - X coordinate (in full image) of the spine center, or None. - """ - region_w = search_region.shape[1] - if region_w < 10: - return None - - # Column-mean brightness in the search region - col_means = np.mean(search_region, axis=0).astype(np.float64) - - # Smooth with boxcar kernel (width = 1% of image width, min 5) - kernel_size = max(5, w // 100) - if kernel_size % 2 == 0: - kernel_size += 1 - kernel = np.ones(kernel_size) / kernel_size - smoothed_raw = np.convolve(col_means, kernel, mode="same") - - # Trim convolution edge artifacts (edges are zero-padded → artificially low) - margin = kernel_size // 2 - if region_w <= 2 * margin + 10: - return None - smoothed = smoothed_raw[margin:region_w - margin] - trim_offset = margin # offset of smoothed[0] relative to search_region - - val_min = float(np.min(smoothed)) - val_max = float(np.max(smoothed)) - shadow_range = val_max - val_min - - # --- Check 1: Strong brightness gradient --- - if shadow_range <= 40: - logger.debug( - "%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range, - ) - return None - - # --- Check 2: Darkest point must be genuinely dark --- - # Spine shadows have mean column brightness 60-160. - # Text on white paper stays above 180. - if val_min > 180: - logger.debug( - "%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min, - ) - return None - - spine_idx = int(np.argmin(smoothed)) # index in trimmed array - spine_local = spine_idx + trim_offset # index in search_region - trimmed_len = len(smoothed) - - # --- Check 3: Valley width (spine is narrow, text plateau is wide) --- - # Count how many columns are within 20% of the shadow range above the min. - valley_thresh = val_min + shadow_range * 0.20 - valley_mask = smoothed < valley_thresh - valley_width = int(np.sum(valley_mask)) - # Spine valleys are typically 3-15% of image width (20-120px on a 800px image). - # Text content plateaus span 20%+ of the search region. - max_valley_frac = 0.50 # valley must not cover more than half the trimmed region - if valley_width > trimmed_len * max_valley_frac: - logger.debug( - "%s edge: no spine (valley too wide: %d/%d = %.0f%%)", - side.capitalize(), valley_width, trimmed_len, - 100.0 * valley_width / trimmed_len, - ) - return None - - # --- Check 4: Brightness must rise toward page content --- - # For left edge: after spine, brightness should rise (= page paper) - # For right edge: before spine, brightness should rise - rise_check_w = max(5, trimmed_len // 5) # check 20% of trimmed region - if side == "left": - # Check columns to the right of the spine (in trimmed array) - right_start = min(spine_idx + 5, trimmed_len - 1) - right_end = min(right_start + rise_check_w, trimmed_len) - if right_end > right_start: - rise_brightness = float(np.mean(smoothed[right_start:right_end])) - rise = rise_brightness - val_min - if rise < shadow_range * 0.3: - logger.debug( - "%s edge: no spine (insufficient rise: %.0f, need %.0f)", - side.capitalize(), rise, shadow_range * 0.3, - ) - return None - else: # right - # Check columns to the left of the spine (in trimmed array) - left_end = max(spine_idx - 5, 0) - left_start = max(left_end - rise_check_w, 0) - if left_end > left_start: - rise_brightness = float(np.mean(smoothed[left_start:left_end])) - rise = rise_brightness - val_min - if rise < shadow_range * 0.3: - logger.debug( - "%s edge: no spine (insufficient rise: %.0f, need %.0f)", - side.capitalize(), rise, shadow_range * 0.3, - ) - return None - - spine_x = offset_x + spine_local - - logger.info( - "%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)", - side.capitalize(), spine_x, val_min, shadow_range, valley_width, - ) - return spine_x - - -def _detect_gutter_continuity( - gray: np.ndarray, - search_region: np.ndarray, - offset_x: int, - w: int, - side: str, -) -> Optional[int]: - """Detect gutter shadow via vertical continuity analysis. - - Camera book scans produce a subtle brightness gradient at the gutter - that is too faint for scanner-shadow detection (range < 40). However, - the gutter shadow has a unique property: it runs **continuously from - top to bottom** without interruption. Text and images always have - vertical gaps between lines, paragraphs, or sections. - - Algorithm: - 1. Divide image into N horizontal strips (~60px each) - 2. For each column, compute what fraction of strips are darker than - the page median (from the center 50% of the full image) - 3. A "gutter column" has ≥ 75% of strips darker than page_median − δ - 4. Smooth the dark-fraction profile and find the transition point - from the edge inward where the fraction drops below 0.50 - 5. Validate: gutter band must be 0.5%-10% of image width - - Args: - gray: Full grayscale image. - search_region: Edge slice of the grayscale image. - offset_x: X offset of search_region relative to full image. - w: Full image width. - side: 'left' or 'right'. - - Returns: - X coordinate (in full image) of the gutter inner edge, or None. - """ - region_h, region_w = search_region.shape[:2] - if region_w < 20 or region_h < 100: - return None - - # --- 1. Divide into horizontal strips --- - strip_target_h = 60 # ~60px per strip - n_strips = max(10, region_h // strip_target_h) - strip_h = region_h // n_strips - - strip_means = np.zeros((n_strips, region_w), dtype=np.float64) - for s in range(n_strips): - y0 = s * strip_h - y1 = min((s + 1) * strip_h, region_h) - strip_means[s] = np.mean(search_region[y0:y1, :], axis=0) - - # --- 2. Page median from center 50% of full image --- - center_lo = w // 4 - center_hi = 3 * w // 4 - page_median = float(np.median(gray[:, center_lo:center_hi])) - - # Camera shadows are subtle — threshold just 5 levels below page median - dark_thresh = page_median - 5.0 - - # If page is very dark overall (e.g. photo, not a book page), bail out - if page_median < 180: - return None - - # --- 3. Per-column dark fraction --- - dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64) - dark_frac = dark_count / n_strips # shape: (region_w,) - - # --- 4. Smooth and find transition --- - # Rolling mean (window = 1% of image width, min 5) - smooth_w = max(5, w // 100) - if smooth_w % 2 == 0: - smooth_w += 1 - kernel = np.ones(smooth_w) / smooth_w - frac_smooth = np.convolve(dark_frac, kernel, mode="same") - - # Trim convolution edges - margin = smooth_w // 2 - if region_w <= 2 * margin + 10: - return None - - # Find the peak of dark fraction (gutter center). - # For right gutters the peak is near the edge; for left gutters - # (V-shaped spine shadow) the peak may be well inside the region. - transition_thresh = 0.50 - peak_frac = float(np.max(frac_smooth[margin:region_w - margin])) - - if peak_frac < 0.70: - logger.debug( - "%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac, - ) - return None - - peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin - gutter_inner = None # local x in search_region - - if side == "right": - # Scan from peak toward the page center (leftward) - for x in range(peak_x, margin, -1): - if frac_smooth[x] < transition_thresh: - gutter_inner = x + 1 - break - else: - # Scan from peak toward the page center (rightward) - for x in range(peak_x, region_w - margin): - if frac_smooth[x] < transition_thresh: - gutter_inner = x - 1 - break - - if gutter_inner is None: - return None - - # --- 5. Validate gutter width --- - if side == "right": - gutter_width = region_w - gutter_inner - else: - gutter_width = gutter_inner - - min_gutter = max(3, int(w * 0.005)) # at least 0.5% of image - max_gutter = int(w * 0.10) # at most 10% of image - - if gutter_width < min_gutter: - logger.debug( - "%s gutter: too narrow (%dpx < %dpx)", side.capitalize(), - gutter_width, min_gutter, - ) - return None - - if gutter_width > max_gutter: - logger.debug( - "%s gutter: too wide (%dpx > %dpx)", side.capitalize(), - gutter_width, max_gutter, - ) - return None - - # Check that the gutter band is meaningfully darker than the page - if side == "right": - gutter_brightness = float(np.mean(strip_means[:, gutter_inner:])) - else: - gutter_brightness = float(np.mean(strip_means[:, :gutter_inner])) - - brightness_drop = page_median - gutter_brightness - if brightness_drop < 3: - logger.debug( - "%s gutter: insufficient brightness drop (%.1f levels)", - side.capitalize(), brightness_drop, - ) - return None - - gutter_x = offset_x + gutter_inner - - logger.info( - "%s gutter (continuity): x=%d, width=%dpx (%.1f%%), " - "brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f", - side.capitalize(), gutter_x, gutter_width, - 100.0 * gutter_width / w, gutter_brightness, page_median, - brightness_drop, float(frac_smooth[gutter_inner]), - ) - return gutter_x - - -def _detect_left_edge_shadow( - gray: np.ndarray, - binary: np.ndarray, - w: int, - h: int, -) -> int: - """Detect left content edge, accounting for book-spine shadow. - - Tries three methods in order: - 1. Scanner spine-shadow (dark gradient, range > 40) - 2. Camera gutter continuity (subtle shadow running top-to-bottom) - 3. Binary projection fallback (first ink column) - """ - search_w = max(1, w // 4) - spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left") - if spine_x is not None: - return spine_x - - # Fallback 1: vertical continuity (camera gutter shadow) - gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left") - if gutter_x is not None: - return gutter_x - - # Fallback 2: binary vertical projection - return _detect_edge_projection(binary, axis=0, from_start=True, dim=w) - - -def _detect_right_edge_shadow( - gray: np.ndarray, - binary: np.ndarray, - w: int, - h: int, -) -> int: - """Detect right content edge, accounting for book-spine shadow. - - Tries three methods in order: - 1. Scanner spine-shadow (dark gradient, range > 40) - 2. Camera gutter continuity (subtle shadow running top-to-bottom) - 3. Binary projection fallback (last ink column) - """ - search_w = max(1, w // 4) - right_start = w - search_w - spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right") - if spine_x is not None: - return spine_x - - # Fallback 1: vertical continuity (camera gutter shadow) - gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right") - if gutter_x is not None: - return gutter_x - - # Fallback 2: binary vertical projection - return _detect_edge_projection(binary, axis=0, from_start=False, dim=w) - - -def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]: - """Detect top and bottom content edges via binary horizontal projection.""" - top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h) - bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h) - return top, bottom - - -def _detect_edge_projection( - binary: np.ndarray, - axis: int, - from_start: bool, - dim: int, -) -> int: - """Find the first/last row or column with ink density above threshold. - - axis=0 → project vertically (column densities) → returns x position - axis=1 → project horizontally (row densities) → returns y position - - Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension. - """ - # Compute density per row/column (mean of binary pixels / 255) - projection = np.mean(binary, axis=axis) / 255.0 - - # Create mask of "ink" positions - ink_mask = projection >= _INK_THRESHOLD - - # Filter narrow runs (noise) - min_run = max(1, int(dim * _MIN_RUN_FRAC)) - ink_mask = _filter_narrow_runs(ink_mask, min_run) - - ink_positions = np.where(ink_mask)[0] - if len(ink_positions) == 0: - return 0 if from_start else dim - - if from_start: - return int(ink_positions[0]) - else: - return int(ink_positions[-1]) - - -def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray: - """Remove True-runs shorter than min_run pixels.""" - if min_run <= 1: - return mask - - result = mask.copy() - n = len(result) - i = 0 - while i < n: - if result[i]: - start = i - while i < n and result[i]: - i += 1 - if i - start < min_run: - result[start:i] = False - else: - i += 1 - return result - - -# --------------------------------------------------------------------------- -# Format detection (kept as optional metadata) -# --------------------------------------------------------------------------- - -def _detect_format(width: int, height: int) -> Tuple[str, float]: - """Detect paper format from dimensions by comparing aspect ratios.""" - if width <= 0 or height <= 0: - return "unknown", 0.0 - - aspect = max(width, height) / min(width, height) - - best_format = "unknown" - best_diff = float("inf") - - for fmt, expected_ratio in PAPER_FORMATS.items(): - diff = abs(aspect - expected_ratio) - if diff < best_diff: - best_diff = diff - best_format = fmt - - confidence = max(0.0, 1.0 - best_diff * 5.0) - - if confidence < 0.3: - return "unknown", 0.0 - - return best_format, round(confidence, 3) +from page_crop_edges import ( # noqa: F401 + _INK_THRESHOLD, + _MIN_RUN_FRAC, + _detect_spine_shadow, + _detect_gutter_continuity, + _detect_left_edge_shadow, + _detect_right_edge_shadow, + _detect_top_bottom_edges, + _detect_edge_projection, + _filter_narrow_runs, +) diff --git a/klausur-service/backend/page_crop_core.py b/klausur-service/backend/page_crop_core.py new file mode 100644 index 0000000..53c3723 --- /dev/null +++ b/klausur-service/backend/page_crop_core.py @@ -0,0 +1,342 @@ +""" +Page Crop - Core Crop and Format Detection + +Content-based crop for scanned pages and book scans. Detects the content +boundary by analysing ink density projections and (for book scans) the +spine shadow gradient. + +Extracted from page_crop.py to keep files under 500 LOC. +License: Apache 2.0 +""" + +import logging +from typing import Dict, Any, Tuple + +import cv2 +import numpy as np + +from page_crop_edges import ( + _detect_left_edge_shadow, + _detect_right_edge_shadow, + _detect_top_bottom_edges, +) + +logger = logging.getLogger(__name__) + +# Known paper format aspect ratios (height / width, portrait orientation) +PAPER_FORMATS = { + "A4": 297.0 / 210.0, # 1.4143 + "A5": 210.0 / 148.0, # 1.4189 + "Letter": 11.0 / 8.5, # 1.2941 + "Legal": 14.0 / 8.5, # 1.6471 + "A3": 420.0 / 297.0, # 1.4141 +} + + +def detect_page_splits( + img_bgr: np.ndarray, +) -> list: + """Detect if the image is a multi-page spread and return split rectangles. + + Uses **brightness** (not ink density) to find the spine area: + the scanner bed produces a characteristic gray strip where pages meet, + which is darker than the white paper on either side. + + Returns a list of page dicts ``{x, y, width, height, page_index}`` + or an empty list if only one page is detected. + """ + h, w = img_bgr.shape[:2] + + # Only check landscape-ish images (width > height * 1.15) + if w < h * 1.15: + return [] + + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + + # Column-mean brightness (0-255) — the spine is darker (gray scanner bed) + col_brightness = np.mean(gray, axis=0).astype(np.float64) + + # Heavy smoothing to ignore individual text lines + kern = max(11, w // 50) + if kern % 2 == 0: + kern += 1 + brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same") + + # Page paper is bright (typically > 200), spine/scanner bed is darker + page_brightness = float(np.max(brightness_smooth)) + if page_brightness < 100: + return [] # Very dark image, skip + + # Spine threshold: significantly darker than the page + spine_thresh = page_brightness * 0.88 + + # Search in center region (30-70% of width) + center_lo = int(w * 0.30) + center_hi = int(w * 0.70) + + # Find the darkest valley in the center region + center_brightness = brightness_smooth[center_lo:center_hi] + darkest_val = float(np.min(center_brightness)) + + if darkest_val >= spine_thresh: + logger.debug("No spine detected: min brightness %.0f >= threshold %.0f", + darkest_val, spine_thresh) + return [] + + # Find ALL contiguous dark runs in the center region + is_dark = center_brightness < spine_thresh + dark_runs: list = [] + run_start = -1 + for i in range(len(is_dark)): + if is_dark[i]: + if run_start < 0: + run_start = i + else: + if run_start >= 0: + dark_runs.append((run_start, i)) + run_start = -1 + if run_start >= 0: + dark_runs.append((run_start, len(is_dark))) + + # Filter out runs that are too narrow (< 1% of image width) + min_spine_px = int(w * 0.01) + dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px] + + if not dark_runs: + logger.debug("No dark runs wider than %dpx in center region", min_spine_px) + return [] + + # Score each dark run: prefer centered, dark, narrow valleys + center_region_len = center_hi - center_lo + image_center_in_region = (w * 0.5 - center_lo) + best_score = -1.0 + best_start, best_end = dark_runs[0] + + for rs, re in dark_runs: + run_width = re - rs + run_center = (rs + re) / 2.0 + + sigma = center_region_len * 0.15 + dist = abs(run_center - image_center_in_region) + center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2)) + + run_brightness = float(np.mean(center_brightness[rs:re])) + darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh) + + width_frac = run_width / w + if width_frac <= 0.05: + narrowness_bonus = 1.0 + elif width_frac <= 0.15: + narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10 + else: + narrowness_bonus = 0.0 + + score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus) + + logger.debug( + "Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f", + center_lo + rs, center_lo + re, run_width, + center_factor, darkness_factor, narrowness_bonus, score, + ) + + if score > best_score: + best_score = score + best_start, best_end = rs, re + + spine_w = best_end - best_start + spine_x = center_lo + best_start + spine_center = spine_x + spine_w // 2 + + logger.debug( + "Best spine candidate: x=%d..%d (w=%d), score=%.4f", + spine_x, spine_x + spine_w, spine_w, best_score, + ) + + # Verify: must have bright (paper) content on BOTH sides + left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x])) + right_end = center_lo + best_end + right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)])) + + if left_brightness < spine_thresh or right_brightness < spine_thresh: + logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f", + left_brightness, right_brightness, spine_thresh) + return [] + + logger.info( + "Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, " + "left_paper=%.0f, right_paper=%.0f", + spine_x, right_end, spine_w, darkest_val, page_brightness, + left_brightness, right_brightness, + ) + + # Split at the spine center + split_points = [spine_center] + + # Build page rectangles + pages: list = [] + prev_x = 0 + for i, sx in enumerate(split_points): + pages.append({"x": prev_x, "y": 0, "width": sx - prev_x, + "height": h, "page_index": i}) + prev_x = sx + pages.append({"x": prev_x, "y": 0, "width": w - prev_x, + "height": h, "page_index": len(split_points)}) + + # Filter out tiny pages (< 15% of total width) + pages = [p for p in pages if p["width"] >= w * 0.15] + if len(pages) < 2: + return [] + + # Re-index + for i, p in enumerate(pages): + p["page_index"] = i + + logger.info( + "Page split detected: %d pages, spine_w=%d, split_points=%s", + len(pages), spine_w, split_points, + ) + return pages + + +def detect_and_crop_page( + img_bgr: np.ndarray, + margin_frac: float = 0.01, +) -> Tuple[np.ndarray, Dict[str, Any]]: + """Detect content boundary and crop scanner/book borders. + + Algorithm (4-edge detection): + 1. Adaptive threshold -> binary (text=255, bg=0) + 2. Left edge: spine-shadow detection via grayscale column means, + fallback to binary vertical projection + 3. Right edge: binary vertical projection (last ink column) + 4. Top/bottom edges: binary horizontal projection + 5. Sanity checks, then crop with configurable margin + + Args: + img_bgr: Input BGR image (should already be deskewed/dewarped) + margin_frac: Extra margin around content (fraction of dimension, default 1%) + + Returns: + Tuple of (cropped_image, result_dict) + """ + h, w = img_bgr.shape[:2] + total_area = h * w + + result: Dict[str, Any] = { + "crop_applied": False, + "crop_rect": None, + "crop_rect_pct": None, + "original_size": {"width": w, "height": h}, + "cropped_size": {"width": w, "height": h}, + "detected_format": None, + "format_confidence": 0.0, + "aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4), + "border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0}, + } + + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + + # --- Binarise with adaptive threshold --- + binary = cv2.adaptiveThreshold( + gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY_INV, blockSize=51, C=15, + ) + + # --- Edge detection --- + left_edge = _detect_left_edge_shadow(gray, binary, w, h) + right_edge = _detect_right_edge_shadow(gray, binary, w, h) + top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h) + + # Compute border fractions + border_top = top_edge / h + border_bottom = (h - bottom_edge) / h + border_left = left_edge / w + border_right = (w - right_edge) / w + + result["border_fractions"] = { + "top": round(border_top, 4), + "bottom": round(border_bottom, 4), + "left": round(border_left, 4), + "right": round(border_right, 4), + } + + # Sanity: only crop if at least one edge has > 2% border + min_border = 0.02 + if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]): + logger.info("All borders < %.0f%% — no crop needed", min_border * 100) + result["detected_format"], result["format_confidence"] = _detect_format(w, h) + return img_bgr, result + + # Add margin + margin_x = int(w * margin_frac) + margin_y = int(h * margin_frac) + + crop_x = max(0, left_edge - margin_x) + crop_y = max(0, top_edge - margin_y) + crop_x2 = min(w, right_edge + margin_x) + crop_y2 = min(h, bottom_edge + margin_y) + + crop_w = crop_x2 - crop_x + crop_h = crop_y2 - crop_y + + # Sanity: cropped area must be >= 40% of original + if crop_w * crop_h < 0.40 * total_area: + logger.warning("Cropped area too small (%.0f%%) — skipping crop", + 100.0 * crop_w * crop_h / total_area) + result["detected_format"], result["format_confidence"] = _detect_format(w, h) + return img_bgr, result + + cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy() + + detected_format, format_confidence = _detect_format(crop_w, crop_h) + + result["crop_applied"] = True + result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h} + result["crop_rect_pct"] = { + "x": round(100.0 * crop_x / w, 2), + "y": round(100.0 * crop_y / h, 2), + "width": round(100.0 * crop_w / w, 2), + "height": round(100.0 * crop_h / h, 2), + } + result["cropped_size"] = {"width": crop_w, "height": crop_h} + result["detected_format"] = detected_format + result["format_confidence"] = format_confidence + result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4) + + logger.info( + "Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), " + "borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%", + w, h, crop_w, crop_h, detected_format, format_confidence * 100, + border_top * 100, border_bottom * 100, + border_left * 100, border_right * 100, + ) + + return cropped, result + + +# --------------------------------------------------------------------------- +# Format detection (kept as optional metadata) +# --------------------------------------------------------------------------- + +def _detect_format(width: int, height: int) -> Tuple[str, float]: + """Detect paper format from dimensions by comparing aspect ratios.""" + if width <= 0 or height <= 0: + return "unknown", 0.0 + + aspect = max(width, height) / min(width, height) + + best_format = "unknown" + best_diff = float("inf") + + for fmt, expected_ratio in PAPER_FORMATS.items(): + diff = abs(aspect - expected_ratio) + if diff < best_diff: + best_diff = diff + best_format = fmt + + confidence = max(0.0, 1.0 - best_diff * 5.0) + + if confidence < 0.3: + return "unknown", 0.0 + + return best_format, round(confidence, 3) diff --git a/klausur-service/backend/page_crop_edges.py b/klausur-service/backend/page_crop_edges.py new file mode 100644 index 0000000..b231078 --- /dev/null +++ b/klausur-service/backend/page_crop_edges.py @@ -0,0 +1,388 @@ +""" +Page Crop - Edge Detection Helpers + +Spine shadow detection, gutter continuity analysis, projection-based +edge detection, and narrow-run filtering for content cropping. + +Extracted from page_crop.py to keep files under 500 LOC. +License: Apache 2.0 +""" + +import logging +from typing import Optional, Tuple + +import cv2 +import numpy as np + +logger = logging.getLogger(__name__) + +# Minimum ink density (fraction of pixels) to count a row/column as "content" +_INK_THRESHOLD = 0.003 # 0.3% + +# Minimum run length (fraction of dimension) to keep — shorter runs are noise +_MIN_RUN_FRAC = 0.005 # 0.5% + + +def _detect_spine_shadow( + gray: np.ndarray, + search_region: np.ndarray, + offset_x: int, + w: int, + side: str, +) -> Optional[int]: + """Find the book spine center (darkest point) in a scanner shadow. + + The scanner produces a gray strip where the book spine presses against + the glass. The darkest column in that strip is the spine center — + that's where we crop. + + Distinguishes real spine shadows from text content by checking: + 1. Strong brightness range (> 40 levels) + 2. Darkest point is genuinely dark (< 180 mean brightness) + 3. The dark area is a NARROW valley, not a text-content plateau + 4. Brightness rises significantly toward the page content side + + Args: + gray: Full grayscale image (for context). + search_region: Column slice of the grayscale image to search in. + offset_x: X offset of search_region relative to full image. + w: Full image width. + side: 'left' or 'right' (for logging). + + Returns: + X coordinate (in full image) of the spine center, or None. + """ + region_w = search_region.shape[1] + if region_w < 10: + return None + + # Column-mean brightness in the search region + col_means = np.mean(search_region, axis=0).astype(np.float64) + + # Smooth with boxcar kernel (width = 1% of image width, min 5) + kernel_size = max(5, w // 100) + if kernel_size % 2 == 0: + kernel_size += 1 + kernel = np.ones(kernel_size) / kernel_size + smoothed_raw = np.convolve(col_means, kernel, mode="same") + + # Trim convolution edge artifacts (edges are zero-padded -> artificially low) + margin = kernel_size // 2 + if region_w <= 2 * margin + 10: + return None + smoothed = smoothed_raw[margin:region_w - margin] + trim_offset = margin # offset of smoothed[0] relative to search_region + + val_min = float(np.min(smoothed)) + val_max = float(np.max(smoothed)) + shadow_range = val_max - val_min + + # --- Check 1: Strong brightness gradient --- + if shadow_range <= 40: + logger.debug( + "%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range, + ) + return None + + # --- Check 2: Darkest point must be genuinely dark --- + if val_min > 180: + logger.debug( + "%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min, + ) + return None + + spine_idx = int(np.argmin(smoothed)) # index in trimmed array + spine_local = spine_idx + trim_offset # index in search_region + trimmed_len = len(smoothed) + + # --- Check 3: Valley width (spine is narrow, text plateau is wide) --- + valley_thresh = val_min + shadow_range * 0.20 + valley_mask = smoothed < valley_thresh + valley_width = int(np.sum(valley_mask)) + max_valley_frac = 0.50 + if valley_width > trimmed_len * max_valley_frac: + logger.debug( + "%s edge: no spine (valley too wide: %d/%d = %.0f%%)", + side.capitalize(), valley_width, trimmed_len, + 100.0 * valley_width / trimmed_len, + ) + return None + + # --- Check 4: Brightness must rise toward page content --- + rise_check_w = max(5, trimmed_len // 5) + if side == "left": + right_start = min(spine_idx + 5, trimmed_len - 1) + right_end = min(right_start + rise_check_w, trimmed_len) + if right_end > right_start: + rise_brightness = float(np.mean(smoothed[right_start:right_end])) + rise = rise_brightness - val_min + if rise < shadow_range * 0.3: + logger.debug( + "%s edge: no spine (insufficient rise: %.0f, need %.0f)", + side.capitalize(), rise, shadow_range * 0.3, + ) + return None + else: # right + left_end = max(spine_idx - 5, 0) + left_start = max(left_end - rise_check_w, 0) + if left_end > left_start: + rise_brightness = float(np.mean(smoothed[left_start:left_end])) + rise = rise_brightness - val_min + if rise < shadow_range * 0.3: + logger.debug( + "%s edge: no spine (insufficient rise: %.0f, need %.0f)", + side.capitalize(), rise, shadow_range * 0.3, + ) + return None + + spine_x = offset_x + spine_local + + logger.info( + "%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)", + side.capitalize(), spine_x, val_min, shadow_range, valley_width, + ) + return spine_x + + +def _detect_gutter_continuity( + gray: np.ndarray, + search_region: np.ndarray, + offset_x: int, + w: int, + side: str, +) -> Optional[int]: + """Detect gutter shadow via vertical continuity analysis. + + Camera book scans produce a subtle brightness gradient at the gutter + that is too faint for scanner-shadow detection (range < 40). However, + the gutter shadow has a unique property: it runs **continuously from + top to bottom** without interruption. + + Algorithm: + 1. Divide image into N horizontal strips (~60px each) + 2. For each column, compute what fraction of strips are darker than + the page median (from the center 50% of the full image) + 3. A "gutter column" has >= 75% of strips darker than page_median - d + 4. Smooth the dark-fraction profile and find the transition point + 5. Validate: gutter band must be 0.5%-10% of image width + """ + region_h, region_w = search_region.shape[:2] + if region_w < 20 or region_h < 100: + return None + + # --- 1. Divide into horizontal strips --- + strip_target_h = 60 + n_strips = max(10, region_h // strip_target_h) + strip_h = region_h // n_strips + + strip_means = np.zeros((n_strips, region_w), dtype=np.float64) + for s in range(n_strips): + y0 = s * strip_h + y1 = min((s + 1) * strip_h, region_h) + strip_means[s] = np.mean(search_region[y0:y1, :], axis=0) + + # --- 2. Page median from center 50% of full image --- + center_lo = w // 4 + center_hi = 3 * w // 4 + page_median = float(np.median(gray[:, center_lo:center_hi])) + + dark_thresh = page_median - 5.0 + + if page_median < 180: + return None + + # --- 3. Per-column dark fraction --- + dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64) + dark_frac = dark_count / n_strips + + # --- 4. Smooth and find transition --- + smooth_w = max(5, w // 100) + if smooth_w % 2 == 0: + smooth_w += 1 + kernel = np.ones(smooth_w) / smooth_w + frac_smooth = np.convolve(dark_frac, kernel, mode="same") + + margin = smooth_w // 2 + if region_w <= 2 * margin + 10: + return None + + transition_thresh = 0.50 + peak_frac = float(np.max(frac_smooth[margin:region_w - margin])) + + if peak_frac < 0.70: + logger.debug( + "%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac, + ) + return None + + peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin + gutter_inner = None + + if side == "right": + for x in range(peak_x, margin, -1): + if frac_smooth[x] < transition_thresh: + gutter_inner = x + 1 + break + else: + for x in range(peak_x, region_w - margin): + if frac_smooth[x] < transition_thresh: + gutter_inner = x - 1 + break + + if gutter_inner is None: + return None + + # --- 5. Validate gutter width --- + if side == "right": + gutter_width = region_w - gutter_inner + else: + gutter_width = gutter_inner + + min_gutter = max(3, int(w * 0.005)) + max_gutter = int(w * 0.10) + + if gutter_width < min_gutter: + logger.debug( + "%s gutter: too narrow (%dpx < %dpx)", side.capitalize(), + gutter_width, min_gutter, + ) + return None + + if gutter_width > max_gutter: + logger.debug( + "%s gutter: too wide (%dpx > %dpx)", side.capitalize(), + gutter_width, max_gutter, + ) + return None + + if side == "right": + gutter_brightness = float(np.mean(strip_means[:, gutter_inner:])) + else: + gutter_brightness = float(np.mean(strip_means[:, :gutter_inner])) + + brightness_drop = page_median - gutter_brightness + if brightness_drop < 3: + logger.debug( + "%s gutter: insufficient brightness drop (%.1f levels)", + side.capitalize(), brightness_drop, + ) + return None + + gutter_x = offset_x + gutter_inner + + logger.info( + "%s gutter (continuity): x=%d, width=%dpx (%.1f%%), " + "brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f", + side.capitalize(), gutter_x, gutter_width, + 100.0 * gutter_width / w, gutter_brightness, page_median, + brightness_drop, float(frac_smooth[gutter_inner]), + ) + return gutter_x + + +def _detect_left_edge_shadow( + gray: np.ndarray, + binary: np.ndarray, + w: int, + h: int, +) -> int: + """Detect left content edge, accounting for book-spine shadow. + + Tries three methods in order: + 1. Scanner spine-shadow (dark gradient, range > 40) + 2. Camera gutter continuity (subtle shadow running top-to-bottom) + 3. Binary projection fallback (first ink column) + """ + search_w = max(1, w // 4) + spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left") + if spine_x is not None: + return spine_x + + gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left") + if gutter_x is not None: + return gutter_x + + return _detect_edge_projection(binary, axis=0, from_start=True, dim=w) + + +def _detect_right_edge_shadow( + gray: np.ndarray, + binary: np.ndarray, + w: int, + h: int, +) -> int: + """Detect right content edge, accounting for book-spine shadow. + + Tries three methods in order: + 1. Scanner spine-shadow (dark gradient, range > 40) + 2. Camera gutter continuity (subtle shadow running top-to-bottom) + 3. Binary projection fallback (last ink column) + """ + search_w = max(1, w // 4) + right_start = w - search_w + spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right") + if spine_x is not None: + return spine_x + + gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right") + if gutter_x is not None: + return gutter_x + + return _detect_edge_projection(binary, axis=0, from_start=False, dim=w) + + +def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]: + """Detect top and bottom content edges via binary horizontal projection.""" + top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h) + bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h) + return top, bottom + + +def _detect_edge_projection( + binary: np.ndarray, + axis: int, + from_start: bool, + dim: int, +) -> int: + """Find the first/last row or column with ink density above threshold. + + axis=0 -> project vertically (column densities) -> returns x position + axis=1 -> project horizontally (row densities) -> returns y position + + Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension. + """ + projection = np.mean(binary, axis=axis) / 255.0 + + ink_mask = projection >= _INK_THRESHOLD + + min_run = max(1, int(dim * _MIN_RUN_FRAC)) + ink_mask = _filter_narrow_runs(ink_mask, min_run) + + ink_positions = np.where(ink_mask)[0] + if len(ink_positions) == 0: + return 0 if from_start else dim + + if from_start: + return int(ink_positions[0]) + else: + return int(ink_positions[-1]) + + +def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray: + """Remove True-runs shorter than min_run pixels.""" + if min_run <= 1: + return mask + + result = mask.copy() + n = len(result) + i = 0 + while i < n: + if result[i]: + start = i + while i < n and result[i]: + i += 1 + if i - start < min_run: + result[start:i] = False + else: + i += 1 + return result diff --git a/klausur-service/backend/services/trocr_batch.py b/klausur-service/backend/services/trocr_batch.py new file mode 100644 index 0000000..2ec8417 --- /dev/null +++ b/klausur-service/backend/services/trocr_batch.py @@ -0,0 +1,160 @@ +""" +TrOCR Batch Processing & Streaming + +Batch OCR and SSE streaming for multiple images. + +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import asyncio +import logging +import time +from typing import Optional, List, Dict, Any + +from .trocr_models import OCRResult, BatchOCRResult +from .trocr_ocr import run_trocr_ocr_enhanced + +logger = logging.getLogger(__name__) + + +async def run_trocr_batch( + images: List[bytes], + handwritten: bool = True, + split_lines: bool = True, + use_cache: bool = True, + progress_callback: Optional[callable] = None +) -> BatchOCRResult: + """ + Process multiple images in batch. + + Args: + images: List of image data bytes + handwritten: Use handwritten model + split_lines: Whether to split images into lines + use_cache: Whether to use caching + progress_callback: Optional callback(current, total) for progress updates + + Returns: + BatchOCRResult with all results + """ + start_time = time.time() + results = [] + cached_count = 0 + error_count = 0 + + for idx, image_data in enumerate(images): + try: + result = await run_trocr_ocr_enhanced( + image_data, + handwritten=handwritten, + split_lines=split_lines, + use_cache=use_cache + ) + results.append(result) + + if result.from_cache: + cached_count += 1 + + # Report progress + if progress_callback: + progress_callback(idx + 1, len(images)) + + except Exception as e: + logger.error(f"Batch OCR error for image {idx}: {e}") + error_count += 1 + results.append(OCRResult( + text=f"Error: {str(e)}", + confidence=0.0, + processing_time_ms=0, + model="error", + has_lora_adapter=False + )) + + total_time_ms = int((time.time() - start_time) * 1000) + + return BatchOCRResult( + results=results, + total_time_ms=total_time_ms, + processed_count=len(images), + cached_count=cached_count, + error_count=error_count + ) + + +# Generator for SSE streaming during batch processing +async def run_trocr_batch_stream( + images: List[bytes], + handwritten: bool = True, + split_lines: bool = True, + use_cache: bool = True +): + """ + Process images and yield progress updates for SSE streaming. + + Yields: + dict with current progress and result + """ + start_time = time.time() + total = len(images) + + for idx, image_data in enumerate(images): + try: + result = await run_trocr_ocr_enhanced( + image_data, + handwritten=handwritten, + split_lines=split_lines, + use_cache=use_cache + ) + + elapsed_ms = int((time.time() - start_time) * 1000) + avg_time_per_image = elapsed_ms / (idx + 1) + estimated_remaining = int(avg_time_per_image * (total - idx - 1)) + + yield { + "type": "progress", + "current": idx + 1, + "total": total, + "progress_percent": ((idx + 1) / total) * 100, + "elapsed_ms": elapsed_ms, + "estimated_remaining_ms": estimated_remaining, + "result": { + "text": result.text, + "confidence": result.confidence, + "processing_time_ms": result.processing_time_ms, + "from_cache": result.from_cache + } + } + + except Exception as e: + logger.error(f"Stream OCR error for image {idx}: {e}") + yield { + "type": "error", + "current": idx + 1, + "total": total, + "error": str(e) + } + + total_time_ms = int((time.time() - start_time) * 1000) + yield { + "type": "complete", + "total_time_ms": total_time_ms, + "processed_count": total + } + + +# Test function +async def test_trocr_ocr(image_path: str, handwritten: bool = False): + """Test TrOCR on a local image file.""" + from .trocr_ocr import run_trocr_ocr + + with open(image_path, "rb") as f: + image_data = f.read() + + text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten) + + print(f"\n=== TrOCR Test ===") + print(f"Mode: {'Handwritten' if handwritten else 'Printed'}") + print(f"Confidence: {confidence:.2f}") + print(f"Text:\n{text}") + + return text, confidence diff --git a/klausur-service/backend/services/trocr_models.py b/klausur-service/backend/services/trocr_models.py new file mode 100644 index 0000000..c9c0553 --- /dev/null +++ b/klausur-service/backend/services/trocr_models.py @@ -0,0 +1,278 @@ +""" +TrOCR Models & Cache + +Dataclasses, LRU cache, and model loading for TrOCR service. + +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import io +import os +import hashlib +import logging +import time +from typing import Tuple, Optional, List, Dict, Any +from dataclasses import dataclass, field +from collections import OrderedDict +from datetime import datetime, timedelta + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Backend routing: auto | pytorch | onnx +# --------------------------------------------------------------------------- +_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx + +# Lazy loading for heavy dependencies +# Cache keyed by model_name to support base and large variants simultaneously +_trocr_models: dict = {} # {model_name: (processor, model)} +_trocr_processor = None # backwards-compat alias -> base-printed +_trocr_model = None # backwards-compat alias -> base-printed +_trocr_available = None +_model_loaded_at = None + +# Simple in-memory cache with LRU eviction +_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() +_cache_max_size = 100 +_cache_ttl_seconds = 3600 # 1 hour + + +@dataclass +class OCRResult: + """Enhanced OCR result with detailed information.""" + text: str + confidence: float + processing_time_ms: int + model: str + has_lora_adapter: bool = False + char_confidences: List[float] = field(default_factory=list) + word_boxes: List[Dict[str, Any]] = field(default_factory=list) + from_cache: bool = False + image_hash: str = "" + + +@dataclass +class BatchOCRResult: + """Result for batch processing.""" + results: List[OCRResult] + total_time_ms: int + processed_count: int + cached_count: int + error_count: int + + +def _compute_image_hash(image_data: bytes) -> str: + """Compute SHA256 hash of image data for caching.""" + return hashlib.sha256(image_data).hexdigest()[:16] + + +def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]: + """Get cached OCR result if available and not expired.""" + if image_hash in _ocr_cache: + entry = _ocr_cache[image_hash] + if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds): + # Move to end (LRU) + _ocr_cache.move_to_end(image_hash) + return entry["result"] + else: + # Expired, remove + del _ocr_cache[image_hash] + return None + + +def _cache_set(image_hash: str, result: Dict[str, Any]) -> None: + """Store OCR result in cache.""" + # Evict oldest if at capacity + while len(_ocr_cache) >= _cache_max_size: + _ocr_cache.popitem(last=False) + + _ocr_cache[image_hash] = { + "result": result, + "cached_at": datetime.now() + } + + +def get_cache_stats() -> Dict[str, Any]: + """Get cache statistics.""" + return { + "size": len(_ocr_cache), + "max_size": _cache_max_size, + "ttl_seconds": _cache_ttl_seconds, + "hit_rate": 0 # Could track this with additional counters + } + + +def _check_trocr_available() -> bool: + """Check if TrOCR dependencies are available.""" + global _trocr_available + if _trocr_available is not None: + return _trocr_available + + try: + import torch + from transformers import TrOCRProcessor, VisionEncoderDecoderModel + _trocr_available = True + except ImportError as e: + logger.warning(f"TrOCR dependencies not available: {e}") + _trocr_available = False + + return _trocr_available + + +def get_trocr_model(handwritten: bool = False, size: str = "base"): + """ + Lazy load TrOCR model and processor. + + Args: + handwritten: Use handwritten model instead of printed model + size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy + for exam HTR). Only applies to handwritten variant. + + Returns tuple of (processor, model) or (None, None) if unavailable. + """ + global _trocr_processor, _trocr_model + + if not _check_trocr_available(): + return None, None + + # Select model name + if size == "large" and handwritten: + model_name = "microsoft/trocr-large-handwritten" + elif handwritten: + model_name = "microsoft/trocr-base-handwritten" + else: + model_name = "microsoft/trocr-base-printed" + + if model_name in _trocr_models: + return _trocr_models[model_name] + + try: + import torch + from transformers import TrOCRProcessor, VisionEncoderDecoderModel + + logger.info(f"Loading TrOCR model: {model_name}") + processor = TrOCRProcessor.from_pretrained(model_name) + model = VisionEncoderDecoderModel.from_pretrained(model_name) + + # Use GPU if available + device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + model.to(device) + logger.info(f"TrOCR model loaded on device: {device}") + + _trocr_models[model_name] = (processor, model) + + # Keep backwards-compat globals pointing at base-printed + if model_name == "microsoft/trocr-base-printed": + _trocr_processor = processor + _trocr_model = model + + return processor, model + + except Exception as e: + logger.error(f"Failed to load TrOCR model {model_name}: {e}") + return None, None + + +def preload_trocr_model(handwritten: bool = True) -> bool: + """ + Preload TrOCR model at startup for faster first request. + + Call this from your FastAPI startup event: + @app.on_event("startup") + async def startup(): + preload_trocr_model() + """ + global _model_loaded_at + logger.info("Preloading TrOCR model...") + processor, model = get_trocr_model(handwritten=handwritten) + if processor is not None and model is not None: + _model_loaded_at = datetime.now() + logger.info("TrOCR model preloaded successfully") + return True + else: + logger.warning("TrOCR model preloading failed") + return False + + +def get_model_status() -> Dict[str, Any]: + """Get current model status information.""" + processor, model = get_trocr_model(handwritten=True) + is_loaded = processor is not None and model is not None + + status = { + "status": "available" if is_loaded else "not_installed", + "is_loaded": is_loaded, + "model_name": "trocr-base-handwritten" if is_loaded else None, + "loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None, + } + + if is_loaded: + import torch + device = next(model.parameters()).device + status["device"] = str(device) + + return status + + +def get_active_backend() -> str: + """ + Return which TrOCR backend is configured. + + Possible values: "auto", "pytorch", "onnx". + """ + return _trocr_backend + + +def _split_into_lines(image) -> list: + """ + Split an image into text lines using simple projection-based segmentation. + + This is a basic implementation - for production use, consider using + a dedicated line detection model. + """ + import numpy as np + from PIL import Image + + try: + # Convert to grayscale + gray = image.convert('L') + img_array = np.array(gray) + + # Binarize (simple threshold) + threshold = 200 + binary = img_array < threshold + + # Horizontal projection (sum of dark pixels per row) + h_proj = np.sum(binary, axis=1) + + # Find line boundaries (where projection drops below threshold) + line_threshold = img_array.shape[1] * 0.02 # 2% of width + in_line = False + line_start = 0 + lines = [] + + for i, val in enumerate(h_proj): + if val > line_threshold and not in_line: + # Start of line + in_line = True + line_start = i + elif val <= line_threshold and in_line: + # End of line + in_line = False + # Add padding + start = max(0, line_start - 5) + end = min(img_array.shape[0], i + 5) + if end - start > 10: # Minimum line height + lines.append(image.crop((0, start, image.width, end))) + + # Handle last line if still in_line + if in_line: + start = max(0, line_start - 5) + lines.append(image.crop((0, start, image.width, image.height))) + + logger.info(f"Split image into {len(lines)} lines") + return lines + + except Exception as e: + logger.warning(f"Line splitting failed: {e}") + return [] diff --git a/klausur-service/backend/services/trocr_ocr.py b/klausur-service/backend/services/trocr_ocr.py new file mode 100644 index 0000000..2b55a81 --- /dev/null +++ b/klausur-service/backend/services/trocr_ocr.py @@ -0,0 +1,309 @@ +""" +TrOCR OCR Execution + +Core OCR inference routines (PyTorch, ONNX routing, enhanced mode). + +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import io +import logging +import time +from typing import Tuple, Optional, List, Dict, Any + +from .trocr_models import ( + OCRResult, + _trocr_backend, + _compute_image_hash, + _cache_get, + _cache_set, + get_trocr_model, + _split_into_lines, +) + +logger = logging.getLogger(__name__) + + +def _try_onnx_ocr( + image_data: bytes, + handwritten: bool = False, + split_lines: bool = True, +) -> Optional[Tuple[Optional[str], float]]: + """ + Attempt ONNX inference. Returns the (text, confidence) tuple on + success, or None if ONNX is not available / fails to load. + """ + try: + from .trocr_onnx_service import is_onnx_available, run_trocr_onnx + + if not is_onnx_available(handwritten=handwritten): + return None + # run_trocr_onnx is async -- return the coroutine's awaitable result + # The caller (run_trocr_ocr) will await it. + return run_trocr_onnx # sentinel: caller checks callable + except ImportError: + return None + + +async def _run_pytorch_ocr( + image_data: bytes, + handwritten: bool = False, + split_lines: bool = True, + size: str = "base", +) -> Tuple[Optional[str], float]: + """ + Original PyTorch inference path (extracted for routing). + """ + processor, model = get_trocr_model(handwritten=handwritten, size=size) + + if processor is None or model is None: + logger.error("TrOCR PyTorch model not available") + return None, 0.0 + + try: + import torch + from PIL import Image + import numpy as np + + # Load image + image = Image.open(io.BytesIO(image_data)).convert("RGB") + + if split_lines: + lines = _split_into_lines(image) + if not lines: + lines = [image] + else: + lines = [image] + + all_text = [] + confidences = [] + + for line_image in lines: + pixel_values = processor(images=line_image, return_tensors="pt").pixel_values + + device = next(model.parameters()).device + pixel_values = pixel_values.to(device) + + with torch.no_grad(): + generated_ids = model.generate(pixel_values, max_length=128) + + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + if generated_text.strip(): + all_text.append(generated_text.strip()) + confidences.append(0.85 if len(generated_text) > 3 else 0.5) + + text = "\n".join(all_text) + confidence = sum(confidences) / len(confidences) if confidences else 0.0 + + logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines") + return text, confidence + + except Exception as e: + logger.error(f"TrOCR PyTorch failed: {e}") + import traceback + logger.error(traceback.format_exc()) + return None, 0.0 + + +async def run_trocr_ocr( + image_data: bytes, + handwritten: bool = False, + split_lines: bool = True, + size: str = "base", +) -> Tuple[Optional[str], float]: + """ + Run TrOCR on an image. + + Routes between ONNX and PyTorch backends based on the TROCR_BACKEND + environment variable (default: "auto"). + + - "onnx" -- always use ONNX (raises RuntimeError if unavailable) + - "pytorch" -- always use PyTorch (original behaviour) + - "auto" -- try ONNX first, fall back to PyTorch + + TrOCR is optimized for single-line text recognition, so for full-page + images we need to either: + 1. Split into lines first (using line detection) + 2. Process the whole image and get partial results + + Args: + image_data: Raw image bytes + handwritten: Use handwritten model (slower but better for handwriting) + split_lines: Whether to split image into lines first + size: "base" or "large" (only for handwritten variant) + + Returns: + Tuple of (extracted_text, confidence) + """ + backend = _trocr_backend + + # --- ONNX-only mode --- + if backend == "onnx": + onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines) + if onnx_fn is None or not callable(onnx_fn): + raise RuntimeError( + "ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. " + "Ensure onnxruntime + optimum are installed and ONNX model files exist." + ) + return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines) + + # --- PyTorch-only mode --- + if backend == "pytorch": + return await _run_pytorch_ocr( + image_data, handwritten=handwritten, split_lines=split_lines, size=size, + ) + + # --- Auto mode: try ONNX first, then PyTorch --- + onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines) + if onnx_fn is not None and callable(onnx_fn): + try: + result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines) + if result[0] is not None: + return result + logger.warning("ONNX returned None text, falling back to PyTorch") + except Exception as e: + logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch") + + return await _run_pytorch_ocr( + image_data, handwritten=handwritten, split_lines=split_lines, size=size, + ) + + +def _try_onnx_enhanced( + handwritten: bool = True, +): + """ + Return the ONNX enhanced coroutine function, or None if unavailable. + """ + try: + from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced + + if not is_onnx_available(handwritten=handwritten): + return None + return run_trocr_onnx_enhanced + except ImportError: + return None + + +async def run_trocr_ocr_enhanced( + image_data: bytes, + handwritten: bool = True, + split_lines: bool = True, + use_cache: bool = True +) -> OCRResult: + """ + Enhanced TrOCR OCR with caching and detailed results. + + Routes between ONNX and PyTorch backends based on the TROCR_BACKEND + environment variable (default: "auto"). + + Args: + image_data: Raw image bytes + handwritten: Use handwritten model + split_lines: Whether to split image into lines first + use_cache: Whether to use caching + + Returns: + OCRResult with detailed information + """ + backend = _trocr_backend + + # --- ONNX-only mode --- + if backend == "onnx": + onnx_fn = _try_onnx_enhanced(handwritten=handwritten) + if onnx_fn is None: + raise RuntimeError( + "ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. " + "Ensure onnxruntime + optimum are installed and ONNX model files exist." + ) + return await onnx_fn( + image_data, handwritten=handwritten, + split_lines=split_lines, use_cache=use_cache, + ) + + # --- Auto mode: try ONNX first --- + if backend == "auto": + onnx_fn = _try_onnx_enhanced(handwritten=handwritten) + if onnx_fn is not None: + try: + result = await onnx_fn( + image_data, handwritten=handwritten, + split_lines=split_lines, use_cache=use_cache, + ) + if result.text: + return result + logger.warning("ONNX enhanced returned empty text, falling back to PyTorch") + except Exception as e: + logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch") + + # --- PyTorch path (backend == "pytorch" or auto fallback) --- + start_time = time.time() + + # Check cache first + image_hash = _compute_image_hash(image_data) + if use_cache: + cached = _cache_get(image_hash) + if cached: + return OCRResult( + text=cached["text"], + confidence=cached["confidence"], + processing_time_ms=0, + model=cached["model"], + has_lora_adapter=cached.get("has_lora_adapter", False), + char_confidences=cached.get("char_confidences", []), + word_boxes=cached.get("word_boxes", []), + from_cache=True, + image_hash=image_hash + ) + + # Run OCR via PyTorch + text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines) + + processing_time_ms = int((time.time() - start_time) * 1000) + + # Generate word boxes with simulated confidences + word_boxes = [] + if text: + words = text.split() + for idx, word in enumerate(words): + # Simulate word confidence (slightly varied around overall confidence) + word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100)) + word_boxes.append({ + "text": word, + "confidence": word_conf, + "bbox": [0, 0, 0, 0] # Would need actual bounding box detection + }) + + # Generate character confidences + char_confidences = [] + if text: + for char in text: + # Simulate per-character confidence + char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100)) + char_confidences.append(char_conf) + + result = OCRResult( + text=text or "", + confidence=confidence, + processing_time_ms=processing_time_ms, + model="trocr-base-handwritten" if handwritten else "trocr-base-printed", + has_lora_adapter=False, # Would check actual adapter status + char_confidences=char_confidences, + word_boxes=word_boxes, + from_cache=False, + image_hash=image_hash + ) + + # Cache result + if use_cache and text: + _cache_set(image_hash, { + "text": result.text, + "confidence": result.confidence, + "model": result.model, + "has_lora_adapter": result.has_lora_adapter, + "char_confidences": result.char_confidences, + "word_boxes": result.word_boxes + }) + + return result diff --git a/klausur-service/backend/services/trocr_service.py b/klausur-service/backend/services/trocr_service.py index a91dd13..e994f06 100644 --- a/klausur-service/backend/services/trocr_service.py +++ b/klausur-service/backend/services/trocr_service.py @@ -1,720 +1,70 @@ """ -TrOCR Service +TrOCR Service — Barrel Re-export Microsoft's Transformer-based OCR for text recognition. -Besonders geeignet fuer: -- Gedruckten Text -- Saubere Scans -- Schnelle Verarbeitung - -Model: microsoft/trocr-base-printed (oder handwritten Variante) +Split into submodules: +- trocr_models.py — Dataclasses, cache, model loading, line splitting +- trocr_ocr.py — Core OCR inference (PyTorch/ONNX routing, enhanced) +- trocr_batch.py — Batch processing and SSE streaming DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. - -Phase 2 Enhancements: -- Batch processing for multiple images -- SHA256-based caching for repeated requests -- Model preloading for faster first request -- Word-level bounding boxes with confidence """ -import io -import os -import hashlib -import logging -import time -import asyncio -from typing import Tuple, Optional, List, Dict, Any -from dataclasses import dataclass, field -from collections import OrderedDict -from datetime import datetime, timedelta - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Backend routing: auto | pytorch | onnx -# --------------------------------------------------------------------------- -_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx - -# Lazy loading for heavy dependencies -# Cache keyed by model_name to support base and large variants simultaneously -_trocr_models: dict = {} # {model_name: (processor, model)} -_trocr_processor = None # backwards-compat alias → base-printed -_trocr_model = None # backwards-compat alias → base-printed -_trocr_available = None -_model_loaded_at = None - -# Simple in-memory cache with LRU eviction -_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() -_cache_max_size = 100 -_cache_ttl_seconds = 3600 # 1 hour - - -@dataclass -class OCRResult: - """Enhanced OCR result with detailed information.""" - text: str - confidence: float - processing_time_ms: int - model: str - has_lora_adapter: bool = False - char_confidences: List[float] = field(default_factory=list) - word_boxes: List[Dict[str, Any]] = field(default_factory=list) - from_cache: bool = False - image_hash: str = "" - - -@dataclass -class BatchOCRResult: - """Result for batch processing.""" - results: List[OCRResult] - total_time_ms: int - processed_count: int - cached_count: int - error_count: int - - -def _compute_image_hash(image_data: bytes) -> str: - """Compute SHA256 hash of image data for caching.""" - return hashlib.sha256(image_data).hexdigest()[:16] - - -def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]: - """Get cached OCR result if available and not expired.""" - if image_hash in _ocr_cache: - entry = _ocr_cache[image_hash] - if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds): - # Move to end (LRU) - _ocr_cache.move_to_end(image_hash) - return entry["result"] - else: - # Expired, remove - del _ocr_cache[image_hash] - return None - - -def _cache_set(image_hash: str, result: Dict[str, Any]) -> None: - """Store OCR result in cache.""" - # Evict oldest if at capacity - while len(_ocr_cache) >= _cache_max_size: - _ocr_cache.popitem(last=False) - - _ocr_cache[image_hash] = { - "result": result, - "cached_at": datetime.now() - } - - -def get_cache_stats() -> Dict[str, Any]: - """Get cache statistics.""" - return { - "size": len(_ocr_cache), - "max_size": _cache_max_size, - "ttl_seconds": _cache_ttl_seconds, - "hit_rate": 0 # Could track this with additional counters - } - - -def _check_trocr_available() -> bool: - """Check if TrOCR dependencies are available.""" - global _trocr_available - if _trocr_available is not None: - return _trocr_available - - try: - import torch - from transformers import TrOCRProcessor, VisionEncoderDecoderModel - _trocr_available = True - except ImportError as e: - logger.warning(f"TrOCR dependencies not available: {e}") - _trocr_available = False - - return _trocr_available - - -def get_trocr_model(handwritten: bool = False, size: str = "base"): - """ - Lazy load TrOCR model and processor. - - Args: - handwritten: Use handwritten model instead of printed model - size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy - for exam HTR). Only applies to handwritten variant. - - Returns tuple of (processor, model) or (None, None) if unavailable. - """ - global _trocr_processor, _trocr_model - - if not _check_trocr_available(): - return None, None - - # Select model name - if size == "large" and handwritten: - model_name = "microsoft/trocr-large-handwritten" - elif handwritten: - model_name = "microsoft/trocr-base-handwritten" - else: - model_name = "microsoft/trocr-base-printed" - - if model_name in _trocr_models: - return _trocr_models[model_name] - - try: - import torch - from transformers import TrOCRProcessor, VisionEncoderDecoderModel - - logger.info(f"Loading TrOCR model: {model_name}") - processor = TrOCRProcessor.from_pretrained(model_name) - model = VisionEncoderDecoderModel.from_pretrained(model_name) - - # Use GPU if available - device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - model.to(device) - logger.info(f"TrOCR model loaded on device: {device}") - - _trocr_models[model_name] = (processor, model) - - # Keep backwards-compat globals pointing at base-printed - if model_name == "microsoft/trocr-base-printed": - _trocr_processor = processor - _trocr_model = model - - return processor, model - - except Exception as e: - logger.error(f"Failed to load TrOCR model {model_name}: {e}") - return None, None - - -def preload_trocr_model(handwritten: bool = True) -> bool: - """ - Preload TrOCR model at startup for faster first request. - - Call this from your FastAPI startup event: - @app.on_event("startup") - async def startup(): - preload_trocr_model() - """ - global _model_loaded_at - logger.info("Preloading TrOCR model...") - processor, model = get_trocr_model(handwritten=handwritten) - if processor is not None and model is not None: - _model_loaded_at = datetime.now() - logger.info("TrOCR model preloaded successfully") - return True - else: - logger.warning("TrOCR model preloading failed") - return False - - -def get_model_status() -> Dict[str, Any]: - """Get current model status information.""" - processor, model = get_trocr_model(handwritten=True) - is_loaded = processor is not None and model is not None - - status = { - "status": "available" if is_loaded else "not_installed", - "is_loaded": is_loaded, - "model_name": "trocr-base-handwritten" if is_loaded else None, - "loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None, - } - - if is_loaded: - import torch - device = next(model.parameters()).device - status["device"] = str(device) - - return status - - -def get_active_backend() -> str: - """ - Return which TrOCR backend is configured. - - Possible values: "auto", "pytorch", "onnx". - """ - return _trocr_backend - - -def _try_onnx_ocr( - image_data: bytes, - handwritten: bool = False, - split_lines: bool = True, -) -> Optional[Tuple[Optional[str], float]]: - """ - Attempt ONNX inference. Returns the (text, confidence) tuple on - success, or None if ONNX is not available / fails to load. - """ - try: - from .trocr_onnx_service import is_onnx_available, run_trocr_onnx - - if not is_onnx_available(handwritten=handwritten): - return None - # run_trocr_onnx is async — return the coroutine's awaitable result - # The caller (run_trocr_ocr) will await it. - return run_trocr_onnx # sentinel: caller checks callable - except ImportError: - return None - - -async def _run_pytorch_ocr( - image_data: bytes, - handwritten: bool = False, - split_lines: bool = True, - size: str = "base", -) -> Tuple[Optional[str], float]: - """ - Original PyTorch inference path (extracted for routing). - """ - processor, model = get_trocr_model(handwritten=handwritten, size=size) - - if processor is None or model is None: - logger.error("TrOCR PyTorch model not available") - return None, 0.0 - - try: - import torch - from PIL import Image - import numpy as np - - # Load image - image = Image.open(io.BytesIO(image_data)).convert("RGB") - - if split_lines: - lines = _split_into_lines(image) - if not lines: - lines = [image] - else: - lines = [image] - - all_text = [] - confidences = [] - - for line_image in lines: - pixel_values = processor(images=line_image, return_tensors="pt").pixel_values - - device = next(model.parameters()).device - pixel_values = pixel_values.to(device) - - with torch.no_grad(): - generated_ids = model.generate(pixel_values, max_length=128) - - generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - - if generated_text.strip(): - all_text.append(generated_text.strip()) - confidences.append(0.85 if len(generated_text) > 3 else 0.5) - - text = "\n".join(all_text) - confidence = sum(confidences) / len(confidences) if confidences else 0.0 - - logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines") - return text, confidence - - except Exception as e: - logger.error(f"TrOCR PyTorch failed: {e}") - import traceback - logger.error(traceback.format_exc()) - return None, 0.0 - - -async def run_trocr_ocr( - image_data: bytes, - handwritten: bool = False, - split_lines: bool = True, - size: str = "base", -) -> Tuple[Optional[str], float]: - """ - Run TrOCR on an image. - - Routes between ONNX and PyTorch backends based on the TROCR_BACKEND - environment variable (default: "auto"). - - - "onnx" — always use ONNX (raises RuntimeError if unavailable) - - "pytorch" — always use PyTorch (original behaviour) - - "auto" — try ONNX first, fall back to PyTorch - - TrOCR is optimized for single-line text recognition, so for full-page - images we need to either: - 1. Split into lines first (using line detection) - 2. Process the whole image and get partial results - - Args: - image_data: Raw image bytes - handwritten: Use handwritten model (slower but better for handwriting) - split_lines: Whether to split image into lines first - size: "base" or "large" (only for handwritten variant) - - Returns: - Tuple of (extracted_text, confidence) - """ - backend = _trocr_backend - - # --- ONNX-only mode --- - if backend == "onnx": - onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines) - if onnx_fn is None or not callable(onnx_fn): - raise RuntimeError( - "ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. " - "Ensure onnxruntime + optimum are installed and ONNX model files exist." - ) - return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines) - - # --- PyTorch-only mode --- - if backend == "pytorch": - return await _run_pytorch_ocr( - image_data, handwritten=handwritten, split_lines=split_lines, size=size, - ) - - # --- Auto mode: try ONNX first, then PyTorch --- - onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines) - if onnx_fn is not None and callable(onnx_fn): - try: - result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines) - if result[0] is not None: - return result - logger.warning("ONNX returned None text, falling back to PyTorch") - except Exception as e: - logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch") - - return await _run_pytorch_ocr( - image_data, handwritten=handwritten, split_lines=split_lines, size=size, - ) - - -def _split_into_lines(image) -> list: - """ - Split an image into text lines using simple projection-based segmentation. - - This is a basic implementation - for production use, consider using - a dedicated line detection model. - """ - import numpy as np - from PIL import Image - - try: - # Convert to grayscale - gray = image.convert('L') - img_array = np.array(gray) - - # Binarize (simple threshold) - threshold = 200 - binary = img_array < threshold - - # Horizontal projection (sum of dark pixels per row) - h_proj = np.sum(binary, axis=1) - - # Find line boundaries (where projection drops below threshold) - line_threshold = img_array.shape[1] * 0.02 # 2% of width - in_line = False - line_start = 0 - lines = [] - - for i, val in enumerate(h_proj): - if val > line_threshold and not in_line: - # Start of line - in_line = True - line_start = i - elif val <= line_threshold and in_line: - # End of line - in_line = False - # Add padding - start = max(0, line_start - 5) - end = min(img_array.shape[0], i + 5) - if end - start > 10: # Minimum line height - lines.append(image.crop((0, start, image.width, end))) - - # Handle last line if still in_line - if in_line: - start = max(0, line_start - 5) - lines.append(image.crop((0, start, image.width, image.height))) - - logger.info(f"Split image into {len(lines)} lines") - return lines - - except Exception as e: - logger.warning(f"Line splitting failed: {e}") - return [] - - -def _try_onnx_enhanced( - handwritten: bool = True, -): - """ - Return the ONNX enhanced coroutine function, or None if unavailable. - """ - try: - from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced - - if not is_onnx_available(handwritten=handwritten): - return None - return run_trocr_onnx_enhanced - except ImportError: - return None - - -async def run_trocr_ocr_enhanced( - image_data: bytes, - handwritten: bool = True, - split_lines: bool = True, - use_cache: bool = True -) -> OCRResult: - """ - Enhanced TrOCR OCR with caching and detailed results. - - Routes between ONNX and PyTorch backends based on the TROCR_BACKEND - environment variable (default: "auto"). - - Args: - image_data: Raw image bytes - handwritten: Use handwritten model - split_lines: Whether to split image into lines first - use_cache: Whether to use caching - - Returns: - OCRResult with detailed information - """ - backend = _trocr_backend - - # --- ONNX-only mode --- - if backend == "onnx": - onnx_fn = _try_onnx_enhanced(handwritten=handwritten) - if onnx_fn is None: - raise RuntimeError( - "ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. " - "Ensure onnxruntime + optimum are installed and ONNX model files exist." - ) - return await onnx_fn( - image_data, handwritten=handwritten, - split_lines=split_lines, use_cache=use_cache, - ) - - # --- Auto mode: try ONNX first --- - if backend == "auto": - onnx_fn = _try_onnx_enhanced(handwritten=handwritten) - if onnx_fn is not None: - try: - result = await onnx_fn( - image_data, handwritten=handwritten, - split_lines=split_lines, use_cache=use_cache, - ) - if result.text: - return result - logger.warning("ONNX enhanced returned empty text, falling back to PyTorch") - except Exception as e: - logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch") - - # --- PyTorch path (backend == "pytorch" or auto fallback) --- - start_time = time.time() - - # Check cache first - image_hash = _compute_image_hash(image_data) - if use_cache: - cached = _cache_get(image_hash) - if cached: - return OCRResult( - text=cached["text"], - confidence=cached["confidence"], - processing_time_ms=0, - model=cached["model"], - has_lora_adapter=cached.get("has_lora_adapter", False), - char_confidences=cached.get("char_confidences", []), - word_boxes=cached.get("word_boxes", []), - from_cache=True, - image_hash=image_hash - ) - - # Run OCR via PyTorch - text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines) - - processing_time_ms = int((time.time() - start_time) * 1000) - - # Generate word boxes with simulated confidences - word_boxes = [] - if text: - words = text.split() - for idx, word in enumerate(words): - # Simulate word confidence (slightly varied around overall confidence) - word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100)) - word_boxes.append({ - "text": word, - "confidence": word_conf, - "bbox": [0, 0, 0, 0] # Would need actual bounding box detection - }) - - # Generate character confidences - char_confidences = [] - if text: - for char in text: - # Simulate per-character confidence - char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100)) - char_confidences.append(char_conf) - - result = OCRResult( - text=text or "", - confidence=confidence, - processing_time_ms=processing_time_ms, - model="trocr-base-handwritten" if handwritten else "trocr-base-printed", - has_lora_adapter=False, # Would check actual adapter status - char_confidences=char_confidences, - word_boxes=word_boxes, - from_cache=False, - image_hash=image_hash - ) - - # Cache result - if use_cache and text: - _cache_set(image_hash, { - "text": result.text, - "confidence": result.confidence, - "model": result.model, - "has_lora_adapter": result.has_lora_adapter, - "char_confidences": result.char_confidences, - "word_boxes": result.word_boxes - }) - - return result - - -async def run_trocr_batch( - images: List[bytes], - handwritten: bool = True, - split_lines: bool = True, - use_cache: bool = True, - progress_callback: Optional[callable] = None -) -> BatchOCRResult: - """ - Process multiple images in batch. - - Args: - images: List of image data bytes - handwritten: Use handwritten model - split_lines: Whether to split images into lines - use_cache: Whether to use caching - progress_callback: Optional callback(current, total) for progress updates - - Returns: - BatchOCRResult with all results - """ - start_time = time.time() - results = [] - cached_count = 0 - error_count = 0 - - for idx, image_data in enumerate(images): - try: - result = await run_trocr_ocr_enhanced( - image_data, - handwritten=handwritten, - split_lines=split_lines, - use_cache=use_cache - ) - results.append(result) - - if result.from_cache: - cached_count += 1 - - # Report progress - if progress_callback: - progress_callback(idx + 1, len(images)) - - except Exception as e: - logger.error(f"Batch OCR error for image {idx}: {e}") - error_count += 1 - results.append(OCRResult( - text=f"Error: {str(e)}", - confidence=0.0, - processing_time_ms=0, - model="error", - has_lora_adapter=False - )) - - total_time_ms = int((time.time() - start_time) * 1000) - - return BatchOCRResult( - results=results, - total_time_ms=total_time_ms, - processed_count=len(images), - cached_count=cached_count, - error_count=error_count - ) - - -# Generator for SSE streaming during batch processing -async def run_trocr_batch_stream( - images: List[bytes], - handwritten: bool = True, - split_lines: bool = True, - use_cache: bool = True -): - """ - Process images and yield progress updates for SSE streaming. - - Yields: - dict with current progress and result - """ - start_time = time.time() - total = len(images) - - for idx, image_data in enumerate(images): - try: - result = await run_trocr_ocr_enhanced( - image_data, - handwritten=handwritten, - split_lines=split_lines, - use_cache=use_cache - ) - - elapsed_ms = int((time.time() - start_time) * 1000) - avg_time_per_image = elapsed_ms / (idx + 1) - estimated_remaining = int(avg_time_per_image * (total - idx - 1)) - - yield { - "type": "progress", - "current": idx + 1, - "total": total, - "progress_percent": ((idx + 1) / total) * 100, - "elapsed_ms": elapsed_ms, - "estimated_remaining_ms": estimated_remaining, - "result": { - "text": result.text, - "confidence": result.confidence, - "processing_time_ms": result.processing_time_ms, - "from_cache": result.from_cache - } - } - - except Exception as e: - logger.error(f"Stream OCR error for image {idx}: {e}") - yield { - "type": "error", - "current": idx + 1, - "total": total, - "error": str(e) - } - - total_time_ms = int((time.time() - start_time) * 1000) - yield { - "type": "complete", - "total_time_ms": total_time_ms, - "processed_count": total - } - - -# Test function -async def test_trocr_ocr(image_path: str, handwritten: bool = False): - """Test TrOCR on a local image file.""" - with open(image_path, "rb") as f: - image_data = f.read() - - text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten) - - print(f"\n=== TrOCR Test ===") - print(f"Mode: {'Handwritten' if handwritten else 'Printed'}") - print(f"Confidence: {confidence:.2f}") - print(f"Text:\n{text}") - - return text, confidence +# Models, cache, and model loading +from .trocr_models import ( + OCRResult, + BatchOCRResult, + _compute_image_hash, + _cache_get, + _cache_set, + get_cache_stats, + _check_trocr_available, + get_trocr_model, + preload_trocr_model, + get_model_status, + get_active_backend, + _split_into_lines, +) + +# Core OCR execution +from .trocr_ocr import ( + run_trocr_ocr, + run_trocr_ocr_enhanced, + _run_pytorch_ocr, +) + +# Batch processing & streaming +from .trocr_batch import ( + run_trocr_batch, + run_trocr_batch_stream, + test_trocr_ocr, +) + +__all__ = [ + # Dataclasses + "OCRResult", + "BatchOCRResult", + # Cache + "_compute_image_hash", + "_cache_get", + "_cache_set", + "get_cache_stats", + # Model loading + "_check_trocr_available", + "get_trocr_model", + "preload_trocr_model", + "get_model_status", + "get_active_backend", + "_split_into_lines", + # OCR execution + "run_trocr_ocr", + "run_trocr_ocr_enhanced", + "_run_pytorch_ocr", + # Batch + "run_trocr_batch", + "run_trocr_batch_stream", + "test_trocr_ocr", +] if __name__ == "__main__": diff --git a/website/app/admin/compliance/audit-checklist/_components/AuditProgressBar.tsx b/website/app/admin/compliance/audit-checklist/_components/AuditProgressBar.tsx new file mode 100644 index 0000000..b65265a --- /dev/null +++ b/website/app/admin/compliance/audit-checklist/_components/AuditProgressBar.tsx @@ -0,0 +1,50 @@ +'use client' + +import { AuditStatistics } from './types' +import { Language } from '@/lib/compliance-i18n' + +export default function AuditProgressBar({ statistics, lang }: { statistics: AuditStatistics; lang: Language }) { + const segments = [ + { key: 'compliant', count: statistics.compliant, color: 'bg-green-500', label: 'Konform' }, + { key: 'compliant_with_notes', count: statistics.compliant_with_notes, color: 'bg-yellow-500', label: 'Mit Anm.' }, + { key: 'non_compliant', count: statistics.non_compliant, color: 'bg-red-500', label: 'Nicht konform' }, + { key: 'not_applicable', count: statistics.not_applicable, color: 'bg-slate-300', label: 'N/A' }, + { key: 'pending', count: statistics.pending, color: 'bg-slate-100', label: 'Offen' }, + ] + + return ( +
+
+

Fortschritt

+ + {Math.round(statistics.completion_percentage)}% + +
+ + {/* Stacked Progress Bar */} +
+ {segments.map(seg => ( + seg.count > 0 && ( +
+ ) + ))} +
+ + {/* Legend */} +
+ {segments.map(seg => ( +
+ + {seg.label}: + {seg.count} +
+ ))} +
+
+ ) +} diff --git a/website/app/admin/compliance/audit-checklist/_components/ChecklistItemRow.tsx b/website/app/admin/compliance/audit-checklist/_components/ChecklistItemRow.tsx new file mode 100644 index 0000000..7f0e1c9 --- /dev/null +++ b/website/app/admin/compliance/audit-checklist/_components/ChecklistItemRow.tsx @@ -0,0 +1,73 @@ +'use client' + +import { AuditChecklistItem, RESULT_STATUS } from './types' +import { Language } from '@/lib/compliance-i18n' + +export default function ChecklistItemRow({ + item, + index, + onSignOff, + lang, +}: { + item: AuditChecklistItem + index: number + onSignOff: () => void + lang: Language +}) { + const status = RESULT_STATUS[item.current_result] + + return ( +
+
+ {/* Status Icon */} +
+ {status.icon} +
+ + {/* Content */} +
+
+ {index}. + {item.regulation_code} + {item.article} + {item.is_signed && ( + + + + + Signiert + + )} +
+

{item.title}

+ {item.notes && ( +

"{item.notes}"

+ )} +
+ {item.controls_mapped} Controls + {item.evidence_count} Nachweise + {item.signed_at && ( + Signiert: {new Date(item.signed_at).toLocaleDateString('de-DE')} + )} +
+
+ + {/* Actions */} +
+ + {lang === 'de' ? status.label : status.labelEn} + + +
+
+
+ ) +} diff --git a/website/app/admin/compliance/audit-checklist/_components/CreateSessionModal.tsx b/website/app/admin/compliance/audit-checklist/_components/CreateSessionModal.tsx new file mode 100644 index 0000000..ee2bfe0 --- /dev/null +++ b/website/app/admin/compliance/audit-checklist/_components/CreateSessionModal.tsx @@ -0,0 +1,158 @@ +'use client' + +import { useState } from 'react' +import { Regulation } from './types' + +interface CreateSessionData { + name: string + description?: string + auditor_name: string + auditor_email?: string + regulation_codes?: string[] +} + +export default function CreateSessionModal({ + regulations, + onClose, + onCreate, +}: { + regulations: Regulation[] + onClose: () => void + onCreate: (data: CreateSessionData) => void +}) { + const [name, setName] = useState('') + const [description, setDescription] = useState('') + const [auditorName, setAuditorName] = useState('') + const [auditorEmail, setAuditorEmail] = useState('') + const [selectedRegs, setSelectedRegs] = useState([]) + const [submitting, setSubmitting] = useState(false) + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault() + if (!name || !auditorName) return + + setSubmitting(true) + await onCreate({ + name, + description: description || undefined, + auditor_name: auditorName, + auditor_email: auditorEmail || undefined, + regulation_codes: selectedRegs.length > 0 ? selectedRegs : undefined, + }) + setSubmitting(false) + } + + return ( +
+
+
+

Neue Audit-Session

+ +
+ +
+
+ + setName(e.target.value)} + placeholder="z.B. Q1 2026 Compliance Audit" + className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500" + required + /> +
+ +
+ +