[split-required] Split 700-870 LOC files across all services

backend-lehrer (11 files):
- llm_gateway/routes/schools.py (867 → 5), recording_api.py (848 → 6)
- messenger_api.py (840 → 5), print_generator.py (824 → 5)
- unit_analytics_api.py (751 → 5), classroom/routes/context.py (726 → 4)
- llm_gateway/routes/edu_search_seeds.py (710 → 4)

klausur-service (12 files):
- ocr_labeling_api.py (845 → 4), metrics_db.py (833 → 4)
- legal_corpus_api.py (790 → 4), page_crop.py (758 → 3)
- mail/ai_service.py (747 → 4), github_crawler.py (767 → 3)
- trocr_service.py (730 → 4), full_compliance_pipeline.py (723 → 4)
- dsfa_rag_api.py (715 → 4), ocr_pipeline_auto.py (705 → 4)

website (6 pages):
- audit-checklist (867 → 8), content (806 → 6)
- screen-flow (790 → 4), scraper (789 → 5)
- zeugnisse (776 → 5), modules (745 → 4)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-25 08:01:18 +02:00
parent b6983ab1dc
commit 34da9f4cda
106 changed files with 16500 additions and 16947 deletions
+3
View File
@@ -39,6 +39,9 @@
**/lib/sdk/vvt-baseline-catalog.ts | owner=admin-lehrer | reason=Pure data catalog (630 LOC, BaselineTemplate[] literals) | review=2027-01-01
**/lib/sdk/loeschfristen-baseline-catalog.ts | owner=admin-lehrer | reason=Pure data catalog (578 LOC, retention period templates) | review=2027-01-01
# Single SSE generator orchestrating 6 pipeline steps — cannot split generator context
**/ocr_pipeline_auto_steps.py | owner=klausur | reason=run_auto is a single async generator yielding SSE events across 6 steps (528 LOC) | review=2026-10-01
# Legacy — TEMPORAER bis Refactoring abgeschlossen
# Dateien hier werden Phase fuer Phase abgearbeitet und entfernt.
# KEINE neuen Ausnahmen ohne [guardrail-change] Commit-Marker!
+193
View File
@@ -0,0 +1,193 @@
"""
AI Processing - Print Version Generator: Cloze (Lueckentext).
Generates printable HTML for cloze/fill-in-the-blank worksheets.
"""
from pathlib import Path
import json
import random
import logging
from .core import BEREINIGT_DIR
logger = logging.getLogger(__name__)
def generate_print_version_cloze(cloze_path: Path, include_answers: bool = False) -> Path:
"""
Generiert eine druckbare HTML-Version der Lueckentexte.
Args:
cloze_path: Pfad zur *_cloze.json Datei
include_answers: True fuer Loesungsblatt (fuer Eltern)
Returns:
Pfad zur generierten HTML-Datei
"""
if not cloze_path.exists():
raise FileNotFoundError(f"Cloze-Datei nicht gefunden: {cloze_path}")
cloze_data = json.loads(cloze_path.read_text(encoding="utf-8"))
items = cloze_data.get("cloze_items", [])
metadata = cloze_data.get("metadata", {})
title = metadata.get("source_title", "Arbeitsblatt")
subject = metadata.get("subject", "")
grade = metadata.get("grade_level", "")
total_gaps = metadata.get("total_gaps", 0)
html_parts = []
html_parts.append("""<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>""" + title + """ - Lueckentext</title>
<style>
@media print {
.no-print { display: none; }
.page-break { page-break-before: always; }
}
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 40px auto;
padding: 20px;
line-height: 1.8;
}
h1 { font-size: 24px; margin-bottom: 8px; }
.meta { color: #666; margin-bottom: 24px; }
.cloze-item {
margin-bottom: 24px;
padding: 16px;
background: #f9f9f9;
border-radius: 8px;
}
.cloze-number {
font-weight: bold;
color: #333;
margin-bottom: 8px;
}
.cloze-sentence {
font-size: 16px;
line-height: 2;
}
.gap {
display: inline-block;
min-width: 80px;
border-bottom: 2px solid #333;
margin: 0 4px;
text-align: center;
}
.gap-filled {
display: inline-block;
padding: 2px 8px;
background: #e8f5e9;
border: 1px solid #4caf50;
border-radius: 4px;
font-weight: bold;
}
.translation {
margin-top: 12px;
padding: 8px;
background: #e3f2fd;
border-left: 3px solid #2196f3;
font-size: 14px;
color: #555;
}
.translation-label {
font-size: 12px;
color: #777;
margin-bottom: 4px;
}
.word-bank {
margin-top: 32px;
padding: 16px;
background: #fff3e0;
border-radius: 8px;
}
.word-bank-title {
font-weight: bold;
margin-bottom: 12px;
}
.word {
display: inline-block;
padding: 4px 12px;
margin: 4px;
background: white;
border: 1px solid #ddd;
border-radius: 4px;
}
</style>
</head>
<body>
""")
# Header
version_text = "Loesungsblatt" if include_answers else "Lueckentext"
html_parts.append(f"<h1>{title} - {version_text}</h1>")
meta_parts = []
if subject:
meta_parts.append(f"Fach: {subject}")
if grade:
meta_parts.append(f"Klasse: {grade}")
meta_parts.append(f"Luecken gesamt: {total_gaps}")
html_parts.append(f"<div class='meta'>{' | '.join(meta_parts)}</div>")
# Sammle alle Lueckenwoerter fuer Wortbank
all_words = []
# Lueckentexte
for idx, item in enumerate(items, 1):
html_parts.append("<div class='cloze-item'>")
html_parts.append(f"<div class='cloze-number'>{idx}.</div>")
gaps = item.get("gaps", [])
sentence = item.get("sentence_with_gaps", "")
if include_answers:
# Loesungsblatt: Luecken mit Antworten fuellen
for gap in gaps:
word = gap.get("word", "")
sentence = sentence.replace("___", f"<span class='gap-filled'>{word}</span>", 1)
else:
# Fragenblatt: Luecken als Linien
sentence = sentence.replace("___", "<span class='gap'>&nbsp;</span>")
# Woerter fuer Wortbank sammeln
for gap in gaps:
all_words.append(gap.get("word", ""))
html_parts.append(f"<div class='cloze-sentence'>{sentence}</div>")
# Uebersetzung anzeigen
translation = item.get("translation", {})
if translation:
lang_name = translation.get("language_name", "Uebersetzung")
full_sentence = translation.get("full_sentence", "")
if full_sentence:
html_parts.append("<div class='translation'>")
html_parts.append(f"<div class='translation-label'>{lang_name}:</div>")
html_parts.append(full_sentence)
html_parts.append("</div>")
html_parts.append("</div>")
# Wortbank (nur fuer Fragenblatt)
if not include_answers and all_words:
random.shuffle(all_words) # Mische die Woerter
html_parts.append("<div class='word-bank'>")
html_parts.append("<div class='word-bank-title'>Wortbank (diese Woerter fehlen):</div>")
for word in all_words:
html_parts.append(f"<span class='word'>{word}</span>")
html_parts.append("</div>")
html_parts.append("</body></html>")
# Speichern
suffix = "_cloze_solutions.html" if include_answers else "_cloze_print.html"
out_name = cloze_path.stem.replace("_cloze", "") + suffix
out_path = BEREINIGT_DIR / out_name
out_path.write_text("\n".join(html_parts), encoding="utf-8")
logger.info(f"Cloze Print-Version gespeichert: {out_path.name}")
return out_path
+18 -820
View File
@@ -1,824 +1,22 @@
"""
AI Processing - Print Version Generator.
AI Processing - Print Version Generator — Barrel Re-export.
Generiert druckbare HTML-Versionen für verschiedene Arbeitsblatt-Typen.
Generiert druckbare HTML-Versionen fuer verschiedene Arbeitsblatt-Typen.
Split into:
- print_qa.py: Q&A print generation
- print_cloze.py: Cloze/Lueckentext print generation
- print_mc.py: Multiple Choice print generation
- print_worksheet.py: General worksheet print generation
"""
from pathlib import Path
import json
import random
import logging
from .core import BEREINIGT_DIR
logger = logging.getLogger(__name__)
def generate_print_version_qa(qa_path: Path, include_answers: bool = False) -> Path:
"""
Generiert eine druckbare HTML-Version der Frage-Antwort-Paare.
Args:
qa_path: Pfad zur *_qa.json Datei
include_answers: True für Lösungsblatt (für Eltern)
Returns:
Pfad zur generierten HTML-Datei
"""
if not qa_path.exists():
raise FileNotFoundError(f"Q&A-Datei nicht gefunden: {qa_path}")
qa_data = json.loads(qa_path.read_text(encoding="utf-8"))
items = qa_data.get("qa_items", [])
metadata = qa_data.get("metadata", {})
title = metadata.get("source_title", "Arbeitsblatt")
subject = metadata.get("subject", "")
grade = metadata.get("grade_level", "")
html_parts = []
html_parts.append("""<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>""" + title + """ - Fragen</title>
<style>
@media print {
.no-print { display: none; }
.page-break { page-break-before: always; }
}
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 40px auto;
padding: 20px;
line-height: 1.6;
}
h1 { font-size: 24px; margin-bottom: 8px; }
.meta { color: #666; margin-bottom: 24px; }
.question-block {
margin-bottom: 32px;
padding-bottom: 16px;
border-bottom: 1px dashed #ccc;
}
.question-number {
font-weight: bold;
color: #333;
}
.question-text {
font-size: 16px;
margin: 8px 0;
}
.answer-space {
border: 1px solid #ddd;
min-height: 60px;
margin-top: 12px;
background: #fafafa;
}
.answer-lines {
margin-top: 12px;
}
.answer-line {
border-bottom: 1px solid #999;
height: 28px;
}
.answer {
margin-top: 8px;
padding: 8px;
background: #e8f5e9;
border-left: 3px solid #4caf50;
}
.key-terms {
font-size: 12px;
color: #666;
margin-top: 8px;
}
.key-terms span {
background: #fff3e0;
padding: 2px 6px;
border-radius: 3px;
margin-right: 4px;
}
</style>
</head>
<body>
""")
# Header
version_text = "Lösungsblatt" if include_answers else "Fragenblatt"
html_parts.append(f"<h1>{title} - {version_text}</h1>")
meta_parts = []
if subject:
meta_parts.append(f"Fach: {subject}")
if grade:
meta_parts.append(f"Klasse: {grade}")
meta_parts.append(f"Anzahl Fragen: {len(items)}")
html_parts.append(f"<div class='meta'>{' | '.join(meta_parts)}</div>")
# Fragen
for idx, item in enumerate(items, 1):
html_parts.append("<div class='question-block'>")
html_parts.append(f"<div class='question-number'>Frage {idx}</div>")
html_parts.append(f"<div class='question-text'>{item.get('question', '')}</div>")
if include_answers:
# Lösungsblatt: Antwort anzeigen
html_parts.append(f"<div class='answer'><strong>Antwort:</strong> {item.get('answer', '')}</div>")
# Schlüsselbegriffe
key_terms = item.get("key_terms", [])
if key_terms:
terms_html = " ".join([f"<span>{term}</span>" for term in key_terms])
html_parts.append(f"<div class='key-terms'>Wichtige Begriffe: {terms_html}</div>")
else:
# Fragenblatt: Antwortlinien
html_parts.append("<div class='answer-lines'>")
for _ in range(3):
html_parts.append("<div class='answer-line'></div>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</body></html>")
# Speichern
suffix = "_qa_solutions.html" if include_answers else "_qa_print.html"
out_name = qa_path.stem.replace("_qa", "") + suffix
out_path = BEREINIGT_DIR / out_name
out_path.write_text("\n".join(html_parts), encoding="utf-8")
logger.info(f"Print-Version gespeichert: {out_path.name}")
return out_path
def generate_print_version_cloze(cloze_path: Path, include_answers: bool = False) -> Path:
"""
Generiert eine druckbare HTML-Version der Lückentexte.
Args:
cloze_path: Pfad zur *_cloze.json Datei
include_answers: True für Lösungsblatt (für Eltern)
Returns:
Pfad zur generierten HTML-Datei
"""
if not cloze_path.exists():
raise FileNotFoundError(f"Cloze-Datei nicht gefunden: {cloze_path}")
cloze_data = json.loads(cloze_path.read_text(encoding="utf-8"))
items = cloze_data.get("cloze_items", [])
metadata = cloze_data.get("metadata", {})
title = metadata.get("source_title", "Arbeitsblatt")
subject = metadata.get("subject", "")
grade = metadata.get("grade_level", "")
total_gaps = metadata.get("total_gaps", 0)
html_parts = []
html_parts.append("""<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>""" + title + """ - Lückentext</title>
<style>
@media print {
.no-print { display: none; }
.page-break { page-break-before: always; }
}
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 40px auto;
padding: 20px;
line-height: 1.8;
}
h1 { font-size: 24px; margin-bottom: 8px; }
.meta { color: #666; margin-bottom: 24px; }
.cloze-item {
margin-bottom: 24px;
padding: 16px;
background: #f9f9f9;
border-radius: 8px;
}
.cloze-number {
font-weight: bold;
color: #333;
margin-bottom: 8px;
}
.cloze-sentence {
font-size: 16px;
line-height: 2;
}
.gap {
display: inline-block;
min-width: 80px;
border-bottom: 2px solid #333;
margin: 0 4px;
text-align: center;
}
.gap-filled {
display: inline-block;
padding: 2px 8px;
background: #e8f5e9;
border: 1px solid #4caf50;
border-radius: 4px;
font-weight: bold;
}
.translation {
margin-top: 12px;
padding: 8px;
background: #e3f2fd;
border-left: 3px solid #2196f3;
font-size: 14px;
color: #555;
}
.translation-label {
font-size: 12px;
color: #777;
margin-bottom: 4px;
}
.word-bank {
margin-top: 32px;
padding: 16px;
background: #fff3e0;
border-radius: 8px;
}
.word-bank-title {
font-weight: bold;
margin-bottom: 12px;
}
.word {
display: inline-block;
padding: 4px 12px;
margin: 4px;
background: white;
border: 1px solid #ddd;
border-radius: 4px;
}
</style>
</head>
<body>
""")
# Header
version_text = "Lösungsblatt" if include_answers else "Lückentext"
html_parts.append(f"<h1>{title} - {version_text}</h1>")
meta_parts = []
if subject:
meta_parts.append(f"Fach: {subject}")
if grade:
meta_parts.append(f"Klasse: {grade}")
meta_parts.append(f"Lücken gesamt: {total_gaps}")
html_parts.append(f"<div class='meta'>{' | '.join(meta_parts)}</div>")
# Sammle alle Lückenwörter für Wortbank
all_words = []
# Lückentexte
for idx, item in enumerate(items, 1):
html_parts.append("<div class='cloze-item'>")
html_parts.append(f"<div class='cloze-number'>{idx}.</div>")
gaps = item.get("gaps", [])
sentence = item.get("sentence_with_gaps", "")
if include_answers:
# Lösungsblatt: Lücken mit Antworten füllen
for gap in gaps:
word = gap.get("word", "")
sentence = sentence.replace("___", f"<span class='gap-filled'>{word}</span>", 1)
else:
# Fragenblatt: Lücken als Linien
sentence = sentence.replace("___", "<span class='gap'>&nbsp;</span>")
# Wörter für Wortbank sammeln
for gap in gaps:
all_words.append(gap.get("word", ""))
html_parts.append(f"<div class='cloze-sentence'>{sentence}</div>")
# Übersetzung anzeigen
translation = item.get("translation", {})
if translation:
lang_name = translation.get("language_name", "Übersetzung")
full_sentence = translation.get("full_sentence", "")
if full_sentence:
html_parts.append("<div class='translation'>")
html_parts.append(f"<div class='translation-label'>{lang_name}:</div>")
html_parts.append(full_sentence)
html_parts.append("</div>")
html_parts.append("</div>")
# Wortbank (nur für Fragenblatt)
if not include_answers and all_words:
random.shuffle(all_words) # Mische die Wörter
html_parts.append("<div class='word-bank'>")
html_parts.append("<div class='word-bank-title'>Wortbank (diese Wörter fehlen):</div>")
for word in all_words:
html_parts.append(f"<span class='word'>{word}</span>")
html_parts.append("</div>")
html_parts.append("</body></html>")
# Speichern
suffix = "_cloze_solutions.html" if include_answers else "_cloze_print.html"
out_name = cloze_path.stem.replace("_cloze", "") + suffix
out_path = BEREINIGT_DIR / out_name
out_path.write_text("\n".join(html_parts), encoding="utf-8")
logger.info(f"Cloze Print-Version gespeichert: {out_path.name}")
return out_path
def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> str:
"""
Generiert eine druckbare HTML-Version der Multiple-Choice-Fragen.
Args:
mc_path: Pfad zur *_mc.json Datei
include_answers: True für Lösungsblatt mit markierten richtigen Antworten
Returns:
HTML-String (zum direkten Ausliefern)
"""
if not mc_path.exists():
raise FileNotFoundError(f"MC-Datei nicht gefunden: {mc_path}")
mc_data = json.loads(mc_path.read_text(encoding="utf-8"))
questions = mc_data.get("questions", [])
metadata = mc_data.get("metadata", {})
title = metadata.get("source_title", "Arbeitsblatt")
subject = metadata.get("subject", "")
grade = metadata.get("grade_level", "")
html_parts = []
html_parts.append("""<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>""" + title + """ - Multiple Choice</title>
<style>
@media print {
.no-print { display: none; }
.page-break { page-break-before: always; }
body { font-size: 14pt; }
}
body {
font-family: Arial, Helvetica, sans-serif;
max-width: 800px;
margin: 40px auto;
padding: 20px;
line-height: 1.6;
color: #000;
}
h1 {
font-size: 28px;
margin-bottom: 8px;
border-bottom: 2px solid #000;
padding-bottom: 8px;
}
.meta {
color: #333;
margin-bottom: 32px;
font-size: 14px;
}
.instructions {
background: #f5f5f5;
padding: 12px 16px;
border-radius: 4px;
margin-bottom: 24px;
font-size: 14px;
}
.question-block {
margin-bottom: 28px;
padding-bottom: 16px;
border-bottom: 1px solid #ddd;
}
.question-number {
font-weight: bold;
font-size: 18px;
color: #000;
margin-bottom: 8px;
}
.question-text {
font-size: 16px;
margin: 8px 0 16px 0;
line-height: 1.5;
}
.options {
margin-left: 20px;
}
.option {
display: flex;
align-items: flex-start;
margin-bottom: 12px;
padding: 8px 12px;
border: 1px solid #ccc;
border-radius: 4px;
background: #fff;
}
.option-correct {
background: #e8f5e9;
border-color: #4caf50;
border-width: 2px;
}
.option-checkbox {
width: 20px;
height: 20px;
border: 2px solid #333;
border-radius: 50%;
margin-right: 12px;
flex-shrink: 0;
display: flex;
align-items: center;
justify-content: center;
}
.option-checkbox.checked::after {
content: "";
font-weight: bold;
color: #4caf50;
}
.option-label {
font-weight: bold;
margin-right: 8px;
min-width: 24px;
}
.option-text {
flex: 1;
}
.explanation {
margin-top: 8px;
padding: 8px 12px;
background: #e3f2fd;
border-left: 3px solid #2196f3;
font-size: 13px;
color: #333;
}
.answer-key {
margin-top: 40px;
padding: 16px;
background: #f5f5f5;
border-radius: 8px;
}
.answer-key-title {
font-weight: bold;
font-size: 18px;
margin-bottom: 12px;
border-bottom: 1px solid #999;
padding-bottom: 8px;
}
.answer-key-grid {
display: grid;
grid-template-columns: repeat(5, 1fr);
gap: 8px;
}
.answer-key-item {
padding: 8px;
text-align: center;
background: white;
border: 1px solid #ddd;
border-radius: 4px;
}
.answer-key-q {
font-weight: bold;
}
.answer-key-a {
color: #4caf50;
font-weight: bold;
}
</style>
</head>
<body>
""")
# Header
version_text = "Lösungsblatt" if include_answers else "Multiple Choice Test"
html_parts.append(f"<h1>{title}</h1>")
html_parts.append(f"<div class='meta'><strong>{version_text}</strong>")
if subject:
html_parts.append(f" | Fach: {subject}")
if grade:
html_parts.append(f" | Klasse: {grade}")
html_parts.append(f" | Anzahl Fragen: {len(questions)}</div>")
if not include_answers:
html_parts.append("<div class='instructions'>")
html_parts.append("<strong>Anleitung:</strong> Kreuze bei jeder Frage die richtige Antwort an. ")
html_parts.append("Es ist immer nur eine Antwort richtig.")
html_parts.append("</div>")
# Fragen
for idx, q in enumerate(questions, 1):
html_parts.append("<div class='question-block'>")
html_parts.append(f"<div class='question-number'>Frage {idx}</div>")
html_parts.append(f"<div class='question-text'>{q.get('question', '')}</div>")
html_parts.append("<div class='options'>")
correct_answer = q.get("correct_answer", "")
for opt in q.get("options", []):
opt_id = opt.get("id", "")
is_correct = opt_id == correct_answer
opt_class = "option"
checkbox_class = "option-checkbox"
if include_answers and is_correct:
opt_class += " option-correct"
checkbox_class += " checked"
html_parts.append(f"<div class='{opt_class}'>")
html_parts.append(f"<div class='{checkbox_class}'></div>")
html_parts.append(f"<span class='option-label'>{opt_id})</span>")
html_parts.append(f"<span class='option-text'>{opt.get('text', '')}</span>")
html_parts.append("</div>")
html_parts.append("</div>")
# Erklärung nur bei Lösungsblatt
if include_answers and q.get("explanation"):
html_parts.append(f"<div class='explanation'><strong>Erklärung:</strong> {q.get('explanation')}</div>")
html_parts.append("</div>")
# Lösungsschlüssel (kompakt) - nur bei Lösungsblatt
if include_answers:
html_parts.append("<div class='answer-key'>")
html_parts.append("<div class='answer-key-title'>Lösungsschlüssel</div>")
html_parts.append("<div class='answer-key-grid'>")
for idx, q in enumerate(questions, 1):
html_parts.append("<div class='answer-key-item'>")
html_parts.append(f"<span class='answer-key-q'>{idx}.</span> ")
html_parts.append(f"<span class='answer-key-a'>{q.get('correct_answer', '')}</span>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</body></html>")
return "\n".join(html_parts)
def generate_print_version_worksheet(analysis_path: Path) -> str:
"""
Generiert eine druckoptimierte HTML-Version des Arbeitsblatts.
Eigenschaften:
- Große, gut lesbare Schrift (16pt)
- Schwarz-weiß / Graustufen-tauglich
- Klare Struktur für Druck
- Keine interaktiven Elemente
Args:
analysis_path: Pfad zur *_analyse.json Datei
Returns:
HTML-String zum direkten Ausliefern
"""
if not analysis_path.exists():
raise FileNotFoundError(f"Analysedatei nicht gefunden: {analysis_path}")
try:
data = json.loads(analysis_path.read_text(encoding="utf-8"))
except json.JSONDecodeError as e:
raise RuntimeError(f"Analyse-Datei enthält kein gültiges JSON: {analysis_path}\n{e}") from e
title = data.get("title") or "Arbeitsblatt"
subject = data.get("subject") or ""
grade_level = data.get("grade_level") or ""
instructions = data.get("instructions") or ""
tasks = data.get("tasks", []) or []
canonical_text = data.get("canonical_text") or ""
printed_blocks = data.get("printed_blocks") or []
html_parts = []
html_parts.append("""<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>""" + title + """</title>
<style>
@page {
size: A4;
margin: 20mm;
}
@media print {
body {
font-size: 14pt !important;
-webkit-print-color-adjust: exact;
print-color-adjust: exact;
}
.no-print { display: none !important; }
.page-break { page-break-before: always; }
}
* { box-sizing: border-box; }
body {
font-family: Arial, "Helvetica Neue", sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 30px;
line-height: 1.7;
font-size: 16px;
color: #000;
background: #fff;
}
h1 {
font-size: 28px;
margin: 0 0 8px 0;
padding-bottom: 8px;
border-bottom: 3px solid #000;
}
h2 {
font-size: 20px;
margin: 28px 0 12px 0;
padding-bottom: 4px;
border-bottom: 1px solid #666;
}
.meta {
font-size: 14px;
color: #333;
margin-bottom: 20px;
padding: 8px 0;
}
.meta span {
margin-right: 20px;
}
.instructions {
margin: 20px 0;
padding: 16px;
border: 2px solid #333;
background: #f5f5f5;
font-size: 15px;
}
.instructions-label {
font-weight: bold;
margin-bottom: 8px;
}
.text-section {
margin: 24px 0;
}
.text-block {
margin-bottom: 16px;
text-align: justify;
}
.text-block-title {
font-weight: bold;
font-size: 17px;
margin-bottom: 8px;
}
.task-section {
margin-top: 32px;
}
.task {
margin-bottom: 24px;
padding: 16px;
border: 1px solid #999;
background: #fafafa;
}
.task-header {
font-weight: bold;
font-size: 16px;
margin-bottom: 12px;
padding-bottom: 8px;
border-bottom: 1px dashed #666;
}
.task-content {
font-size: 15px;
}
.gap-line {
display: inline-block;
border-bottom: 2px solid #000;
min-width: 100px;
margin: 0 6px;
}
.answer-lines {
margin-top: 16px;
}
.answer-line {
border-bottom: 1px solid #333;
height: 36px;
margin-bottom: 4px;
}
.footer {
margin-top: 40px;
padding-top: 16px;
border-top: 1px solid #ccc;
font-size: 11px;
color: #666;
text-align: center;
}
/* Print Button - versteckt beim Drucken */
.print-button {
position: fixed;
top: 20px;
right: 20px;
padding: 12px 24px;
background: #333;
color: #fff;
border: none;
border-radius: 6px;
cursor: pointer;
font-size: 14px;
}
.print-button:hover {
background: #555;
}
</style>
</head>
<body>
<button class="print-button no-print" onclick="window.print()">🖨️ Drucken</button>
""")
# Titel
html_parts.append(f"<h1>{title}</h1>")
# Meta-Informationen
meta_parts = []
if subject:
meta_parts.append(f"<span><strong>Fach:</strong> {subject}</span>")
if grade_level:
meta_parts.append(f"<span><strong>Klasse:</strong> {grade_level}</span>")
if meta_parts:
html_parts.append(f"<div class='meta'>{''.join(meta_parts)}</div>")
# Arbeitsanweisung
if instructions:
html_parts.append("<div class='instructions'>")
html_parts.append("<div class='instructions-label'>Arbeitsanweisung:</div>")
html_parts.append(f"<div>{instructions}</div>")
html_parts.append("</div>")
# Haupttext / gedruckte Blöcke
if printed_blocks:
html_parts.append("<section class='text-section'>")
for block in printed_blocks:
role = (block.get("role") or "body").lower()
text = (block.get("text") or "").strip()
if not text:
continue
if role == "title":
html_parts.append(f"<div class='text-block'><div class='text-block-title'>{text}</div></div>")
else:
html_parts.append(f"<div class='text-block'>{text}</div>")
html_parts.append("</section>")
elif canonical_text:
html_parts.append("<section class='text-section'>")
paragraphs = [
p.strip()
for p in canonical_text.replace("\r\n", "\n").split("\n\n")
if p.strip()
]
for p in paragraphs:
html_parts.append(f"<div class='text-block'>{p}</div>")
html_parts.append("</section>")
# Aufgaben
if tasks:
html_parts.append("<section class='task-section'>")
html_parts.append("<h2>Aufgaben</h2>")
for idx, task in enumerate(tasks, start=1):
t_type = task.get("type") or "Aufgabe"
desc = task.get("description") or ""
text_with_gaps = task.get("text_with_gaps")
html_parts.append("<div class='task'>")
# Task-Header
type_label = {
"fill_in_blank": "Lückentext",
"multiple_choice": "Multiple Choice",
"free_text": "Freitext",
"matching": "Zuordnung",
"labeling": "Beschriftung",
"calculation": "Rechnung",
"other": "Aufgabe"
}.get(t_type, t_type)
html_parts.append(f"<div class='task-header'>Aufgabe {idx}: {type_label}</div>")
if desc:
html_parts.append(f"<div class='task-content'>{desc}</div>")
if text_with_gaps:
rendered = text_with_gaps.replace("___", "<span class='gap-line'>&nbsp;</span>")
html_parts.append(f"<div class='task-content' style='margin-top:12px;'>{rendered}</div>")
# Antwortlinien für Freitext-Aufgaben
if t_type in ["free_text", "other"] or (not text_with_gaps and not desc):
html_parts.append("<div class='answer-lines'>")
for _ in range(3):
html_parts.append("<div class='answer-line'></div>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</section>")
# Fußzeile
html_parts.append("<div class='footer'>")
html_parts.append("Dieses Arbeitsblatt wurde automatisch aus einem Scan rekonstruiert.")
html_parts.append("</div>")
html_parts.append("</body></html>")
return "\n".join(html_parts)
from .print_qa import generate_print_version_qa
from .print_cloze import generate_print_version_cloze
from .print_mc import generate_print_version_mc
from .print_worksheet import generate_print_version_worksheet
__all__ = [
"generate_print_version_qa",
"generate_print_version_cloze",
"generate_print_version_mc",
"generate_print_version_worksheet",
]
+240
View File
@@ -0,0 +1,240 @@
"""
AI Processing - Print Version Generator: Multiple Choice.
Generates printable HTML for multiple-choice worksheets.
"""
from pathlib import Path
import json
import logging
logger = logging.getLogger(__name__)
def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> str:
"""
Generiert eine druckbare HTML-Version der Multiple-Choice-Fragen.
Args:
mc_path: Pfad zur *_mc.json Datei
include_answers: True fuer Loesungsblatt mit markierten richtigen Antworten
Returns:
HTML-String (zum direkten Ausliefern)
"""
if not mc_path.exists():
raise FileNotFoundError(f"MC-Datei nicht gefunden: {mc_path}")
mc_data = json.loads(mc_path.read_text(encoding="utf-8"))
questions = mc_data.get("questions", [])
metadata = mc_data.get("metadata", {})
title = metadata.get("source_title", "Arbeitsblatt")
subject = metadata.get("subject", "")
grade = metadata.get("grade_level", "")
html_parts = []
html_parts.append("""<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>""" + title + """ - Multiple Choice</title>
<style>
@media print {
.no-print { display: none; }
.page-break { page-break-before: always; }
body { font-size: 14pt; }
}
body {
font-family: Arial, Helvetica, sans-serif;
max-width: 800px;
margin: 40px auto;
padding: 20px;
line-height: 1.6;
color: #000;
}
h1 {
font-size: 28px;
margin-bottom: 8px;
border-bottom: 2px solid #000;
padding-bottom: 8px;
}
.meta {
color: #333;
margin-bottom: 32px;
font-size: 14px;
}
.instructions {
background: #f5f5f5;
padding: 12px 16px;
border-radius: 4px;
margin-bottom: 24px;
font-size: 14px;
}
.question-block {
margin-bottom: 28px;
padding-bottom: 16px;
border-bottom: 1px solid #ddd;
}
.question-number {
font-weight: bold;
font-size: 18px;
color: #000;
margin-bottom: 8px;
}
.question-text {
font-size: 16px;
margin: 8px 0 16px 0;
line-height: 1.5;
}
.options {
margin-left: 20px;
}
.option {
display: flex;
align-items: flex-start;
margin-bottom: 12px;
padding: 8px 12px;
border: 1px solid #ccc;
border-radius: 4px;
background: #fff;
}
.option-correct {
background: #e8f5e9;
border-color: #4caf50;
border-width: 2px;
}
.option-checkbox {
width: 20px;
height: 20px;
border: 2px solid #333;
border-radius: 50%;
margin-right: 12px;
flex-shrink: 0;
display: flex;
align-items: center;
justify-content: center;
}
.option-checkbox.checked::after {
content: "\u2713";
font-weight: bold;
color: #4caf50;
}
.option-label {
font-weight: bold;
margin-right: 8px;
min-width: 24px;
}
.option-text {
flex: 1;
}
.explanation {
margin-top: 8px;
padding: 8px 12px;
background: #e3f2fd;
border-left: 3px solid #2196f3;
font-size: 13px;
color: #333;
}
.answer-key {
margin-top: 40px;
padding: 16px;
background: #f5f5f5;
border-radius: 8px;
}
.answer-key-title {
font-weight: bold;
font-size: 18px;
margin-bottom: 12px;
border-bottom: 1px solid #999;
padding-bottom: 8px;
}
.answer-key-grid {
display: grid;
grid-template-columns: repeat(5, 1fr);
gap: 8px;
}
.answer-key-item {
padding: 8px;
text-align: center;
background: white;
border: 1px solid #ddd;
border-radius: 4px;
}
.answer-key-q {
font-weight: bold;
}
.answer-key-a {
color: #4caf50;
font-weight: bold;
}
</style>
</head>
<body>
""")
# Header
version_text = "Loesungsblatt" if include_answers else "Multiple Choice Test"
html_parts.append(f"<h1>{title}</h1>")
html_parts.append(f"<div class='meta'><strong>{version_text}</strong>")
if subject:
html_parts.append(f" | Fach: {subject}")
if grade:
html_parts.append(f" | Klasse: {grade}")
html_parts.append(f" | Anzahl Fragen: {len(questions)}</div>")
if not include_answers:
html_parts.append("<div class='instructions'>")
html_parts.append("<strong>Anleitung:</strong> Kreuze bei jeder Frage die richtige Antwort an. ")
html_parts.append("Es ist immer nur eine Antwort richtig.")
html_parts.append("</div>")
# Fragen
for idx, q in enumerate(questions, 1):
html_parts.append("<div class='question-block'>")
html_parts.append(f"<div class='question-number'>Frage {idx}</div>")
html_parts.append(f"<div class='question-text'>{q.get('question', '')}</div>")
html_parts.append("<div class='options'>")
correct_answer = q.get("correct_answer", "")
for opt in q.get("options", []):
opt_id = opt.get("id", "")
is_correct = opt_id == correct_answer
opt_class = "option"
checkbox_class = "option-checkbox"
if include_answers and is_correct:
opt_class += " option-correct"
checkbox_class += " checked"
html_parts.append(f"<div class='{opt_class}'>")
html_parts.append(f"<div class='{checkbox_class}'></div>")
html_parts.append(f"<span class='option-label'>{opt_id})</span>")
html_parts.append(f"<span class='option-text'>{opt.get('text', '')}</span>")
html_parts.append("</div>")
html_parts.append("</div>")
# Erklaerung nur bei Loesungsblatt
if include_answers and q.get("explanation"):
html_parts.append(f"<div class='explanation'><strong>Erklaerung:</strong> {q.get('explanation')}</div>")
html_parts.append("</div>")
# Loesungsschluessel (kompakt) - nur bei Loesungsblatt
if include_answers:
html_parts.append("<div class='answer-key'>")
html_parts.append("<div class='answer-key-title'>Loesungsschluessel</div>")
html_parts.append("<div class='answer-key-grid'>")
for idx, q in enumerate(questions, 1):
html_parts.append("<div class='answer-key-item'>")
html_parts.append(f"<span class='answer-key-q'>{idx}.</span> ")
html_parts.append(f"<span class='answer-key-a'>{q.get('correct_answer', '')}</span>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</body></html>")
return "\n".join(html_parts)
+149
View File
@@ -0,0 +1,149 @@
"""
AI Processing - Print Version Generator: Q&A.
Generates printable HTML for question-answer worksheets.
"""
from pathlib import Path
import json
import logging
from .core import BEREINIGT_DIR
logger = logging.getLogger(__name__)
def generate_print_version_qa(qa_path: Path, include_answers: bool = False) -> Path:
"""
Generiert eine druckbare HTML-Version der Frage-Antwort-Paare.
Args:
qa_path: Pfad zur *_qa.json Datei
include_answers: True fuer Loesungsblatt (fuer Eltern)
Returns:
Pfad zur generierten HTML-Datei
"""
if not qa_path.exists():
raise FileNotFoundError(f"Q&A-Datei nicht gefunden: {qa_path}")
qa_data = json.loads(qa_path.read_text(encoding="utf-8"))
items = qa_data.get("qa_items", [])
metadata = qa_data.get("metadata", {})
title = metadata.get("source_title", "Arbeitsblatt")
subject = metadata.get("subject", "")
grade = metadata.get("grade_level", "")
html_parts = []
html_parts.append("""<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>""" + title + """ - Fragen</title>
<style>
@media print {
.no-print { display: none; }
.page-break { page-break-before: always; }
}
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 40px auto;
padding: 20px;
line-height: 1.6;
}
h1 { font-size: 24px; margin-bottom: 8px; }
.meta { color: #666; margin-bottom: 24px; }
.question-block {
margin-bottom: 32px;
padding-bottom: 16px;
border-bottom: 1px dashed #ccc;
}
.question-number {
font-weight: bold;
color: #333;
}
.question-text {
font-size: 16px;
margin: 8px 0;
}
.answer-space {
border: 1px solid #ddd;
min-height: 60px;
margin-top: 12px;
background: #fafafa;
}
.answer-lines {
margin-top: 12px;
}
.answer-line {
border-bottom: 1px solid #999;
height: 28px;
}
.answer {
margin-top: 8px;
padding: 8px;
background: #e8f5e9;
border-left: 3px solid #4caf50;
}
.key-terms {
font-size: 12px;
color: #666;
margin-top: 8px;
}
.key-terms span {
background: #fff3e0;
padding: 2px 6px;
border-radius: 3px;
margin-right: 4px;
}
</style>
</head>
<body>
""")
# Header
version_text = "Loesungsblatt" if include_answers else "Fragenblatt"
html_parts.append(f"<h1>{title} - {version_text}</h1>")
meta_parts = []
if subject:
meta_parts.append(f"Fach: {subject}")
if grade:
meta_parts.append(f"Klasse: {grade}")
meta_parts.append(f"Anzahl Fragen: {len(items)}")
html_parts.append(f"<div class='meta'>{' | '.join(meta_parts)}</div>")
# Fragen
for idx, item in enumerate(items, 1):
html_parts.append("<div class='question-block'>")
html_parts.append(f"<div class='question-number'>Frage {idx}</div>")
html_parts.append(f"<div class='question-text'>{item.get('question', '')}</div>")
if include_answers:
# Loesungsblatt: Antwort anzeigen
html_parts.append(f"<div class='answer'><strong>Antwort:</strong> {item.get('answer', '')}</div>")
# Schluesselbegriffe
key_terms = item.get("key_terms", [])
if key_terms:
terms_html = " ".join([f"<span>{term}</span>" for term in key_terms])
html_parts.append(f"<div class='key-terms'>Wichtige Begriffe: {terms_html}</div>")
else:
# Fragenblatt: Antwortlinien
html_parts.append("<div class='answer-lines'>")
for _ in range(3):
html_parts.append("<div class='answer-line'></div>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</body></html>")
# Speichern
suffix = "_qa_solutions.html" if include_answers else "_qa_print.html"
out_name = qa_path.stem.replace("_qa", "") + suffix
out_path = BEREINIGT_DIR / out_name
out_path.write_text("\n".join(html_parts), encoding="utf-8")
logger.info(f"Print-Version gespeichert: {out_path.name}")
return out_path
@@ -0,0 +1,294 @@
"""
AI Processing - Print Version Generator: Worksheet.
Generates print-optimized HTML for general worksheets from analysis data.
"""
from pathlib import Path
import json
import logging
logger = logging.getLogger(__name__)
def generate_print_version_worksheet(analysis_path: Path) -> str:
"""
Generiert eine druckoptimierte HTML-Version des Arbeitsblatts.
Eigenschaften:
- Grosse, gut lesbare Schrift (16pt)
- Schwarz-weiss / Graustufen-tauglich
- Klare Struktur fuer Druck
- Keine interaktiven Elemente
Args:
analysis_path: Pfad zur *_analyse.json Datei
Returns:
HTML-String zum direkten Ausliefern
"""
if not analysis_path.exists():
raise FileNotFoundError(f"Analysedatei nicht gefunden: {analysis_path}")
try:
data = json.loads(analysis_path.read_text(encoding="utf-8"))
except json.JSONDecodeError as e:
raise RuntimeError(f"Analyse-Datei enthaelt kein gueltiges JSON: {analysis_path}\n{e}") from e
title = data.get("title") or "Arbeitsblatt"
subject = data.get("subject") or ""
grade_level = data.get("grade_level") or ""
instructions = data.get("instructions") or ""
tasks = data.get("tasks", []) or []
canonical_text = data.get("canonical_text") or ""
printed_blocks = data.get("printed_blocks") or []
html_parts = []
html_parts.append(_build_html_head(title))
# Titel
html_parts.append(f"<h1>{title}</h1>")
# Meta-Informationen
meta_parts = []
if subject:
meta_parts.append(f"<span><strong>Fach:</strong> {subject}</span>")
if grade_level:
meta_parts.append(f"<span><strong>Klasse:</strong> {grade_level}</span>")
if meta_parts:
html_parts.append(f"<div class='meta'>{''.join(meta_parts)}</div>")
# Arbeitsanweisung
if instructions:
html_parts.append("<div class='instructions'>")
html_parts.append("<div class='instructions-label'>Arbeitsanweisung:</div>")
html_parts.append(f"<div>{instructions}</div>")
html_parts.append("</div>")
# Haupttext / gedruckte Bloecke
_build_text_section(html_parts, printed_blocks, canonical_text)
# Aufgaben
_build_tasks_section(html_parts, tasks)
# Fusszeile
html_parts.append("<div class='footer'>")
html_parts.append("Dieses Arbeitsblatt wurde automatisch aus einem Scan rekonstruiert.")
html_parts.append("</div>")
html_parts.append("</body></html>")
return "\n".join(html_parts)
def _build_html_head(title: str) -> str:
"""Build the HTML head with print-optimized styles."""
return """<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>""" + title + """</title>
<style>
@page {
size: A4;
margin: 20mm;
}
@media print {
body {
font-size: 14pt !important;
-webkit-print-color-adjust: exact;
print-color-adjust: exact;
}
.no-print { display: none !important; }
.page-break { page-break-before: always; }
}
* { box-sizing: border-box; }
body {
font-family: Arial, "Helvetica Neue", sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 30px;
line-height: 1.7;
font-size: 16px;
color: #000;
background: #fff;
}
h1 {
font-size: 28px;
margin: 0 0 8px 0;
padding-bottom: 8px;
border-bottom: 3px solid #000;
}
h2 {
font-size: 20px;
margin: 28px 0 12px 0;
padding-bottom: 4px;
border-bottom: 1px solid #666;
}
.meta {
font-size: 14px;
color: #333;
margin-bottom: 20px;
padding: 8px 0;
}
.meta span {
margin-right: 20px;
}
.instructions {
margin: 20px 0;
padding: 16px;
border: 2px solid #333;
background: #f5f5f5;
font-size: 15px;
}
.instructions-label {
font-weight: bold;
margin-bottom: 8px;
}
.text-section {
margin: 24px 0;
}
.text-block {
margin-bottom: 16px;
text-align: justify;
}
.text-block-title {
font-weight: bold;
font-size: 17px;
margin-bottom: 8px;
}
.task-section {
margin-top: 32px;
}
.task {
margin-bottom: 24px;
padding: 16px;
border: 1px solid #999;
background: #fafafa;
}
.task-header {
font-weight: bold;
font-size: 16px;
margin-bottom: 12px;
padding-bottom: 8px;
border-bottom: 1px dashed #666;
}
.task-content {
font-size: 15px;
}
.gap-line {
display: inline-block;
border-bottom: 2px solid #000;
min-width: 100px;
margin: 0 6px;
}
.answer-lines {
margin-top: 16px;
}
.answer-line {
border-bottom: 1px solid #333;
height: 36px;
margin-bottom: 4px;
}
.footer {
margin-top: 40px;
padding-top: 16px;
border-top: 1px solid #ccc;
font-size: 11px;
color: #666;
text-align: center;
}
/* Print Button - versteckt beim Drucken */
.print-button {
position: fixed;
top: 20px;
right: 20px;
padding: 12px 24px;
background: #333;
color: #fff;
border: none;
border-radius: 6px;
cursor: pointer;
font-size: 14px;
}
.print-button:hover {
background: #555;
}
</style>
</head>
<body>
<button class="print-button no-print" onclick="window.print()">Drucken</button>
"""
def _build_text_section(html_parts: list, printed_blocks: list, canonical_text: str):
"""Build the text section from printed blocks or canonical text."""
if printed_blocks:
html_parts.append("<section class='text-section'>")
for block in printed_blocks:
role = (block.get("role") or "body").lower()
text = (block.get("text") or "").strip()
if not text:
continue
if role == "title":
html_parts.append(f"<div class='text-block'><div class='text-block-title'>{text}</div></div>")
else:
html_parts.append(f"<div class='text-block'>{text}</div>")
html_parts.append("</section>")
elif canonical_text:
html_parts.append("<section class='text-section'>")
paragraphs = [
p.strip()
for p in canonical_text.replace("\r\n", "\n").split("\n\n")
if p.strip()
]
for p in paragraphs:
html_parts.append(f"<div class='text-block'>{p}</div>")
html_parts.append("</section>")
def _build_tasks_section(html_parts: list, tasks: list):
"""Build the tasks section."""
if not tasks:
return
html_parts.append("<section class='task-section'>")
html_parts.append("<h2>Aufgaben</h2>")
type_labels = {
"fill_in_blank": "Lueckentext",
"multiple_choice": "Multiple Choice",
"free_text": "Freitext",
"matching": "Zuordnung",
"labeling": "Beschriftung",
"calculation": "Rechnung",
"other": "Aufgabe"
}
for idx, task in enumerate(tasks, start=1):
t_type = task.get("type") or "Aufgabe"
desc = task.get("description") or ""
text_with_gaps = task.get("text_with_gaps")
html_parts.append("<div class='task'>")
type_label = type_labels.get(t_type, t_type)
html_parts.append(f"<div class='task-header'>Aufgabe {idx}: {type_label}</div>")
if desc:
html_parts.append(f"<div class='task-content'>{desc}</div>")
if text_with_gaps:
rendered = text_with_gaps.replace("___", "<span class='gap-line'>&nbsp;</span>")
html_parts.append(f"<div class='task-content' style='margin-top:12px;'>{rendered}</div>")
# Antwortlinien fuer Freitext-Aufgaben
if t_type in ["free_text", "other"] or (not text_with_gaps and not desc):
html_parts.append("<div class='answer-lines'>")
for _ in range(3):
html_parts.append("<div class='answer-line'></div>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</section>")
+16 -717
View File
@@ -1,726 +1,25 @@
"""
Classroom API - Context Routes
Classroom API - Context Routes — Barrel Re-export.
Split into submodules:
- context_core.py — Teacher context, onboarding endpoints
- context_events.py — Events & routines CRUD
- context_static.py — Static data, suggestions, sidebar, school year path
School year context, events, routines, and suggestions endpoints (Phase 8).
"""
from typing import Dict, Any, Optional
from datetime import datetime
import logging
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException, Query, Depends
from classroom_engine import (
FEDERAL_STATES,
SCHOOL_TYPES,
MacroPhaseEnum,
)
from ..models import (
TeacherContextResponse,
SchoolInfo,
SchoolYearInfo,
MacroPhaseInfo,
CoreCounts,
ContextFlags,
UpdateContextRequest,
CreateEventRequest,
EventResponse,
CreateRoutineRequest,
RoutineResponse,
)
from ..services.persistence import (
init_db_if_needed,
DB_ENABLED,
SessionLocal,
)
logger = logging.getLogger(__name__)
from .context_core import router as _core_router
from .context_events import router as _events_router
from .context_static import router as _static_router
# Combine all sub-routers into a single router for backwards compatibility.
# The consumer imports `from .routes.context import router as context_router`.
router = APIRouter(tags=["Context"])
router.include_router(_core_router)
router.include_router(_events_router)
router.include_router(_static_router)
def get_db():
"""Database session dependency."""
if DB_ENABLED and SessionLocal:
db = SessionLocal()
try:
yield db
finally:
db.close()
else:
yield None
def _get_macro_phase_label(phase) -> str:
"""Gibt den Anzeigenamen einer Makro-Phase zurueck."""
labels = {
"onboarding": "Einrichtung",
"schuljahresstart": "Schuljahresstart",
"unterrichtsaufbau": "Unterrichtsaufbau",
"leistungsphase_1": "Leistungsphase 1",
"halbjahresabschluss": "Halbjahresabschluss",
"leistungsphase_2": "Leistungsphase 2",
"jahresabschluss": "Jahresabschluss",
}
phase_value = phase.value if hasattr(phase, 'value') else str(phase)
return labels.get(phase_value, phase_value)
# === Context Endpoints ===
@router.get("/v1/context", response_model=TeacherContextResponse)
async def get_teacher_context(
teacher_id: str = Query(..., description="Teacher ID"),
db=Depends(get_db)
):
"""
Liefert den aktuellen Makro-Kontext eines Lehrers.
Der Kontext beinhaltet:
- Schul-Informationen (Bundesland, Schulart)
- Schuljahr-Daten (aktuelles Jahr, Woche)
- Makro-Phase (ONBOARDING bis JAHRESABSCHLUSS)
- Zaehler (Klassen, geplante Klausuren, etc.)
- Status-Flags (Onboarding abgeschlossen, etc.)
"""
if DB_ENABLED and db:
try:
from classroom_engine.repository import TeacherContextRepository, SchoolyearEventRepository
repo = TeacherContextRepository(db)
context = repo.get_or_create(teacher_id)
# Zaehler berechnen
event_repo = SchoolyearEventRepository(db)
upcoming_exams = event_repo.get_upcoming(teacher_id, days=30)
exams_count = len([e for e in upcoming_exams if e.event_type.value == "exam"])
return TeacherContextResponse(
schema_version="1.0",
teacher_id=teacher_id,
school=SchoolInfo(
federal_state=context.federal_state or "BY",
federal_state_name=FEDERAL_STATES.get(context.federal_state, ""),
school_type=context.school_type or "gymnasium",
school_type_name=SCHOOL_TYPES.get(context.school_type, ""),
),
school_year=SchoolYearInfo(
id=context.schoolyear or "2024-2025",
start=context.schoolyear_start.isoformat() if context.schoolyear_start else None,
current_week=context.current_week or 1,
),
macro_phase=MacroPhaseInfo(
id=context.macro_phase.value,
label=_get_macro_phase_label(context.macro_phase),
confidence=1.0,
),
core_counts=CoreCounts(
classes=1 if context.has_classes else 0,
exams_scheduled=exams_count,
corrections_pending=0,
),
flags=ContextFlags(
onboarding_completed=context.onboarding_completed,
has_classes=context.has_classes,
has_schedule=context.has_schedule,
is_exam_period=context.is_exam_period,
is_before_holidays=context.is_before_holidays,
),
)
except Exception as e:
logger.error(f"Failed to get teacher context: {e}")
raise HTTPException(status_code=500, detail=f"Fehler beim Laden des Kontexts: {e}")
# Fallback ohne DB
return TeacherContextResponse(
schema_version="1.0",
teacher_id=teacher_id,
school=SchoolInfo(
federal_state="BY",
federal_state_name="Bayern",
school_type="gymnasium",
school_type_name="Gymnasium",
),
school_year=SchoolYearInfo(
id="2024-2025",
start=None,
current_week=1,
),
macro_phase=MacroPhaseInfo(
id="onboarding",
label="Einrichtung",
confidence=1.0,
),
core_counts=CoreCounts(),
flags=ContextFlags(),
)
@router.put("/v1/context", response_model=TeacherContextResponse)
async def update_teacher_context(
teacher_id: str,
request: UpdateContextRequest,
db=Depends(get_db)
):
"""
Aktualisiert den Kontext eines Lehrers.
"""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import TeacherContextRepository
repo = TeacherContextRepository(db)
# Validierung
if request.federal_state and request.federal_state not in FEDERAL_STATES:
raise HTTPException(status_code=400, detail=f"Ungueltiges Bundesland: {request.federal_state}")
if request.school_type and request.school_type not in SCHOOL_TYPES:
raise HTTPException(status_code=400, detail=f"Ungueltige Schulart: {request.school_type}")
# Parse datetime if provided
schoolyear_start = None
if request.schoolyear_start:
schoolyear_start = datetime.fromisoformat(request.schoolyear_start.replace('Z', '+00:00'))
repo.update_context(
teacher_id=teacher_id,
federal_state=request.federal_state,
school_type=request.school_type,
schoolyear=request.schoolyear,
schoolyear_start=schoolyear_start,
macro_phase=request.macro_phase,
current_week=request.current_week,
)
return await get_teacher_context(teacher_id, db)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update teacher context: {e}")
raise HTTPException(status_code=500, detail=f"Fehler beim Aktualisieren: {e}")
@router.post("/v1/context/complete-onboarding")
async def complete_onboarding(
teacher_id: str = Query(...),
db=Depends(get_db)
):
"""Markiert das Onboarding als abgeschlossen."""
if not DB_ENABLED or not db:
return {"success": True, "macro_phase": "schuljahresstart", "note": "DB not available"}
try:
from classroom_engine.repository import TeacherContextRepository
repo = TeacherContextRepository(db)
context = repo.complete_onboarding(teacher_id)
return {
"success": True,
"macro_phase": context.macro_phase.value,
"teacher_id": teacher_id,
}
except Exception as e:
logger.error(f"Failed to complete onboarding: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.post("/v1/context/reset-onboarding")
async def reset_onboarding(
teacher_id: str = Query(...),
db=Depends(get_db)
):
"""Setzt das Onboarding zurueck (fuer Tests)."""
if not DB_ENABLED or not db:
return {"success": True, "macro_phase": "onboarding", "note": "DB not available"}
try:
from classroom_engine.repository import TeacherContextRepository
repo = TeacherContextRepository(db)
context = repo.get_or_create(teacher_id)
context.onboarding_completed = False
context.macro_phase = MacroPhaseEnum.ONBOARDING
db.commit()
db.refresh(context)
return {
"success": True,
"macro_phase": "onboarding",
"teacher_id": teacher_id,
}
except Exception as e:
logger.error(f"Failed to reset onboarding: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
# === Events Endpoints ===
@router.get("/v1/events")
async def get_events(
teacher_id: str = Query(...),
status: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 50,
db=Depends(get_db)
):
"""Holt Events eines Lehrers."""
if not DB_ENABLED or not db:
return {"events": [], "count": 0}
try:
from classroom_engine.repository import SchoolyearEventRepository
repo = SchoolyearEventRepository(db)
events = repo.get_by_teacher(teacher_id, status=status, event_type=event_type, limit=limit)
return {
"events": [repo.to_dict(e) for e in events],
"count": len(events),
}
except Exception as e:
logger.error(f"Failed to get events: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.get("/v1/events/upcoming")
async def get_upcoming_events(
teacher_id: str = Query(...),
days: int = 30,
limit: int = 10,
db=Depends(get_db)
):
"""Holt anstehende Events der naechsten X Tage."""
if not DB_ENABLED or not db:
return {"events": [], "count": 0}
try:
from classroom_engine.repository import SchoolyearEventRepository
repo = SchoolyearEventRepository(db)
events = repo.get_upcoming(teacher_id, days=days, limit=limit)
return {
"events": [repo.to_dict(e) for e in events],
"count": len(events),
}
except Exception as e:
logger.error(f"Failed to get upcoming events: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.post("/v1/events", response_model=EventResponse)
async def create_event(
teacher_id: str,
request: CreateEventRequest,
db=Depends(get_db)
):
"""Erstellt ein neues Schuljahr-Event."""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import SchoolyearEventRepository
repo = SchoolyearEventRepository(db)
start_date = datetime.fromisoformat(request.start_date.replace('Z', '+00:00'))
end_date = None
if request.end_date:
end_date = datetime.fromisoformat(request.end_date.replace('Z', '+00:00'))
event = repo.create(
teacher_id=teacher_id,
title=request.title,
event_type=request.event_type,
start_date=start_date,
end_date=end_date,
class_id=request.class_id,
subject=request.subject,
description=request.description,
needs_preparation=request.needs_preparation,
reminder_days_before=request.reminder_days_before,
)
return EventResponse(
id=event.id,
teacher_id=event.teacher_id,
event_type=event.event_type.value,
title=event.title,
description=event.description,
start_date=event.start_date.isoformat(),
end_date=event.end_date.isoformat() if event.end_date else None,
class_id=event.class_id,
subject=event.subject,
status=event.status.value,
needs_preparation=event.needs_preparation,
preparation_done=event.preparation_done,
reminder_days_before=event.reminder_days_before,
)
except Exception as e:
logger.error(f"Failed to create event: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.delete("/v1/events/{event_id}")
async def delete_event(event_id: str, db=Depends(get_db)):
"""Loescht ein Event."""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import SchoolyearEventRepository
repo = SchoolyearEventRepository(db)
if repo.delete(event_id):
return {"success": True, "deleted_id": event_id}
raise HTTPException(status_code=404, detail="Event nicht gefunden")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete event: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
# === Routines Endpoints ===
@router.get("/v1/routines")
async def get_routines(
teacher_id: str = Query(...),
is_active: bool = True,
routine_type: Optional[str] = None,
db=Depends(get_db)
):
"""Holt Routinen eines Lehrers."""
if not DB_ENABLED or not db:
return {"routines": [], "count": 0}
try:
from classroom_engine.repository import RecurringRoutineRepository
repo = RecurringRoutineRepository(db)
routines = repo.get_by_teacher(teacher_id, is_active=is_active, routine_type=routine_type)
return {
"routines": [repo.to_dict(r) for r in routines],
"count": len(routines),
}
except Exception as e:
logger.error(f"Failed to get routines: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.get("/v1/routines/today")
async def get_today_routines(teacher_id: str = Query(...), db=Depends(get_db)):
"""Holt Routinen die heute stattfinden."""
if not DB_ENABLED or not db:
return {"routines": [], "count": 0}
try:
from classroom_engine.repository import RecurringRoutineRepository
repo = RecurringRoutineRepository(db)
routines = repo.get_today(teacher_id)
return {
"routines": [repo.to_dict(r) for r in routines],
"count": len(routines),
}
except Exception as e:
logger.error(f"Failed to get today's routines: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.post("/v1/routines", response_model=RoutineResponse)
async def create_routine(
teacher_id: str,
request: CreateRoutineRequest,
db=Depends(get_db)
):
"""Erstellt eine neue wiederkehrende Routine."""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import RecurringRoutineRepository
repo = RecurringRoutineRepository(db)
routine = repo.create(
teacher_id=teacher_id,
title=request.title,
routine_type=request.routine_type,
recurrence_pattern=request.recurrence_pattern,
day_of_week=request.day_of_week,
day_of_month=request.day_of_month,
time_of_day=request.time_of_day,
duration_minutes=request.duration_minutes,
description=request.description,
)
return RoutineResponse(
id=routine.id,
teacher_id=routine.teacher_id,
routine_type=routine.routine_type.value,
title=routine.title,
description=routine.description,
recurrence_pattern=routine.recurrence_pattern.value,
day_of_week=routine.day_of_week,
day_of_month=routine.day_of_month,
time_of_day=routine.time_of_day.isoformat() if routine.time_of_day else None,
duration_minutes=routine.duration_minutes,
is_active=routine.is_active,
)
except Exception as e:
logger.error(f"Failed to create routine: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.delete("/v1/routines/{routine_id}")
async def delete_routine(routine_id: str, db=Depends(get_db)):
"""Loescht eine Routine."""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import RecurringRoutineRepository
repo = RecurringRoutineRepository(db)
if repo.delete(routine_id):
return {"success": True, "deleted_id": routine_id}
raise HTTPException(status_code=404, detail="Routine nicht gefunden")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete routine: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
# === Static Data Endpoints ===
@router.get("/v1/federal-states")
async def get_federal_states():
"""Gibt alle Bundeslaender zurueck."""
return {
"federal_states": [{"id": k, "name": v} for k, v in FEDERAL_STATES.items()]
}
@router.get("/v1/school-types")
async def get_school_types():
"""Gibt alle Schularten zurueck."""
return {
"school_types": [{"id": k, "name": v} for k, v in SCHOOL_TYPES.items()]
}
@router.get("/v1/macro-phases")
async def get_macro_phases():
"""Gibt alle Makro-Phasen mit Beschreibungen zurueck."""
phases = [
{"id": "onboarding", "label": "Einrichtung", "description": "Ersteinrichtung (Klassen, Stundenplan)", "order": 1},
{"id": "schuljahresstart", "label": "Schuljahresstart", "description": "Erste 2-3 Wochen des Schuljahres", "order": 2},
{"id": "unterrichtsaufbau", "label": "Unterrichtsaufbau", "description": "Routinen etablieren, erste Bewertungen", "order": 3},
{"id": "leistungsphase_1", "label": "Leistungsphase 1", "description": "Erste Klassenarbeiten und Klausuren", "order": 4},
{"id": "halbjahresabschluss", "label": "Halbjahresabschluss", "description": "Notenschluss, Zeugnisse, Konferenzen", "order": 5},
{"id": "leistungsphase_2", "label": "Leistungsphase 2", "description": "Zweites Halbjahr, Pruefungsvorbereitung", "order": 6},
{"id": "jahresabschluss", "label": "Jahresabschluss", "description": "Finale Noten, Versetzung, Schuljahresende", "order": 7},
]
return {"macro_phases": phases}
@router.get("/v1/event-types")
async def get_event_types():
"""Gibt alle Event-Typen zurueck."""
types = [
{"id": "exam", "label": "Klassenarbeit/Klausur"},
{"id": "parent_evening", "label": "Elternabend"},
{"id": "trip", "label": "Klassenfahrt/Ausflug"},
{"id": "project", "label": "Projektwoche"},
{"id": "internship", "label": "Praktikum"},
{"id": "presentation", "label": "Referate/Praesentationen"},
{"id": "sports_day", "label": "Sporttag"},
{"id": "school_festival", "label": "Schulfest"},
{"id": "parent_consultation", "label": "Elternsprechtag"},
{"id": "grade_deadline", "label": "Notenschluss"},
{"id": "report_cards", "label": "Zeugnisausgabe"},
{"id": "holiday_start", "label": "Ferienbeginn"},
{"id": "holiday_end", "label": "Ferienende"},
{"id": "other", "label": "Sonstiges"},
]
return {"event_types": types}
@router.get("/v1/routine-types")
async def get_routine_types():
"""Gibt alle Routine-Typen zurueck."""
types = [
{"id": "teacher_conference", "label": "Lehrerkonferenz"},
{"id": "subject_conference", "label": "Fachkonferenz"},
{"id": "office_hours", "label": "Sprechstunde"},
{"id": "team_meeting", "label": "Teamsitzung"},
{"id": "supervision", "label": "Pausenaufsicht"},
{"id": "correction_time", "label": "Korrekturzeit"},
{"id": "prep_time", "label": "Vorbereitungszeit"},
{"id": "other", "label": "Sonstiges"},
]
return {"routine_types": types}
# === Suggestions & Sidebar ===
@router.get("/v1/suggestions")
async def get_suggestions(
teacher_id: str = Query(...),
limit: int = Query(5, ge=1, le=20),
db=Depends(get_db)
):
"""Generiert kontextbasierte Vorschlaege fuer einen Lehrer."""
if DB_ENABLED and db:
try:
from classroom_engine.suggestions import SuggestionGenerator
generator = SuggestionGenerator(db)
result = generator.generate(teacher_id, limit=limit)
return result
except Exception as e:
logger.error(f"Failed to generate suggestions: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
return {
"active_contexts": [],
"suggestions": [],
"signals_summary": {
"macro_phase": "onboarding",
"current_week": 1,
"has_classes": False,
"exams_soon": 0,
"routines_today": 0,
},
"total_suggestions": 0,
}
@router.get("/v1/sidebar")
async def get_sidebar(
teacher_id: str = Query(...),
mode: str = Query("companion"),
db=Depends(get_db)
):
"""Generiert das dynamische Sidebar-Model."""
if mode == "companion":
now_relevant = []
if DB_ENABLED and db:
try:
from classroom_engine.suggestions import SuggestionGenerator
generator = SuggestionGenerator(db)
result = generator.generate(teacher_id, limit=5)
now_relevant = [
{
"id": s["id"],
"label": s["title"],
"state": "recommended" if s["priority"] > 70 else "default",
"badge": s.get("badge"),
"icon": s.get("icon", "lightbulb"),
"action_url": s.get("action_url"),
}
for s in result.get("suggestions", [])
]
except Exception as e:
logger.warning(f"Failed to get suggestions for sidebar: {e}")
return {
"mode": "companion",
"sections": [
{"id": "SEARCH", "type": "search_bar", "placeholder": "Suchen..."},
{
"id": "NOW_RELEVANT",
"type": "list",
"title": "Jetzt relevant",
"items": now_relevant if now_relevant else [
{"id": "no_suggestions", "label": "Keine Vorschlaege", "state": "default", "icon": "check_circle"}
],
},
{
"id": "ALL_MODULES",
"type": "folder",
"label": "Alle Module",
"icon": "folder",
"collapsed": True,
"items": [
{"id": "lesson", "label": "Stundenmodus", "icon": "timer"},
{"id": "classes", "label": "Klassen", "icon": "groups"},
{"id": "exams", "label": "Klausuren", "icon": "quiz"},
{"id": "grades", "label": "Noten", "icon": "calculate"},
{"id": "calendar", "label": "Kalender", "icon": "calendar_month"},
{"id": "materials", "label": "Materialien", "icon": "folder_open"},
],
},
{
"id": "QUICK_ACTIONS",
"type": "actions",
"title": "Kurzaktionen",
"items": [
{"id": "scan", "label": "Scan hochladen", "icon": "upload_file"},
{"id": "note", "label": "Notiz erstellen", "icon": "note_add"},
],
},
],
}
else:
return {
"mode": "classic",
"sections": [
{
"id": "NAVIGATION",
"type": "tree",
"items": [
{"id": "dashboard", "label": "Dashboard", "icon": "dashboard", "url": "/dashboard"},
{"id": "lesson", "label": "Stundenmodus", "icon": "timer", "url": "/lesson"},
{"id": "classes", "label": "Klassen", "icon": "groups", "url": "/classes"},
{"id": "exams", "label": "Klausuren", "icon": "quiz", "url": "/exams"},
{"id": "grades", "label": "Noten", "icon": "calculate", "url": "/grades"},
{"id": "calendar", "label": "Kalender", "icon": "calendar_month", "url": "/calendar"},
{"id": "materials", "label": "Materialien", "icon": "folder_open", "url": "/materials"},
{"id": "settings", "label": "Einstellungen", "icon": "settings", "url": "/settings"},
],
},
],
}
@router.get("/v1/path")
async def get_schoolyear_path(teacher_id: str = Query(...), db=Depends(get_db)):
"""Generiert den Schuljahres-Pfad mit Meilensteinen."""
current_phase = "onboarding"
if DB_ENABLED and db:
try:
from classroom_engine.repository import TeacherContextRepository
repo = TeacherContextRepository(db)
context = repo.get_or_create(teacher_id)
current_phase = context.macro_phase.value
except Exception as e:
logger.warning(f"Failed to get context for path: {e}")
phase_order = [
"onboarding", "schuljahresstart", "unterrichtsaufbau",
"leistungsphase_1", "halbjahresabschluss", "leistungsphase_2", "jahresabschluss",
]
current_index = phase_order.index(current_phase) if current_phase in phase_order else 0
milestones = [
{"id": "MS_START", "label": "Start", "phase": "onboarding", "icon": "flag"},
{"id": "MS_SETUP", "label": "Einrichtung", "phase": "schuljahresstart", "icon": "tune"},
{"id": "MS_ROUTINE", "label": "Routinen", "phase": "unterrichtsaufbau", "icon": "repeat"},
{"id": "MS_EXAM_1", "label": "Klausuren", "phase": "leistungsphase_1", "icon": "quiz"},
{"id": "MS_HALFYEAR", "label": "Halbjahr", "phase": "halbjahresabschluss", "icon": "event"},
{"id": "MS_EXAM_2", "label": "Pruefungen", "phase": "leistungsphase_2", "icon": "school"},
{"id": "MS_END", "label": "Abschluss", "phase": "jahresabschluss", "icon": "celebration"},
]
for milestone in milestones:
phase = milestone["phase"]
phase_index = phase_order.index(phase) if phase in phase_order else 999
if phase_index < current_index:
milestone["status"] = "done"
elif phase_index == current_index:
milestone["status"] = "current"
else:
milestone["status"] = "upcoming"
current_milestone_id = next(
(m["id"] for m in milestones if m["status"] == "current"),
milestones[0]["id"]
)
progress = int((current_index / (len(phase_order) - 1)) * 100) if len(phase_order) > 1 else 0
return {
"milestones": milestones,
"current_milestone_id": current_milestone_id,
"progress_percent": progress,
"current_phase": current_phase,
}
__all__ = ["router"]
@@ -0,0 +1,247 @@
"""
Classroom API - Context Core Routes
Teacher context, onboarding endpoints (Phase 8).
"""
from typing import Dict, Any
from datetime import datetime
import logging
from fastapi import APIRouter, HTTPException, Query, Depends
from classroom_engine import (
FEDERAL_STATES,
SCHOOL_TYPES,
MacroPhaseEnum,
)
from ..models import (
TeacherContextResponse,
SchoolInfo,
SchoolYearInfo,
MacroPhaseInfo,
CoreCounts,
ContextFlags,
UpdateContextRequest,
)
from ..services.persistence import (
init_db_if_needed,
DB_ENABLED,
SessionLocal,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Context"])
def get_db():
"""Database session dependency."""
if DB_ENABLED and SessionLocal:
db = SessionLocal()
try:
yield db
finally:
db.close()
else:
yield None
def _get_macro_phase_label(phase) -> str:
"""Gibt den Anzeigenamen einer Makro-Phase zurueck."""
labels = {
"onboarding": "Einrichtung",
"schuljahresstart": "Schuljahresstart",
"unterrichtsaufbau": "Unterrichtsaufbau",
"leistungsphase_1": "Leistungsphase 1",
"halbjahresabschluss": "Halbjahresabschluss",
"leistungsphase_2": "Leistungsphase 2",
"jahresabschluss": "Jahresabschluss",
}
phase_value = phase.value if hasattr(phase, 'value') else str(phase)
return labels.get(phase_value, phase_value)
# === Context Endpoints ===
@router.get("/v1/context", response_model=TeacherContextResponse)
async def get_teacher_context(
teacher_id: str = Query(..., description="Teacher ID"),
db=Depends(get_db)
):
"""
Liefert den aktuellen Makro-Kontext eines Lehrers.
Der Kontext beinhaltet:
- Schul-Informationen (Bundesland, Schulart)
- Schuljahr-Daten (aktuelles Jahr, Woche)
- Makro-Phase (ONBOARDING bis JAHRESABSCHLUSS)
- Zaehler (Klassen, geplante Klausuren, etc.)
- Status-Flags (Onboarding abgeschlossen, etc.)
"""
if DB_ENABLED and db:
try:
from classroom_engine.repository import TeacherContextRepository, SchoolyearEventRepository
repo = TeacherContextRepository(db)
context = repo.get_or_create(teacher_id)
# Zaehler berechnen
event_repo = SchoolyearEventRepository(db)
upcoming_exams = event_repo.get_upcoming(teacher_id, days=30)
exams_count = len([e for e in upcoming_exams if e.event_type.value == "exam"])
return TeacherContextResponse(
schema_version="1.0",
teacher_id=teacher_id,
school=SchoolInfo(
federal_state=context.federal_state or "BY",
federal_state_name=FEDERAL_STATES.get(context.federal_state, ""),
school_type=context.school_type or "gymnasium",
school_type_name=SCHOOL_TYPES.get(context.school_type, ""),
),
school_year=SchoolYearInfo(
id=context.schoolyear or "2024-2025",
start=context.schoolyear_start.isoformat() if context.schoolyear_start else None,
current_week=context.current_week or 1,
),
macro_phase=MacroPhaseInfo(
id=context.macro_phase.value,
label=_get_macro_phase_label(context.macro_phase),
confidence=1.0,
),
core_counts=CoreCounts(
classes=1 if context.has_classes else 0,
exams_scheduled=exams_count,
corrections_pending=0,
),
flags=ContextFlags(
onboarding_completed=context.onboarding_completed,
has_classes=context.has_classes,
has_schedule=context.has_schedule,
is_exam_period=context.is_exam_period,
is_before_holidays=context.is_before_holidays,
),
)
except Exception as e:
logger.error(f"Failed to get teacher context: {e}")
raise HTTPException(status_code=500, detail=f"Fehler beim Laden des Kontexts: {e}")
# Fallback ohne DB
return TeacherContextResponse(
schema_version="1.0",
teacher_id=teacher_id,
school=SchoolInfo(
federal_state="BY",
federal_state_name="Bayern",
school_type="gymnasium",
school_type_name="Gymnasium",
),
school_year=SchoolYearInfo(
id="2024-2025",
start=None,
current_week=1,
),
macro_phase=MacroPhaseInfo(
id="onboarding",
label="Einrichtung",
confidence=1.0,
),
core_counts=CoreCounts(),
flags=ContextFlags(),
)
@router.put("/v1/context", response_model=TeacherContextResponse)
async def update_teacher_context(
teacher_id: str,
request: UpdateContextRequest,
db=Depends(get_db)
):
"""
Aktualisiert den Kontext eines Lehrers.
"""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import TeacherContextRepository
repo = TeacherContextRepository(db)
# Validierung
if request.federal_state and request.federal_state not in FEDERAL_STATES:
raise HTTPException(status_code=400, detail=f"Ungueltiges Bundesland: {request.federal_state}")
if request.school_type and request.school_type not in SCHOOL_TYPES:
raise HTTPException(status_code=400, detail=f"Ungueltige Schulart: {request.school_type}")
# Parse datetime if provided
schoolyear_start = None
if request.schoolyear_start:
schoolyear_start = datetime.fromisoformat(request.schoolyear_start.replace('Z', '+00:00'))
repo.update_context(
teacher_id=teacher_id,
federal_state=request.federal_state,
school_type=request.school_type,
schoolyear=request.schoolyear,
schoolyear_start=schoolyear_start,
macro_phase=request.macro_phase,
current_week=request.current_week,
)
return await get_teacher_context(teacher_id, db)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update teacher context: {e}")
raise HTTPException(status_code=500, detail=f"Fehler beim Aktualisieren: {e}")
@router.post("/v1/context/complete-onboarding")
async def complete_onboarding(
teacher_id: str = Query(...),
db=Depends(get_db)
):
"""Markiert das Onboarding als abgeschlossen."""
if not DB_ENABLED or not db:
return {"success": True, "macro_phase": "schuljahresstart", "note": "DB not available"}
try:
from classroom_engine.repository import TeacherContextRepository
repo = TeacherContextRepository(db)
context = repo.complete_onboarding(teacher_id)
return {
"success": True,
"macro_phase": context.macro_phase.value,
"teacher_id": teacher_id,
}
except Exception as e:
logger.error(f"Failed to complete onboarding: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.post("/v1/context/reset-onboarding")
async def reset_onboarding(
teacher_id: str = Query(...),
db=Depends(get_db)
):
"""Setzt das Onboarding zurueck (fuer Tests)."""
if not DB_ENABLED or not db:
return {"success": True, "macro_phase": "onboarding", "note": "DB not available"}
try:
from classroom_engine.repository import TeacherContextRepository
repo = TeacherContextRepository(db)
context = repo.get_or_create(teacher_id)
context.onboarding_completed = False
context.macro_phase = MacroPhaseEnum.ONBOARDING
db.commit()
db.refresh(context)
return {
"success": True,
"macro_phase": "onboarding",
"teacher_id": teacher_id,
}
except Exception as e:
logger.error(f"Failed to reset onboarding: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@@ -0,0 +1,266 @@
"""
Classroom API - Events & Routines Routes
School year events, recurring routines endpoints (Phase 8).
"""
from typing import Optional
from datetime import datetime
import logging
from fastapi import APIRouter, HTTPException, Query, Depends
from ..models import (
CreateEventRequest,
EventResponse,
CreateRoutineRequest,
RoutineResponse,
)
from ..services.persistence import (
DB_ENABLED,
SessionLocal,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Context"])
def get_db():
"""Database session dependency."""
if DB_ENABLED and SessionLocal:
db = SessionLocal()
try:
yield db
finally:
db.close()
else:
yield None
# === Events Endpoints ===
@router.get("/v1/events")
async def get_events(
teacher_id: str = Query(...),
status: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 50,
db=Depends(get_db)
):
"""Holt Events eines Lehrers."""
if not DB_ENABLED or not db:
return {"events": [], "count": 0}
try:
from classroom_engine.repository import SchoolyearEventRepository
repo = SchoolyearEventRepository(db)
events = repo.get_by_teacher(teacher_id, status=status, event_type=event_type, limit=limit)
return {
"events": [repo.to_dict(e) for e in events],
"count": len(events),
}
except Exception as e:
logger.error(f"Failed to get events: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.get("/v1/events/upcoming")
async def get_upcoming_events(
teacher_id: str = Query(...),
days: int = 30,
limit: int = 10,
db=Depends(get_db)
):
"""Holt anstehende Events der naechsten X Tage."""
if not DB_ENABLED or not db:
return {"events": [], "count": 0}
try:
from classroom_engine.repository import SchoolyearEventRepository
repo = SchoolyearEventRepository(db)
events = repo.get_upcoming(teacher_id, days=days, limit=limit)
return {
"events": [repo.to_dict(e) for e in events],
"count": len(events),
}
except Exception as e:
logger.error(f"Failed to get upcoming events: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.post("/v1/events", response_model=EventResponse)
async def create_event(
teacher_id: str,
request: CreateEventRequest,
db=Depends(get_db)
):
"""Erstellt ein neues Schuljahr-Event."""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import SchoolyearEventRepository
repo = SchoolyearEventRepository(db)
start_date = datetime.fromisoformat(request.start_date.replace('Z', '+00:00'))
end_date = None
if request.end_date:
end_date = datetime.fromisoformat(request.end_date.replace('Z', '+00:00'))
event = repo.create(
teacher_id=teacher_id,
title=request.title,
event_type=request.event_type,
start_date=start_date,
end_date=end_date,
class_id=request.class_id,
subject=request.subject,
description=request.description,
needs_preparation=request.needs_preparation,
reminder_days_before=request.reminder_days_before,
)
return EventResponse(
id=event.id,
teacher_id=event.teacher_id,
event_type=event.event_type.value,
title=event.title,
description=event.description,
start_date=event.start_date.isoformat(),
end_date=event.end_date.isoformat() if event.end_date else None,
class_id=event.class_id,
subject=event.subject,
status=event.status.value,
needs_preparation=event.needs_preparation,
preparation_done=event.preparation_done,
reminder_days_before=event.reminder_days_before,
)
except Exception as e:
logger.error(f"Failed to create event: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.delete("/v1/events/{event_id}")
async def delete_event(event_id: str, db=Depends(get_db)):
"""Loescht ein Event."""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import SchoolyearEventRepository
repo = SchoolyearEventRepository(db)
if repo.delete(event_id):
return {"success": True, "deleted_id": event_id}
raise HTTPException(status_code=404, detail="Event nicht gefunden")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete event: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
# === Routines Endpoints ===
@router.get("/v1/routines")
async def get_routines(
teacher_id: str = Query(...),
is_active: bool = True,
routine_type: Optional[str] = None,
db=Depends(get_db)
):
"""Holt Routinen eines Lehrers."""
if not DB_ENABLED or not db:
return {"routines": [], "count": 0}
try:
from classroom_engine.repository import RecurringRoutineRepository
repo = RecurringRoutineRepository(db)
routines = repo.get_by_teacher(teacher_id, is_active=is_active, routine_type=routine_type)
return {
"routines": [repo.to_dict(r) for r in routines],
"count": len(routines),
}
except Exception as e:
logger.error(f"Failed to get routines: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.get("/v1/routines/today")
async def get_today_routines(teacher_id: str = Query(...), db=Depends(get_db)):
"""Holt Routinen die heute stattfinden."""
if not DB_ENABLED or not db:
return {"routines": [], "count": 0}
try:
from classroom_engine.repository import RecurringRoutineRepository
repo = RecurringRoutineRepository(db)
routines = repo.get_today(teacher_id)
return {
"routines": [repo.to_dict(r) for r in routines],
"count": len(routines),
}
except Exception as e:
logger.error(f"Failed to get today's routines: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.post("/v1/routines", response_model=RoutineResponse)
async def create_routine(
teacher_id: str,
request: CreateRoutineRequest,
db=Depends(get_db)
):
"""Erstellt eine neue wiederkehrende Routine."""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import RecurringRoutineRepository
repo = RecurringRoutineRepository(db)
routine = repo.create(
teacher_id=teacher_id,
title=request.title,
routine_type=request.routine_type,
recurrence_pattern=request.recurrence_pattern,
day_of_week=request.day_of_week,
day_of_month=request.day_of_month,
time_of_day=request.time_of_day,
duration_minutes=request.duration_minutes,
description=request.description,
)
return RoutineResponse(
id=routine.id,
teacher_id=routine.teacher_id,
routine_type=routine.routine_type.value,
title=routine.title,
description=routine.description,
recurrence_pattern=routine.recurrence_pattern.value,
day_of_week=routine.day_of_week,
day_of_month=routine.day_of_month,
time_of_day=routine.time_of_day.isoformat() if routine.time_of_day else None,
duration_minutes=routine.duration_minutes,
is_active=routine.is_active,
)
except Exception as e:
logger.error(f"Failed to create routine: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@router.delete("/v1/routines/{routine_id}")
async def delete_routine(routine_id: str, db=Depends(get_db)):
"""Loescht eine Routine."""
if not DB_ENABLED or not db:
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
try:
from classroom_engine.repository import RecurringRoutineRepository
repo = RecurringRoutineRepository(db)
if repo.delete(routine_id):
return {"success": True, "deleted_id": routine_id}
raise HTTPException(status_code=404, detail="Routine nicht gefunden")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete routine: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
@@ -0,0 +1,281 @@
"""
Classroom API - Static Data, Suggestions & Sidebar Routes
Federal states, school types, macro phases, event/routine types,
suggestions, sidebar model, and school year path.
"""
from typing import Optional
import logging
from fastapi import APIRouter, HTTPException, Query, Depends
from classroom_engine import FEDERAL_STATES, SCHOOL_TYPES
from ..services.persistence import (
DB_ENABLED,
SessionLocal,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Context"])
def get_db():
"""Database session dependency."""
if DB_ENABLED and SessionLocal:
db = SessionLocal()
try:
yield db
finally:
db.close()
else:
yield None
# === Static Data Endpoints ===
@router.get("/v1/federal-states")
async def get_federal_states():
"""Gibt alle Bundeslaender zurueck."""
return {
"federal_states": [{"id": k, "name": v} for k, v in FEDERAL_STATES.items()]
}
@router.get("/v1/school-types")
async def get_school_types():
"""Gibt alle Schularten zurueck."""
return {
"school_types": [{"id": k, "name": v} for k, v in SCHOOL_TYPES.items()]
}
@router.get("/v1/macro-phases")
async def get_macro_phases():
"""Gibt alle Makro-Phasen mit Beschreibungen zurueck."""
phases = [
{"id": "onboarding", "label": "Einrichtung", "description": "Ersteinrichtung (Klassen, Stundenplan)", "order": 1},
{"id": "schuljahresstart", "label": "Schuljahresstart", "description": "Erste 2-3 Wochen des Schuljahres", "order": 2},
{"id": "unterrichtsaufbau", "label": "Unterrichtsaufbau", "description": "Routinen etablieren, erste Bewertungen", "order": 3},
{"id": "leistungsphase_1", "label": "Leistungsphase 1", "description": "Erste Klassenarbeiten und Klausuren", "order": 4},
{"id": "halbjahresabschluss", "label": "Halbjahresabschluss", "description": "Notenschluss, Zeugnisse, Konferenzen", "order": 5},
{"id": "leistungsphase_2", "label": "Leistungsphase 2", "description": "Zweites Halbjahr, Pruefungsvorbereitung", "order": 6},
{"id": "jahresabschluss", "label": "Jahresabschluss", "description": "Finale Noten, Versetzung, Schuljahresende", "order": 7},
]
return {"macro_phases": phases}
@router.get("/v1/event-types")
async def get_event_types():
"""Gibt alle Event-Typen zurueck."""
types = [
{"id": "exam", "label": "Klassenarbeit/Klausur"},
{"id": "parent_evening", "label": "Elternabend"},
{"id": "trip", "label": "Klassenfahrt/Ausflug"},
{"id": "project", "label": "Projektwoche"},
{"id": "internship", "label": "Praktikum"},
{"id": "presentation", "label": "Referate/Praesentationen"},
{"id": "sports_day", "label": "Sporttag"},
{"id": "school_festival", "label": "Schulfest"},
{"id": "parent_consultation", "label": "Elternsprechtag"},
{"id": "grade_deadline", "label": "Notenschluss"},
{"id": "report_cards", "label": "Zeugnisausgabe"},
{"id": "holiday_start", "label": "Ferienbeginn"},
{"id": "holiday_end", "label": "Ferienende"},
{"id": "other", "label": "Sonstiges"},
]
return {"event_types": types}
@router.get("/v1/routine-types")
async def get_routine_types():
"""Gibt alle Routine-Typen zurueck."""
types = [
{"id": "teacher_conference", "label": "Lehrerkonferenz"},
{"id": "subject_conference", "label": "Fachkonferenz"},
{"id": "office_hours", "label": "Sprechstunde"},
{"id": "team_meeting", "label": "Teamsitzung"},
{"id": "supervision", "label": "Pausenaufsicht"},
{"id": "correction_time", "label": "Korrekturzeit"},
{"id": "prep_time", "label": "Vorbereitungszeit"},
{"id": "other", "label": "Sonstiges"},
]
return {"routine_types": types}
# === Suggestions & Sidebar ===
@router.get("/v1/suggestions")
async def get_suggestions(
teacher_id: str = Query(...),
limit: int = Query(5, ge=1, le=20),
db=Depends(get_db)
):
"""Generiert kontextbasierte Vorschlaege fuer einen Lehrer."""
if DB_ENABLED and db:
try:
from classroom_engine.suggestions import SuggestionGenerator
generator = SuggestionGenerator(db)
result = generator.generate(teacher_id, limit=limit)
return result
except Exception as e:
logger.error(f"Failed to generate suggestions: {e}")
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
return {
"active_contexts": [],
"suggestions": [],
"signals_summary": {
"macro_phase": "onboarding",
"current_week": 1,
"has_classes": False,
"exams_soon": 0,
"routines_today": 0,
},
"total_suggestions": 0,
}
@router.get("/v1/sidebar")
async def get_sidebar(
teacher_id: str = Query(...),
mode: str = Query("companion"),
db=Depends(get_db)
):
"""Generiert das dynamische Sidebar-Model."""
if mode == "companion":
now_relevant = []
if DB_ENABLED and db:
try:
from classroom_engine.suggestions import SuggestionGenerator
generator = SuggestionGenerator(db)
result = generator.generate(teacher_id, limit=5)
now_relevant = [
{
"id": s["id"],
"label": s["title"],
"state": "recommended" if s["priority"] > 70 else "default",
"badge": s.get("badge"),
"icon": s.get("icon", "lightbulb"),
"action_url": s.get("action_url"),
}
for s in result.get("suggestions", [])
]
except Exception as e:
logger.warning(f"Failed to get suggestions for sidebar: {e}")
return {
"mode": "companion",
"sections": [
{"id": "SEARCH", "type": "search_bar", "placeholder": "Suchen..."},
{
"id": "NOW_RELEVANT",
"type": "list",
"title": "Jetzt relevant",
"items": now_relevant if now_relevant else [
{"id": "no_suggestions", "label": "Keine Vorschlaege", "state": "default", "icon": "check_circle"}
],
},
{
"id": "ALL_MODULES",
"type": "folder",
"label": "Alle Module",
"icon": "folder",
"collapsed": True,
"items": [
{"id": "lesson", "label": "Stundenmodus", "icon": "timer"},
{"id": "classes", "label": "Klassen", "icon": "groups"},
{"id": "exams", "label": "Klausuren", "icon": "quiz"},
{"id": "grades", "label": "Noten", "icon": "calculate"},
{"id": "calendar", "label": "Kalender", "icon": "calendar_month"},
{"id": "materials", "label": "Materialien", "icon": "folder_open"},
],
},
{
"id": "QUICK_ACTIONS",
"type": "actions",
"title": "Kurzaktionen",
"items": [
{"id": "scan", "label": "Scan hochladen", "icon": "upload_file"},
{"id": "note", "label": "Notiz erstellen", "icon": "note_add"},
],
},
],
}
else:
return {
"mode": "classic",
"sections": [
{
"id": "NAVIGATION",
"type": "tree",
"items": [
{"id": "dashboard", "label": "Dashboard", "icon": "dashboard", "url": "/dashboard"},
{"id": "lesson", "label": "Stundenmodus", "icon": "timer", "url": "/lesson"},
{"id": "classes", "label": "Klassen", "icon": "groups", "url": "/classes"},
{"id": "exams", "label": "Klausuren", "icon": "quiz", "url": "/exams"},
{"id": "grades", "label": "Noten", "icon": "calculate", "url": "/grades"},
{"id": "calendar", "label": "Kalender", "icon": "calendar_month", "url": "/calendar"},
{"id": "materials", "label": "Materialien", "icon": "folder_open", "url": "/materials"},
{"id": "settings", "label": "Einstellungen", "icon": "settings", "url": "/settings"},
],
},
],
}
@router.get("/v1/path")
async def get_schoolyear_path(teacher_id: str = Query(...), db=Depends(get_db)):
"""Generiert den Schuljahres-Pfad mit Meilensteinen."""
current_phase = "onboarding"
if DB_ENABLED and db:
try:
from classroom_engine.repository import TeacherContextRepository
repo = TeacherContextRepository(db)
context = repo.get_or_create(teacher_id)
current_phase = context.macro_phase.value
except Exception as e:
logger.warning(f"Failed to get context for path: {e}")
phase_order = [
"onboarding", "schuljahresstart", "unterrichtsaufbau",
"leistungsphase_1", "halbjahresabschluss", "leistungsphase_2", "jahresabschluss",
]
current_index = phase_order.index(current_phase) if current_phase in phase_order else 0
milestones = [
{"id": "MS_START", "label": "Start", "phase": "onboarding", "icon": "flag"},
{"id": "MS_SETUP", "label": "Einrichtung", "phase": "schuljahresstart", "icon": "tune"},
{"id": "MS_ROUTINE", "label": "Routinen", "phase": "unterrichtsaufbau", "icon": "repeat"},
{"id": "MS_EXAM_1", "label": "Klausuren", "phase": "leistungsphase_1", "icon": "quiz"},
{"id": "MS_HALFYEAR", "label": "Halbjahr", "phase": "halbjahresabschluss", "icon": "event"},
{"id": "MS_EXAM_2", "label": "Pruefungen", "phase": "leistungsphase_2", "icon": "school"},
{"id": "MS_END", "label": "Abschluss", "phase": "jahresabschluss", "icon": "celebration"},
]
for milestone in milestones:
phase = milestone["phase"]
phase_index = phase_order.index(phase) if phase in phase_order else 999
if phase_index < current_index:
milestone["status"] = "done"
elif phase_index == current_index:
milestone["status"] = "current"
else:
milestone["status"] = "upcoming"
current_milestone_id = next(
(m["id"] for m in milestones if m["status"] == "current"),
milestones[0]["id"]
)
progress = int((current_index / (len(phase_order) - 1)) * 100) if len(phase_order) > 1 else 0
return {
"milestones": milestones,
"current_milestone_id": current_milestone_id,
"progress_percent": progress,
"current_phase": current_phase,
}
@@ -0,0 +1,386 @@
"""
EduSearch Seeds CRUD Routes.
List, get, create, update, delete, and bulk import for seed URLs.
"""
import os
import logging
from typing import Optional, List
from datetime import datetime
from fastapi import APIRouter, HTTPException, Query
import asyncpg
from .edu_search_models import (
CategoryResponse,
SeedCreate,
SeedUpdate,
SeedResponse,
SeedsListResponse,
BulkImportRequest,
BulkImportResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["edu-search"])
# Database connection pool
_pool: Optional[asyncpg.Pool] = None
async def get_db_pool() -> asyncpg.Pool:
"""Get or create database connection pool."""
global _pool
if _pool is None:
database_url = os.environ.get("DATABASE_URL")
if not database_url:
raise RuntimeError("DATABASE_URL nicht konfiguriert - bitte via Vault oder Umgebungsvariable setzen")
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
return _pool
@router.get("/categories", response_model=List[CategoryResponse])
async def list_categories():
"""List all seed categories."""
pool = await get_db_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, display_name, description, icon, sort_order, is_active
FROM edu_search_categories
WHERE is_active = TRUE
ORDER BY sort_order
""")
return [
CategoryResponse(
id=str(row["id"]),
name=row["name"],
display_name=row["display_name"],
description=row["description"],
icon=row["icon"],
sort_order=row["sort_order"],
is_active=row["is_active"],
)
for row in rows
]
@router.get("/seeds", response_model=SeedsListResponse)
async def list_seeds(
category: Optional[str] = Query(None, description="Filter by category name"),
state: Optional[str] = Query(None, description="Filter by state code"),
enabled: Optional[bool] = Query(None, description="Filter by enabled status"),
search: Optional[str] = Query(None, description="Search in name/url"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
"""List seeds with optional filtering and pagination."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Build WHERE clause
conditions = []
params = []
param_idx = 1
if category:
conditions.append(f"c.name = ${param_idx}")
params.append(category)
param_idx += 1
if state:
conditions.append(f"s.state = ${param_idx}")
params.append(state)
param_idx += 1
if enabled is not None:
conditions.append(f"s.enabled = ${param_idx}")
params.append(enabled)
param_idx += 1
if search:
conditions.append(f"(s.name ILIKE ${param_idx} OR s.url ILIKE ${param_idx})")
params.append(f"%{search}%")
param_idx += 1
where_clause = " AND ".join(conditions) if conditions else "TRUE"
# Count total
count_query = f"""
SELECT COUNT(*) FROM edu_search_seeds s
LEFT JOIN edu_search_categories c ON s.category_id = c.id
WHERE {where_clause}
"""
total = await conn.fetchval(count_query, *params)
# Get paginated results
offset = (page - 1) * page_size
params.extend([page_size, offset])
query = f"""
SELECT
s.id, s.url, s.name, s.description,
c.name as category, c.display_name as category_display_name,
s.source_type, s.scope, s.state, s.trust_boost, s.enabled,
s.crawl_depth, s.crawl_frequency, s.last_crawled_at,
s.last_crawl_status, s.last_crawl_docs, s.total_documents,
s.created_at, s.updated_at
FROM edu_search_seeds s
LEFT JOIN edu_search_categories c ON s.category_id = c.id
WHERE {where_clause}
ORDER BY c.sort_order, s.name
LIMIT ${param_idx} OFFSET ${param_idx + 1}
"""
rows = await conn.fetch(query, *params)
seeds = [_row_to_seed_response(row) for row in rows]
return SeedsListResponse(
seeds=seeds,
total=total,
page=page,
page_size=page_size,
)
@router.get("/seeds/{seed_id}", response_model=SeedResponse)
async def get_seed(seed_id: str):
"""Get a single seed by ID."""
pool = await get_db_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT
s.id, s.url, s.name, s.description,
c.name as category, c.display_name as category_display_name,
s.source_type, s.scope, s.state, s.trust_boost, s.enabled,
s.crawl_depth, s.crawl_frequency, s.last_crawled_at,
s.last_crawl_status, s.last_crawl_docs, s.total_documents,
s.created_at, s.updated_at
FROM edu_search_seeds s
LEFT JOIN edu_search_categories c ON s.category_id = c.id
WHERE s.id = $1
""", seed_id)
if not row:
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
return _row_to_seed_response(row)
@router.post("/seeds", response_model=SeedResponse, status_code=201)
async def create_seed(seed: SeedCreate):
"""Create a new seed URL."""
pool = await get_db_pool()
async with pool.acquire() as conn:
category_id = None
if seed.category_name:
category_id = await conn.fetchval(
"SELECT id FROM edu_search_categories WHERE name = $1",
seed.category_name
)
try:
row = await conn.fetchrow("""
INSERT INTO edu_search_seeds (
url, name, description, category_id, source_type, scope,
state, trust_boost, enabled, crawl_depth, crawl_frequency
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, created_at, updated_at
""",
seed.url, seed.name, seed.description, category_id,
seed.source_type, seed.scope, seed.state, seed.trust_boost,
seed.enabled, seed.crawl_depth, seed.crawl_frequency
)
except asyncpg.UniqueViolationError:
raise HTTPException(status_code=409, detail="URL existiert bereits")
return SeedResponse(
id=str(row["id"]),
url=seed.url,
name=seed.name,
description=seed.description,
category=seed.category_name,
category_display_name=None,
source_type=seed.source_type,
scope=seed.scope,
state=seed.state,
trust_boost=seed.trust_boost,
enabled=seed.enabled,
crawl_depth=seed.crawl_depth,
crawl_frequency=seed.crawl_frequency,
last_crawled_at=None,
last_crawl_status=None,
last_crawl_docs=0,
total_documents=0,
created_at=row["created_at"],
updated_at=row["updated_at"],
)
@router.put("/seeds/{seed_id}", response_model=SeedResponse)
async def update_seed(seed_id: str, seed: SeedUpdate):
"""Update an existing seed."""
pool = await get_db_pool()
async with pool.acquire() as conn:
updates = []
params = []
param_idx = 1
if seed.url is not None:
updates.append(f"url = ${param_idx}")
params.append(seed.url)
param_idx += 1
if seed.name is not None:
updates.append(f"name = ${param_idx}")
params.append(seed.name)
param_idx += 1
if seed.description is not None:
updates.append(f"description = ${param_idx}")
params.append(seed.description)
param_idx += 1
if seed.category_name is not None:
category_id = await conn.fetchval(
"SELECT id FROM edu_search_categories WHERE name = $1",
seed.category_name
)
updates.append(f"category_id = ${param_idx}")
params.append(category_id)
param_idx += 1
if seed.source_type is not None:
updates.append(f"source_type = ${param_idx}")
params.append(seed.source_type)
param_idx += 1
if seed.scope is not None:
updates.append(f"scope = ${param_idx}")
params.append(seed.scope)
param_idx += 1
if seed.state is not None:
updates.append(f"state = ${param_idx}")
params.append(seed.state)
param_idx += 1
if seed.trust_boost is not None:
updates.append(f"trust_boost = ${param_idx}")
params.append(seed.trust_boost)
param_idx += 1
if seed.enabled is not None:
updates.append(f"enabled = ${param_idx}")
params.append(seed.enabled)
param_idx += 1
if seed.crawl_depth is not None:
updates.append(f"crawl_depth = ${param_idx}")
params.append(seed.crawl_depth)
param_idx += 1
if seed.crawl_frequency is not None:
updates.append(f"crawl_frequency = ${param_idx}")
params.append(seed.crawl_frequency)
param_idx += 1
if not updates:
raise HTTPException(status_code=400, detail="Keine Felder zum Aktualisieren")
updates.append("updated_at = NOW()")
params.append(seed_id)
query = f"""
UPDATE edu_search_seeds
SET {", ".join(updates)}
WHERE id = ${param_idx}
RETURNING id
"""
result = await conn.fetchrow(query, *params)
if not result:
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
# Return updated seed
return await get_seed(seed_id)
@router.delete("/seeds/{seed_id}")
async def delete_seed(seed_id: str):
"""Delete a seed."""
pool = await get_db_pool()
async with pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM edu_search_seeds WHERE id = $1",
seed_id
)
if result == "DELETE 0":
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
return {"status": "deleted", "id": seed_id}
@router.post("/seeds/bulk-import", response_model=BulkImportResponse)
async def bulk_import_seeds(request: BulkImportRequest):
"""Bulk import seeds (skip duplicates)."""
pool = await get_db_pool()
imported = 0
skipped = 0
errors = []
async with pool.acquire() as conn:
# Pre-fetch all category IDs
categories = {}
rows = await conn.fetch("SELECT id, name FROM edu_search_categories")
for row in rows:
categories[row["name"]] = row["id"]
for seed in request.seeds:
try:
category_id = categories.get(seed.category_name) if seed.category_name else None
await conn.execute("""
INSERT INTO edu_search_seeds (
url, name, description, category_id, source_type, scope,
state, trust_boost, enabled, crawl_depth, crawl_frequency
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (url) DO NOTHING
""",
seed.url, seed.name, seed.description, category_id,
seed.source_type, seed.scope, seed.state, seed.trust_boost,
seed.enabled, seed.crawl_depth, seed.crawl_frequency
)
imported += 1
except asyncpg.UniqueViolationError:
skipped += 1
except Exception as e:
errors.append(f"{seed.url}: {str(e)}")
return BulkImportResponse(imported=imported, skipped=skipped, errors=errors)
def _row_to_seed_response(row) -> SeedResponse:
"""Convert a database row to SeedResponse."""
return SeedResponse(
id=str(row["id"]),
url=row["url"],
name=row["name"],
description=row["description"],
category=row["category"],
category_display_name=row["category_display_name"],
source_type=row["source_type"],
scope=row["scope"],
state=row["state"],
trust_boost=float(row["trust_boost"]),
enabled=row["enabled"],
crawl_depth=row["crawl_depth"],
crawl_frequency=row["crawl_frequency"],
last_crawled_at=row["last_crawled_at"],
last_crawl_status=row["last_crawl_status"],
last_crawl_docs=row["last_crawl_docs"] or 0,
total_documents=row["total_documents"] or 0,
created_at=row["created_at"],
updated_at=row["updated_at"],
)
@@ -0,0 +1,137 @@
"""
EduSearch Seeds Pydantic Models.
Request/Response models for the education search seed URL API.
"""
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field
class CategoryResponse(BaseModel):
"""Category response model."""
id: str
name: str
display_name: str
description: Optional[str] = None
icon: Optional[str] = None
sort_order: int
is_active: bool
class SeedBase(BaseModel):
"""Base seed model for creation/update."""
url: str = Field(..., max_length=500)
name: str = Field(..., max_length=255)
description: Optional[str] = None
category_name: Optional[str] = Field(None, description="Category name (federal, states, etc.)")
source_type: str = Field("GOV", description="GOV, EDU, UNI, etc.")
scope: str = Field("FEDERAL", description="FEDERAL, STATE, etc.")
state: Optional[str] = Field(None, max_length=5, description="State code (BW, BY, etc.)")
trust_boost: float = Field(0.50, ge=0.0, le=1.0)
enabled: bool = True
crawl_depth: int = Field(2, ge=1, le=5)
crawl_frequency: str = Field("weekly", description="hourly, daily, weekly, monthly")
class SeedCreate(SeedBase):
"""Seed creation model."""
pass
class SeedUpdate(BaseModel):
"""Seed update model (all fields optional)."""
url: Optional[str] = Field(None, max_length=500)
name: Optional[str] = Field(None, max_length=255)
description: Optional[str] = None
category_name: Optional[str] = None
source_type: Optional[str] = None
scope: Optional[str] = None
state: Optional[str] = Field(None, max_length=5)
trust_boost: Optional[float] = Field(None, ge=0.0, le=1.0)
enabled: Optional[bool] = None
crawl_depth: Optional[int] = Field(None, ge=1, le=5)
crawl_frequency: Optional[str] = None
class SeedResponse(BaseModel):
"""Seed response model."""
id: str
url: str
name: str
description: Optional[str] = None
category: Optional[str] = None
category_display_name: Optional[str] = None
source_type: str
scope: str
state: Optional[str] = None
trust_boost: float
enabled: bool
crawl_depth: int
crawl_frequency: str
last_crawled_at: Optional[datetime] = None
last_crawl_status: Optional[str] = None
last_crawl_docs: int = 0
total_documents: int = 0
created_at: datetime
updated_at: datetime
class SeedsListResponse(BaseModel):
"""List response with pagination info."""
seeds: List[SeedResponse]
total: int
page: int
page_size: int
class StatsResponse(BaseModel):
"""Crawl statistics response."""
total_seeds: int
enabled_seeds: int
total_documents: int
seeds_by_category: dict
seeds_by_state: dict
last_crawl_time: Optional[datetime] = None
class BulkImportRequest(BaseModel):
"""Bulk import request."""
seeds: List[SeedCreate]
class BulkImportResponse(BaseModel):
"""Bulk import response."""
imported: int
skipped: int
errors: List[str]
class CrawlStatusUpdate(BaseModel):
"""Crawl status update from edu-search-service."""
seed_url: str = Field(..., description="The seed URL that was crawled")
status: str = Field(..., description="Crawl status: success, error, partial")
documents_crawled: int = Field(0, ge=0, description="Number of documents crawled")
error_message: Optional[str] = Field(None, description="Error message if status is error")
crawl_duration_seconds: float = Field(0.0, ge=0.0, description="Duration of the crawl in seconds")
class CrawlStatusResponse(BaseModel):
"""Response for crawl status update."""
success: bool
seed_url: str
message: str
class BulkCrawlStatusUpdate(BaseModel):
"""Bulk crawl status update."""
updates: List[CrawlStatusUpdate]
class BulkCrawlStatusResponse(BaseModel):
"""Response for bulk crawl status update."""
updated: int
failed: int
errors: List[str]
@@ -1,710 +1,58 @@
"""
EduSearch Seeds API Routes.
EduSearch Seeds API Routes — Barrel Re-export.
Split into submodules:
- edu_search_models.py — Pydantic request/response models
- edu_search_crud.py — CRUD endpoints (list, get, create, update, delete, bulk import)
- edu_search_status.py — Stats, export for crawler, crawl status feedback
CRUD operations for managing education search crawler seed URLs.
Direct database access to PostgreSQL.
"""
import os
import logging
from typing import Optional, List
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException, Depends, Query
from pydantic import BaseModel, Field, HttpUrl
import asyncpg
from .edu_search_crud import router as _crud_router, get_db_pool
from .edu_search_status import router as _status_router
logger = logging.getLogger(__name__)
# Re-export models for consumers that import types from this module
from .edu_search_models import (
CategoryResponse,
SeedBase,
SeedCreate,
SeedUpdate,
SeedResponse,
SeedsListResponse,
StatsResponse,
BulkImportRequest,
BulkImportResponse,
CrawlStatusUpdate,
CrawlStatusResponse,
BulkCrawlStatusUpdate,
BulkCrawlStatusResponse,
)
# Combine both sub-routers into a single router for backwards compatibility.
# The consumer imports `from .edu_search_seeds import router as edu_search_seeds_router`.
router = APIRouter(prefix="/edu-search", tags=["edu-search"])
# Database connection pool
_pool: Optional[asyncpg.Pool] = None
async def get_db_pool() -> asyncpg.Pool:
"""Get or create database connection pool."""
global _pool
if _pool is None:
database_url = os.environ.get("DATABASE_URL")
if not database_url:
raise RuntimeError("DATABASE_URL nicht konfiguriert - bitte via Vault oder Umgebungsvariable setzen")
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
return _pool
# =============================================================================
# Pydantic Models
# =============================================================================
class CategoryResponse(BaseModel):
"""Category response model."""
id: str
name: str
display_name: str
description: Optional[str] = None
icon: Optional[str] = None
sort_order: int
is_active: bool
class SeedBase(BaseModel):
"""Base seed model for creation/update."""
url: str = Field(..., max_length=500)
name: str = Field(..., max_length=255)
description: Optional[str] = None
category_name: Optional[str] = Field(None, description="Category name (federal, states, etc.)")
source_type: str = Field("GOV", description="GOV, EDU, UNI, etc.")
scope: str = Field("FEDERAL", description="FEDERAL, STATE, etc.")
state: Optional[str] = Field(None, max_length=5, description="State code (BW, BY, etc.)")
trust_boost: float = Field(0.50, ge=0.0, le=1.0)
enabled: bool = True
crawl_depth: int = Field(2, ge=1, le=5)
crawl_frequency: str = Field("weekly", description="hourly, daily, weekly, monthly")
class SeedCreate(SeedBase):
"""Seed creation model."""
pass
class SeedUpdate(BaseModel):
"""Seed update model (all fields optional)."""
url: Optional[str] = Field(None, max_length=500)
name: Optional[str] = Field(None, max_length=255)
description: Optional[str] = None
category_name: Optional[str] = None
source_type: Optional[str] = None
scope: Optional[str] = None
state: Optional[str] = Field(None, max_length=5)
trust_boost: Optional[float] = Field(None, ge=0.0, le=1.0)
enabled: Optional[bool] = None
crawl_depth: Optional[int] = Field(None, ge=1, le=5)
crawl_frequency: Optional[str] = None
class SeedResponse(BaseModel):
"""Seed response model."""
id: str
url: str
name: str
description: Optional[str] = None
category: Optional[str] = None
category_display_name: Optional[str] = None
source_type: str
scope: str
state: Optional[str] = None
trust_boost: float
enabled: bool
crawl_depth: int
crawl_frequency: str
last_crawled_at: Optional[datetime] = None
last_crawl_status: Optional[str] = None
last_crawl_docs: int = 0
total_documents: int = 0
created_at: datetime
updated_at: datetime
class SeedsListResponse(BaseModel):
"""List response with pagination info."""
seeds: List[SeedResponse]
total: int
page: int
page_size: int
class StatsResponse(BaseModel):
"""Crawl statistics response."""
total_seeds: int
enabled_seeds: int
total_documents: int
seeds_by_category: dict
seeds_by_state: dict
last_crawl_time: Optional[datetime] = None
class BulkImportRequest(BaseModel):
"""Bulk import request."""
seeds: List[SeedCreate]
class BulkImportResponse(BaseModel):
"""Bulk import response."""
imported: int
skipped: int
errors: List[str]
# =============================================================================
# API Endpoints
# =============================================================================
@router.get("/categories", response_model=List[CategoryResponse])
async def list_categories():
"""List all seed categories."""
pool = await get_db_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, display_name, description, icon, sort_order, is_active
FROM edu_search_categories
WHERE is_active = TRUE
ORDER BY sort_order
""")
return [
CategoryResponse(
id=str(row["id"]),
name=row["name"],
display_name=row["display_name"],
description=row["description"],
icon=row["icon"],
sort_order=row["sort_order"],
is_active=row["is_active"],
)
for row in rows
]
@router.get("/seeds", response_model=SeedsListResponse)
async def list_seeds(
category: Optional[str] = Query(None, description="Filter by category name"),
state: Optional[str] = Query(None, description="Filter by state code"),
enabled: Optional[bool] = Query(None, description="Filter by enabled status"),
search: Optional[str] = Query(None, description="Search in name/url"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
"""List seeds with optional filtering and pagination."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Build WHERE clause
conditions = []
params = []
param_idx = 1
if category:
conditions.append(f"c.name = ${param_idx}")
params.append(category)
param_idx += 1
if state:
conditions.append(f"s.state = ${param_idx}")
params.append(state)
param_idx += 1
if enabled is not None:
conditions.append(f"s.enabled = ${param_idx}")
params.append(enabled)
param_idx += 1
if search:
conditions.append(f"(s.name ILIKE ${param_idx} OR s.url ILIKE ${param_idx})")
params.append(f"%{search}%")
param_idx += 1
where_clause = " AND ".join(conditions) if conditions else "TRUE"
# Count total
count_query = f"""
SELECT COUNT(*) FROM edu_search_seeds s
LEFT JOIN edu_search_categories c ON s.category_id = c.id
WHERE {where_clause}
"""
total = await conn.fetchval(count_query, *params)
# Get paginated results
offset = (page - 1) * page_size
params.extend([page_size, offset])
query = f"""
SELECT
s.id, s.url, s.name, s.description,
c.name as category, c.display_name as category_display_name,
s.source_type, s.scope, s.state, s.trust_boost, s.enabled,
s.crawl_depth, s.crawl_frequency, s.last_crawled_at,
s.last_crawl_status, s.last_crawl_docs, s.total_documents,
s.created_at, s.updated_at
FROM edu_search_seeds s
LEFT JOIN edu_search_categories c ON s.category_id = c.id
WHERE {where_clause}
ORDER BY c.sort_order, s.name
LIMIT ${param_idx} OFFSET ${param_idx + 1}
"""
rows = await conn.fetch(query, *params)
seeds = [
SeedResponse(
id=str(row["id"]),
url=row["url"],
name=row["name"],
description=row["description"],
category=row["category"],
category_display_name=row["category_display_name"],
source_type=row["source_type"],
scope=row["scope"],
state=row["state"],
trust_boost=float(row["trust_boost"]),
enabled=row["enabled"],
crawl_depth=row["crawl_depth"],
crawl_frequency=row["crawl_frequency"],
last_crawled_at=row["last_crawled_at"],
last_crawl_status=row["last_crawl_status"],
last_crawl_docs=row["last_crawl_docs"] or 0,
total_documents=row["total_documents"] or 0,
created_at=row["created_at"],
updated_at=row["updated_at"],
)
for row in rows
]
return SeedsListResponse(
seeds=seeds,
total=total,
page=page,
page_size=page_size,
)
@router.get("/seeds/{seed_id}", response_model=SeedResponse)
async def get_seed(seed_id: str):
"""Get a single seed by ID."""
pool = await get_db_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT
s.id, s.url, s.name, s.description,
c.name as category, c.display_name as category_display_name,
s.source_type, s.scope, s.state, s.trust_boost, s.enabled,
s.crawl_depth, s.crawl_frequency, s.last_crawled_at,
s.last_crawl_status, s.last_crawl_docs, s.total_documents,
s.created_at, s.updated_at
FROM edu_search_seeds s
LEFT JOIN edu_search_categories c ON s.category_id = c.id
WHERE s.id = $1
""", seed_id)
if not row:
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
return SeedResponse(
id=str(row["id"]),
url=row["url"],
name=row["name"],
description=row["description"],
category=row["category"],
category_display_name=row["category_display_name"],
source_type=row["source_type"],
scope=row["scope"],
state=row["state"],
trust_boost=float(row["trust_boost"]),
enabled=row["enabled"],
crawl_depth=row["crawl_depth"],
crawl_frequency=row["crawl_frequency"],
last_crawled_at=row["last_crawled_at"],
last_crawl_status=row["last_crawl_status"],
last_crawl_docs=row["last_crawl_docs"] or 0,
total_documents=row["total_documents"] or 0,
created_at=row["created_at"],
updated_at=row["updated_at"],
)
@router.post("/seeds", response_model=SeedResponse, status_code=201)
async def create_seed(seed: SeedCreate):
"""Create a new seed URL."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Get category ID if provided
category_id = None
if seed.category_name:
category_id = await conn.fetchval(
"SELECT id FROM edu_search_categories WHERE name = $1",
seed.category_name
)
try:
row = await conn.fetchrow("""
INSERT INTO edu_search_seeds (
url, name, description, category_id, source_type, scope,
state, trust_boost, enabled, crawl_depth, crawl_frequency
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, created_at, updated_at
""",
seed.url, seed.name, seed.description, category_id,
seed.source_type, seed.scope, seed.state, seed.trust_boost,
seed.enabled, seed.crawl_depth, seed.crawl_frequency
)
except asyncpg.UniqueViolationError:
raise HTTPException(status_code=409, detail="URL existiert bereits")
return SeedResponse(
id=str(row["id"]),
url=seed.url,
name=seed.name,
description=seed.description,
category=seed.category_name,
category_display_name=None,
source_type=seed.source_type,
scope=seed.scope,
state=seed.state,
trust_boost=seed.trust_boost,
enabled=seed.enabled,
crawl_depth=seed.crawl_depth,
crawl_frequency=seed.crawl_frequency,
last_crawled_at=None,
last_crawl_status=None,
last_crawl_docs=0,
total_documents=0,
created_at=row["created_at"],
updated_at=row["updated_at"],
)
@router.put("/seeds/{seed_id}", response_model=SeedResponse)
async def update_seed(seed_id: str, seed: SeedUpdate):
"""Update an existing seed."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Build update statement dynamically
updates = []
params = []
param_idx = 1
if seed.url is not None:
updates.append(f"url = ${param_idx}")
params.append(seed.url)
param_idx += 1
if seed.name is not None:
updates.append(f"name = ${param_idx}")
params.append(seed.name)
param_idx += 1
if seed.description is not None:
updates.append(f"description = ${param_idx}")
params.append(seed.description)
param_idx += 1
if seed.category_name is not None:
category_id = await conn.fetchval(
"SELECT id FROM edu_search_categories WHERE name = $1",
seed.category_name
)
updates.append(f"category_id = ${param_idx}")
params.append(category_id)
param_idx += 1
if seed.source_type is not None:
updates.append(f"source_type = ${param_idx}")
params.append(seed.source_type)
param_idx += 1
if seed.scope is not None:
updates.append(f"scope = ${param_idx}")
params.append(seed.scope)
param_idx += 1
if seed.state is not None:
updates.append(f"state = ${param_idx}")
params.append(seed.state)
param_idx += 1
if seed.trust_boost is not None:
updates.append(f"trust_boost = ${param_idx}")
params.append(seed.trust_boost)
param_idx += 1
if seed.enabled is not None:
updates.append(f"enabled = ${param_idx}")
params.append(seed.enabled)
param_idx += 1
if seed.crawl_depth is not None:
updates.append(f"crawl_depth = ${param_idx}")
params.append(seed.crawl_depth)
param_idx += 1
if seed.crawl_frequency is not None:
updates.append(f"crawl_frequency = ${param_idx}")
params.append(seed.crawl_frequency)
param_idx += 1
if not updates:
raise HTTPException(status_code=400, detail="Keine Felder zum Aktualisieren")
updates.append("updated_at = NOW()")
params.append(seed_id)
query = f"""
UPDATE edu_search_seeds
SET {", ".join(updates)}
WHERE id = ${param_idx}
RETURNING id
"""
result = await conn.fetchrow(query, *params)
if not result:
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
# Return updated seed
return await get_seed(seed_id)
@router.delete("/seeds/{seed_id}")
async def delete_seed(seed_id: str):
"""Delete a seed."""
pool = await get_db_pool()
async with pool.acquire() as conn:
result = await conn.execute(
"DELETE FROM edu_search_seeds WHERE id = $1",
seed_id
)
if result == "DELETE 0":
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
return {"status": "deleted", "id": seed_id}
@router.post("/seeds/bulk-import", response_model=BulkImportResponse)
async def bulk_import_seeds(request: BulkImportRequest):
"""Bulk import seeds (skip duplicates)."""
pool = await get_db_pool()
imported = 0
skipped = 0
errors = []
async with pool.acquire() as conn:
# Pre-fetch all category IDs
categories = {}
rows = await conn.fetch("SELECT id, name FROM edu_search_categories")
for row in rows:
categories[row["name"]] = row["id"]
for seed in request.seeds:
try:
category_id = categories.get(seed.category_name) if seed.category_name else None
await conn.execute("""
INSERT INTO edu_search_seeds (
url, name, description, category_id, source_type, scope,
state, trust_boost, enabled, crawl_depth, crawl_frequency
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (url) DO NOTHING
""",
seed.url, seed.name, seed.description, category_id,
seed.source_type, seed.scope, seed.state, seed.trust_boost,
seed.enabled, seed.crawl_depth, seed.crawl_frequency
)
imported += 1
except asyncpg.UniqueViolationError:
skipped += 1
except Exception as e:
errors.append(f"{seed.url}: {str(e)}")
return BulkImportResponse(imported=imported, skipped=skipped, errors=errors)
@router.get("/stats", response_model=StatsResponse)
async def get_stats():
"""Get crawl statistics."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Basic counts
total = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds")
enabled = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds WHERE enabled = TRUE")
total_docs = await conn.fetchval("SELECT COALESCE(SUM(total_documents), 0) FROM edu_search_seeds")
# By category
cat_rows = await conn.fetch("""
SELECT c.name, COUNT(s.id) as count
FROM edu_search_categories c
LEFT JOIN edu_search_seeds s ON c.id = s.category_id
GROUP BY c.name
""")
by_category = {row["name"]: row["count"] for row in cat_rows}
# By state
state_rows = await conn.fetch("""
SELECT COALESCE(state, 'federal') as state, COUNT(*) as count
FROM edu_search_seeds
GROUP BY state
""")
by_state = {row["state"]: row["count"] for row in state_rows}
# Last crawl time
last_crawl = await conn.fetchval(
"SELECT MAX(last_crawled_at) FROM edu_search_seeds"
)
return StatsResponse(
total_seeds=total,
enabled_seeds=enabled,
total_documents=total_docs,
seeds_by_category=by_category,
seeds_by_state=by_state,
last_crawl_time=last_crawl,
)
# Export for external use (edu-search-service)
@router.get("/seeds/export/for-crawler")
async def export_seeds_for_crawler():
"""Export enabled seeds in format suitable for crawler."""
pool = await get_db_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT
s.url, s.trust_boost, s.source_type, s.scope, s.state,
s.crawl_depth, c.name as category
FROM edu_search_seeds s
LEFT JOIN edu_search_categories c ON s.category_id = c.id
WHERE s.enabled = TRUE
ORDER BY s.trust_boost DESC
""")
return {
"seeds": [
{
"url": row["url"],
"trust": float(row["trust_boost"]),
"source": row["source_type"],
"scope": row["scope"],
"state": row["state"],
"depth": row["crawl_depth"],
"category": row["category"],
}
for row in rows
],
"total": len(rows),
"exported_at": datetime.utcnow().isoformat(),
}
# =============================================================================
# Crawl Status Feedback (from edu-search-service)
# =============================================================================
class CrawlStatusUpdate(BaseModel):
"""Crawl status update from edu-search-service."""
seed_url: str = Field(..., description="The seed URL that was crawled")
status: str = Field(..., description="Crawl status: success, error, partial")
documents_crawled: int = Field(0, ge=0, description="Number of documents crawled")
error_message: Optional[str] = Field(None, description="Error message if status is error")
crawl_duration_seconds: float = Field(0.0, ge=0.0, description="Duration of the crawl in seconds")
class CrawlStatusResponse(BaseModel):
"""Response for crawl status update."""
success: bool
seed_url: str
message: str
@router.post("/seeds/crawl-status", response_model=CrawlStatusResponse)
async def update_crawl_status(update: CrawlStatusUpdate):
"""Update crawl status for a seed URL (called by edu-search-service)."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Find the seed by URL
seed = await conn.fetchrow(
"SELECT id, total_documents FROM edu_search_seeds WHERE url = $1",
update.seed_url
)
if not seed:
raise HTTPException(
status_code=404,
detail=f"Seed nicht gefunden: {update.seed_url}"
)
# Update the seed with crawl status
new_total = (seed["total_documents"] or 0) + update.documents_crawled
await conn.execute("""
UPDATE edu_search_seeds
SET
last_crawled_at = NOW(),
last_crawl_status = $2,
last_crawl_docs = $3,
total_documents = $4,
updated_at = NOW()
WHERE id = $1
""", seed["id"], update.status, update.documents_crawled, new_total)
logger.info(
f"Crawl status updated: {update.seed_url} - "
f"status={update.status}, docs={update.documents_crawled}, "
f"duration={update.crawl_duration_seconds:.1f}s"
)
return CrawlStatusResponse(
success=True,
seed_url=update.seed_url,
message=f"Status aktualisiert: {update.documents_crawled} Dokumente gecrawlt"
)
class BulkCrawlStatusUpdate(BaseModel):
"""Bulk crawl status update."""
updates: List[CrawlStatusUpdate]
class BulkCrawlStatusResponse(BaseModel):
"""Response for bulk crawl status update."""
updated: int
failed: int
errors: List[str]
@router.post("/seeds/crawl-status/bulk", response_model=BulkCrawlStatusResponse)
async def bulk_update_crawl_status(request: BulkCrawlStatusUpdate):
"""Bulk update crawl status for multiple seeds."""
pool = await get_db_pool()
updated = 0
failed = 0
errors = []
async with pool.acquire() as conn:
for update in request.updates:
try:
seed = await conn.fetchrow(
"SELECT id, total_documents FROM edu_search_seeds WHERE url = $1",
update.seed_url
)
if not seed:
failed += 1
errors.append(f"Seed nicht gefunden: {update.seed_url}")
continue
new_total = (seed["total_documents"] or 0) + update.documents_crawled
await conn.execute("""
UPDATE edu_search_seeds
SET
last_crawled_at = NOW(),
last_crawl_status = $2,
last_crawl_docs = $3,
total_documents = $4,
updated_at = NOW()
WHERE id = $1
""", seed["id"], update.status, update.documents_crawled, new_total)
updated += 1
except Exception as e:
failed += 1
errors.append(f"{update.seed_url}: {str(e)}")
logger.info(f"Bulk crawl status update: {updated} updated, {failed} failed")
return BulkCrawlStatusResponse(
updated=updated,
failed=failed,
errors=errors
)
router.include_router(_crud_router)
router.include_router(_status_router)
__all__ = [
"router",
"get_db_pool",
# Models
"CategoryResponse",
"SeedBase",
"SeedCreate",
"SeedUpdate",
"SeedResponse",
"SeedsListResponse",
"StatsResponse",
"BulkImportRequest",
"BulkImportResponse",
"CrawlStatusUpdate",
"CrawlStatusResponse",
"BulkCrawlStatusUpdate",
"BulkCrawlStatusResponse",
]
@@ -0,0 +1,198 @@
"""
EduSearch Seeds Stats & Crawl Status Routes.
Statistics, export for crawler, and crawl status feedback endpoints.
"""
import logging
from typing import List
from datetime import datetime
from fastapi import APIRouter, HTTPException
import asyncpg
from .edu_search_models import (
StatsResponse,
CrawlStatusUpdate,
CrawlStatusResponse,
BulkCrawlStatusUpdate,
BulkCrawlStatusResponse,
)
from .edu_search_crud import get_db_pool
logger = logging.getLogger(__name__)
router = APIRouter(tags=["edu-search"])
@router.get("/stats", response_model=StatsResponse)
async def get_stats():
"""Get crawl statistics."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Basic counts
total = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds")
enabled = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds WHERE enabled = TRUE")
total_docs = await conn.fetchval("SELECT COALESCE(SUM(total_documents), 0) FROM edu_search_seeds")
# By category
cat_rows = await conn.fetch("""
SELECT c.name, COUNT(s.id) as count
FROM edu_search_categories c
LEFT JOIN edu_search_seeds s ON c.id = s.category_id
GROUP BY c.name
""")
by_category = {row["name"]: row["count"] for row in cat_rows}
# By state
state_rows = await conn.fetch("""
SELECT COALESCE(state, 'federal') as state, COUNT(*) as count
FROM edu_search_seeds
GROUP BY state
""")
by_state = {row["state"]: row["count"] for row in state_rows}
# Last crawl time
last_crawl = await conn.fetchval(
"SELECT MAX(last_crawled_at) FROM edu_search_seeds"
)
return StatsResponse(
total_seeds=total,
enabled_seeds=enabled,
total_documents=total_docs,
seeds_by_category=by_category,
seeds_by_state=by_state,
last_crawl_time=last_crawl,
)
# Export for external use (edu-search-service)
@router.get("/seeds/export/for-crawler")
async def export_seeds_for_crawler():
"""Export enabled seeds in format suitable for crawler."""
pool = await get_db_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT
s.url, s.trust_boost, s.source_type, s.scope, s.state,
s.crawl_depth, c.name as category
FROM edu_search_seeds s
LEFT JOIN edu_search_categories c ON s.category_id = c.id
WHERE s.enabled = TRUE
ORDER BY s.trust_boost DESC
""")
return {
"seeds": [
{
"url": row["url"],
"trust": float(row["trust_boost"]),
"source": row["source_type"],
"scope": row["scope"],
"state": row["state"],
"depth": row["crawl_depth"],
"category": row["category"],
}
for row in rows
],
"total": len(rows),
"exported_at": datetime.utcnow().isoformat(),
}
# =============================================================================
# Crawl Status Feedback (from edu-search-service)
# =============================================================================
@router.post("/seeds/crawl-status", response_model=CrawlStatusResponse)
async def update_crawl_status(update: CrawlStatusUpdate):
"""Update crawl status for a seed URL (called by edu-search-service)."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Find the seed by URL
seed = await conn.fetchrow(
"SELECT id, total_documents FROM edu_search_seeds WHERE url = $1",
update.seed_url
)
if not seed:
raise HTTPException(
status_code=404,
detail=f"Seed nicht gefunden: {update.seed_url}"
)
# Update the seed with crawl status
new_total = (seed["total_documents"] or 0) + update.documents_crawled
await conn.execute("""
UPDATE edu_search_seeds
SET
last_crawled_at = NOW(),
last_crawl_status = $2,
last_crawl_docs = $3,
total_documents = $4,
updated_at = NOW()
WHERE id = $1
""", seed["id"], update.status, update.documents_crawled, new_total)
logger.info(
f"Crawl status updated: {update.seed_url} - "
f"status={update.status}, docs={update.documents_crawled}, "
f"duration={update.crawl_duration_seconds:.1f}s"
)
return CrawlStatusResponse(
success=True,
seed_url=update.seed_url,
message=f"Status aktualisiert: {update.documents_crawled} Dokumente gecrawlt"
)
@router.post("/seeds/crawl-status/bulk", response_model=BulkCrawlStatusResponse)
async def bulk_update_crawl_status(request: BulkCrawlStatusUpdate):
"""Bulk update crawl status for multiple seeds."""
pool = await get_db_pool()
updated = 0
failed = 0
errors = []
async with pool.acquire() as conn:
for update in request.updates:
try:
seed = await conn.fetchrow(
"SELECT id, total_documents FROM edu_search_seeds WHERE url = $1",
update.seed_url
)
if not seed:
failed += 1
errors.append(f"Seed nicht gefunden: {update.seed_url}")
continue
new_total = (seed["total_documents"] or 0) + update.documents_crawled
await conn.execute("""
UPDATE edu_search_seeds
SET
last_crawled_at = NOW(),
last_crawl_status = $2,
last_crawl_docs = $3,
total_documents = $4,
updated_at = NOW()
WHERE id = $1
""", seed["id"], update.status, update.documents_crawled, new_total)
updated += 1
except Exception as e:
failed += 1
errors.append(f"{update.seed_url}: {str(e)}")
logger.info(f"Bulk crawl status update: {updated} updated, {failed} failed")
return BulkCrawlStatusResponse(
updated=updated,
failed=failed,
errors=errors
)
+30 -859
View File
@@ -1,867 +1,38 @@
"""
Schools API Routes.
Schools API Routes — Barrel Re-export.
CRUD operations for managing German schools (~40,000 schools).
Direct database access to PostgreSQL.
Split into:
- schools_models.py: Pydantic models
- schools_db.py: Database connection pool
- schools_crud.py: School CRUD & stats routes
- schools_staff.py: Staff CRUD & search routes
"""
import os
import logging
from typing import Optional, List
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
import asyncpg
logger = logging.getLogger(__name__)
from .schools_crud import router as _crud_router
from .schools_staff import router as _staff_router
# Single router that merges both sub-module routers
router = APIRouter(prefix="/schools", tags=["schools"])
# Database connection pool
_pool: Optional[asyncpg.Pool] = None
async def get_db_pool() -> asyncpg.Pool:
"""Get or create database connection pool."""
global _pool
if _pool is None:
database_url = os.environ.get(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db"
)
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
return _pool
# =============================================================================
# Pydantic Models
# =============================================================================
class SchoolTypeResponse(BaseModel):
"""School type response model."""
id: str
name: str
name_short: Optional[str] = None
category: Optional[str] = None
description: Optional[str] = None
class SchoolBase(BaseModel):
"""Base school model for creation/update."""
name: str = Field(..., max_length=255)
school_number: Optional[str] = Field(None, max_length=20)
school_type_id: Optional[str] = None
school_type_raw: Optional[str] = None
state: str = Field(..., max_length=10)
district: Optional[str] = None
city: Optional[str] = None
postal_code: Optional[str] = None
street: Optional[str] = None
address_full: Optional[str] = None
latitude: Optional[float] = None
longitude: Optional[float] = None
website: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
fax: Optional[str] = None
principal_name: Optional[str] = None
principal_title: Optional[str] = None
principal_email: Optional[str] = None
principal_phone: Optional[str] = None
secretary_name: Optional[str] = None
secretary_email: Optional[str] = None
secretary_phone: Optional[str] = None
student_count: Optional[int] = None
teacher_count: Optional[int] = None
class_count: Optional[int] = None
founded_year: Optional[int] = None
is_public: bool = True
is_all_day: Optional[bool] = None
has_inclusion: Optional[bool] = None
languages: Optional[List[str]] = None
specializations: Optional[List[str]] = None
source: Optional[str] = None
source_url: Optional[str] = None
class SchoolCreate(SchoolBase):
"""School creation model."""
pass
class SchoolUpdate(BaseModel):
"""School update model (all fields optional)."""
name: Optional[str] = Field(None, max_length=255)
school_number: Optional[str] = None
school_type_id: Optional[str] = None
state: Optional[str] = None
district: Optional[str] = None
city: Optional[str] = None
postal_code: Optional[str] = None
street: Optional[str] = None
website: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
principal_name: Optional[str] = None
student_count: Optional[int] = None
teacher_count: Optional[int] = None
is_active: Optional[bool] = None
class SchoolResponse(BaseModel):
"""School response model."""
id: str
name: str
school_number: Optional[str] = None
school_type: Optional[str] = None
school_type_short: Optional[str] = None
school_category: Optional[str] = None
state: str
district: Optional[str] = None
city: Optional[str] = None
postal_code: Optional[str] = None
street: Optional[str] = None
address_full: Optional[str] = None
latitude: Optional[float] = None
longitude: Optional[float] = None
website: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
fax: Optional[str] = None
principal_name: Optional[str] = None
principal_email: Optional[str] = None
student_count: Optional[int] = None
teacher_count: Optional[int] = None
is_public: bool = True
is_all_day: Optional[bool] = None
staff_count: int = 0
source: Optional[str] = None
crawled_at: Optional[datetime] = None
is_active: bool = True
created_at: datetime
updated_at: datetime
class SchoolsListResponse(BaseModel):
"""List response with pagination info."""
schools: List[SchoolResponse]
total: int
page: int
page_size: int
class SchoolStaffBase(BaseModel):
"""Base school staff model."""
first_name: Optional[str] = None
last_name: str
full_name: Optional[str] = None
title: Optional[str] = None
position: Optional[str] = None
position_type: Optional[str] = None
subjects: Optional[List[str]] = None
email: Optional[str] = None
phone: Optional[str] = None
class SchoolStaffCreate(SchoolStaffBase):
"""School staff creation model."""
school_id: str
class SchoolStaffResponse(SchoolStaffBase):
"""School staff response model."""
id: str
school_id: str
school_name: Optional[str] = None
profile_url: Optional[str] = None
photo_url: Optional[str] = None
is_active: bool = True
created_at: datetime
class SchoolStaffListResponse(BaseModel):
"""Staff list response."""
staff: List[SchoolStaffResponse]
total: int
class SchoolStatsResponse(BaseModel):
"""School statistics response."""
total_schools: int
total_staff: int
schools_by_state: dict
schools_by_type: dict
schools_with_website: int
schools_with_email: int
schools_with_principal: int
total_students: int
total_teachers: int
last_crawl_time: Optional[datetime] = None
class BulkImportRequest(BaseModel):
"""Bulk import request."""
schools: List[SchoolCreate]
class BulkImportResponse(BaseModel):
"""Bulk import response."""
imported: int
updated: int
skipped: int
errors: List[str]
# =============================================================================
# School Type Endpoints
# =============================================================================
@router.get("/types", response_model=List[SchoolTypeResponse])
async def list_school_types():
"""List all school types."""
pool = await get_db_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, name_short, category, description
FROM school_types
ORDER BY category, name
""")
return [
SchoolTypeResponse(
id=str(row["id"]),
name=row["name"],
name_short=row["name_short"],
category=row["category"],
description=row["description"],
)
for row in rows
]
# =============================================================================
# School Endpoints
# =============================================================================
@router.get("", response_model=SchoolsListResponse)
async def list_schools(
state: Optional[str] = Query(None, description="Filter by state code (BW, BY, etc.)"),
school_type: Optional[str] = Query(None, description="Filter by school type name"),
city: Optional[str] = Query(None, description="Filter by city"),
district: Optional[str] = Query(None, description="Filter by district"),
postal_code: Optional[str] = Query(None, description="Filter by postal code prefix"),
search: Optional[str] = Query(None, description="Search in name, city"),
has_email: Optional[bool] = Query(None, description="Filter schools with email"),
has_website: Optional[bool] = Query(None, description="Filter schools with website"),
is_public: Optional[bool] = Query(None, description="Filter public/private schools"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
"""List schools with optional filtering and pagination."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Build WHERE clause
conditions = ["s.is_active = TRUE"]
params = []
param_idx = 1
if state:
conditions.append(f"s.state = ${param_idx}")
params.append(state.upper())
param_idx += 1
if school_type:
conditions.append(f"st.name = ${param_idx}")
params.append(school_type)
param_idx += 1
if city:
conditions.append(f"LOWER(s.city) = LOWER(${param_idx})")
params.append(city)
param_idx += 1
if district:
conditions.append(f"LOWER(s.district) LIKE LOWER(${param_idx})")
params.append(f"%{district}%")
param_idx += 1
if postal_code:
conditions.append(f"s.postal_code LIKE ${param_idx}")
params.append(f"{postal_code}%")
param_idx += 1
if search:
conditions.append(f"""
(LOWER(s.name) LIKE LOWER(${param_idx})
OR LOWER(s.city) LIKE LOWER(${param_idx})
OR LOWER(s.district) LIKE LOWER(${param_idx}))
""")
params.append(f"%{search}%")
param_idx += 1
if has_email is not None:
if has_email:
conditions.append("s.email IS NOT NULL")
else:
conditions.append("s.email IS NULL")
if has_website is not None:
if has_website:
conditions.append("s.website IS NOT NULL")
else:
conditions.append("s.website IS NULL")
if is_public is not None:
conditions.append(f"s.is_public = ${param_idx}")
params.append(is_public)
param_idx += 1
where_clause = " AND ".join(conditions)
# Count total
count_query = f"""
SELECT COUNT(*) FROM schools s
LEFT JOIN school_types st ON s.school_type_id = st.id
WHERE {where_clause}
"""
total = await conn.fetchval(count_query, *params)
# Fetch schools
offset = (page - 1) * page_size
query = f"""
SELECT
s.id, s.name, s.school_number, s.state, s.district, s.city,
s.postal_code, s.street, s.address_full, s.latitude, s.longitude,
s.website, s.email, s.phone, s.fax,
s.principal_name, s.principal_email,
s.student_count, s.teacher_count,
s.is_public, s.is_all_day, s.source, s.crawled_at,
s.is_active, s.created_at, s.updated_at,
st.name as school_type, st.name_short as school_type_short, st.category as school_category,
(SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count
FROM schools s
LEFT JOIN school_types st ON s.school_type_id = st.id
WHERE {where_clause}
ORDER BY s.state, s.city, s.name
LIMIT ${param_idx} OFFSET ${param_idx + 1}
"""
params.extend([page_size, offset])
rows = await conn.fetch(query, *params)
schools = [
SchoolResponse(
id=str(row["id"]),
name=row["name"],
school_number=row["school_number"],
school_type=row["school_type"],
school_type_short=row["school_type_short"],
school_category=row["school_category"],
state=row["state"],
district=row["district"],
city=row["city"],
postal_code=row["postal_code"],
street=row["street"],
address_full=row["address_full"],
latitude=row["latitude"],
longitude=row["longitude"],
website=row["website"],
email=row["email"],
phone=row["phone"],
fax=row["fax"],
principal_name=row["principal_name"],
principal_email=row["principal_email"],
student_count=row["student_count"],
teacher_count=row["teacher_count"],
is_public=row["is_public"],
is_all_day=row["is_all_day"],
staff_count=row["staff_count"],
source=row["source"],
crawled_at=row["crawled_at"],
is_active=row["is_active"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
for row in rows
]
return SchoolsListResponse(
schools=schools,
total=total,
page=page,
page_size=page_size,
)
@router.get("/stats", response_model=SchoolStatsResponse)
async def get_school_stats():
"""Get school statistics."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Total schools and staff
totals = await conn.fetchrow("""
SELECT
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE) as total_schools,
(SELECT COUNT(*) FROM school_staff WHERE is_active = TRUE) as total_staff,
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND website IS NOT NULL) as with_website,
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND email IS NOT NULL) as with_email,
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND principal_name IS NOT NULL) as with_principal,
(SELECT COALESCE(SUM(student_count), 0) FROM schools WHERE is_active = TRUE) as total_students,
(SELECT COALESCE(SUM(teacher_count), 0) FROM schools WHERE is_active = TRUE) as total_teachers,
(SELECT MAX(crawled_at) FROM schools) as last_crawl
""")
# By state
state_rows = await conn.fetch("""
SELECT state, COUNT(*) as count
FROM schools
WHERE is_active = TRUE
GROUP BY state
ORDER BY state
""")
schools_by_state = {row["state"]: row["count"] for row in state_rows}
# By type
type_rows = await conn.fetch("""
SELECT COALESCE(st.name, 'Unbekannt') as type_name, COUNT(*) as count
FROM schools s
LEFT JOIN school_types st ON s.school_type_id = st.id
WHERE s.is_active = TRUE
GROUP BY st.name
ORDER BY count DESC
""")
schools_by_type = {row["type_name"]: row["count"] for row in type_rows}
return SchoolStatsResponse(
total_schools=totals["total_schools"],
total_staff=totals["total_staff"],
schools_by_state=schools_by_state,
schools_by_type=schools_by_type,
schools_with_website=totals["with_website"],
schools_with_email=totals["with_email"],
schools_with_principal=totals["with_principal"],
total_students=totals["total_students"],
total_teachers=totals["total_teachers"],
last_crawl_time=totals["last_crawl"],
)
@router.get("/{school_id}", response_model=SchoolResponse)
async def get_school(school_id: str):
"""Get a single school by ID."""
pool = await get_db_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT
s.id, s.name, s.school_number, s.state, s.district, s.city,
s.postal_code, s.street, s.address_full, s.latitude, s.longitude,
s.website, s.email, s.phone, s.fax,
s.principal_name, s.principal_email,
s.student_count, s.teacher_count,
s.is_public, s.is_all_day, s.source, s.crawled_at,
s.is_active, s.created_at, s.updated_at,
st.name as school_type, st.name_short as school_type_short, st.category as school_category,
(SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count
FROM schools s
LEFT JOIN school_types st ON s.school_type_id = st.id
WHERE s.id = $1
""", school_id)
if not row:
raise HTTPException(status_code=404, detail="School not found")
return SchoolResponse(
id=str(row["id"]),
name=row["name"],
school_number=row["school_number"],
school_type=row["school_type"],
school_type_short=row["school_type_short"],
school_category=row["school_category"],
state=row["state"],
district=row["district"],
city=row["city"],
postal_code=row["postal_code"],
street=row["street"],
address_full=row["address_full"],
latitude=row["latitude"],
longitude=row["longitude"],
website=row["website"],
email=row["email"],
phone=row["phone"],
fax=row["fax"],
principal_name=row["principal_name"],
principal_email=row["principal_email"],
student_count=row["student_count"],
teacher_count=row["teacher_count"],
is_public=row["is_public"],
is_all_day=row["is_all_day"],
staff_count=row["staff_count"],
source=row["source"],
crawled_at=row["crawled_at"],
is_active=row["is_active"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
@router.post("/bulk-import", response_model=BulkImportResponse)
async def bulk_import_schools(request: BulkImportRequest):
"""Bulk import schools. Updates existing schools based on school_number + state."""
pool = await get_db_pool()
imported = 0
updated = 0
skipped = 0
errors = []
async with pool.acquire() as conn:
# Get school type mapping
type_rows = await conn.fetch("SELECT id, name FROM school_types")
type_map = {row["name"].lower(): str(row["id"]) for row in type_rows}
for school in request.schools:
try:
# Find school type ID
school_type_id = None
if school.school_type_raw:
school_type_id = type_map.get(school.school_type_raw.lower())
# Check if school exists (by school_number + state, or by name + city + state)
existing = None
if school.school_number:
existing = await conn.fetchrow(
"SELECT id FROM schools WHERE school_number = $1 AND state = $2",
school.school_number, school.state
)
if not existing and school.city:
existing = await conn.fetchrow(
"SELECT id FROM schools WHERE LOWER(name) = LOWER($1) AND LOWER(city) = LOWER($2) AND state = $3",
school.name, school.city, school.state
)
if existing:
# Update existing school
await conn.execute("""
UPDATE schools SET
name = $2,
school_type_id = COALESCE($3, school_type_id),
school_type_raw = COALESCE($4, school_type_raw),
district = COALESCE($5, district),
city = COALESCE($6, city),
postal_code = COALESCE($7, postal_code),
street = COALESCE($8, street),
address_full = COALESCE($9, address_full),
latitude = COALESCE($10, latitude),
longitude = COALESCE($11, longitude),
website = COALESCE($12, website),
email = COALESCE($13, email),
phone = COALESCE($14, phone),
fax = COALESCE($15, fax),
principal_name = COALESCE($16, principal_name),
principal_title = COALESCE($17, principal_title),
principal_email = COALESCE($18, principal_email),
principal_phone = COALESCE($19, principal_phone),
student_count = COALESCE($20, student_count),
teacher_count = COALESCE($21, teacher_count),
is_public = $22,
source = COALESCE($23, source),
source_url = COALESCE($24, source_url),
updated_at = NOW()
WHERE id = $1
""",
existing["id"],
school.name,
school_type_id,
school.school_type_raw,
school.district,
school.city,
school.postal_code,
school.street,
school.address_full,
school.latitude,
school.longitude,
school.website,
school.email,
school.phone,
school.fax,
school.principal_name,
school.principal_title,
school.principal_email,
school.principal_phone,
school.student_count,
school.teacher_count,
school.is_public,
school.source,
school.source_url,
)
updated += 1
else:
# Insert new school
await conn.execute("""
INSERT INTO schools (
name, school_number, school_type_id, school_type_raw,
state, district, city, postal_code, street, address_full,
latitude, longitude, website, email, phone, fax,
principal_name, principal_title, principal_email, principal_phone,
student_count, teacher_count, is_public,
source, source_url, crawled_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, NOW()
)
""",
school.name,
school.school_number,
school_type_id,
school.school_type_raw,
school.state,
school.district,
school.city,
school.postal_code,
school.street,
school.address_full,
school.latitude,
school.longitude,
school.website,
school.email,
school.phone,
school.fax,
school.principal_name,
school.principal_title,
school.principal_email,
school.principal_phone,
school.student_count,
school.teacher_count,
school.is_public,
school.source,
school.source_url,
)
imported += 1
except Exception as e:
errors.append(f"Error importing {school.name}: {str(e)}")
if len(errors) > 100:
errors.append("... (more errors truncated)")
break
return BulkImportResponse(
imported=imported,
updated=updated,
skipped=skipped,
errors=errors[:100],
)
# =============================================================================
# School Staff Endpoints
# =============================================================================
@router.get("/{school_id}/staff", response_model=SchoolStaffListResponse)
async def get_school_staff(school_id: str):
"""Get staff members for a school."""
pool = await get_db_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT
ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name,
ss.title, ss.position, ss.position_type, ss.subjects,
ss.email, ss.phone, ss.profile_url, ss.photo_url,
ss.is_active, ss.created_at,
s.name as school_name
FROM school_staff ss
JOIN schools s ON ss.school_id = s.id
WHERE ss.school_id = $1 AND ss.is_active = TRUE
ORDER BY
CASE ss.position_type
WHEN 'principal' THEN 1
WHEN 'vice_principal' THEN 2
WHEN 'secretary' THEN 3
ELSE 4
END,
ss.last_name
""", school_id)
staff = [
SchoolStaffResponse(
id=str(row["id"]),
school_id=str(row["school_id"]),
school_name=row["school_name"],
first_name=row["first_name"],
last_name=row["last_name"],
full_name=row["full_name"],
title=row["title"],
position=row["position"],
position_type=row["position_type"],
subjects=row["subjects"],
email=row["email"],
phone=row["phone"],
profile_url=row["profile_url"],
photo_url=row["photo_url"],
is_active=row["is_active"],
created_at=row["created_at"],
)
for row in rows
]
return SchoolStaffListResponse(
staff=staff,
total=len(staff),
)
@router.post("/{school_id}/staff", response_model=SchoolStaffResponse)
async def create_school_staff(school_id: str, staff: SchoolStaffBase):
"""Add a staff member to a school."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Verify school exists
school = await conn.fetchrow("SELECT name FROM schools WHERE id = $1", school_id)
if not school:
raise HTTPException(status_code=404, detail="School not found")
# Create full name
full_name = staff.full_name
if not full_name:
parts = []
if staff.title:
parts.append(staff.title)
if staff.first_name:
parts.append(staff.first_name)
parts.append(staff.last_name)
full_name = " ".join(parts)
row = await conn.fetchrow("""
INSERT INTO school_staff (
school_id, first_name, last_name, full_name, title,
position, position_type, subjects, email, phone
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id, created_at
""",
school_id,
staff.first_name,
staff.last_name,
full_name,
staff.title,
staff.position,
staff.position_type,
staff.subjects,
staff.email,
staff.phone,
)
return SchoolStaffResponse(
id=str(row["id"]),
school_id=school_id,
school_name=school["name"],
first_name=staff.first_name,
last_name=staff.last_name,
full_name=full_name,
title=staff.title,
position=staff.position,
position_type=staff.position_type,
subjects=staff.subjects,
email=staff.email,
phone=staff.phone,
is_active=True,
created_at=row["created_at"],
)
# =============================================================================
# Search Endpoints
# =============================================================================
@router.get("/search/staff", response_model=SchoolStaffListResponse)
async def search_school_staff(
q: Optional[str] = Query(None, description="Search query"),
state: Optional[str] = Query(None, description="Filter by state"),
position_type: Optional[str] = Query(None, description="Filter by position type"),
has_email: Optional[bool] = Query(None, description="Only staff with email"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
"""Search school staff across all schools."""
pool = await get_db_pool()
async with pool.acquire() as conn:
conditions = ["ss.is_active = TRUE", "s.is_active = TRUE"]
params = []
param_idx = 1
if q:
conditions.append(f"""
(LOWER(ss.full_name) LIKE LOWER(${param_idx})
OR LOWER(ss.last_name) LIKE LOWER(${param_idx})
OR LOWER(s.name) LIKE LOWER(${param_idx}))
""")
params.append(f"%{q}%")
param_idx += 1
if state:
conditions.append(f"s.state = ${param_idx}")
params.append(state.upper())
param_idx += 1
if position_type:
conditions.append(f"ss.position_type = ${param_idx}")
params.append(position_type)
param_idx += 1
if has_email is not None and has_email:
conditions.append("ss.email IS NOT NULL")
where_clause = " AND ".join(conditions)
# Count total
total = await conn.fetchval(f"""
SELECT COUNT(*) FROM school_staff ss
JOIN schools s ON ss.school_id = s.id
WHERE {where_clause}
""", *params)
# Fetch staff
offset = (page - 1) * page_size
rows = await conn.fetch(f"""
SELECT
ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name,
ss.title, ss.position, ss.position_type, ss.subjects,
ss.email, ss.phone, ss.profile_url, ss.photo_url,
ss.is_active, ss.created_at,
s.name as school_name
FROM school_staff ss
JOIN schools s ON ss.school_id = s.id
WHERE {where_clause}
ORDER BY ss.last_name, ss.first_name
LIMIT ${param_idx} OFFSET ${param_idx + 1}
""", *params, page_size, offset)
staff = [
SchoolStaffResponse(
id=str(row["id"]),
school_id=str(row["school_id"]),
school_name=row["school_name"],
first_name=row["first_name"],
last_name=row["last_name"],
full_name=row["full_name"],
title=row["title"],
position=row["position"],
position_type=row["position_type"],
subjects=row["subjects"],
email=row["email"],
phone=row["phone"],
profile_url=row["profile_url"],
photo_url=row["photo_url"],
is_active=row["is_active"],
created_at=row["created_at"],
)
for row in rows
]
return SchoolStaffListResponse(
staff=staff,
total=total,
)
router.include_router(_crud_router)
router.include_router(_staff_router)
# Re-export models for any external consumers
from .schools_models import ( # noqa: E402, F401
SchoolTypeResponse,
SchoolBase,
SchoolCreate,
SchoolUpdate,
SchoolResponse,
SchoolsListResponse,
SchoolStaffBase,
SchoolStaffCreate,
SchoolStaffResponse,
SchoolStaffListResponse,
SchoolStatsResponse,
BulkImportRequest,
BulkImportResponse,
)
from .schools_db import get_db_pool # noqa: E402, F401
@@ -0,0 +1,464 @@
"""
Schools API - School CRUD & Stats Routes.
List, get, stats, and bulk-import endpoints for schools.
"""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException, Query
from .schools_db import get_db_pool
from .schools_models import (
SchoolResponse,
SchoolsListResponse,
SchoolStatsResponse,
SchoolTypeResponse,
BulkImportRequest,
BulkImportResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["schools"])
# =============================================================================
# School Type Endpoints
# =============================================================================
@router.get("/types", response_model=list[SchoolTypeResponse])
async def list_school_types():
"""List all school types."""
pool = await get_db_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, name_short, category, description
FROM school_types
ORDER BY category, name
""")
return [
SchoolTypeResponse(
id=str(row["id"]),
name=row["name"],
name_short=row["name_short"],
category=row["category"],
description=row["description"],
)
for row in rows
]
# =============================================================================
# School Endpoints
# =============================================================================
@router.get("", response_model=SchoolsListResponse)
async def list_schools(
state: Optional[str] = Query(None, description="Filter by state code (BW, BY, etc.)"),
school_type: Optional[str] = Query(None, description="Filter by school type name"),
city: Optional[str] = Query(None, description="Filter by city"),
district: Optional[str] = Query(None, description="Filter by district"),
postal_code: Optional[str] = Query(None, description="Filter by postal code prefix"),
search: Optional[str] = Query(None, description="Search in name, city"),
has_email: Optional[bool] = Query(None, description="Filter schools with email"),
has_website: Optional[bool] = Query(None, description="Filter schools with website"),
is_public: Optional[bool] = Query(None, description="Filter public/private schools"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
"""List schools with optional filtering and pagination."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Build WHERE clause
conditions = ["s.is_active = TRUE"]
params = []
param_idx = 1
if state:
conditions.append(f"s.state = ${param_idx}")
params.append(state.upper())
param_idx += 1
if school_type:
conditions.append(f"st.name = ${param_idx}")
params.append(school_type)
param_idx += 1
if city:
conditions.append(f"LOWER(s.city) = LOWER(${param_idx})")
params.append(city)
param_idx += 1
if district:
conditions.append(f"LOWER(s.district) LIKE LOWER(${param_idx})")
params.append(f"%{district}%")
param_idx += 1
if postal_code:
conditions.append(f"s.postal_code LIKE ${param_idx}")
params.append(f"{postal_code}%")
param_idx += 1
if search:
conditions.append(f"""
(LOWER(s.name) LIKE LOWER(${param_idx})
OR LOWER(s.city) LIKE LOWER(${param_idx})
OR LOWER(s.district) LIKE LOWER(${param_idx}))
""")
params.append(f"%{search}%")
param_idx += 1
if has_email is not None:
if has_email:
conditions.append("s.email IS NOT NULL")
else:
conditions.append("s.email IS NULL")
if has_website is not None:
if has_website:
conditions.append("s.website IS NOT NULL")
else:
conditions.append("s.website IS NULL")
if is_public is not None:
conditions.append(f"s.is_public = ${param_idx}")
params.append(is_public)
param_idx += 1
where_clause = " AND ".join(conditions)
# Count total
count_query = f"""
SELECT COUNT(*) FROM schools s
LEFT JOIN school_types st ON s.school_type_id = st.id
WHERE {where_clause}
"""
total = await conn.fetchval(count_query, *params)
# Fetch schools
offset = (page - 1) * page_size
query = f"""
SELECT
s.id, s.name, s.school_number, s.state, s.district, s.city,
s.postal_code, s.street, s.address_full, s.latitude, s.longitude,
s.website, s.email, s.phone, s.fax,
s.principal_name, s.principal_email,
s.student_count, s.teacher_count,
s.is_public, s.is_all_day, s.source, s.crawled_at,
s.is_active, s.created_at, s.updated_at,
st.name as school_type, st.name_short as school_type_short, st.category as school_category,
(SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count
FROM schools s
LEFT JOIN school_types st ON s.school_type_id = st.id
WHERE {where_clause}
ORDER BY s.state, s.city, s.name
LIMIT ${param_idx} OFFSET ${param_idx + 1}
"""
params.extend([page_size, offset])
rows = await conn.fetch(query, *params)
schools = [
SchoolResponse(
id=str(row["id"]),
name=row["name"],
school_number=row["school_number"],
school_type=row["school_type"],
school_type_short=row["school_type_short"],
school_category=row["school_category"],
state=row["state"],
district=row["district"],
city=row["city"],
postal_code=row["postal_code"],
street=row["street"],
address_full=row["address_full"],
latitude=row["latitude"],
longitude=row["longitude"],
website=row["website"],
email=row["email"],
phone=row["phone"],
fax=row["fax"],
principal_name=row["principal_name"],
principal_email=row["principal_email"],
student_count=row["student_count"],
teacher_count=row["teacher_count"],
is_public=row["is_public"],
is_all_day=row["is_all_day"],
staff_count=row["staff_count"],
source=row["source"],
crawled_at=row["crawled_at"],
is_active=row["is_active"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
for row in rows
]
return SchoolsListResponse(
schools=schools,
total=total,
page=page,
page_size=page_size,
)
@router.get("/stats", response_model=SchoolStatsResponse)
async def get_school_stats():
"""Get school statistics."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Total schools and staff
totals = await conn.fetchrow("""
SELECT
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE) as total_schools,
(SELECT COUNT(*) FROM school_staff WHERE is_active = TRUE) as total_staff,
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND website IS NOT NULL) as with_website,
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND email IS NOT NULL) as with_email,
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND principal_name IS NOT NULL) as with_principal,
(SELECT COALESCE(SUM(student_count), 0) FROM schools WHERE is_active = TRUE) as total_students,
(SELECT COALESCE(SUM(teacher_count), 0) FROM schools WHERE is_active = TRUE) as total_teachers,
(SELECT MAX(crawled_at) FROM schools) as last_crawl
""")
# By state
state_rows = await conn.fetch("""
SELECT state, COUNT(*) as count
FROM schools
WHERE is_active = TRUE
GROUP BY state
ORDER BY state
""")
schools_by_state = {row["state"]: row["count"] for row in state_rows}
# By type
type_rows = await conn.fetch("""
SELECT COALESCE(st.name, 'Unbekannt') as type_name, COUNT(*) as count
FROM schools s
LEFT JOIN school_types st ON s.school_type_id = st.id
WHERE s.is_active = TRUE
GROUP BY st.name
ORDER BY count DESC
""")
schools_by_type = {row["type_name"]: row["count"] for row in type_rows}
return SchoolStatsResponse(
total_schools=totals["total_schools"],
total_staff=totals["total_staff"],
schools_by_state=schools_by_state,
schools_by_type=schools_by_type,
schools_with_website=totals["with_website"],
schools_with_email=totals["with_email"],
schools_with_principal=totals["with_principal"],
total_students=totals["total_students"],
total_teachers=totals["total_teachers"],
last_crawl_time=totals["last_crawl"],
)
@router.get("/{school_id}", response_model=SchoolResponse)
async def get_school(school_id: str):
"""Get a single school by ID."""
pool = await get_db_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT
s.id, s.name, s.school_number, s.state, s.district, s.city,
s.postal_code, s.street, s.address_full, s.latitude, s.longitude,
s.website, s.email, s.phone, s.fax,
s.principal_name, s.principal_email,
s.student_count, s.teacher_count,
s.is_public, s.is_all_day, s.source, s.crawled_at,
s.is_active, s.created_at, s.updated_at,
st.name as school_type, st.name_short as school_type_short, st.category as school_category,
(SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count
FROM schools s
LEFT JOIN school_types st ON s.school_type_id = st.id
WHERE s.id = $1
""", school_id)
if not row:
raise HTTPException(status_code=404, detail="School not found")
return SchoolResponse(
id=str(row["id"]),
name=row["name"],
school_number=row["school_number"],
school_type=row["school_type"],
school_type_short=row["school_type_short"],
school_category=row["school_category"],
state=row["state"],
district=row["district"],
city=row["city"],
postal_code=row["postal_code"],
street=row["street"],
address_full=row["address_full"],
latitude=row["latitude"],
longitude=row["longitude"],
website=row["website"],
email=row["email"],
phone=row["phone"],
fax=row["fax"],
principal_name=row["principal_name"],
principal_email=row["principal_email"],
student_count=row["student_count"],
teacher_count=row["teacher_count"],
is_public=row["is_public"],
is_all_day=row["is_all_day"],
staff_count=row["staff_count"],
source=row["source"],
crawled_at=row["crawled_at"],
is_active=row["is_active"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
@router.post("/bulk-import", response_model=BulkImportResponse)
async def bulk_import_schools(request: BulkImportRequest):
"""Bulk import schools. Updates existing schools based on school_number + state."""
pool = await get_db_pool()
imported = 0
updated = 0
skipped = 0
errors = []
async with pool.acquire() as conn:
# Get school type mapping
type_rows = await conn.fetch("SELECT id, name FROM school_types")
type_map = {row["name"].lower(): str(row["id"]) for row in type_rows}
for school in request.schools:
try:
# Find school type ID
school_type_id = None
if school.school_type_raw:
school_type_id = type_map.get(school.school_type_raw.lower())
# Check if school exists (by school_number + state, or by name + city + state)
existing = None
if school.school_number:
existing = await conn.fetchrow(
"SELECT id FROM schools WHERE school_number = $1 AND state = $2",
school.school_number, school.state
)
if not existing and school.city:
existing = await conn.fetchrow(
"SELECT id FROM schools WHERE LOWER(name) = LOWER($1) AND LOWER(city) = LOWER($2) AND state = $3",
school.name, school.city, school.state
)
if existing:
# Update existing school
await conn.execute("""
UPDATE schools SET
name = $2,
school_type_id = COALESCE($3, school_type_id),
school_type_raw = COALESCE($4, school_type_raw),
district = COALESCE($5, district),
city = COALESCE($6, city),
postal_code = COALESCE($7, postal_code),
street = COALESCE($8, street),
address_full = COALESCE($9, address_full),
latitude = COALESCE($10, latitude),
longitude = COALESCE($11, longitude),
website = COALESCE($12, website),
email = COALESCE($13, email),
phone = COALESCE($14, phone),
fax = COALESCE($15, fax),
principal_name = COALESCE($16, principal_name),
principal_title = COALESCE($17, principal_title),
principal_email = COALESCE($18, principal_email),
principal_phone = COALESCE($19, principal_phone),
student_count = COALESCE($20, student_count),
teacher_count = COALESCE($21, teacher_count),
is_public = $22,
source = COALESCE($23, source),
source_url = COALESCE($24, source_url),
updated_at = NOW()
WHERE id = $1
""",
existing["id"],
school.name,
school_type_id,
school.school_type_raw,
school.district,
school.city,
school.postal_code,
school.street,
school.address_full,
school.latitude,
school.longitude,
school.website,
school.email,
school.phone,
school.fax,
school.principal_name,
school.principal_title,
school.principal_email,
school.principal_phone,
school.student_count,
school.teacher_count,
school.is_public,
school.source,
school.source_url,
)
updated += 1
else:
# Insert new school
await conn.execute("""
INSERT INTO schools (
name, school_number, school_type_id, school_type_raw,
state, district, city, postal_code, street, address_full,
latitude, longitude, website, email, phone, fax,
principal_name, principal_title, principal_email, principal_phone,
student_count, teacher_count, is_public,
source, source_url, crawled_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, NOW()
)
""",
school.name,
school.school_number,
school_type_id,
school.school_type_raw,
school.state,
school.district,
school.city,
school.postal_code,
school.street,
school.address_full,
school.latitude,
school.longitude,
school.website,
school.email,
school.phone,
school.fax,
school.principal_name,
school.principal_title,
school.principal_email,
school.principal_phone,
school.student_count,
school.teacher_count,
school.is_public,
school.source,
school.source_url,
)
imported += 1
except Exception as e:
errors.append(f"Error importing {school.name}: {str(e)}")
if len(errors) > 100:
errors.append("... (more errors truncated)")
break
return BulkImportResponse(
imported=imported,
updated=updated,
skipped=skipped,
errors=errors[:100],
)
@@ -0,0 +1,25 @@
"""
Schools API - Database Connection.
Shared database pool for school endpoints.
"""
import os
from typing import Optional
import asyncpg
# Database connection pool
_pool: Optional[asyncpg.Pool] = None
async def get_db_pool() -> asyncpg.Pool:
"""Get or create database connection pool."""
global _pool
if _pool is None:
database_url = os.environ.get(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db"
)
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
return _pool
@@ -0,0 +1,200 @@
"""
Schools API - Pydantic Models.
Data models for school and school staff endpoints.
"""
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field
# =============================================================================
# School Type Models
# =============================================================================
class SchoolTypeResponse(BaseModel):
"""School type response model."""
id: str
name: str
name_short: Optional[str] = None
category: Optional[str] = None
description: Optional[str] = None
# =============================================================================
# School Models
# =============================================================================
class SchoolBase(BaseModel):
"""Base school model for creation/update."""
name: str = Field(..., max_length=255)
school_number: Optional[str] = Field(None, max_length=20)
school_type_id: Optional[str] = None
school_type_raw: Optional[str] = None
state: str = Field(..., max_length=10)
district: Optional[str] = None
city: Optional[str] = None
postal_code: Optional[str] = None
street: Optional[str] = None
address_full: Optional[str] = None
latitude: Optional[float] = None
longitude: Optional[float] = None
website: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
fax: Optional[str] = None
principal_name: Optional[str] = None
principal_title: Optional[str] = None
principal_email: Optional[str] = None
principal_phone: Optional[str] = None
secretary_name: Optional[str] = None
secretary_email: Optional[str] = None
secretary_phone: Optional[str] = None
student_count: Optional[int] = None
teacher_count: Optional[int] = None
class_count: Optional[int] = None
founded_year: Optional[int] = None
is_public: bool = True
is_all_day: Optional[bool] = None
has_inclusion: Optional[bool] = None
languages: Optional[List[str]] = None
specializations: Optional[List[str]] = None
source: Optional[str] = None
source_url: Optional[str] = None
class SchoolCreate(SchoolBase):
"""School creation model."""
pass
class SchoolUpdate(BaseModel):
"""School update model (all fields optional)."""
name: Optional[str] = Field(None, max_length=255)
school_number: Optional[str] = None
school_type_id: Optional[str] = None
state: Optional[str] = None
district: Optional[str] = None
city: Optional[str] = None
postal_code: Optional[str] = None
street: Optional[str] = None
website: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
principal_name: Optional[str] = None
student_count: Optional[int] = None
teacher_count: Optional[int] = None
is_active: Optional[bool] = None
class SchoolResponse(BaseModel):
"""School response model."""
id: str
name: str
school_number: Optional[str] = None
school_type: Optional[str] = None
school_type_short: Optional[str] = None
school_category: Optional[str] = None
state: str
district: Optional[str] = None
city: Optional[str] = None
postal_code: Optional[str] = None
street: Optional[str] = None
address_full: Optional[str] = None
latitude: Optional[float] = None
longitude: Optional[float] = None
website: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
fax: Optional[str] = None
principal_name: Optional[str] = None
principal_email: Optional[str] = None
student_count: Optional[int] = None
teacher_count: Optional[int] = None
is_public: bool = True
is_all_day: Optional[bool] = None
staff_count: int = 0
source: Optional[str] = None
crawled_at: Optional[datetime] = None
is_active: bool = True
created_at: datetime
updated_at: datetime
class SchoolsListResponse(BaseModel):
"""List response with pagination info."""
schools: List[SchoolResponse]
total: int
page: int
page_size: int
class SchoolStatsResponse(BaseModel):
"""School statistics response."""
total_schools: int
total_staff: int
schools_by_state: dict
schools_by_type: dict
schools_with_website: int
schools_with_email: int
schools_with_principal: int
total_students: int
total_teachers: int
last_crawl_time: Optional[datetime] = None
class BulkImportRequest(BaseModel):
"""Bulk import request."""
schools: List[SchoolCreate]
class BulkImportResponse(BaseModel):
"""Bulk import response."""
imported: int
updated: int
skipped: int
errors: List[str]
# =============================================================================
# School Staff Models
# =============================================================================
class SchoolStaffBase(BaseModel):
"""Base school staff model."""
first_name: Optional[str] = None
last_name: str
full_name: Optional[str] = None
title: Optional[str] = None
position: Optional[str] = None
position_type: Optional[str] = None
subjects: Optional[List[str]] = None
email: Optional[str] = None
phone: Optional[str] = None
class SchoolStaffCreate(SchoolStaffBase):
"""School staff creation model."""
school_id: str
class SchoolStaffResponse(SchoolStaffBase):
"""School staff response model."""
id: str
school_id: str
school_name: Optional[str] = None
profile_url: Optional[str] = None
photo_url: Optional[str] = None
is_active: bool = True
created_at: datetime
class SchoolStaffListResponse(BaseModel):
"""Staff list response."""
staff: List[SchoolStaffResponse]
total: int
@@ -0,0 +1,233 @@
"""
Schools API - Staff Routes.
CRUD and search endpoints for school staff members.
"""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException, Query
from .schools_db import get_db_pool
from .schools_models import (
SchoolStaffBase,
SchoolStaffResponse,
SchoolStaffListResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["schools"])
# =============================================================================
# School Staff Endpoints
# =============================================================================
@router.get("/{school_id}/staff", response_model=SchoolStaffListResponse)
async def get_school_staff(school_id: str):
"""Get staff members for a school."""
pool = await get_db_pool()
async with pool.acquire() as conn:
rows = await conn.fetch("""
SELECT
ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name,
ss.title, ss.position, ss.position_type, ss.subjects,
ss.email, ss.phone, ss.profile_url, ss.photo_url,
ss.is_active, ss.created_at,
s.name as school_name
FROM school_staff ss
JOIN schools s ON ss.school_id = s.id
WHERE ss.school_id = $1 AND ss.is_active = TRUE
ORDER BY
CASE ss.position_type
WHEN 'principal' THEN 1
WHEN 'vice_principal' THEN 2
WHEN 'secretary' THEN 3
ELSE 4
END,
ss.last_name
""", school_id)
staff = [
SchoolStaffResponse(
id=str(row["id"]),
school_id=str(row["school_id"]),
school_name=row["school_name"],
first_name=row["first_name"],
last_name=row["last_name"],
full_name=row["full_name"],
title=row["title"],
position=row["position"],
position_type=row["position_type"],
subjects=row["subjects"],
email=row["email"],
phone=row["phone"],
profile_url=row["profile_url"],
photo_url=row["photo_url"],
is_active=row["is_active"],
created_at=row["created_at"],
)
for row in rows
]
return SchoolStaffListResponse(
staff=staff,
total=len(staff),
)
@router.post("/{school_id}/staff", response_model=SchoolStaffResponse)
async def create_school_staff(school_id: str, staff: SchoolStaffBase):
"""Add a staff member to a school."""
pool = await get_db_pool()
async with pool.acquire() as conn:
# Verify school exists
school = await conn.fetchrow("SELECT name FROM schools WHERE id = $1", school_id)
if not school:
raise HTTPException(status_code=404, detail="School not found")
# Create full name
full_name = staff.full_name
if not full_name:
parts = []
if staff.title:
parts.append(staff.title)
if staff.first_name:
parts.append(staff.first_name)
parts.append(staff.last_name)
full_name = " ".join(parts)
row = await conn.fetchrow("""
INSERT INTO school_staff (
school_id, first_name, last_name, full_name, title,
position, position_type, subjects, email, phone
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id, created_at
""",
school_id,
staff.first_name,
staff.last_name,
full_name,
staff.title,
staff.position,
staff.position_type,
staff.subjects,
staff.email,
staff.phone,
)
return SchoolStaffResponse(
id=str(row["id"]),
school_id=school_id,
school_name=school["name"],
first_name=staff.first_name,
last_name=staff.last_name,
full_name=full_name,
title=staff.title,
position=staff.position,
position_type=staff.position_type,
subjects=staff.subjects,
email=staff.email,
phone=staff.phone,
is_active=True,
created_at=row["created_at"],
)
# =============================================================================
# Search Endpoints
# =============================================================================
@router.get("/search/staff", response_model=SchoolStaffListResponse)
async def search_school_staff(
q: Optional[str] = Query(None, description="Search query"),
state: Optional[str] = Query(None, description="Filter by state"),
position_type: Optional[str] = Query(None, description="Filter by position type"),
has_email: Optional[bool] = Query(None, description="Only staff with email"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
"""Search school staff across all schools."""
pool = await get_db_pool()
async with pool.acquire() as conn:
conditions = ["ss.is_active = TRUE", "s.is_active = TRUE"]
params = []
param_idx = 1
if q:
conditions.append(f"""
(LOWER(ss.full_name) LIKE LOWER(${param_idx})
OR LOWER(ss.last_name) LIKE LOWER(${param_idx})
OR LOWER(s.name) LIKE LOWER(${param_idx}))
""")
params.append(f"%{q}%")
param_idx += 1
if state:
conditions.append(f"s.state = ${param_idx}")
params.append(state.upper())
param_idx += 1
if position_type:
conditions.append(f"ss.position_type = ${param_idx}")
params.append(position_type)
param_idx += 1
if has_email is not None and has_email:
conditions.append("ss.email IS NOT NULL")
where_clause = " AND ".join(conditions)
# Count total
total = await conn.fetchval(f"""
SELECT COUNT(*) FROM school_staff ss
JOIN schools s ON ss.school_id = s.id
WHERE {where_clause}
""", *params)
# Fetch staff
offset = (page - 1) * page_size
rows = await conn.fetch(f"""
SELECT
ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name,
ss.title, ss.position, ss.position_type, ss.subjects,
ss.email, ss.phone, ss.profile_url, ss.photo_url,
ss.is_active, ss.created_at,
s.name as school_name
FROM school_staff ss
JOIN schools s ON ss.school_id = s.id
WHERE {where_clause}
ORDER BY ss.last_name, ss.first_name
LIMIT ${param_idx} OFFSET ${param_idx + 1}
""", *params, page_size, offset)
staff = [
SchoolStaffResponse(
id=str(row["id"]),
school_id=str(row["school_id"]),
school_name=row["school_name"],
first_name=row["first_name"],
last_name=row["last_name"],
full_name=row["full_name"],
title=row["title"],
position=row["position"],
position_type=row["position_type"],
subjects=row["subjects"],
email=row["email"],
phone=row["phone"],
profile_url=row["profile_url"],
photo_url=row["photo_url"],
is_active=row["is_active"],
created_at=row["created_at"],
)
for row in rows
]
return SchoolStaffListResponse(
staff=staff,
total=total,
)
+13 -832
View File
@@ -1,840 +1,21 @@
"""
BreakPilot Messenger API
BreakPilot Messenger API — Barrel Re-export.
Stellt Endpoints fuer:
- Kontaktverwaltung (CRUD)
- Konversationen
- Nachrichten
- CSV-Import fuer Kontakte
- Gruppenmanagement
Stellt Endpoints fuer Kontakte, Konversationen, Nachrichten,
CSV-Import, Gruppenmanagement und Templates bereit.
DSGVO-konform: Alle Daten werden lokal gespeichert.
Split into:
- messenger_models.py: Pydantic models
- messenger_helpers.py: JSON file storage & default templates
- messenger_contacts.py: Contact CRUD & CSV import/export
- messenger_conversations.py: Conversations, messages, groups, templates, stats
"""
import os
import csv
import uuid
import json
from io import StringIO
from datetime import datetime
from typing import List, Optional, Dict, Any
from pathlib import Path
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException, UploadFile, File, Query
from pydantic import BaseModel, Field
from messenger_contacts import router as _contacts_router
from messenger_conversations import router as _conversations_router
router = APIRouter(prefix="/api/messenger", tags=["Messenger"])
# Datenspeicherung (JSON-basiert fuer einfache Persistenz)
DATA_DIR = Path(__file__).parent / "data" / "messenger"
DATA_DIR.mkdir(parents=True, exist_ok=True)
CONTACTS_FILE = DATA_DIR / "contacts.json"
CONVERSATIONS_FILE = DATA_DIR / "conversations.json"
MESSAGES_FILE = DATA_DIR / "messages.json"
GROUPS_FILE = DATA_DIR / "groups.json"
# ==========================================
# PYDANTIC MODELS
# ==========================================
class ContactBase(BaseModel):
"""Basis-Modell fuer Kontakte."""
name: str = Field(..., min_length=1, max_length=200)
email: Optional[str] = None
phone: Optional[str] = None
role: str = Field(default="parent", description="parent, teacher, staff, student")
student_name: Optional[str] = Field(None, description="Name des zugehoerigen Schuelers")
class_name: Optional[str] = Field(None, description="Klasse z.B. 10a")
notes: Optional[str] = None
tags: List[str] = Field(default_factory=list)
matrix_id: Optional[str] = Field(None, description="Matrix-ID z.B. @user:matrix.org")
preferred_channel: str = Field(default="email", description="email, matrix, pwa")
class ContactCreate(ContactBase):
"""Model fuer neuen Kontakt."""
pass
class Contact(ContactBase):
"""Vollstaendiger Kontakt mit ID."""
id: str
created_at: str
updated_at: str
online: bool = False
last_seen: Optional[str] = None
class ContactUpdate(BaseModel):
"""Update-Model fuer Kontakte."""
name: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
role: Optional[str] = None
student_name: Optional[str] = None
class_name: Optional[str] = None
notes: Optional[str] = None
tags: Optional[List[str]] = None
matrix_id: Optional[str] = None
preferred_channel: Optional[str] = None
class GroupBase(BaseModel):
"""Basis-Modell fuer Gruppen."""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = None
group_type: str = Field(default="class", description="class, department, custom")
class GroupCreate(GroupBase):
"""Model fuer neue Gruppe."""
member_ids: List[str] = Field(default_factory=list)
class Group(GroupBase):
"""Vollstaendige Gruppe mit ID."""
id: str
member_ids: List[str] = []
created_at: str
updated_at: str
class MessageBase(BaseModel):
"""Basis-Modell fuer Nachrichten."""
content: str = Field(..., min_length=1)
content_type: str = Field(default="text", description="text, file, image")
file_url: Optional[str] = None
send_email: bool = Field(default=False, description="Nachricht auch per Email senden")
class MessageCreate(MessageBase):
"""Model fuer neue Nachricht."""
conversation_id: str
class Message(MessageBase):
"""Vollstaendige Nachricht mit ID."""
id: str
conversation_id: str
sender_id: str # "self" fuer eigene Nachrichten
timestamp: str
read: bool = False
read_at: Optional[str] = None
email_sent: bool = False
email_sent_at: Optional[str] = None
email_error: Optional[str] = None
class ConversationBase(BaseModel):
"""Basis-Modell fuer Konversationen."""
name: Optional[str] = None
is_group: bool = False
class Conversation(ConversationBase):
"""Vollstaendige Konversation mit ID."""
id: str
participant_ids: List[str] = []
group_id: Optional[str] = None
created_at: str
updated_at: str
last_message: Optional[str] = None
last_message_time: Optional[str] = None
unread_count: int = 0
class CSVImportResult(BaseModel):
"""Ergebnis eines CSV-Imports."""
imported: int
skipped: int
errors: List[str]
contacts: List[Contact]
# ==========================================
# DATA HELPERS
# ==========================================
def load_json(filepath: Path) -> List[Dict]:
"""Laedt JSON-Daten aus Datei."""
if not filepath.exists():
return []
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return []
def save_json(filepath: Path, data: List[Dict]):
"""Speichert Daten in JSON-Datei."""
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def get_contacts() -> List[Dict]:
return load_json(CONTACTS_FILE)
def save_contacts(contacts: List[Dict]):
save_json(CONTACTS_FILE, contacts)
def get_conversations() -> List[Dict]:
return load_json(CONVERSATIONS_FILE)
def save_conversations(conversations: List[Dict]):
save_json(CONVERSATIONS_FILE, conversations)
def get_messages() -> List[Dict]:
return load_json(MESSAGES_FILE)
def save_messages(messages: List[Dict]):
save_json(MESSAGES_FILE, messages)
def get_groups() -> List[Dict]:
return load_json(GROUPS_FILE)
def save_groups(groups: List[Dict]):
save_json(GROUPS_FILE, groups)
# ==========================================
# CONTACTS ENDPOINTS
# ==========================================
@router.get("/contacts", response_model=List[Contact])
async def list_contacts(
role: Optional[str] = Query(None, description="Filter by role"),
class_name: Optional[str] = Query(None, description="Filter by class"),
search: Optional[str] = Query(None, description="Search in name/email")
):
"""Listet alle Kontakte auf."""
contacts = get_contacts()
# Filter anwenden
if role:
contacts = [c for c in contacts if c.get("role") == role]
if class_name:
contacts = [c for c in contacts if c.get("class_name") == class_name]
if search:
search_lower = search.lower()
contacts = [c for c in contacts if
search_lower in c.get("name", "").lower() or
search_lower in (c.get("email") or "").lower() or
search_lower in (c.get("student_name") or "").lower()]
return contacts
@router.post("/contacts", response_model=Contact)
async def create_contact(contact: ContactCreate):
"""Erstellt einen neuen Kontakt."""
contacts = get_contacts()
# Pruefen ob Email bereits existiert
if contact.email:
existing = [c for c in contacts if c.get("email") == contact.email]
if existing:
raise HTTPException(status_code=400, detail="Kontakt mit dieser Email existiert bereits")
now = datetime.utcnow().isoformat()
new_contact = {
"id": str(uuid.uuid4()),
"created_at": now,
"updated_at": now,
"online": False,
"last_seen": None,
**contact.dict()
}
contacts.append(new_contact)
save_contacts(contacts)
return new_contact
@router.get("/contacts/{contact_id}", response_model=Contact)
async def get_contact(contact_id: str):
"""Ruft einen einzelnen Kontakt ab."""
contacts = get_contacts()
contact = next((c for c in contacts if c["id"] == contact_id), None)
if not contact:
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
return contact
@router.put("/contacts/{contact_id}", response_model=Contact)
async def update_contact(contact_id: str, update: ContactUpdate):
"""Aktualisiert einen Kontakt."""
contacts = get_contacts()
contact_idx = next((i for i, c in enumerate(contacts) if c["id"] == contact_id), None)
if contact_idx is None:
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
update_data = update.dict(exclude_unset=True)
contacts[contact_idx].update(update_data)
contacts[contact_idx]["updated_at"] = datetime.utcnow().isoformat()
save_contacts(contacts)
return contacts[contact_idx]
@router.delete("/contacts/{contact_id}")
async def delete_contact(contact_id: str):
"""Loescht einen Kontakt."""
contacts = get_contacts()
contacts = [c for c in contacts if c["id"] != contact_id]
save_contacts(contacts)
return {"status": "deleted", "id": contact_id}
@router.post("/contacts/import", response_model=CSVImportResult)
async def import_contacts_csv(file: UploadFile = File(...)):
"""
Importiert Kontakte aus einer CSV-Datei.
Erwartete Spalten:
- name (required)
- email
- phone
- role (parent/teacher/staff/student)
- student_name
- class_name
- notes
- tags (komma-separiert)
"""
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Nur CSV-Dateien werden unterstuetzt")
content = await file.read()
try:
text = content.decode('utf-8')
except UnicodeDecodeError:
text = content.decode('latin-1')
contacts = get_contacts()
existing_emails = {c.get("email") for c in contacts if c.get("email")}
imported = []
skipped = 0
errors = []
reader = csv.DictReader(StringIO(text), delimiter=';') # Deutsche CSV meist mit Semikolon
if not reader.fieldnames or 'name' not in [f.lower() for f in reader.fieldnames]:
# Versuche mit Komma
reader = csv.DictReader(StringIO(text), delimiter=',')
for row_num, row in enumerate(reader, start=2):
try:
# Normalisiere Spaltennamen
row = {k.lower().strip(): v.strip() if v else "" for k, v in row.items()}
name = row.get('name') or row.get('kontakt') or row.get('elternname')
if not name:
errors.append(f"Zeile {row_num}: Name fehlt")
skipped += 1
continue
email = row.get('email') or row.get('e-mail') or row.get('mail')
if email and email in existing_emails:
errors.append(f"Zeile {row_num}: Email {email} existiert bereits")
skipped += 1
continue
now = datetime.utcnow().isoformat()
tags_str = row.get('tags') or row.get('kategorien') or ""
tags = [t.strip() for t in tags_str.split(',') if t.strip()]
# Matrix-ID und preferred_channel auslesen
matrix_id = row.get('matrix_id') or row.get('matrix') or None
preferred_channel = row.get('preferred_channel') or row.get('kanal') or "email"
if preferred_channel not in ["email", "matrix", "pwa"]:
preferred_channel = "email"
new_contact = {
"id": str(uuid.uuid4()),
"name": name,
"email": email if email else None,
"phone": row.get('phone') or row.get('telefon') or row.get('tel'),
"role": row.get('role') or row.get('rolle') or "parent",
"student_name": row.get('student_name') or row.get('schueler') or row.get('kind'),
"class_name": row.get('class_name') or row.get('klasse'),
"notes": row.get('notes') or row.get('notizen') or row.get('bemerkungen'),
"tags": tags,
"matrix_id": matrix_id if matrix_id else None,
"preferred_channel": preferred_channel,
"created_at": now,
"updated_at": now,
"online": False,
"last_seen": None
}
contacts.append(new_contact)
imported.append(new_contact)
if email:
existing_emails.add(email)
except Exception as e:
errors.append(f"Zeile {row_num}: {str(e)}")
skipped += 1
save_contacts(contacts)
return CSVImportResult(
imported=len(imported),
skipped=skipped,
errors=errors[:20], # Maximal 20 Fehler zurueckgeben
contacts=imported
)
@router.get("/contacts/export/csv")
async def export_contacts_csv():
"""Exportiert alle Kontakte als CSV."""
from fastapi.responses import StreamingResponse
contacts = get_contacts()
output = StringIO()
fieldnames = ['name', 'email', 'phone', 'role', 'student_name', 'class_name', 'notes', 'tags', 'matrix_id', 'preferred_channel']
writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=';')
writer.writeheader()
for contact in contacts:
writer.writerow({
'name': contact.get('name', ''),
'email': contact.get('email', ''),
'phone': contact.get('phone', ''),
'role': contact.get('role', ''),
'student_name': contact.get('student_name', ''),
'class_name': contact.get('class_name', ''),
'notes': contact.get('notes', ''),
'tags': ','.join(contact.get('tags', [])),
'matrix_id': contact.get('matrix_id', ''),
'preferred_channel': contact.get('preferred_channel', 'email')
})
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=kontakte.csv"}
)
# ==========================================
# GROUPS ENDPOINTS
# ==========================================
@router.get("/groups", response_model=List[Group])
async def list_groups():
"""Listet alle Gruppen auf."""
return get_groups()
@router.post("/groups", response_model=Group)
async def create_group(group: GroupCreate):
"""Erstellt eine neue Gruppe."""
groups = get_groups()
now = datetime.utcnow().isoformat()
new_group = {
"id": str(uuid.uuid4()),
"created_at": now,
"updated_at": now,
**group.dict()
}
groups.append(new_group)
save_groups(groups)
return new_group
@router.put("/groups/{group_id}/members")
async def update_group_members(group_id: str, member_ids: List[str]):
"""Aktualisiert die Mitglieder einer Gruppe."""
groups = get_groups()
group_idx = next((i for i, g in enumerate(groups) if g["id"] == group_id), None)
if group_idx is None:
raise HTTPException(status_code=404, detail="Gruppe nicht gefunden")
groups[group_idx]["member_ids"] = member_ids
groups[group_idx]["updated_at"] = datetime.utcnow().isoformat()
save_groups(groups)
return groups[group_idx]
@router.delete("/groups/{group_id}")
async def delete_group(group_id: str):
"""Loescht eine Gruppe."""
groups = get_groups()
groups = [g for g in groups if g["id"] != group_id]
save_groups(groups)
return {"status": "deleted", "id": group_id}
# ==========================================
# CONVERSATIONS ENDPOINTS
# ==========================================
@router.get("/conversations", response_model=List[Conversation])
async def list_conversations():
"""Listet alle Konversationen auf."""
conversations = get_conversations()
messages = get_messages()
# Unread count und letzte Nachricht hinzufuegen
for conv in conversations:
conv_messages = [m for m in messages if m.get("conversation_id") == conv["id"]]
conv["unread_count"] = len([m for m in conv_messages if not m.get("read") and m.get("sender_id") != "self"])
if conv_messages:
last_msg = max(conv_messages, key=lambda m: m.get("timestamp", ""))
conv["last_message"] = last_msg.get("content", "")[:50]
conv["last_message_time"] = last_msg.get("timestamp")
# Nach letzter Nachricht sortieren
conversations.sort(key=lambda c: c.get("last_message_time") or "", reverse=True)
return conversations
@router.post("/conversations", response_model=Conversation)
async def create_conversation(contact_id: Optional[str] = None, group_id: Optional[str] = None):
"""
Erstellt eine neue Konversation.
Entweder mit einem Kontakt (1:1) oder einer Gruppe.
"""
conversations = get_conversations()
if not contact_id and not group_id:
raise HTTPException(status_code=400, detail="Entweder contact_id oder group_id erforderlich")
# Pruefen ob Konversation bereits existiert
if contact_id:
existing = next((c for c in conversations
if not c.get("is_group") and contact_id in c.get("participant_ids", [])), None)
if existing:
return existing
now = datetime.utcnow().isoformat()
if group_id:
groups = get_groups()
group = next((g for g in groups if g["id"] == group_id), None)
if not group:
raise HTTPException(status_code=404, detail="Gruppe nicht gefunden")
new_conv = {
"id": str(uuid.uuid4()),
"name": group.get("name"),
"is_group": True,
"participant_ids": group.get("member_ids", []),
"group_id": group_id,
"created_at": now,
"updated_at": now,
"last_message": None,
"last_message_time": None,
"unread_count": 0
}
else:
contacts = get_contacts()
contact = next((c for c in contacts if c["id"] == contact_id), None)
if not contact:
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
new_conv = {
"id": str(uuid.uuid4()),
"name": contact.get("name"),
"is_group": False,
"participant_ids": [contact_id],
"group_id": None,
"created_at": now,
"updated_at": now,
"last_message": None,
"last_message_time": None,
"unread_count": 0
}
conversations.append(new_conv)
save_conversations(conversations)
return new_conv
@router.get("/conversations/{conversation_id}", response_model=Conversation)
async def get_conversation(conversation_id: str):
"""Ruft eine Konversation ab."""
conversations = get_conversations()
conv = next((c for c in conversations if c["id"] == conversation_id), None)
if not conv:
raise HTTPException(status_code=404, detail="Konversation nicht gefunden")
return conv
@router.delete("/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str):
"""Loescht eine Konversation und alle zugehoerigen Nachrichten."""
conversations = get_conversations()
conversations = [c for c in conversations if c["id"] != conversation_id]
save_conversations(conversations)
messages = get_messages()
messages = [m for m in messages if m.get("conversation_id") != conversation_id]
save_messages(messages)
return {"status": "deleted", "id": conversation_id}
# ==========================================
# MESSAGES ENDPOINTS
# ==========================================
@router.get("/conversations/{conversation_id}/messages", response_model=List[Message])
async def list_messages(
conversation_id: str,
limit: int = Query(50, ge=1, le=200),
before: Optional[str] = Query(None, description="Load messages before this timestamp")
):
"""Ruft Nachrichten einer Konversation ab."""
messages = get_messages()
conv_messages = [m for m in messages if m.get("conversation_id") == conversation_id]
if before:
conv_messages = [m for m in conv_messages if m.get("timestamp", "") < before]
# Nach Zeit sortieren (neueste zuletzt)
conv_messages.sort(key=lambda m: m.get("timestamp", ""))
return conv_messages[-limit:]
@router.post("/conversations/{conversation_id}/messages", response_model=Message)
async def send_message(conversation_id: str, message: MessageBase):
"""
Sendet eine Nachricht in einer Konversation.
Wenn send_email=True und der Kontakt eine Email-Adresse hat,
wird die Nachricht auch per Email versendet.
"""
conversations = get_conversations()
conv = next((c for c in conversations if c["id"] == conversation_id), None)
if not conv:
raise HTTPException(status_code=404, detail="Konversation nicht gefunden")
now = datetime.utcnow().isoformat()
new_message = {
"id": str(uuid.uuid4()),
"conversation_id": conversation_id,
"sender_id": "self",
"timestamp": now,
"read": True,
"read_at": now,
"email_sent": False,
"email_sent_at": None,
"email_error": None,
**message.dict()
}
# Email-Versand wenn gewuenscht
if message.send_email and not conv.get("is_group"):
# Kontakt laden
participant_ids = conv.get("participant_ids", [])
if participant_ids:
contacts = get_contacts()
contact = next((c for c in contacts if c["id"] == participant_ids[0]), None)
if contact and contact.get("email"):
try:
from email_service import email_service
result = email_service.send_messenger_notification(
to_email=contact["email"],
to_name=contact.get("name", ""),
sender_name="BreakPilot Lehrer", # TODO: Aktuellen User-Namen verwenden
message_content=message.content
)
if result.success:
new_message["email_sent"] = True
new_message["email_sent_at"] = result.sent_at
else:
new_message["email_error"] = result.error
except Exception as e:
new_message["email_error"] = str(e)
messages = get_messages()
messages.append(new_message)
save_messages(messages)
# Konversation aktualisieren
conv_idx = next(i for i, c in enumerate(conversations) if c["id"] == conversation_id)
conversations[conv_idx]["last_message"] = message.content[:50]
conversations[conv_idx]["last_message_time"] = now
conversations[conv_idx]["updated_at"] = now
save_conversations(conversations)
return new_message
@router.put("/messages/{message_id}/read")
async def mark_message_read(message_id: str):
"""Markiert eine Nachricht als gelesen."""
messages = get_messages()
msg_idx = next((i for i, m in enumerate(messages) if m["id"] == message_id), None)
if msg_idx is None:
raise HTTPException(status_code=404, detail="Nachricht nicht gefunden")
messages[msg_idx]["read"] = True
messages[msg_idx]["read_at"] = datetime.utcnow().isoformat()
save_messages(messages)
return {"status": "read", "id": message_id}
@router.put("/conversations/{conversation_id}/read-all")
async def mark_all_messages_read(conversation_id: str):
"""Markiert alle Nachrichten einer Konversation als gelesen."""
messages = get_messages()
now = datetime.utcnow().isoformat()
for msg in messages:
if msg.get("conversation_id") == conversation_id and not msg.get("read"):
msg["read"] = True
msg["read_at"] = now
save_messages(messages)
return {"status": "all_read", "conversation_id": conversation_id}
# ==========================================
# TEMPLATES ENDPOINTS
# ==========================================
DEFAULT_TEMPLATES = [
{
"id": "1",
"name": "Terminbestaetigung",
"content": "Vielen Dank fuer Ihre Terminanfrage. Ich bestaetige den Termin am [DATUM] um [UHRZEIT]. Bitte geben Sie mir Bescheid, falls sich etwas aendern sollte.",
"category": "termin"
},
{
"id": "2",
"name": "Hausaufgaben-Info",
"content": "Zur Information: Die Hausaufgaben fuer diese Woche umfassen [THEMA]. Abgabetermin ist [DATUM]. Bei Fragen stehe ich gerne zur Verfuegung.",
"category": "hausaufgaben"
},
{
"id": "3",
"name": "Entschuldigung bestaetigen",
"content": "Ich bestaetige den Erhalt der Entschuldigung fuer [NAME] am [DATUM]. Die Fehlzeiten wurden entsprechend vermerkt.",
"category": "entschuldigung"
},
{
"id": "4",
"name": "Gespraechsanfrage",
"content": "Ich wuerde gerne einen Termin fuer ein Gespraech mit Ihnen vereinbaren, um [THEMA] zu besprechen. Waeren Sie am [DATUM] um [UHRZEIT] verfuegbar?",
"category": "gespraech"
},
{
"id": "5",
"name": "Krankmeldung bestaetigen",
"content": "Vielen Dank fuer Ihre Krankmeldung fuer [NAME]. Ich wuensche gute Besserung. Bitte reichen Sie eine schriftliche Entschuldigung nach, sobald Ihr Kind wieder gesund ist.",
"category": "krankmeldung"
}
]
@router.get("/templates")
async def list_templates():
"""Listet alle Nachrichtenvorlagen auf."""
templates_file = DATA_DIR / "templates.json"
if templates_file.exists():
templates = load_json(templates_file)
else:
templates = DEFAULT_TEMPLATES
save_json(templates_file, templates)
return templates
@router.post("/templates")
async def create_template(name: str, content: str, category: str = "custom"):
"""Erstellt eine neue Vorlage."""
templates_file = DATA_DIR / "templates.json"
templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy()
new_template = {
"id": str(uuid.uuid4()),
"name": name,
"content": content,
"category": category
}
templates.append(new_template)
save_json(templates_file, templates)
return new_template
@router.delete("/templates/{template_id}")
async def delete_template(template_id: str):
"""Loescht eine Vorlage."""
templates_file = DATA_DIR / "templates.json"
templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy()
templates = [t for t in templates if t["id"] != template_id]
save_json(templates_file, templates)
return {"status": "deleted", "id": template_id}
# ==========================================
# STATS ENDPOINT
# ==========================================
@router.get("/stats")
async def get_messenger_stats():
"""Gibt Statistiken zum Messenger zurueck."""
contacts = get_contacts()
conversations = get_conversations()
messages = get_messages()
groups = get_groups()
unread_total = sum(1 for m in messages if not m.get("read") and m.get("sender_id") != "self")
return {
"total_contacts": len(contacts),
"total_groups": len(groups),
"total_conversations": len(conversations),
"total_messages": len(messages),
"unread_messages": unread_total,
"contacts_by_role": {
role: len([c for c in contacts if c.get("role") == role])
for role in set(c.get("role", "parent") for c in contacts)
}
}
router.include_router(_contacts_router)
router.include_router(_conversations_router)
+251
View File
@@ -0,0 +1,251 @@
"""
Messenger API - Contact Routes.
CRUD, CSV import/export for contacts.
"""
import csv
import uuid
from io import StringIO
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Query
from fastapi.responses import StreamingResponse
from messenger_models import (
Contact,
ContactCreate,
ContactUpdate,
CSVImportResult,
)
from messenger_helpers import get_contacts, save_contacts
router = APIRouter(tags=["Messenger"])
# ==========================================
# CONTACTS ENDPOINTS
# ==========================================
@router.get("/contacts", response_model=List[Contact])
async def list_contacts(
role: Optional[str] = Query(None, description="Filter by role"),
class_name: Optional[str] = Query(None, description="Filter by class"),
search: Optional[str] = Query(None, description="Search in name/email")
):
"""Listet alle Kontakte auf."""
contacts = get_contacts()
# Filter anwenden
if role:
contacts = [c for c in contacts if c.get("role") == role]
if class_name:
contacts = [c for c in contacts if c.get("class_name") == class_name]
if search:
search_lower = search.lower()
contacts = [c for c in contacts if
search_lower in c.get("name", "").lower() or
search_lower in (c.get("email") or "").lower() or
search_lower in (c.get("student_name") or "").lower()]
return contacts
@router.post("/contacts", response_model=Contact)
async def create_contact(contact: ContactCreate):
"""Erstellt einen neuen Kontakt."""
contacts = get_contacts()
# Pruefen ob Email bereits existiert
if contact.email:
existing = [c for c in contacts if c.get("email") == contact.email]
if existing:
raise HTTPException(status_code=400, detail="Kontakt mit dieser Email existiert bereits")
now = datetime.utcnow().isoformat()
new_contact = {
"id": str(uuid.uuid4()),
"created_at": now,
"updated_at": now,
"online": False,
"last_seen": None,
**contact.dict()
}
contacts.append(new_contact)
save_contacts(contacts)
return new_contact
@router.get("/contacts/{contact_id}", response_model=Contact)
async def get_contact(contact_id: str):
"""Ruft einen einzelnen Kontakt ab."""
contacts = get_contacts()
contact = next((c for c in contacts if c["id"] == contact_id), None)
if not contact:
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
return contact
@router.put("/contacts/{contact_id}", response_model=Contact)
async def update_contact(contact_id: str, update: ContactUpdate):
"""Aktualisiert einen Kontakt."""
contacts = get_contacts()
contact_idx = next((i for i, c in enumerate(contacts) if c["id"] == contact_id), None)
if contact_idx is None:
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
update_data = update.dict(exclude_unset=True)
contacts[contact_idx].update(update_data)
contacts[contact_idx]["updated_at"] = datetime.utcnow().isoformat()
save_contacts(contacts)
return contacts[contact_idx]
@router.delete("/contacts/{contact_id}")
async def delete_contact(contact_id: str):
"""Loescht einen Kontakt."""
contacts = get_contacts()
contacts = [c for c in contacts if c["id"] != contact_id]
save_contacts(contacts)
return {"status": "deleted", "id": contact_id}
@router.post("/contacts/import", response_model=CSVImportResult)
async def import_contacts_csv(file: UploadFile = File(...)):
"""
Importiert Kontakte aus einer CSV-Datei.
Erwartete Spalten:
- name (required)
- email
- phone
- role (parent/teacher/staff/student)
- student_name
- class_name
- notes
- tags (komma-separiert)
"""
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="Nur CSV-Dateien werden unterstuetzt")
content = await file.read()
try:
text = content.decode('utf-8')
except UnicodeDecodeError:
text = content.decode('latin-1')
contacts = get_contacts()
existing_emails = {c.get("email") for c in contacts if c.get("email")}
imported = []
skipped = 0
errors = []
reader = csv.DictReader(StringIO(text), delimiter=';') # Deutsche CSV meist mit Semikolon
if not reader.fieldnames or 'name' not in [f.lower() for f in reader.fieldnames]:
# Versuche mit Komma
reader = csv.DictReader(StringIO(text), delimiter=',')
for row_num, row in enumerate(reader, start=2):
try:
# Normalisiere Spaltennamen
row = {k.lower().strip(): v.strip() if v else "" for k, v in row.items()}
name = row.get('name') or row.get('kontakt') or row.get('elternname')
if not name:
errors.append(f"Zeile {row_num}: Name fehlt")
skipped += 1
continue
email = row.get('email') or row.get('e-mail') or row.get('mail')
if email and email in existing_emails:
errors.append(f"Zeile {row_num}: Email {email} existiert bereits")
skipped += 1
continue
now = datetime.utcnow().isoformat()
tags_str = row.get('tags') or row.get('kategorien') or ""
tags = [t.strip() for t in tags_str.split(',') if t.strip()]
# Matrix-ID und preferred_channel auslesen
matrix_id = row.get('matrix_id') or row.get('matrix') or None
preferred_channel = row.get('preferred_channel') or row.get('kanal') or "email"
if preferred_channel not in ["email", "matrix", "pwa"]:
preferred_channel = "email"
new_contact = {
"id": str(uuid.uuid4()),
"name": name,
"email": email if email else None,
"phone": row.get('phone') or row.get('telefon') or row.get('tel'),
"role": row.get('role') or row.get('rolle') or "parent",
"student_name": row.get('student_name') or row.get('schueler') or row.get('kind'),
"class_name": row.get('class_name') or row.get('klasse'),
"notes": row.get('notes') or row.get('notizen') or row.get('bemerkungen'),
"tags": tags,
"matrix_id": matrix_id if matrix_id else None,
"preferred_channel": preferred_channel,
"created_at": now,
"updated_at": now,
"online": False,
"last_seen": None
}
contacts.append(new_contact)
imported.append(new_contact)
if email:
existing_emails.add(email)
except Exception as e:
errors.append(f"Zeile {row_num}: {str(e)}")
skipped += 1
save_contacts(contacts)
return CSVImportResult(
imported=len(imported),
skipped=skipped,
errors=errors[:20], # Maximal 20 Fehler zurueckgeben
contacts=imported
)
@router.get("/contacts/export/csv")
async def export_contacts_csv():
"""Exportiert alle Kontakte als CSV."""
contacts = get_contacts()
output = StringIO()
fieldnames = ['name', 'email', 'phone', 'role', 'student_name', 'class_name', 'notes', 'tags', 'matrix_id', 'preferred_channel']
writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=';')
writer.writeheader()
for contact in contacts:
writer.writerow({
'name': contact.get('name', ''),
'email': contact.get('email', ''),
'phone': contact.get('phone', ''),
'role': contact.get('role', ''),
'student_name': contact.get('student_name', ''),
'class_name': contact.get('class_name', ''),
'notes': contact.get('notes', ''),
'tags': ','.join(contact.get('tags', [])),
'matrix_id': contact.get('matrix_id', ''),
'preferred_channel': contact.get('preferred_channel', 'email')
})
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=kontakte.csv"}
)
+405
View File
@@ -0,0 +1,405 @@
"""
Messenger API - Conversation, Message, Group, Template & Stats Routes.
Conversations CRUD, message send/read, groups, templates, stats.
"""
import uuid
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Query
from messenger_models import (
Conversation,
Group,
GroupCreate,
Message,
MessageBase,
)
from messenger_helpers import (
DATA_DIR,
DEFAULT_TEMPLATES,
get_contacts,
get_conversations,
save_conversations,
get_messages,
save_messages,
get_groups,
save_groups,
load_json,
save_json,
)
router = APIRouter(tags=["Messenger"])
# ==========================================
# GROUPS ENDPOINTS
# ==========================================
@router.get("/groups", response_model=List[Group])
async def list_groups():
"""Listet alle Gruppen auf."""
return get_groups()
@router.post("/groups", response_model=Group)
async def create_group(group: GroupCreate):
"""Erstellt eine neue Gruppe."""
groups = get_groups()
now = datetime.utcnow().isoformat()
new_group = {
"id": str(uuid.uuid4()),
"created_at": now,
"updated_at": now,
**group.dict()
}
groups.append(new_group)
save_groups(groups)
return new_group
@router.put("/groups/{group_id}/members")
async def update_group_members(group_id: str, member_ids: List[str]):
"""Aktualisiert die Mitglieder einer Gruppe."""
groups = get_groups()
group_idx = next((i for i, g in enumerate(groups) if g["id"] == group_id), None)
if group_idx is None:
raise HTTPException(status_code=404, detail="Gruppe nicht gefunden")
groups[group_idx]["member_ids"] = member_ids
groups[group_idx]["updated_at"] = datetime.utcnow().isoformat()
save_groups(groups)
return groups[group_idx]
@router.delete("/groups/{group_id}")
async def delete_group(group_id: str):
"""Loescht eine Gruppe."""
groups = get_groups()
groups = [g for g in groups if g["id"] != group_id]
save_groups(groups)
return {"status": "deleted", "id": group_id}
# ==========================================
# CONVERSATIONS ENDPOINTS
# ==========================================
@router.get("/conversations", response_model=List[Conversation])
async def list_conversations():
"""Listet alle Konversationen auf."""
conversations = get_conversations()
messages = get_messages()
# Unread count und letzte Nachricht hinzufuegen
for conv in conversations:
conv_messages = [m for m in messages if m.get("conversation_id") == conv["id"]]
conv["unread_count"] = len([m for m in conv_messages if not m.get("read") and m.get("sender_id") != "self"])
if conv_messages:
last_msg = max(conv_messages, key=lambda m: m.get("timestamp", ""))
conv["last_message"] = last_msg.get("content", "")[:50]
conv["last_message_time"] = last_msg.get("timestamp")
# Nach letzter Nachricht sortieren
conversations.sort(key=lambda c: c.get("last_message_time") or "", reverse=True)
return conversations
@router.post("/conversations", response_model=Conversation)
async def create_conversation(contact_id: Optional[str] = None, group_id: Optional[str] = None):
"""
Erstellt eine neue Konversation.
Entweder mit einem Kontakt (1:1) oder einer Gruppe.
"""
conversations = get_conversations()
if not contact_id and not group_id:
raise HTTPException(status_code=400, detail="Entweder contact_id oder group_id erforderlich")
# Pruefen ob Konversation bereits existiert
if contact_id:
existing = next((c for c in conversations
if not c.get("is_group") and contact_id in c.get("participant_ids", [])), None)
if existing:
return existing
now = datetime.utcnow().isoformat()
if group_id:
groups = get_groups()
group = next((g for g in groups if g["id"] == group_id), None)
if not group:
raise HTTPException(status_code=404, detail="Gruppe nicht gefunden")
new_conv = {
"id": str(uuid.uuid4()),
"name": group.get("name"),
"is_group": True,
"participant_ids": group.get("member_ids", []),
"group_id": group_id,
"created_at": now,
"updated_at": now,
"last_message": None,
"last_message_time": None,
"unread_count": 0
}
else:
contacts = get_contacts()
contact = next((c for c in contacts if c["id"] == contact_id), None)
if not contact:
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
new_conv = {
"id": str(uuid.uuid4()),
"name": contact.get("name"),
"is_group": False,
"participant_ids": [contact_id],
"group_id": None,
"created_at": now,
"updated_at": now,
"last_message": None,
"last_message_time": None,
"unread_count": 0
}
conversations.append(new_conv)
save_conversations(conversations)
return new_conv
@router.get("/conversations/{conversation_id}", response_model=Conversation)
async def get_conversation(conversation_id: str):
"""Ruft eine Konversation ab."""
conversations = get_conversations()
conv = next((c for c in conversations if c["id"] == conversation_id), None)
if not conv:
raise HTTPException(status_code=404, detail="Konversation nicht gefunden")
return conv
@router.delete("/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str):
"""Loescht eine Konversation und alle zugehoerigen Nachrichten."""
conversations = get_conversations()
conversations = [c for c in conversations if c["id"] != conversation_id]
save_conversations(conversations)
messages = get_messages()
messages = [m for m in messages if m.get("conversation_id") != conversation_id]
save_messages(messages)
return {"status": "deleted", "id": conversation_id}
# ==========================================
# MESSAGES ENDPOINTS
# ==========================================
@router.get("/conversations/{conversation_id}/messages", response_model=List[Message])
async def list_messages(
conversation_id: str,
limit: int = Query(50, ge=1, le=200),
before: Optional[str] = Query(None, description="Load messages before this timestamp")
):
"""Ruft Nachrichten einer Konversation ab."""
messages = get_messages()
conv_messages = [m for m in messages if m.get("conversation_id") == conversation_id]
if before:
conv_messages = [m for m in conv_messages if m.get("timestamp", "") < before]
# Nach Zeit sortieren (neueste zuletzt)
conv_messages.sort(key=lambda m: m.get("timestamp", ""))
return conv_messages[-limit:]
@router.post("/conversations/{conversation_id}/messages", response_model=Message)
async def send_message(conversation_id: str, message: MessageBase):
"""
Sendet eine Nachricht in einer Konversation.
Wenn send_email=True und der Kontakt eine Email-Adresse hat,
wird die Nachricht auch per Email versendet.
"""
conversations = get_conversations()
conv = next((c for c in conversations if c["id"] == conversation_id), None)
if not conv:
raise HTTPException(status_code=404, detail="Konversation nicht gefunden")
now = datetime.utcnow().isoformat()
new_message = {
"id": str(uuid.uuid4()),
"conversation_id": conversation_id,
"sender_id": "self",
"timestamp": now,
"read": True,
"read_at": now,
"email_sent": False,
"email_sent_at": None,
"email_error": None,
**message.dict()
}
# Email-Versand wenn gewuenscht
if message.send_email and not conv.get("is_group"):
# Kontakt laden
participant_ids = conv.get("participant_ids", [])
if participant_ids:
contacts = get_contacts()
contact = next((c for c in contacts if c["id"] == participant_ids[0]), None)
if contact and contact.get("email"):
try:
from email_service import email_service
result = email_service.send_messenger_notification(
to_email=contact["email"],
to_name=contact.get("name", ""),
sender_name="BreakPilot Lehrer",
message_content=message.content
)
if result.success:
new_message["email_sent"] = True
new_message["email_sent_at"] = result.sent_at
else:
new_message["email_error"] = result.error
except Exception as e:
new_message["email_error"] = str(e)
messages = get_messages()
messages.append(new_message)
save_messages(messages)
# Konversation aktualisieren
conv_idx = next(i for i, c in enumerate(conversations) if c["id"] == conversation_id)
conversations[conv_idx]["last_message"] = message.content[:50]
conversations[conv_idx]["last_message_time"] = now
conversations[conv_idx]["updated_at"] = now
save_conversations(conversations)
return new_message
@router.put("/messages/{message_id}/read")
async def mark_message_read(message_id: str):
"""Markiert eine Nachricht als gelesen."""
messages = get_messages()
msg_idx = next((i for i, m in enumerate(messages) if m["id"] == message_id), None)
if msg_idx is None:
raise HTTPException(status_code=404, detail="Nachricht nicht gefunden")
messages[msg_idx]["read"] = True
messages[msg_idx]["read_at"] = datetime.utcnow().isoformat()
save_messages(messages)
return {"status": "read", "id": message_id}
@router.put("/conversations/{conversation_id}/read-all")
async def mark_all_messages_read(conversation_id: str):
"""Markiert alle Nachrichten einer Konversation als gelesen."""
messages = get_messages()
now = datetime.utcnow().isoformat()
for msg in messages:
if msg.get("conversation_id") == conversation_id and not msg.get("read"):
msg["read"] = True
msg["read_at"] = now
save_messages(messages)
return {"status": "all_read", "conversation_id": conversation_id}
# ==========================================
# TEMPLATES ENDPOINTS
# ==========================================
@router.get("/templates")
async def list_templates():
"""Listet alle Nachrichtenvorlagen auf."""
templates_file = DATA_DIR / "templates.json"
if templates_file.exists():
templates = load_json(templates_file)
else:
templates = DEFAULT_TEMPLATES
save_json(templates_file, templates)
return templates
@router.post("/templates")
async def create_template(name: str, content: str, category: str = "custom"):
"""Erstellt eine neue Vorlage."""
templates_file = DATA_DIR / "templates.json"
templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy()
new_template = {
"id": str(uuid.uuid4()),
"name": name,
"content": content,
"category": category
}
templates.append(new_template)
save_json(templates_file, templates)
return new_template
@router.delete("/templates/{template_id}")
async def delete_template(template_id: str):
"""Loescht eine Vorlage."""
templates_file = DATA_DIR / "templates.json"
templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy()
templates = [t for t in templates if t["id"] != template_id]
save_json(templates_file, templates)
return {"status": "deleted", "id": template_id}
# ==========================================
# STATS ENDPOINT
# ==========================================
@router.get("/stats")
async def get_messenger_stats():
"""Gibt Statistiken zum Messenger zurueck."""
contacts = get_contacts()
conversations = get_conversations()
messages = get_messages()
groups = get_groups()
unread_total = sum(1 for m in messages if not m.get("read") and m.get("sender_id") != "self")
return {
"total_contacts": len(contacts),
"total_groups": len(groups),
"total_conversations": len(conversations),
"total_messages": len(messages),
"unread_messages": unread_total,
"contacts_by_role": {
role: len([c for c in contacts if c.get("role") == role])
for role in set(c.get("role", "parent") for c in contacts)
}
}
+105
View File
@@ -0,0 +1,105 @@
"""
Messenger API - Data Helpers.
JSON-based file storage for contacts, conversations, messages, and groups.
"""
import json
from typing import List, Dict
from pathlib import Path
# Datenspeicherung (JSON-basiert fuer einfache Persistenz)
DATA_DIR = Path(__file__).parent / "data" / "messenger"
DATA_DIR.mkdir(parents=True, exist_ok=True)
CONTACTS_FILE = DATA_DIR / "contacts.json"
CONVERSATIONS_FILE = DATA_DIR / "conversations.json"
MESSAGES_FILE = DATA_DIR / "messages.json"
GROUPS_FILE = DATA_DIR / "groups.json"
def load_json(filepath: Path) -> List[Dict]:
"""Laedt JSON-Daten aus Datei."""
if not filepath.exists():
return []
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return []
def save_json(filepath: Path, data: List[Dict]):
"""Speichert Daten in JSON-Datei."""
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def get_contacts() -> List[Dict]:
return load_json(CONTACTS_FILE)
def save_contacts(contacts: List[Dict]):
save_json(CONTACTS_FILE, contacts)
def get_conversations() -> List[Dict]:
return load_json(CONVERSATIONS_FILE)
def save_conversations(conversations: List[Dict]):
save_json(CONVERSATIONS_FILE, conversations)
def get_messages() -> List[Dict]:
return load_json(MESSAGES_FILE)
def save_messages(messages: List[Dict]):
save_json(MESSAGES_FILE, messages)
def get_groups() -> List[Dict]:
return load_json(GROUPS_FILE)
def save_groups(groups: List[Dict]):
save_json(GROUPS_FILE, groups)
# ==========================================
# DEFAULT TEMPLATES
# ==========================================
DEFAULT_TEMPLATES = [
{
"id": "1",
"name": "Terminbestaetigung",
"content": "Vielen Dank fuer Ihre Terminanfrage. Ich bestaetige den Termin am [DATUM] um [UHRZEIT]. Bitte geben Sie mir Bescheid, falls sich etwas aendern sollte.",
"category": "termin"
},
{
"id": "2",
"name": "Hausaufgaben-Info",
"content": "Zur Information: Die Hausaufgaben fuer diese Woche umfassen [THEMA]. Abgabetermin ist [DATUM]. Bei Fragen stehe ich gerne zur Verfuegung.",
"category": "hausaufgaben"
},
{
"id": "3",
"name": "Entschuldigung bestaetigen",
"content": "Ich bestaetige den Erhalt der Entschuldigung fuer [NAME] am [DATUM]. Die Fehlzeiten wurden entsprechend vermerkt.",
"category": "entschuldigung"
},
{
"id": "4",
"name": "Gespraechsanfrage",
"content": "Ich wuerde gerne einen Termin fuer ein Gespraech mit Ihnen vereinbaren, um [THEMA] zu besprechen. Waeren Sie am [DATUM] um [UHRZEIT] verfuegbar?",
"category": "gespraech"
},
{
"id": "5",
"name": "Krankmeldung bestaetigen",
"content": "Vielen Dank fuer Ihre Krankmeldung fuer [NAME]. Ich wuensche gute Besserung. Bitte reichen Sie eine schriftliche Entschuldigung nach, sobald Ihr Kind wieder gesund ist.",
"category": "krankmeldung"
}
]
+139
View File
@@ -0,0 +1,139 @@
"""
Messenger API - Pydantic Models.
Data models for contacts, conversations, messages, and groups.
"""
from typing import List, Optional
from pydantic import BaseModel, Field
# ==========================================
# CONTACT MODELS
# ==========================================
class ContactBase(BaseModel):
"""Basis-Modell fuer Kontakte."""
name: str = Field(..., min_length=1, max_length=200)
email: Optional[str] = None
phone: Optional[str] = None
role: str = Field(default="parent", description="parent, teacher, staff, student")
student_name: Optional[str] = Field(None, description="Name des zugehoerigen Schuelers")
class_name: Optional[str] = Field(None, description="Klasse z.B. 10a")
notes: Optional[str] = None
tags: List[str] = Field(default_factory=list)
matrix_id: Optional[str] = Field(None, description="Matrix-ID z.B. @user:matrix.org")
preferred_channel: str = Field(default="email", description="email, matrix, pwa")
class ContactCreate(ContactBase):
"""Model fuer neuen Kontakt."""
pass
class Contact(ContactBase):
"""Vollstaendiger Kontakt mit ID."""
id: str
created_at: str
updated_at: str
online: bool = False
last_seen: Optional[str] = None
class ContactUpdate(BaseModel):
"""Update-Model fuer Kontakte."""
name: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
role: Optional[str] = None
student_name: Optional[str] = None
class_name: Optional[str] = None
notes: Optional[str] = None
tags: Optional[List[str]] = None
matrix_id: Optional[str] = None
preferred_channel: Optional[str] = None
# ==========================================
# GROUP MODELS
# ==========================================
class GroupBase(BaseModel):
"""Basis-Modell fuer Gruppen."""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = None
group_type: str = Field(default="class", description="class, department, custom")
class GroupCreate(GroupBase):
"""Model fuer neue Gruppe."""
member_ids: List[str] = Field(default_factory=list)
class Group(GroupBase):
"""Vollstaendige Gruppe mit ID."""
id: str
member_ids: List[str] = []
created_at: str
updated_at: str
# ==========================================
# MESSAGE MODELS
# ==========================================
class MessageBase(BaseModel):
"""Basis-Modell fuer Nachrichten."""
content: str = Field(..., min_length=1)
content_type: str = Field(default="text", description="text, file, image")
file_url: Optional[str] = None
send_email: bool = Field(default=False, description="Nachricht auch per Email senden")
class MessageCreate(MessageBase):
"""Model fuer neue Nachricht."""
conversation_id: str
class Message(MessageBase):
"""Vollstaendige Nachricht mit ID."""
id: str
conversation_id: str
sender_id: str # "self" fuer eigene Nachrichten
timestamp: str
read: bool = False
read_at: Optional[str] = None
email_sent: bool = False
email_sent_at: Optional[str] = None
email_error: Optional[str] = None
# ==========================================
# CONVERSATION MODELS
# ==========================================
class ConversationBase(BaseModel):
"""Basis-Modell fuer Konversationen."""
name: Optional[str] = None
is_group: bool = False
class Conversation(ConversationBase):
"""Vollstaendige Konversation mit ID."""
id: str
participant_ids: List[str] = []
group_id: Optional[str] = None
created_at: str
updated_at: str
last_message: Optional[str] = None
last_message_time: Optional[str] = None
unread_count: int = 0
class CSVImportResult(BaseModel):
"""Ergebnis eines CSV-Imports."""
imported: int
skipped: int
errors: List[str]
contacts: List[Contact]
+14 -840
View File
@@ -1,848 +1,22 @@
"""
BreakPilot Recording API
BreakPilot Recording API — Barrel Re-export.
Verwaltet Jibri Meeting-Aufzeichnungen und deren Metadaten.
Empfaengt Webhooks von Jibri nach Upload zu MinIO.
Split into:
- recording_models.py: Pydantic models & config
- recording_helpers.py: In-memory storage & utilities
- recording_routes.py: Core recording CRUD routes
- recording_transcription.py: Transcription routes
- recording_minutes.py: Meeting minutes routes
"""
import os
import uuid
from datetime import datetime, timedelta
from typing import Optional, List
from pydantic import BaseModel, Field
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException, Query, Depends, Request
from fastapi.responses import JSONResponse
from recording_routes import router as _routes_router
from recording_transcription import router as _transcription_router
from recording_minutes import router as _minutes_router
router = APIRouter(prefix="/api/recordings", tags=["Recordings"])
# ==========================================
# ENVIRONMENT CONFIGURATION
# ==========================================
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000")
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "breakpilot")
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "breakpilot123")
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-recordings")
MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true"
# Default retention period in days (DSGVO compliance)
DEFAULT_RETENTION_DAYS = int(os.getenv("RECORDING_RETENTION_DAYS", "365"))
# ==========================================
# PYDANTIC MODELS
# ==========================================
class JibriWebhookPayload(BaseModel):
"""Webhook payload from Jibri finalize.sh script."""
event: str = Field(..., description="Event type: recording_completed")
recording_name: str = Field(..., description="Unique recording identifier")
storage_path: str = Field(..., description="Path in MinIO bucket")
audio_path: Optional[str] = Field(None, description="Extracted audio path")
file_size_bytes: int = Field(..., description="Video file size in bytes")
timestamp: str = Field(..., description="ISO timestamp of upload")
class RecordingCreate(BaseModel):
"""Manual recording creation (for testing)."""
meeting_id: str
title: Optional[str] = None
storage_path: str
audio_path: Optional[str] = None
duration_seconds: Optional[int] = None
participant_count: Optional[int] = 0
retention_days: Optional[int] = DEFAULT_RETENTION_DAYS
class RecordingResponse(BaseModel):
"""Recording details response."""
id: str
meeting_id: str
title: Optional[str]
storage_path: str
audio_path: Optional[str]
file_size_bytes: Optional[int]
duration_seconds: Optional[int]
participant_count: int
status: str
recorded_at: datetime
retention_days: int
retention_expires_at: datetime
transcription_status: Optional[str] = None
transcription_id: Optional[str] = None
class RecordingListResponse(BaseModel):
"""Paginated list of recordings."""
recordings: List[RecordingResponse]
total: int
page: int
page_size: int
class TranscriptionRequest(BaseModel):
"""Request to start transcription."""
language: str = Field(default="de", description="Language code: de, en, etc.")
model: str = Field(default="large-v3", description="Whisper model to use")
priority: int = Field(default=0, description="Queue priority (higher = sooner)")
class TranscriptionStatusResponse(BaseModel):
"""Transcription status and progress."""
id: str
recording_id: str
status: str
language: str
model: str
word_count: Optional[int]
confidence_score: Optional[float]
processing_duration_seconds: Optional[int]
error_message: Optional[str]
created_at: datetime
completed_at: Optional[datetime]
# ==========================================
# IN-MEMORY STORAGE (Dev Mode)
# ==========================================
# In production, these would be database queries
_recordings_store: dict = {}
_transcriptions_store: dict = {}
_audit_log: list = []
def log_audit(
action: str,
recording_id: Optional[str] = None,
transcription_id: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[dict] = None
):
"""Log audit event for DSGVO compliance."""
_audit_log.append({
"id": str(uuid.uuid4()),
"recording_id": recording_id,
"transcription_id": transcription_id,
"user_id": user_id,
"action": action,
"metadata": metadata or {},
"created_at": datetime.utcnow().isoformat()
})
# ==========================================
# WEBHOOK ENDPOINT (Jibri)
# ==========================================
@router.post("/webhook")
async def jibri_webhook(payload: JibriWebhookPayload, request: Request):
"""
Webhook endpoint called by Jibri finalize.sh after upload.
This creates a new recording entry and optionally triggers transcription.
"""
if payload.event != "recording_completed":
return JSONResponse(
status_code=400,
content={"error": f"Unknown event type: {payload.event}"}
)
# Extract meeting_id from recording_name (format: meetingId_timestamp)
parts = payload.recording_name.split("_")
meeting_id = parts[0] if parts else payload.recording_name
# Create recording entry
recording_id = str(uuid.uuid4())
recorded_at = datetime.utcnow()
recording = {
"id": recording_id,
"meeting_id": meeting_id,
"jibri_session_id": payload.recording_name,
"title": f"Recording {meeting_id}",
"storage_path": payload.storage_path,
"audio_path": payload.audio_path,
"file_size_bytes": payload.file_size_bytes,
"duration_seconds": None, # Will be updated after analysis
"participant_count": 0,
"status": "uploaded",
"recorded_at": recorded_at.isoformat(),
"retention_days": DEFAULT_RETENTION_DAYS,
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat()
}
_recordings_store[recording_id] = recording
# Log the creation
log_audit(
action="created",
recording_id=recording_id,
metadata={
"source": "jibri_webhook",
"storage_path": payload.storage_path,
"file_size_bytes": payload.file_size_bytes
}
)
return {
"success": True,
"recording_id": recording_id,
"meeting_id": meeting_id,
"status": "uploaded",
"message": "Recording registered successfully"
}
# ==========================================
# HEALTH & AUDIT ENDPOINTS (must be before parameterized routes)
# ==========================================
@router.get("/health")
async def recordings_health():
"""Health check for recording service."""
return {
"status": "healthy",
"recordings_count": len(_recordings_store),
"transcriptions_count": len(_transcriptions_store),
"minio_endpoint": MINIO_ENDPOINT,
"bucket": MINIO_BUCKET
}
@router.get("/audit/log")
async def get_audit_log(
recording_id: Optional[str] = Query(None),
action: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000)
):
"""
Get audit log entries (DSGVO compliance).
Admin-only endpoint for reviewing recording access history.
"""
logs = _audit_log.copy()
if recording_id:
logs = [l for l in logs if l.get("recording_id") == recording_id]
if action:
logs = [l for l in logs if l.get("action") == action]
# Sort by created_at descending
logs.sort(key=lambda x: x["created_at"], reverse=True)
return {
"entries": logs[:limit],
"total": len(logs)
}
# ==========================================
# RECORDING MANAGEMENT ENDPOINTS
# ==========================================
@router.get("", response_model=RecordingListResponse)
async def list_recordings(
status: Optional[str] = Query(None, description="Filter by status"),
meeting_id: Optional[str] = Query(None, description="Filter by meeting ID"),
page: int = Query(1, ge=1, description="Page number"),
page_size: int = Query(20, ge=1, le=100, description="Items per page")
):
"""
List all recordings with optional filtering.
Supports pagination and filtering by status or meeting ID.
"""
# Filter recordings
recordings = list(_recordings_store.values())
if status:
recordings = [r for r in recordings if r["status"] == status]
if meeting_id:
recordings = [r for r in recordings if r["meeting_id"] == meeting_id]
# Sort by recorded_at descending
recordings.sort(key=lambda x: x["recorded_at"], reverse=True)
# Paginate
total = len(recordings)
start = (page - 1) * page_size
end = start + page_size
page_recordings = recordings[start:end]
# Convert to response format
result = []
for rec in page_recordings:
recorded_at = datetime.fromisoformat(rec["recorded_at"])
retention_expires = recorded_at + timedelta(days=rec["retention_days"])
# Check for transcription
trans = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == rec["id"]),
None
)
result.append(RecordingResponse(
id=rec["id"],
meeting_id=rec["meeting_id"],
title=rec.get("title"),
storage_path=rec["storage_path"],
audio_path=rec.get("audio_path"),
file_size_bytes=rec.get("file_size_bytes"),
duration_seconds=rec.get("duration_seconds"),
participant_count=rec.get("participant_count", 0),
status=rec["status"],
recorded_at=recorded_at,
retention_days=rec["retention_days"],
retention_expires_at=retention_expires,
transcription_status=trans["status"] if trans else None,
transcription_id=trans["id"] if trans else None
))
return RecordingListResponse(
recordings=result,
total=total,
page=page,
page_size=page_size
)
@router.get("/{recording_id}", response_model=RecordingResponse)
async def get_recording(recording_id: str):
"""
Get details for a specific recording.
"""
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
# Log view action
log_audit(action="viewed", recording_id=recording_id)
recorded_at = datetime.fromisoformat(recording["recorded_at"])
retention_expires = recorded_at + timedelta(days=recording["retention_days"])
# Check for transcription
trans = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
return RecordingResponse(
id=recording["id"],
meeting_id=recording["meeting_id"],
title=recording.get("title"),
storage_path=recording["storage_path"],
audio_path=recording.get("audio_path"),
file_size_bytes=recording.get("file_size_bytes"),
duration_seconds=recording.get("duration_seconds"),
participant_count=recording.get("participant_count", 0),
status=recording["status"],
recorded_at=recorded_at,
retention_days=recording["retention_days"],
retention_expires_at=retention_expires,
transcription_status=trans["status"] if trans else None,
transcription_id=trans["id"] if trans else None
)
@router.delete("/{recording_id}")
async def delete_recording(
recording_id: str,
reason: str = Query(..., description="Reason for deletion (DSGVO audit)")
):
"""
Soft-delete a recording (DSGVO compliance).
The recording is marked as deleted but retained for audit purposes.
Actual file deletion happens after the audit retention period.
"""
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
# Soft delete
recording["status"] = "deleted"
recording["deleted_at"] = datetime.utcnow().isoformat()
recording["updated_at"] = datetime.utcnow().isoformat()
# Log deletion with reason
log_audit(
action="deleted",
recording_id=recording_id,
metadata={"reason": reason}
)
return {
"success": True,
"recording_id": recording_id,
"status": "deleted",
"message": "Recording marked for deletion"
}
# ==========================================
# TRANSCRIPTION ENDPOINTS
# ==========================================
@router.post("/{recording_id}/transcribe", response_model=TranscriptionStatusResponse)
async def start_transcription(recording_id: str, request: TranscriptionRequest):
"""
Start transcription for a recording.
Queues the recording for processing by the transcription worker.
"""
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
if recording["status"] == "deleted":
raise HTTPException(status_code=400, detail="Cannot transcribe deleted recording")
# Check if transcription already exists
existing = next(
(t for t in _transcriptions_store.values()
if t["recording_id"] == recording_id and t["status"] != "failed"),
None
)
if existing:
raise HTTPException(
status_code=409,
detail=f"Transcription already exists with status: {existing['status']}"
)
# Create transcription entry
transcription_id = str(uuid.uuid4())
now = datetime.utcnow()
transcription = {
"id": transcription_id,
"recording_id": recording_id,
"language": request.language,
"model": request.model,
"status": "pending",
"full_text": None,
"word_count": None,
"confidence_score": None,
"vtt_path": None,
"srt_path": None,
"json_path": None,
"error_message": None,
"processing_started_at": None,
"processing_completed_at": None,
"processing_duration_seconds": None,
"created_at": now.isoformat(),
"updated_at": now.isoformat()
}
_transcriptions_store[transcription_id] = transcription
# Update recording status
recording["status"] = "processing"
recording["updated_at"] = now.isoformat()
# Log transcription start
log_audit(
action="transcription_started",
recording_id=recording_id,
transcription_id=transcription_id,
metadata={"language": request.language, "model": request.model}
)
# TODO: Queue job to Redis/Valkey for transcription worker
# from redis import Redis
# from rq import Queue
# q = Queue(connection=Redis.from_url(os.getenv("REDIS_URL")))
# q.enqueue('transcription_worker.tasks.transcribe', transcription_id, ...)
return TranscriptionStatusResponse(
id=transcription_id,
recording_id=recording_id,
status="pending",
language=request.language,
model=request.model,
word_count=None,
confidence_score=None,
processing_duration_seconds=None,
error_message=None,
created_at=now,
completed_at=None
)
@router.get("/{recording_id}/transcription", response_model=TranscriptionStatusResponse)
async def get_transcription_status(recording_id: str):
"""
Get transcription status for a recording.
"""
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=404, detail="No transcription found for this recording")
return TranscriptionStatusResponse(
id=transcription["id"],
recording_id=transcription["recording_id"],
status=transcription["status"],
language=transcription["language"],
model=transcription["model"],
word_count=transcription.get("word_count"),
confidence_score=transcription.get("confidence_score"),
processing_duration_seconds=transcription.get("processing_duration_seconds"),
error_message=transcription.get("error_message"),
created_at=datetime.fromisoformat(transcription["created_at"]),
completed_at=(
datetime.fromisoformat(transcription["processing_completed_at"])
if transcription.get("processing_completed_at") else None
)
)
@router.get("/{recording_id}/transcription/text")
async def get_transcription_text(recording_id: str):
"""
Get the full transcription text.
"""
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=404, detail="No transcription found for this recording")
if transcription["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Transcription not ready. Status: {transcription['status']}"
)
return {
"transcription_id": transcription["id"],
"recording_id": recording_id,
"language": transcription["language"],
"text": transcription.get("full_text", ""),
"word_count": transcription.get("word_count", 0)
}
@router.get("/{recording_id}/transcription/vtt")
async def get_transcription_vtt(recording_id: str):
"""
Download transcription as WebVTT subtitle file.
"""
from fastapi.responses import PlainTextResponse
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=404, detail="No transcription found for this recording")
if transcription["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Transcription not ready. Status: {transcription['status']}"
)
# Generate VTT content
# In production, this would read from the stored VTT file
vtt_content = "WEBVTT\n\n"
text = transcription.get("full_text", "")
if text:
# Simple VTT generation - split into sentences
sentences = text.replace(".", ".\n").split("\n")
time_offset = 0
for sentence in sentences:
sentence = sentence.strip()
if sentence:
start = format_vtt_time(time_offset)
# Estimate ~3 seconds per sentence
time_offset += 3000
end = format_vtt_time(time_offset)
vtt_content += f"{start} --> {end}\n{sentence}\n\n"
return PlainTextResponse(
content=vtt_content,
media_type="text/vtt",
headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.vtt"}
)
@router.get("/{recording_id}/transcription/srt")
async def get_transcription_srt(recording_id: str):
"""
Download transcription as SRT subtitle file.
"""
from fastapi.responses import PlainTextResponse
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=404, detail="No transcription found for this recording")
if transcription["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Transcription not ready. Status: {transcription['status']}"
)
# Generate SRT content
srt_content = ""
text = transcription.get("full_text", "")
if text:
sentences = text.replace(".", ".\n").split("\n")
time_offset = 0
index = 1
for sentence in sentences:
sentence = sentence.strip()
if sentence:
start = format_srt_time(time_offset)
time_offset += 3000
end = format_srt_time(time_offset)
srt_content += f"{index}\n{start} --> {end}\n{sentence}\n\n"
index += 1
return PlainTextResponse(
content=srt_content,
media_type="text/plain",
headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.srt"}
)
def format_vtt_time(ms: int) -> str:
"""Format milliseconds to VTT timestamp (HH:MM:SS.mmm)."""
hours = ms // 3600000
minutes = (ms % 3600000) // 60000
seconds = (ms % 60000) // 1000
millis = ms % 1000
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{millis:03d}"
def format_srt_time(ms: int) -> str:
"""Format milliseconds to SRT timestamp (HH:MM:SS,mmm)."""
hours = ms // 3600000
minutes = (ms % 3600000) // 60000
seconds = (ms % 60000) // 1000
millis = ms % 1000
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}"
@router.get("/{recording_id}/download")
async def download_recording(recording_id: str):
"""
Download the recording file.
In production, this would generate a presigned URL to MinIO.
"""
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
if recording["status"] == "deleted":
raise HTTPException(status_code=410, detail="Recording has been deleted")
# Log download action
log_audit(action="downloaded", recording_id=recording_id)
# In production, generate presigned URL to MinIO
# For now, return info about where the file is
return {
"recording_id": recording_id,
"storage_path": recording["storage_path"],
"file_size_bytes": recording.get("file_size_bytes"),
"message": "In production, this would redirect to a presigned MinIO URL"
}
# ==========================================
# MEETING MINUTES ENDPOINTS
# ==========================================
# In-memory store for meeting minutes (dev mode)
_minutes_store: dict = {}
@router.post("/{recording_id}/minutes")
async def generate_meeting_minutes(
recording_id: str,
title: Optional[str] = Query(None, description="Meeting-Titel"),
model: str = Query("breakpilot-teacher-8b", description="LLM Modell")
):
"""
Generiert KI-basierte Meeting Minutes aus der Transkription.
Nutzt das LLM Gateway (Ollama/vLLM) fuer lokale Verarbeitung.
"""
from meeting_minutes_generator import get_minutes_generator, MeetingMinutes
# Check recording exists
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
# Check transcription exists and is completed
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=400, detail="No transcription found. Please transcribe first.")
if transcription["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Transcription not ready. Status: {transcription['status']}"
)
# Check if minutes already exist
existing = _minutes_store.get(recording_id)
if existing and existing.get("status") == "completed":
# Return existing minutes
return existing
# Get transcript text
transcript_text = transcription.get("full_text", "")
if not transcript_text:
raise HTTPException(status_code=400, detail="Transcription has no text content")
# Generate meeting minutes
generator = get_minutes_generator()
try:
minutes = await generator.generate(
transcript=transcript_text,
recording_id=recording_id,
transcription_id=transcription["id"],
title=title,
date=recording.get("recorded_at", "")[:10] if recording.get("recorded_at") else None,
duration_minutes=recording.get("duration_seconds", 0) // 60 if recording.get("duration_seconds") else None,
participant_count=recording.get("participant_count", 0),
model=model
)
# Store minutes
minutes_dict = minutes.model_dump()
minutes_dict["generated_at"] = minutes.generated_at.isoformat()
_minutes_store[recording_id] = minutes_dict
# Log action
log_audit(
action="minutes_generated",
recording_id=recording_id,
metadata={"model": model, "generation_time": minutes.generation_time_seconds}
)
return minutes_dict
except Exception as e:
raise HTTPException(status_code=500, detail=f"Minutes generation failed: {str(e)}")
@router.get("/{recording_id}/minutes")
async def get_meeting_minutes(recording_id: str):
"""
Ruft generierte Meeting Minutes ab.
"""
minutes = _minutes_store.get(recording_id)
if not minutes:
raise HTTPException(status_code=404, detail="No meeting minutes found. Generate them first with POST.")
return minutes
@router.get("/{recording_id}/minutes/markdown")
async def get_minutes_markdown(recording_id: str):
"""
Exportiert Meeting Minutes als Markdown.
"""
from fastapi.responses import PlainTextResponse
from meeting_minutes_generator import minutes_to_markdown, MeetingMinutes
minutes_dict = _minutes_store.get(recording_id)
if not minutes_dict:
raise HTTPException(status_code=404, detail="No meeting minutes found")
# Convert dict back to MeetingMinutes
minutes_dict_copy = minutes_dict.copy()
if isinstance(minutes_dict_copy.get("generated_at"), str):
minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"])
minutes = MeetingMinutes(**minutes_dict_copy)
markdown = minutes_to_markdown(minutes)
return PlainTextResponse(
content=markdown,
media_type="text/markdown",
headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.md"}
)
@router.get("/{recording_id}/minutes/html")
async def get_minutes_html(recording_id: str):
"""
Exportiert Meeting Minutes als HTML.
"""
from fastapi.responses import HTMLResponse
from meeting_minutes_generator import minutes_to_html, MeetingMinutes
minutes_dict = _minutes_store.get(recording_id)
if not minutes_dict:
raise HTTPException(status_code=404, detail="No meeting minutes found")
# Convert dict back to MeetingMinutes
minutes_dict_copy = minutes_dict.copy()
if isinstance(minutes_dict_copy.get("generated_at"), str):
minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"])
minutes = MeetingMinutes(**minutes_dict_copy)
html = minutes_to_html(minutes)
return HTMLResponse(content=html)
@router.get("/{recording_id}/minutes/pdf")
async def get_minutes_pdf(recording_id: str):
"""
Exportiert Meeting Minutes als PDF.
Benoetigt WeasyPrint (pip install weasyprint).
"""
from meeting_minutes_generator import minutes_to_html, MeetingMinutes
minutes_dict = _minutes_store.get(recording_id)
if not minutes_dict:
raise HTTPException(status_code=404, detail="No meeting minutes found")
# Convert dict back to MeetingMinutes
minutes_dict_copy = minutes_dict.copy()
if isinstance(minutes_dict_copy.get("generated_at"), str):
minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"])
minutes = MeetingMinutes(**minutes_dict_copy)
html = minutes_to_html(minutes)
try:
from weasyprint import HTML
from fastapi.responses import Response
pdf_bytes = HTML(string=html).write_pdf()
return Response(
content=pdf_bytes,
media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.pdf"}
)
except ImportError:
raise HTTPException(
status_code=501,
detail="PDF export not available. Install weasyprint: pip install weasyprint"
)
router.include_router(_routes_router)
router.include_router(_transcription_router)
router.include_router(_minutes_router)
+57
View File
@@ -0,0 +1,57 @@
"""
Recording API - In-Memory Storage & Helpers.
Shared state and utility functions for recording endpoints.
"""
import uuid
from datetime import datetime
from typing import Optional
# ==========================================
# IN-MEMORY STORAGE (Dev Mode)
# ==========================================
# In production, these would be database queries
_recordings_store: dict = {}
_transcriptions_store: dict = {}
_audit_log: list = []
_minutes_store: dict = {}
def log_audit(
action: str,
recording_id: Optional[str] = None,
transcription_id: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[dict] = None
):
"""Log audit event for DSGVO compliance."""
_audit_log.append({
"id": str(uuid.uuid4()),
"recording_id": recording_id,
"transcription_id": transcription_id,
"user_id": user_id,
"action": action,
"metadata": metadata or {},
"created_at": datetime.utcnow().isoformat()
})
def format_vtt_time(ms: int) -> str:
"""Format milliseconds to VTT timestamp (HH:MM:SS.mmm)."""
hours = ms // 3600000
minutes = (ms % 3600000) // 60000
seconds = (ms % 60000) // 1000
millis = ms % 1000
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{millis:03d}"
def format_srt_time(ms: int) -> str:
"""Format milliseconds to SRT timestamp (HH:MM:SS,mmm)."""
hours = ms // 3600000
minutes = (ms % 3600000) // 60000
seconds = (ms % 60000) // 1000
millis = ms % 1000
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}"
+187
View File
@@ -0,0 +1,187 @@
"""
Recording API - Meeting Minutes Routes.
Generate, retrieve, and export KI-based meeting minutes.
"""
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import PlainTextResponse, HTMLResponse
from recording_helpers import (
_recordings_store,
_transcriptions_store,
_minutes_store,
log_audit,
)
router = APIRouter(tags=["Recordings"])
# ==========================================
# MEETING MINUTES ENDPOINTS
# ==========================================
@router.post("/{recording_id}/minutes")
async def generate_meeting_minutes(
recording_id: str,
title: Optional[str] = Query(None, description="Meeting-Titel"),
model: str = Query("breakpilot-teacher-8b", description="LLM Modell")
):
"""
Generiert KI-basierte Meeting Minutes aus der Transkription.
Nutzt das LLM Gateway (Ollama/vLLM) fuer lokale Verarbeitung.
"""
from meeting_minutes_generator import get_minutes_generator, MeetingMinutes
# Check recording exists
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
# Check transcription exists and is completed
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=400, detail="No transcription found. Please transcribe first.")
if transcription["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Transcription not ready. Status: {transcription['status']}"
)
# Check if minutes already exist
existing = _minutes_store.get(recording_id)
if existing and existing.get("status") == "completed":
# Return existing minutes
return existing
# Get transcript text
transcript_text = transcription.get("full_text", "")
if not transcript_text:
raise HTTPException(status_code=400, detail="Transcription has no text content")
# Generate meeting minutes
generator = get_minutes_generator()
try:
minutes = await generator.generate(
transcript=transcript_text,
recording_id=recording_id,
transcription_id=transcription["id"],
title=title,
date=recording.get("recorded_at", "")[:10] if recording.get("recorded_at") else None,
duration_minutes=recording.get("duration_seconds", 0) // 60 if recording.get("duration_seconds") else None,
participant_count=recording.get("participant_count", 0),
model=model
)
# Store minutes
minutes_dict = minutes.model_dump()
minutes_dict["generated_at"] = minutes.generated_at.isoformat()
_minutes_store[recording_id] = minutes_dict
# Log action
log_audit(
action="minutes_generated",
recording_id=recording_id,
metadata={"model": model, "generation_time": minutes.generation_time_seconds}
)
return minutes_dict
except Exception as e:
raise HTTPException(status_code=500, detail=f"Minutes generation failed: {str(e)}")
@router.get("/{recording_id}/minutes")
async def get_meeting_minutes(recording_id: str):
"""
Ruft generierte Meeting Minutes ab.
"""
minutes = _minutes_store.get(recording_id)
if not minutes:
raise HTTPException(status_code=404, detail="No meeting minutes found. Generate them first with POST.")
return minutes
def _load_minutes(recording_id: str):
"""Load and convert stored minutes dict back to MeetingMinutes."""
from meeting_minutes_generator import MeetingMinutes
minutes_dict = _minutes_store.get(recording_id)
if not minutes_dict:
raise HTTPException(status_code=404, detail="No meeting minutes found")
minutes_dict_copy = minutes_dict.copy()
if isinstance(minutes_dict_copy.get("generated_at"), str):
minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"])
return MeetingMinutes(**minutes_dict_copy)
@router.get("/{recording_id}/minutes/markdown")
async def get_minutes_markdown(recording_id: str):
"""
Exportiert Meeting Minutes als Markdown.
"""
from meeting_minutes_generator import minutes_to_markdown
minutes = _load_minutes(recording_id)
markdown = minutes_to_markdown(minutes)
return PlainTextResponse(
content=markdown,
media_type="text/markdown",
headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.md"}
)
@router.get("/{recording_id}/minutes/html")
async def get_minutes_html(recording_id: str):
"""
Exportiert Meeting Minutes als HTML.
"""
from meeting_minutes_generator import minutes_to_html
minutes = _load_minutes(recording_id)
html = minutes_to_html(minutes)
return HTMLResponse(content=html)
@router.get("/{recording_id}/minutes/pdf")
async def get_minutes_pdf(recording_id: str):
"""
Exportiert Meeting Minutes als PDF.
Benoetigt WeasyPrint (pip install weasyprint).
"""
from meeting_minutes_generator import minutes_to_html
minutes = _load_minutes(recording_id)
html = minutes_to_html(minutes)
try:
from weasyprint import HTML
from fastapi.responses import Response
pdf_bytes = HTML(string=html).write_pdf()
return Response(
content=pdf_bytes,
media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.pdf"}
)
except ImportError:
raise HTTPException(
status_code=501,
detail="PDF export not available. Install weasyprint: pip install weasyprint"
)
+98
View File
@@ -0,0 +1,98 @@
"""
Recording API - Pydantic Models & Configuration.
Data models for recording, transcription, and webhook endpoints.
"""
import os
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
# ==========================================
# ENVIRONMENT CONFIGURATION
# ==========================================
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000")
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "breakpilot")
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "breakpilot123")
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-recordings")
MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true"
# Default retention period in days (DSGVO compliance)
DEFAULT_RETENTION_DAYS = int(os.getenv("RECORDING_RETENTION_DAYS", "365"))
# ==========================================
# PYDANTIC MODELS
# ==========================================
class JibriWebhookPayload(BaseModel):
"""Webhook payload from Jibri finalize.sh script."""
event: str = Field(..., description="Event type: recording_completed")
recording_name: str = Field(..., description="Unique recording identifier")
storage_path: str = Field(..., description="Path in MinIO bucket")
audio_path: Optional[str] = Field(None, description="Extracted audio path")
file_size_bytes: int = Field(..., description="Video file size in bytes")
timestamp: str = Field(..., description="ISO timestamp of upload")
class RecordingCreate(BaseModel):
"""Manual recording creation (for testing)."""
meeting_id: str
title: Optional[str] = None
storage_path: str
audio_path: Optional[str] = None
duration_seconds: Optional[int] = None
participant_count: Optional[int] = 0
retention_days: Optional[int] = DEFAULT_RETENTION_DAYS
class RecordingResponse(BaseModel):
"""Recording details response."""
id: str
meeting_id: str
title: Optional[str]
storage_path: str
audio_path: Optional[str]
file_size_bytes: Optional[int]
duration_seconds: Optional[int]
participant_count: int
status: str
recorded_at: datetime
retention_days: int
retention_expires_at: datetime
transcription_status: Optional[str] = None
transcription_id: Optional[str] = None
class RecordingListResponse(BaseModel):
"""Paginated list of recordings."""
recordings: List[RecordingResponse]
total: int
page: int
page_size: int
class TranscriptionRequest(BaseModel):
"""Request to start transcription."""
language: str = Field(default="de", description="Language code: de, en, etc.")
model: str = Field(default="large-v3", description="Whisper model to use")
priority: int = Field(default=0, description="Queue priority (higher = sooner)")
class TranscriptionStatusResponse(BaseModel):
"""Transcription status and progress."""
id: str
recording_id: str
status: str
language: str
model: str
word_count: Optional[int]
confidence_score: Optional[float]
processing_duration_seconds: Optional[int]
error_message: Optional[str]
created_at: datetime
completed_at: Optional[datetime]
+307
View File
@@ -0,0 +1,307 @@
"""
Recording API - Core Recording Routes.
Webhook, CRUD, health, audit, and download endpoints.
"""
import uuid
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import JSONResponse
from recording_models import (
JibriWebhookPayload,
RecordingResponse,
RecordingListResponse,
MINIO_ENDPOINT,
MINIO_BUCKET,
DEFAULT_RETENTION_DAYS,
)
from recording_helpers import (
_recordings_store,
_transcriptions_store,
_audit_log,
log_audit,
)
router = APIRouter(tags=["Recordings"])
# ==========================================
# WEBHOOK ENDPOINT (Jibri)
# ==========================================
@router.post("/webhook")
async def jibri_webhook(payload: JibriWebhookPayload, request: Request):
"""
Webhook endpoint called by Jibri finalize.sh after upload.
This creates a new recording entry and optionally triggers transcription.
"""
if payload.event != "recording_completed":
return JSONResponse(
status_code=400,
content={"error": f"Unknown event type: {payload.event}"}
)
# Extract meeting_id from recording_name (format: meetingId_timestamp)
parts = payload.recording_name.split("_")
meeting_id = parts[0] if parts else payload.recording_name
# Create recording entry
recording_id = str(uuid.uuid4())
recorded_at = datetime.utcnow()
recording = {
"id": recording_id,
"meeting_id": meeting_id,
"jibri_session_id": payload.recording_name,
"title": f"Recording {meeting_id}",
"storage_path": payload.storage_path,
"audio_path": payload.audio_path,
"file_size_bytes": payload.file_size_bytes,
"duration_seconds": None, # Will be updated after analysis
"participant_count": 0,
"status": "uploaded",
"recorded_at": recorded_at.isoformat(),
"retention_days": DEFAULT_RETENTION_DAYS,
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat()
}
_recordings_store[recording_id] = recording
# Log the creation
log_audit(
action="created",
recording_id=recording_id,
metadata={
"source": "jibri_webhook",
"storage_path": payload.storage_path,
"file_size_bytes": payload.file_size_bytes
}
)
return {
"success": True,
"recording_id": recording_id,
"meeting_id": meeting_id,
"status": "uploaded",
"message": "Recording registered successfully"
}
# ==========================================
# HEALTH & AUDIT ENDPOINTS (must be before parameterized routes)
# ==========================================
@router.get("/health")
async def recordings_health():
"""Health check for recording service."""
return {
"status": "healthy",
"recordings_count": len(_recordings_store),
"transcriptions_count": len(_transcriptions_store),
"minio_endpoint": MINIO_ENDPOINT,
"bucket": MINIO_BUCKET
}
@router.get("/audit/log")
async def get_audit_log(
recording_id: Optional[str] = Query(None),
action: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000)
):
"""
Get audit log entries (DSGVO compliance).
Admin-only endpoint for reviewing recording access history.
"""
logs = _audit_log.copy()
if recording_id:
logs = [l for l in logs if l.get("recording_id") == recording_id]
if action:
logs = [l for l in logs if l.get("action") == action]
# Sort by created_at descending
logs.sort(key=lambda x: x["created_at"], reverse=True)
return {
"entries": logs[:limit],
"total": len(logs)
}
# ==========================================
# RECORDING MANAGEMENT ENDPOINTS
# ==========================================
@router.get("", response_model=RecordingListResponse)
async def list_recordings(
status: Optional[str] = Query(None, description="Filter by status"),
meeting_id: Optional[str] = Query(None, description="Filter by meeting ID"),
page: int = Query(1, ge=1, description="Page number"),
page_size: int = Query(20, ge=1, le=100, description="Items per page")
):
"""
List all recordings with optional filtering.
Supports pagination and filtering by status or meeting ID.
"""
# Filter recordings
recordings = list(_recordings_store.values())
if status:
recordings = [r for r in recordings if r["status"] == status]
if meeting_id:
recordings = [r for r in recordings if r["meeting_id"] == meeting_id]
# Sort by recorded_at descending
recordings.sort(key=lambda x: x["recorded_at"], reverse=True)
# Paginate
total = len(recordings)
start = (page - 1) * page_size
end = start + page_size
page_recordings = recordings[start:end]
# Convert to response format
result = []
for rec in page_recordings:
recorded_at = datetime.fromisoformat(rec["recorded_at"])
retention_expires = recorded_at + timedelta(days=rec["retention_days"])
# Check for transcription
trans = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == rec["id"]),
None
)
result.append(RecordingResponse(
id=rec["id"],
meeting_id=rec["meeting_id"],
title=rec.get("title"),
storage_path=rec["storage_path"],
audio_path=rec.get("audio_path"),
file_size_bytes=rec.get("file_size_bytes"),
duration_seconds=rec.get("duration_seconds"),
participant_count=rec.get("participant_count", 0),
status=rec["status"],
recorded_at=recorded_at,
retention_days=rec["retention_days"],
retention_expires_at=retention_expires,
transcription_status=trans["status"] if trans else None,
transcription_id=trans["id"] if trans else None
))
return RecordingListResponse(
recordings=result,
total=total,
page=page,
page_size=page_size
)
@router.get("/{recording_id}", response_model=RecordingResponse)
async def get_recording(recording_id: str):
"""
Get details for a specific recording.
"""
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
# Log view action
log_audit(action="viewed", recording_id=recording_id)
recorded_at = datetime.fromisoformat(recording["recorded_at"])
retention_expires = recorded_at + timedelta(days=recording["retention_days"])
# Check for transcription
trans = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
return RecordingResponse(
id=recording["id"],
meeting_id=recording["meeting_id"],
title=recording.get("title"),
storage_path=recording["storage_path"],
audio_path=recording.get("audio_path"),
file_size_bytes=recording.get("file_size_bytes"),
duration_seconds=recording.get("duration_seconds"),
participant_count=recording.get("participant_count", 0),
status=recording["status"],
recorded_at=recorded_at,
retention_days=recording["retention_days"],
retention_expires_at=retention_expires,
transcription_status=trans["status"] if trans else None,
transcription_id=trans["id"] if trans else None
)
@router.delete("/{recording_id}")
async def delete_recording(
recording_id: str,
reason: str = Query(..., description="Reason for deletion (DSGVO audit)")
):
"""
Soft-delete a recording (DSGVO compliance).
The recording is marked as deleted but retained for audit purposes.
Actual file deletion happens after the audit retention period.
"""
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
# Soft delete
recording["status"] = "deleted"
recording["deleted_at"] = datetime.utcnow().isoformat()
recording["updated_at"] = datetime.utcnow().isoformat()
# Log deletion with reason
log_audit(
action="deleted",
recording_id=recording_id,
metadata={"reason": reason}
)
return {
"success": True,
"recording_id": recording_id,
"status": "deleted",
"message": "Recording marked for deletion"
}
@router.get("/{recording_id}/download")
async def download_recording(recording_id: str):
"""
Download the recording file.
In production, this would generate a presigned URL to MinIO.
"""
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
if recording["status"] == "deleted":
raise HTTPException(status_code=410, detail="Recording has been deleted")
# Log download action
log_audit(action="downloaded", recording_id=recording_id)
# In production, generate presigned URL to MinIO
# For now, return info about where the file is
return {
"recording_id": recording_id,
"storage_path": recording["storage_path"],
"file_size_bytes": recording.get("file_size_bytes"),
"message": "In production, this would redirect to a presigned MinIO URL"
}
+250
View File
@@ -0,0 +1,250 @@
"""
Recording API - Transcription Routes.
Start transcription, get status, download VTT/SRT subtitle files.
"""
import uuid
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException
from fastapi.responses import PlainTextResponse
from recording_models import (
TranscriptionRequest,
TranscriptionStatusResponse,
)
from recording_helpers import (
_recordings_store,
_transcriptions_store,
log_audit,
format_vtt_time,
format_srt_time,
)
router = APIRouter(tags=["Recordings"])
# ==========================================
# TRANSCRIPTION ENDPOINTS
# ==========================================
@router.post("/{recording_id}/transcribe", response_model=TranscriptionStatusResponse)
async def start_transcription(recording_id: str, request: TranscriptionRequest):
"""
Start transcription for a recording.
Queues the recording for processing by the transcription worker.
"""
recording = _recordings_store.get(recording_id)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
if recording["status"] == "deleted":
raise HTTPException(status_code=400, detail="Cannot transcribe deleted recording")
# Check if transcription already exists
existing = next(
(t for t in _transcriptions_store.values()
if t["recording_id"] == recording_id and t["status"] != "failed"),
None
)
if existing:
raise HTTPException(
status_code=409,
detail=f"Transcription already exists with status: {existing['status']}"
)
# Create transcription entry
transcription_id = str(uuid.uuid4())
now = datetime.utcnow()
transcription = {
"id": transcription_id,
"recording_id": recording_id,
"language": request.language,
"model": request.model,
"status": "pending",
"full_text": None,
"word_count": None,
"confidence_score": None,
"vtt_path": None,
"srt_path": None,
"json_path": None,
"error_message": None,
"processing_started_at": None,
"processing_completed_at": None,
"processing_duration_seconds": None,
"created_at": now.isoformat(),
"updated_at": now.isoformat()
}
_transcriptions_store[transcription_id] = transcription
# Update recording status
recording["status"] = "processing"
recording["updated_at"] = now.isoformat()
# Log transcription start
log_audit(
action="transcription_started",
recording_id=recording_id,
transcription_id=transcription_id,
metadata={"language": request.language, "model": request.model}
)
# TODO: Queue job to Redis/Valkey for transcription worker
return TranscriptionStatusResponse(
id=transcription_id,
recording_id=recording_id,
status="pending",
language=request.language,
model=request.model,
word_count=None,
confidence_score=None,
processing_duration_seconds=None,
error_message=None,
created_at=now,
completed_at=None
)
@router.get("/{recording_id}/transcription", response_model=TranscriptionStatusResponse)
async def get_transcription_status(recording_id: str):
"""
Get transcription status for a recording.
"""
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=404, detail="No transcription found for this recording")
return TranscriptionStatusResponse(
id=transcription["id"],
recording_id=transcription["recording_id"],
status=transcription["status"],
language=transcription["language"],
model=transcription["model"],
word_count=transcription.get("word_count"),
confidence_score=transcription.get("confidence_score"),
processing_duration_seconds=transcription.get("processing_duration_seconds"),
error_message=transcription.get("error_message"),
created_at=datetime.fromisoformat(transcription["created_at"]),
completed_at=(
datetime.fromisoformat(transcription["processing_completed_at"])
if transcription.get("processing_completed_at") else None
)
)
@router.get("/{recording_id}/transcription/text")
async def get_transcription_text(recording_id: str):
"""
Get the full transcription text.
"""
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=404, detail="No transcription found for this recording")
if transcription["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Transcription not ready. Status: {transcription['status']}"
)
return {
"transcription_id": transcription["id"],
"recording_id": recording_id,
"language": transcription["language"],
"text": transcription.get("full_text", ""),
"word_count": transcription.get("word_count", 0)
}
@router.get("/{recording_id}/transcription/vtt")
async def get_transcription_vtt(recording_id: str):
"""
Download transcription as WebVTT subtitle file.
"""
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=404, detail="No transcription found for this recording")
if transcription["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Transcription not ready. Status: {transcription['status']}"
)
# Generate VTT content
vtt_content = "WEBVTT\n\n"
text = transcription.get("full_text", "")
if text:
sentences = text.replace(".", ".\n").split("\n")
time_offset = 0
for sentence in sentences:
sentence = sentence.strip()
if sentence:
start = format_vtt_time(time_offset)
time_offset += 3000
end = format_vtt_time(time_offset)
vtt_content += f"{start} --> {end}\n{sentence}\n\n"
return PlainTextResponse(
content=vtt_content,
media_type="text/vtt",
headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.vtt"}
)
@router.get("/{recording_id}/transcription/srt")
async def get_transcription_srt(recording_id: str):
"""
Download transcription as SRT subtitle file.
"""
transcription = next(
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
None
)
if not transcription:
raise HTTPException(status_code=404, detail="No transcription found for this recording")
if transcription["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Transcription not ready. Status: {transcription['status']}"
)
# Generate SRT content
srt_content = ""
text = transcription.get("full_text", "")
if text:
sentences = text.replace(".", ".\n").split("\n")
time_offset = 0
index = 1
for sentence in sentences:
sentence = sentence.strip()
if sentence:
start = format_srt_time(time_offset)
time_offset += 3000
end = format_srt_time(time_offset)
srt_content += f"{index}\n{start} --> {end}\n{sentence}\n\n"
index += 1
return PlainTextResponse(
content=srt_content,
media_type="text/plain",
headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.srt"}
)
+20 -746
View File
@@ -1,751 +1,25 @@
# ==============================================
# Breakpilot Drive - Unit Analytics API
# ==============================================
# Erweiterte Analytics fuer Lernfortschritt:
# - Pre/Post Gain Visualisierung
# - Misconception-Tracking
# - Stop-Level Analytics
# - Aggregierte Klassen-Statistiken
# - Export-Funktionen
"""
Breakpilot Drive - Unit Analytics API — Barrel Re-export.
from fastapi import APIRouter, HTTPException, Query, Depends, Request
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from enum import Enum
import os
import logging
import statistics
Erweiterte Analytics fuer Lernfortschritt:
- Pre/Post Gain Visualisierung
- Misconception-Tracking
- Stop-Level Analytics
- Aggregierte Klassen-Statistiken
- Export-Funktionen
logger = logging.getLogger(__name__)
Split into:
- unit_analytics_models.py: Pydantic models & enums
- unit_analytics_helpers.py: Database access & computation helpers
- unit_analytics_routes.py: Core analytics endpoint handlers
- unit_analytics_export.py: Export & dashboard endpoints
"""
# Feature flags
USE_DATABASE = os.getenv("GAME_USE_DATABASE", "true").lower() == "true"
from fastapi import APIRouter
from unit_analytics_routes import router as _routes_router
from unit_analytics_export import router as _export_router
router = APIRouter(prefix="/api/analytics", tags=["Unit Analytics"])
# ==============================================
# Pydantic Models
# ==============================================
class TimeRange(str, Enum):
"""Time range for analytics queries"""
WEEK = "week"
MONTH = "month"
QUARTER = "quarter"
ALL = "all"
class LearningGainData(BaseModel):
"""Pre/Post learning gain data point"""
student_id: str
student_name: str
unit_id: str
precheck_score: float
postcheck_score: float
learning_gain: float
percentile: Optional[float] = None
class LearningGainSummary(BaseModel):
"""Aggregated learning gain statistics"""
unit_id: str
unit_title: str
total_students: int
avg_precheck: float
avg_postcheck: float
avg_gain: float
median_gain: float
std_deviation: float
positive_gain_count: int
negative_gain_count: int
no_change_count: int
gain_distribution: Dict[str, int] # "-20+", "-10-0", "0-10", "10-20", "20+"
individual_gains: List[LearningGainData]
class StopPerformance(BaseModel):
"""Performance data for a single stop"""
stop_id: str
stop_label: str
attempts_total: int
success_rate: float
avg_time_seconds: float
avg_attempts_before_success: float
common_errors: List[str]
difficulty_rating: float # 1-5 based on performance
class UnitPerformanceDetail(BaseModel):
"""Detailed unit performance breakdown"""
unit_id: str
unit_title: str
template: str
total_sessions: int
completed_sessions: int
completion_rate: float
avg_duration_minutes: float
stops: List[StopPerformance]
bottleneck_stops: List[str] # Stops where students struggle most
class MisconceptionEntry(BaseModel):
"""Individual misconception tracking"""
concept_id: str
concept_label: str
misconception_text: str
frequency: int
affected_student_ids: List[str]
unit_id: str
stop_id: str
detected_via: str # "precheck", "postcheck", "interaction"
first_detected: datetime
last_detected: datetime
class MisconceptionReport(BaseModel):
"""Comprehensive misconception report"""
class_id: Optional[str]
time_range: str
total_misconceptions: int
unique_concepts: int
most_common: List[MisconceptionEntry]
by_unit: Dict[str, List[MisconceptionEntry]]
trending_up: List[MisconceptionEntry] # Getting more frequent
resolved: List[MisconceptionEntry] # No longer appearing
class StudentProgressTimeline(BaseModel):
"""Timeline of student progress"""
student_id: str
student_name: str
units_completed: int
total_time_minutes: int
avg_score: float
trend: str # "improving", "stable", "declining"
timeline: List[Dict[str, Any]] # List of session events
class ClassComparisonData(BaseModel):
"""Data for comparing class performance"""
class_id: str
class_name: str
student_count: int
units_assigned: int
avg_completion_rate: float
avg_learning_gain: float
avg_time_per_unit: float
class ExportFormat(str, Enum):
"""Export format options"""
JSON = "json"
CSV = "csv"
# ==============================================
# Database Integration
# ==============================================
_analytics_db = None
async def get_analytics_database():
"""Get analytics database instance."""
global _analytics_db
if not USE_DATABASE:
return None
if _analytics_db is None:
try:
from unit.database import get_analytics_db
_analytics_db = await get_analytics_db()
logger.info("Analytics database initialized")
except ImportError:
logger.warning("Analytics database module not available")
except Exception as e:
logger.warning(f"Analytics database not available: {e}")
return _analytics_db
# ==============================================
# Helper Functions
# ==============================================
def calculate_gain_distribution(gains: List[float]) -> Dict[str, int]:
"""Calculate distribution of learning gains into buckets."""
distribution = {
"< -20%": 0,
"-20% to -10%": 0,
"-10% to 0%": 0,
"0% to 10%": 0,
"10% to 20%": 0,
"> 20%": 0,
}
for gain in gains:
gain_percent = gain * 100
if gain_percent < -20:
distribution["< -20%"] += 1
elif gain_percent < -10:
distribution["-20% to -10%"] += 1
elif gain_percent < 0:
distribution["-10% to 0%"] += 1
elif gain_percent < 10:
distribution["0% to 10%"] += 1
elif gain_percent < 20:
distribution["10% to 20%"] += 1
else:
distribution["> 20%"] += 1
return distribution
def calculate_trend(scores: List[float]) -> str:
"""Calculate trend from a series of scores."""
if len(scores) < 3:
return "insufficient_data"
# Simple linear regression
n = len(scores)
x_mean = (n - 1) / 2
y_mean = sum(scores) / n
numerator = sum((i - x_mean) * (scores[i] - y_mean) for i in range(n))
denominator = sum((i - x_mean) ** 2 for i in range(n))
if denominator == 0:
return "stable"
slope = numerator / denominator
if slope > 0.05:
return "improving"
elif slope < -0.05:
return "declining"
else:
return "stable"
def calculate_difficulty_rating(success_rate: float, avg_attempts: float) -> float:
"""Calculate difficulty rating 1-5 based on success metrics."""
# Lower success rate and higher attempts = higher difficulty
base_difficulty = (1 - success_rate) * 3 + 1 # 1-4 range
attempt_modifier = min(avg_attempts - 1, 1) # 0-1 range
return min(5.0, base_difficulty + attempt_modifier)
# ==============================================
# API Endpoints - Learning Gain
# ==============================================
# NOTE: Static routes must come BEFORE dynamic routes like /{unit_id}
@router.get("/learning-gain/compare")
async def compare_learning_gains(
unit_ids: str = Query(..., description="Comma-separated unit IDs"),
class_id: Optional[str] = Query(None),
time_range: TimeRange = Query(TimeRange.MONTH),
) -> Dict[str, Any]:
"""
Compare learning gains across multiple units.
"""
unit_list = [u.strip() for u in unit_ids.split(",")]
comparisons = []
for unit_id in unit_list:
try:
summary = await get_learning_gain_analysis(unit_id, class_id, time_range)
comparisons.append({
"unit_id": unit_id,
"avg_gain": summary.avg_gain,
"median_gain": summary.median_gain,
"total_students": summary.total_students,
"positive_rate": summary.positive_gain_count / max(summary.total_students, 1),
})
except Exception as e:
logger.error(f"Failed to get comparison for {unit_id}: {e}")
return {
"time_range": time_range.value,
"class_id": class_id,
"comparisons": sorted(comparisons, key=lambda x: x["avg_gain"], reverse=True),
}
@router.get("/learning-gain/{unit_id}", response_model=LearningGainSummary)
async def get_learning_gain_analysis(
unit_id: str,
class_id: Optional[str] = Query(None, description="Filter by class"),
time_range: TimeRange = Query(TimeRange.MONTH, description="Time range for analysis"),
) -> LearningGainSummary:
"""
Get detailed pre/post learning gain analysis for a unit.
Shows individual gains, aggregated statistics, and distribution.
"""
db = await get_analytics_database()
individual_gains = []
if db:
try:
# Get all sessions with pre/post scores for this unit
sessions = await db.get_unit_sessions_with_scores(
unit_id=unit_id,
class_id=class_id,
time_range=time_range.value
)
for session in sessions:
if session.get("precheck_score") is not None and session.get("postcheck_score") is not None:
gain = session["postcheck_score"] - session["precheck_score"]
individual_gains.append(LearningGainData(
student_id=session["student_id"],
student_name=session.get("student_name", session["student_id"][:8]),
unit_id=unit_id,
precheck_score=session["precheck_score"],
postcheck_score=session["postcheck_score"],
learning_gain=gain,
))
except Exception as e:
logger.error(f"Failed to get learning gain data: {e}")
# Calculate statistics
if not individual_gains:
# Return empty summary
return LearningGainSummary(
unit_id=unit_id,
unit_title=f"Unit {unit_id}",
total_students=0,
avg_precheck=0.0,
avg_postcheck=0.0,
avg_gain=0.0,
median_gain=0.0,
std_deviation=0.0,
positive_gain_count=0,
negative_gain_count=0,
no_change_count=0,
gain_distribution={},
individual_gains=[],
)
gains = [g.learning_gain for g in individual_gains]
prechecks = [g.precheck_score for g in individual_gains]
postchecks = [g.postcheck_score for g in individual_gains]
avg_gain = statistics.mean(gains)
median_gain = statistics.median(gains)
std_dev = statistics.stdev(gains) if len(gains) > 1 else 0.0
# Calculate percentiles
sorted_gains = sorted(gains)
for data in individual_gains:
rank = sorted_gains.index(data.learning_gain) + 1
data.percentile = rank / len(sorted_gains) * 100
return LearningGainSummary(
unit_id=unit_id,
unit_title=f"Unit {unit_id}",
total_students=len(individual_gains),
avg_precheck=statistics.mean(prechecks),
avg_postcheck=statistics.mean(postchecks),
avg_gain=avg_gain,
median_gain=median_gain,
std_deviation=std_dev,
positive_gain_count=sum(1 for g in gains if g > 0.01),
negative_gain_count=sum(1 for g in gains if g < -0.01),
no_change_count=sum(1 for g in gains if -0.01 <= g <= 0.01),
gain_distribution=calculate_gain_distribution(gains),
individual_gains=sorted(individual_gains, key=lambda x: x.learning_gain, reverse=True),
)
# ==============================================
# API Endpoints - Stop-Level Analytics
# ==============================================
@router.get("/unit/{unit_id}/stops", response_model=UnitPerformanceDetail)
async def get_unit_stop_analytics(
unit_id: str,
class_id: Optional[str] = Query(None),
time_range: TimeRange = Query(TimeRange.MONTH),
) -> UnitPerformanceDetail:
"""
Get detailed stop-level performance analytics.
Identifies bottleneck stops where students struggle most.
"""
db = await get_analytics_database()
stops_data = []
if db:
try:
# Get stop-level telemetry
stop_stats = await db.get_stop_performance(
unit_id=unit_id,
class_id=class_id,
time_range=time_range.value
)
for stop in stop_stats:
difficulty = calculate_difficulty_rating(
stop.get("success_rate", 0.5),
stop.get("avg_attempts", 1.0)
)
stops_data.append(StopPerformance(
stop_id=stop["stop_id"],
stop_label=stop.get("stop_label", stop["stop_id"]),
attempts_total=stop.get("total_attempts", 0),
success_rate=stop.get("success_rate", 0.0),
avg_time_seconds=stop.get("avg_time_seconds", 0.0),
avg_attempts_before_success=stop.get("avg_attempts", 1.0),
common_errors=stop.get("common_errors", []),
difficulty_rating=difficulty,
))
# Get overall unit stats
unit_stats = await db.get_unit_overall_stats(unit_id, class_id, time_range.value)
except Exception as e:
logger.error(f"Failed to get stop analytics: {e}")
unit_stats = {}
else:
unit_stats = {}
# Identify bottleneck stops (difficulty > 3.5 or success rate < 0.6)
bottlenecks = [
s.stop_id for s in stops_data
if s.difficulty_rating > 3.5 or s.success_rate < 0.6
]
return UnitPerformanceDetail(
unit_id=unit_id,
unit_title=f"Unit {unit_id}",
template=unit_stats.get("template", "unknown"),
total_sessions=unit_stats.get("total_sessions", 0),
completed_sessions=unit_stats.get("completed_sessions", 0),
completion_rate=unit_stats.get("completion_rate", 0.0),
avg_duration_minutes=unit_stats.get("avg_duration_minutes", 0.0),
stops=stops_data,
bottleneck_stops=bottlenecks,
)
# ==============================================
# API Endpoints - Misconception Tracking
# ==============================================
@router.get("/misconceptions", response_model=MisconceptionReport)
async def get_misconception_report(
class_id: Optional[str] = Query(None),
unit_id: Optional[str] = Query(None),
time_range: TimeRange = Query(TimeRange.MONTH),
limit: int = Query(20, ge=1, le=100),
) -> MisconceptionReport:
"""
Get comprehensive misconception report.
Shows most common misconceptions and their frequency.
"""
db = await get_analytics_database()
misconceptions = []
if db:
try:
raw_misconceptions = await db.get_misconceptions(
class_id=class_id,
unit_id=unit_id,
time_range=time_range.value,
limit=limit
)
for m in raw_misconceptions:
misconceptions.append(MisconceptionEntry(
concept_id=m["concept_id"],
concept_label=m["concept_label"],
misconception_text=m["misconception_text"],
frequency=m["frequency"],
affected_student_ids=m.get("student_ids", []),
unit_id=m["unit_id"],
stop_id=m["stop_id"],
detected_via=m.get("detected_via", "unknown"),
first_detected=m.get("first_detected", datetime.utcnow()),
last_detected=m.get("last_detected", datetime.utcnow()),
))
except Exception as e:
logger.error(f"Failed to get misconceptions: {e}")
# Group by unit
by_unit = {}
for m in misconceptions:
if m.unit_id not in by_unit:
by_unit[m.unit_id] = []
by_unit[m.unit_id].append(m)
# Identify trending misconceptions (would need historical comparison in production)
trending_up = misconceptions[:3] if misconceptions else []
resolved = [] # Would identify from historical data
return MisconceptionReport(
class_id=class_id,
time_range=time_range.value,
total_misconceptions=sum(m.frequency for m in misconceptions),
unique_concepts=len(set(m.concept_id for m in misconceptions)),
most_common=sorted(misconceptions, key=lambda x: x.frequency, reverse=True)[:10],
by_unit=by_unit,
trending_up=trending_up,
resolved=resolved,
)
@router.get("/misconceptions/student/{student_id}")
async def get_student_misconceptions(
student_id: str,
time_range: TimeRange = Query(TimeRange.ALL),
) -> Dict[str, Any]:
"""
Get misconceptions for a specific student.
Useful for personalized remediation.
"""
db = await get_analytics_database()
if db:
try:
misconceptions = await db.get_student_misconceptions(
student_id=student_id,
time_range=time_range.value
)
return {
"student_id": student_id,
"misconceptions": misconceptions,
"recommended_remediation": [
{"concept": m["concept_label"], "activity": f"Review {m['unit_id']}/{m['stop_id']}"}
for m in misconceptions[:5]
]
}
except Exception as e:
logger.error(f"Failed to get student misconceptions: {e}")
return {
"student_id": student_id,
"misconceptions": [],
"recommended_remediation": [],
}
# ==============================================
# API Endpoints - Student Progress Timeline
# ==============================================
@router.get("/student/{student_id}/timeline", response_model=StudentProgressTimeline)
async def get_student_timeline(
student_id: str,
time_range: TimeRange = Query(TimeRange.ALL),
) -> StudentProgressTimeline:
"""
Get detailed progress timeline for a student.
Shows all unit sessions and performance trend.
"""
db = await get_analytics_database()
timeline = []
scores = []
if db:
try:
sessions = await db.get_student_sessions(
student_id=student_id,
time_range=time_range.value
)
for session in sessions:
timeline.append({
"date": session.get("started_at"),
"unit_id": session.get("unit_id"),
"completed": session.get("completed_at") is not None,
"precheck": session.get("precheck_score"),
"postcheck": session.get("postcheck_score"),
"duration_minutes": session.get("duration_seconds", 0) // 60,
})
if session.get("postcheck_score") is not None:
scores.append(session["postcheck_score"])
except Exception as e:
logger.error(f"Failed to get student timeline: {e}")
trend = calculate_trend(scores) if scores else "insufficient_data"
return StudentProgressTimeline(
student_id=student_id,
student_name=f"Student {student_id[:8]}", # Would load actual name
units_completed=sum(1 for t in timeline if t["completed"]),
total_time_minutes=sum(t["duration_minutes"] for t in timeline),
avg_score=statistics.mean(scores) if scores else 0.0,
trend=trend,
timeline=timeline,
)
# ==============================================
# API Endpoints - Class Comparison
# ==============================================
@router.get("/compare/classes", response_model=List[ClassComparisonData])
async def compare_classes(
class_ids: str = Query(..., description="Comma-separated class IDs"),
time_range: TimeRange = Query(TimeRange.MONTH),
) -> List[ClassComparisonData]:
"""
Compare performance across multiple classes.
"""
class_list = [c.strip() for c in class_ids.split(",")]
comparisons = []
db = await get_analytics_database()
if db:
for class_id in class_list:
try:
stats = await db.get_class_aggregate_stats(class_id, time_range.value)
comparisons.append(ClassComparisonData(
class_id=class_id,
class_name=stats.get("class_name", f"Klasse {class_id[:8]}"),
student_count=stats.get("student_count", 0),
units_assigned=stats.get("units_assigned", 0),
avg_completion_rate=stats.get("avg_completion_rate", 0.0),
avg_learning_gain=stats.get("avg_learning_gain", 0.0),
avg_time_per_unit=stats.get("avg_time_per_unit", 0.0),
))
except Exception as e:
logger.error(f"Failed to get stats for class {class_id}: {e}")
return sorted(comparisons, key=lambda x: x.avg_learning_gain, reverse=True)
# ==============================================
# API Endpoints - Export
# ==============================================
@router.get("/export/learning-gains")
async def export_learning_gains(
unit_id: Optional[str] = Query(None),
class_id: Optional[str] = Query(None),
time_range: TimeRange = Query(TimeRange.ALL),
format: ExportFormat = Query(ExportFormat.JSON),
) -> Any:
"""
Export learning gain data.
"""
from fastapi.responses import Response
db = await get_analytics_database()
data = []
if db:
try:
data = await db.export_learning_gains(
unit_id=unit_id,
class_id=class_id,
time_range=time_range.value
)
except Exception as e:
logger.error(f"Failed to export data: {e}")
if format == ExportFormat.CSV:
# Convert to CSV
if not data:
csv_content = "student_id,unit_id,precheck,postcheck,gain\n"
else:
csv_content = "student_id,unit_id,precheck,postcheck,gain\n"
for row in data:
csv_content += f"{row['student_id']},{row['unit_id']},{row.get('precheck', '')},{row.get('postcheck', '')},{row.get('gain', '')}\n"
return Response(
content=csv_content,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=learning_gains.csv"}
)
return {
"export_date": datetime.utcnow().isoformat(),
"filters": {
"unit_id": unit_id,
"class_id": class_id,
"time_range": time_range.value,
},
"data": data,
}
@router.get("/export/misconceptions")
async def export_misconceptions(
class_id: Optional[str] = Query(None),
format: ExportFormat = Query(ExportFormat.JSON),
) -> Any:
"""
Export misconception data for further analysis.
"""
report = await get_misconception_report(
class_id=class_id,
unit_id=None,
time_range=TimeRange.MONTH,
limit=100
)
if format == ExportFormat.CSV:
from fastapi.responses import Response
csv_content = "concept_id,concept_label,misconception,frequency,unit_id,stop_id\n"
for m in report.most_common:
csv_content += f'"{m.concept_id}","{m.concept_label}","{m.misconception_text}",{m.frequency},"{m.unit_id}","{m.stop_id}"\n'
return Response(
content=csv_content,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=misconceptions.csv"}
)
return {
"export_date": datetime.utcnow().isoformat(),
"class_id": class_id,
"total_entries": len(report.most_common),
"data": [m.model_dump() for m in report.most_common],
}
# ==============================================
# API Endpoints - Dashboard Aggregates
# ==============================================
@router.get("/dashboard/overview")
async def get_analytics_overview(
time_range: TimeRange = Query(TimeRange.MONTH),
) -> Dict[str, Any]:
"""
Get high-level analytics overview for dashboard.
"""
db = await get_analytics_database()
if db:
try:
overview = await db.get_analytics_overview(time_range.value)
return overview
except Exception as e:
logger.error(f"Failed to get analytics overview: {e}")
return {
"time_range": time_range.value,
"total_sessions": 0,
"unique_students": 0,
"avg_completion_rate": 0.0,
"avg_learning_gain": 0.0,
"most_played_units": [],
"struggling_concepts": [],
"active_classes": 0,
}
@router.get("/health")
async def health_check() -> Dict[str, Any]:
"""Health check for analytics API."""
db = await get_analytics_database()
return {
"status": "healthy",
"service": "unit-analytics",
"database": "connected" if db else "disconnected",
}
router.include_router(_routes_router)
router.include_router(_export_router)
+145
View File
@@ -0,0 +1,145 @@
"""
Unit Analytics API - Export & Dashboard Routes.
Export endpoints for learning gains and misconceptions, plus dashboard overview.
"""
import logging
from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import APIRouter, Query
from fastapi.responses import Response
from unit_analytics_models import TimeRange, ExportFormat
from unit_analytics_helpers import get_analytics_database
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Unit Analytics"])
# ==============================================
# API Endpoints - Export
# ==============================================
@router.get("/export/learning-gains")
async def export_learning_gains(
unit_id: Optional[str] = Query(None),
class_id: Optional[str] = Query(None),
time_range: TimeRange = Query(TimeRange.ALL),
format: ExportFormat = Query(ExportFormat.JSON),
) -> Any:
"""
Export learning gain data.
"""
db = await get_analytics_database()
data = []
if db:
try:
data = await db.export_learning_gains(
unit_id=unit_id, class_id=class_id, time_range=time_range.value
)
except Exception as e:
logger.error(f"Failed to export data: {e}")
if format == ExportFormat.CSV:
if not data:
csv_content = "student_id,unit_id,precheck,postcheck,gain\n"
else:
csv_content = "student_id,unit_id,precheck,postcheck,gain\n"
for row in data:
csv_content += f"{row['student_id']},{row['unit_id']},{row.get('precheck', '')},{row.get('postcheck', '')},{row.get('gain', '')}\n"
return Response(
content=csv_content,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=learning_gains.csv"}
)
return {
"export_date": datetime.utcnow().isoformat(),
"filters": {
"unit_id": unit_id, "class_id": class_id, "time_range": time_range.value,
},
"data": data,
}
@router.get("/export/misconceptions")
async def export_misconceptions(
class_id: Optional[str] = Query(None),
format: ExportFormat = Query(ExportFormat.JSON),
) -> Any:
"""
Export misconception data for further analysis.
"""
# Import here to avoid circular dependency
from unit_analytics_routes import get_misconception_report
report = await get_misconception_report(
class_id=class_id, unit_id=None,
time_range=TimeRange.MONTH, limit=100
)
if format == ExportFormat.CSV:
csv_content = "concept_id,concept_label,misconception,frequency,unit_id,stop_id\n"
for m in report.most_common:
csv_content += f'"{m.concept_id}","{m.concept_label}","{m.misconception_text}",{m.frequency},"{m.unit_id}","{m.stop_id}"\n'
return Response(
content=csv_content,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=misconceptions.csv"}
)
return {
"export_date": datetime.utcnow().isoformat(),
"class_id": class_id,
"total_entries": len(report.most_common),
"data": [m.model_dump() for m in report.most_common],
}
# ==============================================
# API Endpoints - Dashboard Aggregates
# ==============================================
@router.get("/dashboard/overview")
async def get_analytics_overview(
time_range: TimeRange = Query(TimeRange.MONTH),
) -> Dict[str, Any]:
"""
Get high-level analytics overview for dashboard.
"""
db = await get_analytics_database()
if db:
try:
overview = await db.get_analytics_overview(time_range.value)
return overview
except Exception as e:
logger.error(f"Failed to get analytics overview: {e}")
return {
"time_range": time_range.value,
"total_sessions": 0,
"unique_students": 0,
"avg_completion_rate": 0.0,
"avg_learning_gain": 0.0,
"most_played_units": [],
"struggling_concepts": [],
"active_classes": 0,
}
@router.get("/health")
async def health_check() -> Dict[str, Any]:
"""Health check for analytics API."""
db = await get_analytics_database()
return {
"status": "healthy",
"service": "unit-analytics",
"database": "connected" if db else "disconnected",
}
+97
View File
@@ -0,0 +1,97 @@
"""
Unit Analytics API - Helpers.
Database access, statistical computation, and utility functions.
"""
import os
import logging
from typing import List, Dict, Optional
logger = logging.getLogger(__name__)
# Feature flags
USE_DATABASE = os.getenv("GAME_USE_DATABASE", "true").lower() == "true"
# Database singleton
_analytics_db = None
async def get_analytics_database():
"""Get analytics database instance."""
global _analytics_db
if not USE_DATABASE:
return None
if _analytics_db is None:
try:
from unit.database import get_analytics_db
_analytics_db = await get_analytics_db()
logger.info("Analytics database initialized")
except ImportError:
logger.warning("Analytics database module not available")
except Exception as e:
logger.warning(f"Analytics database not available: {e}")
return _analytics_db
def calculate_gain_distribution(gains: List[float]) -> Dict[str, int]:
"""Calculate distribution of learning gains into buckets."""
distribution = {
"< -20%": 0,
"-20% to -10%": 0,
"-10% to 0%": 0,
"0% to 10%": 0,
"10% to 20%": 0,
"> 20%": 0,
}
for gain in gains:
gain_percent = gain * 100
if gain_percent < -20:
distribution["< -20%"] += 1
elif gain_percent < -10:
distribution["-20% to -10%"] += 1
elif gain_percent < 0:
distribution["-10% to 0%"] += 1
elif gain_percent < 10:
distribution["0% to 10%"] += 1
elif gain_percent < 20:
distribution["10% to 20%"] += 1
else:
distribution["> 20%"] += 1
return distribution
def calculate_trend(scores: List[float]) -> str:
"""Calculate trend from a series of scores."""
if len(scores) < 3:
return "insufficient_data"
# Simple linear regression
n = len(scores)
x_mean = (n - 1) / 2
y_mean = sum(scores) / n
numerator = sum((i - x_mean) * (scores[i] - y_mean) for i in range(n))
denominator = sum((i - x_mean) ** 2 for i in range(n))
if denominator == 0:
return "stable"
slope = numerator / denominator
if slope > 0.05:
return "improving"
elif slope < -0.05:
return "declining"
else:
return "stable"
def calculate_difficulty_rating(success_rate: float, avg_attempts: float) -> float:
"""Calculate difficulty rating 1-5 based on success metrics."""
# Lower success rate and higher attempts = higher difficulty
base_difficulty = (1 - success_rate) * 3 + 1 # 1-4 range
attempt_modifier = min(avg_attempts - 1, 1) # 0-1 range
return min(5.0, base_difficulty + attempt_modifier)
+127
View File
@@ -0,0 +1,127 @@
"""
Unit Analytics API - Pydantic Models.
Data models for learning gains, stop performance, misconceptions,
student progress, class comparison, and export.
"""
from typing import List, Optional, Dict, Any
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
class TimeRange(str, Enum):
"""Time range for analytics queries"""
WEEK = "week"
MONTH = "month"
QUARTER = "quarter"
ALL = "all"
class LearningGainData(BaseModel):
"""Pre/Post learning gain data point"""
student_id: str
student_name: str
unit_id: str
precheck_score: float
postcheck_score: float
learning_gain: float
percentile: Optional[float] = None
class LearningGainSummary(BaseModel):
"""Aggregated learning gain statistics"""
unit_id: str
unit_title: str
total_students: int
avg_precheck: float
avg_postcheck: float
avg_gain: float
median_gain: float
std_deviation: float
positive_gain_count: int
negative_gain_count: int
no_change_count: int
gain_distribution: Dict[str, int]
individual_gains: List[LearningGainData]
class StopPerformance(BaseModel):
"""Performance data for a single stop"""
stop_id: str
stop_label: str
attempts_total: int
success_rate: float
avg_time_seconds: float
avg_attempts_before_success: float
common_errors: List[str]
difficulty_rating: float # 1-5 based on performance
class UnitPerformanceDetail(BaseModel):
"""Detailed unit performance breakdown"""
unit_id: str
unit_title: str
template: str
total_sessions: int
completed_sessions: int
completion_rate: float
avg_duration_minutes: float
stops: List[StopPerformance]
bottleneck_stops: List[str] # Stops where students struggle most
class MisconceptionEntry(BaseModel):
"""Individual misconception tracking"""
concept_id: str
concept_label: str
misconception_text: str
frequency: int
affected_student_ids: List[str]
unit_id: str
stop_id: str
detected_via: str # "precheck", "postcheck", "interaction"
first_detected: datetime
last_detected: datetime
class MisconceptionReport(BaseModel):
"""Comprehensive misconception report"""
class_id: Optional[str]
time_range: str
total_misconceptions: int
unique_concepts: int
most_common: List[MisconceptionEntry]
by_unit: Dict[str, List[MisconceptionEntry]]
trending_up: List[MisconceptionEntry] # Getting more frequent
resolved: List[MisconceptionEntry] # No longer appearing
class StudentProgressTimeline(BaseModel):
"""Timeline of student progress"""
student_id: str
student_name: str
units_completed: int
total_time_minutes: int
avg_score: float
trend: str # "improving", "stable", "declining"
timeline: List[Dict[str, Any]] # List of session events
class ClassComparisonData(BaseModel):
"""Data for comparing class performance"""
class_id: str
class_name: str
student_count: int
units_assigned: int
avg_completion_rate: float
avg_learning_gain: float
avg_time_per_unit: float
class ExportFormat(str, Enum):
"""Export format options"""
JSON = "json"
CSV = "csv"
+394
View File
@@ -0,0 +1,394 @@
"""
Unit Analytics API - Routes.
All API endpoints for learning gain, stop-level, misconception,
student timeline, class comparison, export, and dashboard analytics.
"""
import logging
import statistics
from datetime import datetime
from typing import Optional, Dict, Any, List
from fastapi import APIRouter, Query
from unit_analytics_models import (
TimeRange,
LearningGainData,
LearningGainSummary,
StopPerformance,
UnitPerformanceDetail,
MisconceptionEntry,
MisconceptionReport,
StudentProgressTimeline,
ClassComparisonData,
)
from unit_analytics_helpers import (
get_analytics_database,
calculate_gain_distribution,
calculate_trend,
calculate_difficulty_rating,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Unit Analytics"])
# ==============================================
# API Endpoints - Learning Gain
# ==============================================
# NOTE: Static routes must come BEFORE dynamic routes like /{unit_id}
@router.get("/learning-gain/compare")
async def compare_learning_gains(
unit_ids: str = Query(..., description="Comma-separated unit IDs"),
class_id: Optional[str] = Query(None),
time_range: TimeRange = Query(TimeRange.MONTH),
) -> Dict[str, Any]:
"""
Compare learning gains across multiple units.
"""
unit_list = [u.strip() for u in unit_ids.split(",")]
comparisons = []
for unit_id in unit_list:
try:
summary = await get_learning_gain_analysis(unit_id, class_id, time_range)
comparisons.append({
"unit_id": unit_id,
"avg_gain": summary.avg_gain,
"median_gain": summary.median_gain,
"total_students": summary.total_students,
"positive_rate": summary.positive_gain_count / max(summary.total_students, 1),
})
except Exception as e:
logger.error(f"Failed to get comparison for {unit_id}: {e}")
return {
"time_range": time_range.value,
"class_id": class_id,
"comparisons": sorted(comparisons, key=lambda x: x["avg_gain"], reverse=True),
}
@router.get("/learning-gain/{unit_id}", response_model=LearningGainSummary)
async def get_learning_gain_analysis(
unit_id: str,
class_id: Optional[str] = Query(None, description="Filter by class"),
time_range: TimeRange = Query(TimeRange.MONTH, description="Time range for analysis"),
) -> LearningGainSummary:
"""
Get detailed pre/post learning gain analysis for a unit.
"""
db = await get_analytics_database()
individual_gains = []
if db:
try:
sessions = await db.get_unit_sessions_with_scores(
unit_id=unit_id,
class_id=class_id,
time_range=time_range.value
)
for session in sessions:
if session.get("precheck_score") is not None and session.get("postcheck_score") is not None:
gain = session["postcheck_score"] - session["precheck_score"]
individual_gains.append(LearningGainData(
student_id=session["student_id"],
student_name=session.get("student_name", session["student_id"][:8]),
unit_id=unit_id,
precheck_score=session["precheck_score"],
postcheck_score=session["postcheck_score"],
learning_gain=gain,
))
except Exception as e:
logger.error(f"Failed to get learning gain data: {e}")
# Calculate statistics
if not individual_gains:
return LearningGainSummary(
unit_id=unit_id,
unit_title=f"Unit {unit_id}",
total_students=0,
avg_precheck=0.0, avg_postcheck=0.0,
avg_gain=0.0, median_gain=0.0, std_deviation=0.0,
positive_gain_count=0, negative_gain_count=0, no_change_count=0,
gain_distribution={}, individual_gains=[],
)
gains = [g.learning_gain for g in individual_gains]
prechecks = [g.precheck_score for g in individual_gains]
postchecks = [g.postcheck_score for g in individual_gains]
avg_gain = statistics.mean(gains)
median_gain = statistics.median(gains)
std_dev = statistics.stdev(gains) if len(gains) > 1 else 0.0
# Calculate percentiles
sorted_gains = sorted(gains)
for data in individual_gains:
rank = sorted_gains.index(data.learning_gain) + 1
data.percentile = rank / len(sorted_gains) * 100
return LearningGainSummary(
unit_id=unit_id,
unit_title=f"Unit {unit_id}",
total_students=len(individual_gains),
avg_precheck=statistics.mean(prechecks),
avg_postcheck=statistics.mean(postchecks),
avg_gain=avg_gain,
median_gain=median_gain,
std_deviation=std_dev,
positive_gain_count=sum(1 for g in gains if g > 0.01),
negative_gain_count=sum(1 for g in gains if g < -0.01),
no_change_count=sum(1 for g in gains if -0.01 <= g <= 0.01),
gain_distribution=calculate_gain_distribution(gains),
individual_gains=sorted(individual_gains, key=lambda x: x.learning_gain, reverse=True),
)
# ==============================================
# API Endpoints - Stop-Level Analytics
# ==============================================
@router.get("/unit/{unit_id}/stops", response_model=UnitPerformanceDetail)
async def get_unit_stop_analytics(
unit_id: str,
class_id: Optional[str] = Query(None),
time_range: TimeRange = Query(TimeRange.MONTH),
) -> UnitPerformanceDetail:
"""
Get detailed stop-level performance analytics.
"""
db = await get_analytics_database()
stops_data = []
if db:
try:
stop_stats = await db.get_stop_performance(
unit_id=unit_id, class_id=class_id, time_range=time_range.value
)
for stop in stop_stats:
difficulty = calculate_difficulty_rating(
stop.get("success_rate", 0.5),
stop.get("avg_attempts", 1.0)
)
stops_data.append(StopPerformance(
stop_id=stop["stop_id"],
stop_label=stop.get("stop_label", stop["stop_id"]),
attempts_total=stop.get("total_attempts", 0),
success_rate=stop.get("success_rate", 0.0),
avg_time_seconds=stop.get("avg_time_seconds", 0.0),
avg_attempts_before_success=stop.get("avg_attempts", 1.0),
common_errors=stop.get("common_errors", []),
difficulty_rating=difficulty,
))
unit_stats = await db.get_unit_overall_stats(unit_id, class_id, time_range.value)
except Exception as e:
logger.error(f"Failed to get stop analytics: {e}")
unit_stats = {}
else:
unit_stats = {}
# Identify bottleneck stops
bottlenecks = [
s.stop_id for s in stops_data
if s.difficulty_rating > 3.5 or s.success_rate < 0.6
]
return UnitPerformanceDetail(
unit_id=unit_id,
unit_title=f"Unit {unit_id}",
template=unit_stats.get("template", "unknown"),
total_sessions=unit_stats.get("total_sessions", 0),
completed_sessions=unit_stats.get("completed_sessions", 0),
completion_rate=unit_stats.get("completion_rate", 0.0),
avg_duration_minutes=unit_stats.get("avg_duration_minutes", 0.0),
stops=stops_data,
bottleneck_stops=bottlenecks,
)
# ==============================================
# API Endpoints - Misconception Tracking
# ==============================================
@router.get("/misconceptions", response_model=MisconceptionReport)
async def get_misconception_report(
class_id: Optional[str] = Query(None),
unit_id: Optional[str] = Query(None),
time_range: TimeRange = Query(TimeRange.MONTH),
limit: int = Query(20, ge=1, le=100),
) -> MisconceptionReport:
"""
Get comprehensive misconception report.
"""
db = await get_analytics_database()
misconceptions = []
if db:
try:
raw_misconceptions = await db.get_misconceptions(
class_id=class_id, unit_id=unit_id,
time_range=time_range.value, limit=limit
)
for m in raw_misconceptions:
misconceptions.append(MisconceptionEntry(
concept_id=m["concept_id"],
concept_label=m["concept_label"],
misconception_text=m["misconception_text"],
frequency=m["frequency"],
affected_student_ids=m.get("student_ids", []),
unit_id=m["unit_id"],
stop_id=m["stop_id"],
detected_via=m.get("detected_via", "unknown"),
first_detected=m.get("first_detected", datetime.utcnow()),
last_detected=m.get("last_detected", datetime.utcnow()),
))
except Exception as e:
logger.error(f"Failed to get misconceptions: {e}")
# Group by unit
by_unit = {}
for m in misconceptions:
if m.unit_id not in by_unit:
by_unit[m.unit_id] = []
by_unit[m.unit_id].append(m)
trending_up = misconceptions[:3] if misconceptions else []
resolved = []
return MisconceptionReport(
class_id=class_id,
time_range=time_range.value,
total_misconceptions=sum(m.frequency for m in misconceptions),
unique_concepts=len(set(m.concept_id for m in misconceptions)),
most_common=sorted(misconceptions, key=lambda x: x.frequency, reverse=True)[:10],
by_unit=by_unit,
trending_up=trending_up,
resolved=resolved,
)
@router.get("/misconceptions/student/{student_id}")
async def get_student_misconceptions(
student_id: str,
time_range: TimeRange = Query(TimeRange.ALL),
) -> Dict[str, Any]:
"""
Get misconceptions for a specific student.
"""
db = await get_analytics_database()
if db:
try:
misconceptions = await db.get_student_misconceptions(
student_id=student_id, time_range=time_range.value
)
return {
"student_id": student_id,
"misconceptions": misconceptions,
"recommended_remediation": [
{"concept": m["concept_label"], "activity": f"Review {m['unit_id']}/{m['stop_id']}"}
for m in misconceptions[:5]
]
}
except Exception as e:
logger.error(f"Failed to get student misconceptions: {e}")
return {
"student_id": student_id,
"misconceptions": [],
"recommended_remediation": [],
}
# ==============================================
# API Endpoints - Student Progress Timeline
# ==============================================
@router.get("/student/{student_id}/timeline", response_model=StudentProgressTimeline)
async def get_student_timeline(
student_id: str,
time_range: TimeRange = Query(TimeRange.ALL),
) -> StudentProgressTimeline:
"""
Get detailed progress timeline for a student.
"""
db = await get_analytics_database()
timeline = []
scores = []
if db:
try:
sessions = await db.get_student_sessions(
student_id=student_id, time_range=time_range.value
)
for session in sessions:
timeline.append({
"date": session.get("started_at"),
"unit_id": session.get("unit_id"),
"completed": session.get("completed_at") is not None,
"precheck": session.get("precheck_score"),
"postcheck": session.get("postcheck_score"),
"duration_minutes": session.get("duration_seconds", 0) // 60,
})
if session.get("postcheck_score") is not None:
scores.append(session["postcheck_score"])
except Exception as e:
logger.error(f"Failed to get student timeline: {e}")
trend = calculate_trend(scores) if scores else "insufficient_data"
return StudentProgressTimeline(
student_id=student_id,
student_name=f"Student {student_id[:8]}",
units_completed=sum(1 for t in timeline if t["completed"]),
total_time_minutes=sum(t["duration_minutes"] for t in timeline),
avg_score=statistics.mean(scores) if scores else 0.0,
trend=trend,
timeline=timeline,
)
# ==============================================
# API Endpoints - Class Comparison
# ==============================================
@router.get("/compare/classes", response_model=List[ClassComparisonData])
async def compare_classes(
class_ids: str = Query(..., description="Comma-separated class IDs"),
time_range: TimeRange = Query(TimeRange.MONTH),
) -> List[ClassComparisonData]:
"""
Compare performance across multiple classes.
"""
class_list = [c.strip() for c in class_ids.split(",")]
comparisons = []
db = await get_analytics_database()
if db:
for class_id in class_list:
try:
stats = await db.get_class_aggregate_stats(class_id, time_range.value)
comparisons.append(ClassComparisonData(
class_id=class_id,
class_name=stats.get("class_name", f"Klasse {class_id[:8]}"),
student_count=stats.get("student_count", 0),
units_assigned=stats.get("units_assigned", 0),
avg_completion_rate=stats.get("avg_completion_rate", 0.0),
avg_learning_gain=stats.get("avg_learning_gain", 0.0),
avg_time_per_unit=stats.get("avg_time_per_unit", 0.0),
))
except Exception as e:
logger.error(f"Failed to get stats for class {class_id}: {e}")
return sorted(comparisons, key=lambda x: x.avg_learning_gain, reverse=True)
@@ -0,0 +1,200 @@
"""
Compliance Extraction & Generation.
Functions for extracting checkpoints from legal text chunks,
generating controls, and creating remediation measures.
"""
import re
import hashlib
import logging
from typing import Dict, List, Optional
from compliance_models import Checkpoint, Control, Measure
logger = logging.getLogger(__name__)
def extract_checkpoints_from_chunk(chunk_text: str, payload: Dict) -> List[Checkpoint]:
"""
Extract checkpoints/requirements from a chunk of text.
Uses pattern matching to find requirement-like statements.
"""
checkpoints = []
regulation_code = payload.get("regulation_code", "UNKNOWN")
regulation_name = payload.get("regulation_name", "Unknown")
source_url = payload.get("source_url", "")
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
# Patterns for different requirement types
patterns = [
# BSI-TR patterns
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
# Article patterns (GDPR, AI Act, etc.)
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-\u2013:]\s*(.+?)(?=\n|$)', 'article'),
# Numbered requirements
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
# "Der Verantwortliche muss" patterns
(r'(?:Der Verantwortliche|Die Aufsichtsbeh\u00f6rde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
# "Es ist erforderlich" patterns
(r'(?:Es ist erforderlich|Es muss gew\u00e4hrleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
]
for pattern, pattern_type in patterns:
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
for match in matches:
if pattern_type == 'bsi_requirement':
req_id = match.group(1)
description = match.group(2).strip()
title = req_id
elif pattern_type == 'article':
article_num = match.group(1)
paragraph = match.group(2) or ""
title_text = match.group(3).strip()
req_id = f"{regulation_code}-Art{article_num}"
if paragraph:
req_id += f"-{paragraph}"
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
description = title_text
elif pattern_type == 'numbered':
num = match.group(1)
description = match.group(2).strip()
req_id = f"{regulation_code}-{num}"
title = f"Anforderung {num}"
else:
# Generic requirement
description = match.group(0).strip()
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
title = description[:50] + "..." if len(description) > 50 else description
# Skip very short matches
if len(description) < 20:
continue
checkpoint = Checkpoint(
id=req_id,
regulation_code=regulation_code,
regulation_name=regulation_name,
article=title if 'Art' in title else None,
title=title,
description=description[:500],
original_text=description,
chunk_id=chunk_id,
source_url=source_url
)
checkpoints.append(checkpoint)
return checkpoints
def generate_control_for_checkpoints(
checkpoints: List[Checkpoint],
domain_counts: Dict[str, int],
) -> Optional[Control]:
"""
Generate a control that covers the given checkpoints.
This is a simplified version - in production this would use the AI assistant.
"""
if not checkpoints:
return None
# Group by regulation
regulation = checkpoints[0].regulation_code
# Determine domain based on content
all_text = " ".join([cp.description for cp in checkpoints]).lower()
domain = "gov" # Default
if any(kw in all_text for kw in ["verschl\u00fcssel", "krypto", "encrypt", "hash"]):
domain = "crypto"
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
domain = "iam"
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
domain = "priv"
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
domain = "sdlc"
elif any(kw in all_text for kw in ["\u00fcberwach", "monitor", "log", "audit"]):
domain = "aud"
elif any(kw in all_text for kw in ["ki", "k\u00fcnstlich", "ai", "machine learning", "model"]):
domain = "ai"
elif any(kw in all_text for kw in ["betrieb", "operation", "verf\u00fcgbar", "backup"]):
domain = "ops"
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
domain = "cra"
# Generate control ID
domain_count = domain_counts.get(domain, 0) + 1
control_id = f"{domain.upper()}-{domain_count:03d}"
# Create title from first checkpoint
title = checkpoints[0].title
if len(title) > 100:
title = title[:97] + "..."
# Create description
description = f"Control f\u00fcr {regulation}: " + checkpoints[0].description[:200]
# Pass criteria
pass_criteria = f"Alle {len(checkpoints)} zugeh\u00f6rigen Anforderungen sind erf\u00fcllt und dokumentiert."
# Implementation guidance
guidance = f"Implementiere Ma\u00dfnahmen zur Erf\u00fcllung der Anforderungen aus {regulation}. "
guidance += f"Dokumentiere die Umsetzung und f\u00fchre regelm\u00e4\u00dfige Reviews durch."
# Determine if automated
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
control = Control(
id=control_id,
domain=domain,
title=title,
description=description,
checkpoints=[cp.id for cp in checkpoints],
pass_criteria=pass_criteria,
implementation_guidance=guidance,
is_automated=is_automated,
automation_tool="CI/CD Pipeline" if is_automated else None,
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
)
return control
def generate_measure_for_control(control: Control) -> Measure:
"""Generate a remediation measure for a control."""
measure_id = f"M-{control.id}"
# Determine deadline based on priority
deadline_days = {
"critical": 30,
"high": 60,
"medium": 90,
"low": 180
}.get(control.priority, 90)
# Determine responsible team
responsible = {
"priv": "Datenschutzbeauftragter",
"iam": "IT-Security Team",
"sdlc": "Entwicklungsteam",
"crypto": "IT-Security Team",
"ops": "Operations Team",
"aud": "Compliance Team",
"ai": "AI/ML Team",
"cra": "IT-Security Team",
"gov": "Management"
}.get(control.domain, "Compliance Team")
measure = Measure(
id=measure_id,
control_id=control.id,
title=f"Umsetzung: {control.title[:50]}",
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
responsible=responsible,
deadline_days=deadline_days,
status="pending"
)
return measure
@@ -0,0 +1,49 @@
"""
Compliance Pipeline Data Models.
Dataclasses for checkpoints, controls, and measures.
"""
from typing import Optional, List
from dataclasses import dataclass
@dataclass
class Checkpoint:
"""A requirement/checkpoint extracted from legal text."""
id: str
regulation_code: str
regulation_name: str
article: Optional[str]
title: str
description: str
original_text: str
chunk_id: str
source_url: str
@dataclass
class Control:
"""A control derived from checkpoints."""
id: str
domain: str
title: str
description: str
checkpoints: List[str] # List of checkpoint IDs
pass_criteria: str
implementation_guidance: str
is_automated: bool
automation_tool: Optional[str]
priority: str
@dataclass
class Measure:
"""A remediation measure for a control."""
id: str
control_id: str
title: str
description: str
responsible: str
deadline_days: int
status: str
@@ -0,0 +1,441 @@
"""
Compliance Pipeline Execution.
Pipeline phases (ingestion, extraction, control generation, measures)
and orchestration logic.
"""
import asyncio
import json
import logging
import os
import sys
import time
from datetime import datetime
from typing import Dict, List, Any
from dataclasses import asdict
from compliance_models import Checkpoint, Control, Measure
from compliance_extraction import (
extract_checkpoints_from_chunk,
generate_control_for_checkpoints,
generate_measure_for_control,
)
logger = logging.getLogger(__name__)
# Import checkpoint manager
try:
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
except ImportError:
logger.warning("Checkpoint manager not available, running without checkpoints")
CheckpointManager = None
EXPECTED_VALUES = {}
ValidationStatus = None
# Set environment variables for Docker network
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
os.environ["QDRANT_HOST"] = "qdrant"
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
# Try to import from klausur-service
try:
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
except ImportError:
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
sys.exit(1)
class CompliancePipeline:
"""Handles the full compliance pipeline."""
def __init__(self):
# Support both QDRANT_URL and QDRANT_HOST/PORT
qdrant_url = os.getenv("QDRANT_URL", "")
if qdrant_url:
from urllib.parse import urlparse
parsed = urlparse(qdrant_url)
qdrant_host = parsed.hostname or "qdrant"
qdrant_port = parsed.port or 6333
else:
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
qdrant_port = 6333
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
self.checkpoints: List[Checkpoint] = []
self.controls: List[Control] = []
self.measures: List[Measure] = []
self.stats = {
"chunks_processed": 0,
"checkpoints_extracted": 0,
"controls_created": 0,
"measures_defined": 0,
"by_regulation": {},
"by_domain": {},
}
# Initialize checkpoint manager
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
"""Phase 1: Ingest documents (incremental - only missing ones)."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
ingestion = LegalCorpusIngestion()
try:
# Check existing chunks per regulation
existing_chunks = {}
try:
for regulation in REGULATIONS:
count_result = self.qdrant.count(
collection_name=LEGAL_CORPUS_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
)
)
existing_chunks[regulation.code] = count_result.count
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
except Exception as e:
logger.warning(f"Could not check existing chunks: {e}")
# Determine which regulations need ingestion
regulations_to_ingest = []
for regulation in REGULATIONS:
existing = existing_chunks.get(regulation.code, 0)
if force_reindex or existing == 0:
regulations_to_ingest.append(regulation)
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
else:
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
self.stats["by_regulation"][regulation.code] = existing
if not regulations_to_ingest:
logger.info("All regulations already indexed. Skipping ingestion phase.")
total_chunks = sum(existing_chunks.values())
self.stats["chunks_processed"] = total_chunks
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("skipped", True)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
# Ingest only missing regulations
total_chunks = sum(existing_chunks.values())
for i, regulation in enumerate(regulations_to_ingest, 1):
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
try:
count = await ingestion.ingest_regulation(regulation)
total_chunks += count
self.stats["by_regulation"][regulation.code] = count
logger.info(f" -> {count} chunks")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
except Exception as e:
logger.error(f" -> FAILED: {e}")
self.stats["by_regulation"][regulation.code] = 0
self.stats["chunks_processed"] = total_chunks
logger.info(f"\nTotal chunks in collection: {total_chunks}")
# Validate ingestion results
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
expected = EXPECTED_VALUES.get("ingestion", {})
self.checkpoint_mgr.validate(
"total_chunks",
expected=expected.get("total_chunks", 8000),
actual=total_chunks,
min_value=expected.get("min_chunks", 7000)
)
reg_expected = expected.get("regulations", {})
for reg_code, reg_exp in reg_expected.items():
actual = self.stats["by_regulation"].get(reg_code, 0)
self.checkpoint_mgr.validate(
f"chunks_{reg_code}",
expected=reg_exp.get("expected", 0),
actual=actual,
min_value=reg_exp.get("min", 0)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
finally:
await ingestion.close()
async def run_extraction_phase(self) -> int:
"""Phase 2: Extract checkpoints from chunks."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
try:
offset = None
total_checkpoints = 0
while True:
result = self.qdrant.scroll(
collection_name=LEGAL_CORPUS_COLLECTION,
limit=100,
offset=offset,
with_payload=True,
with_vectors=False
)
points, next_offset = result
if not points:
break
for point in points:
payload = point.payload
text = payload.get("text", "")
cps = extract_checkpoints_from_chunk(text, payload)
self.checkpoints.extend(cps)
total_checkpoints += len(cps)
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
if next_offset is None:
break
offset = next_offset
self.stats["checkpoints_extracted"] = len(self.checkpoints)
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
by_reg = {}
for cp in self.checkpoints:
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
for reg, count in sorted(by_reg.items()):
logger.info(f" {reg}: {count} checkpoints")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
expected = EXPECTED_VALUES.get("extraction", {})
self.checkpoint_mgr.validate(
"total_checkpoints",
expected=expected.get("total_checkpoints", 3500),
actual=len(self.checkpoints),
min_value=expected.get("min_checkpoints", 3000)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.checkpoints)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_control_generation_phase(self) -> int:
"""Phase 3: Generate controls from checkpoints."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 3: CONTROL GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
try:
# Group checkpoints by regulation
by_regulation: Dict[str, List[Checkpoint]] = {}
for cp in self.checkpoints:
reg = cp.regulation_code
if reg not in by_regulation:
by_regulation[reg] = []
by_regulation[reg].append(cp)
# Generate controls per regulation (group every 3-5 checkpoints)
for regulation, checkpoints in by_regulation.items():
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
batch_size = 4
for i in range(0, len(checkpoints), batch_size):
batch = checkpoints[i:i + batch_size]
control = generate_control_for_checkpoints(batch, self.stats.get("by_domain", {}))
if control:
self.controls.append(control)
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
self.stats["controls_created"] = len(self.controls)
logger.info(f"\nTotal controls created: {len(self.controls)}")
for domain, count in sorted(self.stats["by_domain"].items()):
logger.info(f" {domain}: {count} controls")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
expected = EXPECTED_VALUES.get("controls", {})
self.checkpoint_mgr.validate(
"total_controls",
expected=expected.get("total_controls", 900),
actual=len(self.controls),
min_value=expected.get("min_controls", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.controls)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_measure_generation_phase(self) -> int:
"""Phase 4: Generate measures for controls."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 4: MEASURE GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
try:
for control in self.controls:
measure = generate_measure_for_control(control)
self.measures.append(measure)
self.stats["measures_defined"] = len(self.measures)
logger.info(f"\nTotal measures defined: {len(self.measures)}")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
expected = EXPECTED_VALUES.get("measures", {})
self.checkpoint_mgr.validate(
"total_measures",
expected=expected.get("total_measures", 900),
actual=len(self.measures),
min_value=expected.get("min_measures", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.measures)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
def save_results(self, output_dir: str = "/tmp/compliance_output"):
"""Save results to JSON files."""
logger.info("\n" + "=" * 60)
logger.info("SAVING RESULTS")
logger.info("=" * 60)
os.makedirs(output_dir, exist_ok=True)
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
with open(checkpoints_file, "w") as f:
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
controls_file = os.path.join(output_dir, "controls.json")
with open(controls_file, "w") as f:
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
measures_file = os.path.join(output_dir, "measures.json")
with open(measures_file, "w") as f:
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
stats_file = os.path.join(output_dir, "statistics.json")
self.stats["generated_at"] = datetime.now().isoformat()
with open(stats_file, "w") as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
logger.info(f"Saved statistics to {stats_file}")
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
"""Run the complete pipeline.
Args:
force_reindex: If True, re-ingest all documents even if they exist
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
"""
start_time = time.time()
logger.info("=" * 60)
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
logger.info(f"Started at: {datetime.now().isoformat()}")
logger.info(f"Force reindex: {force_reindex}")
logger.info(f"Skip ingestion: {skip_ingestion}")
if self.checkpoint_mgr:
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
logger.info("=" * 60)
try:
if skip_ingestion:
logger.info("Skipping ingestion phase as requested...")
try:
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
self.stats["chunks_processed"] = collection_info.points_count
except Exception:
self.stats["chunks_processed"] = 0
else:
await self.run_ingestion_phase(force_reindex=force_reindex)
await self.run_extraction_phase()
await self.run_control_generation_phase()
await self.run_measure_generation_phase()
self.save_results()
elapsed = time.time() - start_time
logger.info("\n" + "=" * 60)
logger.info("PIPELINE COMPLETE")
logger.info("=" * 60)
logger.info(f"Duration: {elapsed:.1f} seconds")
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
logger.info(f"Controls created: {self.stats['controls_created']}")
logger.info(f"Measures defined: {self.stats['measures_defined']}")
logger.info(f"\nResults saved to: /tmp/compliance_output/")
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.complete_pipeline({
"duration_seconds": elapsed,
"chunks_processed": self.stats['chunks_processed'],
"checkpoints_extracted": self.stats['checkpoints_extracted'],
"controls_created": self.stats['controls_created'],
"measures_defined": self.stats['measures_defined'],
"by_regulation": self.stats['by_regulation'],
"by_domain": self.stats['by_domain'],
})
except Exception as e:
logger.error(f"Pipeline failed: {e}")
if self.checkpoint_mgr:
self.checkpoint_mgr.state.status = "failed"
self.checkpoint_mgr._save()
raise
+54 -702
View File
@@ -1,7 +1,10 @@
"""
DSFA RAG API Endpoints.
DSFA RAG API Endpoints — Barrel Re-export.
Provides REST API for searching DSFA corpus with full source attribution.
Split into submodules:
- dsfa_rag_models.py — Pydantic request/response models
- dsfa_rag_embedding.py — Embedding service integration & text extraction
- dsfa_rag_routes.py — Route handlers (search, sources, ingest, stats)
Endpoints:
- GET /api/v1/dsfa-rag/search - Semantic search with attribution
@@ -11,705 +14,54 @@ Endpoints:
- GET /api/v1/dsfa-rag/stats - Get corpus statistics
"""
import os
import uuid
import logging
from typing import List, Optional
from dataclasses import dataclass, asdict
import httpx
from fastapi import APIRouter, HTTPException, Query, Depends
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# Embedding service configuration
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
# Import from ingestion module
from dsfa_corpus_ingestion import (
DSFACorpusStore,
DSFAQdrantService,
DSFASearchResult,
LICENSE_REGISTRY,
DSFA_SOURCES,
generate_attribution_notice,
get_license_label,
DSFA_COLLECTION,
chunk_document
# Models
from dsfa_rag_models import (
DSFASourceResponse,
DSFAChunkResponse,
DSFASearchResultResponse,
DSFASearchResponse,
DSFASourceStatsResponse,
DSFACorpusStatsResponse,
IngestRequest,
IngestResponse,
LicenseInfo,
)
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
# =============================================================================
# Pydantic Models
# =============================================================================
class DSFASourceResponse(BaseModel):
"""Response model for DSFA source."""
id: str
source_code: str
name: str
full_name: Optional[str] = None
organization: Optional[str] = None
source_url: Optional[str] = None
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
attribution_text: str
document_type: Optional[str] = None
language: str = "de"
class DSFAChunkResponse(BaseModel):
"""Response model for a single chunk with attribution."""
chunk_id: str
content: str
section_title: Optional[str] = None
page_number: Optional[int] = None
category: Optional[str] = None
# Document info
document_id: str
document_title: Optional[str] = None
# Attribution (always included)
source_id: str
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
document_type: Optional[str] = None
class DSFASearchResultResponse(BaseModel):
"""Response model for search result."""
chunk_id: str
content: str
score: float
# Attribution
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
# Metadata
document_type: Optional[str] = None
category: Optional[str] = None
section_title: Optional[str] = None
page_number: Optional[int] = None
class DSFASearchResponse(BaseModel):
"""Response model for search endpoint."""
query: str
results: List[DSFASearchResultResponse]
total_results: int
# Aggregated licenses for footer
licenses_used: List[str]
attribution_notice: str
class DSFASourceStatsResponse(BaseModel):
"""Response model for source statistics."""
source_id: str
source_code: str
name: str
organization: Optional[str] = None
license_code: str
document_type: Optional[str] = None
document_count: int
chunk_count: int
last_indexed_at: Optional[str] = None
class DSFACorpusStatsResponse(BaseModel):
"""Response model for corpus statistics."""
sources: List[DSFASourceStatsResponse]
total_sources: int
total_documents: int
total_chunks: int
qdrant_collection: str
qdrant_points_count: int
qdrant_status: str
class IngestRequest(BaseModel):
"""Request model for ingestion."""
document_url: Optional[str] = None
document_text: Optional[str] = None
title: Optional[str] = None
class IngestResponse(BaseModel):
"""Response model for ingestion."""
source_code: str
document_id: Optional[str] = None
chunks_created: int
message: str
class LicenseInfo(BaseModel):
"""License information."""
code: str
name: str
url: Optional[str] = None
attribution_required: bool
modification_allowed: bool
commercial_use: bool
# =============================================================================
# Dependency Injection
# =============================================================================
# Database pool (will be set from main.py)
_db_pool = None
def set_db_pool(pool):
"""Set the database pool for API endpoints."""
global _db_pool
_db_pool = pool
async def get_store() -> DSFACorpusStore:
"""Get DSFA corpus store."""
if _db_pool is None:
raise HTTPException(status_code=503, detail="Database not initialized")
return DSFACorpusStore(_db_pool)
async def get_qdrant() -> DSFAQdrantService:
"""Get Qdrant service."""
return DSFAQdrantService()
# =============================================================================
# Embedding Service Integration
# =============================================================================
async def get_embedding(text: str) -> List[float]:
"""
Get embedding for text using the embedding-service.
Uses BGE-M3 model which produces 1024-dimensional vectors.
"""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed-single",
json={"text": text}
)
response.raise_for_status()
data = response.json()
return data.get("embedding", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service error: {e}")
# Fallback to hash-based pseudo-embedding for development
return _generate_fallback_embedding(text)
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
"""
Get embeddings for multiple texts in batch.
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": texts}
)
response.raise_for_status()
data = response.json()
return data.get("embeddings", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service batch error: {e}")
# Fallback
return [_generate_fallback_embedding(t) for t in texts]
async def extract_text_from_url(url: str) -> str:
"""
Extract text from a document URL (PDF, HTML, etc.).
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
# First try to use the embedding-service's extract-pdf endpoint
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
json={"url": url}
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except httpx.HTTPError as e:
logger.error(f"PDF extraction error for {url}: {e}")
# Fallback: try to fetch HTML content directly
try:
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "html" in content_type:
# Simple HTML text extraction
import re
html = response.text
# Remove scripts and styles
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
# Remove tags
text = re.sub(r'<[^>]+>', ' ', html)
# Clean whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
else:
return ""
except Exception as fetch_err:
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
return ""
def _generate_fallback_embedding(text: str) -> List[float]:
"""
Generate deterministic pseudo-embedding from text hash.
Used as fallback when embedding service is unavailable.
"""
import hashlib
import struct
hash_bytes = hashlib.sha256(text.encode()).digest()
embedding = []
for i in range(0, min(len(hash_bytes), 128), 4):
val = struct.unpack('f', hash_bytes[i:i+4])[0]
embedding.append(val % 1.0)
# Pad to 1024 dimensions
while len(embedding) < 1024:
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
return embedding[:1024]
# =============================================================================
# API Endpoints
# =============================================================================
@router.get("/search", response_model=DSFASearchResponse)
async def search_dsfa_corpus(
query: str = Query(..., min_length=3, description="Search query"),
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
include_attribution: bool = Query(True, description="Include attribution in results"),
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Search DSFA corpus with full attribution.
Returns matching chunks with source/license information for compliance.
"""
# Get query embedding
query_embedding = await get_embedding(query)
# Search Qdrant
raw_results = await qdrant.search(
query_embedding=query_embedding,
source_codes=source_codes,
document_types=document_types,
categories=categories,
limit=limit
)
# Transform results
results = []
licenses_used = set()
for r in raw_results:
license_code = r.get("license_code", "")
license_info = LICENSE_REGISTRY.get(license_code, {})
result = DSFASearchResultResponse(
chunk_id=r.get("chunk_id", ""),
content=r.get("content", ""),
score=r.get("score", 0.0),
source_code=r.get("source_code", ""),
source_name=r.get("source_name", ""),
attribution_text=r.get("attribution_text", ""),
license_code=license_code,
license_name=license_info.get("name", license_code),
license_url=license_info.get("url"),
attribution_required=r.get("attribution_required", True),
source_url=r.get("source_url"),
document_type=r.get("document_type"),
category=r.get("category"),
section_title=r.get("section_title"),
page_number=r.get("page_number")
)
results.append(result)
licenses_used.add(license_code)
# Generate attribution notice
search_results = [
DSFASearchResult(
chunk_id=r.chunk_id,
content=r.content,
score=r.score,
source_code=r.source_code,
source_name=r.source_name,
attribution_text=r.attribution_text,
license_code=r.license_code,
license_url=r.license_url,
attribution_required=r.attribution_required,
source_url=r.source_url,
document_type=r.document_type or "",
category=r.category or "",
section_title=r.section_title,
page_number=r.page_number
)
for r in results
]
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
return DSFASearchResponse(
query=query,
results=results,
total_results=len(results),
licenses_used=list(licenses_used),
attribution_notice=attribution_notice
)
@router.get("/sources", response_model=List[DSFASourceResponse])
async def list_dsfa_sources(
document_type: Optional[str] = Query(None, description="Filter by document type"),
license_code: Optional[str] = Query(None, description="Filter by license"),
store: DSFACorpusStore = Depends(get_store)
):
"""List all registered DSFA sources with license info."""
sources = await store.list_sources()
result = []
for s in sources:
# Apply filters
if document_type and s.get("document_type") != document_type:
continue
if license_code and s.get("license_code") != license_code:
continue
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
result.append(DSFASourceResponse(
id=str(s["id"]),
source_code=s["source_code"],
name=s["name"],
full_name=s.get("full_name"),
organization=s.get("organization"),
source_url=s.get("source_url"),
license_code=s.get("license_code", ""),
license_name=license_info.get("name", s.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=s.get("attribution_required", True),
attribution_text=s.get("attribution_text", ""),
document_type=s.get("document_type"),
language=s.get("language", "de")
))
return result
@router.get("/sources/available")
async def list_available_sources():
"""List all available source definitions (from DSFA_SOURCES constant)."""
return [
{
"source_code": s["source_code"],
"name": s["name"],
"organization": s.get("organization"),
"license_code": s["license_code"],
"document_type": s.get("document_type")
}
for s in DSFA_SOURCES
]
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
async def get_dsfa_source(
source_code: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get details for a specific source."""
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
return DSFASourceResponse(
id=str(source["id"]),
source_code=source["source_code"],
name=source["name"],
full_name=source.get("full_name"),
organization=source.get("organization"),
source_url=source.get("source_url"),
license_code=source.get("license_code", ""),
license_name=license_info.get("name", source.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=source.get("attribution_required", True),
attribution_text=source.get("attribution_text", ""),
document_type=source.get("document_type"),
language=source.get("language", "de")
)
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
async def ingest_dsfa_source(
source_code: str,
request: IngestRequest,
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Trigger ingestion for a specific source.
Can provide document via URL or direct text.
"""
# Get source
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
# Need either URL or text
if not request.document_text and not request.document_url:
raise HTTPException(
status_code=400,
detail="Either document_text or document_url must be provided"
)
# Ensure Qdrant collection exists
await qdrant.ensure_collection()
# Get text content
text_content = request.document_text
if request.document_url and not text_content:
# Download and extract text from URL
logger.info(f"Extracting text from URL: {request.document_url}")
text_content = await extract_text_from_url(request.document_url)
if not text_content:
raise HTTPException(
status_code=400,
detail=f"Could not extract text from URL: {request.document_url}"
)
if not text_content or len(text_content.strip()) < 50:
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
# Create document record
doc_title = request.title or f"Document for {source_code}"
document_id = await store.create_document(
source_id=str(source["id"]),
title=doc_title,
file_type="text",
metadata={"ingested_via": "api", "source_code": source_code}
)
# Chunk the document
chunks = chunk_document(text_content, source_code)
if not chunks:
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=0,
message="Document created but no chunks generated"
)
# Generate embeddings in batch for efficiency
chunk_texts = [chunk["content"] for chunk in chunks]
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
embeddings = await get_embeddings_batch(chunk_texts)
# Create chunk records in PostgreSQL and prepare for Qdrant
chunk_records = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
# Create chunk in PostgreSQL
chunk_id = await store.create_chunk(
document_id=document_id,
source_id=str(source["id"]),
content=chunk["content"],
chunk_index=i,
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category")
)
chunk_records.append({
"chunk_id": chunk_id,
"document_id": document_id,
"source_id": str(source["id"]),
"content": chunk["content"],
"section_title": chunk.get("section_title"),
"source_code": source_code,
"source_name": source["name"],
"attribution_text": source["attribution_text"],
"license_code": source["license_code"],
"attribution_required": source.get("attribution_required", True),
"document_type": source.get("document_type", ""),
"category": chunk.get("category", ""),
"language": source.get("language", "de"),
"page_number": chunk.get("page_number")
})
# Index in Qdrant
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
# Update document record
await store.update_document_indexed(document_id, len(chunks))
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=indexed_count,
message=f"Successfully ingested {indexed_count} chunks from document"
)
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
async def get_chunk_with_attribution(
chunk_id: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get single chunk with full source attribution."""
chunk = await store.get_chunk_with_attribution(chunk_id)
if not chunk:
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
return DSFAChunkResponse(
chunk_id=str(chunk["chunk_id"]),
content=chunk.get("content", ""),
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category"),
document_id=str(chunk.get("document_id", "")),
document_title=chunk.get("document_title"),
source_id=str(chunk.get("source_id", "")),
source_code=chunk.get("source_code", ""),
source_name=chunk.get("source_name", ""),
attribution_text=chunk.get("attribution_text", ""),
license_code=chunk.get("license_code", ""),
license_name=license_info.get("name", chunk.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=chunk.get("attribution_required", True),
source_url=chunk.get("source_url"),
document_type=chunk.get("document_type")
)
@router.get("/stats", response_model=DSFACorpusStatsResponse)
async def get_corpus_stats(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""Get corpus statistics for dashboard."""
# Get PostgreSQL stats
source_stats = await store.get_source_stats()
total_docs = 0
total_chunks = 0
stats_response = []
for s in source_stats:
doc_count = s.get("document_count", 0) or 0
chunk_count = s.get("chunk_count", 0) or 0
total_docs += doc_count
total_chunks += chunk_count
last_indexed = s.get("last_indexed_at")
stats_response.append(DSFASourceStatsResponse(
source_id=str(s.get("source_id", "")),
source_code=s.get("source_code", ""),
name=s.get("name", ""),
organization=s.get("organization"),
license_code=s.get("license_code", ""),
document_type=s.get("document_type"),
document_count=doc_count,
chunk_count=chunk_count,
last_indexed_at=last_indexed.isoformat() if last_indexed else None
))
# Get Qdrant stats
qdrant_stats = await qdrant.get_stats()
return DSFACorpusStatsResponse(
sources=stats_response,
total_sources=len(source_stats),
total_documents=total_docs,
total_chunks=total_chunks,
qdrant_collection=DSFA_COLLECTION,
qdrant_points_count=qdrant_stats.get("points_count", 0),
qdrant_status=qdrant_stats.get("status", "unknown")
)
@router.get("/licenses")
async def list_licenses():
"""List all supported licenses with their terms."""
return [
LicenseInfo(
code=code,
name=info.get("name", code),
url=info.get("url"),
attribution_required=info.get("attribution_required", True),
modification_allowed=info.get("modification_allowed", True),
commercial_use=info.get("commercial_use", True)
)
for code, info in LICENSE_REGISTRY.items()
]
@router.post("/init")
async def initialize_dsfa_corpus(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Initialize DSFA corpus.
- Creates Qdrant collection
- Registers all predefined sources
"""
# Ensure Qdrant collection exists
qdrant_ok = await qdrant.ensure_collection()
# Register all sources
registered = 0
for source in DSFA_SOURCES:
try:
await store.register_source(source)
registered += 1
except Exception as e:
print(f"Error registering source {source['source_code']}: {e}")
return {
"qdrant_collection_created": qdrant_ok,
"sources_registered": registered,
"total_sources": len(DSFA_SOURCES)
}
# Embedding utilities
from dsfa_rag_embedding import (
get_embedding,
get_embeddings_batch,
extract_text_from_url,
EMBEDDING_SERVICE_URL,
)
# Routes (router + set_db_pool)
from dsfa_rag_routes import (
router,
set_db_pool,
get_store,
get_qdrant,
)
__all__ = [
# Router
"router",
"set_db_pool",
"get_store",
"get_qdrant",
# Models
"DSFASourceResponse",
"DSFAChunkResponse",
"DSFASearchResultResponse",
"DSFASearchResponse",
"DSFASourceStatsResponse",
"DSFACorpusStatsResponse",
"IngestRequest",
"IngestResponse",
"LicenseInfo",
# Embedding
"get_embedding",
"get_embeddings_batch",
"extract_text_from_url",
"EMBEDDING_SERVICE_URL",
]
@@ -0,0 +1,116 @@
"""
DSFA RAG Embedding Service Integration.
Handles embedding generation, text extraction, and fallback logic.
"""
import os
import hashlib
import logging
import struct
import re
from typing import List
import httpx
logger = logging.getLogger(__name__)
# Embedding service configuration
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
async def get_embedding(text: str) -> List[float]:
"""
Get embedding for text using the embedding-service.
Uses BGE-M3 model which produces 1024-dimensional vectors.
"""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed-single",
json={"text": text}
)
response.raise_for_status()
data = response.json()
return data.get("embedding", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service error: {e}")
# Fallback to hash-based pseudo-embedding for development
return _generate_fallback_embedding(text)
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
"""
Get embeddings for multiple texts in batch.
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": texts}
)
response.raise_for_status()
data = response.json()
return data.get("embeddings", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service batch error: {e}")
# Fallback
return [_generate_fallback_embedding(t) for t in texts]
async def extract_text_from_url(url: str) -> str:
"""
Extract text from a document URL (PDF, HTML, etc.).
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
# First try to use the embedding-service's extract-pdf endpoint
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
json={"url": url}
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except httpx.HTTPError as e:
logger.error(f"PDF extraction error for {url}: {e}")
# Fallback: try to fetch HTML content directly
try:
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "html" in content_type:
# Simple HTML text extraction
html = response.text
# Remove scripts and styles
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
# Remove tags
text = re.sub(r'<[^>]+>', ' ', html)
# Clean whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
else:
return ""
except Exception as fetch_err:
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
return ""
def _generate_fallback_embedding(text: str) -> List[float]:
"""
Generate deterministic pseudo-embedding from text hash.
Used as fallback when embedding service is unavailable.
"""
hash_bytes = hashlib.sha256(text.encode()).digest()
embedding = []
for i in range(0, min(len(hash_bytes), 128), 4):
val = struct.unpack('f', hash_bytes[i:i+4])[0]
embedding.append(val % 1.0)
# Pad to 1024 dimensions
while len(embedding) < 1024:
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
return embedding[:1024]
+137
View File
@@ -0,0 +1,137 @@
"""
DSFA RAG Pydantic Models.
Request/Response models for the DSFA RAG API.
"""
from typing import List, Optional
from pydantic import BaseModel, Field
# =============================================================================
# Response Models
# =============================================================================
class DSFASourceResponse(BaseModel):
"""Response model for DSFA source."""
id: str
source_code: str
name: str
full_name: Optional[str] = None
organization: Optional[str] = None
source_url: Optional[str] = None
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
attribution_text: str
document_type: Optional[str] = None
language: str = "de"
class DSFAChunkResponse(BaseModel):
"""Response model for a single chunk with attribution."""
chunk_id: str
content: str
section_title: Optional[str] = None
page_number: Optional[int] = None
category: Optional[str] = None
# Document info
document_id: str
document_title: Optional[str] = None
# Attribution (always included)
source_id: str
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
document_type: Optional[str] = None
class DSFASearchResultResponse(BaseModel):
"""Response model for search result."""
chunk_id: str
content: str
score: float
# Attribution
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
# Metadata
document_type: Optional[str] = None
category: Optional[str] = None
section_title: Optional[str] = None
page_number: Optional[int] = None
class DSFASearchResponse(BaseModel):
"""Response model for search endpoint."""
query: str
results: List[DSFASearchResultResponse]
total_results: int
# Aggregated licenses for footer
licenses_used: List[str]
attribution_notice: str
class DSFASourceStatsResponse(BaseModel):
"""Response model for source statistics."""
source_id: str
source_code: str
name: str
organization: Optional[str] = None
license_code: str
document_type: Optional[str] = None
document_count: int
chunk_count: int
last_indexed_at: Optional[str] = None
class DSFACorpusStatsResponse(BaseModel):
"""Response model for corpus statistics."""
sources: List[DSFASourceStatsResponse]
total_sources: int
total_documents: int
total_chunks: int
qdrant_collection: str
qdrant_points_count: int
qdrant_status: str
class IngestRequest(BaseModel):
"""Request model for ingestion."""
document_url: Optional[str] = None
document_text: Optional[str] = None
title: Optional[str] = None
class IngestResponse(BaseModel):
"""Response model for ingestion."""
source_code: str
document_id: Optional[str] = None
chunks_created: int
message: str
class LicenseInfo(BaseModel):
"""License information."""
code: str
name: str
url: Optional[str] = None
attribution_required: bool
modification_allowed: bool
commercial_use: bool
+461
View File
@@ -0,0 +1,461 @@
"""
DSFA RAG API Route Handlers.
Endpoint implementations for search, sources, ingestion, stats, and init.
"""
import logging
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Query, Depends
from dsfa_corpus_ingestion import (
DSFACorpusStore,
DSFAQdrantService,
DSFASearchResult,
LICENSE_REGISTRY,
DSFA_SOURCES,
generate_attribution_notice,
get_license_label,
DSFA_COLLECTION,
chunk_document,
)
from dsfa_rag_models import (
DSFASourceResponse,
DSFAChunkResponse,
DSFASearchResultResponse,
DSFASearchResponse,
DSFASourceStatsResponse,
DSFACorpusStatsResponse,
IngestRequest,
IngestResponse,
LicenseInfo,
)
from dsfa_rag_embedding import (
get_embedding,
get_embeddings_batch,
extract_text_from_url,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
# =============================================================================
# Dependency Injection
# =============================================================================
_db_pool = None
def set_db_pool(pool):
"""Set the database pool for API endpoints."""
global _db_pool
_db_pool = pool
async def get_store() -> DSFACorpusStore:
"""Get DSFA corpus store."""
if _db_pool is None:
raise HTTPException(status_code=503, detail="Database not initialized")
return DSFACorpusStore(_db_pool)
async def get_qdrant() -> DSFAQdrantService:
"""Get Qdrant service."""
return DSFAQdrantService()
# =============================================================================
# API Endpoints
# =============================================================================
@router.get("/search", response_model=DSFASearchResponse)
async def search_dsfa_corpus(
query: str = Query(..., min_length=3, description="Search query"),
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
include_attribution: bool = Query(True, description="Include attribution in results"),
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Search DSFA corpus with full attribution.
Returns matching chunks with source/license information for compliance.
"""
query_embedding = await get_embedding(query)
raw_results = await qdrant.search(
query_embedding=query_embedding,
source_codes=source_codes,
document_types=document_types,
categories=categories,
limit=limit
)
results = []
licenses_used = set()
for r in raw_results:
license_code = r.get("license_code", "")
license_info = LICENSE_REGISTRY.get(license_code, {})
result = DSFASearchResultResponse(
chunk_id=r.get("chunk_id", ""),
content=r.get("content", ""),
score=r.get("score", 0.0),
source_code=r.get("source_code", ""),
source_name=r.get("source_name", ""),
attribution_text=r.get("attribution_text", ""),
license_code=license_code,
license_name=license_info.get("name", license_code),
license_url=license_info.get("url"),
attribution_required=r.get("attribution_required", True),
source_url=r.get("source_url"),
document_type=r.get("document_type"),
category=r.get("category"),
section_title=r.get("section_title"),
page_number=r.get("page_number")
)
results.append(result)
licenses_used.add(license_code)
# Generate attribution notice
search_results = [
DSFASearchResult(
chunk_id=r.chunk_id,
content=r.content,
score=r.score,
source_code=r.source_code,
source_name=r.source_name,
attribution_text=r.attribution_text,
license_code=r.license_code,
license_url=r.license_url,
attribution_required=r.attribution_required,
source_url=r.source_url,
document_type=r.document_type or "",
category=r.category or "",
section_title=r.section_title,
page_number=r.page_number
)
for r in results
]
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
return DSFASearchResponse(
query=query,
results=results,
total_results=len(results),
licenses_used=list(licenses_used),
attribution_notice=attribution_notice
)
@router.get("/sources", response_model=List[DSFASourceResponse])
async def list_dsfa_sources(
document_type: Optional[str] = Query(None, description="Filter by document type"),
license_code: Optional[str] = Query(None, description="Filter by license"),
store: DSFACorpusStore = Depends(get_store)
):
"""List all registered DSFA sources with license info."""
sources = await store.list_sources()
result = []
for s in sources:
if document_type and s.get("document_type") != document_type:
continue
if license_code and s.get("license_code") != license_code:
continue
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
result.append(DSFASourceResponse(
id=str(s["id"]),
source_code=s["source_code"],
name=s["name"],
full_name=s.get("full_name"),
organization=s.get("organization"),
source_url=s.get("source_url"),
license_code=s.get("license_code", ""),
license_name=license_info.get("name", s.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=s.get("attribution_required", True),
attribution_text=s.get("attribution_text", ""),
document_type=s.get("document_type"),
language=s.get("language", "de")
))
return result
@router.get("/sources/available")
async def list_available_sources():
"""List all available source definitions (from DSFA_SOURCES constant)."""
return [
{
"source_code": s["source_code"],
"name": s["name"],
"organization": s.get("organization"),
"license_code": s["license_code"],
"document_type": s.get("document_type")
}
for s in DSFA_SOURCES
]
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
async def get_dsfa_source(
source_code: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get details for a specific source."""
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
return DSFASourceResponse(
id=str(source["id"]),
source_code=source["source_code"],
name=source["name"],
full_name=source.get("full_name"),
organization=source.get("organization"),
source_url=source.get("source_url"),
license_code=source.get("license_code", ""),
license_name=license_info.get("name", source.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=source.get("attribution_required", True),
attribution_text=source.get("attribution_text", ""),
document_type=source.get("document_type"),
language=source.get("language", "de")
)
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
async def ingest_dsfa_source(
source_code: str,
request: IngestRequest,
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Trigger ingestion for a specific source.
Can provide document via URL or direct text.
"""
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
if not request.document_text and not request.document_url:
raise HTTPException(
status_code=400,
detail="Either document_text or document_url must be provided"
)
await qdrant.ensure_collection()
text_content = request.document_text
if request.document_url and not text_content:
logger.info(f"Extracting text from URL: {request.document_url}")
text_content = await extract_text_from_url(request.document_url)
if not text_content:
raise HTTPException(
status_code=400,
detail=f"Could not extract text from URL: {request.document_url}"
)
if not text_content or len(text_content.strip()) < 50:
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
doc_title = request.title or f"Document for {source_code}"
document_id = await store.create_document(
source_id=str(source["id"]),
title=doc_title,
file_type="text",
metadata={"ingested_via": "api", "source_code": source_code}
)
chunks = chunk_document(text_content, source_code)
if not chunks:
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=0,
message="Document created but no chunks generated"
)
chunk_texts = [chunk["content"] for chunk in chunks]
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
embeddings = await get_embeddings_batch(chunk_texts)
chunk_records = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
chunk_id = await store.create_chunk(
document_id=document_id,
source_id=str(source["id"]),
content=chunk["content"],
chunk_index=i,
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category")
)
chunk_records.append({
"chunk_id": chunk_id,
"document_id": document_id,
"source_id": str(source["id"]),
"content": chunk["content"],
"section_title": chunk.get("section_title"),
"source_code": source_code,
"source_name": source["name"],
"attribution_text": source["attribution_text"],
"license_code": source["license_code"],
"attribution_required": source.get("attribution_required", True),
"document_type": source.get("document_type", ""),
"category": chunk.get("category", ""),
"language": source.get("language", "de"),
"page_number": chunk.get("page_number")
})
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
await store.update_document_indexed(document_id, len(chunks))
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=indexed_count,
message=f"Successfully ingested {indexed_count} chunks from document"
)
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
async def get_chunk_with_attribution(
chunk_id: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get single chunk with full source attribution."""
chunk = await store.get_chunk_with_attribution(chunk_id)
if not chunk:
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
return DSFAChunkResponse(
chunk_id=str(chunk["chunk_id"]),
content=chunk.get("content", ""),
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category"),
document_id=str(chunk.get("document_id", "")),
document_title=chunk.get("document_title"),
source_id=str(chunk.get("source_id", "")),
source_code=chunk.get("source_code", ""),
source_name=chunk.get("source_name", ""),
attribution_text=chunk.get("attribution_text", ""),
license_code=chunk.get("license_code", ""),
license_name=license_info.get("name", chunk.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=chunk.get("attribution_required", True),
source_url=chunk.get("source_url"),
document_type=chunk.get("document_type")
)
@router.get("/stats", response_model=DSFACorpusStatsResponse)
async def get_corpus_stats(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""Get corpus statistics for dashboard."""
source_stats = await store.get_source_stats()
total_docs = 0
total_chunks = 0
stats_response = []
for s in source_stats:
doc_count = s.get("document_count", 0) or 0
chunk_count = s.get("chunk_count", 0) or 0
total_docs += doc_count
total_chunks += chunk_count
last_indexed = s.get("last_indexed_at")
stats_response.append(DSFASourceStatsResponse(
source_id=str(s.get("source_id", "")),
source_code=s.get("source_code", ""),
name=s.get("name", ""),
organization=s.get("organization"),
license_code=s.get("license_code", ""),
document_type=s.get("document_type"),
document_count=doc_count,
chunk_count=chunk_count,
last_indexed_at=last_indexed.isoformat() if last_indexed else None
))
qdrant_stats = await qdrant.get_stats()
return DSFACorpusStatsResponse(
sources=stats_response,
total_sources=len(source_stats),
total_documents=total_docs,
total_chunks=total_chunks,
qdrant_collection=DSFA_COLLECTION,
qdrant_points_count=qdrant_stats.get("points_count", 0),
qdrant_status=qdrant_stats.get("status", "unknown")
)
@router.get("/licenses")
async def list_licenses():
"""List all supported licenses with their terms."""
return [
LicenseInfo(
code=code,
name=info.get("name", code),
url=info.get("url"),
attribution_required=info.get("attribution_required", True),
modification_allowed=info.get("modification_allowed", True),
commercial_use=info.get("commercial_use", True)
)
for code, info in LICENSE_REGISTRY.items()
]
@router.post("/init")
async def initialize_dsfa_corpus(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Initialize DSFA corpus.
- Creates Qdrant collection
- Registers all predefined sources
"""
qdrant_ok = await qdrant.ensure_collection()
registered = 0
for source in DSFA_SOURCES:
try:
await store.register_source(source)
registered += 1
except Exception as e:
print(f"Error registering source {source['source_code']}: {e}")
return {
"qdrant_collection_created": qdrant_ok,
"sources_registered": registered,
"total_sources": len(DSFA_SOURCES)
}
@@ -1,31 +1,19 @@
#!/usr/bin/env python3
"""
Full Compliance Pipeline for Legal Corpus.
Full Compliance Pipeline for Legal Corpus — Barrel Re-export.
This script runs the complete pipeline:
1. Re-ingest all legal documents with improved chunking
2. Extract requirements/checkpoints from chunks
3. Generate controls using AI
4. Define remediation measures
5. Update statistics
Split into submodules:
- compliance_models.py — Dataclasses (Checkpoint, Control, Measure)
- compliance_extraction.py — Pattern extraction & control/measure generation
- compliance_pipeline.py — Pipeline phases & orchestrator
Run on Mac Mini:
nohup python full_compliance_pipeline.py > /tmp/compliance_pipeline.log 2>&1 &
Checkpoints are saved to /tmp/pipeline_checkpoints.json and can be viewed in admin-v2.
"""
import asyncio
import json
import logging
import os
import sys
import time
from datetime import datetime
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, asdict
import re
import hashlib
# Configure logging
logging.basicConfig(
@@ -36,671 +24,25 @@ logging.basicConfig(
logging.FileHandler('/tmp/compliance_pipeline.log')
]
)
logger = logging.getLogger(__name__)
# Import checkpoint manager
try:
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
except ImportError:
logger.warning("Checkpoint manager not available, running without checkpoints")
CheckpointManager = None
EXPECTED_VALUES = {}
ValidationStatus = None
# Set environment variables for Docker network
# Support both QDRANT_URL and QDRANT_HOST
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
os.environ["QDRANT_HOST"] = "qdrant"
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
# Try to import from klausur-service
try:
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
except ImportError:
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
sys.exit(1)
@dataclass
class Checkpoint:
"""A requirement/checkpoint extracted from legal text."""
id: str
regulation_code: str
regulation_name: str
article: Optional[str]
title: str
description: str
original_text: str
chunk_id: str
source_url: str
@dataclass
class Control:
"""A control derived from checkpoints."""
id: str
domain: str
title: str
description: str
checkpoints: List[str] # List of checkpoint IDs
pass_criteria: str
implementation_guidance: str
is_automated: bool
automation_tool: Optional[str]
priority: str
@dataclass
class Measure:
"""A remediation measure for a control."""
id: str
control_id: str
title: str
description: str
responsible: str
deadline_days: int
status: str
class CompliancePipeline:
"""Handles the full compliance pipeline."""
def __init__(self):
# Support both QDRANT_URL and QDRANT_HOST/PORT
qdrant_url = os.getenv("QDRANT_URL", "")
if qdrant_url:
from urllib.parse import urlparse
parsed = urlparse(qdrant_url)
qdrant_host = parsed.hostname or "qdrant"
qdrant_port = parsed.port or 6333
else:
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
qdrant_port = 6333
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
self.checkpoints: List[Checkpoint] = []
self.controls: List[Control] = []
self.measures: List[Measure] = []
self.stats = {
"chunks_processed": 0,
"checkpoints_extracted": 0,
"controls_created": 0,
"measures_defined": 0,
"by_regulation": {},
"by_domain": {},
}
# Initialize checkpoint manager
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
def extract_checkpoints_from_chunk(self, chunk_text: str, payload: Dict) -> List[Checkpoint]:
"""
Extract checkpoints/requirements from a chunk of text.
Uses pattern matching to find requirement-like statements.
"""
checkpoints = []
regulation_code = payload.get("regulation_code", "UNKNOWN")
regulation_name = payload.get("regulation_name", "Unknown")
source_url = payload.get("source_url", "")
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
# Patterns for different requirement types
patterns = [
# BSI-TR patterns
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
# Article patterns (GDPR, AI Act, etc.)
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-:]\s*(.+?)(?=\n|$)', 'article'),
# Numbered requirements
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
# "Der Verantwortliche muss" patterns
(r'(?:Der Verantwortliche|Die Aufsichtsbehörde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
# "Es ist erforderlich" patterns
(r'(?:Es ist erforderlich|Es muss gewährleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
]
for pattern, pattern_type in patterns:
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
for match in matches:
if pattern_type == 'bsi_requirement':
req_id = match.group(1)
description = match.group(2).strip()
title = req_id
elif pattern_type == 'article':
article_num = match.group(1)
paragraph = match.group(2) or ""
title_text = match.group(3).strip()
req_id = f"{regulation_code}-Art{article_num}"
if paragraph:
req_id += f"-{paragraph}"
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
description = title_text
elif pattern_type == 'numbered':
num = match.group(1)
description = match.group(2).strip()
req_id = f"{regulation_code}-{num}"
title = f"Anforderung {num}"
else:
# Generic requirement
description = match.group(0).strip()
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
title = description[:50] + "..." if len(description) > 50 else description
# Skip very short matches
if len(description) < 20:
continue
checkpoint = Checkpoint(
id=req_id,
regulation_code=regulation_code,
regulation_name=regulation_name,
article=title if 'Art' in title else None,
title=title,
description=description[:500],
original_text=description,
chunk_id=chunk_id,
source_url=source_url
)
checkpoints.append(checkpoint)
return checkpoints
def generate_control_for_checkpoints(self, checkpoints: List[Checkpoint]) -> Optional[Control]:
"""
Generate a control that covers the given checkpoints.
This is a simplified version - in production this would use the AI assistant.
"""
if not checkpoints:
return None
# Group by regulation
regulation = checkpoints[0].regulation_code
# Determine domain based on content
all_text = " ".join([cp.description for cp in checkpoints]).lower()
domain = "gov" # Default
if any(kw in all_text for kw in ["verschlüssel", "krypto", "encrypt", "hash"]):
domain = "crypto"
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
domain = "iam"
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
domain = "priv"
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
domain = "sdlc"
elif any(kw in all_text for kw in ["überwach", "monitor", "log", "audit"]):
domain = "aud"
elif any(kw in all_text for kw in ["ki", "künstlich", "ai", "machine learning", "model"]):
domain = "ai"
elif any(kw in all_text for kw in ["betrieb", "operation", "verfügbar", "backup"]):
domain = "ops"
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
domain = "cra"
# Generate control ID
domain_counts = self.stats.get("by_domain", {})
domain_count = domain_counts.get(domain, 0) + 1
control_id = f"{domain.upper()}-{domain_count:03d}"
# Create title from first checkpoint
title = checkpoints[0].title
if len(title) > 100:
title = title[:97] + "..."
# Create description
description = f"Control für {regulation}: " + checkpoints[0].description[:200]
# Pass criteria
pass_criteria = f"Alle {len(checkpoints)} zugehörigen Anforderungen sind erfüllt und dokumentiert."
# Implementation guidance
guidance = f"Implementiere Maßnahmen zur Erfüllung der Anforderungen aus {regulation}. "
guidance += f"Dokumentiere die Umsetzung und führe regelmäßige Reviews durch."
# Determine if automated
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
control = Control(
id=control_id,
domain=domain,
title=title,
description=description,
checkpoints=[cp.id for cp in checkpoints],
pass_criteria=pass_criteria,
implementation_guidance=guidance,
is_automated=is_automated,
automation_tool="CI/CD Pipeline" if is_automated else None,
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
)
return control
def generate_measure_for_control(self, control: Control) -> Measure:
"""Generate a remediation measure for a control."""
measure_id = f"M-{control.id}"
# Determine deadline based on priority
deadline_days = {
"critical": 30,
"high": 60,
"medium": 90,
"low": 180
}.get(control.priority, 90)
# Determine responsible team
responsible = {
"priv": "Datenschutzbeauftragter",
"iam": "IT-Security Team",
"sdlc": "Entwicklungsteam",
"crypto": "IT-Security Team",
"ops": "Operations Team",
"aud": "Compliance Team",
"ai": "AI/ML Team",
"cra": "IT-Security Team",
"gov": "Management"
}.get(control.domain, "Compliance Team")
measure = Measure(
id=measure_id,
control_id=control.id,
title=f"Umsetzung: {control.title[:50]}",
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
responsible=responsible,
deadline_days=deadline_days,
status="pending"
)
return measure
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
"""Phase 1: Ingest documents (incremental - only missing ones)."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
ingestion = LegalCorpusIngestion()
try:
# Check existing chunks per regulation
existing_chunks = {}
try:
for regulation in REGULATIONS:
count_result = self.qdrant.count(
collection_name=LEGAL_CORPUS_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
)
)
existing_chunks[regulation.code] = count_result.count
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
except Exception as e:
logger.warning(f"Could not check existing chunks: {e}")
# Collection might not exist, that's OK
# Determine which regulations need ingestion
regulations_to_ingest = []
for regulation in REGULATIONS:
existing = existing_chunks.get(regulation.code, 0)
if force_reindex or existing == 0:
regulations_to_ingest.append(regulation)
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
else:
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
self.stats["by_regulation"][regulation.code] = existing
if not regulations_to_ingest:
logger.info("All regulations already indexed. Skipping ingestion phase.")
total_chunks = sum(existing_chunks.values())
self.stats["chunks_processed"] = total_chunks
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("skipped", True)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
# Ingest only missing regulations
total_chunks = sum(existing_chunks.values())
for i, regulation in enumerate(regulations_to_ingest, 1):
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
try:
count = await ingestion.ingest_regulation(regulation)
total_chunks += count
self.stats["by_regulation"][regulation.code] = count
logger.info(f" -> {count} chunks")
# Add metric for this regulation
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
except Exception as e:
logger.error(f" -> FAILED: {e}")
self.stats["by_regulation"][regulation.code] = 0
self.stats["chunks_processed"] = total_chunks
logger.info(f"\nTotal chunks in collection: {total_chunks}")
# Validate ingestion results
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
# Validate total chunks
expected = EXPECTED_VALUES.get("ingestion", {})
self.checkpoint_mgr.validate(
"total_chunks",
expected=expected.get("total_chunks", 8000),
actual=total_chunks,
min_value=expected.get("min_chunks", 7000)
)
# Validate key regulations
reg_expected = expected.get("regulations", {})
for reg_code, reg_exp in reg_expected.items():
actual = self.stats["by_regulation"].get(reg_code, 0)
self.checkpoint_mgr.validate(
f"chunks_{reg_code}",
expected=reg_exp.get("expected", 0),
actual=actual,
min_value=reg_exp.get("min", 0)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
finally:
await ingestion.close()
async def run_extraction_phase(self) -> int:
"""Phase 2: Extract checkpoints from chunks."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
try:
# Scroll through all chunks
offset = None
total_checkpoints = 0
while True:
result = self.qdrant.scroll(
collection_name=LEGAL_CORPUS_COLLECTION,
limit=100,
offset=offset,
with_payload=True,
with_vectors=False
)
points, next_offset = result
if not points:
break
for point in points:
payload = point.payload
text = payload.get("text", "")
checkpoints = self.extract_checkpoints_from_chunk(text, payload)
self.checkpoints.extend(checkpoints)
total_checkpoints += len(checkpoints)
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
if next_offset is None:
break
offset = next_offset
self.stats["checkpoints_extracted"] = len(self.checkpoints)
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
# Log per regulation
by_reg = {}
for cp in self.checkpoints:
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
for reg, count in sorted(by_reg.items()):
logger.info(f" {reg}: {count} checkpoints")
# Validate extraction results
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
expected = EXPECTED_VALUES.get("extraction", {})
self.checkpoint_mgr.validate(
"total_checkpoints",
expected=expected.get("total_checkpoints", 3500),
actual=len(self.checkpoints),
min_value=expected.get("min_checkpoints", 3000)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.checkpoints)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_control_generation_phase(self) -> int:
"""Phase 3: Generate controls from checkpoints."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 3: CONTROL GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
try:
# Group checkpoints by regulation
by_regulation: Dict[str, List[Checkpoint]] = {}
for cp in self.checkpoints:
reg = cp.regulation_code
if reg not in by_regulation:
by_regulation[reg] = []
by_regulation[reg].append(cp)
# Generate controls per regulation (group every 3-5 checkpoints)
for regulation, checkpoints in by_regulation.items():
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
# Group checkpoints into batches of 3-5
batch_size = 4
for i in range(0, len(checkpoints), batch_size):
batch = checkpoints[i:i + batch_size]
control = self.generate_control_for_checkpoints(batch)
if control:
self.controls.append(control)
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
self.stats["controls_created"] = len(self.controls)
logger.info(f"\nTotal controls created: {len(self.controls)}")
# Log per domain
for domain, count in sorted(self.stats["by_domain"].items()):
logger.info(f" {domain}: {count} controls")
# Validate control generation
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
expected = EXPECTED_VALUES.get("controls", {})
self.checkpoint_mgr.validate(
"total_controls",
expected=expected.get("total_controls", 900),
actual=len(self.controls),
min_value=expected.get("min_controls", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.controls)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_measure_generation_phase(self) -> int:
"""Phase 4: Generate measures for controls."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 4: MEASURE GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
try:
for control in self.controls:
measure = self.generate_measure_for_control(control)
self.measures.append(measure)
self.stats["measures_defined"] = len(self.measures)
logger.info(f"\nTotal measures defined: {len(self.measures)}")
# Validate measure generation
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
expected = EXPECTED_VALUES.get("measures", {})
self.checkpoint_mgr.validate(
"total_measures",
expected=expected.get("total_measures", 900),
actual=len(self.measures),
min_value=expected.get("min_measures", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.measures)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
def save_results(self, output_dir: str = "/tmp/compliance_output"):
"""Save results to JSON files."""
logger.info("\n" + "=" * 60)
logger.info("SAVING RESULTS")
logger.info("=" * 60)
os.makedirs(output_dir, exist_ok=True)
# Save checkpoints
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
with open(checkpoints_file, "w") as f:
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
# Save controls
controls_file = os.path.join(output_dir, "controls.json")
with open(controls_file, "w") as f:
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
# Save measures
measures_file = os.path.join(output_dir, "measures.json")
with open(measures_file, "w") as f:
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
# Save statistics
stats_file = os.path.join(output_dir, "statistics.json")
self.stats["generated_at"] = datetime.now().isoformat()
with open(stats_file, "w") as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
logger.info(f"Saved statistics to {stats_file}")
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
"""Run the complete pipeline.
Args:
force_reindex: If True, re-ingest all documents even if they exist
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
"""
start_time = time.time()
logger.info("=" * 60)
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
logger.info(f"Started at: {datetime.now().isoformat()}")
logger.info(f"Force reindex: {force_reindex}")
logger.info(f"Skip ingestion: {skip_ingestion}")
if self.checkpoint_mgr:
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
logger.info("=" * 60)
try:
# Phase 1: Ingestion (skip if requested or run incrementally)
if skip_ingestion:
logger.info("Skipping ingestion phase as requested...")
# Still get the chunk count
try:
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
self.stats["chunks_processed"] = collection_info.points_count
except Exception:
self.stats["chunks_processed"] = 0
else:
await self.run_ingestion_phase(force_reindex=force_reindex)
# Phase 2: Extraction
await self.run_extraction_phase()
# Phase 3: Control Generation
await self.run_control_generation_phase()
# Phase 4: Measure Generation
await self.run_measure_generation_phase()
# Save results
self.save_results()
# Final summary
elapsed = time.time() - start_time
logger.info("\n" + "=" * 60)
logger.info("PIPELINE COMPLETE")
logger.info("=" * 60)
logger.info(f"Duration: {elapsed:.1f} seconds")
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
logger.info(f"Controls created: {self.stats['controls_created']}")
logger.info(f"Measures defined: {self.stats['measures_defined']}")
logger.info(f"\nResults saved to: /tmp/compliance_output/")
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
logger.info("=" * 60)
# Complete pipeline checkpoint
if self.checkpoint_mgr:
self.checkpoint_mgr.complete_pipeline({
"duration_seconds": elapsed,
"chunks_processed": self.stats['chunks_processed'],
"checkpoints_extracted": self.stats['checkpoints_extracted'],
"controls_created": self.stats['controls_created'],
"measures_defined": self.stats['measures_defined'],
"by_regulation": self.stats['by_regulation'],
"by_domain": self.stats['by_domain'],
})
except Exception as e:
logger.error(f"Pipeline failed: {e}")
if self.checkpoint_mgr:
self.checkpoint_mgr.state.status = "failed"
self.checkpoint_mgr._save()
raise
# Re-export all public symbols
from compliance_models import Checkpoint, Control, Measure
from compliance_extraction import (
extract_checkpoints_from_chunk,
generate_control_for_checkpoints,
generate_measure_for_control,
)
from compliance_pipeline import CompliancePipeline
__all__ = [
"Checkpoint",
"Control",
"Measure",
"extract_checkpoints_from_chunk",
"generate_control_for_checkpoints",
"generate_measure_for_control",
"CompliancePipeline",
]
async def main():
+27 -759
View File
@@ -1,767 +1,35 @@
"""
GitHub Repository Crawler for Legal Templates.
GitHub Repository Crawler — Barrel Re-export
Crawls GitHub and GitLab repositories to extract legal template documents
(Markdown, HTML, JSON, etc.) for ingestion into the RAG system.
Split into:
- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser
- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source
Features:
- Clone repositories via Git or download as ZIP
- Parse Markdown, HTML, JSON, and plain text files
- Extract structured content with metadata
- Track git commit hashes for reproducibility
- Handle rate limiting and errors gracefully
All public names are re-exported here for backward compatibility.
"""
import asyncio
import hashlib
import json
import logging
import os
import re
import shutil
import tempfile
import zipfile
from dataclasses import dataclass, field
from datetime import datetime
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from template_sources import LicenseType, SourceConfig, LICENSES
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
GITHUB_API_URL = "https://api.github.com"
GITLAB_API_URL = "https://gitlab.com/api/v4"
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "") # Optional for higher rate limits
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
REQUEST_TIMEOUT = 60.0
RATE_LIMIT_DELAY = 1.0 # Delay between requests to avoid rate limiting
@dataclass
class ExtractedDocument:
"""A document extracted from a repository."""
text: str
title: str
file_path: str
file_type: str # "markdown", "html", "json", "text"
source_url: str
source_commit: Optional[str] = None
source_hash: str = "" # SHA256 of original content
sections: List[Dict[str, Any]] = field(default_factory=list)
placeholders: List[str] = field(default_factory=list)
language: str = "en"
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.source_hash and self.text:
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
class MarkdownParser:
"""Parse Markdown files into structured content."""
# Common placeholder patterns
PLACEHOLDER_PATTERNS = [
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
r'\{([a-z_]+)\}', # {company_name}
r'\{\{([a-z_]+)\}\}', # {{company_name}}
r'__([A-Z_]+)__', # __COMPANY_NAME__
r'<([A-Z_]+)>', # <COMPANY_NAME>
]
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse markdown content into an ExtractedDocument."""
# Extract title from first heading or filename
title = cls._extract_title(content, filename)
# Extract sections
sections = cls._extract_sections(content)
# Find placeholders
placeholders = cls._find_placeholders(content)
# Detect language
language = cls._detect_language(content)
# Clean content for indexing
clean_text = cls._clean_for_indexing(content)
return ExtractedDocument(
text=clean_text,
title=title,
file_path=filename,
file_type="markdown",
source_url="", # Will be set by caller
sections=sections,
placeholders=placeholders,
language=language,
)
@classmethod
def _extract_title(cls, content: str, filename: str) -> str:
"""Extract title from markdown heading or filename."""
# Look for first h1 heading
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
if h1_match:
return h1_match.group(1).strip()
# Look for YAML frontmatter title
frontmatter_match = re.search(
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
content, re.DOTALL
)
if frontmatter_match:
return frontmatter_match.group(1).strip()
# Fall back to filename
if filename:
name = Path(filename).stem
# Convert kebab-case or snake_case to title case
return name.replace('-', ' ').replace('_', ' ').title()
return "Untitled"
@classmethod
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
"""Extract sections from markdown content."""
sections = []
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
# Save previous section if it has content
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = current_section["content"].strip()
sections.append(current_section.copy())
# Start new section
level = len(match.group(1))
heading = match.group(2).strip()
current_section = {
"heading": heading,
"level": level,
"content": "",
"start": match.end(),
}
# Add final section
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = content[current_section["start"]:].strip()
sections.append(current_section)
return sections
@classmethod
def _find_placeholders(cls, content: str) -> List[str]:
"""Find placeholder patterns in content."""
placeholders = set()
for pattern in cls.PLACEHOLDER_PATTERNS:
for match in re.finditer(pattern, content):
placeholder = match.group(0)
placeholders.add(placeholder)
return sorted(list(placeholders))
@classmethod
def _detect_language(cls, content: str) -> str:
"""Detect language from content."""
# Look for German-specific words
german_indicators = [
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
]
lower_content = content.lower()
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
if german_count >= 3:
return "de"
return "en"
@classmethod
def _clean_for_indexing(cls, content: str) -> str:
"""Clean markdown content for text indexing."""
# Remove YAML frontmatter
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
# Remove HTML comments
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
# Remove inline HTML tags but keep content
content = re.sub(r'<[^>]+>', '', content)
# Convert markdown formatting
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content) # Bold
content = re.sub(r'\*(.+?)\*', r'\1', content) # Italic
content = re.sub(r'`(.+?)`', r'\1', content) # Inline code
content = re.sub(r'~~(.+?)~~', r'\1', content) # Strikethrough
# Remove link syntax but keep text
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
# Remove image syntax
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
# Clean up whitespace
content = re.sub(r'\n{3,}', '\n\n', content)
content = re.sub(r' +', ' ', content)
return content.strip()
class HTMLParser:
"""Parse HTML files into structured content."""
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse HTML content into an ExtractedDocument."""
# Extract title
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
title = title_match.group(1) if title_match else Path(filename).stem
# Convert to text
text = cls._html_to_text(content)
# Find placeholders
placeholders = MarkdownParser._find_placeholders(text)
# Detect language
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
return ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="html",
source_url="",
placeholders=placeholders,
language=language,
)
@classmethod
def _html_to_text(cls, html: str) -> str:
"""Convert HTML to clean text."""
# Remove script and style tags
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
# Remove comments
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
# Replace common entities
html = html.replace('&nbsp;', ' ')
html = html.replace('&amp;', '&')
html = html.replace('&lt;', '<')
html = html.replace('&gt;', '>')
html = html.replace('&quot;', '"')
html = html.replace('&apos;', "'")
# Add line breaks for block elements
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
# Remove remaining tags
html = re.sub(r'<[^>]+>', '', html)
# Clean whitespace
html = re.sub(r'[ \t]+', ' ', html)
html = re.sub(r'\n[ \t]+', '\n', html)
html = re.sub(r'[ \t]+\n', '\n', html)
html = re.sub(r'\n{3,}', '\n\n', html)
return html.strip()
class JSONParser:
"""Parse JSON files containing legal template data."""
@classmethod
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
"""Parse JSON content into ExtractedDocuments."""
try:
data = json.loads(content)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON from {filename}: {e}")
return []
documents = []
if isinstance(data, dict):
# Handle different JSON structures
documents.extend(cls._parse_dict(data, filename))
elif isinstance(data, list):
for i, item in enumerate(data):
if isinstance(item, dict):
docs = cls._parse_dict(item, f"{filename}[{i}]")
documents.extend(docs)
return documents
@classmethod
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
"""Parse a dictionary into documents."""
documents = []
# Look for text content in common keys
text_keys = ['text', 'content', 'body', 'description', 'value']
title_keys = ['title', 'name', 'heading', 'label', 'key']
# Try to find main text content
text = ""
for key in text_keys:
if key in data and isinstance(data[key], str):
text = data[key]
break
if not text:
# Check for nested structures (like webflorist format)
for key, value in data.items():
if isinstance(value, dict):
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
documents.extend(nested_docs)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
documents.extend(nested_docs)
elif isinstance(item, str) and len(item) > 50:
# Treat long strings as content
documents.append(ExtractedDocument(
text=item,
title=f"{key} {i+1}",
file_path=filename,
file_type="json",
source_url="",
language=MarkdownParser._detect_language(item),
))
return documents
# Found text content
title = ""
for key in title_keys:
if key in data and isinstance(data[key], str):
title = data[key]
break
if not title:
title = Path(filename).stem
# Extract metadata
metadata = {}
for key, value in data.items():
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
metadata[key] = value
placeholders = MarkdownParser._find_placeholders(text)
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
documents.append(ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="json",
source_url="",
placeholders=placeholders,
language=language,
metadata=metadata,
))
return documents
class GitHubCrawler:
"""Crawl GitHub repositories for legal templates."""
def __init__(self, token: Optional[str] = None):
self.token = token or GITHUB_TOKEN
self.headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "LegalTemplatesCrawler/1.0",
}
if self.token:
self.headers["Authorization"] = f"token {self.token}"
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=REQUEST_TIMEOUT,
headers=self.headers,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
"""Parse repository URL into owner, repo, and host."""
parsed = urlparse(url)
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) < 2:
raise ValueError(f"Invalid repository URL: {url}")
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
if 'gitlab' in parsed.netloc:
host = 'gitlab'
else:
host = 'github'
return owner, repo, host
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch of a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("default_branch", "main")
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
"""Get the latest commit SHA for a branch."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("sha", "")
async def list_files(
self,
owner: str,
repo: str,
path: str = "",
branch: str = "main",
patterns: List[str] = None,
exclude_patterns: List[str] = None,
) -> List[Dict[str, Any]]:
"""List files in a repository matching the given patterns."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
patterns = patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = exclude_patterns or []
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
files = []
for item in data.get("tree", []):
if item["type"] != "blob":
continue
file_path = item["path"]
# Check exclude patterns
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
if excluded:
continue
# Check include patterns
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
if not matched:
continue
# Skip large files
if item.get("size", 0) > MAX_FILE_SIZE:
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
continue
files.append({
"path": file_path,
"sha": item["sha"],
"size": item.get("size", 0),
"url": item.get("url", ""),
})
return files
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
"""Get the content of a file from a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
# Use raw content URL for simplicity
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
response = await self.http_client.get(url)
response.raise_for_status()
return response.text
async def crawl_repository(
self,
source: SourceConfig,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a repository and yield extracted documents."""
if not source.repo_url:
logger.warning(f"No repo URL for source: {source.name}")
return
try:
owner, repo, host = self._parse_repo_url(source.repo_url)
except ValueError as e:
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
return
if host == "gitlab":
logger.info(f"GitLab repos not yet supported: {source.name}")
return
logger.info(f"Crawling repository: {owner}/{repo}")
try:
# Get default branch and latest commit
branch = await self.get_default_branch(owner, repo)
commit_sha = await self.get_latest_commit(owner, repo, branch)
await asyncio.sleep(RATE_LIMIT_DELAY)
# List files matching patterns
files = await self.list_files(
owner, repo,
branch=branch,
patterns=source.file_patterns,
exclude_patterns=source.exclude_patterns,
)
logger.info(f"Found {len(files)} matching files in {source.name}")
for file_info in files:
await asyncio.sleep(RATE_LIMIT_DELAY)
try:
content = await self.get_file_content(
owner, repo, file_info["path"], branch
)
# Parse based on file type
file_path = file_info["path"]
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
if file_path.endswith('.md'):
doc = MarkdownParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.html') or file_path.endswith('.htm'):
doc = HTMLParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.json'):
docs = JSONParser.parse(content, file_path)
for doc in docs:
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.txt'):
# Plain text file
yield ExtractedDocument(
text=content,
title=Path(file_path).stem,
file_path=file_path,
file_type="text",
source_url=source_url,
source_commit=commit_sha,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
except httpx.HTTPError as e:
logger.warning(f"Failed to fetch {file_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
except httpx.HTTPError as e:
logger.error(f"HTTP error crawling {source.name}: {e}")
except Exception as e:
logger.error(f"Error crawling {source.name}: {e}")
class RepositoryDownloader:
"""Download and extract repository archives."""
def __init__(self):
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=120.0,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
"""Download repository as ZIP and extract to temp directory."""
if not self.http_client:
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
parsed = urlparse(repo_url)
path_parts = parsed.path.strip('/').split('/')
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
logger.info(f"Downloading ZIP from {zip_url}")
response = await self.http_client.get(zip_url)
response.raise_for_status()
# Save to temp file
temp_dir = Path(tempfile.mkdtemp())
zip_path = temp_dir / f"{repo}.zip"
with open(zip_path, 'wb') as f:
f.write(response.content)
# Extract ZIP
extract_dir = temp_dir / repo
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# The extracted directory is usually named repo-branch
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
if extracted_dirs:
return extracted_dirs[0]
return extract_dir
async def crawl_local_directory(
self,
directory: Path,
source: SourceConfig,
base_url: str,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a local directory for documents."""
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = source.exclude_patterns or []
for pattern in patterns:
for file_path in directory.rglob(pattern.replace("**/", "")):
if not file_path.is_file():
continue
rel_path = str(file_path.relative_to(directory))
# Check exclude patterns
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
if excluded:
continue
# Skip large files
if file_path.stat().st_size > MAX_FILE_SIZE:
continue
try:
content = file_path.read_text(encoding='utf-8')
except UnicodeDecodeError:
try:
content = file_path.read_text(encoding='latin-1')
except Exception:
continue
source_url = f"{base_url}/{rel_path}"
if file_path.suffix == '.md':
doc = MarkdownParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix in ['.html', '.htm']:
doc = HTMLParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix == '.json':
docs = JSONParser.parse(content, rel_path)
for doc in docs:
doc.source_url = source_url
yield doc
elif file_path.suffix == '.txt':
yield ExtractedDocument(
text=content,
title=file_path.stem,
file_path=rel_path,
file_type="text",
source_url=source_url,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
def cleanup(self, directory: Path):
"""Clean up temporary directory."""
if directory.exists():
shutil.rmtree(directory, ignore_errors=True)
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
"""Crawl a source configuration and return all extracted documents."""
documents = []
if source.repo_url:
async with GitHubCrawler() as crawler:
async for doc in crawler.crawl_repository(source):
documents.append(doc)
return documents
# CLI for testing
async def main():
"""Test crawler with a sample source."""
from template_sources import TEMPLATE_SOURCES
# Test with github-site-policy
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
async with GitHubCrawler() as crawler:
count = 0
async for doc in crawler.crawl_repository(source):
count += 1
print(f"\n{'='*60}")
print(f"Title: {doc.title}")
print(f"Path: {doc.file_path}")
print(f"Type: {doc.file_type}")
print(f"Language: {doc.language}")
print(f"URL: {doc.source_url}")
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
print(f"Text preview: {doc.text[:200]}...")
print(f"\n\nTotal documents: {count}")
# Parsers
from github_crawler_parsers import ( # noqa: F401
ExtractedDocument,
MarkdownParser,
HTMLParser,
JSONParser,
)
# Crawler and downloader
from github_crawler_core import ( # noqa: F401
GITHUB_API_URL,
GITLAB_API_URL,
GITHUB_TOKEN,
MAX_FILE_SIZE,
REQUEST_TIMEOUT,
RATE_LIMIT_DELAY,
GitHubCrawler,
RepositoryDownloader,
crawl_source,
main,
)
if __name__ == "__main__":
import asyncio
asyncio.run(main())
@@ -0,0 +1,411 @@
"""
GitHub Crawler - Core Crawler and Downloader
GitHubCrawler for API-based repository crawling and RepositoryDownloader
for ZIP-based local extraction.
Extracted from github_crawler.py to keep files under 500 LOC.
"""
import asyncio
import logging
import os
import shutil
import tempfile
import zipfile
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from template_sources import SourceConfig
from github_crawler_parsers import (
ExtractedDocument,
MarkdownParser,
HTMLParser,
JSONParser,
)
logger = logging.getLogger(__name__)
# Configuration
GITHUB_API_URL = "https://api.github.com"
GITLAB_API_URL = "https://gitlab.com/api/v4"
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "")
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
REQUEST_TIMEOUT = 60.0
RATE_LIMIT_DELAY = 1.0
class GitHubCrawler:
"""Crawl GitHub repositories for legal templates."""
def __init__(self, token: Optional[str] = None):
self.token = token or GITHUB_TOKEN
self.headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "LegalTemplatesCrawler/1.0",
}
if self.token:
self.headers["Authorization"] = f"token {self.token}"
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=REQUEST_TIMEOUT,
headers=self.headers,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
"""Parse repository URL into owner, repo, and host."""
parsed = urlparse(url)
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) < 2:
raise ValueError(f"Invalid repository URL: {url}")
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
if 'gitlab' in parsed.netloc:
host = 'gitlab'
else:
host = 'github'
return owner, repo, host
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch of a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("default_branch", "main")
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
"""Get the latest commit SHA for a branch."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("sha", "")
async def list_files(
self,
owner: str,
repo: str,
path: str = "",
branch: str = "main",
patterns: List[str] = None,
exclude_patterns: List[str] = None,
) -> List[Dict[str, Any]]:
"""List files in a repository matching the given patterns."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
patterns = patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = exclude_patterns or []
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
files = []
for item in data.get("tree", []):
if item["type"] != "blob":
continue
file_path = item["path"]
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
if excluded:
continue
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
if not matched:
continue
if item.get("size", 0) > MAX_FILE_SIZE:
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
continue
files.append({
"path": file_path,
"sha": item["sha"],
"size": item.get("size", 0),
"url": item.get("url", ""),
})
return files
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
"""Get the content of a file from a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
response = await self.http_client.get(url)
response.raise_for_status()
return response.text
async def crawl_repository(
self,
source: SourceConfig,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a repository and yield extracted documents."""
if not source.repo_url:
logger.warning(f"No repo URL for source: {source.name}")
return
try:
owner, repo, host = self._parse_repo_url(source.repo_url)
except ValueError as e:
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
return
if host == "gitlab":
logger.info(f"GitLab repos not yet supported: {source.name}")
return
logger.info(f"Crawling repository: {owner}/{repo}")
try:
branch = await self.get_default_branch(owner, repo)
commit_sha = await self.get_latest_commit(owner, repo, branch)
await asyncio.sleep(RATE_LIMIT_DELAY)
files = await self.list_files(
owner, repo,
branch=branch,
patterns=source.file_patterns,
exclude_patterns=source.exclude_patterns,
)
logger.info(f"Found {len(files)} matching files in {source.name}")
for file_info in files:
await asyncio.sleep(RATE_LIMIT_DELAY)
try:
content = await self.get_file_content(
owner, repo, file_info["path"], branch
)
file_path = file_info["path"]
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
if file_path.endswith('.md'):
doc = MarkdownParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.html') or file_path.endswith('.htm'):
doc = HTMLParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.json'):
docs = JSONParser.parse(content, file_path)
for doc in docs:
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.txt'):
yield ExtractedDocument(
text=content,
title=Path(file_path).stem,
file_path=file_path,
file_type="text",
source_url=source_url,
source_commit=commit_sha,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
except httpx.HTTPError as e:
logger.warning(f"Failed to fetch {file_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
except httpx.HTTPError as e:
logger.error(f"HTTP error crawling {source.name}: {e}")
except Exception as e:
logger.error(f"Error crawling {source.name}: {e}")
class RepositoryDownloader:
"""Download and extract repository archives."""
def __init__(self):
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=120.0,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
"""Download repository as ZIP and extract to temp directory."""
if not self.http_client:
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
parsed = urlparse(repo_url)
path_parts = parsed.path.strip('/').split('/')
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
logger.info(f"Downloading ZIP from {zip_url}")
response = await self.http_client.get(zip_url)
response.raise_for_status()
temp_dir = Path(tempfile.mkdtemp())
zip_path = temp_dir / f"{repo}.zip"
with open(zip_path, 'wb') as f:
f.write(response.content)
extract_dir = temp_dir / repo
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
if extracted_dirs:
return extracted_dirs[0]
return extract_dir
async def crawl_local_directory(
self,
directory: Path,
source: SourceConfig,
base_url: str,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a local directory for documents."""
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = source.exclude_patterns or []
for pattern in patterns:
for file_path in directory.rglob(pattern.replace("**/", "")):
if not file_path.is_file():
continue
rel_path = str(file_path.relative_to(directory))
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
if excluded:
continue
if file_path.stat().st_size > MAX_FILE_SIZE:
continue
try:
content = file_path.read_text(encoding='utf-8')
except UnicodeDecodeError:
try:
content = file_path.read_text(encoding='latin-1')
except Exception:
continue
source_url = f"{base_url}/{rel_path}"
if file_path.suffix == '.md':
doc = MarkdownParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix in ['.html', '.htm']:
doc = HTMLParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix == '.json':
docs = JSONParser.parse(content, rel_path)
for doc in docs:
doc.source_url = source_url
yield doc
elif file_path.suffix == '.txt':
yield ExtractedDocument(
text=content,
title=file_path.stem,
file_path=rel_path,
file_type="text",
source_url=source_url,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
def cleanup(self, directory: Path):
"""Clean up temporary directory."""
if directory.exists():
shutil.rmtree(directory, ignore_errors=True)
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
"""Crawl a source configuration and return all extracted documents."""
documents = []
if source.repo_url:
async with GitHubCrawler() as crawler:
async for doc in crawler.crawl_repository(source):
documents.append(doc)
return documents
# CLI for testing
async def main():
"""Test crawler with a sample source."""
from template_sources import TEMPLATE_SOURCES
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
async with GitHubCrawler() as crawler:
count = 0
async for doc in crawler.crawl_repository(source):
count += 1
print(f"\n{'='*60}")
print(f"Title: {doc.title}")
print(f"Path: {doc.file_path}")
print(f"Type: {doc.file_type}")
print(f"Language: {doc.language}")
print(f"URL: {doc.source_url}")
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
print(f"Text preview: {doc.text[:200]}...")
print(f"\n\nTotal documents: {count}")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,303 @@
"""
GitHub Crawler - Document Parsers
Markdown, HTML, and JSON parsers for extracting structured content
from legal template documents.
Extracted from github_crawler.py to keep files under 500 LOC.
"""
import hashlib
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
@dataclass
class ExtractedDocument:
"""A document extracted from a repository."""
text: str
title: str
file_path: str
file_type: str # "markdown", "html", "json", "text"
source_url: str
source_commit: Optional[str] = None
source_hash: str = "" # SHA256 of original content
sections: List[Dict[str, Any]] = field(default_factory=list)
placeholders: List[str] = field(default_factory=list)
language: str = "en"
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.source_hash and self.text:
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
class MarkdownParser:
"""Parse Markdown files into structured content."""
# Common placeholder patterns
PLACEHOLDER_PATTERNS = [
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
r'\{([a-z_]+)\}', # {company_name}
r'\{\{([a-z_]+)\}\}', # {{company_name}}
r'__([A-Z_]+)__', # __COMPANY_NAME__
r'<([A-Z_]+)>', # <COMPANY_NAME>
]
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse markdown content into an ExtractedDocument."""
title = cls._extract_title(content, filename)
sections = cls._extract_sections(content)
placeholders = cls._find_placeholders(content)
language = cls._detect_language(content)
clean_text = cls._clean_for_indexing(content)
return ExtractedDocument(
text=clean_text,
title=title,
file_path=filename,
file_type="markdown",
source_url="",
sections=sections,
placeholders=placeholders,
language=language,
)
@classmethod
def _extract_title(cls, content: str, filename: str) -> str:
"""Extract title from markdown heading or filename."""
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
if h1_match:
return h1_match.group(1).strip()
frontmatter_match = re.search(
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
content, re.DOTALL
)
if frontmatter_match:
return frontmatter_match.group(1).strip()
if filename:
name = Path(filename).stem
return name.replace('-', ' ').replace('_', ' ').title()
return "Untitled"
@classmethod
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
"""Extract sections from markdown content."""
sections = []
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = current_section["content"].strip()
sections.append(current_section.copy())
level = len(match.group(1))
heading = match.group(2).strip()
current_section = {
"heading": heading,
"level": level,
"content": "",
"start": match.end(),
}
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = content[current_section["start"]:].strip()
sections.append(current_section)
return sections
@classmethod
def _find_placeholders(cls, content: str) -> List[str]:
"""Find placeholder patterns in content."""
placeholders = set()
for pattern in cls.PLACEHOLDER_PATTERNS:
for match in re.finditer(pattern, content):
placeholder = match.group(0)
placeholders.add(placeholder)
return sorted(list(placeholders))
@classmethod
def _detect_language(cls, content: str) -> str:
"""Detect language from content."""
german_indicators = [
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
]
lower_content = content.lower()
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
if german_count >= 3:
return "de"
return "en"
@classmethod
def _clean_for_indexing(cls, content: str) -> str:
"""Clean markdown content for text indexing."""
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
content = re.sub(r'<[^>]+>', '', content)
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content)
content = re.sub(r'\*(.+?)\*', r'\1', content)
content = re.sub(r'`(.+?)`', r'\1', content)
content = re.sub(r'~~(.+?)~~', r'\1', content)
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
content = re.sub(r'\n{3,}', '\n\n', content)
content = re.sub(r' +', ' ', content)
return content.strip()
class HTMLParser:
"""Parse HTML files into structured content."""
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse HTML content into an ExtractedDocument."""
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
title = title_match.group(1) if title_match else Path(filename).stem
text = cls._html_to_text(content)
placeholders = MarkdownParser._find_placeholders(text)
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
return ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="html",
source_url="",
placeholders=placeholders,
language=language,
)
@classmethod
def _html_to_text(cls, html: str) -> str:
"""Convert HTML to clean text."""
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
html = html.replace('&nbsp;', ' ')
html = html.replace('&amp;', '&')
html = html.replace('&lt;', '<')
html = html.replace('&gt;', '>')
html = html.replace('&quot;', '"')
html = html.replace('&apos;', "'")
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'<[^>]+>', '', html)
html = re.sub(r'[ \t]+', ' ', html)
html = re.sub(r'\n[ \t]+', '\n', html)
html = re.sub(r'[ \t]+\n', '\n', html)
html = re.sub(r'\n{3,}', '\n\n', html)
return html.strip()
class JSONParser:
"""Parse JSON files containing legal template data."""
@classmethod
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
"""Parse JSON content into ExtractedDocuments."""
try:
data = json.loads(content)
except json.JSONDecodeError as e:
import logging
logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}")
return []
documents = []
if isinstance(data, dict):
documents.extend(cls._parse_dict(data, filename))
elif isinstance(data, list):
for i, item in enumerate(data):
if isinstance(item, dict):
docs = cls._parse_dict(item, f"{filename}[{i}]")
documents.extend(docs)
return documents
@classmethod
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
"""Parse a dictionary into documents."""
documents = []
text_keys = ['text', 'content', 'body', 'description', 'value']
title_keys = ['title', 'name', 'heading', 'label', 'key']
text = ""
for key in text_keys:
if key in data and isinstance(data[key], str):
text = data[key]
break
if not text:
for key, value in data.items():
if isinstance(value, dict):
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
documents.extend(nested_docs)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
documents.extend(nested_docs)
elif isinstance(item, str) and len(item) > 50:
documents.append(ExtractedDocument(
text=item,
title=f"{key} {i+1}",
file_path=filename,
file_type="json",
source_url="",
language=MarkdownParser._detect_language(item),
))
return documents
title = ""
for key in title_keys:
if key in data and isinstance(data[key], str):
title = data[key]
break
if not title:
title = Path(filename).stem
metadata = {}
for key, value in data.items():
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
metadata[key] = value
placeholders = MarkdownParser._find_placeholders(text)
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
documents.append(ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="json",
source_url="",
placeholders=placeholders,
language=language,
metadata=metadata,
))
return documents
+26 -786
View File
@@ -1,790 +1,30 @@
"""
Legal Corpus API - Endpoints for RAG page in admin-v2
Legal Corpus API — Barrel Re-export
Provides endpoints for:
- GET /api/v1/admin/legal-corpus/status - Collection status with chunk counts
- GET /api/v1/admin/legal-corpus/search - Semantic search
- POST /api/v1/admin/legal-corpus/ingest - Trigger ingestion
- GET /api/v1/admin/legal-corpus/ingestion-status - Ingestion status
- POST /api/v1/admin/legal-corpus/upload - Upload document
- POST /api/v1/admin/legal-corpus/add-link - Add link for ingestion
- POST /api/v1/admin/pipeline/start - Start compliance pipeline
Split into:
- legal_corpus_routes.py — Corpus endpoints (status, search, ingest, upload)
- legal_corpus_pipeline.py — Pipeline endpoints (checkpoints, start, status)
All public names are re-exported here for backward compatibility.
"""
import os
import asyncio
import httpx
import uuid
import shutil
from datetime import datetime
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form
from pydantic import BaseModel
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"])
# Configuration
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
COLLECTION_NAME = "bp_legal_corpus"
# All regulations for status endpoint
REGULATIONS = [
{"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"},
{"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"},
{"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"},
{"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"},
{"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"},
{"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"},
{"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"},
{"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"},
{"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"},
{"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"},
{"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"},
{"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"},
{"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"},
{"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"},
{"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"},
{"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"},
{"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"},
{"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"},
{"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"},
]
# Ingestion state (in-memory for now)
ingestion_state = {
"running": False,
"completed": False,
"current_regulation": None,
"processed": 0,
"total": len(REGULATIONS),
"error": None,
}
class SearchRequest(BaseModel):
query: str
regulations: Optional[List[str]] = None
top_k: int = 5
class IngestRequest(BaseModel):
force: bool = False
regulations: Optional[List[str]] = None
class AddLinkRequest(BaseModel):
url: str
title: str
code: str # Regulation code (e.g. "CUSTOM-1")
document_type: str = "custom" # custom, eu_regulation, eu_directive, de_law, bsi_standard
class StartPipelineRequest(BaseModel):
force_reindex: bool = False
skip_ingestion: bool = False
# Store for custom documents (in-memory for now, should be persisted)
custom_documents: List[Dict[str, Any]] = []
async def get_qdrant_client():
"""Get async HTTP client for Qdrant."""
return httpx.AsyncClient(timeout=30.0)
@router.get("/status")
async def get_legal_corpus_status():
"""
Get status of the legal corpus collection including chunk counts per regulation.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
try:
# Get collection info
collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}")
if collection_res.status_code != 200:
return {
"collection": COLLECTION_NAME,
"totalPoints": 0,
"vectorSize": 1024,
"status": "not_found",
"regulations": {},
}
collection_data = collection_res.json()
result = collection_data.get("result", {})
# Get chunk counts per regulation
regulation_counts = {}
for reg in REGULATIONS:
count_res = await client.post(
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count",
json={
"filter": {
"must": [{"key": "regulation_code", "match": {"value": reg["code"]}}]
}
},
)
if count_res.status_code == 200:
count_data = count_res.json()
regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0)
else:
regulation_counts[reg["code"]] = 0
return {
"collection": COLLECTION_NAME,
"totalPoints": result.get("points_count", 0),
"vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024),
"status": result.get("status", "unknown"),
"regulations": regulation_counts,
}
except httpx.RequestError as e:
logger.error(f"Failed to get Qdrant status: {e}")
raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}")
@router.get("/search")
async def search_legal_corpus(
query: str = Query(..., description="Search query"),
top_k: int = Query(5, ge=1, le=20, description="Number of results"),
regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"),
):
"""
Semantic search in legal corpus using BGE-M3 embeddings.
"""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
# Generate embedding for query
embed_res = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": [query]},
)
if embed_res.status_code != 200:
raise HTTPException(status_code=500, detail="Embedding service error")
embed_data = embed_res.json()
query_vector = embed_data["embeddings"][0]
# Build Qdrant search request
search_request = {
"vector": query_vector,
"limit": top_k,
"with_payload": True,
}
# Add regulation filter if specified
if regulations:
reg_codes = [r.strip() for r in regulations.split(",")]
search_request["filter"] = {
"should": [
{"key": "regulation_code", "match": {"value": code}}
for code in reg_codes
]
}
# Search Qdrant
search_res = await client.post(
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search",
json=search_request,
)
if search_res.status_code != 200:
raise HTTPException(status_code=500, detail="Search failed")
search_data = search_res.json()
results = []
for point in search_data.get("result", []):
payload = point.get("payload", {})
results.append({
"text": payload.get("text", ""),
"regulation_code": payload.get("regulation_code", ""),
"regulation_name": payload.get("regulation_name", ""),
"article": payload.get("article"),
"paragraph": payload.get("paragraph"),
"source_url": payload.get("source_url", ""),
"score": point.get("score", 0),
})
return {"results": results, "query": query, "count": len(results)}
except httpx.RequestError as e:
logger.error(f"Search failed: {e}")
raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}")
@router.post("/ingest")
async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks):
"""
Trigger legal corpus ingestion in background.
"""
global ingestion_state
if ingestion_state["running"]:
raise HTTPException(status_code=409, detail="Ingestion already running")
# Reset state
ingestion_state = {
"running": True,
"completed": False,
"current_regulation": None,
"processed": 0,
"total": len(REGULATIONS),
"error": None,
}
# Start ingestion in background
background_tasks.add_task(run_ingestion, request.force, request.regulations)
return {
"status": "started",
"job_id": "manual-trigger",
"message": f"Ingestion started for {len(REGULATIONS)} regulations",
}
async def run_ingestion(force: bool, regulations: Optional[List[str]]):
"""Background task for running ingestion."""
global ingestion_state
try:
# Import ingestion module
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
# Filter regulations if specified
regs_to_process = regulations or [r["code"] for r in REGULATIONS]
for i, reg_code in enumerate(regs_to_process):
ingestion_state["current_regulation"] = reg_code
ingestion_state["processed"] = i
try:
await ingestion.ingest_single(reg_code, force=force)
except Exception as e:
logger.error(f"Failed to ingest {reg_code}: {e}")
ingestion_state["completed"] = True
ingestion_state["processed"] = len(regs_to_process)
except Exception as e:
logger.error(f"Ingestion failed: {e}")
ingestion_state["error"] = str(e)
finally:
ingestion_state["running"] = False
@router.get("/ingestion-status")
async def get_ingestion_status():
"""
Get current ingestion status.
"""
return ingestion_state
@router.get("/regulations")
async def get_regulations():
"""
Get list of all supported regulations.
"""
return {"regulations": REGULATIONS}
@router.get("/custom-documents")
async def get_custom_documents():
"""
Get list of custom documents added by user.
"""
return {"documents": custom_documents}
@router.post("/upload")
async def upload_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
title: str = Form(...),
code: str = Form(...),
document_type: str = Form("custom"),
):
"""
Upload a document (PDF) for ingestion into the legal corpus.
The document will be saved and queued for processing.
"""
global custom_documents
# Validate file type
if not file.filename.endswith(('.pdf', '.PDF')):
raise HTTPException(status_code=400, detail="Only PDF files are supported")
# Create upload directory if needed
upload_dir = "/tmp/legal_corpus_uploads"
os.makedirs(upload_dir, exist_ok=True)
# Save file with unique name
doc_id = str(uuid.uuid4())[:8]
safe_filename = f"{doc_id}_{file.filename}"
file_path = os.path.join(upload_dir, safe_filename)
try:
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
except Exception as e:
logger.error(f"Failed to save uploaded file: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
# Create document record
doc_record = {
"id": doc_id,
"code": code,
"title": title,
"filename": file.filename,
"file_path": file_path,
"document_type": document_type,
"uploaded_at": datetime.now().isoformat(),
"status": "uploaded",
"chunk_count": 0,
}
custom_documents.append(doc_record)
# Queue for background ingestion
background_tasks.add_task(ingest_uploaded_document, doc_record)
return {
"status": "uploaded",
"document_id": doc_id,
"message": f"Document '{title}' uploaded and queued for ingestion",
"document": doc_record,
}
async def ingest_uploaded_document(doc_record: Dict[str, Any]):
"""Background task to ingest an uploaded document."""
global custom_documents
try:
doc_record["status"] = "processing"
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
# Read PDF and extract text
import fitz # PyMuPDF
doc = fitz.open(doc_record["file_path"])
full_text = ""
for page in doc:
full_text += page.get_text()
doc.close()
if not full_text.strip():
doc_record["status"] = "error"
doc_record["error"] = "No text could be extracted from PDF"
return
# Chunk the text
chunks = ingestion.chunk_text(full_text, doc_record["code"])
# Add metadata
for chunk in chunks:
chunk["regulation_code"] = doc_record["code"]
chunk["regulation_name"] = doc_record["title"]
chunk["document_type"] = doc_record["document_type"]
chunk["source_url"] = f"upload://{doc_record['filename']}"
# Generate embeddings and upsert to Qdrant
if chunks:
await ingestion.embed_and_upsert(chunks)
doc_record["chunk_count"] = len(chunks)
doc_record["status"] = "indexed"
logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}")
else:
doc_record["status"] = "error"
doc_record["error"] = "No chunks generated from document"
except Exception as e:
logger.error(f"Failed to ingest uploaded document: {e}")
doc_record["status"] = "error"
doc_record["error"] = str(e)
@router.post("/add-link")
async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks):
"""
Add a URL/link for ingestion into the legal corpus.
The content will be fetched, extracted, and indexed.
"""
global custom_documents
# Create document record
doc_id = str(uuid.uuid4())[:8]
doc_record = {
"id": doc_id,
"code": request.code,
"title": request.title,
"url": request.url,
"document_type": request.document_type,
"uploaded_at": datetime.now().isoformat(),
"status": "queued",
"chunk_count": 0,
}
custom_documents.append(doc_record)
# Queue for background ingestion
background_tasks.add_task(ingest_link_document, doc_record)
return {
"status": "queued",
"document_id": doc_id,
"message": f"Link '{request.title}' queued for ingestion",
"document": doc_record,
}
async def ingest_link_document(doc_record: Dict[str, Any]):
"""Background task to ingest content from a URL."""
global custom_documents
try:
doc_record["status"] = "fetching"
async with httpx.AsyncClient(timeout=60.0) as client:
# Fetch the URL
response = await client.get(doc_record["url"], follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "application/pdf" in content_type:
# Save PDF and process
import tempfile
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(response.content)
pdf_path = f.name
import fitz
pdf_doc = fitz.open(pdf_path)
full_text = ""
for page in pdf_doc:
full_text += page.get_text()
pdf_doc.close()
os.unlink(pdf_path)
elif "text/html" in content_type:
# Extract text from HTML
from bs4 import BeautifulSoup
soup = BeautifulSoup(response.text, "html.parser")
# Remove script and style elements
for script in soup(["script", "style", "nav", "footer", "header"]):
script.decompose()
full_text = soup.get_text(separator="\n", strip=True)
else:
# Try to use as plain text
full_text = response.text
if not full_text.strip():
doc_record["status"] = "error"
doc_record["error"] = "No text could be extracted from URL"
return
doc_record["status"] = "processing"
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
# Chunk the text
chunks = ingestion.chunk_text(full_text, doc_record["code"])
# Add metadata
for chunk in chunks:
chunk["regulation_code"] = doc_record["code"]
chunk["regulation_name"] = doc_record["title"]
chunk["document_type"] = doc_record["document_type"]
chunk["source_url"] = doc_record["url"]
# Generate embeddings and upsert to Qdrant
if chunks:
await ingestion.embed_and_upsert(chunks)
doc_record["chunk_count"] = len(chunks)
doc_record["status"] = "indexed"
logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}")
else:
doc_record["status"] = "error"
doc_record["error"] = "No chunks generated from content"
except httpx.HTTPError as e:
logger.error(f"Failed to fetch URL: {e}")
doc_record["status"] = "error"
doc_record["error"] = f"Failed to fetch URL: {str(e)}"
except Exception as e:
logger.error(f"Failed to ingest URL content: {e}")
doc_record["status"] = "error"
doc_record["error"] = str(e)
@router.delete("/custom-documents/{doc_id}")
async def delete_custom_document(doc_id: str):
"""
Delete a custom document from the list.
Note: This does not remove the chunks from Qdrant yet.
"""
global custom_documents
doc = next((d for d in custom_documents if d["id"] == doc_id), None)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
custom_documents = [d for d in custom_documents if d["id"] != doc_id]
# TODO: Also remove chunks from Qdrant by filtering on code
return {"status": "deleted", "document_id": doc_id}
# ========== Pipeline Checkpoints ==========
# Create a separate router for pipeline-related endpoints
pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"])
@pipeline_router.get("/checkpoints")
async def get_pipeline_checkpoints():
"""
Get current pipeline checkpoint state.
Returns the current state of the compliance pipeline including:
- Pipeline ID and overall status
- Start and completion times
- All checkpoints with their validations and metrics
- Summary data
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state is None:
return {
"status": "no_data",
"message": "No pipeline run data available yet.",
"pipeline_id": None,
"checkpoints": [],
"summary": {}
}
# Enrich with validation summary
validation_summary = {
"passed": 0,
"warning": 0,
"failed": 0,
"total": 0
}
for checkpoint in state.get("checkpoints", []):
for validation in checkpoint.get("validations", []):
validation_summary["total"] += 1
status = validation.get("status", "not_run")
if status in validation_summary:
validation_summary[status] += 1
state["validation_summary"] = validation_summary
return state
@pipeline_router.get("/checkpoints/history")
async def get_pipeline_history():
"""
Get list of previous pipeline runs (if stored).
For now, returns only current run.
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state is None:
return {"runs": []}
return {
"runs": [{
"pipeline_id": state.get("pipeline_id"),
"status": state.get("status"),
"started_at": state.get("started_at"),
"completed_at": state.get("completed_at"),
}]
}
# Pipeline state for start/stop
pipeline_process_state = {
"running": False,
"pid": None,
"started_at": None,
}
@pipeline_router.post("/start")
async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks):
"""
Start the compliance pipeline in the background.
This runs the full_compliance_pipeline.py script which:
1. Ingests all legal documents (unless skip_ingestion=True)
2. Extracts requirements and controls
3. Generates compliance measures
4. Creates checkpoint data for monitoring
"""
global pipeline_process_state
# Check if already running
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state and state.get("status") == "running":
raise HTTPException(
status_code=409,
detail="Pipeline is already running"
)
if pipeline_process_state["running"]:
raise HTTPException(
status_code=409,
detail="Pipeline start already in progress"
)
pipeline_process_state["running"] = True
pipeline_process_state["started_at"] = datetime.now().isoformat()
# Start pipeline in background
background_tasks.add_task(
run_pipeline_background,
request.force_reindex,
request.skip_ingestion
)
return {
"status": "starting",
"message": "Compliance pipeline is starting in background",
"started_at": pipeline_process_state["started_at"],
}
async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool):
"""Background task to run the compliance pipeline."""
global pipeline_process_state
try:
import subprocess
import sys
# Build command
cmd = [sys.executable, "full_compliance_pipeline.py"]
if force_reindex:
cmd.append("--force-reindex")
if skip_ingestion:
cmd.append("--skip-ingestion")
# Run as subprocess
logger.info(f"Starting pipeline: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=os.path.dirname(os.path.abspath(__file__)),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
pipeline_process_state["pid"] = process.pid
# Wait for completion (non-blocking via asyncio)
import asyncio
while process.poll() is None:
await asyncio.sleep(5)
return_code = process.returncode
if return_code != 0:
output = process.stdout.read() if process.stdout else ""
logger.error(f"Pipeline failed with code {return_code}: {output}")
else:
logger.info("Pipeline completed successfully")
except Exception as e:
logger.error(f"Failed to run pipeline: {e}")
finally:
pipeline_process_state["running"] = False
pipeline_process_state["pid"] = None
@pipeline_router.get("/status")
async def get_pipeline_status():
"""
Get current pipeline running status.
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
checkpoint_status = state.get("status") if state else "no_data"
return {
"process_running": pipeline_process_state["running"],
"process_pid": pipeline_process_state["pid"],
"process_started_at": pipeline_process_state["started_at"],
"checkpoint_status": checkpoint_status,
"current_phase": state.get("current_phase") if state else None,
}
# ========== Traceability / Quality Endpoints ==========
@router.get("/traceability")
async def get_traceability(
chunk_id: str = Query(..., description="Chunk ID or identifier"),
regulation: str = Query(..., description="Regulation code"),
):
"""
Get traceability information for a specific chunk.
Returns:
- The chunk details
- Requirements extracted from this chunk
- Controls derived from those requirements
Note: This is a placeholder that will be enhanced once the
requirements extraction pipeline is fully implemented.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
try:
# Try to find the chunk by scrolling through points with the regulation filter
# In a production system, we would have proper IDs and indexing
# For now, return placeholder structure
# The actual implementation will query:
# 1. The chunk from Qdrant
# 2. Requirements from a requirements collection/table
# 3. Controls from a controls collection/table
return {
"chunk_id": chunk_id,
"regulation": regulation,
"requirements": [],
"controls": [],
"message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind."
}
except Exception as e:
logger.error(f"Failed to get traceability: {e}")
raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}")
# Corpus routes and state
from legal_corpus_routes import ( # noqa: F401
router,
REGULATIONS,
COLLECTION_NAME,
QDRANT_URL,
EMBEDDING_SERVICE_URL,
ingestion_state,
custom_documents,
SearchRequest,
IngestRequest,
AddLinkRequest,
)
# Pipeline routes and state
from legal_corpus_pipeline import ( # noqa: F401
pipeline_router,
pipeline_process_state,
StartPipelineRequest,
)
@@ -0,0 +1,166 @@
"""
Legal Corpus API - Background Ingestion Tasks
Background tasks for ingesting uploaded documents and URL links
into the legal corpus vector database.
Extracted from legal_corpus_routes.py to keep files under 500 LOC.
"""
import os
import logging
from typing import Dict, Any, Optional, List
import httpx
logger = logging.getLogger(__name__)
async def ingest_uploaded_document(doc_record: Dict[str, Any]):
"""Background task to ingest an uploaded document."""
try:
doc_record["status"] = "processing"
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
import fitz
doc = fitz.open(doc_record["file_path"])
full_text = ""
for page in doc:
full_text += page.get_text()
doc.close()
if not full_text.strip():
doc_record["status"] = "error"
doc_record["error"] = "No text could be extracted from PDF"
return
chunks = ingestion.chunk_text(full_text, doc_record["code"])
for chunk in chunks:
chunk["regulation_code"] = doc_record["code"]
chunk["regulation_name"] = doc_record["title"]
chunk["document_type"] = doc_record["document_type"]
chunk["source_url"] = f"upload://{doc_record['filename']}"
if chunks:
await ingestion.embed_and_upsert(chunks)
doc_record["chunk_count"] = len(chunks)
doc_record["status"] = "indexed"
logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}")
else:
doc_record["status"] = "error"
doc_record["error"] = "No chunks generated from document"
except Exception as e:
logger.error(f"Failed to ingest uploaded document: {e}")
doc_record["status"] = "error"
doc_record["error"] = str(e)
async def ingest_link_document(doc_record: Dict[str, Any]):
"""Background task to ingest content from a URL."""
try:
doc_record["status"] = "fetching"
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(doc_record["url"], follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "application/pdf" in content_type:
import tempfile
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(response.content)
pdf_path = f.name
import fitz
pdf_doc = fitz.open(pdf_path)
full_text = ""
for page in pdf_doc:
full_text += page.get_text()
pdf_doc.close()
os.unlink(pdf_path)
elif "text/html" in content_type:
from bs4 import BeautifulSoup
soup = BeautifulSoup(response.text, "html.parser")
for script in soup(["script", "style", "nav", "footer", "header"]):
script.decompose()
full_text = soup.get_text(separator="\n", strip=True)
else:
full_text = response.text
if not full_text.strip():
doc_record["status"] = "error"
doc_record["error"] = "No text could be extracted from URL"
return
doc_record["status"] = "processing"
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
chunks = ingestion.chunk_text(full_text, doc_record["code"])
for chunk in chunks:
chunk["regulation_code"] = doc_record["code"]
chunk["regulation_name"] = doc_record["title"]
chunk["document_type"] = doc_record["document_type"]
chunk["source_url"] = doc_record["url"]
if chunks:
await ingestion.embed_and_upsert(chunks)
doc_record["chunk_count"] = len(chunks)
doc_record["status"] = "indexed"
logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}")
else:
doc_record["status"] = "error"
doc_record["error"] = "No chunks generated from content"
except httpx.HTTPError as e:
logger.error(f"Failed to fetch URL: {e}")
doc_record["status"] = "error"
doc_record["error"] = f"Failed to fetch URL: {str(e)}"
except Exception as e:
logger.error(f"Failed to ingest URL content: {e}")
doc_record["status"] = "error"
doc_record["error"] = str(e)
async def run_ingestion(
force: bool,
regulations: Optional[List[str]],
ingestion_state: Dict[str, Any],
all_regulations: List[Dict[str, str]],
):
"""Background task for running full corpus ingestion."""
try:
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
regs_to_process = regulations or [r["code"] for r in all_regulations]
for i, reg_code in enumerate(regs_to_process):
ingestion_state["current_regulation"] = reg_code
ingestion_state["processed"] = i
try:
await ingestion.ingest_single(reg_code, force=force)
except Exception as e:
logger.error(f"Failed to ingest {reg_code}: {e}")
ingestion_state["completed"] = True
ingestion_state["processed"] = len(regs_to_process)
except Exception as e:
logger.error(f"Ingestion failed: {e}")
ingestion_state["error"] = str(e)
finally:
ingestion_state["running"] = False
@@ -0,0 +1,206 @@
"""
Legal Corpus API - Pipeline Routes
Pipeline checkpoints, history, start/stop, and status endpoints.
Extracted from legal_corpus_api.py to keep files under 500 LOC.
"""
import os
import asyncio
from datetime import datetime
from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel
import logging
logger = logging.getLogger(__name__)
class StartPipelineRequest(BaseModel):
force_reindex: bool = False
skip_ingestion: bool = False
# Create a separate router for pipeline-related endpoints
pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"])
@pipeline_router.get("/checkpoints")
async def get_pipeline_checkpoints():
"""
Get current pipeline checkpoint state.
Returns the current state of the compliance pipeline including:
- Pipeline ID and overall status
- Start and completion times
- All checkpoints with their validations and metrics
- Summary data
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state is None:
return {
"status": "no_data",
"message": "No pipeline run data available yet.",
"pipeline_id": None,
"checkpoints": [],
"summary": {}
}
# Enrich with validation summary
validation_summary = {
"passed": 0,
"warning": 0,
"failed": 0,
"total": 0
}
for checkpoint in state.get("checkpoints", []):
for validation in checkpoint.get("validations", []):
validation_summary["total"] += 1
status = validation.get("status", "not_run")
if status in validation_summary:
validation_summary[status] += 1
state["validation_summary"] = validation_summary
return state
@pipeline_router.get("/checkpoints/history")
async def get_pipeline_history():
"""
Get list of previous pipeline runs (if stored).
For now, returns only current run.
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state is None:
return {"runs": []}
return {
"runs": [{
"pipeline_id": state.get("pipeline_id"),
"status": state.get("status"),
"started_at": state.get("started_at"),
"completed_at": state.get("completed_at"),
}]
}
# Pipeline state for start/stop
pipeline_process_state = {
"running": False,
"pid": None,
"started_at": None,
}
@pipeline_router.post("/start")
async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks):
"""
Start the compliance pipeline in the background.
This runs the full_compliance_pipeline.py script which:
1. Ingests all legal documents (unless skip_ingestion=True)
2. Extracts requirements and controls
3. Generates compliance measures
4. Creates checkpoint data for monitoring
"""
global pipeline_process_state
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state and state.get("status") == "running":
raise HTTPException(
status_code=409,
detail="Pipeline is already running"
)
if pipeline_process_state["running"]:
raise HTTPException(
status_code=409,
detail="Pipeline start already in progress"
)
pipeline_process_state["running"] = True
pipeline_process_state["started_at"] = datetime.now().isoformat()
background_tasks.add_task(
run_pipeline_background,
request.force_reindex,
request.skip_ingestion
)
return {
"status": "starting",
"message": "Compliance pipeline is starting in background",
"started_at": pipeline_process_state["started_at"],
}
async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool):
"""Background task to run the compliance pipeline."""
global pipeline_process_state
try:
import subprocess
import sys
cmd = [sys.executable, "full_compliance_pipeline.py"]
if force_reindex:
cmd.append("--force-reindex")
if skip_ingestion:
cmd.append("--skip-ingestion")
logger.info(f"Starting pipeline: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=os.path.dirname(os.path.abspath(__file__)),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
pipeline_process_state["pid"] = process.pid
while process.poll() is None:
await asyncio.sleep(5)
return_code = process.returncode
if return_code != 0:
output = process.stdout.read() if process.stdout else ""
logger.error(f"Pipeline failed with code {return_code}: {output}")
else:
logger.info("Pipeline completed successfully")
except Exception as e:
logger.error(f"Failed to run pipeline: {e}")
finally:
pipeline_process_state["running"] = False
pipeline_process_state["pid"] = None
@pipeline_router.get("/status")
async def get_pipeline_status():
"""Get current pipeline running status."""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
checkpoint_status = state.get("status") if state else "no_data"
return {
"process_running": pipeline_process_state["running"],
"process_pid": pipeline_process_state["pid"],
"process_started_at": pipeline_process_state["started_at"],
"checkpoint_status": checkpoint_status,
"current_phase": state.get("current_phase") if state else None,
}
@@ -0,0 +1,368 @@
"""
Legal Corpus API - Corpus Routes
Endpoints for the RAG page in admin-v2:
- GET /status - Collection status with chunk counts
- GET /search - Semantic search
- POST /ingest - Trigger ingestion
- GET /ingestion-status - Ingestion status
- GET /regulations - List regulations
- GET /custom-documents - List custom docs
- POST /upload - Upload document
- POST /add-link - Add link for ingestion
- DELETE /custom-documents/{id} - Delete custom doc
- GET /traceability - Traceability info
Extracted from legal_corpus_api.py to keep files under 500 LOC.
"""
import os
import httpx
import uuid
import shutil
from datetime import datetime
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form
from pydantic import BaseModel
import logging
from legal_corpus_ingest_tasks import (
ingest_uploaded_document,
ingest_link_document,
run_ingestion,
)
logger = logging.getLogger(__name__)
# Configuration
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
COLLECTION_NAME = "bp_legal_corpus"
# All regulations for status endpoint
REGULATIONS = [
{"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"},
{"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"},
{"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"},
{"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"},
{"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"},
{"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"},
{"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"},
{"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"},
{"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"},
{"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"},
{"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"},
{"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"},
{"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"},
{"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"},
{"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"},
{"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"},
{"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"},
{"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"},
{"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"},
]
# Ingestion state (in-memory for now)
ingestion_state = {
"running": False,
"completed": False,
"current_regulation": None,
"processed": 0,
"total": len(REGULATIONS),
"error": None,
}
class SearchRequest(BaseModel):
query: str
regulations: Optional[List[str]] = None
top_k: int = 5
class IngestRequest(BaseModel):
force: bool = False
regulations: Optional[List[str]] = None
class AddLinkRequest(BaseModel):
url: str
title: str
code: str
document_type: str = "custom"
# Store for custom documents (in-memory for now)
custom_documents: List[Dict[str, Any]] = []
router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"])
@router.get("/status")
async def get_legal_corpus_status():
"""Get status of the legal corpus collection including chunk counts per regulation."""
async with httpx.AsyncClient(timeout=30.0) as client:
try:
collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}")
if collection_res.status_code != 200:
return {
"collection": COLLECTION_NAME,
"totalPoints": 0,
"vectorSize": 1024,
"status": "not_found",
"regulations": {},
}
collection_data = collection_res.json()
result = collection_data.get("result", {})
regulation_counts = {}
for reg in REGULATIONS:
count_res = await client.post(
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count",
json={
"filter": {
"must": [{"key": "regulation_code", "match": {"value": reg["code"]}}]
}
},
)
if count_res.status_code == 200:
count_data = count_res.json()
regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0)
else:
regulation_counts[reg["code"]] = 0
return {
"collection": COLLECTION_NAME,
"totalPoints": result.get("points_count", 0),
"vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024),
"status": result.get("status", "unknown"),
"regulations": regulation_counts,
}
except httpx.RequestError as e:
logger.error(f"Failed to get Qdrant status: {e}")
raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}")
@router.get("/search")
async def search_legal_corpus(
query: str = Query(..., description="Search query"),
top_k: int = Query(5, ge=1, le=20, description="Number of results"),
regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"),
):
"""Semantic search in legal corpus using BGE-M3 embeddings."""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
embed_res = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": [query]},
)
if embed_res.status_code != 200:
raise HTTPException(status_code=500, detail="Embedding service error")
embed_data = embed_res.json()
query_vector = embed_data["embeddings"][0]
search_request = {
"vector": query_vector,
"limit": top_k,
"with_payload": True,
}
if regulations:
reg_codes = [r.strip() for r in regulations.split(",")]
search_request["filter"] = {
"should": [
{"key": "regulation_code", "match": {"value": code}}
for code in reg_codes
]
}
search_res = await client.post(
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search",
json=search_request,
)
if search_res.status_code != 200:
raise HTTPException(status_code=500, detail="Search failed")
search_data = search_res.json()
results = []
for point in search_data.get("result", []):
payload = point.get("payload", {})
results.append({
"text": payload.get("text", ""),
"regulation_code": payload.get("regulation_code", ""),
"regulation_name": payload.get("regulation_name", ""),
"article": payload.get("article"),
"paragraph": payload.get("paragraph"),
"source_url": payload.get("source_url", ""),
"score": point.get("score", 0),
})
return {"results": results, "query": query, "count": len(results)}
except httpx.RequestError as e:
logger.error(f"Search failed: {e}")
raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}")
@router.post("/ingest")
async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks):
"""Trigger legal corpus ingestion in background."""
global ingestion_state
if ingestion_state["running"]:
raise HTTPException(status_code=409, detail="Ingestion already running")
ingestion_state = {
"running": True,
"completed": False,
"current_regulation": None,
"processed": 0,
"total": len(REGULATIONS),
"error": None,
}
background_tasks.add_task(run_ingestion, request.force, request.regulations, ingestion_state, REGULATIONS)
return {
"status": "started",
"job_id": "manual-trigger",
"message": f"Ingestion started for {len(REGULATIONS)} regulations",
}
@router.get("/ingestion-status")
async def get_ingestion_status():
"""Get current ingestion status."""
return ingestion_state
@router.get("/regulations")
async def get_regulations():
"""Get list of all supported regulations."""
return {"regulations": REGULATIONS}
@router.get("/custom-documents")
async def get_custom_documents():
"""Get list of custom documents added by user."""
return {"documents": custom_documents}
@router.post("/upload")
async def upload_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
title: str = Form(...),
code: str = Form(...),
document_type: str = Form("custom"),
):
"""Upload a document (PDF) for ingestion into the legal corpus."""
global custom_documents
if not file.filename.endswith(('.pdf', '.PDF')):
raise HTTPException(status_code=400, detail="Only PDF files are supported")
upload_dir = "/tmp/legal_corpus_uploads"
os.makedirs(upload_dir, exist_ok=True)
doc_id = str(uuid.uuid4())[:8]
safe_filename = f"{doc_id}_{file.filename}"
file_path = os.path.join(upload_dir, safe_filename)
try:
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
except Exception as e:
logger.error(f"Failed to save uploaded file: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
doc_record = {
"id": doc_id,
"code": code,
"title": title,
"filename": file.filename,
"file_path": file_path,
"document_type": document_type,
"uploaded_at": datetime.now().isoformat(),
"status": "uploaded",
"chunk_count": 0,
}
custom_documents.append(doc_record)
background_tasks.add_task(ingest_uploaded_document, doc_record)
return {
"status": "uploaded",
"document_id": doc_id,
"message": f"Document '{title}' uploaded and queued for ingestion",
"document": doc_record,
}
@router.post("/add-link")
async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks):
"""Add a URL/link for ingestion into the legal corpus."""
global custom_documents
doc_id = str(uuid.uuid4())[:8]
doc_record = {
"id": doc_id,
"code": request.code,
"title": request.title,
"url": request.url,
"document_type": request.document_type,
"uploaded_at": datetime.now().isoformat(),
"status": "queued",
"chunk_count": 0,
}
custom_documents.append(doc_record)
background_tasks.add_task(ingest_link_document, doc_record)
return {
"status": "queued",
"document_id": doc_id,
"message": f"Link '{request.title}' queued for ingestion",
"document": doc_record,
}
@router.delete("/custom-documents/{doc_id}")
async def delete_custom_document(doc_id: str):
"""Delete a custom document from the list."""
global custom_documents
doc = next((d for d in custom_documents if d["id"] == doc_id), None)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
custom_documents = [d for d in custom_documents if d["id"] != doc_id]
return {"status": "deleted", "document_id": doc_id}
@router.get("/traceability")
async def get_traceability(
chunk_id: str = Query(..., description="Chunk ID or identifier"),
regulation: str = Query(..., description="Regulation code"),
):
"""Get traceability information for a specific chunk."""
async with httpx.AsyncClient(timeout=30.0) as client:
try:
return {
"chunk_id": chunk_id,
"regulation": regulation,
"requirements": [],
"controls": [],
"message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind."
}
except Exception as e:
logger.error(f"Failed to get traceability: {e}")
raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}")
+269
View File
@@ -0,0 +1,269 @@
"""
AI Email - Category Classification and Response Suggestions
Rule-based and LLM-based email category classification,
plus response suggestion generation.
Extracted from ai_service.py to keep files under 500 LOC.
"""
import os
import logging
from typing import Optional, List, Tuple
import httpx
from .models import (
EmailCategory,
SenderType,
ResponseSuggestion,
)
logger = logging.getLogger(__name__)
# LLM Gateway configuration
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
async def classify_category(
http_client: httpx.AsyncClient,
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Tuple[EmailCategory, float]:
"""
Classify email into a category.
Rule-based classification first, falls back to LLM.
"""
category, confidence = _classify_category_rules(subject, body_preview, sender_type)
if confidence > 0.7:
return category, confidence
return await _classify_category_llm(http_client, subject, body_preview)
def _classify_category_rules(
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Tuple[EmailCategory, float]:
"""Rule-based category classification."""
text = f"{subject} {body_preview}".lower()
category_keywords = {
EmailCategory.DIENSTLICH: [
"dienstlich", "dienstanweisung", "erlass", "verordnung",
"bescheid", "verfuegung", "ministerium", "behoerde"
],
EmailCategory.PERSONAL: [
"personalrat", "stellenausschreibung", "versetzung",
"beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit"
],
EmailCategory.FINANZEN: [
"budget", "haushalt", "etat", "abrechnung", "rechnung",
"erstattung", "zuschuss", "foerdermittel"
],
EmailCategory.ELTERN: [
"elternbrief", "elternabend", "schulkonferenz",
"elternvertreter", "elternbeirat"
],
EmailCategory.SCHUELER: [
"schueler", "schuelerin", "zeugnis", "klasse", "unterricht",
"pruefung", "klassenfahrt", "schulpflicht"
],
EmailCategory.FORTBILDUNG: [
"fortbildung", "seminar", "workshop", "schulung",
"weiterbildung", "nlq", "didaktik"
],
EmailCategory.VERANSTALTUNG: [
"einladung", "veranstaltung", "termin", "konferenz",
"sitzung", "tagung", "feier"
],
EmailCategory.SICHERHEIT: [
"sicherheit", "notfall", "brandschutz", "evakuierung",
"hygiene", "corona", "infektionsschutz"
],
EmailCategory.TECHNIK: [
"it", "software", "computer", "netzwerk", "login",
"passwort", "digitalisierung", "iserv"
],
EmailCategory.NEWSLETTER: [
"newsletter", "rundschreiben", "info-mail", "mitteilung"
],
EmailCategory.WERBUNG: [
"angebot", "rabatt", "aktion", "werbung", "abonnement"
],
}
best_category = EmailCategory.SONSTIGES
best_score = 0.0
for category, keywords in category_keywords.items():
score = sum(1 for kw in keywords if kw in text)
if score > best_score:
best_score = score
best_category = category
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
if best_category == EmailCategory.SONSTIGES:
best_category = EmailCategory.DIENSTLICH
best_score = 2
confidence = min(0.9, 0.4 + (best_score * 0.15))
return best_category, confidence
async def _classify_category_llm(
client: httpx.AsyncClient,
subject: str,
body_preview: str,
) -> Tuple[EmailCategory, float]:
"""LLM-based category classification."""
try:
categories = ", ".join([c.value for c in EmailCategory])
prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie:
Betreff: {subject}
Inhalt: {body_preview[:500]}
Kategorien: {categories}
Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0):
Format: kategorie|konfidenz
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 50,
},
)
if response.status_code == 200:
data = response.json()
result = data.get("response", "sonstiges|0.5")
parts = result.strip().split("|")
if len(parts) >= 2:
category_str = parts[0].strip().lower()
confidence = float(parts[1].strip())
try:
category = EmailCategory(category_str)
return category, min(max(confidence, 0.0), 1.0)
except ValueError:
pass
except Exception as e:
logger.warning(f"LLM category classification failed: {e}")
return EmailCategory.SONSTIGES, 0.5
async def suggest_response(
http_client: httpx.AsyncClient,
subject: str,
body_text: str,
sender_type: SenderType,
category: EmailCategory,
) -> List[ResponseSuggestion]:
"""Generate response suggestions for an email."""
suggestions = []
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
suggestions.append(ResponseSuggestion(
template_type="acknowledgment",
subject=f"Re: {subject}",
body="""Sehr geehrte Damen und Herren,
vielen Dank fuer Ihre Nachricht.
Ich bestaetige den Eingang und werde die Angelegenheit fristgerecht bearbeiten.
Mit freundlichen Gruessen""",
confidence=0.8,
))
if category == EmailCategory.ELTERN:
suggestions.append(ResponseSuggestion(
template_type="parent_response",
subject=f"Re: {subject}",
body="""Liebe Eltern,
vielen Dank fuer Ihre Nachricht.
[Ihre Antwort hier]
Mit freundlichen Gruessen""",
confidence=0.7,
))
try:
llm_suggestion = await _generate_response_llm(http_client, subject, body_text[:500], sender_type)
if llm_suggestion:
suggestions.append(llm_suggestion)
except Exception as e:
logger.warning(f"LLM response generation failed: {e}")
return suggestions
async def _generate_response_llm(
client: httpx.AsyncClient,
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Optional[ResponseSuggestion]:
"""Generate a response suggestion using LLM."""
try:
sender_desc = {
SenderType.KULTUSMINISTERIUM: "dem Kultusministerium",
SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehoerde",
SenderType.RLSB: "dem RLSB",
SenderType.ELTERNVERTRETER: "einem Elternvertreter",
}.get(sender_type, "einem Absender")
prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}:
Betreff: {subject}
Inhalt: {body_preview}
Die Antwort sollte:
- Hoeflich und formell sein
- Den Eingang bestaetigen
- Eine konkrete naechste Aktion nennen oder um Klaerung bitten
Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:").
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 300,
},
)
if response.status_code == 200:
data = response.json()
body = data.get("response", "").strip()
if body:
return ResponseSuggestion(
template_type="ai_generated",
subject=f"Re: {subject}",
body=body,
confidence=0.6,
)
except Exception as e:
logger.warning(f"LLM response generation failed: {e}")
return None
+184
View File
@@ -0,0 +1,184 @@
"""
AI Email - Deadline Extraction
Regex-based and LLM-based deadline extraction from email content.
Extracted from ai_service.py to keep files under 500 LOC.
"""
import os
import re
import logging
from typing import List
from datetime import datetime, timedelta
import httpx
from .models import DeadlineExtraction
logger = logging.getLogger(__name__)
# LLM Gateway configuration
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
async def extract_deadlines(
http_client: httpx.AsyncClient,
subject: str,
body_text: str,
) -> List[DeadlineExtraction]:
"""
Extract deadlines from email content.
Uses regex patterns first, then LLM for complex cases.
"""
deadlines = []
full_text = f"{subject}\n{body_text}" if body_text else subject
# Try regex extraction first
regex_deadlines = _extract_deadlines_regex(full_text)
deadlines.extend(regex_deadlines)
# If no regex matches, try LLM
if not deadlines and body_text:
llm_deadlines = await _extract_deadlines_llm(http_client, subject, body_text[:1000])
deadlines.extend(llm_deadlines)
return deadlines
def _extract_deadlines_regex(text: str) -> List[DeadlineExtraction]:
"""Extract deadlines using regex patterns."""
deadlines = []
now = datetime.now()
# German date patterns
patterns = [
# "bis zum 15.01.2025"
(r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "spaetestens am 15.01.2025"
(r"sp\u00e4testens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "Abgabetermin: 15.01.2025"
(r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "innerhalb von 14 Tagen"
(r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False),
# "bis Ende Januar"
(r"bis\s+(?:Ende\s+)?(Januar|Februar|M\u00e4rz|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False),
]
for pattern, is_specific_date in patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
try:
if is_specific_date:
day = int(match.group(1))
month = int(match.group(2))
year = int(match.group(3))
if year < 100:
year += 2000
deadline_date = datetime(year, month, day)
if deadline_date < now:
continue
start = max(0, match.start() - 50)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=f"Frist: {match.group(0)}",
confidence=0.85,
source_text=context,
is_firm=True,
))
else:
if "Tagen" in pattern or "Wochen" in pattern:
days = int(match.group(1))
if "Wochen" in match.group(0).lower():
days *= 7
deadline_date = now + timedelta(days=days)
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=f"Relative Frist: {match.group(0)}",
confidence=0.7,
source_text=match.group(0),
is_firm=False,
))
except (ValueError, IndexError) as e:
logger.debug(f"Failed to parse date: {e}")
continue
return deadlines
async def _extract_deadlines_llm(
client: httpx.AsyncClient,
subject: str,
body_preview: str,
) -> List[DeadlineExtraction]:
"""Extract deadlines using LLM."""
try:
prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine:
Betreff: {subject}
Inhalt: {body_preview}
Liste alle Fristen im folgenden Format auf (eine pro Zeile):
DATUM|BESCHREIBUNG|VERBINDLICH
Beispiel: 2025-01-15|Abgabe der Berichte|ja
Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN
Antworte NUR im angegebenen Format.
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 200,
},
)
if response.status_code == 200:
data = response.json()
result_text = data.get("response", "")
if "KEINE_FRISTEN" in result_text:
return []
deadlines = []
for line in result_text.strip().split("\n"):
parts = line.split("|")
if len(parts) >= 2:
try:
date_str = parts[0].strip()
deadline_date = datetime.fromisoformat(date_str)
description = parts[1].strip()
is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=description,
confidence=0.7,
source_text=line,
is_firm=is_firm,
))
except (ValueError, IndexError):
continue
return deadlines
except Exception as e:
logger.warning(f"LLM deadline extraction failed: {e}")
return []
+134
View File
@@ -0,0 +1,134 @@
"""
AI Email - Sender Classification
Domain-based and LLM-based sender classification for emails.
Extracted from ai_service.py to keep files under 500 LOC.
"""
import os
import logging
from typing import Optional
import httpx
from .models import (
SenderType,
SenderClassification,
classify_sender_by_domain,
)
logger = logging.getLogger(__name__)
# LLM Gateway configuration
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
async def classify_sender(
http_client: httpx.AsyncClient,
sender_email: str,
sender_name: Optional[str] = None,
subject: Optional[str] = None,
body_preview: Optional[str] = None,
) -> SenderClassification:
"""
Classify the sender of an email.
First tries domain matching, then falls back to LLM.
"""
# Try domain-based classification first (fast, high confidence)
domain_result = classify_sender_by_domain(sender_email)
if domain_result:
return domain_result
# Fall back to LLM classification
return await _classify_sender_llm(
http_client, sender_email, sender_name, subject, body_preview
)
async def _classify_sender_llm(
client: httpx.AsyncClient,
sender_email: str,
sender_name: Optional[str],
subject: Optional[str],
body_preview: Optional[str],
) -> SenderClassification:
"""Classify sender using LLM."""
try:
prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn:
Absender E-Mail: {sender_email}
Absender Name: {sender_name or "Nicht angegeben"}
Betreff: {subject or "Nicht angegeben"}
Vorschau: {body_preview[:200] if body_preview else "Nicht verfuegbar"}
Klassifiziere den Absender in EINE der folgenden Kategorien:
- kultusministerium: Kultusministerium/Bildungsministerium
- landesschulbehoerde: Landesschulbehoerde
- rlsb: Regionales Landesamt fuer Schule und Bildung
- schulamt: Schulamt
- nibis: Niedersaechsischer Bildungsserver
- schultraeger: Schultraeger/Kommune
- elternvertreter: Elternvertreter/Elternrat
- gewerkschaft: Gewerkschaft (GEW, VBE, etc.)
- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.)
- privatperson: Privatperson
- unternehmen: Unternehmen/Firma
- unbekannt: Nicht einzuordnen
Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0.
Format: kategorie|konfidenz|kurze_begruendung
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 100,
},
)
if response.status_code == 200:
data = response.json()
result_text = data.get("response", "unbekannt|0.5|")
parts = result_text.strip().split("|")
if len(parts) >= 2:
sender_type_str = parts[0].strip().lower()
confidence = float(parts[1].strip())
type_mapping = {
"kultusministerium": SenderType.KULTUSMINISTERIUM,
"landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE,
"rlsb": SenderType.RLSB,
"schulamt": SenderType.SCHULAMT,
"nibis": SenderType.NIBIS,
"schultraeger": SenderType.SCHULTRAEGER,
"elternvertreter": SenderType.ELTERNVERTRETER,
"gewerkschaft": SenderType.GEWERKSCHAFT,
"fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT,
"privatperson": SenderType.PRIVATPERSON,
"unternehmen": SenderType.UNTERNEHMEN,
}
sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT)
return SenderClassification(
sender_type=sender_type,
confidence=min(max(confidence, 0.0), 1.0),
domain_matched=False,
ai_classified=True,
)
except Exception as e:
logger.warning(f"LLM sender classification failed: {e}")
# Default fallback
return SenderClassification(
sender_type=SenderType.UNBEKANNT,
confidence=0.3,
domain_matched=False,
ai_classified=False,
)
+27 -578
View File
@@ -1,18 +1,19 @@
"""
AI Email Analysis Service
AI Email Analysis Service — Barrel Re-export
KI-powered email analysis with:
- Sender classification (authority recognition)
- Deadline extraction
- Category classification
- Response suggestions
Split into:
- mail/ai_sender.py — Sender classification (domain + LLM)
- mail/ai_deadline.py — Deadline extraction (regex + LLM)
- mail/ai_category.py — Category classification + response suggestions
The AIEmailService class and get_ai_email_service() are defined here
to maintain the original public API.
"""
import os
import re
import logging
from typing import Optional, List, Dict, Any, Tuple
from datetime import datetime, timedelta
from typing import Optional, List, Tuple
from datetime import datetime
import httpx
from .models import (
@@ -23,17 +24,15 @@ from .models import (
DeadlineExtraction,
EmailAnalysisResult,
ResponseSuggestion,
KNOWN_AUTHORITIES_NI,
classify_sender_by_domain,
get_priority_from_sender_type,
)
from .mail_db import update_email_ai_analysis
from .ai_sender import classify_sender, LLM_GATEWAY_URL
from .ai_deadline import extract_deadlines
from .ai_category import classify_category, suggest_response
logger = logging.getLogger(__name__)
# LLM Gateway configuration
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
class AIEmailService:
"""
@@ -56,10 +55,6 @@ class AIEmailService:
self._http_client = httpx.AsyncClient(timeout=30.0)
return self._http_client
# =========================================================================
# Sender Classification
# =========================================================================
async def classify_sender(
self,
sender_email: str,
@@ -67,300 +62,20 @@ class AIEmailService:
subject: Optional[str] = None,
body_preview: Optional[str] = None,
) -> SenderClassification:
"""
Classify the sender of an email.
First tries domain matching, then falls back to LLM.
Args:
sender_email: Sender's email address
sender_name: Sender's display name
subject: Email subject
body_preview: First 200 chars of body
Returns:
SenderClassification with type and confidence
"""
# Try domain-based classification first (fast, high confidence)
domain_result = classify_sender_by_domain(sender_email)
if domain_result:
return domain_result
# Fall back to LLM classification
return await self._classify_sender_llm(
sender_email, sender_name, subject, body_preview
)
async def _classify_sender_llm(
self,
sender_email: str,
sender_name: Optional[str],
subject: Optional[str],
body_preview: Optional[str],
) -> SenderClassification:
"""Classify sender using LLM."""
try:
"""Classify the sender of an email."""
client = await self.get_http_client()
prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn:
Absender E-Mail: {sender_email}
Absender Name: {sender_name or "Nicht angegeben"}
Betreff: {subject or "Nicht angegeben"}
Vorschau: {body_preview[:200] if body_preview else "Nicht verfügbar"}
Klassifiziere den Absender in EINE der folgenden Kategorien:
- kultusministerium: Kultusministerium/Bildungsministerium
- landesschulbehoerde: Landesschulbehörde
- rlsb: Regionales Landesamt für Schule und Bildung
- schulamt: Schulamt
- nibis: Niedersächsischer Bildungsserver
- schultraeger: Schulträger/Kommune
- elternvertreter: Elternvertreter/Elternrat
- gewerkschaft: Gewerkschaft (GEW, VBE, etc.)
- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.)
- privatperson: Privatperson
- unternehmen: Unternehmen/Firma
- unbekannt: Nicht einzuordnen
Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0.
Format: kategorie|konfidenz|kurze_begründung
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 100,
},
return await classify_sender(
client, sender_email, sender_name, subject, body_preview
)
if response.status_code == 200:
data = response.json()
result_text = data.get("response", "unbekannt|0.5|")
# Parse response
parts = result_text.strip().split("|")
if len(parts) >= 2:
sender_type_str = parts[0].strip().lower()
confidence = float(parts[1].strip())
# Map to enum
type_mapping = {
"kultusministerium": SenderType.KULTUSMINISTERIUM,
"landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE,
"rlsb": SenderType.RLSB,
"schulamt": SenderType.SCHULAMT,
"nibis": SenderType.NIBIS,
"schultraeger": SenderType.SCHULTRAEGER,
"elternvertreter": SenderType.ELTERNVERTRETER,
"gewerkschaft": SenderType.GEWERKSCHAFT,
"fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT,
"privatperson": SenderType.PRIVATPERSON,
"unternehmen": SenderType.UNTERNEHMEN,
}
sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT)
return SenderClassification(
sender_type=sender_type,
confidence=min(max(confidence, 0.0), 1.0),
domain_matched=False,
ai_classified=True,
)
except Exception as e:
logger.warning(f"LLM sender classification failed: {e}")
# Default fallback
return SenderClassification(
sender_type=SenderType.UNBEKANNT,
confidence=0.3,
domain_matched=False,
ai_classified=False,
)
# =========================================================================
# Deadline Extraction
# =========================================================================
async def extract_deadlines(
self,
subject: str,
body_text: str,
) -> List[DeadlineExtraction]:
"""
Extract deadlines from email content.
Uses regex patterns first, then LLM for complex cases.
Args:
subject: Email subject
body_text: Email body text
Returns:
List of extracted deadlines
"""
deadlines = []
# Combine subject and body
full_text = f"{subject}\n{body_text}" if body_text else subject
# Try regex extraction first
regex_deadlines = self._extract_deadlines_regex(full_text)
deadlines.extend(regex_deadlines)
# If no regex matches, try LLM
if not deadlines and body_text:
llm_deadlines = await self._extract_deadlines_llm(subject, body_text[:1000])
deadlines.extend(llm_deadlines)
return deadlines
def _extract_deadlines_regex(self, text: str) -> List[DeadlineExtraction]:
"""Extract deadlines using regex patterns."""
deadlines = []
now = datetime.now()
# German date patterns
patterns = [
# "bis zum 15.01.2025"
(r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "spätestens am 15.01.2025"
(r"spätestens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "Abgabetermin: 15.01.2025"
(r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "innerhalb von 14 Tagen"
(r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False),
# "bis Ende Januar"
(r"bis\s+(?:Ende\s+)?(Januar|Februar|März|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False),
]
for pattern, is_specific_date in patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
try:
if is_specific_date:
day = int(match.group(1))
month = int(match.group(2))
year = int(match.group(3))
# Handle 2-digit years
if year < 100:
year += 2000
deadline_date = datetime(year, month, day)
# Skip past dates
if deadline_date < now:
continue
# Get surrounding context
start = max(0, match.start() - 50)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=f"Frist: {match.group(0)}",
confidence=0.85,
source_text=context,
is_firm=True,
))
else:
# Relative dates (innerhalb von X Tagen)
if "Tagen" in pattern or "Wochen" in pattern:
days = int(match.group(1))
if "Wochen" in match.group(0).lower():
days *= 7
deadline_date = now + timedelta(days=days)
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=f"Relative Frist: {match.group(0)}",
confidence=0.7,
source_text=match.group(0),
is_firm=False,
))
except (ValueError, IndexError) as e:
logger.debug(f"Failed to parse date: {e}")
continue
return deadlines
async def _extract_deadlines_llm(
self,
subject: str,
body_preview: str,
) -> List[DeadlineExtraction]:
"""Extract deadlines using LLM."""
try:
"""Extract deadlines from email content."""
client = await self.get_http_client()
prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine:
Betreff: {subject}
Inhalt: {body_preview}
Liste alle Fristen im folgenden Format auf (eine pro Zeile):
DATUM|BESCHREIBUNG|VERBINDLICH
Beispiel: 2025-01-15|Abgabe der Berichte|ja
Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN
Antworte NUR im angegebenen Format.
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 200,
},
)
if response.status_code == 200:
data = response.json()
result_text = data.get("response", "")
if "KEINE_FRISTEN" in result_text:
return []
deadlines = []
for line in result_text.strip().split("\n"):
parts = line.split("|")
if len(parts) >= 2:
try:
date_str = parts[0].strip()
deadline_date = datetime.fromisoformat(date_str)
description = parts[1].strip()
is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=description,
confidence=0.7,
source_text=line,
is_firm=is_firm,
))
except (ValueError, IndexError):
continue
return deadlines
except Exception as e:
logger.warning(f"LLM deadline extraction failed: {e}")
return []
# =========================================================================
# Email Category Classification
# =========================================================================
return await extract_deadlines(client, subject, body_text)
async def classify_category(
self,
@@ -368,155 +83,9 @@ Antworte NUR im angegebenen Format.
body_preview: str,
sender_type: SenderType,
) -> Tuple[EmailCategory, float]:
"""
Classify email into a category.
Args:
subject: Email subject
body_preview: First 200 chars of body
sender_type: Already classified sender type
Returns:
Tuple of (category, confidence)
"""
# Rule-based classification first
category, confidence = self._classify_category_rules(subject, body_preview, sender_type)
if confidence > 0.7:
return category, confidence
# Fall back to LLM
return await self._classify_category_llm(subject, body_preview)
def _classify_category_rules(
self,
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Tuple[EmailCategory, float]:
"""Rule-based category classification."""
text = f"{subject} {body_preview}".lower()
# Keywords for each category
category_keywords = {
EmailCategory.DIENSTLICH: [
"dienstlich", "dienstanweisung", "erlass", "verordnung",
"bescheid", "verfügung", "ministerium", "behörde"
],
EmailCategory.PERSONAL: [
"personalrat", "stellenausschreibung", "versetzung",
"beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit"
],
EmailCategory.FINANZEN: [
"budget", "haushalt", "etat", "abrechnung", "rechnung",
"erstattung", "zuschuss", "fördermittel"
],
EmailCategory.ELTERN: [
"elternbrief", "elternabend", "schulkonferenz",
"elternvertreter", "elternbeirat"
],
EmailCategory.SCHUELER: [
"schüler", "schülerin", "zeugnis", "klasse", "unterricht",
"prüfung", "klassenfahrt", "schulpflicht"
],
EmailCategory.FORTBILDUNG: [
"fortbildung", "seminar", "workshop", "schulung",
"weiterbildung", "nlq", "didaktik"
],
EmailCategory.VERANSTALTUNG: [
"einladung", "veranstaltung", "termin", "konferenz",
"sitzung", "tagung", "feier"
],
EmailCategory.SICHERHEIT: [
"sicherheit", "notfall", "brandschutz", "evakuierung",
"hygiene", "corona", "infektionsschutz"
],
EmailCategory.TECHNIK: [
"it", "software", "computer", "netzwerk", "login",
"passwort", "digitalisierung", "iserv"
],
EmailCategory.NEWSLETTER: [
"newsletter", "rundschreiben", "info-mail", "mitteilung"
],
EmailCategory.WERBUNG: [
"angebot", "rabatt", "aktion", "werbung", "abonnement"
],
}
best_category = EmailCategory.SONSTIGES
best_score = 0.0
for category, keywords in category_keywords.items():
score = sum(1 for kw in keywords if kw in text)
if score > best_score:
best_score = score
best_category = category
# Adjust based on sender type
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
if best_category == EmailCategory.SONSTIGES:
best_category = EmailCategory.DIENSTLICH
best_score = 2
# Convert score to confidence
confidence = min(0.9, 0.4 + (best_score * 0.15))
return best_category, confidence
async def _classify_category_llm(
self,
subject: str,
body_preview: str,
) -> Tuple[EmailCategory, float]:
"""LLM-based category classification."""
try:
"""Classify email into a category."""
client = await self.get_http_client()
categories = ", ".join([c.value for c in EmailCategory])
prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie:
Betreff: {subject}
Inhalt: {body_preview[:500]}
Kategorien: {categories}
Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0):
Format: kategorie|konfidenz
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 50,
},
)
if response.status_code == 200:
data = response.json()
result = data.get("response", "sonstiges|0.5")
parts = result.strip().split("|")
if len(parts) >= 2:
category_str = parts[0].strip().lower()
confidence = float(parts[1].strip())
try:
category = EmailCategory(category_str)
return category, min(max(confidence, 0.0), 1.0)
except ValueError:
pass
except Exception as e:
logger.warning(f"LLM category classification failed: {e}")
return EmailCategory.SONSTIGES, 0.5
# =========================================================================
# Full Analysis Pipeline
# =========================================================================
return await classify_category(client, subject, body_preview, sender_type)
async def analyze_email(
self,
@@ -527,20 +96,7 @@ Format: kategorie|konfidenz
body_text: Optional[str],
body_preview: Optional[str],
) -> EmailAnalysisResult:
"""
Run full analysis pipeline on an email.
Args:
email_id: Database ID of the email
sender_email: Sender's email address
sender_name: Sender's display name
subject: Email subject
body_text: Full body text
body_preview: Preview text
Returns:
Complete analysis result
"""
"""Run full analysis pipeline on an email."""
# 1. Classify sender
sender_classification = await self.classify_sender(
sender_email, sender_name, subject, body_preview
@@ -569,8 +125,8 @@ Format: kategorie|konfidenz
elif days_until <= 7:
suggested_priority = max(suggested_priority, TaskPriority.MEDIUM)
# 5. Generate summary (optional, can be expensive)
summary = None # Could add LLM summary generation here
# 5. Summary (optional)
summary = None
# 6. Determine if task should be auto-created
auto_create_task = (
@@ -612,10 +168,6 @@ Format: kategorie|konfidenz
auto_create_task=auto_create_task,
)
# =========================================================================
# Response Suggestions
# =========================================================================
async def suggest_response(
self,
subject: str,
@@ -623,115 +175,12 @@ Format: kategorie|konfidenz
sender_type: SenderType,
category: EmailCategory,
) -> List[ResponseSuggestion]:
"""
Generate response suggestions for an email.
Args:
subject: Original email subject
body_text: Original email body
sender_type: Classified sender type
category: Classified category
Returns:
List of response suggestions
"""
suggestions = []
# Add standard templates based on sender type and category
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
suggestions.append(ResponseSuggestion(
template_type="acknowledgment",
subject=f"Re: {subject}",
body="""Sehr geehrte Damen und Herren,
vielen Dank für Ihre Nachricht.
Ich bestätige den Eingang und werde die Angelegenheit fristgerecht bearbeiten.
Mit freundlichen Grüßen""",
confidence=0.8,
))
if category == EmailCategory.ELTERN:
suggestions.append(ResponseSuggestion(
template_type="parent_response",
subject=f"Re: {subject}",
body="""Liebe Eltern,
vielen Dank für Ihre Nachricht.
[Ihre Antwort hier]
Mit freundlichen Grüßen""",
confidence=0.7,
))
# Add LLM-generated suggestion
try:
llm_suggestion = await self._generate_response_llm(subject, body_text[:500], sender_type)
if llm_suggestion:
suggestions.append(llm_suggestion)
except Exception as e:
logger.warning(f"LLM response generation failed: {e}")
return suggestions
async def _generate_response_llm(
self,
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Optional[ResponseSuggestion]:
"""Generate a response suggestion using LLM."""
try:
"""Generate response suggestions for an email."""
client = await self.get_http_client()
sender_desc = {
SenderType.KULTUSMINISTERIUM: "dem Kultusministerium",
SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehörde",
SenderType.RLSB: "dem RLSB",
SenderType.ELTERNVERTRETER: "einem Elternvertreter",
}.get(sender_type, "einem Absender")
prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}:
Betreff: {subject}
Inhalt: {body_preview}
Die Antwort sollte:
- Höflich und formell sein
- Den Eingang bestätigen
- Eine konkrete nächste Aktion nennen oder um Klärung bitten
Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:").
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 300,
},
return await suggest_response(
client, subject, body_text, sender_type, category
)
if response.status_code == 200:
data = response.json()
body = data.get("response", "").strip()
if body:
return ResponseSuggestion(
template_type="ai_generated",
subject=f"Re: {subject}",
body=body,
confidence=0.6,
)
except Exception as e:
logger.warning(f"LLM response generation failed: {e}")
return None
# Global instance
_ai_service: Optional[AIEmailService] = None
+33 -830
View File
@@ -1,833 +1,36 @@
"""
PostgreSQL Metrics Database Service
Stores search feedback, calculates quality metrics (Precision, Recall, MRR).
PostgreSQL Metrics Database Service — Barrel Re-export
Split into:
- metrics_db_core.py — Pool, feedback, metrics, relevance
- metrics_db_schema.py — Table initialization (DDL)
- metrics_db_zeugnis.py — Zeugnis source/document/stats operations
All public names are re-exported here for backward compatibility.
"""
import os
from typing import Optional, List, Dict
from datetime import datetime, timedelta
import asyncio
# Database Configuration - uses test default if not configured (for CI)
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
# Connection pool
_pool = None
async def get_pool():
"""Get or create database connection pool."""
global _pool
if _pool is None:
try:
import asyncpg
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
except ImportError:
print("Warning: asyncpg not installed. Metrics storage disabled.")
return None
except Exception as e:
print(f"Warning: Failed to connect to PostgreSQL: {e}")
return None
return _pool
async def init_metrics_tables() -> bool:
"""Initialize metrics tables in PostgreSQL."""
pool = await get_pool()
if pool is None:
return False
create_tables_sql = """
-- RAG Search Feedback Table
CREATE TABLE IF NOT EXISTS rag_search_feedback (
id SERIAL PRIMARY KEY,
result_id VARCHAR(255) NOT NULL,
query_text TEXT,
collection_name VARCHAR(100),
score FLOAT,
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
notes TEXT,
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
-- Index for efficient querying
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
-- RAG Search Logs Table (for latency tracking)
CREATE TABLE IF NOT EXISTS rag_search_logs (
id SERIAL PRIMARY KEY,
query_text TEXT NOT NULL,
collection_name VARCHAR(100),
result_count INTEGER,
latency_ms INTEGER,
top_score FLOAT,
filters JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
-- RAG Upload History Table
CREATE TABLE IF NOT EXISTS rag_upload_history (
id SERIAL PRIMARY KEY,
filename VARCHAR(500) NOT NULL,
collection_name VARCHAR(100),
year INTEGER,
pdfs_extracted INTEGER,
minio_path VARCHAR(1000),
uploaded_by VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
-- Binäre Relevanz-Judgments für echte Precision/Recall
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
id SERIAL PRIMARY KEY,
query_id VARCHAR(255) NOT NULL,
query_text TEXT NOT NULL,
result_id VARCHAR(255) NOT NULL,
result_rank INTEGER,
is_relevant BOOLEAN NOT NULL,
collection_name VARCHAR(100),
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
-- Zeugnisse Source Tracking
CREATE TABLE IF NOT EXISTS zeugnis_sources (
id VARCHAR(36) PRIMARY KEY,
bundesland VARCHAR(10) NOT NULL,
name VARCHAR(255) NOT NULL,
base_url TEXT,
license_type VARCHAR(50) NOT NULL,
training_allowed BOOLEAN DEFAULT FALSE,
verified_by VARCHAR(100),
verified_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
-- Zeugnisse Seed URLs
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
url TEXT NOT NULL,
doc_type VARCHAR(50),
status VARCHAR(20) DEFAULT 'pending',
last_crawled TIMESTAMP,
error_message TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
-- Zeugnisse Documents
CREATE TABLE IF NOT EXISTS zeugnis_documents (
id VARCHAR(36) PRIMARY KEY,
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
title VARCHAR(500),
url TEXT NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
training_allowed BOOLEAN DEFAULT FALSE,
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
file_size INTEGER,
content_type VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
-- Zeugnisse Document Versions
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
version INTEGER NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
change_summary TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
-- Zeugnisse Usage Events (Audit Trail)
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
event_type VARCHAR(50) NOT NULL,
user_id VARCHAR(100),
details JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
-- Crawler Queue
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
priority INTEGER DEFAULT 5,
status VARCHAR(20) DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
documents_found INTEGER DEFAULT 0,
documents_indexed INTEGER DEFAULT 0,
error_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
"""
try:
async with pool.acquire() as conn:
await conn.execute(create_tables_sql)
print("RAG metrics tables initialized")
return True
except Exception as e:
print(f"Failed to initialize metrics tables: {e}")
return False
# =============================================================================
# Feedback Storage
# =============================================================================
async def store_feedback(
result_id: str,
rating: int,
query_text: Optional[str] = None,
collection_name: Optional[str] = None,
score: Optional[float] = None,
notes: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store search result feedback."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_feedback
(result_id, query_text, collection_name, score, rating, notes, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
result_id, query_text, collection_name, score, rating, notes, user_id
)
return True
except Exception as e:
print(f"Failed to store feedback: {e}")
return False
async def log_search(
query_text: str,
collection_name: str,
result_count: int,
latency_ms: int,
top_score: Optional[float] = None,
filters: Optional[Dict] = None,
) -> bool:
"""Log a search for metrics tracking."""
pool = await get_pool()
if pool is None:
return False
try:
import json
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_logs
(query_text, collection_name, result_count, latency_ms, top_score, filters)
VALUES ($1, $2, $3, $4, $5, $6)
""",
query_text, collection_name, result_count, latency_ms, top_score,
json.dumps(filters) if filters else None
)
return True
except Exception as e:
print(f"Failed to log search: {e}")
return False
async def log_upload(
filename: str,
collection_name: str,
year: int,
pdfs_extracted: int,
minio_path: Optional[str] = None,
uploaded_by: Optional[str] = None,
) -> bool:
"""Log an upload for history tracking."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_upload_history
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
VALUES ($1, $2, $3, $4, $5, $6)
""",
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
)
return True
except Exception as e:
print(f"Failed to log upload: {e}")
return False
# =============================================================================
# Metrics Calculation
# =============================================================================
async def calculate_metrics(
collection_name: Optional[str] = None,
days: int = 7,
) -> Dict:
"""
Calculate RAG quality metrics from stored feedback.
Returns:
Dict with precision, recall, MRR, latency, etc.
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
# Date filter
since = datetime.now() - timedelta(days=days)
# Collection filter
collection_filter = ""
params = [since]
if collection_name:
collection_filter = "AND collection_name = $2"
params.append(collection_name)
# Total feedback count
total_feedback = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
# Rating distribution
rating_dist = await conn.fetch(
f"""
SELECT rating, COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
GROUP BY rating
ORDER BY rating DESC
""",
*params
)
# Average rating (proxy for precision)
avg_rating = await conn.fetchval(
f"""
SELECT AVG(rating) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
# Score distribution
score_dist = await conn.fetch(
f"""
SELECT
CASE
WHEN score >= 0.9 THEN '0.9+'
WHEN score >= 0.7 THEN '0.7-0.9'
WHEN score >= 0.5 THEN '0.5-0.7'
ELSE '<0.5'
END as range,
COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
GROUP BY range
ORDER BY range DESC
""",
*params
)
# Search logs for latency
latency_stats = await conn.fetchrow(
f"""
SELECT
AVG(latency_ms) as avg_latency,
COUNT(*) as total_searches,
AVG(result_count) as avg_results
FROM rag_search_logs
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
""",
*params
)
# Calculate precision@5 (% of top 5 rated 4+)
precision_at_5 = await conn.fetchval(
f"""
SELECT
CASE WHEN COUNT(*) > 0
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
ELSE 0 END
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
) or 0
# Calculate MRR (Mean Reciprocal Rank) - simplified
# Using average rating as proxy for relevance
mrr = (avg_rating or 0) / 5.0
# Error rate (ratings of 1 or 2)
error_count = sum(
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
)
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
# Score distribution as percentages
total_scored = sum(s['count'] for s in score_dist)
score_distribution = {}
for s in score_dist:
if total_scored > 0:
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
else:
score_distribution[s['range']] = 0
return {
"connected": True,
"period_days": days,
"precision_at_5": round(precision_at_5, 2),
"recall_at_10": round(precision_at_5 * 1.1, 2), # Estimated
"mrr": round(mrr, 2),
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
"total_ratings": total_feedback,
"total_searches": latency_stats['total_searches'] or 0,
"error_rate": round(error_rate, 1),
"score_distribution": score_distribution,
"rating_distribution": {
str(r['rating']): r['count'] for r in rating_dist if r['rating']
},
}
except Exception as e:
print(f"Failed to calculate metrics: {e}")
return {"error": str(e), "connected": False}
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
"""Get recent feedback entries."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
FROM rag_search_feedback
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"result_id": r['result_id'],
"rating": r['rating'],
"query_text": r['query_text'],
"collection_name": r['collection_name'],
"score": r['score'],
"notes": r['notes'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get recent feedback: {e}")
return []
async def get_upload_history(limit: int = 20) -> List[Dict]:
"""Get recent upload history."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
FROM rag_upload_history
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"filename": r['filename'],
"collection_name": r['collection_name'],
"year": r['year'],
"pdfs_extracted": r['pdfs_extracted'],
"minio_path": r['minio_path'],
"uploaded_by": r['uploaded_by'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get upload history: {e}")
return []
# =============================================================================
# Relevance Judgments (Binary Precision/Recall)
# =============================================================================
async def store_relevance_judgment(
query_id: str,
query_text: str,
result_id: str,
is_relevant: bool,
result_rank: Optional[int] = None,
collection_name: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store binary relevance judgment for Precision/Recall calculation."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_relevance_judgments
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT DO NOTHING
""",
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
)
return True
except Exception as e:
print(f"Failed to store relevance judgment: {e}")
return False
async def calculate_precision_recall(
collection_name: Optional[str] = None,
days: int = 7,
k: int = 10,
) -> Dict:
"""
Calculate true Precision@k and Recall@k from binary relevance judgments.
Precision@k = (Relevant docs in top k) / k
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since, k]
if collection_name:
collection_filter = "AND collection_name = $3"
params.append(collection_name)
# Get precision@k per query, then average
precision_result = await conn.fetchval(
f"""
WITH query_precision AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
GREATEST(COUNT(*), 1) as precision
FROM rag_relevance_judgments
WHERE created_at >= $1
AND (result_rank IS NULL OR result_rank <= $2)
{collection_filter}
GROUP BY query_id
)
SELECT AVG(precision) FROM query_precision
""",
*params
) or 0
# Get recall@k per query, then average
recall_result = await conn.fetchval(
f"""
WITH query_recall AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
FROM rag_relevance_judgments
WHERE created_at >= $1
{collection_filter}
GROUP BY query_id
)
SELECT AVG(recall) FROM query_recall
""",
*params
) or 0
# Total judgments
total_judgments = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
# Unique queries
unique_queries = await conn.fetchval(
f"""
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
return {
"connected": True,
"period_days": days,
"k": k,
"precision_at_k": round(precision_result, 3),
"recall_at_k": round(recall_result, 3),
"f1_score": round(
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
),
"total_judgments": total_judgments or 0,
"unique_queries": unique_queries or 0,
}
except Exception as e:
print(f"Failed to calculate precision/recall: {e}")
return {"error": str(e), "connected": False}
# =============================================================================
# Zeugnis Database Operations
# =============================================================================
async def get_zeugnis_sources() -> List[Dict]:
"""Get all zeugnis sources (Bundesländer)."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, bundesland, name, base_url, license_type, training_allowed,
verified_by, verified_at, created_at, updated_at
FROM zeugnis_sources
ORDER BY bundesland
"""
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis sources: {e}")
return []
async def upsert_zeugnis_source(
id: str,
bundesland: str,
name: str,
license_type: str,
training_allowed: bool,
base_url: Optional[str] = None,
verified_by: Optional[str] = None,
) -> bool:
"""Insert or update a zeugnis source."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
base_url = EXCLUDED.base_url,
license_type = EXCLUDED.license_type,
training_allowed = EXCLUDED.training_allowed,
verified_by = EXCLUDED.verified_by,
verified_at = NOW(),
updated_at = NOW()
""",
id, bundesland, name, base_url, license_type, training_allowed, verified_by
)
return True
except Exception as e:
print(f"Failed to upsert zeugnis source: {e}")
return False
async def get_zeugnis_documents(
bundesland: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
"""Get zeugnis documents with optional filtering."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
if bundesland:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE s.bundesland = $1
ORDER BY d.created_at DESC
LIMIT $2 OFFSET $3
""",
bundesland, limit, offset
)
else:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
ORDER BY d.created_at DESC
LIMIT $1 OFFSET $2
""",
limit, offset
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis documents: {e}")
return []
async def get_zeugnis_stats() -> Dict:
"""Get zeugnis crawler statistics."""
pool = await get_pool()
if pool is None:
return {"error": "Database not available"}
try:
async with pool.acquire() as conn:
# Total sources
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
# Total documents
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
# Indexed documents
indexed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
)
# Training allowed
training_allowed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
)
# Per Bundesland stats
per_bundesland = await conn.fetch(
"""
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
GROUP BY s.bundesland, s.name, s.training_allowed
ORDER BY s.bundesland
"""
)
# Active crawls
active_crawls = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
)
return {
"total_sources": sources or 0,
"total_documents": documents or 0,
"indexed_documents": indexed or 0,
"training_allowed_documents": training_allowed or 0,
"active_crawls": active_crawls or 0,
"per_bundesland": [dict(r) for r in per_bundesland],
}
except Exception as e:
print(f"Failed to get zeugnis stats: {e}")
return {"error": str(e)}
async def log_zeugnis_event(
document_id: str,
event_type: str,
user_id: Optional[str] = None,
details: Optional[Dict] = None,
) -> bool:
"""Log a zeugnis usage event for audit trail."""
pool = await get_pool()
if pool is None:
return False
try:
import json
import uuid
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
VALUES ($1, $2, $3, $4, $5)
""",
str(uuid.uuid4()), document_id, event_type, user_id,
json.dumps(details) if details else None
)
return True
except Exception as e:
print(f"Failed to log zeugnis event: {e}")
return False
# Schema: table initialization
from metrics_db_schema import init_metrics_tables # noqa: F401
# Core: pool, feedback, search logs, metrics, relevance
from metrics_db_core import ( # noqa: F401
DATABASE_URL,
get_pool,
store_feedback,
log_search,
log_upload,
calculate_metrics,
get_recent_feedback,
get_upload_history,
store_relevance_judgment,
calculate_precision_recall,
)
# Zeugnis operations
from metrics_db_zeugnis import ( # noqa: F401
get_zeugnis_sources,
upsert_zeugnis_source,
get_zeugnis_documents,
get_zeugnis_stats,
log_zeugnis_event,
)
+459
View File
@@ -0,0 +1,459 @@
"""
PostgreSQL Metrics Database - Core Operations
Connection pool, table initialization, feedback storage, search logging,
upload history, metrics calculation, and relevance judgments.
Extracted from metrics_db.py to keep files under 500 LOC.
"""
import os
from typing import Optional, List, Dict
from datetime import datetime, timedelta
# Database Configuration - uses test default if not configured (for CI)
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
# Connection pool
_pool = None
async def get_pool():
"""Get or create database connection pool."""
global _pool
if _pool is None:
try:
import asyncpg
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
except ImportError:
print("Warning: asyncpg not installed. Metrics storage disabled.")
return None
except Exception as e:
print(f"Warning: Failed to connect to PostgreSQL: {e}")
return None
return _pool
# =============================================================================
# Feedback Storage
# =============================================================================
async def store_feedback(
result_id: str,
rating: int,
query_text: Optional[str] = None,
collection_name: Optional[str] = None,
score: Optional[float] = None,
notes: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store search result feedback."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_feedback
(result_id, query_text, collection_name, score, rating, notes, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
result_id, query_text, collection_name, score, rating, notes, user_id
)
return True
except Exception as e:
print(f"Failed to store feedback: {e}")
return False
async def log_search(
query_text: str,
collection_name: str,
result_count: int,
latency_ms: int,
top_score: Optional[float] = None,
filters: Optional[Dict] = None,
) -> bool:
"""Log a search for metrics tracking."""
pool = await get_pool()
if pool is None:
return False
try:
import json
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_logs
(query_text, collection_name, result_count, latency_ms, top_score, filters)
VALUES ($1, $2, $3, $4, $5, $6)
""",
query_text, collection_name, result_count, latency_ms, top_score,
json.dumps(filters) if filters else None
)
return True
except Exception as e:
print(f"Failed to log search: {e}")
return False
async def log_upload(
filename: str,
collection_name: str,
year: int,
pdfs_extracted: int,
minio_path: Optional[str] = None,
uploaded_by: Optional[str] = None,
) -> bool:
"""Log an upload for history tracking."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_upload_history
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
VALUES ($1, $2, $3, $4, $5, $6)
""",
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
)
return True
except Exception as e:
print(f"Failed to log upload: {e}")
return False
# =============================================================================
# Metrics Calculation
# =============================================================================
async def calculate_metrics(
collection_name: Optional[str] = None,
days: int = 7,
) -> Dict:
"""
Calculate RAG quality metrics from stored feedback.
Returns:
Dict with precision, recall, MRR, latency, etc.
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since]
if collection_name:
collection_filter = "AND collection_name = $2"
params.append(collection_name)
total_feedback = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
rating_dist = await conn.fetch(
f"""
SELECT rating, COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
GROUP BY rating
ORDER BY rating DESC
""",
*params
)
avg_rating = await conn.fetchval(
f"""
SELECT AVG(rating) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
score_dist = await conn.fetch(
f"""
SELECT
CASE
WHEN score >= 0.9 THEN '0.9+'
WHEN score >= 0.7 THEN '0.7-0.9'
WHEN score >= 0.5 THEN '0.5-0.7'
ELSE '<0.5'
END as range,
COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
GROUP BY range
ORDER BY range DESC
""",
*params
)
latency_stats = await conn.fetchrow(
f"""
SELECT
AVG(latency_ms) as avg_latency,
COUNT(*) as total_searches,
AVG(result_count) as avg_results
FROM rag_search_logs
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
""",
*params
)
precision_at_5 = await conn.fetchval(
f"""
SELECT
CASE WHEN COUNT(*) > 0
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
ELSE 0 END
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
) or 0
mrr = (avg_rating or 0) / 5.0
error_count = sum(
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
)
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
total_scored = sum(s['count'] for s in score_dist)
score_distribution = {}
for s in score_dist:
if total_scored > 0:
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
else:
score_distribution[s['range']] = 0
return {
"connected": True,
"period_days": days,
"precision_at_5": round(precision_at_5, 2),
"recall_at_10": round(precision_at_5 * 1.1, 2),
"mrr": round(mrr, 2),
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
"total_ratings": total_feedback,
"total_searches": latency_stats['total_searches'] or 0,
"error_rate": round(error_rate, 1),
"score_distribution": score_distribution,
"rating_distribution": {
str(r['rating']): r['count'] for r in rating_dist if r['rating']
},
}
except Exception as e:
print(f"Failed to calculate metrics: {e}")
return {"error": str(e), "connected": False}
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
"""Get recent feedback entries."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
FROM rag_search_feedback
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"result_id": r['result_id'],
"rating": r['rating'],
"query_text": r['query_text'],
"collection_name": r['collection_name'],
"score": r['score'],
"notes": r['notes'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get recent feedback: {e}")
return []
async def get_upload_history(limit: int = 20) -> List[Dict]:
"""Get recent upload history."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
FROM rag_upload_history
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"filename": r['filename'],
"collection_name": r['collection_name'],
"year": r['year'],
"pdfs_extracted": r['pdfs_extracted'],
"minio_path": r['minio_path'],
"uploaded_by": r['uploaded_by'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get upload history: {e}")
return []
# =============================================================================
# Relevance Judgments (Binary Precision/Recall)
# =============================================================================
async def store_relevance_judgment(
query_id: str,
query_text: str,
result_id: str,
is_relevant: bool,
result_rank: Optional[int] = None,
collection_name: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store binary relevance judgment for Precision/Recall calculation."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_relevance_judgments
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT DO NOTHING
""",
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
)
return True
except Exception as e:
print(f"Failed to store relevance judgment: {e}")
return False
async def calculate_precision_recall(
collection_name: Optional[str] = None,
days: int = 7,
k: int = 10,
) -> Dict:
"""
Calculate true Precision@k and Recall@k from binary relevance judgments.
Precision@k = (Relevant docs in top k) / k
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since, k]
if collection_name:
collection_filter = "AND collection_name = $3"
params.append(collection_name)
precision_result = await conn.fetchval(
f"""
WITH query_precision AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
GREATEST(COUNT(*), 1) as precision
FROM rag_relevance_judgments
WHERE created_at >= $1
AND (result_rank IS NULL OR result_rank <= $2)
{collection_filter}
GROUP BY query_id
)
SELECT AVG(precision) FROM query_precision
""",
*params
) or 0
recall_result = await conn.fetchval(
f"""
WITH query_recall AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
FROM rag_relevance_judgments
WHERE created_at >= $1
{collection_filter}
GROUP BY query_id
)
SELECT AVG(recall) FROM query_recall
""",
*params
) or 0
total_judgments = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
unique_queries = await conn.fetchval(
f"""
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
return {
"connected": True,
"period_days": days,
"k": k,
"precision_at_k": round(precision_result, 3),
"recall_at_k": round(recall_result, 3),
"f1_score": round(
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
),
"total_judgments": total_judgments or 0,
"unique_queries": unique_queries or 0,
}
except Exception as e:
print(f"Failed to calculate precision/recall: {e}")
return {"error": str(e), "connected": False}
@@ -0,0 +1,182 @@
"""
PostgreSQL Metrics Database - Schema Initialization
Table creation DDL for all metrics, feedback, and zeugnis tables.
Extracted from metrics_db_core.py to keep files under 500 LOC.
"""
from metrics_db_core import get_pool
async def init_metrics_tables() -> bool:
"""Initialize metrics tables in PostgreSQL."""
pool = await get_pool()
if pool is None:
return False
create_tables_sql = """
-- RAG Search Feedback Table
CREATE TABLE IF NOT EXISTS rag_search_feedback (
id SERIAL PRIMARY KEY,
result_id VARCHAR(255) NOT NULL,
query_text TEXT,
collection_name VARCHAR(100),
score FLOAT,
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
notes TEXT,
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
-- Index for efficient querying
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
-- RAG Search Logs Table (for latency tracking)
CREATE TABLE IF NOT EXISTS rag_search_logs (
id SERIAL PRIMARY KEY,
query_text TEXT NOT NULL,
collection_name VARCHAR(100),
result_count INTEGER,
latency_ms INTEGER,
top_score FLOAT,
filters JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
-- RAG Upload History Table
CREATE TABLE IF NOT EXISTS rag_upload_history (
id SERIAL PRIMARY KEY,
filename VARCHAR(500) NOT NULL,
collection_name VARCHAR(100),
year INTEGER,
pdfs_extracted INTEGER,
minio_path VARCHAR(1000),
uploaded_by VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
-- Binaere Relevanz-Judgments fuer echte Precision/Recall
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
id SERIAL PRIMARY KEY,
query_id VARCHAR(255) NOT NULL,
query_text TEXT NOT NULL,
result_id VARCHAR(255) NOT NULL,
result_rank INTEGER,
is_relevant BOOLEAN NOT NULL,
collection_name VARCHAR(100),
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
-- Zeugnisse Source Tracking
CREATE TABLE IF NOT EXISTS zeugnis_sources (
id VARCHAR(36) PRIMARY KEY,
bundesland VARCHAR(10) NOT NULL,
name VARCHAR(255) NOT NULL,
base_url TEXT,
license_type VARCHAR(50) NOT NULL,
training_allowed BOOLEAN DEFAULT FALSE,
verified_by VARCHAR(100),
verified_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
-- Zeugnisse Seed URLs
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
url TEXT NOT NULL,
doc_type VARCHAR(50),
status VARCHAR(20) DEFAULT 'pending',
last_crawled TIMESTAMP,
error_message TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
-- Zeugnisse Documents
CREATE TABLE IF NOT EXISTS zeugnis_documents (
id VARCHAR(36) PRIMARY KEY,
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
title VARCHAR(500),
url TEXT NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
training_allowed BOOLEAN DEFAULT FALSE,
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
file_size INTEGER,
content_type VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
-- Zeugnisse Document Versions
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
version INTEGER NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
change_summary TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
-- Zeugnisse Usage Events (Audit Trail)
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
event_type VARCHAR(50) NOT NULL,
user_id VARCHAR(100),
details JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
-- Crawler Queue
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
priority INTEGER DEFAULT 5,
status VARCHAR(20) DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
documents_found INTEGER DEFAULT 0,
documents_indexed INTEGER DEFAULT 0,
error_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
"""
try:
async with pool.acquire() as conn:
await conn.execute(create_tables_sql)
print("RAG metrics tables initialized")
return True
except Exception as e:
print(f"Failed to initialize metrics tables: {e}")
return False
@@ -0,0 +1,193 @@
"""
PostgreSQL Metrics Database - Zeugnis Operations
Zeugnis source management, document queries, statistics, and event logging.
Extracted from metrics_db.py to keep files under 500 LOC.
"""
from typing import Optional, List, Dict
from metrics_db_core import get_pool
# =============================================================================
# Zeugnis Database Operations
# =============================================================================
async def get_zeugnis_sources() -> List[Dict]:
"""Get all zeugnis sources (Bundeslaender)."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, bundesland, name, base_url, license_type, training_allowed,
verified_by, verified_at, created_at, updated_at
FROM zeugnis_sources
ORDER BY bundesland
"""
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis sources: {e}")
return []
async def upsert_zeugnis_source(
id: str,
bundesland: str,
name: str,
license_type: str,
training_allowed: bool,
base_url: Optional[str] = None,
verified_by: Optional[str] = None,
) -> bool:
"""Insert or update a zeugnis source."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
base_url = EXCLUDED.base_url,
license_type = EXCLUDED.license_type,
training_allowed = EXCLUDED.training_allowed,
verified_by = EXCLUDED.verified_by,
verified_at = NOW(),
updated_at = NOW()
""",
id, bundesland, name, base_url, license_type, training_allowed, verified_by
)
return True
except Exception as e:
print(f"Failed to upsert zeugnis source: {e}")
return False
async def get_zeugnis_documents(
bundesland: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
"""Get zeugnis documents with optional filtering."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
if bundesland:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE s.bundesland = $1
ORDER BY d.created_at DESC
LIMIT $2 OFFSET $3
""",
bundesland, limit, offset
)
else:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
ORDER BY d.created_at DESC
LIMIT $1 OFFSET $2
""",
limit, offset
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis documents: {e}")
return []
async def get_zeugnis_stats() -> Dict:
"""Get zeugnis crawler statistics."""
pool = await get_pool()
if pool is None:
return {"error": "Database not available"}
try:
async with pool.acquire() as conn:
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
indexed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
)
training_allowed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
)
per_bundesland = await conn.fetch(
"""
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
GROUP BY s.bundesland, s.name, s.training_allowed
ORDER BY s.bundesland
"""
)
active_crawls = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
)
return {
"total_sources": sources or 0,
"total_documents": documents or 0,
"indexed_documents": indexed or 0,
"training_allowed_documents": training_allowed or 0,
"active_crawls": active_crawls or 0,
"per_bundesland": [dict(r) for r in per_bundesland],
}
except Exception as e:
print(f"Failed to get zeugnis stats: {e}")
return {"error": str(e)}
async def log_zeugnis_event(
document_id: str,
event_type: str,
user_id: Optional[str] = None,
details: Optional[Dict] = None,
) -> bool:
"""Log a zeugnis usage event for audit trail."""
pool = await get_pool()
if pool is None:
return False
try:
import json
import uuid
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
VALUES ($1, $2, $3, $4, $5)
""",
str(uuid.uuid4()), document_id, event_type, user_id,
json.dumps(details) if details else None
)
return True
except Exception as e:
print(f"Failed to log zeugnis event: {e}")
return False
+64 -828
View File
@@ -1,845 +1,81 @@
"""
OCR Labeling API for Handwriting Training Data Collection
OCR Labeling API — Barrel Re-export
DATENSCHUTZ/PRIVACY:
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
- Keine Daten werden an externe Server gesendet
- Bilder werden mit SHA256-Hash dedupliziert
- Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision)
Split into:
- ocr_labeling_models.py — Pydantic models and constants
- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing
- ocr_labeling_routes.py — Session/queue/labeling route handlers
- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers
Endpoints:
- POST /sessions - Create labeling session
- POST /sessions/{id}/upload - Upload images for labeling
- GET /queue - Get next items to label
- POST /confirm - Confirm OCR as correct
- POST /correct - Save corrected ground truth
- POST /skip - Skip unusable item
- GET /stats - Get labeling statistics
- POST /export - Export training data
All public names are re-exported here for backward compatibility.
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
from datetime import datetime
import uuid
import hashlib
import os
import base64
# Import database functions
from metrics_db import (
create_ocr_labeling_session,
get_ocr_labeling_sessions,
get_ocr_labeling_session,
add_ocr_labeling_item,
get_ocr_labeling_queue,
get_ocr_labeling_item,
confirm_ocr_label,
correct_ocr_label,
skip_ocr_item,
get_ocr_labeling_stats,
export_training_samples,
get_training_samples,
# Models
from ocr_labeling_models import ( # noqa: F401
LOCAL_STORAGE_PATH,
SessionCreate,
SessionResponse,
ItemResponse,
ConfirmRequest,
CorrectRequest,
SkipRequest,
ExportRequest,
StatsResponse,
)
# Try to import Vision OCR service
try:
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
VISION_OCR_AVAILABLE = True
except ImportError:
VISION_OCR_AVAILABLE = False
print("Warning: Vision OCR service not available")
# Helpers
from ocr_labeling_helpers import ( # noqa: F401
VISION_OCR_AVAILABLE,
PADDLEOCR_AVAILABLE,
TROCR_AVAILABLE,
DONUT_AVAILABLE,
MINIO_AVAILABLE,
TRAINING_EXPORT_AVAILABLE,
compute_image_hash,
run_ocr_on_image,
run_vision_ocr_wrapper,
run_paddleocr_wrapper,
run_trocr_wrapper,
run_donut_wrapper,
save_image_locally,
get_image_url,
)
# Try to import PaddleOCR from hybrid_vocab_extractor
# Conditional re-exports from helpers' optional imports
try:
from hybrid_vocab_extractor import run_paddle_ocr
PADDLEOCR_AVAILABLE = True
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401
except ImportError:
PADDLEOCR_AVAILABLE = False
print("Warning: PaddleOCR not available")
pass
# Try to import TrOCR service
try:
from services.trocr_service import run_trocr_ocr
TROCR_AVAILABLE = True
except ImportError:
TROCR_AVAILABLE = False
print("Warning: TrOCR service not available")
# Try to import Donut service
try:
from services.donut_ocr_service import run_donut_ocr
DONUT_AVAILABLE = True
except ImportError:
DONUT_AVAILABLE = False
print("Warning: Donut OCR service not available")
# Try to import MinIO storage
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
print("Warning: MinIO storage not available, using local storage")
# Try to import Training Export Service
try:
from training_export_service import (
from training_export_service import ( # noqa: F401
TrainingExportService,
TrainingSample,
get_training_export_service,
)
TRAINING_EXPORT_AVAILABLE = True
except ImportError:
TRAINING_EXPORT_AVAILABLE = False
print("Warning: Training export service not available")
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
# Local storage path (fallback if MinIO not available)
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
# =============================================================================
# Pydantic Models
# =============================================================================
class SessionCreate(BaseModel):
name: str
source_type: str = "klausur" # klausur, handwriting_sample, scan
description: Optional[str] = None
ocr_model: Optional[str] = "llama3.2-vision:11b"
class SessionResponse(BaseModel):
id: str
name: str
source_type: str
description: Optional[str]
ocr_model: Optional[str]
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
skipped_items: int
created_at: datetime
class ItemResponse(BaseModel):
id: str
session_id: str
session_name: str
image_path: str
image_url: Optional[str]
ocr_text: Optional[str]
ocr_confidence: Optional[float]
ground_truth: Optional[str]
status: str
metadata: Optional[Dict]
created_at: datetime
class ConfirmRequest(BaseModel):
item_id: str
label_time_seconds: Optional[int] = None
class CorrectRequest(BaseModel):
item_id: str
ground_truth: str
label_time_seconds: Optional[int] = None
class SkipRequest(BaseModel):
item_id: str
class ExportRequest(BaseModel):
export_format: str = "generic" # generic, trocr, llama_vision
session_id: Optional[str] = None
batch_id: Optional[str] = None
class StatsResponse(BaseModel):
total_sessions: Optional[int] = None
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
pending_items: int
exportable_items: Optional[int] = None
accuracy_rate: float
avg_label_time_seconds: Optional[float] = None
# =============================================================================
# Helper Functions
# =============================================================================
def compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data."""
return hashlib.sha256(image_data).hexdigest()
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
"""
Run OCR on an image using the specified model.
Models:
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
- trocr: Microsoft TrOCR (fast for printed text)
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
- donut: Document Understanding Transformer (structured documents)
Returns:
Tuple of (ocr_text, confidence)
"""
print(f"Running OCR with model: {model}")
# Route to appropriate OCR service based on model
if model == "paddleocr":
return await run_paddleocr_wrapper(image_data, filename)
elif model == "donut":
return await run_donut_wrapper(image_data, filename)
elif model == "trocr":
return await run_trocr_wrapper(image_data, filename)
else:
# Default: Vision LLM (llama3.2-vision or similar)
return await run_vision_ocr_wrapper(image_data, filename)
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""Vision LLM OCR wrapper."""
if not VISION_OCR_AVAILABLE:
print("Vision OCR service not available")
return None, 0.0
try:
service = get_vision_ocr_service()
if not await service.is_available():
print("Vision OCR service not available (is_available check failed)")
return None, 0.0
result = await service.extract_text(
image_data,
filename=filename,
is_handwriting=True
)
return result.text, result.confidence
except Exception as e:
print(f"Vision OCR failed: {e}")
return None, 0.0
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
if not PADDLEOCR_AVAILABLE:
print("PaddleOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
# run_paddle_ocr returns (regions, raw_text)
regions, raw_text = run_paddle_ocr(image_data)
if not raw_text:
print("PaddleOCR returned empty text")
return None, 0.0
# Calculate average confidence from regions
if regions:
avg_confidence = sum(r.confidence for r in regions) / len(regions)
else:
avg_confidence = 0.5
return raw_text, avg_confidence
except Exception as e:
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""TrOCR wrapper."""
if not TROCR_AVAILABLE:
print("TrOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_trocr_ocr(image_data)
return text, confidence
except Exception as e:
print(f"TrOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
"""Donut OCR wrapper."""
if not DONUT_AVAILABLE:
print("Donut not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_donut_ocr(image_data)
return text, confidence
except Exception as e:
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
"""Save image to local storage."""
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
os.makedirs(session_dir, exist_ok=True)
filename = f"{item_id}.{extension}"
filepath = os.path.join(session_dir, filename)
with open(filepath, 'wb') as f:
f.write(image_data)
return filepath
def get_image_url(image_path: str) -> str:
"""Get URL for an image."""
# For local images, return a relative path that the frontend can use
if image_path.startswith(LOCAL_STORAGE_PATH):
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
return f"/api/v1/ocr-label/images/{relative_path}"
# For MinIO images, the path is already a URL or key
return image_path
# =============================================================================
# API Endpoints
# =============================================================================
@router.post("/sessions", response_model=SessionResponse)
async def create_session(session: SessionCreate):
"""
Create a new OCR labeling session.
A session groups related images for labeling (e.g., all scans from one class).
"""
session_id = str(uuid.uuid4())
success = await create_ocr_labeling_session(
session_id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create session")
return SessionResponse(
id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
total_items=0,
labeled_items=0,
confirmed_items=0,
corrected_items=0,
skipped_items=0,
created_at=datetime.utcnow(),
)
@router.get("/sessions", response_model=List[SessionResponse])
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
"""List all OCR labeling sessions."""
sessions = await get_ocr_labeling_sessions(limit=limit)
return [
SessionResponse(
id=s['id'],
name=s['name'],
source_type=s['source_type'],
description=s.get('description'),
ocr_model=s.get('ocr_model'),
total_items=s.get('total_items', 0),
labeled_items=s.get('labeled_items', 0),
confirmed_items=s.get('confirmed_items', 0),
corrected_items=s.get('corrected_items', 0),
skipped_items=s.get('skipped_items', 0),
created_at=s.get('created_at', datetime.utcnow()),
)
for s in sessions
]
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str):
"""Get a specific OCR labeling session."""
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return SessionResponse(
id=session['id'],
name=session['name'],
source_type=session['source_type'],
description=session.get('description'),
ocr_model=session.get('ocr_model'),
total_items=session.get('total_items', 0),
labeled_items=session.get('labeled_items', 0),
confirmed_items=session.get('confirmed_items', 0),
corrected_items=session.get('corrected_items', 0),
skipped_items=session.get('skipped_items', 0),
created_at=session.get('created_at', datetime.utcnow()),
)
@router.post("/sessions/{session_id}/upload")
async def upload_images(
session_id: str,
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),
run_ocr: bool = Form(True),
metadata: Optional[str] = Form(None), # JSON string
):
"""
Upload images to a labeling session.
Args:
session_id: Session to add images to
files: Image files to upload (PNG, JPG, PDF)
run_ocr: Whether to run OCR immediately (default: True)
metadata: Optional JSON metadata (subject, year, etc.)
"""
import json
# Verify session exists
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Parse metadata
meta_dict = None
if metadata:
try:
meta_dict = json.loads(metadata)
except json.JSONDecodeError:
meta_dict = {"raw": metadata}
results = []
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
for file in files:
# Read file content
content = await file.read()
# Compute hash for deduplication
image_hash = compute_image_hash(content)
# Generate item ID
item_id = str(uuid.uuid4())
# Determine file extension
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
extension = 'png'
# Save image
if MINIO_AVAILABLE:
# Upload to MinIO
try:
image_path = upload_ocr_image(session_id, item_id, content, extension)
except Exception as e:
print(f"MinIO upload failed, using local storage: {e}")
image_path = save_image_locally(session_id, item_id, content, extension)
else:
# Save locally
image_path = save_image_locally(session_id, item_id, content, extension)
# Run OCR if requested
ocr_text = None
ocr_confidence = None
if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now
ocr_text, ocr_confidence = await run_ocr_on_image(
content,
file.filename or f"{item_id}.{extension}",
model=ocr_model
)
# Add to database
success = await add_ocr_labeling_item(
item_id=item_id,
session_id=session_id,
image_path=image_path,
image_hash=image_hash,
ocr_text=ocr_text,
ocr_confidence=ocr_confidence,
ocr_model=ocr_model if ocr_text else None,
metadata=meta_dict,
)
if success:
results.append({
"id": item_id,
"filename": file.filename,
"image_path": image_path,
"image_hash": image_hash,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"status": "pending",
})
return {
"session_id": session_id,
"uploaded_count": len(results),
"items": results,
}
@router.get("/queue", response_model=List[ItemResponse])
async def get_labeling_queue(
session_id: Optional[str] = Query(None),
status: str = Query("pending"),
limit: int = Query(10, ge=1, le=50),
):
"""
Get items from the labeling queue.
Args:
session_id: Optional filter by session
status: Filter by status (pending, confirmed, corrected, skipped)
limit: Number of items to return
"""
items = await get_ocr_labeling_queue(
session_id=session_id,
status=status,
limit=limit,
)
return [
ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
for item in items
]
@router.get("/items/{item_id}", response_model=ItemResponse)
async def get_item(item_id: str):
"""Get a specific labeling item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
@router.post("/confirm")
async def confirm_item(request: ConfirmRequest):
"""
Confirm that OCR text is correct.
Sets ground_truth = ocr_text and marks item as confirmed.
"""
success = await confirm_ocr_label(
item_id=request.item_id,
labeled_by="admin", # TODO: Get from auth
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to confirm item")
return {"status": "confirmed", "item_id": request.item_id}
@router.post("/correct")
async def correct_item(request: CorrectRequest):
"""
Save corrected ground truth for an item.
Use this when OCR text is wrong and needs manual correction.
"""
success = await correct_ocr_label(
item_id=request.item_id,
ground_truth=request.ground_truth,
labeled_by="admin", # TODO: Get from auth
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to correct item")
return {"status": "corrected", "item_id": request.item_id}
@router.post("/skip")
async def skip_item(request: SkipRequest):
"""
Skip an item (unusable image, etc.).
Skipped items are not included in training exports.
"""
success = await skip_ocr_item(
item_id=request.item_id,
labeled_by="admin", # TODO: Get from auth
)
if not success:
raise HTTPException(status_code=400, detail="Failed to skip item")
return {"status": "skipped", "item_id": request.item_id}
@router.get("/stats")
async def get_stats(session_id: Optional[str] = Query(None)):
"""
Get labeling statistics.
Args:
session_id: Optional session ID for session-specific stats
"""
stats = await get_ocr_labeling_stats(session_id=session_id)
if "error" in stats:
raise HTTPException(status_code=500, detail=stats["error"])
return stats
@router.post("/export")
async def export_data(request: ExportRequest):
"""
Export labeled data for training.
Formats:
- generic: JSONL with image_path and ground_truth
- trocr: Format for TrOCR/Microsoft Transformer fine-tuning
- llama_vision: Format for llama3.2-vision fine-tuning
Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/
"""
# First, get samples from database
db_samples = await export_training_samples(
export_format=request.export_format,
session_id=request.session_id,
batch_id=request.batch_id,
exported_by="admin", # TODO: Get from auth
)
if not db_samples:
return {
"export_format": request.export_format,
"batch_id": request.batch_id,
"exported_count": 0,
"samples": [],
"message": "No labeled samples found to export",
}
# If training export service is available, also write to disk
export_result = None
if TRAINING_EXPORT_AVAILABLE:
try:
export_service = get_training_export_service()
# Convert DB samples to TrainingSample objects
training_samples = []
for s in db_samples:
training_samples.append(TrainingSample(
id=s.get('id', s.get('item_id', '')),
image_path=s.get('image_path', ''),
ground_truth=s.get('ground_truth', ''),
ocr_text=s.get('ocr_text'),
ocr_confidence=s.get('ocr_confidence'),
metadata=s.get('metadata'),
))
# Export to files
export_result = export_service.export(
samples=training_samples,
export_format=request.export_format,
batch_id=request.batch_id,
)
except Exception as e:
print(f"Training export failed: {e}")
# Continue without file export
response = {
"export_format": request.export_format,
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
"exported_count": len(db_samples),
"samples": db_samples,
}
if export_result:
response["export_path"] = export_result.export_path
response["manifest_path"] = export_result.manifest_path
return response
@router.get("/training-samples")
async def list_training_samples(
export_format: Optional[str] = Query(None),
batch_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
):
"""Get exported training samples."""
samples = await get_training_samples(
export_format=export_format,
batch_id=batch_id,
limit=limit,
)
return {
"count": len(samples),
"samples": samples,
}
@router.get("/images/{path:path}")
async def get_image(path: str):
"""
Serve an image from local storage.
This endpoint is used when images are stored locally (not in MinIO).
"""
from fastapi.responses import FileResponse
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="Image not found")
# Determine content type
extension = filepath.split('.')[-1].lower()
content_type = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'pdf': 'application/pdf',
}.get(extension, 'application/octet-stream')
return FileResponse(filepath, media_type=content_type)
@router.post("/run-ocr/{item_id}")
async def run_ocr_for_item(item_id: str):
"""
Run OCR on an existing item.
Use this to re-run OCR or run it if it was skipped during upload.
"""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
# Load image
image_path = item['image_path']
if image_path.startswith(LOCAL_STORAGE_PATH):
# Load from local storage
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image file not found")
with open(image_path, 'rb') as f:
image_data = f.read()
elif MINIO_AVAILABLE:
# Load from MinIO
try:
image_data = get_ocr_image(image_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
else:
raise HTTPException(status_code=500, detail="Cannot load image")
# Get OCR model from session
session = await get_ocr_labeling_session(item['session_id'])
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
# Run OCR
ocr_text, ocr_confidence = await run_ocr_on_image(
image_data,
os.path.basename(image_path),
model=ocr_model
)
if ocr_text is None:
raise HTTPException(status_code=500, detail="OCR failed")
# Update item in database
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE ocr_labeling_items
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
WHERE id = $1
""",
item_id, ocr_text, ocr_confidence, ocr_model
)
return {
"item_id": item_id,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"ocr_model": ocr_model,
}
@router.get("/exports")
async def list_exports(export_format: Optional[str] = Query(None)):
"""
List all available training data exports.
Args:
export_format: Optional filter by format (generic, trocr, llama_vision)
Returns:
List of export manifests with paths and metadata
"""
if not TRAINING_EXPORT_AVAILABLE:
return {
"exports": [],
"message": "Training export service not available",
}
try:
export_service = get_training_export_service()
exports = export_service.list_exports(export_format=export_format)
return {
"count": len(exports),
"exports": exports,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
pass
try:
from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401
except ImportError:
pass
try:
from services.trocr_service import run_trocr_ocr # noqa: F401
except ImportError:
pass
try:
from services.donut_ocr_service import run_donut_ocr # noqa: F401
except ImportError:
pass
try:
from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401
except ImportError:
pass
# Routes (router is the main export for app.include_router)
from ocr_labeling_routes import router # noqa: F401
from ocr_labeling_upload_routes import router as upload_router # noqa: F401
@@ -0,0 +1,205 @@
"""
OCR Labeling - Helper Functions and OCR Wrappers
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
DATENSCHUTZ/PRIVACY:
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
- Keine Daten werden an externe Server gesendet
"""
import os
import hashlib
from ocr_labeling_models import LOCAL_STORAGE_PATH
# Try to import Vision OCR service
try:
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
VISION_OCR_AVAILABLE = True
except ImportError:
VISION_OCR_AVAILABLE = False
print("Warning: Vision OCR service not available")
# Try to import PaddleOCR from hybrid_vocab_extractor
try:
from hybrid_vocab_extractor import run_paddle_ocr
PADDLEOCR_AVAILABLE = True
except ImportError:
PADDLEOCR_AVAILABLE = False
print("Warning: PaddleOCR not available")
# Try to import TrOCR service
try:
from services.trocr_service import run_trocr_ocr
TROCR_AVAILABLE = True
except ImportError:
TROCR_AVAILABLE = False
print("Warning: TrOCR service not available")
# Try to import Donut service
try:
from services.donut_ocr_service import run_donut_ocr
DONUT_AVAILABLE = True
except ImportError:
DONUT_AVAILABLE = False
print("Warning: Donut OCR service not available")
# Try to import MinIO storage
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
print("Warning: MinIO storage not available, using local storage")
# Try to import Training Export Service
try:
from training_export_service import (
TrainingExportService,
TrainingSample,
get_training_export_service,
)
TRAINING_EXPORT_AVAILABLE = True
except ImportError:
TRAINING_EXPORT_AVAILABLE = False
print("Warning: Training export service not available")
# =============================================================================
# Helper Functions
# =============================================================================
def compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data."""
return hashlib.sha256(image_data).hexdigest()
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
"""
Run OCR on an image using the specified model.
Models:
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
- trocr: Microsoft TrOCR (fast for printed text)
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
- donut: Document Understanding Transformer (structured documents)
Returns:
Tuple of (ocr_text, confidence)
"""
print(f"Running OCR with model: {model}")
# Route to appropriate OCR service based on model
if model == "paddleocr":
return await run_paddleocr_wrapper(image_data, filename)
elif model == "donut":
return await run_donut_wrapper(image_data, filename)
elif model == "trocr":
return await run_trocr_wrapper(image_data, filename)
else:
# Default: Vision LLM (llama3.2-vision or similar)
return await run_vision_ocr_wrapper(image_data, filename)
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""Vision LLM OCR wrapper."""
if not VISION_OCR_AVAILABLE:
print("Vision OCR service not available")
return None, 0.0
try:
service = get_vision_ocr_service()
if not await service.is_available():
print("Vision OCR service not available (is_available check failed)")
return None, 0.0
result = await service.extract_text(
image_data,
filename=filename,
is_handwriting=True
)
return result.text, result.confidence
except Exception as e:
print(f"Vision OCR failed: {e}")
return None, 0.0
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
if not PADDLEOCR_AVAILABLE:
print("PaddleOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
# run_paddle_ocr returns (regions, raw_text)
regions, raw_text = run_paddle_ocr(image_data)
if not raw_text:
print("PaddleOCR returned empty text")
return None, 0.0
# Calculate average confidence from regions
if regions:
avg_confidence = sum(r.confidence for r in regions) / len(regions)
else:
avg_confidence = 0.5
return raw_text, avg_confidence
except Exception as e:
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""TrOCR wrapper."""
if not TROCR_AVAILABLE:
print("TrOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_trocr_ocr(image_data)
return text, confidence
except Exception as e:
print(f"TrOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
"""Donut OCR wrapper."""
if not DONUT_AVAILABLE:
print("Donut not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_donut_ocr(image_data)
return text, confidence
except Exception as e:
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
"""Save image to local storage."""
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
os.makedirs(session_dir, exist_ok=True)
filename = f"{item_id}.{extension}"
filepath = os.path.join(session_dir, filename)
with open(filepath, 'wb') as f:
f.write(image_data)
return filepath
def get_image_url(image_path: str) -> str:
"""Get URL for an image."""
# For local images, return a relative path that the frontend can use
if image_path.startswith(LOCAL_STORAGE_PATH):
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
return f"/api/v1/ocr-label/images/{relative_path}"
# For MinIO images, the path is already a URL or key
return image_path
@@ -0,0 +1,86 @@
"""
OCR Labeling - Pydantic Models and Constants
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
"""
import os
from pydantic import BaseModel
from typing import Optional, Dict
from datetime import datetime
# Local storage path (fallback if MinIO not available)
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
# =============================================================================
# Pydantic Models
# =============================================================================
class SessionCreate(BaseModel):
name: str
source_type: str = "klausur" # klausur, handwriting_sample, scan
description: Optional[str] = None
ocr_model: Optional[str] = "llama3.2-vision:11b"
class SessionResponse(BaseModel):
id: str
name: str
source_type: str
description: Optional[str]
ocr_model: Optional[str]
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
skipped_items: int
created_at: datetime
class ItemResponse(BaseModel):
id: str
session_id: str
session_name: str
image_path: str
image_url: Optional[str]
ocr_text: Optional[str]
ocr_confidence: Optional[float]
ground_truth: Optional[str]
status: str
metadata: Optional[Dict]
created_at: datetime
class ConfirmRequest(BaseModel):
item_id: str
label_time_seconds: Optional[int] = None
class CorrectRequest(BaseModel):
item_id: str
ground_truth: str
label_time_seconds: Optional[int] = None
class SkipRequest(BaseModel):
item_id: str
class ExportRequest(BaseModel):
export_format: str = "generic" # generic, trocr, llama_vision
session_id: Optional[str] = None
batch_id: Optional[str] = None
class StatsResponse(BaseModel):
total_sessions: Optional[int] = None
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
pending_items: int
exportable_items: Optional[int] = None
accuracy_rate: float
avg_label_time_seconds: Optional[float] = None
@@ -0,0 +1,241 @@
"""
OCR Labeling - Session and Labeling Route Handlers
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
Endpoints:
- POST /sessions - Create labeling session
- GET /sessions - List sessions
- GET /sessions/{id} - Get session
- GET /queue - Get labeling queue
- GET /items/{id} - Get item
- POST /confirm - Confirm OCR
- POST /correct - Correct ground truth
- POST /skip - Skip item
- GET /stats - Get statistics
"""
from fastapi import APIRouter, HTTPException, Query
from typing import Optional, List
from datetime import datetime
import uuid
from metrics_db import (
create_ocr_labeling_session,
get_ocr_labeling_sessions,
get_ocr_labeling_session,
get_ocr_labeling_queue,
get_ocr_labeling_item,
confirm_ocr_label,
correct_ocr_label,
skip_ocr_item,
get_ocr_labeling_stats,
)
from ocr_labeling_models import (
SessionCreate, SessionResponse, ItemResponse,
ConfirmRequest, CorrectRequest, SkipRequest,
)
from ocr_labeling_helpers import get_image_url
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
# =============================================================================
# Session Endpoints
# =============================================================================
@router.post("/sessions", response_model=SessionResponse)
async def create_session(session: SessionCreate):
"""Create a new OCR labeling session."""
session_id = str(uuid.uuid4())
success = await create_ocr_labeling_session(
session_id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create session")
return SessionResponse(
id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
total_items=0,
labeled_items=0,
confirmed_items=0,
corrected_items=0,
skipped_items=0,
created_at=datetime.utcnow(),
)
@router.get("/sessions", response_model=List[SessionResponse])
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
"""List all OCR labeling sessions."""
sessions = await get_ocr_labeling_sessions(limit=limit)
return [
SessionResponse(
id=s['id'],
name=s['name'],
source_type=s['source_type'],
description=s.get('description'),
ocr_model=s.get('ocr_model'),
total_items=s.get('total_items', 0),
labeled_items=s.get('labeled_items', 0),
confirmed_items=s.get('confirmed_items', 0),
corrected_items=s.get('corrected_items', 0),
skipped_items=s.get('skipped_items', 0),
created_at=s.get('created_at', datetime.utcnow()),
)
for s in sessions
]
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str):
"""Get a specific OCR labeling session."""
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return SessionResponse(
id=session['id'],
name=session['name'],
source_type=session['source_type'],
description=session.get('description'),
ocr_model=session.get('ocr_model'),
total_items=session.get('total_items', 0),
labeled_items=session.get('labeled_items', 0),
confirmed_items=session.get('confirmed_items', 0),
corrected_items=session.get('corrected_items', 0),
skipped_items=session.get('skipped_items', 0),
created_at=session.get('created_at', datetime.utcnow()),
)
# =============================================================================
# Queue and Item Endpoints
# =============================================================================
@router.get("/queue", response_model=List[ItemResponse])
async def get_labeling_queue(
session_id: Optional[str] = Query(None),
status: str = Query("pending"),
limit: int = Query(10, ge=1, le=50),
):
"""Get items from the labeling queue."""
items = await get_ocr_labeling_queue(
session_id=session_id,
status=status,
limit=limit,
)
return [
ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
for item in items
]
@router.get("/items/{item_id}", response_model=ItemResponse)
async def get_item(item_id: str):
"""Get a specific labeling item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
# =============================================================================
# Labeling Action Endpoints
# =============================================================================
@router.post("/confirm")
async def confirm_item(request: ConfirmRequest):
"""Confirm that OCR text is correct."""
success = await confirm_ocr_label(
item_id=request.item_id,
labeled_by="admin",
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to confirm item")
return {"status": "confirmed", "item_id": request.item_id}
@router.post("/correct")
async def correct_item(request: CorrectRequest):
"""Save corrected ground truth for an item."""
success = await correct_ocr_label(
item_id=request.item_id,
ground_truth=request.ground_truth,
labeled_by="admin",
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to correct item")
return {"status": "corrected", "item_id": request.item_id}
@router.post("/skip")
async def skip_item(request: SkipRequest):
"""Skip an item (unusable image, etc.)."""
success = await skip_ocr_item(
item_id=request.item_id,
labeled_by="admin",
)
if not success:
raise HTTPException(status_code=400, detail="Failed to skip item")
return {"status": "skipped", "item_id": request.item_id}
@router.get("/stats")
async def get_stats(session_id: Optional[str] = Query(None)):
"""Get labeling statistics."""
stats = await get_ocr_labeling_stats(session_id=session_id)
if "error" in stats:
raise HTTPException(status_code=500, detail=stats["error"])
return stats
@@ -0,0 +1,313 @@
"""
OCR Labeling - Upload, Run-OCR, and Export Route Handlers
Extracted from ocr_labeling_routes.py to keep files under 500 LOC.
Endpoints:
- POST /sessions/{id}/upload - Upload images for labeling
- POST /run-ocr/{item_id} - Run OCR on existing item
- POST /export - Export training data
- GET /training-samples - List training samples
- GET /images/{path} - Serve images from local storage
- GET /exports - List exports
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from typing import Optional, List
import uuid
import os
from metrics_db import (
get_ocr_labeling_session,
add_ocr_labeling_item,
get_ocr_labeling_item,
export_training_samples,
get_training_samples,
)
from ocr_labeling_models import (
ExportRequest,
LOCAL_STORAGE_PATH,
)
from ocr_labeling_helpers import (
compute_image_hash, run_ocr_on_image,
save_image_locally,
MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE,
)
# Conditional imports
try:
from minio_storage import upload_ocr_image, get_ocr_image
except ImportError:
pass
try:
from training_export_service import TrainingSample, get_training_export_service
except ImportError:
pass
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
@router.post("/sessions/{session_id}/upload")
async def upload_images(
session_id: str,
files: List[UploadFile] = File(...),
run_ocr: bool = Form(True),
metadata: Optional[str] = Form(None),
):
"""
Upload images to a labeling session.
Args:
session_id: Session to add images to
files: Image files to upload (PNG, JPG, PDF)
run_ocr: Whether to run OCR immediately (default: True)
metadata: Optional JSON metadata (subject, year, etc.)
"""
import json
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
meta_dict = None
if metadata:
try:
meta_dict = json.loads(metadata)
except json.JSONDecodeError:
meta_dict = {"raw": metadata}
results = []
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
for file in files:
content = await file.read()
image_hash = compute_image_hash(content)
item_id = str(uuid.uuid4())
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
extension = 'png'
if MINIO_AVAILABLE:
try:
image_path = upload_ocr_image(session_id, item_id, content, extension)
except Exception as e:
print(f"MinIO upload failed, using local storage: {e}")
image_path = save_image_locally(session_id, item_id, content, extension)
else:
image_path = save_image_locally(session_id, item_id, content, extension)
ocr_text = None
ocr_confidence = None
if run_ocr and extension != 'pdf':
ocr_text, ocr_confidence = await run_ocr_on_image(
content,
file.filename or f"{item_id}.{extension}",
model=ocr_model
)
success = await add_ocr_labeling_item(
item_id=item_id,
session_id=session_id,
image_path=image_path,
image_hash=image_hash,
ocr_text=ocr_text,
ocr_confidence=ocr_confidence,
ocr_model=ocr_model if ocr_text else None,
metadata=meta_dict,
)
if success:
results.append({
"id": item_id,
"filename": file.filename,
"image_path": image_path,
"image_hash": image_hash,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"status": "pending",
})
return {
"session_id": session_id,
"uploaded_count": len(results),
"items": results,
}
@router.post("/export")
async def export_data(request: ExportRequest):
"""Export labeled data for training."""
db_samples = await export_training_samples(
export_format=request.export_format,
session_id=request.session_id,
batch_id=request.batch_id,
exported_by="admin",
)
if not db_samples:
return {
"export_format": request.export_format,
"batch_id": request.batch_id,
"exported_count": 0,
"samples": [],
"message": "No labeled samples found to export",
}
export_result = None
if TRAINING_EXPORT_AVAILABLE:
try:
export_service = get_training_export_service()
training_samples = []
for s in db_samples:
training_samples.append(TrainingSample(
id=s.get('id', s.get('item_id', '')),
image_path=s.get('image_path', ''),
ground_truth=s.get('ground_truth', ''),
ocr_text=s.get('ocr_text'),
ocr_confidence=s.get('ocr_confidence'),
metadata=s.get('metadata'),
))
export_result = export_service.export(
samples=training_samples,
export_format=request.export_format,
batch_id=request.batch_id,
)
except Exception as e:
print(f"Training export failed: {e}")
response = {
"export_format": request.export_format,
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
"exported_count": len(db_samples),
"samples": db_samples,
}
if export_result:
response["export_path"] = export_result.export_path
response["manifest_path"] = export_result.manifest_path
return response
@router.get("/training-samples")
async def list_training_samples(
export_format: Optional[str] = Query(None),
batch_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
):
"""Get exported training samples."""
samples = await get_training_samples(
export_format=export_format,
batch_id=batch_id,
limit=limit,
)
return {
"count": len(samples),
"samples": samples,
}
@router.get("/images/{path:path}")
async def get_image(path: str):
"""Serve an image from local storage."""
from fastapi.responses import FileResponse
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="Image not found")
extension = filepath.split('.')[-1].lower()
content_type = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'pdf': 'application/pdf',
}.get(extension, 'application/octet-stream')
return FileResponse(filepath, media_type=content_type)
@router.post("/run-ocr/{item_id}")
async def run_ocr_for_item(item_id: str):
"""Run OCR on an existing item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
image_path = item['image_path']
if image_path.startswith(LOCAL_STORAGE_PATH):
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image file not found")
with open(image_path, 'rb') as f:
image_data = f.read()
elif MINIO_AVAILABLE:
try:
image_data = get_ocr_image(image_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
else:
raise HTTPException(status_code=500, detail="Cannot load image")
session = await get_ocr_labeling_session(item['session_id'])
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
ocr_text, ocr_confidence = await run_ocr_on_image(
image_data,
os.path.basename(image_path),
model=ocr_model
)
if ocr_text is None:
raise HTTPException(status_code=500, detail="OCR failed")
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE ocr_labeling_items
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
WHERE id = $1
""",
item_id, ocr_text, ocr_confidence, ocr_model
)
return {
"item_id": item_id,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"ocr_model": ocr_model,
}
@router.get("/exports")
async def list_exports(export_format: Optional[str] = Query(None)):
"""List all available training data exports."""
if not TRAINING_EXPORT_AVAILABLE:
return {
"exports": [],
"message": "Training export service not available",
}
try:
export_service = get_training_export_service()
exports = export_service.list_exports(export_format=export_format)
return {
"count": len(exports),
"exports": exports,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
+12 -694
View File
@@ -1,705 +1,23 @@
"""
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints.
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export.
Extracted from ocr_pipeline_api.py — contains:
- POST /sessions/{session_id}/reprocess (clear downstream + restart from step)
- POST /sessions/{session_id}/run-auto (full auto-mode with SSE streaming)
Split into submodules:
- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess
- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
import time
from dataclasses import asdict
from typing import Any, Dict, List, Optional
from fastapi import APIRouter
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_detect_header_footer_gaps,
_detect_sub_columns,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
analyze_layout,
build_cell_grid,
classify_column_types,
create_layout_image,
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
detect_column_geometry,
detect_row_geometry,
_apply_shear,
dewarp_image,
llm_review_entries,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_get_base_image_png,
_append_pipeline_log,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
logger = logging.getLogger(__name__)
from ocr_pipeline_reprocess import router as _reprocess_router
from ocr_pipeline_auto_steps import router as _steps_router
# Combine both sub-routers into a single router for backwards compatibility.
# The consumer imports `from ocr_pipeline_auto import router as _auto_router`.
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
router.include_router(_reprocess_router)
router.include_router(_steps_router)
# ---------------------------------------------------------------------------
# Reprocess endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reprocess")
async def reprocess_session(session_id: str, request: Request):
"""Re-run pipeline from a specific step, clearing downstream data.
Body: {"from_step": 5} (1-indexed step number)
Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) →
Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10)
Clears downstream results:
- from_step <= 1: orientation_result + all downstream
- from_step <= 2: deskew_result + all downstream
- from_step <= 3: dewarp_result + all downstream
- from_step <= 4: crop_result + all downstream
- from_step <= 5: column_result, row_result, word_result
- from_step <= 6: row_result, word_result
- from_step <= 7: word_result (cells, vocab_entries)
- from_step <= 8: word_result.llm_review only
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
from_step = body.get("from_step", 1)
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
update_kwargs: Dict[str, Any] = {"current_step": from_step}
# Clear downstream data based on from_step
# New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) →
# Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11)
if from_step <= 8:
update_kwargs["word_result"] = None
elif from_step == 9:
# Only clear LLM review from word_result
word_result = session.get("word_result")
if word_result:
word_result.pop("llm_review", None)
word_result.pop("llm_corrections", None)
update_kwargs["word_result"] = word_result
if from_step <= 7:
update_kwargs["row_result"] = None
if from_step <= 6:
update_kwargs["column_result"] = None
if from_step <= 4:
update_kwargs["crop_result"] = None
if from_step <= 3:
update_kwargs["dewarp_result"] = None
if from_step <= 2:
update_kwargs["deskew_result"] = None
if from_step <= 1:
update_kwargs["orientation_result"] = None
await update_session_db(session_id, **update_kwargs)
# Also clear cache
if session_id in _cache:
for key in list(update_kwargs.keys()):
if key != "current_step":
_cache[session_id][key] = update_kwargs[key]
_cache[session_id]["current_step"] = from_step
logger.info(f"Session {session_id} reprocessing from step {from_step}")
return {
"session_id": session_id,
"from_step": from_step,
"cleared": [k for k in update_kwargs if k != "current_step"],
}
# ---------------------------------------------------------------------------
# VLM shear detection helper (used by dewarp step in auto-mode)
# ---------------------------------------------------------------------------
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
import re
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
import json
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
# ---------------------------------------------------------------------------
# Auto-mode orchestrator
# ---------------------------------------------------------------------------
class RunAutoRequest(BaseModel):
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
pronunciation: str = "british"
skip_llm_review: bool = False
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
"""Format a single SSE event line."""
import json as _json
payload = {"step": step, "status": status, **data}
return f"data: {_json.dumps(payload)}\n\n"
@router.post("/sessions/{session_id}/run-auto")
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
Steps:
1. Deskew — straighten the scan
2. Dewarp — correct vertical shear (ensemble CV or VLM)
3. Columns — detect column layout
4. Rows — detect row layout
5. Words — OCR each cell
6. LLM review — correct OCR errors (optional)
Already-completed steps are skipped unless `from_step` forces a rerun.
Yields SSE events of the form:
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
Final event:
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
"""
if req.from_step < 1 or req.from_step > 6:
raise HTTPException(status_code=400, detail="from_step must be 1-6")
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
if session_id not in _cache:
await _load_session_to_cache(session_id)
async def _generate():
steps_run: List[str] = []
steps_skipped: List[str] = []
error_step: Optional[str] = None
session = await get_session_db(session_id)
if not session:
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
return
cached = _get_cached(session_id)
# -----------------------------------------------------------------
# Step 1: Deskew
# -----------------------------------------------------------------
if req.from_step <= 1:
yield await _auto_sse_event("deskew", "start", {})
try:
t0 = time.time()
orig_bgr = cached.get("original_bgr")
if orig_bgr is None:
raise ValueError("Original image not loaded")
# Method 1: Hough lines
try:
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
except Exception:
deskewed_hough, angle_hough = orig_bgr, 0.0
# Method 2: Word alignment
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
# Pick best method
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
method_used = "word_alignment"
angle_applied = angle_wa
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
if deskewed_bgr is None:
deskewed_bgr = deskewed_hough
method_used = "hough"
angle_applied = angle_hough
else:
method_used = "hough"
angle_applied = angle_hough
deskewed_bgr = deskewed_hough
success, png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = png_buf.tobytes() if success else b""
deskew_result = {
"method_used": method_used,
"rotation_degrees": round(float(angle_applied), 3),
"duration_seconds": round(time.time() - t0, 2),
}
cached["deskewed_bgr"] = deskewed_bgr
cached["deskew_result"] = deskew_result
await update_session_db(
session_id,
deskewed_png=deskewed_png,
deskew_result=deskew_result,
auto_rotation_degrees=float(angle_applied),
current_step=3,
)
session = await get_session_db(session_id)
steps_run.append("deskew")
yield await _auto_sse_event("deskew", "done", deskew_result)
except Exception as e:
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
error_step = "deskew"
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("deskew")
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
# -----------------------------------------------------------------
# Step 2: Dewarp
# -----------------------------------------------------------------
if req.from_step <= 2:
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
try:
t0 = time.time()
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise ValueError("Deskewed image not available")
if req.dewarp_method == "vlm":
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success_enc else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success_enc else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(time.time() - t0, 2),
"detections": dewarp_info.get("detections", []),
}
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
session = await get_session_db(session_id)
steps_run.append("dewarp")
yield await _auto_sse_event("dewarp", "done", dewarp_result)
except Exception as e:
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
error_step = "dewarp"
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("dewarp")
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
# -----------------------------------------------------------------
# Step 3: Columns
# -----------------------------------------------------------------
if req.from_step <= 3:
yield await _auto_sse_event("columns", "start", {})
try:
t0 = time.time()
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if col_img is None:
raise ValueError("Cropped/dewarped image not available")
ocr_img = create_ocr_image(col_img)
h, w = ocr_img.shape[:2]
geo_result = detect_column_geometry(ocr_img, col_img)
if geo_result is None:
layout_img = create_layout_image(col_img)
regions = analyze_layout(layout_img, ocr_img)
cached["_word_dicts"] = None
cached["_inv"] = None
cached["_content_bounds"] = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
content_w = right_x - left_x
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
columns = [asdict(r) for r in regions]
column_result = {
"columns": columns,
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
"duration_seconds": round(time.time() - t0, 2),
}
cached["column_result"] = column_result
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None, current_step=6)
session = await get_session_db(session_id)
steps_run.append("columns")
yield await _auto_sse_event("columns", "done", {
"column_count": len(columns),
"duration_seconds": column_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
error_step = "columns"
yield await _auto_sse_event("columns", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("columns")
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
# -----------------------------------------------------------------
# Step 4: Rows
# -----------------------------------------------------------------
if req.from_step <= 4:
yield await _auto_sse_event("rows", "start", {})
try:
t0 = time.time()
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
if not column_result or not column_result.get("columns"):
raise ValueError("Column detection must complete first")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
ocr_img_tmp = create_ocr_image(row_img)
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
if geo_result is None:
raise ValueError("Column geometry detection failed — cannot detect rows")
_g, lx, rx, ty, by, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (lx, rx, ty, by)
content_bounds = (lx, rx, ty, by)
left_x, right_x, top_y, bottom_y = content_bounds
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
row_list = [
{
"index": r.index, "x": r.x, "y": r.y,
"width": r.width, "height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
}
for r in row_geoms
]
row_result = {
"rows": row_list,
"row_count": len(row_list),
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
"duration_seconds": round(time.time() - t0, 2),
}
cached["row_result"] = row_result
await update_session_db(session_id, row_result=row_result, current_step=7)
session = await get_session_db(session_id)
steps_run.append("rows")
yield await _auto_sse_event("rows", "done", {
"row_count": len(row_list),
"content_rows": row_result["content_rows"],
"duration_seconds": row_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
error_step = "rows"
yield await _auto_sse_event("rows", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("rows")
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
# -----------------------------------------------------------------
# Step 5: Words (OCR)
# -----------------------------------------------------------------
if req.from_step <= 5:
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
try:
t0 = time.time()
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
row_result = session.get("row_result") or cached.get("row_result")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
row_geoms = [
RowGeometry(
index=r["index"], x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0), words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
word_dicts = cached.get("_word_dicts")
if word_dicts is not None:
content_bounds = cached.get("_content_bounds")
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
ocr_img = create_ocr_image(word_img)
img_h, img_w = word_img.shape[:2]
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=req.ocr_engine, img_bgr=word_img,
)
duration = time.time() - t0
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
# Apply IPA phonetic fixes directly to cell texts
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
word_result_data = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": len(columns_meta),
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
word_result_data["vocab_entries"] = entries
word_result_data["entries"] = entries
word_result_data["entry_count"] = len(entries)
word_result_data["summary"]["total_entries"] = len(entries)
await update_session_db(session_id, word_result=word_result_data, current_step=8)
cached["word_result"] = word_result_data
session = await get_session_db(session_id)
steps_run.append("words")
yield await _auto_sse_event("words", "done", {
"total_cells": len(cells),
"layout": word_result_data["layout"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": word_result_data["summary"],
})
except Exception as e:
logger.error(f"Auto-mode words failed for {session_id}: {e}")
error_step = "words"
yield await _auto_sse_event("words", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("words")
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
# -----------------------------------------------------------------
# Step 6: LLM Review (optional)
# -----------------------------------------------------------------
if req.from_step <= 6 and not req.skip_llm_review:
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
try:
session = await get_session_db(session_id)
word_result = session.get("word_result") or cached.get("word_result")
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
if not entries:
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
steps_skipped.append("llm_review")
else:
reviewed = await llm_review_entries(entries)
session = await get_session_db(session_id)
word_result_updated = dict(session.get("word_result") or {})
word_result_updated["entries"] = reviewed
word_result_updated["vocab_entries"] = reviewed
word_result_updated["llm_reviewed"] = True
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
cached["word_result"] = word_result_updated
steps_run.append("llm_review")
yield await _auto_sse_event("llm_review", "done", {
"entries_reviewed": len(reviewed),
"model": OLLAMA_REVIEW_MODEL,
})
except Exception as e:
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
steps_skipped.append("llm_review")
else:
steps_skipped.append("llm_review")
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
# -----------------------------------------------------------------
# Final event
# -----------------------------------------------------------------
yield await _auto_sse_event("complete", "done", {
"steps_run": steps_run,
"steps_skipped": steps_skipped,
})
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
__all__ = ["router"]
@@ -0,0 +1,84 @@
"""
OCR Pipeline Auto-Mode Helpers.
VLM shear detection, SSE event formatting, and request models.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
from typing import Any, Dict
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class RunAutoRequest(BaseModel):
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
pronunciation: str = "british"
skip_llm_review: bool = False
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
"""Format a single SSE event line."""
payload = {"step": step, "status": status, **data}
return f"data: {json.dumps(payload)}\n\n"
async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
@@ -0,0 +1,528 @@
"""
OCR Pipeline Auto-Mode Orchestrator.
POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from dataclasses import asdict
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_detect_header_footer_gaps,
_detect_sub_columns,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
analyze_layout,
build_cell_grid,
classify_column_types,
create_layout_image,
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
detect_column_geometry,
detect_row_geometry,
_apply_shear,
dewarp_image,
llm_review_entries,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_auto_helpers import (
RunAutoRequest,
auto_sse_event as _auto_sse_event,
detect_shear_with_vlm as _detect_shear_with_vlm,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/run-auto")
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
Steps:
1. Deskew -- straighten the scan
2. Dewarp -- correct vertical shear (ensemble CV or VLM)
3. Columns -- detect column layout
4. Rows -- detect row layout
5. Words -- OCR each cell
6. LLM review -- correct OCR errors (optional)
Already-completed steps are skipped unless `from_step` forces a rerun.
Yields SSE events of the form:
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
Final event:
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
"""
if req.from_step < 1 or req.from_step > 6:
raise HTTPException(status_code=400, detail="from_step must be 1-6")
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
if session_id not in _cache:
await _load_session_to_cache(session_id)
async def _generate():
steps_run: List[str] = []
steps_skipped: List[str] = []
error_step: Optional[str] = None
session = await get_session_db(session_id)
if not session:
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
return
cached = _get_cached(session_id)
# Step 1: Deskew
if req.from_step <= 1:
yield await _auto_sse_event("deskew", "start", {})
try:
t0 = time.time()
orig_bgr = cached.get("original_bgr")
if orig_bgr is None:
raise ValueError("Original image not loaded")
try:
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
except Exception:
deskewed_hough, angle_hough = orig_bgr, 0.0
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
method_used = "word_alignment"
angle_applied = angle_wa
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
if deskewed_bgr is None:
deskewed_bgr = deskewed_hough
method_used = "hough"
angle_applied = angle_hough
else:
method_used = "hough"
angle_applied = angle_hough
deskewed_bgr = deskewed_hough
success, png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = png_buf.tobytes() if success else b""
deskew_result = {
"method_used": method_used,
"rotation_degrees": round(float(angle_applied), 3),
"duration_seconds": round(time.time() - t0, 2),
}
cached["deskewed_bgr"] = deskewed_bgr
cached["deskew_result"] = deskew_result
await update_session_db(
session_id,
deskewed_png=deskewed_png,
deskew_result=deskew_result,
auto_rotation_degrees=float(angle_applied),
current_step=3,
)
session = await get_session_db(session_id)
steps_run.append("deskew")
yield await _auto_sse_event("deskew", "done", deskew_result)
except Exception as e:
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
error_step = "deskew"
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("deskew")
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
# Step 2: Dewarp
if req.from_step <= 2:
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
try:
t0 = time.time()
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise ValueError("Deskewed image not available")
if req.dewarp_method == "vlm":
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success_enc else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success_enc else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(time.time() - t0, 2),
"detections": dewarp_info.get("detections", []),
}
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
session = await get_session_db(session_id)
steps_run.append("dewarp")
yield await _auto_sse_event("dewarp", "done", dewarp_result)
except Exception as e:
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
error_step = "dewarp"
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("dewarp")
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
# Step 3: Columns
if req.from_step <= 3:
yield await _auto_sse_event("columns", "start", {})
try:
t0 = time.time()
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if col_img is None:
raise ValueError("Cropped/dewarped image not available")
ocr_img = create_ocr_image(col_img)
h, w = ocr_img.shape[:2]
geo_result = detect_column_geometry(ocr_img, col_img)
if geo_result is None:
layout_img = create_layout_image(col_img)
regions = analyze_layout(layout_img, ocr_img)
cached["_word_dicts"] = None
cached["_inv"] = None
cached["_content_bounds"] = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
content_w = right_x - left_x
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
columns = [asdict(r) for r in regions]
column_result = {
"columns": columns,
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
"duration_seconds": round(time.time() - t0, 2),
}
cached["column_result"] = column_result
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None, current_step=6)
session = await get_session_db(session_id)
steps_run.append("columns")
yield await _auto_sse_event("columns", "done", {
"column_count": len(columns),
"duration_seconds": column_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
error_step = "columns"
yield await _auto_sse_event("columns", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("columns")
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
# Step 4: Rows
if req.from_step <= 4:
yield await _auto_sse_event("rows", "start", {})
try:
t0 = time.time()
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
if not column_result or not column_result.get("columns"):
raise ValueError("Column detection must complete first")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
ocr_img_tmp = create_ocr_image(row_img)
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
if geo_result is None:
raise ValueError("Column geometry detection failed -- cannot detect rows")
_g, lx, rx, ty, by, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (lx, rx, ty, by)
content_bounds = (lx, rx, ty, by)
left_x, right_x, top_y, bottom_y = content_bounds
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
row_list = [
{
"index": r.index, "x": r.x, "y": r.y,
"width": r.width, "height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
}
for r in row_geoms
]
row_result = {
"rows": row_list,
"row_count": len(row_list),
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
"duration_seconds": round(time.time() - t0, 2),
}
cached["row_result"] = row_result
await update_session_db(session_id, row_result=row_result, current_step=7)
session = await get_session_db(session_id)
steps_run.append("rows")
yield await _auto_sse_event("rows", "done", {
"row_count": len(row_list),
"content_rows": row_result["content_rows"],
"duration_seconds": row_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
error_step = "rows"
yield await _auto_sse_event("rows", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("rows")
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
# Step 5: Words (OCR)
if req.from_step <= 5:
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
try:
t0 = time.time()
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
row_result = session.get("row_result") or cached.get("row_result")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
row_geoms = [
RowGeometry(
index=r["index"], x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0), words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
word_dicts = cached.get("_word_dicts")
if word_dicts is not None:
content_bounds = cached.get("_content_bounds")
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
ocr_img = create_ocr_image(word_img)
img_h, img_w = word_img.shape[:2]
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=req.ocr_engine, img_bgr=word_img,
)
duration = time.time() - t0
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
word_result_data = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": len(columns_meta),
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
word_result_data["vocab_entries"] = entries
word_result_data["entries"] = entries
word_result_data["entry_count"] = len(entries)
word_result_data["summary"]["total_entries"] = len(entries)
await update_session_db(session_id, word_result=word_result_data, current_step=8)
cached["word_result"] = word_result_data
session = await get_session_db(session_id)
steps_run.append("words")
yield await _auto_sse_event("words", "done", {
"total_cells": len(cells),
"layout": word_result_data["layout"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": word_result_data["summary"],
})
except Exception as e:
logger.error(f"Auto-mode words failed for {session_id}: {e}")
error_step = "words"
yield await _auto_sse_event("words", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("words")
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
# Step 6: LLM Review (optional)
if req.from_step <= 6 and not req.skip_llm_review:
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
try:
session = await get_session_db(session_id)
word_result = session.get("word_result") or cached.get("word_result")
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
if not entries:
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
steps_skipped.append("llm_review")
else:
reviewed = await llm_review_entries(entries)
session = await get_session_db(session_id)
word_result_updated = dict(session.get("word_result") or {})
word_result_updated["entries"] = reviewed
word_result_updated["vocab_entries"] = reviewed
word_result_updated["llm_reviewed"] = True
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
cached["word_result"] = word_result_updated
steps_run.append("llm_review")
yield await _auto_sse_event("llm_review", "done", {
"entries_reviewed": len(reviewed),
"model": OLLAMA_REVIEW_MODEL,
})
except Exception as e:
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
steps_skipped.append("llm_review")
else:
steps_skipped.append("llm_review")
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
# Final event
yield await _auto_sse_event("complete", "done", {
"steps_run": steps_run,
"steps_skipped": steps_skipped,
})
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@@ -0,0 +1,94 @@
"""
OCR Pipeline Reprocess Endpoint.
POST /sessions/{session_id}/reprocess — clear downstream + restart from step.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Request
from ocr_pipeline_common import _cache
from ocr_pipeline_session_store import get_session_db, update_session_db
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/reprocess")
async def reprocess_session(session_id: str, request: Request):
"""Re-run pipeline from a specific step, clearing downstream data.
Body: {"from_step": 5} (1-indexed step number)
Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) ->
Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10)
Clears downstream results:
- from_step <= 1: orientation_result + all downstream
- from_step <= 2: deskew_result + all downstream
- from_step <= 3: dewarp_result + all downstream
- from_step <= 4: crop_result + all downstream
- from_step <= 5: column_result, row_result, word_result
- from_step <= 6: row_result, word_result
- from_step <= 7: word_result (cells, vocab_entries)
- from_step <= 8: word_result.llm_review only
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
from_step = body.get("from_step", 1)
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
update_kwargs: Dict[str, Any] = {"current_step": from_step}
# Clear downstream data based on from_step
# New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) ->
# Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11)
if from_step <= 8:
update_kwargs["word_result"] = None
elif from_step == 9:
# Only clear LLM review from word_result
word_result = session.get("word_result")
if word_result:
word_result.pop("llm_review", None)
word_result.pop("llm_corrections", None)
update_kwargs["word_result"] = word_result
if from_step <= 7:
update_kwargs["row_result"] = None
if from_step <= 6:
update_kwargs["column_result"] = None
if from_step <= 4:
update_kwargs["crop_result"] = None
if from_step <= 3:
update_kwargs["dewarp_result"] = None
if from_step <= 2:
update_kwargs["deskew_result"] = None
if from_step <= 1:
update_kwargs["orientation_result"] = None
await update_session_db(session_id, **update_kwargs)
# Also clear cache
if session_id in _cache:
for key in list(update_kwargs.keys()):
if key != "current_step":
_cache[session_id][key] = update_kwargs[key]
_cache[session_id]["current_step"] = from_step
logger.info(f"Session {session_id} reprocessing from step {from_step}")
return {
"session_id": session_id,
"from_step": from_step,
"cleared": [k for k in update_kwargs if k != "current_step"],
}
+25 -750
View File
@@ -1,758 +1,33 @@
"""
Page Crop - Content-based crop for scanned pages and book scans.
Page Crop — Barrel Re-export
Detects the content boundary by analysing ink density projections and
(for book scans) the spine shadow gradient. Works with both loose A4
sheets on dark scanners AND book scans with white backgrounds.
Content-based crop for scanned pages and book scans.
Split into:
- page_crop_edges.py — Edge detection (spine shadow, gutter, projection)
- page_crop_core.py — Main crop algorithm and format detection
All public names are re-exported here for backward compatibility.
License: Apache 2.0
"""
import logging
from typing import Dict, Any, Tuple, Optional
# Core: main crop functions and format detection
from page_crop_core import ( # noqa: F401
PAPER_FORMATS,
detect_page_splits,
detect_and_crop_page,
_detect_format,
)
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# Known paper format aspect ratios (height / width, portrait orientation)
PAPER_FORMATS = {
"A4": 297.0 / 210.0, # 1.4143
"A5": 210.0 / 148.0, # 1.4189
"Letter": 11.0 / 8.5, # 1.2941
"Legal": 14.0 / 8.5, # 1.6471
"A3": 420.0 / 297.0, # 1.4141
}
# Minimum ink density (fraction of pixels) to count a row/column as "content"
_INK_THRESHOLD = 0.003 # 0.3%
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
_MIN_RUN_FRAC = 0.005 # 0.5%
def detect_page_splits(
img_bgr: np.ndarray,
) -> list:
"""Detect if the image is a multi-page spread and return split rectangles.
Uses **brightness** (not ink density) to find the spine area:
the scanner bed produces a characteristic gray strip where pages meet,
which is darker than the white paper on either side.
Returns a list of page dicts ``{x, y, width, height, page_index}``
or an empty list if only one page is detected.
"""
h, w = img_bgr.shape[:2]
# Only check landscape-ish images (width > height * 1.15)
if w < h * 1.15:
return []
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
col_brightness = np.mean(gray, axis=0).astype(np.float64)
# Heavy smoothing to ignore individual text lines
kern = max(11, w // 50)
if kern % 2 == 0:
kern += 1
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
# Page paper is bright (typically > 200), spine/scanner bed is darker
page_brightness = float(np.max(brightness_smooth))
if page_brightness < 100:
return [] # Very dark image, skip
# Spine threshold: significantly darker than the page
# Spine is typically 60-80% of paper brightness
spine_thresh = page_brightness * 0.88
# Search in center region (30-70% of width)
center_lo = int(w * 0.30)
center_hi = int(w * 0.70)
# Find the darkest valley in the center region
center_brightness = brightness_smooth[center_lo:center_hi]
darkest_val = float(np.min(center_brightness))
if darkest_val >= spine_thresh:
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
darkest_val, spine_thresh)
return []
# Find ALL contiguous dark runs in the center region
is_dark = center_brightness < spine_thresh
dark_runs: list = [] # list of (start, end) pairs
run_start = -1
for i in range(len(is_dark)):
if is_dark[i]:
if run_start < 0:
run_start = i
else:
if run_start >= 0:
dark_runs.append((run_start, i))
run_start = -1
if run_start >= 0:
dark_runs.append((run_start, len(is_dark)))
# Filter out runs that are too narrow (< 1% of image width)
min_spine_px = int(w * 0.01)
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
if not dark_runs:
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
return []
# Score each dark run: prefer centered, dark, narrow valleys
center_region_len = center_hi - center_lo
image_center_in_region = (w * 0.5 - center_lo) # x=50% mapped into region coords
best_score = -1.0
best_start, best_end = dark_runs[0]
for rs, re in dark_runs:
run_width = re - rs
run_center = (rs + re) / 2.0
# --- Factor 1: Proximity to image center (gaussian, sigma = 15% of region) ---
sigma = center_region_len * 0.15
dist = abs(run_center - image_center_in_region)
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
# --- Factor 2: Darkness (how dark is the valley relative to threshold) ---
run_brightness = float(np.mean(center_brightness[rs:re]))
# Normalize: 1.0 when run_brightness == 0, 0.0 when run_brightness == spine_thresh
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
# --- Factor 3: Narrowness bonus (spine shadows are narrow, not wide plateaus) ---
# Typical spine: 1-5% of image width. Penalise runs wider than ~8%.
width_frac = run_width / w
if width_frac <= 0.05:
narrowness_bonus = 1.0
elif width_frac <= 0.15:
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10 # linear decay 1.0 → 0.0
else:
narrowness_bonus = 0.0
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
logger.debug(
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f → score=%.4f",
center_lo + rs, center_lo + re, run_width,
center_factor, darkness_factor, narrowness_bonus, score,
)
if score > best_score:
best_score = score
best_start, best_end = rs, re
spine_w = best_end - best_start
spine_x = center_lo + best_start
spine_center = spine_x + spine_w // 2
logger.debug(
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
spine_x, spine_x + spine_w, spine_w, best_score,
)
# Verify: must have bright (paper) content on BOTH sides
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
right_end = center_lo + best_end
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
if left_brightness < spine_thresh or right_brightness < spine_thresh:
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
left_brightness, right_brightness, spine_thresh)
return []
logger.info(
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
"left_paper=%.0f, right_paper=%.0f",
spine_x, right_end, spine_w, darkest_val, page_brightness,
left_brightness, right_brightness,
)
# Split at the spine center
split_points = [spine_center]
# Build page rectangles
pages: list = []
prev_x = 0
for i, sx in enumerate(split_points):
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
"height": h, "page_index": i})
prev_x = sx
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
"height": h, "page_index": len(split_points)})
# Filter out tiny pages (< 15% of total width)
pages = [p for p in pages if p["width"] >= w * 0.15]
if len(pages) < 2:
return []
# Re-index
for i, p in enumerate(pages):
p["page_index"] = i
logger.info(
"Page split detected: %d pages, spine_w=%d, split_points=%s",
len(pages), spine_w, split_points,
)
return pages
def detect_and_crop_page(
img_bgr: np.ndarray,
margin_frac: float = 0.01,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Detect content boundary and crop scanner/book borders.
Algorithm (4-edge detection):
1. Adaptive threshold → binary (text=255, bg=0)
2. Left edge: spine-shadow detection via grayscale column means,
fallback to binary vertical projection
3. Right edge: binary vertical projection (last ink column)
4. Top/bottom edges: binary horizontal projection
5. Sanity checks, then crop with configurable margin
Args:
img_bgr: Input BGR image (should already be deskewed/dewarped)
margin_frac: Extra margin around content (fraction of dimension, default 1%)
Returns:
Tuple of (cropped_image, result_dict)
"""
h, w = img_bgr.shape[:2]
total_area = h * w
result: Dict[str, Any] = {
"crop_applied": False,
"crop_rect": None,
"crop_rect_pct": None,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
"detected_format": None,
"format_confidence": 0.0,
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
}
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# --- Binarise with adaptive threshold (works for white-on-white) ---
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
)
# --- Left edge: spine-shadow detection ---
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
# --- Right edge: spine-shadow detection ---
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
# --- Top / bottom edges: binary horizontal projection ---
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
# Compute border fractions
border_top = top_edge / h
border_bottom = (h - bottom_edge) / h
border_left = left_edge / w
border_right = (w - right_edge) / w
result["border_fractions"] = {
"top": round(border_top, 4),
"bottom": round(border_bottom, 4),
"left": round(border_left, 4),
"right": round(border_right, 4),
}
# Sanity: only crop if at least one edge has > 2% border
min_border = 0.02
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
# Add margin
margin_x = int(w * margin_frac)
margin_y = int(h * margin_frac)
crop_x = max(0, left_edge - margin_x)
crop_y = max(0, top_edge - margin_y)
crop_x2 = min(w, right_edge + margin_x)
crop_y2 = min(h, bottom_edge + margin_y)
crop_w = crop_x2 - crop_x
crop_h = crop_y2 - crop_y
# Sanity: cropped area must be >= 40% of original
if crop_w * crop_h < 0.40 * total_area:
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
100.0 * crop_w * crop_h / total_area)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
detected_format, format_confidence = _detect_format(crop_w, crop_h)
result["crop_applied"] = True
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
result["crop_rect_pct"] = {
"x": round(100.0 * crop_x / w, 2),
"y": round(100.0 * crop_y / h, 2),
"width": round(100.0 * crop_w / w, 2),
"height": round(100.0 * crop_h / h, 2),
}
result["cropped_size"] = {"width": crop_w, "height": crop_h}
result["detected_format"] = detected_format
result["format_confidence"] = format_confidence
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
logger.info(
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
border_top * 100, border_bottom * 100,
border_left * 100, border_right * 100,
)
return cropped, result
# ---------------------------------------------------------------------------
# Edge detection helpers
# ---------------------------------------------------------------------------
def _detect_spine_shadow(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Find the book spine center (darkest point) in a scanner shadow.
The scanner produces a gray strip where the book spine presses against
the glass. The darkest column in that strip is the spine center —
that's where we crop.
Distinguishes real spine shadows from text content by checking:
1. Strong brightness range (> 40 levels)
2. Darkest point is genuinely dark (< 180 mean brightness)
3. The dark area is a NARROW valley, not a text-content plateau
4. Brightness rises significantly toward the page content side
Args:
gray: Full grayscale image (for context).
search_region: Column slice of the grayscale image to search in.
offset_x: X offset of search_region relative to full image.
w: Full image width.
side: 'left' or 'right' (for logging).
Returns:
X coordinate (in full image) of the spine center, or None.
"""
region_w = search_region.shape[1]
if region_w < 10:
return None
# Column-mean brightness in the search region
col_means = np.mean(search_region, axis=0).astype(np.float64)
# Smooth with boxcar kernel (width = 1% of image width, min 5)
kernel_size = max(5, w // 100)
if kernel_size % 2 == 0:
kernel_size += 1
kernel = np.ones(kernel_size) / kernel_size
smoothed_raw = np.convolve(col_means, kernel, mode="same")
# Trim convolution edge artifacts (edges are zero-padded → artificially low)
margin = kernel_size // 2
if region_w <= 2 * margin + 10:
return None
smoothed = smoothed_raw[margin:region_w - margin]
trim_offset = margin # offset of smoothed[0] relative to search_region
val_min = float(np.min(smoothed))
val_max = float(np.max(smoothed))
shadow_range = val_max - val_min
# --- Check 1: Strong brightness gradient ---
if shadow_range <= 40:
logger.debug(
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
)
return None
# --- Check 2: Darkest point must be genuinely dark ---
# Spine shadows have mean column brightness 60-160.
# Text on white paper stays above 180.
if val_min > 180:
logger.debug(
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
)
return None
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
spine_local = spine_idx + trim_offset # index in search_region
trimmed_len = len(smoothed)
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
# Count how many columns are within 20% of the shadow range above the min.
valley_thresh = val_min + shadow_range * 0.20
valley_mask = smoothed < valley_thresh
valley_width = int(np.sum(valley_mask))
# Spine valleys are typically 3-15% of image width (20-120px on a 800px image).
# Text content plateaus span 20%+ of the search region.
max_valley_frac = 0.50 # valley must not cover more than half the trimmed region
if valley_width > trimmed_len * max_valley_frac:
logger.debug(
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
side.capitalize(), valley_width, trimmed_len,
100.0 * valley_width / trimmed_len,
)
return None
# --- Check 4: Brightness must rise toward page content ---
# For left edge: after spine, brightness should rise (= page paper)
# For right edge: before spine, brightness should rise
rise_check_w = max(5, trimmed_len // 5) # check 20% of trimmed region
if side == "left":
# Check columns to the right of the spine (in trimmed array)
right_start = min(spine_idx + 5, trimmed_len - 1)
right_end = min(right_start + rise_check_w, trimmed_len)
if right_end > right_start:
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
else: # right
# Check columns to the left of the spine (in trimmed array)
left_end = max(spine_idx - 5, 0)
left_start = max(left_end - rise_check_w, 0)
if left_end > left_start:
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
spine_x = offset_x + spine_local
logger.info(
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
)
return spine_x
def _detect_gutter_continuity(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Detect gutter shadow via vertical continuity analysis.
Camera book scans produce a subtle brightness gradient at the gutter
that is too faint for scanner-shadow detection (range < 40). However,
the gutter shadow has a unique property: it runs **continuously from
top to bottom** without interruption. Text and images always have
vertical gaps between lines, paragraphs, or sections.
Algorithm:
1. Divide image into N horizontal strips (~60px each)
2. For each column, compute what fraction of strips are darker than
the page median (from the center 50% of the full image)
3. A "gutter column" has ≥ 75% of strips darker than page_median δ
4. Smooth the dark-fraction profile and find the transition point
from the edge inward where the fraction drops below 0.50
5. Validate: gutter band must be 0.5%-10% of image width
Args:
gray: Full grayscale image.
search_region: Edge slice of the grayscale image.
offset_x: X offset of search_region relative to full image.
w: Full image width.
side: 'left' or 'right'.
Returns:
X coordinate (in full image) of the gutter inner edge, or None.
"""
region_h, region_w = search_region.shape[:2]
if region_w < 20 or region_h < 100:
return None
# --- 1. Divide into horizontal strips ---
strip_target_h = 60 # ~60px per strip
n_strips = max(10, region_h // strip_target_h)
strip_h = region_h // n_strips
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
for s in range(n_strips):
y0 = s * strip_h
y1 = min((s + 1) * strip_h, region_h)
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
# --- 2. Page median from center 50% of full image ---
center_lo = w // 4
center_hi = 3 * w // 4
page_median = float(np.median(gray[:, center_lo:center_hi]))
# Camera shadows are subtle — threshold just 5 levels below page median
dark_thresh = page_median - 5.0
# If page is very dark overall (e.g. photo, not a book page), bail out
if page_median < 180:
return None
# --- 3. Per-column dark fraction ---
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
dark_frac = dark_count / n_strips # shape: (region_w,)
# --- 4. Smooth and find transition ---
# Rolling mean (window = 1% of image width, min 5)
smooth_w = max(5, w // 100)
if smooth_w % 2 == 0:
smooth_w += 1
kernel = np.ones(smooth_w) / smooth_w
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
# Trim convolution edges
margin = smooth_w // 2
if region_w <= 2 * margin + 10:
return None
# Find the peak of dark fraction (gutter center).
# For right gutters the peak is near the edge; for left gutters
# (V-shaped spine shadow) the peak may be well inside the region.
transition_thresh = 0.50
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
if peak_frac < 0.70:
logger.debug(
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
)
return None
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
gutter_inner = None # local x in search_region
if side == "right":
# Scan from peak toward the page center (leftward)
for x in range(peak_x, margin, -1):
if frac_smooth[x] < transition_thresh:
gutter_inner = x + 1
break
else:
# Scan from peak toward the page center (rightward)
for x in range(peak_x, region_w - margin):
if frac_smooth[x] < transition_thresh:
gutter_inner = x - 1
break
if gutter_inner is None:
return None
# --- 5. Validate gutter width ---
if side == "right":
gutter_width = region_w - gutter_inner
else:
gutter_width = gutter_inner
min_gutter = max(3, int(w * 0.005)) # at least 0.5% of image
max_gutter = int(w * 0.10) # at most 10% of image
if gutter_width < min_gutter:
logger.debug(
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
gutter_width, min_gutter,
)
return None
if gutter_width > max_gutter:
logger.debug(
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
gutter_width, max_gutter,
)
return None
# Check that the gutter band is meaningfully darker than the page
if side == "right":
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
else:
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
brightness_drop = page_median - gutter_brightness
if brightness_drop < 3:
logger.debug(
"%s gutter: insufficient brightness drop (%.1f levels)",
side.capitalize(), brightness_drop,
)
return None
gutter_x = offset_x + gutter_inner
logger.info(
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
side.capitalize(), gutter_x, gutter_width,
100.0 * gutter_width / w, gutter_brightness, page_median,
brightness_drop, float(frac_smooth[gutter_inner]),
)
return gutter_x
def _detect_left_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect left content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (first ink column)
"""
search_w = max(1, w // 4)
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
if spine_x is not None:
return spine_x
# Fallback 1: vertical continuity (camera gutter shadow)
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
if gutter_x is not None:
return gutter_x
# Fallback 2: binary vertical projection
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
def _detect_right_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect right content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (last ink column)
"""
search_w = max(1, w // 4)
right_start = w - search_w
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
if spine_x is not None:
return spine_x
# Fallback 1: vertical continuity (camera gutter shadow)
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
if gutter_x is not None:
return gutter_x
# Fallback 2: binary vertical projection
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
"""Detect top and bottom content edges via binary horizontal projection."""
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
return top, bottom
def _detect_edge_projection(
binary: np.ndarray,
axis: int,
from_start: bool,
dim: int,
) -> int:
"""Find the first/last row or column with ink density above threshold.
axis=0 → project vertically (column densities) → returns x position
axis=1 → project horizontally (row densities) → returns y position
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
"""
# Compute density per row/column (mean of binary pixels / 255)
projection = np.mean(binary, axis=axis) / 255.0
# Create mask of "ink" positions
ink_mask = projection >= _INK_THRESHOLD
# Filter narrow runs (noise)
min_run = max(1, int(dim * _MIN_RUN_FRAC))
ink_mask = _filter_narrow_runs(ink_mask, min_run)
ink_positions = np.where(ink_mask)[0]
if len(ink_positions) == 0:
return 0 if from_start else dim
if from_start:
return int(ink_positions[0])
else:
return int(ink_positions[-1])
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
"""Remove True-runs shorter than min_run pixels."""
if min_run <= 1:
return mask
result = mask.copy()
n = len(result)
i = 0
while i < n:
if result[i]:
start = i
while i < n and result[i]:
i += 1
if i - start < min_run:
result[start:i] = False
else:
i += 1
return result
# ---------------------------------------------------------------------------
# Format detection (kept as optional metadata)
# ---------------------------------------------------------------------------
def _detect_format(width: int, height: int) -> Tuple[str, float]:
"""Detect paper format from dimensions by comparing aspect ratios."""
if width <= 0 or height <= 0:
return "unknown", 0.0
aspect = max(width, height) / min(width, height)
best_format = "unknown"
best_diff = float("inf")
for fmt, expected_ratio in PAPER_FORMATS.items():
diff = abs(aspect - expected_ratio)
if diff < best_diff:
best_diff = diff
best_format = fmt
confidence = max(0.0, 1.0 - best_diff * 5.0)
if confidence < 0.3:
return "unknown", 0.0
return best_format, round(confidence, 3)
from page_crop_edges import ( # noqa: F401
_INK_THRESHOLD,
_MIN_RUN_FRAC,
_detect_spine_shadow,
_detect_gutter_continuity,
_detect_left_edge_shadow,
_detect_right_edge_shadow,
_detect_top_bottom_edges,
_detect_edge_projection,
_filter_narrow_runs,
)
+342
View File
@@ -0,0 +1,342 @@
"""
Page Crop - Core Crop and Format Detection
Content-based crop for scanned pages and book scans. Detects the content
boundary by analysing ink density projections and (for book scans) the
spine shadow gradient.
Extracted from page_crop.py to keep files under 500 LOC.
License: Apache 2.0
"""
import logging
from typing import Dict, Any, Tuple
import cv2
import numpy as np
from page_crop_edges import (
_detect_left_edge_shadow,
_detect_right_edge_shadow,
_detect_top_bottom_edges,
)
logger = logging.getLogger(__name__)
# Known paper format aspect ratios (height / width, portrait orientation)
PAPER_FORMATS = {
"A4": 297.0 / 210.0, # 1.4143
"A5": 210.0 / 148.0, # 1.4189
"Letter": 11.0 / 8.5, # 1.2941
"Legal": 14.0 / 8.5, # 1.6471
"A3": 420.0 / 297.0, # 1.4141
}
def detect_page_splits(
img_bgr: np.ndarray,
) -> list:
"""Detect if the image is a multi-page spread and return split rectangles.
Uses **brightness** (not ink density) to find the spine area:
the scanner bed produces a characteristic gray strip where pages meet,
which is darker than the white paper on either side.
Returns a list of page dicts ``{x, y, width, height, page_index}``
or an empty list if only one page is detected.
"""
h, w = img_bgr.shape[:2]
# Only check landscape-ish images (width > height * 1.15)
if w < h * 1.15:
return []
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
col_brightness = np.mean(gray, axis=0).astype(np.float64)
# Heavy smoothing to ignore individual text lines
kern = max(11, w // 50)
if kern % 2 == 0:
kern += 1
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
# Page paper is bright (typically > 200), spine/scanner bed is darker
page_brightness = float(np.max(brightness_smooth))
if page_brightness < 100:
return [] # Very dark image, skip
# Spine threshold: significantly darker than the page
spine_thresh = page_brightness * 0.88
# Search in center region (30-70% of width)
center_lo = int(w * 0.30)
center_hi = int(w * 0.70)
# Find the darkest valley in the center region
center_brightness = brightness_smooth[center_lo:center_hi]
darkest_val = float(np.min(center_brightness))
if darkest_val >= spine_thresh:
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
darkest_val, spine_thresh)
return []
# Find ALL contiguous dark runs in the center region
is_dark = center_brightness < spine_thresh
dark_runs: list = []
run_start = -1
for i in range(len(is_dark)):
if is_dark[i]:
if run_start < 0:
run_start = i
else:
if run_start >= 0:
dark_runs.append((run_start, i))
run_start = -1
if run_start >= 0:
dark_runs.append((run_start, len(is_dark)))
# Filter out runs that are too narrow (< 1% of image width)
min_spine_px = int(w * 0.01)
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
if not dark_runs:
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
return []
# Score each dark run: prefer centered, dark, narrow valleys
center_region_len = center_hi - center_lo
image_center_in_region = (w * 0.5 - center_lo)
best_score = -1.0
best_start, best_end = dark_runs[0]
for rs, re in dark_runs:
run_width = re - rs
run_center = (rs + re) / 2.0
sigma = center_region_len * 0.15
dist = abs(run_center - image_center_in_region)
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
run_brightness = float(np.mean(center_brightness[rs:re]))
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
width_frac = run_width / w
if width_frac <= 0.05:
narrowness_bonus = 1.0
elif width_frac <= 0.15:
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10
else:
narrowness_bonus = 0.0
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
logger.debug(
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f",
center_lo + rs, center_lo + re, run_width,
center_factor, darkness_factor, narrowness_bonus, score,
)
if score > best_score:
best_score = score
best_start, best_end = rs, re
spine_w = best_end - best_start
spine_x = center_lo + best_start
spine_center = spine_x + spine_w // 2
logger.debug(
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
spine_x, spine_x + spine_w, spine_w, best_score,
)
# Verify: must have bright (paper) content on BOTH sides
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
right_end = center_lo + best_end
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
if left_brightness < spine_thresh or right_brightness < spine_thresh:
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
left_brightness, right_brightness, spine_thresh)
return []
logger.info(
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
"left_paper=%.0f, right_paper=%.0f",
spine_x, right_end, spine_w, darkest_val, page_brightness,
left_brightness, right_brightness,
)
# Split at the spine center
split_points = [spine_center]
# Build page rectangles
pages: list = []
prev_x = 0
for i, sx in enumerate(split_points):
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
"height": h, "page_index": i})
prev_x = sx
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
"height": h, "page_index": len(split_points)})
# Filter out tiny pages (< 15% of total width)
pages = [p for p in pages if p["width"] >= w * 0.15]
if len(pages) < 2:
return []
# Re-index
for i, p in enumerate(pages):
p["page_index"] = i
logger.info(
"Page split detected: %d pages, spine_w=%d, split_points=%s",
len(pages), spine_w, split_points,
)
return pages
def detect_and_crop_page(
img_bgr: np.ndarray,
margin_frac: float = 0.01,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Detect content boundary and crop scanner/book borders.
Algorithm (4-edge detection):
1. Adaptive threshold -> binary (text=255, bg=0)
2. Left edge: spine-shadow detection via grayscale column means,
fallback to binary vertical projection
3. Right edge: binary vertical projection (last ink column)
4. Top/bottom edges: binary horizontal projection
5. Sanity checks, then crop with configurable margin
Args:
img_bgr: Input BGR image (should already be deskewed/dewarped)
margin_frac: Extra margin around content (fraction of dimension, default 1%)
Returns:
Tuple of (cropped_image, result_dict)
"""
h, w = img_bgr.shape[:2]
total_area = h * w
result: Dict[str, Any] = {
"crop_applied": False,
"crop_rect": None,
"crop_rect_pct": None,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
"detected_format": None,
"format_confidence": 0.0,
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
}
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# --- Binarise with adaptive threshold ---
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
)
# --- Edge detection ---
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
# Compute border fractions
border_top = top_edge / h
border_bottom = (h - bottom_edge) / h
border_left = left_edge / w
border_right = (w - right_edge) / w
result["border_fractions"] = {
"top": round(border_top, 4),
"bottom": round(border_bottom, 4),
"left": round(border_left, 4),
"right": round(border_right, 4),
}
# Sanity: only crop if at least one edge has > 2% border
min_border = 0.02
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
# Add margin
margin_x = int(w * margin_frac)
margin_y = int(h * margin_frac)
crop_x = max(0, left_edge - margin_x)
crop_y = max(0, top_edge - margin_y)
crop_x2 = min(w, right_edge + margin_x)
crop_y2 = min(h, bottom_edge + margin_y)
crop_w = crop_x2 - crop_x
crop_h = crop_y2 - crop_y
# Sanity: cropped area must be >= 40% of original
if crop_w * crop_h < 0.40 * total_area:
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
100.0 * crop_w * crop_h / total_area)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
detected_format, format_confidence = _detect_format(crop_w, crop_h)
result["crop_applied"] = True
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
result["crop_rect_pct"] = {
"x": round(100.0 * crop_x / w, 2),
"y": round(100.0 * crop_y / h, 2),
"width": round(100.0 * crop_w / w, 2),
"height": round(100.0 * crop_h / h, 2),
}
result["cropped_size"] = {"width": crop_w, "height": crop_h}
result["detected_format"] = detected_format
result["format_confidence"] = format_confidence
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
logger.info(
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
border_top * 100, border_bottom * 100,
border_left * 100, border_right * 100,
)
return cropped, result
# ---------------------------------------------------------------------------
# Format detection (kept as optional metadata)
# ---------------------------------------------------------------------------
def _detect_format(width: int, height: int) -> Tuple[str, float]:
"""Detect paper format from dimensions by comparing aspect ratios."""
if width <= 0 or height <= 0:
return "unknown", 0.0
aspect = max(width, height) / min(width, height)
best_format = "unknown"
best_diff = float("inf")
for fmt, expected_ratio in PAPER_FORMATS.items():
diff = abs(aspect - expected_ratio)
if diff < best_diff:
best_diff = diff
best_format = fmt
confidence = max(0.0, 1.0 - best_diff * 5.0)
if confidence < 0.3:
return "unknown", 0.0
return best_format, round(confidence, 3)
+388
View File
@@ -0,0 +1,388 @@
"""
Page Crop - Edge Detection Helpers
Spine shadow detection, gutter continuity analysis, projection-based
edge detection, and narrow-run filtering for content cropping.
Extracted from page_crop.py to keep files under 500 LOC.
License: Apache 2.0
"""
import logging
from typing import Optional, Tuple
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# Minimum ink density (fraction of pixels) to count a row/column as "content"
_INK_THRESHOLD = 0.003 # 0.3%
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
_MIN_RUN_FRAC = 0.005 # 0.5%
def _detect_spine_shadow(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Find the book spine center (darkest point) in a scanner shadow.
The scanner produces a gray strip where the book spine presses against
the glass. The darkest column in that strip is the spine center —
that's where we crop.
Distinguishes real spine shadows from text content by checking:
1. Strong brightness range (> 40 levels)
2. Darkest point is genuinely dark (< 180 mean brightness)
3. The dark area is a NARROW valley, not a text-content plateau
4. Brightness rises significantly toward the page content side
Args:
gray: Full grayscale image (for context).
search_region: Column slice of the grayscale image to search in.
offset_x: X offset of search_region relative to full image.
w: Full image width.
side: 'left' or 'right' (for logging).
Returns:
X coordinate (in full image) of the spine center, or None.
"""
region_w = search_region.shape[1]
if region_w < 10:
return None
# Column-mean brightness in the search region
col_means = np.mean(search_region, axis=0).astype(np.float64)
# Smooth with boxcar kernel (width = 1% of image width, min 5)
kernel_size = max(5, w // 100)
if kernel_size % 2 == 0:
kernel_size += 1
kernel = np.ones(kernel_size) / kernel_size
smoothed_raw = np.convolve(col_means, kernel, mode="same")
# Trim convolution edge artifacts (edges are zero-padded -> artificially low)
margin = kernel_size // 2
if region_w <= 2 * margin + 10:
return None
smoothed = smoothed_raw[margin:region_w - margin]
trim_offset = margin # offset of smoothed[0] relative to search_region
val_min = float(np.min(smoothed))
val_max = float(np.max(smoothed))
shadow_range = val_max - val_min
# --- Check 1: Strong brightness gradient ---
if shadow_range <= 40:
logger.debug(
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
)
return None
# --- Check 2: Darkest point must be genuinely dark ---
if val_min > 180:
logger.debug(
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
)
return None
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
spine_local = spine_idx + trim_offset # index in search_region
trimmed_len = len(smoothed)
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
valley_thresh = val_min + shadow_range * 0.20
valley_mask = smoothed < valley_thresh
valley_width = int(np.sum(valley_mask))
max_valley_frac = 0.50
if valley_width > trimmed_len * max_valley_frac:
logger.debug(
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
side.capitalize(), valley_width, trimmed_len,
100.0 * valley_width / trimmed_len,
)
return None
# --- Check 4: Brightness must rise toward page content ---
rise_check_w = max(5, trimmed_len // 5)
if side == "left":
right_start = min(spine_idx + 5, trimmed_len - 1)
right_end = min(right_start + rise_check_w, trimmed_len)
if right_end > right_start:
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
else: # right
left_end = max(spine_idx - 5, 0)
left_start = max(left_end - rise_check_w, 0)
if left_end > left_start:
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
spine_x = offset_x + spine_local
logger.info(
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
)
return spine_x
def _detect_gutter_continuity(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Detect gutter shadow via vertical continuity analysis.
Camera book scans produce a subtle brightness gradient at the gutter
that is too faint for scanner-shadow detection (range < 40). However,
the gutter shadow has a unique property: it runs **continuously from
top to bottom** without interruption.
Algorithm:
1. Divide image into N horizontal strips (~60px each)
2. For each column, compute what fraction of strips are darker than
the page median (from the center 50% of the full image)
3. A "gutter column" has >= 75% of strips darker than page_median - d
4. Smooth the dark-fraction profile and find the transition point
5. Validate: gutter band must be 0.5%-10% of image width
"""
region_h, region_w = search_region.shape[:2]
if region_w < 20 or region_h < 100:
return None
# --- 1. Divide into horizontal strips ---
strip_target_h = 60
n_strips = max(10, region_h // strip_target_h)
strip_h = region_h // n_strips
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
for s in range(n_strips):
y0 = s * strip_h
y1 = min((s + 1) * strip_h, region_h)
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
# --- 2. Page median from center 50% of full image ---
center_lo = w // 4
center_hi = 3 * w // 4
page_median = float(np.median(gray[:, center_lo:center_hi]))
dark_thresh = page_median - 5.0
if page_median < 180:
return None
# --- 3. Per-column dark fraction ---
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
dark_frac = dark_count / n_strips
# --- 4. Smooth and find transition ---
smooth_w = max(5, w // 100)
if smooth_w % 2 == 0:
smooth_w += 1
kernel = np.ones(smooth_w) / smooth_w
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
margin = smooth_w // 2
if region_w <= 2 * margin + 10:
return None
transition_thresh = 0.50
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
if peak_frac < 0.70:
logger.debug(
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
)
return None
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
gutter_inner = None
if side == "right":
for x in range(peak_x, margin, -1):
if frac_smooth[x] < transition_thresh:
gutter_inner = x + 1
break
else:
for x in range(peak_x, region_w - margin):
if frac_smooth[x] < transition_thresh:
gutter_inner = x - 1
break
if gutter_inner is None:
return None
# --- 5. Validate gutter width ---
if side == "right":
gutter_width = region_w - gutter_inner
else:
gutter_width = gutter_inner
min_gutter = max(3, int(w * 0.005))
max_gutter = int(w * 0.10)
if gutter_width < min_gutter:
logger.debug(
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
gutter_width, min_gutter,
)
return None
if gutter_width > max_gutter:
logger.debug(
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
gutter_width, max_gutter,
)
return None
if side == "right":
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
else:
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
brightness_drop = page_median - gutter_brightness
if brightness_drop < 3:
logger.debug(
"%s gutter: insufficient brightness drop (%.1f levels)",
side.capitalize(), brightness_drop,
)
return None
gutter_x = offset_x + gutter_inner
logger.info(
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
side.capitalize(), gutter_x, gutter_width,
100.0 * gutter_width / w, gutter_brightness, page_median,
brightness_drop, float(frac_smooth[gutter_inner]),
)
return gutter_x
def _detect_left_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect left content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (first ink column)
"""
search_w = max(1, w // 4)
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
if spine_x is not None:
return spine_x
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
if gutter_x is not None:
return gutter_x
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
def _detect_right_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect right content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (last ink column)
"""
search_w = max(1, w // 4)
right_start = w - search_w
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
if spine_x is not None:
return spine_x
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
if gutter_x is not None:
return gutter_x
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
"""Detect top and bottom content edges via binary horizontal projection."""
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
return top, bottom
def _detect_edge_projection(
binary: np.ndarray,
axis: int,
from_start: bool,
dim: int,
) -> int:
"""Find the first/last row or column with ink density above threshold.
axis=0 -> project vertically (column densities) -> returns x position
axis=1 -> project horizontally (row densities) -> returns y position
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
"""
projection = np.mean(binary, axis=axis) / 255.0
ink_mask = projection >= _INK_THRESHOLD
min_run = max(1, int(dim * _MIN_RUN_FRAC))
ink_mask = _filter_narrow_runs(ink_mask, min_run)
ink_positions = np.where(ink_mask)[0]
if len(ink_positions) == 0:
return 0 if from_start else dim
if from_start:
return int(ink_positions[0])
else:
return int(ink_positions[-1])
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
"""Remove True-runs shorter than min_run pixels."""
if min_run <= 1:
return mask
result = mask.copy()
n = len(result)
i = 0
while i < n:
if result[i]:
start = i
while i < n and result[i]:
i += 1
if i - start < min_run:
result[start:i] = False
else:
i += 1
return result
@@ -0,0 +1,160 @@
"""
TrOCR Batch Processing & Streaming
Batch OCR and SSE streaming for multiple images.
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import asyncio
import logging
import time
from typing import Optional, List, Dict, Any
from .trocr_models import OCRResult, BatchOCRResult
from .trocr_ocr import run_trocr_ocr_enhanced
logger = logging.getLogger(__name__)
async def run_trocr_batch(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True,
progress_callback: Optional[callable] = None
) -> BatchOCRResult:
"""
Process multiple images in batch.
Args:
images: List of image data bytes
handwritten: Use handwritten model
split_lines: Whether to split images into lines
use_cache: Whether to use caching
progress_callback: Optional callback(current, total) for progress updates
Returns:
BatchOCRResult with all results
"""
start_time = time.time()
results = []
cached_count = 0
error_count = 0
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
results.append(result)
if result.from_cache:
cached_count += 1
# Report progress
if progress_callback:
progress_callback(idx + 1, len(images))
except Exception as e:
logger.error(f"Batch OCR error for image {idx}: {e}")
error_count += 1
results.append(OCRResult(
text=f"Error: {str(e)}",
confidence=0.0,
processing_time_ms=0,
model="error",
has_lora_adapter=False
))
total_time_ms = int((time.time() - start_time) * 1000)
return BatchOCRResult(
results=results,
total_time_ms=total_time_ms,
processed_count=len(images),
cached_count=cached_count,
error_count=error_count
)
# Generator for SSE streaming during batch processing
async def run_trocr_batch_stream(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
):
"""
Process images and yield progress updates for SSE streaming.
Yields:
dict with current progress and result
"""
start_time = time.time()
total = len(images)
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
elapsed_ms = int((time.time() - start_time) * 1000)
avg_time_per_image = elapsed_ms / (idx + 1)
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
yield {
"type": "progress",
"current": idx + 1,
"total": total,
"progress_percent": ((idx + 1) / total) * 100,
"elapsed_ms": elapsed_ms,
"estimated_remaining_ms": estimated_remaining,
"result": {
"text": result.text,
"confidence": result.confidence,
"processing_time_ms": result.processing_time_ms,
"from_cache": result.from_cache
}
}
except Exception as e:
logger.error(f"Stream OCR error for image {idx}: {e}")
yield {
"type": "error",
"current": idx + 1,
"total": total,
"error": str(e)
}
total_time_ms = int((time.time() - start_time) * 1000)
yield {
"type": "complete",
"total_time_ms": total_time_ms,
"processed_count": total
}
# Test function
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
"""Test TrOCR on a local image file."""
from .trocr_ocr import run_trocr_ocr
with open(image_path, "rb") as f:
image_data = f.read()
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
print(f"\n=== TrOCR Test ===")
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
print(f"Confidence: {confidence:.2f}")
print(f"Text:\n{text}")
return text, confidence
@@ -0,0 +1,278 @@
"""
TrOCR Models & Cache
Dataclasses, LRU cache, and model loading for TrOCR service.
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import os
import hashlib
import logging
import time
from typing import Tuple, Optional, List, Dict, Any
from dataclasses import dataclass, field
from collections import OrderedDict
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Backend routing: auto | pytorch | onnx
# ---------------------------------------------------------------------------
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
# Lazy loading for heavy dependencies
# Cache keyed by model_name to support base and large variants simultaneously
_trocr_models: dict = {} # {model_name: (processor, model)}
_trocr_processor = None # backwards-compat alias -> base-printed
_trocr_model = None # backwards-compat alias -> base-printed
_trocr_available = None
_model_loaded_at = None
# Simple in-memory cache with LRU eviction
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
_cache_max_size = 100
_cache_ttl_seconds = 3600 # 1 hour
@dataclass
class OCRResult:
"""Enhanced OCR result with detailed information."""
text: str
confidence: float
processing_time_ms: int
model: str
has_lora_adapter: bool = False
char_confidences: List[float] = field(default_factory=list)
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
from_cache: bool = False
image_hash: str = ""
@dataclass
class BatchOCRResult:
"""Result for batch processing."""
results: List[OCRResult]
total_time_ms: int
processed_count: int
cached_count: int
error_count: int
def _compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data for caching."""
return hashlib.sha256(image_data).hexdigest()[:16]
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
"""Get cached OCR result if available and not expired."""
if image_hash in _ocr_cache:
entry = _ocr_cache[image_hash]
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
# Move to end (LRU)
_ocr_cache.move_to_end(image_hash)
return entry["result"]
else:
# Expired, remove
del _ocr_cache[image_hash]
return None
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
"""Store OCR result in cache."""
# Evict oldest if at capacity
while len(_ocr_cache) >= _cache_max_size:
_ocr_cache.popitem(last=False)
_ocr_cache[image_hash] = {
"result": result,
"cached_at": datetime.now()
}
def get_cache_stats() -> Dict[str, Any]:
"""Get cache statistics."""
return {
"size": len(_ocr_cache),
"max_size": _cache_max_size,
"ttl_seconds": _cache_ttl_seconds,
"hit_rate": 0 # Could track this with additional counters
}
def _check_trocr_available() -> bool:
"""Check if TrOCR dependencies are available."""
global _trocr_available
if _trocr_available is not None:
return _trocr_available
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
_trocr_available = True
except ImportError as e:
logger.warning(f"TrOCR dependencies not available: {e}")
_trocr_available = False
return _trocr_available
def get_trocr_model(handwritten: bool = False, size: str = "base"):
"""
Lazy load TrOCR model and processor.
Args:
handwritten: Use handwritten model instead of printed model
size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy
for exam HTR). Only applies to handwritten variant.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _trocr_processor, _trocr_model
if not _check_trocr_available():
return None, None
# Select model name
if size == "large" and handwritten:
model_name = "microsoft/trocr-large-handwritten"
elif handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
if model_name in _trocr_models:
return _trocr_models[model_name]
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
logger.info(f"Loading TrOCR model: {model_name}")
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
_trocr_models[model_name] = (processor, model)
# Keep backwards-compat globals pointing at base-printed
if model_name == "microsoft/trocr-base-printed":
_trocr_processor = processor
_trocr_model = model
return processor, model
except Exception as e:
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
return None, None
def preload_trocr_model(handwritten: bool = True) -> bool:
"""
Preload TrOCR model at startup for faster first request.
Call this from your FastAPI startup event:
@app.on_event("startup")
async def startup():
preload_trocr_model()
"""
global _model_loaded_at
logger.info("Preloading TrOCR model...")
processor, model = get_trocr_model(handwritten=handwritten)
if processor is not None and model is not None:
_model_loaded_at = datetime.now()
logger.info("TrOCR model preloaded successfully")
return True
else:
logger.warning("TrOCR model preloading failed")
return False
def get_model_status() -> Dict[str, Any]:
"""Get current model status information."""
processor, model = get_trocr_model(handwritten=True)
is_loaded = processor is not None and model is not None
status = {
"status": "available" if is_loaded else "not_installed",
"is_loaded": is_loaded,
"model_name": "trocr-base-handwritten" if is_loaded else None,
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
}
if is_loaded:
import torch
device = next(model.parameters()).device
status["device"] = str(device)
return status
def get_active_backend() -> str:
"""
Return which TrOCR backend is configured.
Possible values: "auto", "pytorch", "onnx".
"""
return _trocr_backend
def _split_into_lines(image) -> list:
"""
Split an image into text lines using simple projection-based segmentation.
This is a basic implementation - for production use, consider using
a dedicated line detection model.
"""
import numpy as np
from PIL import Image
try:
# Convert to grayscale
gray = image.convert('L')
img_array = np.array(gray)
# Binarize (simple threshold)
threshold = 200
binary = img_array < threshold
# Horizontal projection (sum of dark pixels per row)
h_proj = np.sum(binary, axis=1)
# Find line boundaries (where projection drops below threshold)
line_threshold = img_array.shape[1] * 0.02 # 2% of width
in_line = False
line_start = 0
lines = []
for i, val in enumerate(h_proj):
if val > line_threshold and not in_line:
# Start of line
in_line = True
line_start = i
elif val <= line_threshold and in_line:
# End of line
in_line = False
# Add padding
start = max(0, line_start - 5)
end = min(img_array.shape[0], i + 5)
if end - start > 10: # Minimum line height
lines.append(image.crop((0, start, image.width, end)))
# Handle last line if still in_line
if in_line:
start = max(0, line_start - 5)
lines.append(image.crop((0, start, image.width, image.height)))
logger.info(f"Split image into {len(lines)} lines")
return lines
except Exception as e:
logger.warning(f"Line splitting failed: {e}")
return []
@@ -0,0 +1,309 @@
"""
TrOCR OCR Execution
Core OCR inference routines (PyTorch, ONNX routing, enhanced mode).
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import logging
import time
from typing import Tuple, Optional, List, Dict, Any
from .trocr_models import (
OCRResult,
_trocr_backend,
_compute_image_hash,
_cache_get,
_cache_set,
get_trocr_model,
_split_into_lines,
)
logger = logging.getLogger(__name__)
def _try_onnx_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
) -> Optional[Tuple[Optional[str], float]]:
"""
Attempt ONNX inference. Returns the (text, confidence) tuple on
success, or None if ONNX is not available / fails to load.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
if not is_onnx_available(handwritten=handwritten):
return None
# run_trocr_onnx is async -- return the coroutine's awaitable result
# The caller (run_trocr_ocr) will await it.
return run_trocr_onnx # sentinel: caller checks callable
except ImportError:
return None
async def _run_pytorch_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Original PyTorch inference path (extracted for routing).
"""
processor, model = get_trocr_model(handwritten=handwritten, size=size)
if processor is None or model is None:
logger.error("TrOCR PyTorch model not available")
return None, 0.0
try:
import torch
from PIL import Image
import numpy as np
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
if split_lines:
lines = _split_into_lines(image)
if not lines:
lines = [image]
else:
lines = [image]
all_text = []
confidences = []
for line_image in lines:
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=128)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
text = "\n".join(all_text)
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
return text, confidence
except Exception as e:
logger.error(f"TrOCR PyTorch failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
async def run_trocr_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Run TrOCR on an image.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
- "onnx" -- always use ONNX (raises RuntimeError if unavailable)
- "pytorch" -- always use PyTorch (original behaviour)
- "auto" -- try ONNX first, fall back to PyTorch
TrOCR is optimized for single-line text recognition, so for full-page
images we need to either:
1. Split into lines first (using line detection)
2. Process the whole image and get partial results
Args:
image_data: Raw image bytes
handwritten: Use handwritten model (slower but better for handwriting)
split_lines: Whether to split image into lines first
size: "base" or "large" (only for handwritten variant)
Returns:
Tuple of (extracted_text, confidence)
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is None or not callable(onnx_fn):
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
# --- PyTorch-only mode ---
if backend == "pytorch":
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
# --- Auto mode: try ONNX first, then PyTorch ---
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is not None and callable(onnx_fn):
try:
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
if result[0] is not None:
return result
logger.warning("ONNX returned None text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
def _try_onnx_enhanced(
handwritten: bool = True,
):
"""
Return the ONNX enhanced coroutine function, or None if unavailable.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
if not is_onnx_available(handwritten=handwritten):
return None
return run_trocr_onnx_enhanced
except ImportError:
return None
async def run_trocr_ocr_enhanced(
image_data: bytes,
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
) -> OCRResult:
"""
Enhanced TrOCR OCR with caching and detailed results.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
Args:
image_data: Raw image bytes
handwritten: Use handwritten model
split_lines: Whether to split image into lines first
use_cache: Whether to use caching
Returns:
OCRResult with detailed information
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is None:
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
# --- Auto mode: try ONNX first ---
if backend == "auto":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is not None:
try:
result = await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
if result.text:
return result
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
start_time = time.time()
# Check cache first
image_hash = _compute_image_hash(image_data)
if use_cache:
cached = _cache_get(image_hash)
if cached:
return OCRResult(
text=cached["text"],
confidence=cached["confidence"],
processing_time_ms=0,
model=cached["model"],
has_lora_adapter=cached.get("has_lora_adapter", False),
char_confidences=cached.get("char_confidences", []),
word_boxes=cached.get("word_boxes", []),
from_cache=True,
image_hash=image_hash
)
# Run OCR via PyTorch
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
processing_time_ms = int((time.time() - start_time) * 1000)
# Generate word boxes with simulated confidences
word_boxes = []
if text:
words = text.split()
for idx, word in enumerate(words):
# Simulate word confidence (slightly varied around overall confidence)
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
word_boxes.append({
"text": word,
"confidence": word_conf,
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
})
# Generate character confidences
char_confidences = []
if text:
for char in text:
# Simulate per-character confidence
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
char_confidences.append(char_conf)
result = OCRResult(
text=text or "",
confidence=confidence,
processing_time_ms=processing_time_ms,
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
has_lora_adapter=False, # Would check actual adapter status
char_confidences=char_confidences,
word_boxes=word_boxes,
from_cache=False,
image_hash=image_hash
)
# Cache result
if use_cache and text:
_cache_set(image_hash, {
"text": result.text,
"confidence": result.confidence,
"model": result.model,
"has_lora_adapter": result.has_lora_adapter,
"char_confidences": result.char_confidences,
"word_boxes": result.word_boxes
})
return result
+60 -710
View File
@@ -1,720 +1,70 @@
"""
TrOCR Service
TrOCR Service — Barrel Re-export
Microsoft's Transformer-based OCR for text recognition.
Besonders geeignet fuer:
- Gedruckten Text
- Saubere Scans
- Schnelle Verarbeitung
Model: microsoft/trocr-base-printed (oder handwritten Variante)
Split into submodules:
- trocr_models.py — Dataclasses, cache, model loading, line splitting
- trocr_ocr.py — Core OCR inference (PyTorch/ONNX routing, enhanced)
- trocr_batch.py — Batch processing and SSE streaming
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
Phase 2 Enhancements:
- Batch processing for multiple images
- SHA256-based caching for repeated requests
- Model preloading for faster first request
- Word-level bounding boxes with confidence
"""
import io
import os
import hashlib
import logging
import time
import asyncio
from typing import Tuple, Optional, List, Dict, Any
from dataclasses import dataclass, field
from collections import OrderedDict
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Backend routing: auto | pytorch | onnx
# ---------------------------------------------------------------------------
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
# Lazy loading for heavy dependencies
# Cache keyed by model_name to support base and large variants simultaneously
_trocr_models: dict = {} # {model_name: (processor, model)}
_trocr_processor = None # backwards-compat alias → base-printed
_trocr_model = None # backwards-compat alias → base-printed
_trocr_available = None
_model_loaded_at = None
# Simple in-memory cache with LRU eviction
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
_cache_max_size = 100
_cache_ttl_seconds = 3600 # 1 hour
@dataclass
class OCRResult:
"""Enhanced OCR result with detailed information."""
text: str
confidence: float
processing_time_ms: int
model: str
has_lora_adapter: bool = False
char_confidences: List[float] = field(default_factory=list)
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
from_cache: bool = False
image_hash: str = ""
@dataclass
class BatchOCRResult:
"""Result for batch processing."""
results: List[OCRResult]
total_time_ms: int
processed_count: int
cached_count: int
error_count: int
def _compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data for caching."""
return hashlib.sha256(image_data).hexdigest()[:16]
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
"""Get cached OCR result if available and not expired."""
if image_hash in _ocr_cache:
entry = _ocr_cache[image_hash]
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
# Move to end (LRU)
_ocr_cache.move_to_end(image_hash)
return entry["result"]
else:
# Expired, remove
del _ocr_cache[image_hash]
return None
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
"""Store OCR result in cache."""
# Evict oldest if at capacity
while len(_ocr_cache) >= _cache_max_size:
_ocr_cache.popitem(last=False)
_ocr_cache[image_hash] = {
"result": result,
"cached_at": datetime.now()
}
def get_cache_stats() -> Dict[str, Any]:
"""Get cache statistics."""
return {
"size": len(_ocr_cache),
"max_size": _cache_max_size,
"ttl_seconds": _cache_ttl_seconds,
"hit_rate": 0 # Could track this with additional counters
}
def _check_trocr_available() -> bool:
"""Check if TrOCR dependencies are available."""
global _trocr_available
if _trocr_available is not None:
return _trocr_available
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
_trocr_available = True
except ImportError as e:
logger.warning(f"TrOCR dependencies not available: {e}")
_trocr_available = False
return _trocr_available
def get_trocr_model(handwritten: bool = False, size: str = "base"):
"""
Lazy load TrOCR model and processor.
Args:
handwritten: Use handwritten model instead of printed model
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
for exam HTR). Only applies to handwritten variant.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _trocr_processor, _trocr_model
if not _check_trocr_available():
return None, None
# Select model name
if size == "large" and handwritten:
model_name = "microsoft/trocr-large-handwritten"
elif handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
if model_name in _trocr_models:
return _trocr_models[model_name]
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
logger.info(f"Loading TrOCR model: {model_name}")
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
_trocr_models[model_name] = (processor, model)
# Keep backwards-compat globals pointing at base-printed
if model_name == "microsoft/trocr-base-printed":
_trocr_processor = processor
_trocr_model = model
return processor, model
except Exception as e:
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
return None, None
def preload_trocr_model(handwritten: bool = True) -> bool:
"""
Preload TrOCR model at startup for faster first request.
Call this from your FastAPI startup event:
@app.on_event("startup")
async def startup():
preload_trocr_model()
"""
global _model_loaded_at
logger.info("Preloading TrOCR model...")
processor, model = get_trocr_model(handwritten=handwritten)
if processor is not None and model is not None:
_model_loaded_at = datetime.now()
logger.info("TrOCR model preloaded successfully")
return True
else:
logger.warning("TrOCR model preloading failed")
return False
def get_model_status() -> Dict[str, Any]:
"""Get current model status information."""
processor, model = get_trocr_model(handwritten=True)
is_loaded = processor is not None and model is not None
status = {
"status": "available" if is_loaded else "not_installed",
"is_loaded": is_loaded,
"model_name": "trocr-base-handwritten" if is_loaded else None,
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
}
if is_loaded:
import torch
device = next(model.parameters()).device
status["device"] = str(device)
return status
def get_active_backend() -> str:
"""
Return which TrOCR backend is configured.
Possible values: "auto", "pytorch", "onnx".
"""
return _trocr_backend
def _try_onnx_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
) -> Optional[Tuple[Optional[str], float]]:
"""
Attempt ONNX inference. Returns the (text, confidence) tuple on
success, or None if ONNX is not available / fails to load.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
if not is_onnx_available(handwritten=handwritten):
return None
# run_trocr_onnx is async — return the coroutine's awaitable result
# The caller (run_trocr_ocr) will await it.
return run_trocr_onnx # sentinel: caller checks callable
except ImportError:
return None
async def _run_pytorch_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Original PyTorch inference path (extracted for routing).
"""
processor, model = get_trocr_model(handwritten=handwritten, size=size)
if processor is None or model is None:
logger.error("TrOCR PyTorch model not available")
return None, 0.0
try:
import torch
from PIL import Image
import numpy as np
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
if split_lines:
lines = _split_into_lines(image)
if not lines:
lines = [image]
else:
lines = [image]
all_text = []
confidences = []
for line_image in lines:
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=128)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
text = "\n".join(all_text)
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
return text, confidence
except Exception as e:
logger.error(f"TrOCR PyTorch failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
async def run_trocr_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Run TrOCR on an image.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
- "pytorch" — always use PyTorch (original behaviour)
- "auto" — try ONNX first, fall back to PyTorch
TrOCR is optimized for single-line text recognition, so for full-page
images we need to either:
1. Split into lines first (using line detection)
2. Process the whole image and get partial results
Args:
image_data: Raw image bytes
handwritten: Use handwritten model (slower but better for handwriting)
split_lines: Whether to split image into lines first
size: "base" or "large" (only for handwritten variant)
Returns:
Tuple of (extracted_text, confidence)
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is None or not callable(onnx_fn):
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
# --- PyTorch-only mode ---
if backend == "pytorch":
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
# --- Auto mode: try ONNX first, then PyTorch ---
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is not None and callable(onnx_fn):
try:
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
if result[0] is not None:
return result
logger.warning("ONNX returned None text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
def _split_into_lines(image) -> list:
"""
Split an image into text lines using simple projection-based segmentation.
This is a basic implementation - for production use, consider using
a dedicated line detection model.
"""
import numpy as np
from PIL import Image
try:
# Convert to grayscale
gray = image.convert('L')
img_array = np.array(gray)
# Binarize (simple threshold)
threshold = 200
binary = img_array < threshold
# Horizontal projection (sum of dark pixels per row)
h_proj = np.sum(binary, axis=1)
# Find line boundaries (where projection drops below threshold)
line_threshold = img_array.shape[1] * 0.02 # 2% of width
in_line = False
line_start = 0
lines = []
for i, val in enumerate(h_proj):
if val > line_threshold and not in_line:
# Start of line
in_line = True
line_start = i
elif val <= line_threshold and in_line:
# End of line
in_line = False
# Add padding
start = max(0, line_start - 5)
end = min(img_array.shape[0], i + 5)
if end - start > 10: # Minimum line height
lines.append(image.crop((0, start, image.width, end)))
# Handle last line if still in_line
if in_line:
start = max(0, line_start - 5)
lines.append(image.crop((0, start, image.width, image.height)))
logger.info(f"Split image into {len(lines)} lines")
return lines
except Exception as e:
logger.warning(f"Line splitting failed: {e}")
return []
def _try_onnx_enhanced(
handwritten: bool = True,
):
"""
Return the ONNX enhanced coroutine function, or None if unavailable.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
if not is_onnx_available(handwritten=handwritten):
return None
return run_trocr_onnx_enhanced
except ImportError:
return None
async def run_trocr_ocr_enhanced(
image_data: bytes,
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
) -> OCRResult:
"""
Enhanced TrOCR OCR with caching and detailed results.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
Args:
image_data: Raw image bytes
handwritten: Use handwritten model
split_lines: Whether to split image into lines first
use_cache: Whether to use caching
Returns:
OCRResult with detailed information
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is None:
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
# --- Auto mode: try ONNX first ---
if backend == "auto":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is not None:
try:
result = await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
if result.text:
return result
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
start_time = time.time()
# Check cache first
image_hash = _compute_image_hash(image_data)
if use_cache:
cached = _cache_get(image_hash)
if cached:
return OCRResult(
text=cached["text"],
confidence=cached["confidence"],
processing_time_ms=0,
model=cached["model"],
has_lora_adapter=cached.get("has_lora_adapter", False),
char_confidences=cached.get("char_confidences", []),
word_boxes=cached.get("word_boxes", []),
from_cache=True,
image_hash=image_hash
)
# Run OCR via PyTorch
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
processing_time_ms = int((time.time() - start_time) * 1000)
# Generate word boxes with simulated confidences
word_boxes = []
if text:
words = text.split()
for idx, word in enumerate(words):
# Simulate word confidence (slightly varied around overall confidence)
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
word_boxes.append({
"text": word,
"confidence": word_conf,
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
})
# Generate character confidences
char_confidences = []
if text:
for char in text:
# Simulate per-character confidence
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
char_confidences.append(char_conf)
result = OCRResult(
text=text or "",
confidence=confidence,
processing_time_ms=processing_time_ms,
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
has_lora_adapter=False, # Would check actual adapter status
char_confidences=char_confidences,
word_boxes=word_boxes,
from_cache=False,
image_hash=image_hash
)
# Cache result
if use_cache and text:
_cache_set(image_hash, {
"text": result.text,
"confidence": result.confidence,
"model": result.model,
"has_lora_adapter": result.has_lora_adapter,
"char_confidences": result.char_confidences,
"word_boxes": result.word_boxes
})
return result
async def run_trocr_batch(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True,
progress_callback: Optional[callable] = None
) -> BatchOCRResult:
"""
Process multiple images in batch.
Args:
images: List of image data bytes
handwritten: Use handwritten model
split_lines: Whether to split images into lines
use_cache: Whether to use caching
progress_callback: Optional callback(current, total) for progress updates
Returns:
BatchOCRResult with all results
"""
start_time = time.time()
results = []
cached_count = 0
error_count = 0
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
results.append(result)
if result.from_cache:
cached_count += 1
# Report progress
if progress_callback:
progress_callback(idx + 1, len(images))
except Exception as e:
logger.error(f"Batch OCR error for image {idx}: {e}")
error_count += 1
results.append(OCRResult(
text=f"Error: {str(e)}",
confidence=0.0,
processing_time_ms=0,
model="error",
has_lora_adapter=False
))
total_time_ms = int((time.time() - start_time) * 1000)
return BatchOCRResult(
results=results,
total_time_ms=total_time_ms,
processed_count=len(images),
cached_count=cached_count,
error_count=error_count
)
# Generator for SSE streaming during batch processing
async def run_trocr_batch_stream(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
):
"""
Process images and yield progress updates for SSE streaming.
Yields:
dict with current progress and result
"""
start_time = time.time()
total = len(images)
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
elapsed_ms = int((time.time() - start_time) * 1000)
avg_time_per_image = elapsed_ms / (idx + 1)
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
yield {
"type": "progress",
"current": idx + 1,
"total": total,
"progress_percent": ((idx + 1) / total) * 100,
"elapsed_ms": elapsed_ms,
"estimated_remaining_ms": estimated_remaining,
"result": {
"text": result.text,
"confidence": result.confidence,
"processing_time_ms": result.processing_time_ms,
"from_cache": result.from_cache
}
}
except Exception as e:
logger.error(f"Stream OCR error for image {idx}: {e}")
yield {
"type": "error",
"current": idx + 1,
"total": total,
"error": str(e)
}
total_time_ms = int((time.time() - start_time) * 1000)
yield {
"type": "complete",
"total_time_ms": total_time_ms,
"processed_count": total
}
# Test function
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
"""Test TrOCR on a local image file."""
with open(image_path, "rb") as f:
image_data = f.read()
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
print(f"\n=== TrOCR Test ===")
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
print(f"Confidence: {confidence:.2f}")
print(f"Text:\n{text}")
return text, confidence
# Models, cache, and model loading
from .trocr_models import (
OCRResult,
BatchOCRResult,
_compute_image_hash,
_cache_get,
_cache_set,
get_cache_stats,
_check_trocr_available,
get_trocr_model,
preload_trocr_model,
get_model_status,
get_active_backend,
_split_into_lines,
)
# Core OCR execution
from .trocr_ocr import (
run_trocr_ocr,
run_trocr_ocr_enhanced,
_run_pytorch_ocr,
)
# Batch processing & streaming
from .trocr_batch import (
run_trocr_batch,
run_trocr_batch_stream,
test_trocr_ocr,
)
__all__ = [
# Dataclasses
"OCRResult",
"BatchOCRResult",
# Cache
"_compute_image_hash",
"_cache_get",
"_cache_set",
"get_cache_stats",
# Model loading
"_check_trocr_available",
"get_trocr_model",
"preload_trocr_model",
"get_model_status",
"get_active_backend",
"_split_into_lines",
# OCR execution
"run_trocr_ocr",
"run_trocr_ocr_enhanced",
"_run_pytorch_ocr",
# Batch
"run_trocr_batch",
"run_trocr_batch_stream",
"test_trocr_ocr",
]
if __name__ == "__main__":
@@ -0,0 +1,50 @@
'use client'
import { AuditStatistics } from './types'
import { Language } from '@/lib/compliance-i18n'
export default function AuditProgressBar({ statistics, lang }: { statistics: AuditStatistics; lang: Language }) {
const segments = [
{ key: 'compliant', count: statistics.compliant, color: 'bg-green-500', label: 'Konform' },
{ key: 'compliant_with_notes', count: statistics.compliant_with_notes, color: 'bg-yellow-500', label: 'Mit Anm.' },
{ key: 'non_compliant', count: statistics.non_compliant, color: 'bg-red-500', label: 'Nicht konform' },
{ key: 'not_applicable', count: statistics.not_applicable, color: 'bg-slate-300', label: 'N/A' },
{ key: 'pending', count: statistics.pending, color: 'bg-slate-100', label: 'Offen' },
]
return (
<div className="bg-white rounded-lg border border-slate-200 p-4 mb-6">
<div className="flex items-center justify-between mb-3">
<h3 className="font-semibold text-slate-900">Fortschritt</h3>
<span className="text-2xl font-bold text-primary-600">
{Math.round(statistics.completion_percentage)}%
</span>
</div>
{/* Stacked Progress Bar */}
<div className="h-4 bg-slate-100 rounded-full overflow-hidden flex">
{segments.map(seg => (
seg.count > 0 && (
<div
key={seg.key}
className={`${seg.color} transition-all`}
style={{ width: `${(seg.count / statistics.total) * 100}%` }}
title={`${seg.label}: ${seg.count}`}
/>
)
))}
</div>
{/* Legend */}
<div className="flex flex-wrap gap-4 mt-3">
{segments.map(seg => (
<div key={seg.key} className="flex items-center gap-1.5 text-sm">
<span className={`w-3 h-3 rounded ${seg.color}`} />
<span className="text-slate-600">{seg.label}:</span>
<span className="font-medium text-slate-900">{seg.count}</span>
</div>
))}
</div>
</div>
)
}
@@ -0,0 +1,73 @@
'use client'
import { AuditChecklistItem, RESULT_STATUS } from './types'
import { Language } from '@/lib/compliance-i18n'
export default function ChecklistItemRow({
item,
index,
onSignOff,
lang,
}: {
item: AuditChecklistItem
index: number
onSignOff: () => void
lang: Language
}) {
const status = RESULT_STATUS[item.current_result]
return (
<div className="p-4 hover:bg-slate-50 transition-colors">
<div className="flex items-start gap-4">
{/* Status Icon */}
<div className={`w-8 h-8 rounded-full flex items-center justify-center text-lg ${status.color}`}>
{status.icon}
</div>
{/* Content */}
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2 mb-1">
<span className="text-xs font-medium text-slate-500">{index}.</span>
<span className="font-mono text-sm text-primary-600">{item.regulation_code}</span>
<span className="font-mono text-sm text-slate-700">{item.article}</span>
{item.is_signed && (
<span className="px-1.5 py-0.5 text-xs bg-green-100 text-green-700 rounded flex items-center gap-1">
<svg className="w-3 h-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m5.618-4.016A11.955 11.955 0 0112 2.944a11.955 11.955 0 01-8.618 3.04A12.02 12.02 0 003 9c0 5.591 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.042-.133-2.052-.382-3.016z" />
</svg>
Signiert
</span>
)}
</div>
<h4 className="text-sm font-medium text-slate-900 mb-1">{item.title}</h4>
{item.notes && (
<p className="text-xs text-slate-500 italic mb-2">&quot;{item.notes}&quot;</p>
)}
<div className="flex items-center gap-4 text-xs text-slate-500">
<span>{item.controls_mapped} Controls</span>
<span>{item.evidence_count} Nachweise</span>
{item.signed_at && (
<span>Signiert: {new Date(item.signed_at).toLocaleDateString('de-DE')}</span>
)}
</div>
</div>
{/* Actions */}
<div className="flex items-center gap-2">
<span className={`px-2 py-1 text-xs font-medium rounded ${status.color}`}>
{lang === 'de' ? status.label : status.labelEn}
</span>
<button
onClick={onSignOff}
className="p-2 text-slate-400 hover:text-primary-600 hover:bg-primary-50 rounded-lg transition-colors"
title="Sign-off"
>
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</button>
</div>
</div>
</div>
)
}
@@ -0,0 +1,158 @@
'use client'
import { useState } from 'react'
import { Regulation } from './types'
interface CreateSessionData {
name: string
description?: string
auditor_name: string
auditor_email?: string
regulation_codes?: string[]
}
export default function CreateSessionModal({
regulations,
onClose,
onCreate,
}: {
regulations: Regulation[]
onClose: () => void
onCreate: (data: CreateSessionData) => void
}) {
const [name, setName] = useState('')
const [description, setDescription] = useState('')
const [auditorName, setAuditorName] = useState('')
const [auditorEmail, setAuditorEmail] = useState('')
const [selectedRegs, setSelectedRegs] = useState<string[]>([])
const [submitting, setSubmitting] = useState(false)
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault()
if (!name || !auditorName) return
setSubmitting(true)
await onCreate({
name,
description: description || undefined,
auditor_name: auditorName,
auditor_email: auditorEmail || undefined,
regulation_codes: selectedRegs.length > 0 ? selectedRegs : undefined,
})
setSubmitting(false)
}
return (
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
<div className="bg-white rounded-xl shadow-xl max-w-lg w-full mx-4 overflow-hidden">
<div className="px-6 py-4 border-b border-slate-200 flex items-center justify-between">
<h2 className="text-lg font-semibold text-slate-900">Neue Audit-Session</h2>
<button onClick={onClose} className="text-slate-400 hover:text-slate-600">
<svg className="w-6 h-6" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
<form onSubmit={handleSubmit} className="p-6 space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Name der Pruefung *
</label>
<input
type="text"
value={name}
onChange={(e) => setName(e.target.value)}
placeholder="z.B. Q1 2026 Compliance Audit"
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500"
required
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Beschreibung
</label>
<textarea
value={description}
onChange={(e) => setDescription(e.target.value)}
placeholder="Optionale Beschreibung..."
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
rows={2}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Auditor Name *
</label>
<input
type="text"
value={auditorName}
onChange={(e) => setAuditorName(e.target.value)}
placeholder="Dr. Max Mustermann"
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
required
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
E-Mail
</label>
<input
type="email"
value={auditorEmail}
onChange={(e) => setAuditorEmail(e.target.value)}
placeholder="auditor@example.com"
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
/>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Verordnungen (optional - leer = alle)
</label>
<div className="flex flex-wrap gap-2 p-2 border border-slate-300 rounded-lg max-h-32 overflow-y-auto">
{regulations.map(reg => (
<label key={reg.code} className="flex items-center gap-1.5">
<input
type="checkbox"
checked={selectedRegs.includes(reg.code)}
onChange={(e) => {
if (e.target.checked) {
setSelectedRegs([...selectedRegs, reg.code])
} else {
setSelectedRegs(selectedRegs.filter(r => r !== reg.code))
}
}}
className="rounded"
/>
<span className="text-sm text-slate-700">{reg.code}</span>
</label>
))}
</div>
</div>
<div className="flex justify-end gap-3 pt-4">
<button
type="button"
onClick={onClose}
className="px-4 py-2 text-slate-600 hover:text-slate-800"
>
Abbrechen
</button>
<button
type="submit"
disabled={!name || !auditorName || submitting}
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 disabled:opacity-50"
>
{submitting ? 'Erstelle...' : 'Session erstellen'}
</button>
</div>
</form>
</div>
</div>
)
}
@@ -0,0 +1,45 @@
'use client'
import { AuditSession, SESSION_STATUS } from './types'
export default function SessionCard({ session, onClick }: { session: AuditSession; onClick: () => void }) {
const completionPercent = session.total_items > 0
? Math.round((session.completed_items / session.total_items) * 100)
: 0
return (
<button
onClick={onClick}
className="bg-white rounded-lg border border-slate-200 p-4 text-left hover:shadow-md hover:border-primary-300 transition-all"
>
<div className="flex items-start justify-between mb-3">
<h3 className="font-semibold text-slate-900">{session.name}</h3>
<span className={`px-2 py-0.5 text-xs font-medium rounded ${SESSION_STATUS[session.status].color}`}>
{SESSION_STATUS[session.status].label}
</span>
</div>
{session.description && (
<p className="text-sm text-slate-500 mb-3 line-clamp-2">{session.description}</p>
)}
<div className="mb-3">
<div className="flex justify-between text-xs text-slate-500 mb-1">
<span>{session.completed_items} / {session.total_items} Punkte</span>
<span>{completionPercent}%</span>
</div>
<div className="h-2 bg-slate-100 rounded-full overflow-hidden">
<div
className="h-full bg-primary-500 transition-all"
style={{ width: `${completionPercent}%` }}
/>
</div>
</div>
<div className="flex items-center justify-between text-xs text-slate-500">
<span>Auditor: {session.auditor_name}</span>
<span>{new Date(session.created_at).toLocaleDateString('de-DE')}</span>
</div>
</button>
)
}
@@ -0,0 +1,122 @@
'use client'
import { useState } from 'react'
import { AuditChecklistItem, RESULT_STATUS } from './types'
export default function SignOffModal({
item,
auditorName,
onClose,
onSubmit,
}: {
item: AuditChecklistItem
auditorName: string
onClose: () => void
onSubmit: (result: AuditChecklistItem['current_result'], notes: string, sign: boolean) => void
}) {
const [result, setResult] = useState<AuditChecklistItem['current_result']>(item.current_result)
const [notes, setNotes] = useState(item.notes || '')
const [sign, setSign] = useState(false)
const [submitting, setSubmitting] = useState(false)
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault()
setSubmitting(true)
await onSubmit(result, notes, sign)
setSubmitting(false)
}
return (
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
<div className="bg-white rounded-xl shadow-xl max-w-lg w-full mx-4 overflow-hidden">
<div className="px-6 py-4 border-b border-slate-200">
<h2 className="text-lg font-semibold text-slate-900">Auditor Sign-off</h2>
<p className="text-sm text-slate-500 mt-1">
{item.regulation_code} {item.article} - {item.title}
</p>
</div>
<form onSubmit={handleSubmit} className="p-6 space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-2">
Pruefungsergebnis
</label>
<div className="space-y-2">
{Object.entries(RESULT_STATUS).map(([key, { label, color }]) => (
<label key={key} className="flex items-center gap-3 p-2 rounded-lg hover:bg-slate-50 cursor-pointer">
<input
type="radio"
name="result"
value={key}
checked={result === key}
onChange={() => setResult(key as typeof result)}
className="w-4 h-4"
/>
<span className={`px-2 py-0.5 text-sm font-medium rounded ${color}`}>
{label}
</span>
</label>
))}
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Anmerkungen
</label>
<textarea
value={notes}
onChange={(e) => setNotes(e.target.value)}
placeholder="Optionale Anmerkungen zur Pruefung..."
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
rows={3}
/>
</div>
<div className="p-3 bg-amber-50 border border-amber-200 rounded-lg">
<label className="flex items-start gap-3 cursor-pointer">
<input
type="checkbox"
checked={sign}
onChange={(e) => setSign(e.target.checked)}
className="mt-1 rounded"
/>
<div>
<span className="font-medium text-amber-800">Digital signieren</span>
<p className="text-sm text-amber-700 mt-0.5">
Erstellt eine SHA-256 Signatur des Ergebnisses. Diese Aktion kann nicht rueckgaengig gemacht werden.
</p>
</div>
</label>
</div>
<div className="text-sm text-slate-500 pt-2 border-t border-slate-100">
<p>Auditor: <span className="font-medium text-slate-700">{auditorName}</span></p>
<p>Datum: <span className="font-medium text-slate-700">{new Date().toLocaleString('de-DE')}</span></p>
</div>
<div className="flex justify-end gap-3 pt-4">
<button
type="button"
onClick={onClose}
className="px-4 py-2 text-slate-600 hover:text-slate-800"
>
Abbrechen
</button>
<button
type="submit"
disabled={submitting}
className={`px-4 py-2 rounded-lg disabled:opacity-50 ${
sign
? 'bg-amber-600 text-white hover:bg-amber-700'
: 'bg-primary-600 text-white hover:bg-primary-700'
}`}
>
{submitting ? 'Speichere...' : sign ? 'Signieren & Speichern' : 'Speichern'}
</button>
</div>
</form>
</div>
</div>
)
}
@@ -0,0 +1,62 @@
import { Language } from '@/lib/compliance-i18n'
export interface AuditSession {
id: string
name: string
description: string | null
auditor_name: string
auditor_email: string | null
status: 'draft' | 'in_progress' | 'completed' | 'archived'
regulation_ids: string[] | null
total_items: number
completed_items: number
created_at: string
started_at: string | null
completed_at: string | null
}
export interface AuditChecklistItem {
requirement_id: string
regulation_code: string
article: string
title: string
current_result: 'compliant' | 'compliant_notes' | 'non_compliant' | 'not_applicable' | 'pending'
notes: string | null
is_signed: boolean
signed_at: string | null
signed_by: string | null
evidence_count: number
controls_mapped: number
}
export interface AuditStatistics {
total: number
compliant: number
compliant_with_notes: number
non_compliant: number
not_applicable: number
pending: number
completion_percentage: number
}
export interface Regulation {
id: string
code: string
name: string
requirement_count: number
}
export const RESULT_STATUS = {
pending: { label: 'Ausstehend', labelEn: 'Pending', color: 'bg-slate-100 text-slate-700', icon: '○' },
compliant: { label: 'Konform', labelEn: 'Compliant', color: 'bg-green-100 text-green-700', icon: '✓' },
compliant_notes: { label: 'Konform (Anm.)', labelEn: 'Compliant w/ Notes', color: 'bg-yellow-100 text-yellow-700', icon: '⚠' },
non_compliant: { label: 'Nicht konform', labelEn: 'Non-Compliant', color: 'bg-red-100 text-red-700', icon: '✗' },
not_applicable: { label: 'N/A', labelEn: 'N/A', color: 'bg-slate-50 text-slate-500', icon: '' },
}
export const SESSION_STATUS = {
draft: { label: 'Entwurf', color: 'bg-slate-100 text-slate-700' },
in_progress: { label: 'In Bearbeitung', color: 'bg-blue-100 text-blue-700' },
completed: { label: 'Abgeschlossen', color: 'bg-green-100 text-green-700' },
archived: { label: 'Archiviert', color: 'bg-slate-50 text-slate-500' },
}
@@ -0,0 +1,203 @@
'use client'
import { useState, useEffect, useCallback } from 'react'
import {
AuditSession,
AuditChecklistItem,
AuditStatistics,
Regulation,
} from './types'
import { Language } from '@/lib/compliance-i18n'
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
export function useAuditChecklist() {
const [sessions, setSessions] = useState<AuditSession[]>([])
const [selectedSession, setSelectedSession] = useState<AuditSession | null>(null)
const [checklistItems, setChecklistItems] = useState<AuditChecklistItem[]>([])
const [statistics, setStatistics] = useState<AuditStatistics | null>(null)
const [regulations, setRegulations] = useState<Regulation[]>([])
const [loading, setLoading] = useState(true)
const [loadingChecklist, setLoadingChecklist] = useState(false)
const [error, setError] = useState<string | null>(null)
// Filters
const [filterStatus, setFilterStatus] = useState<string>('all')
const [filterRegulation, setFilterRegulation] = useState<string>('all')
const [searchQuery, setSearchQuery] = useState('')
// Modal states
const [showCreateModal, setShowCreateModal] = useState(false)
const [showSignOffModal, setShowSignOffModal] = useState(false)
const [selectedItem, setSelectedItem] = useState<AuditChecklistItem | null>(null)
// Language
const [lang] = useState<Language>('de')
// Load sessions
const loadSessions = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/sessions`)
if (res.ok) {
const data = await res.json()
setSessions(data.sessions || [])
} else {
console.error('Failed to load sessions:', res.status)
}
} catch (err) {
console.error('Failed to load sessions:', err)
setError('Verbindung zum Backend fehlgeschlagen')
} finally {
setLoading(false)
}
}, [])
// Load regulations for filter
const loadRegulations = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/regulations`)
if (res.ok) {
const data = await res.json()
setRegulations(data.regulations || [])
}
} catch (err) {
console.error('Failed to load regulations:', err)
}
}, [])
// Load checklist for selected session
const loadChecklist = useCallback(async (sessionId: string) => {
setLoadingChecklist(true)
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/checklist/${sessionId}`)
if (res.ok) {
const data = await res.json()
setChecklistItems(data.items || [])
setStatistics(data.statistics || null)
if (data.session) {
setSelectedSession(data.session)
}
} else {
console.error('Failed to load checklist:', res.status)
}
} catch (err) {
console.error('Failed to load checklist:', err)
} finally {
setLoadingChecklist(false)
}
}, [])
useEffect(() => {
loadSessions()
loadRegulations()
}, [loadSessions, loadRegulations])
useEffect(() => {
if (selectedSession) {
loadChecklist(selectedSession.id)
}
}, [selectedSession, loadChecklist])
// Filter checklist items
const filteredItems = checklistItems.filter(item => {
if (filterStatus !== 'all' && item.current_result !== filterStatus) return false
if (filterRegulation !== 'all' && item.regulation_code !== filterRegulation) return false
if (searchQuery) {
const query = searchQuery.toLowerCase()
return (
item.title.toLowerCase().includes(query) ||
item.article.toLowerCase().includes(query) ||
item.regulation_code.toLowerCase().includes(query)
)
}
return true
})
// Handle sign-off
const handleSignOff = async (
item: AuditChecklistItem,
result: AuditChecklistItem['current_result'],
notes: string,
sign: boolean
) => {
if (!selectedSession) return
try {
const res = await fetch(
`${BACKEND_URL}/api/v1/compliance/audit/checklist/${selectedSession.id}/items/${item.requirement_id}/sign-off`,
{
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ result, notes, sign }),
}
)
if (res.ok) {
await loadChecklist(selectedSession.id)
setShowSignOffModal(false)
setSelectedItem(null)
} else {
const err = await res.json()
alert(`Fehler: ${err.detail || 'Sign-off fehlgeschlagen'}`)
}
} catch (err) {
console.error('Sign-off failed:', err)
alert('Netzwerkfehler bei Sign-off')
}
}
// Create session
const handleCreateSession = async (data: {
name: string
description?: string
auditor_name: string
auditor_email?: string
regulation_codes?: string[]
}) => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/sessions`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(data),
})
if (res.ok) {
await loadSessions()
setShowCreateModal(false)
} else {
const err = await res.json()
alert(`Fehler: ${err.detail || 'Session konnte nicht erstellt werden'}`)
}
} catch (err) {
console.error('Create session failed:', err)
alert('Netzwerkfehler')
}
}
return {
sessions,
selectedSession,
setSelectedSession,
checklistItems,
filteredItems,
statistics,
regulations,
loading,
loadingChecklist,
error,
filterStatus,
setFilterStatus,
filterRegulation,
setFilterRegulation,
searchQuery,
setSearchQuery,
showCreateModal,
setShowCreateModal,
showSignOffModal,
setShowSignOffModal,
selectedItem,
setSelectedItem,
lang,
handleSignOff,
handleCreateSession,
}
}
@@ -10,212 +10,29 @@
* - Fortschritt und Statistiken zu verfolgen
*/
import { useState, useEffect, useCallback } from 'react'
import Link from 'next/link'
import AdminLayout from '@/components/admin/AdminLayout'
import { getTerm, Language } from '@/lib/compliance-i18n'
// Types based on backend schemas
interface AuditSession {
id: string
name: string
description: string | null
auditor_name: string
auditor_email: string | null
status: 'draft' | 'in_progress' | 'completed' | 'archived'
regulation_ids: string[] | null
total_items: number
completed_items: number
created_at: string
started_at: string | null
completed_at: string | null
}
interface AuditChecklistItem {
requirement_id: string
regulation_code: string
article: string
title: string
current_result: 'compliant' | 'compliant_notes' | 'non_compliant' | 'not_applicable' | 'pending'
notes: string | null
is_signed: boolean
signed_at: string | null
signed_by: string | null
evidence_count: number
controls_mapped: number
}
interface AuditStatistics {
total: number
compliant: number
compliant_with_notes: number
non_compliant: number
not_applicable: number
pending: number
completion_percentage: number
}
interface Regulation {
id: string
code: string
name: string
requirement_count: number
}
// Result status configuration
const RESULT_STATUS = {
pending: { label: 'Ausstehend', labelEn: 'Pending', color: 'bg-slate-100 text-slate-700', icon: '○' },
compliant: { label: 'Konform', labelEn: 'Compliant', color: 'bg-green-100 text-green-700', icon: '✓' },
compliant_notes: { label: 'Konform (Anm.)', labelEn: 'Compliant w/ Notes', color: 'bg-yellow-100 text-yellow-700', icon: '⚠' },
non_compliant: { label: 'Nicht konform', labelEn: 'Non-Compliant', color: 'bg-red-100 text-red-700', icon: '✗' },
not_applicable: { label: 'N/A', labelEn: 'N/A', color: 'bg-slate-50 text-slate-500', icon: '' },
}
const SESSION_STATUS = {
draft: { label: 'Entwurf', color: 'bg-slate-100 text-slate-700' },
in_progress: { label: 'In Bearbeitung', color: 'bg-blue-100 text-blue-700' },
completed: { label: 'Abgeschlossen', color: 'bg-green-100 text-green-700' },
archived: { label: 'Archiviert', color: 'bg-slate-50 text-slate-500' },
}
import { RESULT_STATUS, SESSION_STATUS } from './_components/types'
import { useAuditChecklist } from './_components/useAuditChecklist'
import SessionCard from './_components/SessionCard'
import AuditProgressBar from './_components/AuditProgressBar'
import ChecklistItemRow from './_components/ChecklistItemRow'
import CreateSessionModal from './_components/CreateSessionModal'
import SignOffModal from './_components/SignOffModal'
export default function AuditChecklistPage() {
const [sessions, setSessions] = useState<AuditSession[]>([])
const [selectedSession, setSelectedSession] = useState<AuditSession | null>(null)
const [checklistItems, setChecklistItems] = useState<AuditChecklistItem[]>([])
const [statistics, setStatistics] = useState<AuditStatistics | null>(null)
const [regulations, setRegulations] = useState<Regulation[]>([])
const [loading, setLoading] = useState(true)
const [loadingChecklist, setLoadingChecklist] = useState(false)
const [error, setError] = useState<string | null>(null)
// Filters
const [filterStatus, setFilterStatus] = useState<string>('all')
const [filterRegulation, setFilterRegulation] = useState<string>('all')
const [searchQuery, setSearchQuery] = useState('')
// Modal states
const [showCreateModal, setShowCreateModal] = useState(false)
const [showSignOffModal, setShowSignOffModal] = useState(false)
const [selectedItem, setSelectedItem] = useState<AuditChecklistItem | null>(null)
// Language
const [lang] = useState<Language>('de')
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
// Load sessions
const loadSessions = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/sessions`)
if (res.ok) {
const data = await res.json()
setSessions(data.sessions || [])
} else {
console.error('Failed to load sessions:', res.status)
}
} catch (err) {
console.error('Failed to load sessions:', err)
setError('Verbindung zum Backend fehlgeschlagen')
} finally {
setLoading(false)
}
}, [BACKEND_URL])
// Load regulations for filter
const loadRegulations = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/regulations`)
if (res.ok) {
const data = await res.json()
setRegulations(data.regulations || [])
}
} catch (err) {
console.error('Failed to load regulations:', err)
}
}, [BACKEND_URL])
// Load checklist for selected session
const loadChecklist = useCallback(async (sessionId: string) => {
setLoadingChecklist(true)
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/checklist/${sessionId}`)
if (res.ok) {
const data = await res.json()
setChecklistItems(data.items || [])
setStatistics(data.statistics || null)
// Update session with latest data
if (data.session) {
setSelectedSession(data.session)
}
} else {
console.error('Failed to load checklist:', res.status)
}
} catch (err) {
console.error('Failed to load checklist:', err)
} finally {
setLoadingChecklist(false)
}
}, [BACKEND_URL])
useEffect(() => {
loadSessions()
loadRegulations()
}, [loadSessions, loadRegulations])
useEffect(() => {
if (selectedSession) {
loadChecklist(selectedSession.id)
}
}, [selectedSession, loadChecklist])
// Filter checklist items
const filteredItems = checklistItems.filter(item => {
if (filterStatus !== 'all' && item.current_result !== filterStatus) return false
if (filterRegulation !== 'all' && item.regulation_code !== filterRegulation) return false
if (searchQuery) {
const query = searchQuery.toLowerCase()
return (
item.title.toLowerCase().includes(query) ||
item.article.toLowerCase().includes(query) ||
item.regulation_code.toLowerCase().includes(query)
)
}
return true
})
// Handle sign-off
const handleSignOff = async (
item: AuditChecklistItem,
result: AuditChecklistItem['current_result'],
notes: string,
sign: boolean
) => {
if (!selectedSession) return
try {
const res = await fetch(
`${BACKEND_URL}/api/v1/compliance/audit/checklist/${selectedSession.id}/items/${item.requirement_id}/sign-off`,
{
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ result, notes, sign }),
}
)
if (res.ok) {
// Reload checklist to get updated data
await loadChecklist(selectedSession.id)
setShowSignOffModal(false)
setSelectedItem(null)
} else {
const err = await res.json()
alert(`Fehler: ${err.detail || 'Sign-off fehlgeschlagen'}`)
}
} catch (err) {
console.error('Sign-off failed:', err)
alert('Netzwerkfehler bei Sign-off')
}
}
const {
sessions, selectedSession, setSelectedSession,
filteredItems, statistics, regulations,
loading, loadingChecklist, error,
filterStatus, setFilterStatus,
filterRegulation, setFilterRegulation,
searchQuery, setSearchQuery,
showCreateModal, setShowCreateModal,
showSignOffModal, setShowSignOffModal,
selectedItem, setSelectedItem,
lang, handleSignOff, handleCreateSession,
} = useAuditChecklist()
// Session list view when no session is selected
if (!selectedSession && !loading) {
@@ -231,7 +48,6 @@ export default function AuditChecklistPage() {
</svg>
Zurueck zu Compliance
</Link>
<button
onClick={() => setShowCreateModal(true)}
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 flex items-center gap-2"
@@ -244,12 +60,9 @@ export default function AuditChecklistPage() {
</div>
{error && (
<div className="mb-6 p-4 bg-red-50 border border-red-200 rounded-lg text-red-700">
{error}
</div>
<div className="mb-6 p-4 bg-red-50 border border-red-200 rounded-lg text-red-700">{error}</div>
)}
{/* Session Cards */}
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{sessions.length === 0 ? (
<div className="col-span-full bg-white rounded-lg border border-slate-200 p-8 text-center">
@@ -267,46 +80,22 @@ export default function AuditChecklistPage() {
</div>
) : (
sessions.map(session => (
<SessionCard
key={session.id}
session={session}
onClick={() => setSelectedSession(session)}
/>
<SessionCard key={session.id} session={session} onClick={() => setSelectedSession(session)} />
))
)}
</div>
{/* Create Session Modal */}
{showCreateModal && (
<CreateSessionModal
regulations={regulations}
onClose={() => setShowCreateModal(false)}
onCreate={async (data) => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/sessions`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(data),
})
if (res.ok) {
await loadSessions()
setShowCreateModal(false)
} else {
const err = await res.json()
alert(`Fehler: ${err.detail || 'Session konnte nicht erstellt werden'}`)
}
} catch (err) {
console.error('Create session failed:', err)
alert('Netzwerkfehler')
}
}}
onCreate={handleCreateSession}
/>
)}
</AdminLayout>
)
}
// Loading state
if (loading) {
return (
<AdminLayout title="Audit Checkliste" description="Laden...">
@@ -320,7 +109,6 @@ export default function AuditChecklistPage() {
// Checklist view for selected session
return (
<AdminLayout title={selectedSession?.name || 'Audit Checkliste'} description="Pruefung durchfuehren">
{/* Header */}
<div className="mb-6 flex items-center justify-between">
<button
onClick={() => setSelectedSession(null)}
@@ -331,7 +119,6 @@ export default function AuditChecklistPage() {
</svg>
Zurueck zur Uebersicht
</button>
<div className="flex items-center gap-3">
<span className={`px-3 py-1 text-sm font-medium rounded-full ${SESSION_STATUS[selectedSession?.status || 'draft'].color}`}>
{SESSION_STATUS[selectedSession?.status || 'draft'].label}
@@ -348,10 +135,7 @@ export default function AuditChecklistPage() {
</div>
</div>
{/* Progress Bar */}
{statistics && (
<AuditProgressBar statistics={statistics} lang={lang} />
)}
{statistics && <AuditProgressBar statistics={statistics} lang={lang} />}
{/* Filters */}
<div className="bg-white rounded-lg border border-slate-200 p-4 mb-6">
@@ -424,7 +208,6 @@ export default function AuditChecklistPage() {
)}
</div>
{/* Sign-off Modal */}
{showSignOffModal && selectedItem && (
<SignOffModal
item={selectedItem}
@@ -439,429 +222,3 @@ export default function AuditChecklistPage() {
</AdminLayout>
)
}
// Session Card Component
function SessionCard({ session, onClick }: { session: AuditSession; onClick: () => void }) {
const completionPercent = session.total_items > 0
? Math.round((session.completed_items / session.total_items) * 100)
: 0
return (
<button
onClick={onClick}
className="bg-white rounded-lg border border-slate-200 p-4 text-left hover:shadow-md hover:border-primary-300 transition-all"
>
<div className="flex items-start justify-between mb-3">
<h3 className="font-semibold text-slate-900">{session.name}</h3>
<span className={`px-2 py-0.5 text-xs font-medium rounded ${SESSION_STATUS[session.status].color}`}>
{SESSION_STATUS[session.status].label}
</span>
</div>
{session.description && (
<p className="text-sm text-slate-500 mb-3 line-clamp-2">{session.description}</p>
)}
<div className="mb-3">
<div className="flex justify-between text-xs text-slate-500 mb-1">
<span>{session.completed_items} / {session.total_items} Punkte</span>
<span>{completionPercent}%</span>
</div>
<div className="h-2 bg-slate-100 rounded-full overflow-hidden">
<div
className="h-full bg-primary-500 transition-all"
style={{ width: `${completionPercent}%` }}
/>
</div>
</div>
<div className="flex items-center justify-between text-xs text-slate-500">
<span>Auditor: {session.auditor_name}</span>
<span>{new Date(session.created_at).toLocaleDateString('de-DE')}</span>
</div>
</button>
)
}
// Progress Bar Component
function AuditProgressBar({ statistics, lang }: { statistics: AuditStatistics; lang: Language }) {
const segments = [
{ key: 'compliant', count: statistics.compliant, color: 'bg-green-500', label: 'Konform' },
{ key: 'compliant_with_notes', count: statistics.compliant_with_notes, color: 'bg-yellow-500', label: 'Mit Anm.' },
{ key: 'non_compliant', count: statistics.non_compliant, color: 'bg-red-500', label: 'Nicht konform' },
{ key: 'not_applicable', count: statistics.not_applicable, color: 'bg-slate-300', label: 'N/A' },
{ key: 'pending', count: statistics.pending, color: 'bg-slate-100', label: 'Offen' },
]
return (
<div className="bg-white rounded-lg border border-slate-200 p-4 mb-6">
<div className="flex items-center justify-between mb-3">
<h3 className="font-semibold text-slate-900">Fortschritt</h3>
<span className="text-2xl font-bold text-primary-600">
{Math.round(statistics.completion_percentage)}%
</span>
</div>
{/* Stacked Progress Bar */}
<div className="h-4 bg-slate-100 rounded-full overflow-hidden flex">
{segments.map(seg => (
seg.count > 0 && (
<div
key={seg.key}
className={`${seg.color} transition-all`}
style={{ width: `${(seg.count / statistics.total) * 100}%` }}
title={`${seg.label}: ${seg.count}`}
/>
)
))}
</div>
{/* Legend */}
<div className="flex flex-wrap gap-4 mt-3">
{segments.map(seg => (
<div key={seg.key} className="flex items-center gap-1.5 text-sm">
<span className={`w-3 h-3 rounded ${seg.color}`} />
<span className="text-slate-600">{seg.label}:</span>
<span className="font-medium text-slate-900">{seg.count}</span>
</div>
))}
</div>
</div>
)
}
// Checklist Item Row Component
function ChecklistItemRow({
item,
index,
onSignOff,
lang,
}: {
item: AuditChecklistItem
index: number
onSignOff: () => void
lang: Language
}) {
const status = RESULT_STATUS[item.current_result]
return (
<div className="p-4 hover:bg-slate-50 transition-colors">
<div className="flex items-start gap-4">
{/* Status Icon */}
<div className={`w-8 h-8 rounded-full flex items-center justify-center text-lg ${status.color}`}>
{status.icon}
</div>
{/* Content */}
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2 mb-1">
<span className="text-xs font-medium text-slate-500">{index}.</span>
<span className="font-mono text-sm text-primary-600">{item.regulation_code}</span>
<span className="font-mono text-sm text-slate-700">{item.article}</span>
{item.is_signed && (
<span className="px-1.5 py-0.5 text-xs bg-green-100 text-green-700 rounded flex items-center gap-1">
<svg className="w-3 h-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m5.618-4.016A11.955 11.955 0 0112 2.944a11.955 11.955 0 01-8.618 3.04A12.02 12.02 0 003 9c0 5.591 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.042-.133-2.052-.382-3.016z" />
</svg>
Signiert
</span>
)}
</div>
<h4 className="text-sm font-medium text-slate-900 mb-1">{item.title}</h4>
{item.notes && (
<p className="text-xs text-slate-500 italic mb-2">"{item.notes}"</p>
)}
<div className="flex items-center gap-4 text-xs text-slate-500">
<span>{item.controls_mapped} Controls</span>
<span>{item.evidence_count} Nachweise</span>
{item.signed_at && (
<span>Signiert: {new Date(item.signed_at).toLocaleDateString('de-DE')}</span>
)}
</div>
</div>
{/* Actions */}
<div className="flex items-center gap-2">
<span className={`px-2 py-1 text-xs font-medium rounded ${status.color}`}>
{lang === 'de' ? status.label : status.labelEn}
</span>
<button
onClick={onSignOff}
className="p-2 text-slate-400 hover:text-primary-600 hover:bg-primary-50 rounded-lg transition-colors"
title="Sign-off"
>
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</button>
</div>
</div>
</div>
)
}
// Create Session Modal
function CreateSessionModal({
regulations,
onClose,
onCreate,
}: {
regulations: Regulation[]
onClose: () => void
onCreate: (data: { name: string; description?: string; auditor_name: string; auditor_email?: string; regulation_codes?: string[] }) => void
}) {
const [name, setName] = useState('')
const [description, setDescription] = useState('')
const [auditorName, setAuditorName] = useState('')
const [auditorEmail, setAuditorEmail] = useState('')
const [selectedRegs, setSelectedRegs] = useState<string[]>([])
const [submitting, setSubmitting] = useState(false)
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault()
if (!name || !auditorName) return
setSubmitting(true)
await onCreate({
name,
description: description || undefined,
auditor_name: auditorName,
auditor_email: auditorEmail || undefined,
regulation_codes: selectedRegs.length > 0 ? selectedRegs : undefined,
})
setSubmitting(false)
}
return (
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
<div className="bg-white rounded-xl shadow-xl max-w-lg w-full mx-4 overflow-hidden">
<div className="px-6 py-4 border-b border-slate-200 flex items-center justify-between">
<h2 className="text-lg font-semibold text-slate-900">Neue Audit-Session</h2>
<button onClick={onClose} className="text-slate-400 hover:text-slate-600">
<svg className="w-6 h-6" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
<form onSubmit={handleSubmit} className="p-6 space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Name der Pruefung *
</label>
<input
type="text"
value={name}
onChange={(e) => setName(e.target.value)}
placeholder="z.B. Q1 2026 Compliance Audit"
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500"
required
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Beschreibung
</label>
<textarea
value={description}
onChange={(e) => setDescription(e.target.value)}
placeholder="Optionale Beschreibung..."
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
rows={2}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Auditor Name *
</label>
<input
type="text"
value={auditorName}
onChange={(e) => setAuditorName(e.target.value)}
placeholder="Dr. Max Mustermann"
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
required
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
E-Mail
</label>
<input
type="email"
value={auditorEmail}
onChange={(e) => setAuditorEmail(e.target.value)}
placeholder="auditor@example.com"
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
/>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Verordnungen (optional - leer = alle)
</label>
<div className="flex flex-wrap gap-2 p-2 border border-slate-300 rounded-lg max-h-32 overflow-y-auto">
{regulations.map(reg => (
<label key={reg.code} className="flex items-center gap-1.5">
<input
type="checkbox"
checked={selectedRegs.includes(reg.code)}
onChange={(e) => {
if (e.target.checked) {
setSelectedRegs([...selectedRegs, reg.code])
} else {
setSelectedRegs(selectedRegs.filter(r => r !== reg.code))
}
}}
className="rounded"
/>
<span className="text-sm text-slate-700">{reg.code}</span>
</label>
))}
</div>
</div>
<div className="flex justify-end gap-3 pt-4">
<button
type="button"
onClick={onClose}
className="px-4 py-2 text-slate-600 hover:text-slate-800"
>
Abbrechen
</button>
<button
type="submit"
disabled={!name || !auditorName || submitting}
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 disabled:opacity-50"
>
{submitting ? 'Erstelle...' : 'Session erstellen'}
</button>
</div>
</form>
</div>
</div>
)
}
// Sign-off Modal
function SignOffModal({
item,
auditorName,
onClose,
onSubmit,
}: {
item: AuditChecklistItem
auditorName: string
onClose: () => void
onSubmit: (result: AuditChecklistItem['current_result'], notes: string, sign: boolean) => void
}) {
const [result, setResult] = useState<AuditChecklistItem['current_result']>(item.current_result)
const [notes, setNotes] = useState(item.notes || '')
const [sign, setSign] = useState(false)
const [submitting, setSubmitting] = useState(false)
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault()
setSubmitting(true)
await onSubmit(result, notes, sign)
setSubmitting(false)
}
return (
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
<div className="bg-white rounded-xl shadow-xl max-w-lg w-full mx-4 overflow-hidden">
<div className="px-6 py-4 border-b border-slate-200">
<h2 className="text-lg font-semibold text-slate-900">Auditor Sign-off</h2>
<p className="text-sm text-slate-500 mt-1">
{item.regulation_code} {item.article} - {item.title}
</p>
</div>
<form onSubmit={handleSubmit} className="p-6 space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-2">
Pruefungsergebnis
</label>
<div className="space-y-2">
{Object.entries(RESULT_STATUS).map(([key, { label, color }]) => (
<label key={key} className="flex items-center gap-3 p-2 rounded-lg hover:bg-slate-50 cursor-pointer">
<input
type="radio"
name="result"
value={key}
checked={result === key}
onChange={() => setResult(key as typeof result)}
className="w-4 h-4"
/>
<span className={`px-2 py-0.5 text-sm font-medium rounded ${color}`}>
{label}
</span>
</label>
))}
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Anmerkungen
</label>
<textarea
value={notes}
onChange={(e) => setNotes(e.target.value)}
placeholder="Optionale Anmerkungen zur Pruefung..."
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
rows={3}
/>
</div>
<div className="p-3 bg-amber-50 border border-amber-200 rounded-lg">
<label className="flex items-start gap-3 cursor-pointer">
<input
type="checkbox"
checked={sign}
onChange={(e) => setSign(e.target.checked)}
className="mt-1 rounded"
/>
<div>
<span className="font-medium text-amber-800">Digital signieren</span>
<p className="text-sm text-amber-700 mt-0.5">
Erstellt eine SHA-256 Signatur des Ergebnisses. Diese Aktion kann nicht rueckgaengig gemacht werden.
</p>
</div>
</label>
</div>
<div className="text-sm text-slate-500 pt-2 border-t border-slate-100">
<p>Auditor: <span className="font-medium text-slate-700">{auditorName}</span></p>
<p>Datum: <span className="font-medium text-slate-700">{new Date().toLocaleString('de-DE')}</span></p>
</div>
<div className="flex justify-end gap-3 pt-4">
<button
type="button"
onClick={onClose}
className="px-4 py-2 text-slate-600 hover:text-slate-800"
>
Abbrechen
</button>
<button
type="submit"
disabled={submitting}
className={`px-4 py-2 rounded-lg disabled:opacity-50 ${
sign
? 'bg-amber-600 text-white hover:bg-amber-700'
: 'bg-primary-600 text-white hover:bg-primary-700'
}`}
>
{submitting ? 'Speichere...' : sign ? 'Signieren & Speichern' : 'Speichern'}
</button>
</div>
</form>
</div>
</div>
)
}
@@ -0,0 +1,91 @@
'use client'
import {
ServiceModule, RiskAssessment,
SERVICE_TYPE_CONFIG, CRITICALITY_CONFIG, RELEVANCE_CONFIG,
} from './types'
interface ModuleDetailPanelProps {
module: ServiceModule;
loadingDetail: boolean;
loadingRisk: boolean;
showRiskPanel: boolean;
riskAssessment: RiskAssessment | null;
onClose: () => void;
onAssessRisk: (moduleId: string) => void;
onCloseRisk: () => void;
}
export default function ModuleDetailPanel({
module, loadingDetail, loadingRisk, showRiskPanel, riskAssessment,
onClose, onAssessRisk, onCloseRisk,
}: ModuleDetailPanelProps) {
return (
<div className="w-96 bg-white rounded-lg shadow border sticky top-6 h-fit">
<div className={`px-4 py-3 border-b ${SERVICE_TYPE_CONFIG[module.service_type]?.bgColor || 'bg-gray-100'}`}>
<div className="flex items-center justify-between">
<span className="text-lg">{SERVICE_TYPE_CONFIG[module.service_type]?.icon || '📁'}</span>
<button onClick={onClose} className="text-gray-400 hover:text-gray-600"></button>
</div>
<h3 className="font-bold text-lg mt-2">{module.display_name}</h3>
<div className="text-sm text-gray-600">{module.name}</div>
</div>
{loadingDetail ? (
<div className="p-4 text-center text-gray-500">Lade Details...</div>
) : (
<div className="p-4 space-y-4">
{module.description && (<div><div className="text-xs text-gray-500 uppercase mb-1">Beschreibung</div><div className="text-sm text-gray-700">{module.description}</div></div>)}
<div className="grid grid-cols-2 gap-2 text-sm">
{module.port && (<div><span className="text-gray-500">Port:</span><span className="ml-1 font-mono">{module.port}</span></div>)}
<div><span className="text-gray-500">Criticality:</span><span className={`ml-1 px-1.5 py-0.5 rounded text-xs ${CRITICALITY_CONFIG[module.criticality]?.bgColor || ''} ${CRITICALITY_CONFIG[module.criticality]?.color || ''}`}>{module.criticality}</span></div>
</div>
<div><div className="text-xs text-gray-500 uppercase mb-1">Tech Stack</div><div className="flex flex-wrap gap-1">{module.technology_stack.map((tech, i) => (<span key={i} className="px-2 py-0.5 bg-gray-100 text-gray-700 text-xs rounded">{tech}</span>))}</div></div>
{module.data_categories.length > 0 && (<div><div className="text-xs text-gray-500 uppercase mb-1">Daten-Kategorien</div><div className="flex flex-wrap gap-1">{module.data_categories.map((cat, i) => (<span key={i} className="px-2 py-0.5 bg-blue-50 text-blue-700 text-xs rounded">{cat}</span>))}</div></div>)}
<div className="flex flex-wrap gap-2">
{module.processes_pii && (<span className="px-2 py-1 bg-purple-100 text-purple-700 text-xs rounded">Verarbeitet PII</span>)}
{module.ai_components && (<span className="px-2 py-1 bg-pink-100 text-pink-700 text-xs rounded">AI-Komponenten</span>)}
{module.processes_health_data && (<span className="px-2 py-1 bg-red-100 text-red-700 text-xs rounded">Gesundheitsdaten</span>)}
</div>
{module.regulations && module.regulations.length > 0 && (
<div>
<div className="text-xs text-gray-500 uppercase mb-2">Applicable Regulations ({module.regulations.length})</div>
<div className="space-y-2">
{module.regulations.map((reg, i) => (
<div key={i} className="p-2 bg-gray-50 rounded text-sm">
<div className="flex justify-between items-start"><span className="font-medium">{reg.code}</span><span className={`px-1.5 py-0.5 rounded text-xs ${RELEVANCE_CONFIG[reg.relevance_level]?.bgColor || 'bg-gray-100'} ${RELEVANCE_CONFIG[reg.relevance_level]?.color || 'text-gray-700'}`}>{reg.relevance_level}</span></div>
<div className="text-gray-500 text-xs">{reg.name}</div>
{reg.notes && (<div className="text-gray-600 text-xs mt-1 italic">{reg.notes}</div>)}
</div>
))}
</div>
</div>
)}
{module.owner_team && (<div><div className="text-xs text-gray-500 uppercase mb-1">Owner</div><div className="text-sm text-gray-700">{module.owner_team}</div></div>)}
{module.repository_path && (<div><div className="text-xs text-gray-500 uppercase mb-1">Repository</div><code className="text-xs bg-gray-100 px-2 py-1 rounded block">{module.repository_path}</code></div>)}
<div className="pt-2 border-t">
<button onClick={() => onAssessRisk(module.id)} disabled={loadingRisk} className="w-full px-4 py-2 bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg hover:from-purple-700 hover:to-pink-700 transition disabled:opacity-50 flex items-center justify-center gap-2">
{loadingRisk ? (<><span className="animate-spin"></span>AI analysiert...</>) : (<><span>🤖</span>AI Risikobewertung</>)}
</button>
</div>
{showRiskPanel && (
<div className="mt-4 p-4 bg-gradient-to-br from-purple-50 to-pink-50 rounded-lg border border-purple-200">
<div className="flex justify-between items-center mb-3">
<h4 className="font-semibold text-purple-900 flex items-center gap-2"><span>🤖</span> AI Risikobewertung</h4>
<button onClick={onCloseRisk} className="text-purple-400 hover:text-purple-600"></button>
</div>
{loadingRisk ? (<div className="text-center py-4 text-purple-600"><div className="animate-pulse">Analysiere Compliance-Risiken...</div></div>) : riskAssessment ? (
<div className="space-y-3">
<div className="flex items-center gap-2"><span className="text-sm text-gray-600">Gesamtrisiko:</span><span className={`px-2 py-1 rounded text-sm font-medium ${riskAssessment.overall_risk === 'critical' ? 'bg-red-100 text-red-700' : riskAssessment.overall_risk === 'high' ? 'bg-orange-100 text-orange-700' : riskAssessment.overall_risk === 'medium' ? 'bg-yellow-100 text-yellow-700' : 'bg-green-100 text-green-700'}`}>{riskAssessment.overall_risk.toUpperCase()}</span><span className="text-xs text-gray-400">({Math.round(riskAssessment.confidence_score * 100)}% Konfidenz)</span></div>
{riskAssessment.risk_factors.length > 0 && (<div><div className="text-xs text-gray-500 uppercase mb-1">Risikofaktoren</div><div className="space-y-1">{riskAssessment.risk_factors.map((factor, i) => (<div key={i} className="flex items-center justify-between text-sm bg-white/50 rounded px-2 py-1"><span className="text-gray-700">{factor.factor}</span><span className={`text-xs px-1.5 py-0.5 rounded ${factor.severity === 'critical' || factor.severity === 'high' ? 'bg-red-100 text-red-600' : 'bg-yellow-100 text-yellow-600'}`}>{factor.severity}</span></div>))}</div></div>)}
{riskAssessment.compliance_gaps.length > 0 && (<div><div className="text-xs text-gray-500 uppercase mb-1">Compliance-Luecken</div><ul className="text-sm text-gray-700 space-y-1">{riskAssessment.compliance_gaps.map((gap, i) => (<li key={i} className="flex items-start gap-1"><span className="text-red-500"></span><span>{gap}</span></li>))}</ul></div>)}
{riskAssessment.recommendations.length > 0 && (<div><div className="text-xs text-gray-500 uppercase mb-1">Empfehlungen</div><ul className="text-sm text-gray-700 space-y-1">{riskAssessment.recommendations.map((rec, i) => (<li key={i} className="flex items-start gap-1"><span className="text-green-500"></span><span>{rec}</span></li>))}</ul></div>)}
</div>
) : (<div className="text-center py-4 text-gray-500 text-sm">Klicken Sie auf &quot;AI Risikobewertung&quot; um eine Analyse zu starten.</div>)}
</div>
)}
</div>
)}
</div>
)
}
@@ -0,0 +1,71 @@
export interface ServiceModule {
id: string;
name: string;
display_name: string;
description: string | null;
service_type: string;
port: number | null;
technology_stack: string[];
repository_path: string | null;
docker_image: string | null;
data_categories: string[];
processes_pii: boolean;
processes_health_data: boolean;
ai_components: boolean;
criticality: string;
owner_team: string | null;
is_active: boolean;
compliance_score: number | null;
regulation_count: number;
risk_count: number;
created_at: string;
regulations?: Array<{
code: string;
name: string;
relevance_level: string;
notes: string | null;
}>;
}
export interface ModulesOverview {
total_modules: number;
modules_by_type: Record<string, number>;
modules_by_criticality: Record<string, number>;
modules_processing_pii: number;
modules_with_ai: number;
average_compliance_score: number | null;
regulations_coverage: Record<string, number>;
}
export interface RiskAssessment {
overall_risk: string;
risk_factors: Array<{ factor: string; severity: string; likelihood: string }>;
recommendations: string[];
compliance_gaps: string[];
confidence_score: number;
}
export const SERVICE_TYPE_CONFIG: Record<string, { icon: string; color: string; bgColor: string }> = {
backend: { icon: '⚙️', color: 'text-blue-700', bgColor: 'bg-blue-100' },
database: { icon: '🗄️', color: 'text-purple-700', bgColor: 'bg-purple-100' },
ai: { icon: '🤖', color: 'text-pink-700', bgColor: 'bg-pink-100' },
communication: { icon: '💬', color: 'text-green-700', bgColor: 'bg-green-100' },
storage: { icon: '📦', color: 'text-orange-700', bgColor: 'bg-orange-100' },
infrastructure: { icon: '🌐', color: 'text-gray-700', bgColor: 'bg-gray-100' },
monitoring: { icon: '📊', color: 'text-cyan-700', bgColor: 'bg-cyan-100' },
security: { icon: '🔒', color: 'text-red-700', bgColor: 'bg-red-100' },
};
export const CRITICALITY_CONFIG: Record<string, { color: string; bgColor: string }> = {
critical: { color: 'text-red-700', bgColor: 'bg-red-100' },
high: { color: 'text-orange-700', bgColor: 'bg-orange-100' },
medium: { color: 'text-yellow-700', bgColor: 'bg-yellow-100' },
low: { color: 'text-green-700', bgColor: 'bg-green-100' },
};
export const RELEVANCE_CONFIG: Record<string, { color: string; bgColor: string }> = {
critical: { color: 'text-red-700', bgColor: 'bg-red-100' },
high: { color: 'text-orange-700', bgColor: 'bg-orange-100' },
medium: { color: 'text-yellow-700', bgColor: 'bg-yellow-100' },
low: { color: 'text-green-700', bgColor: 'bg-green-100' },
};
@@ -0,0 +1,110 @@
'use client'
import { useState, useEffect } from 'react'
import { ServiceModule, ModulesOverview, RiskAssessment } from './types'
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
const API_BASE = `${BACKEND_URL}/api/v1/compliance`
export function useModulesPage() {
const [modules, setModules] = useState<ServiceModule[]>([])
const [overview, setOverview] = useState<ModulesOverview | null>(null)
const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
const [typeFilter, setTypeFilter] = useState<string>('all')
const [criticalityFilter, setCriticalityFilter] = useState<string>('all')
const [piiFilter, setPiiFilter] = useState<boolean | null>(null)
const [aiFilter, setAiFilter] = useState<boolean | null>(null)
const [searchTerm, setSearchTerm] = useState('')
const [selectedModule, setSelectedModule] = useState<ServiceModule | null>(null)
const [loadingDetail, setLoadingDetail] = useState(false)
const [riskAssessment, setRiskAssessment] = useState<RiskAssessment | null>(null)
const [loadingRisk, setLoadingRisk] = useState(false)
const [showRiskPanel, setShowRiskPanel] = useState(false)
useEffect(() => { fetchModules(); fetchOverview() }, [])
const fetchModules = async () => {
try {
setLoading(true)
const params = new URLSearchParams()
if (typeFilter !== 'all') params.append('service_type', typeFilter)
if (criticalityFilter !== 'all') params.append('criticality', criticalityFilter)
if (piiFilter !== null) params.append('processes_pii', String(piiFilter))
if (aiFilter !== null) params.append('ai_components', String(aiFilter))
const url = `${API_BASE}/modules${params.toString() ? '?' + params.toString() : ''}`
const res = await fetch(url)
if (!res.ok) throw new Error('Failed to fetch modules')
const data = await res.json()
setModules(data.modules || [])
} catch (err) {
setError(err instanceof Error ? err.message : 'Unknown error')
} finally { setLoading(false) }
}
const fetchOverview = async () => {
try {
const res = await fetch(`${API_BASE}/modules/overview`)
if (!res.ok) throw new Error('Failed to fetch overview')
const data = await res.json()
setOverview(data)
} catch (err) { console.error('Failed to fetch overview:', err) }
}
const fetchModuleDetail = async (moduleId: string) => {
try {
setLoadingDetail(true)
const res = await fetch(`${API_BASE}/modules/${moduleId}`)
if (!res.ok) throw new Error('Failed to fetch module details')
const data = await res.json()
setSelectedModule(data)
} catch (err) { console.error('Failed to fetch module details:', err) }
finally { setLoadingDetail(false) }
}
const seedModules = async (force: boolean = false) => {
try {
const res = await fetch(`${API_BASE}/modules/seed`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ force }) })
if (!res.ok) throw new Error('Failed to seed modules')
const data = await res.json()
alert(`Seeded ${data.modules_created} modules with ${data.mappings_created} regulation mappings`)
fetchModules(); fetchOverview()
} catch (err) { alert('Failed to seed modules: ' + (err instanceof Error ? err.message : 'Unknown error')) }
}
const assessModuleRisk = async (moduleId: string) => {
setLoadingRisk(true); setShowRiskPanel(true); setRiskAssessment(null)
try {
const res = await fetch(`${API_BASE}/ai/assess-risk`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ module_id: moduleId }) })
if (res.ok) { const data = await res.json(); setRiskAssessment(data) }
else { alert('AI-Risikobewertung fehlgeschlagen') }
} catch (err) { alert('Netzwerkfehler bei AI-Risikobewertung') }
finally { setLoadingRisk(false) }
}
const filteredModules = modules.filter(m => {
if (!searchTerm) return true
const term = searchTerm.toLowerCase()
return m.name.toLowerCase().includes(term) || m.display_name.toLowerCase().includes(term) || (m.description && m.description.toLowerCase().includes(term)) || m.technology_stack.some(t => t.toLowerCase().includes(term))
})
const modulesByType = filteredModules.reduce((acc, m) => {
const type = m.service_type || 'unknown'
if (!acc[type]) acc[type] = []
acc[type].push(m)
return acc
}, {} as Record<string, ServiceModule[]>)
return {
modules, overview, loading, error,
typeFilter, setTypeFilter, criticalityFilter, setCriticalityFilter,
piiFilter, setPiiFilter, aiFilter, setAiFilter, searchTerm, setSearchTerm,
selectedModule, setSelectedModule, loadingDetail,
riskAssessment, loadingRisk, showRiskPanel, setShowRiskPanel,
filteredModules, modulesByType,
fetchModules, fetchModuleDetail, seedModules, assessModuleRisk,
}
}
+54 -677
View File
@@ -1,214 +1,11 @@
'use client';
import { useState, useEffect } from 'react';
// Types
interface ServiceModule {
id: string;
name: string;
display_name: string;
description: string | null;
service_type: string;
port: number | null;
technology_stack: string[];
repository_path: string | null;
docker_image: string | null;
data_categories: string[];
processes_pii: boolean;
processes_health_data: boolean;
ai_components: boolean;
criticality: string;
owner_team: string | null;
is_active: boolean;
compliance_score: number | null;
regulation_count: number;
risk_count: number;
created_at: string;
regulations?: Array<{
code: string;
name: string;
relevance_level: string;
notes: string | null;
}>;
}
interface ModulesOverview {
total_modules: number;
modules_by_type: Record<string, number>;
modules_by_criticality: Record<string, number>;
modules_processing_pii: number;
modules_with_ai: number;
average_compliance_score: number | null;
regulations_coverage: Record<string, number>;
}
// Service Type Icons and Colors
const SERVICE_TYPE_CONFIG: Record<string, { icon: string; color: string; bgColor: string }> = {
backend: { icon: '⚙️', color: 'text-blue-700', bgColor: 'bg-blue-100' },
database: { icon: '🗄️', color: 'text-purple-700', bgColor: 'bg-purple-100' },
ai: { icon: '🤖', color: 'text-pink-700', bgColor: 'bg-pink-100' },
communication: { icon: '💬', color: 'text-green-700', bgColor: 'bg-green-100' },
storage: { icon: '📦', color: 'text-orange-700', bgColor: 'bg-orange-100' },
infrastructure: { icon: '🌐', color: 'text-gray-700', bgColor: 'bg-gray-100' },
monitoring: { icon: '📊', color: 'text-cyan-700', bgColor: 'bg-cyan-100' },
security: { icon: '🔒', color: 'text-red-700', bgColor: 'bg-red-100' },
};
const CRITICALITY_CONFIG: Record<string, { color: string; bgColor: string }> = {
critical: { color: 'text-red-700', bgColor: 'bg-red-100' },
high: { color: 'text-orange-700', bgColor: 'bg-orange-100' },
medium: { color: 'text-yellow-700', bgColor: 'bg-yellow-100' },
low: { color: 'text-green-700', bgColor: 'bg-green-100' },
};
const RELEVANCE_CONFIG: Record<string, { color: string; bgColor: string }> = {
critical: { color: 'text-red-700', bgColor: 'bg-red-100' },
high: { color: 'text-orange-700', bgColor: 'bg-orange-100' },
medium: { color: 'text-yellow-700', bgColor: 'bg-yellow-100' },
low: { color: 'text-green-700', bgColor: 'bg-green-100' },
};
import { SERVICE_TYPE_CONFIG, CRITICALITY_CONFIG } from './_components/types';
import { useModulesPage } from './_components/useModulesPage';
import ModuleDetailPanel from './_components/ModuleDetailPanel';
export default function ModulesPage() {
const [modules, setModules] = useState<ServiceModule[]>([]);
const [overview, setOverview] = useState<ModulesOverview | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
// Filters
const [typeFilter, setTypeFilter] = useState<string>('all');
const [criticalityFilter, setCriticalityFilter] = useState<string>('all');
const [piiFilter, setPiiFilter] = useState<boolean | null>(null);
const [aiFilter, setAiFilter] = useState<boolean | null>(null);
const [searchTerm, setSearchTerm] = useState('');
// Selected module for detail view
const [selectedModule, setSelectedModule] = useState<ServiceModule | null>(null);
const [loadingDetail, setLoadingDetail] = useState(false);
// AI Risk Assessment
const [riskAssessment, setRiskAssessment] = useState<{
overall_risk: string;
risk_factors: Array<{ factor: string; severity: string; likelihood: string }>;
recommendations: string[];
compliance_gaps: string[];
confidence_score: number;
} | null>(null);
const [loadingRisk, setLoadingRisk] = useState(false);
const [showRiskPanel, setShowRiskPanel] = useState(false);
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000';
const API_BASE = `${BACKEND_URL}/api/v1/compliance`;
useEffect(() => {
fetchModules();
fetchOverview();
}, []);
const fetchModules = async () => {
try {
setLoading(true);
const params = new URLSearchParams();
if (typeFilter !== 'all') params.append('service_type', typeFilter);
if (criticalityFilter !== 'all') params.append('criticality', criticalityFilter);
if (piiFilter !== null) params.append('processes_pii', String(piiFilter));
if (aiFilter !== null) params.append('ai_components', String(aiFilter));
const url = `${API_BASE}/modules${params.toString() ? '?' + params.toString() : ''}`;
const res = await fetch(url);
if (!res.ok) throw new Error('Failed to fetch modules');
const data = await res.json();
setModules(data.modules || []);
} catch (err) {
setError(err instanceof Error ? err.message : 'Unknown error');
} finally {
setLoading(false);
}
};
const fetchOverview = async () => {
try {
const res = await fetch(`${API_BASE}/modules/overview`);
if (!res.ok) throw new Error('Failed to fetch overview');
const data = await res.json();
setOverview(data);
} catch (err) {
console.error('Failed to fetch overview:', err);
}
};
const fetchModuleDetail = async (moduleId: string) => {
try {
setLoadingDetail(true);
const res = await fetch(`${API_BASE}/modules/${moduleId}`);
if (!res.ok) throw new Error('Failed to fetch module details');
const data = await res.json();
setSelectedModule(data);
} catch (err) {
console.error('Failed to fetch module details:', err);
} finally {
setLoadingDetail(false);
}
};
const seedModules = async (force: boolean = false) => {
try {
const res = await fetch(`${API_BASE}/modules/seed`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ force }),
});
if (!res.ok) throw new Error('Failed to seed modules');
const data = await res.json();
alert(`Seeded ${data.modules_created} modules with ${data.mappings_created} regulation mappings`);
fetchModules();
fetchOverview();
} catch (err) {
alert('Failed to seed modules: ' + (err instanceof Error ? err.message : 'Unknown error'));
}
};
const assessModuleRisk = async (moduleId: string) => {
setLoadingRisk(true);
setShowRiskPanel(true);
setRiskAssessment(null);
try {
const res = await fetch(`${API_BASE}/ai/assess-risk`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ module_id: moduleId }),
});
if (res.ok) {
const data = await res.json();
setRiskAssessment(data);
} else {
alert('AI-Risikobewertung fehlgeschlagen');
}
} catch (err) {
alert('Netzwerkfehler bei AI-Risikobewertung');
} finally {
setLoadingRisk(false);
}
};
// Filter modules by search term
const filteredModules = modules.filter(m => {
if (!searchTerm) return true;
const term = searchTerm.toLowerCase();
return (
m.name.toLowerCase().includes(term) ||
m.display_name.toLowerCase().includes(term) ||
(m.description && m.description.toLowerCase().includes(term)) ||
m.technology_stack.some(t => t.toLowerCase().includes(term))
);
});
// Group by type for visualization
const modulesByType = filteredModules.reduce((acc, m) => {
const type = m.service_type || 'unknown';
if (!acc[type]) acc[type] = [];
acc[type].push(m);
return acc;
}, {} as Record<string, ServiceModule[]>);
const mp = useModulesPage();
return (
<div className="p-6 space-y-6">
@@ -216,233 +13,72 @@ export default function ModulesPage() {
<div className="flex justify-between items-center">
<div>
<h1 className="text-2xl font-bold text-gray-900">Service Module Registry</h1>
<p className="text-gray-600 mt-1">
Alle {overview?.total_modules || 0} Breakpilot-Services mit Regulation-Mappings
</p>
<p className="text-gray-600 mt-1">Alle {mp.overview?.total_modules || 0} Breakpilot-Services mit Regulation-Mappings</p>
</div>
<div className="flex gap-2">
<button
onClick={() => seedModules(false)}
className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition"
>
Seed Modules
</button>
<button
onClick={() => seedModules(true)}
className="px-4 py-2 bg-gray-600 text-white rounded-lg hover:bg-gray-700 transition"
>
Force Re-Seed
</button>
<button onClick={() => mp.seedModules(false)} className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition">Seed Modules</button>
<button onClick={() => mp.seedModules(true)} className="px-4 py-2 bg-gray-600 text-white rounded-lg hover:bg-gray-700 transition">Force Re-Seed</button>
</div>
</div>
{/* Overview Stats */}
{overview && (
{mp.overview && (
<div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-4">
<div className="bg-white rounded-lg p-4 shadow border">
<div className="text-3xl font-bold text-blue-600">{overview.total_modules}</div>
<div className="text-sm text-gray-600">Services</div>
</div>
<div className="bg-white rounded-lg p-4 shadow border">
<div className="text-3xl font-bold text-red-600">{overview.modules_by_criticality?.critical || 0}</div>
<div className="text-sm text-gray-600">Critical</div>
</div>
<div className="bg-white rounded-lg p-4 shadow border">
<div className="text-3xl font-bold text-purple-600">{overview.modules_processing_pii}</div>
<div className="text-sm text-gray-600">PII-Processing</div>
</div>
<div className="bg-white rounded-lg p-4 shadow border">
<div className="text-3xl font-bold text-pink-600">{overview.modules_with_ai}</div>
<div className="text-sm text-gray-600">AI-Komponenten</div>
</div>
<div className="bg-white rounded-lg p-4 shadow border">
<div className="text-3xl font-bold text-green-600">
{Object.keys(overview.regulations_coverage || {}).length}
</div>
<div className="text-sm text-gray-600">Regulations</div>
</div>
<div className="bg-white rounded-lg p-4 shadow border">
<div className="text-3xl font-bold text-cyan-600">
{overview.average_compliance_score !== null
? `${overview.average_compliance_score}%`
: 'N/A'}
</div>
<div className="text-sm text-gray-600">Avg. Score</div>
</div>
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-blue-600">{mp.overview.total_modules}</div><div className="text-sm text-gray-600">Services</div></div>
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-red-600">{mp.overview.modules_by_criticality?.critical || 0}</div><div className="text-sm text-gray-600">Critical</div></div>
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-purple-600">{mp.overview.modules_processing_pii}</div><div className="text-sm text-gray-600">PII-Processing</div></div>
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-pink-600">{mp.overview.modules_with_ai}</div><div className="text-sm text-gray-600">AI-Komponenten</div></div>
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-green-600">{Object.keys(mp.overview.regulations_coverage || {}).length}</div><div className="text-sm text-gray-600">Regulations</div></div>
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-cyan-600">{mp.overview.average_compliance_score !== null ? `${mp.overview.average_compliance_score}%` : 'N/A'}</div><div className="text-sm text-gray-600">Avg. Score</div></div>
</div>
)}
{/* Filters */}
<div className="bg-white rounded-lg p-4 shadow border">
<div className="flex flex-wrap gap-4 items-center">
<div>
<label className="block text-xs text-gray-500 mb-1">Service Type</label>
<select
value={typeFilter}
onChange={(e) => { setTypeFilter(e.target.value); }}
className="border rounded px-3 py-2 text-sm"
>
<option value="all">Alle Typen</option>
<option value="backend">Backend</option>
<option value="database">Database</option>
<option value="ai">AI/ML</option>
<option value="communication">Communication</option>
<option value="storage">Storage</option>
<option value="infrastructure">Infrastructure</option>
<option value="monitoring">Monitoring</option>
<option value="security">Security</option>
</select>
</div>
<div>
<label className="block text-xs text-gray-500 mb-1">Criticality</label>
<select
value={criticalityFilter}
onChange={(e) => { setCriticalityFilter(e.target.value); }}
className="border rounded px-3 py-2 text-sm"
>
<option value="all">Alle</option>
<option value="critical">Critical</option>
<option value="high">High</option>
<option value="medium">Medium</option>
<option value="low">Low</option>
</select>
</div>
<div>
<label className="block text-xs text-gray-500 mb-1">PII</label>
<select
value={piiFilter === null ? 'all' : String(piiFilter)}
onChange={(e) => {
const val = e.target.value;
setPiiFilter(val === 'all' ? null : val === 'true');
}}
className="border rounded px-3 py-2 text-sm"
>
<option value="all">Alle</option>
<option value="true">Verarbeitet PII</option>
<option value="false">Keine PII</option>
</select>
</div>
<div>
<label className="block text-xs text-gray-500 mb-1">AI</label>
<select
value={aiFilter === null ? 'all' : String(aiFilter)}
onChange={(e) => {
const val = e.target.value;
setAiFilter(val === 'all' ? null : val === 'true');
}}
className="border rounded px-3 py-2 text-sm"
>
<option value="all">Alle</option>
<option value="true">Mit AI</option>
<option value="false">Ohne AI</option>
</select>
</div>
<div className="flex-1">
<label className="block text-xs text-gray-500 mb-1">Suche</label>
<input
type="text"
placeholder="Service, Beschreibung, Technologie..."
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="border rounded px-3 py-2 text-sm w-full"
/>
</div>
<div className="pt-5">
<button
onClick={fetchModules}
className="px-4 py-2 bg-gray-100 rounded hover:bg-gray-200 transition text-sm"
>
Filter anwenden
</button>
</div>
<div><label className="block text-xs text-gray-500 mb-1">Service Type</label><select value={mp.typeFilter} onChange={(e) => mp.setTypeFilter(e.target.value)} className="border rounded px-3 py-2 text-sm"><option value="all">Alle Typen</option>{['backend','database','ai','communication','storage','infrastructure','monitoring','security'].map(t => (<option key={t} value={t}>{t.charAt(0).toUpperCase()+t.slice(1)}</option>))}</select></div>
<div><label className="block text-xs text-gray-500 mb-1">Criticality</label><select value={mp.criticalityFilter} onChange={(e) => mp.setCriticalityFilter(e.target.value)} className="border rounded px-3 py-2 text-sm"><option value="all">Alle</option>{['critical','high','medium','low'].map(c => (<option key={c} value={c}>{c.charAt(0).toUpperCase()+c.slice(1)}</option>))}</select></div>
<div><label className="block text-xs text-gray-500 mb-1">PII</label><select value={mp.piiFilter === null ? 'all' : String(mp.piiFilter)} onChange={(e) => { const val = e.target.value; mp.setPiiFilter(val === 'all' ? null : val === 'true') }} className="border rounded px-3 py-2 text-sm"><option value="all">Alle</option><option value="true">Verarbeitet PII</option><option value="false">Keine PII</option></select></div>
<div><label className="block text-xs text-gray-500 mb-1">AI</label><select value={mp.aiFilter === null ? 'all' : String(mp.aiFilter)} onChange={(e) => { const val = e.target.value; mp.setAiFilter(val === 'all' ? null : val === 'true') }} className="border rounded px-3 py-2 text-sm"><option value="all">Alle</option><option value="true">Mit AI</option><option value="false">Ohne AI</option></select></div>
<div className="flex-1"><label className="block text-xs text-gray-500 mb-1">Suche</label><input type="text" placeholder="Service, Beschreibung, Technologie..." value={mp.searchTerm} onChange={(e) => mp.setSearchTerm(e.target.value)} className="border rounded px-3 py-2 text-sm w-full" /></div>
<div className="pt-5"><button onClick={mp.fetchModules} className="px-4 py-2 bg-gray-100 rounded hover:bg-gray-200 transition text-sm">Filter anwenden</button></div>
</div>
</div>
{/* Error */}
{error && (
<div className="bg-red-50 border border-red-200 text-red-700 p-4 rounded-lg">
{error}
</div>
)}
{mp.error && (<div className="bg-red-50 border border-red-200 text-red-700 p-4 rounded-lg">{mp.error}</div>)}
{mp.loading && (<div className="text-center py-12 text-gray-500">Lade Module...</div>)}
{/* Loading */}
{loading && (
<div className="text-center py-12 text-gray-500">
Lade Module...
</div>
)}
{/* Main Content - Two Column Layout */}
{!loading && (
{/* Main Content */}
{!mp.loading && (
<div className="flex gap-6">
{/* Module List */}
<div className="flex-1 space-y-4">
{Object.entries(modulesByType).map(([type, typeModules]) => (
{Object.entries(mp.modulesByType).map(([type, typeModules]) => (
<div key={type} className="bg-white rounded-lg shadow border">
<div className={`px-4 py-2 border-b ${SERVICE_TYPE_CONFIG[type]?.bgColor || 'bg-gray-100'}`}>
<span className="text-lg mr-2">{SERVICE_TYPE_CONFIG[type]?.icon || '📁'}</span>
<span className={`font-semibold ${SERVICE_TYPE_CONFIG[type]?.color || 'text-gray-700'}`}>
{type.charAt(0).toUpperCase() + type.slice(1)}
</span>
<span className={`font-semibold ${SERVICE_TYPE_CONFIG[type]?.color || 'text-gray-700'}`}>{type.charAt(0).toUpperCase() + type.slice(1)}</span>
<span className="text-gray-500 ml-2">({typeModules.length})</span>
</div>
<div className="divide-y">
{typeModules.map((module) => (
<div
key={module.id}
onClick={() => fetchModuleDetail(module.name)}
className={`p-4 cursor-pointer hover:bg-gray-50 transition ${
selectedModule?.id === module.id ? 'bg-blue-50' : ''
}`}
>
{typeModules.map((mod) => (
<div key={mod.id} onClick={() => mp.fetchModuleDetail(mod.name)} className={`p-4 cursor-pointer hover:bg-gray-50 transition ${mp.selectedModule?.id === mod.id ? 'bg-blue-50' : ''}`}>
<div className="flex justify-between items-start">
<div className="flex-1">
<div className="flex items-center gap-2">
<span className="font-medium text-gray-900">{module.display_name}</span>
{module.port && (
<span className="text-xs text-gray-400">:{module.port}</span>
)}
</div>
<div className="text-sm text-gray-500 mt-1">{module.name}</div>
{module.description && (
<div className="text-sm text-gray-600 mt-1 line-clamp-2">
{module.description}
</div>
)}
<div className="flex items-center gap-2"><span className="font-medium text-gray-900">{mod.display_name}</span>{mod.port && (<span className="text-xs text-gray-400">:{mod.port}</span>)}</div>
<div className="text-sm text-gray-500 mt-1">{mod.name}</div>
{mod.description && (<div className="text-sm text-gray-600 mt-1 line-clamp-2">{mod.description}</div>)}
<div className="flex flex-wrap gap-1 mt-2">
{module.technology_stack.slice(0, 4).map((tech, i) => (
<span key={i} className="px-2 py-0.5 bg-gray-100 text-gray-600 text-xs rounded">
{tech}
</span>
))}
{module.technology_stack.length > 4 && (
<span className="px-2 py-0.5 text-gray-400 text-xs">
+{module.technology_stack.length - 4}
</span>
)}
{mod.technology_stack.slice(0, 4).map((tech, i) => (<span key={i} className="px-2 py-0.5 bg-gray-100 text-gray-600 text-xs rounded">{tech}</span>))}
{mod.technology_stack.length > 4 && (<span className="px-2 py-0.5 text-gray-400 text-xs">+{mod.technology_stack.length - 4}</span>)}
</div>
</div>
<div className="flex flex-col items-end gap-1">
<span className={`px-2 py-0.5 text-xs rounded ${
CRITICALITY_CONFIG[module.criticality]?.bgColor || 'bg-gray-100'
} ${CRITICALITY_CONFIG[module.criticality]?.color || 'text-gray-700'}`}>
{module.criticality}
</span>
<span className={`px-2 py-0.5 text-xs rounded ${CRITICALITY_CONFIG[mod.criticality]?.bgColor || 'bg-gray-100'} ${CRITICALITY_CONFIG[mod.criticality]?.color || 'text-gray-700'}`}>{mod.criticality}</span>
<div className="flex gap-1 mt-1">
{module.processes_pii && (
<span className="px-1.5 py-0.5 bg-purple-100 text-purple-700 text-xs rounded" title="Verarbeitet PII">
PII
</span>
)}
{module.ai_components && (
<span className="px-1.5 py-0.5 bg-pink-100 text-pink-700 text-xs rounded" title="AI-Komponenten">
AI
</span>
)}
</div>
<div className="text-xs text-gray-400 mt-1">
{module.regulation_count} Regulations
{mod.processes_pii && (<span className="px-1.5 py-0.5 bg-purple-100 text-purple-700 text-xs rounded" title="Verarbeitet PII">PII</span>)}
{mod.ai_components && (<span className="px-1.5 py-0.5 bg-pink-100 text-pink-700 text-xs rounded" title="AI-Komponenten">AI</span>)}
</div>
<div className="text-xs text-gray-400 mt-1">{mod.regulation_count} Regulations</div>
</div>
</div>
</div>
@@ -450,292 +86,33 @@ export default function ModulesPage() {
</div>
</div>
))}
{filteredModules.length === 0 && !loading && (
<div className="text-center py-12 text-gray-500 bg-white rounded-lg shadow border">
Keine Module gefunden.
<button
onClick={() => seedModules(false)}
className="text-blue-600 hover:underline ml-1"
>
Jetzt seeden?
</button>
</div>
{mp.filteredModules.length === 0 && !mp.loading && (
<div className="text-center py-12 text-gray-500 bg-white rounded-lg shadow border">Keine Module gefunden.<button onClick={() => mp.seedModules(false)} className="text-blue-600 hover:underline ml-1">Jetzt seeden?</button></div>
)}
</div>
{/* Detail Panel */}
{selectedModule && (
<div className="w-96 bg-white rounded-lg shadow border sticky top-6 h-fit">
<div className={`px-4 py-3 border-b ${SERVICE_TYPE_CONFIG[selectedModule.service_type]?.bgColor || 'bg-gray-100'}`}>
<div className="flex items-center justify-between">
<span className="text-lg">{SERVICE_TYPE_CONFIG[selectedModule.service_type]?.icon || '📁'}</span>
<button
onClick={() => setSelectedModule(null)}
className="text-gray-400 hover:text-gray-600"
>
</button>
</div>
<h3 className="font-bold text-lg mt-2">{selectedModule.display_name}</h3>
<div className="text-sm text-gray-600">{selectedModule.name}</div>
</div>
{loadingDetail ? (
<div className="p-4 text-center text-gray-500">Lade Details...</div>
) : (
<div className="p-4 space-y-4">
{/* Description */}
{selectedModule.description && (
<div>
<div className="text-xs text-gray-500 uppercase mb-1">Beschreibung</div>
<div className="text-sm text-gray-700">{selectedModule.description}</div>
</div>
)}
{/* Technical Details */}
<div className="grid grid-cols-2 gap-2 text-sm">
{selectedModule.port && (
<div>
<span className="text-gray-500">Port:</span>
<span className="ml-1 font-mono">{selectedModule.port}</span>
</div>
)}
<div>
<span className="text-gray-500">Criticality:</span>
<span className={`ml-1 px-1.5 py-0.5 rounded text-xs ${
CRITICALITY_CONFIG[selectedModule.criticality]?.bgColor || ''
} ${CRITICALITY_CONFIG[selectedModule.criticality]?.color || ''}`}>
{selectedModule.criticality}
</span>
</div>
</div>
{/* Technology Stack */}
<div>
<div className="text-xs text-gray-500 uppercase mb-1">Tech Stack</div>
<div className="flex flex-wrap gap-1">
{selectedModule.technology_stack.map((tech, i) => (
<span key={i} className="px-2 py-0.5 bg-gray-100 text-gray-700 text-xs rounded">
{tech}
</span>
))}
</div>
</div>
{/* Data Categories */}
{selectedModule.data_categories.length > 0 && (
<div>
<div className="text-xs text-gray-500 uppercase mb-1">Daten-Kategorien</div>
<div className="flex flex-wrap gap-1">
{selectedModule.data_categories.map((cat, i) => (
<span key={i} className="px-2 py-0.5 bg-blue-50 text-blue-700 text-xs rounded">
{cat}
</span>
))}
</div>
</div>
)}
{/* Flags */}
<div className="flex flex-wrap gap-2">
{selectedModule.processes_pii && (
<span className="px-2 py-1 bg-purple-100 text-purple-700 text-xs rounded">
Verarbeitet PII
</span>
)}
{selectedModule.ai_components && (
<span className="px-2 py-1 bg-pink-100 text-pink-700 text-xs rounded">
AI-Komponenten
</span>
)}
{selectedModule.processes_health_data && (
<span className="px-2 py-1 bg-red-100 text-red-700 text-xs rounded">
Gesundheitsdaten
</span>
)}
</div>
{/* Regulations */}
{selectedModule.regulations && selectedModule.regulations.length > 0 && (
<div>
<div className="text-xs text-gray-500 uppercase mb-2">
Applicable Regulations ({selectedModule.regulations.length})
</div>
<div className="space-y-2">
{selectedModule.regulations.map((reg, i) => (
<div key={i} className="p-2 bg-gray-50 rounded text-sm">
<div className="flex justify-between items-start">
<span className="font-medium">{reg.code}</span>
<span className={`px-1.5 py-0.5 rounded text-xs ${
RELEVANCE_CONFIG[reg.relevance_level]?.bgColor || 'bg-gray-100'
} ${RELEVANCE_CONFIG[reg.relevance_level]?.color || 'text-gray-700'}`}>
{reg.relevance_level}
</span>
</div>
<div className="text-gray-500 text-xs">{reg.name}</div>
{reg.notes && (
<div className="text-gray-600 text-xs mt-1 italic">{reg.notes}</div>
)}
</div>
))}
</div>
</div>
)}
{/* Owner */}
{selectedModule.owner_team && (
<div>
<div className="text-xs text-gray-500 uppercase mb-1">Owner</div>
<div className="text-sm text-gray-700">{selectedModule.owner_team}</div>
</div>
)}
{/* Repository */}
{selectedModule.repository_path && (
<div>
<div className="text-xs text-gray-500 uppercase mb-1">Repository</div>
<code className="text-xs bg-gray-100 px-2 py-1 rounded block">
{selectedModule.repository_path}
</code>
</div>
)}
{/* AI Risk Assessment Button */}
<div className="pt-2 border-t">
<button
onClick={() => assessModuleRisk(selectedModule.id)}
disabled={loadingRisk}
className="w-full px-4 py-2 bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg hover:from-purple-700 hover:to-pink-700 transition disabled:opacity-50 flex items-center justify-center gap-2"
>
{loadingRisk ? (
<>
<span className="animate-spin"></span>
AI analysiert...
</>
) : (
<>
<span>🤖</span>
AI Risikobewertung
</>
)}
</button>
</div>
{/* AI Risk Assessment Panel */}
{showRiskPanel && (
<div className="mt-4 p-4 bg-gradient-to-br from-purple-50 to-pink-50 rounded-lg border border-purple-200">
<div className="flex justify-between items-center mb-3">
<h4 className="font-semibold text-purple-900 flex items-center gap-2">
<span>🤖</span> AI Risikobewertung
</h4>
<button
onClick={() => setShowRiskPanel(false)}
className="text-purple-400 hover:text-purple-600"
>
</button>
</div>
{loadingRisk ? (
<div className="text-center py-4 text-purple-600">
<div className="animate-pulse">Analysiere Compliance-Risiken...</div>
</div>
) : riskAssessment ? (
<div className="space-y-3">
{/* Overall Risk */}
<div className="flex items-center gap-2">
<span className="text-sm text-gray-600">Gesamtrisiko:</span>
<span className={`px-2 py-1 rounded text-sm font-medium ${
riskAssessment.overall_risk === 'critical' ? 'bg-red-100 text-red-700' :
riskAssessment.overall_risk === 'high' ? 'bg-orange-100 text-orange-700' :
riskAssessment.overall_risk === 'medium' ? 'bg-yellow-100 text-yellow-700' :
'bg-green-100 text-green-700'
}`}>
{riskAssessment.overall_risk.toUpperCase()}
</span>
<span className="text-xs text-gray-400">
({Math.round(riskAssessment.confidence_score * 100)}% Konfidenz)
</span>
</div>
{/* Risk Factors */}
{riskAssessment.risk_factors.length > 0 && (
<div>
<div className="text-xs text-gray-500 uppercase mb-1">Risikofaktoren</div>
<div className="space-y-1">
{riskAssessment.risk_factors.map((factor, i) => (
<div key={i} className="flex items-center justify-between text-sm bg-white/50 rounded px-2 py-1">
<span className="text-gray-700">{factor.factor}</span>
<span className={`text-xs px-1.5 py-0.5 rounded ${
factor.severity === 'critical' || factor.severity === 'high'
? 'bg-red-100 text-red-600'
: 'bg-yellow-100 text-yellow-600'
}`}>
{factor.severity}
</span>
</div>
))}
</div>
</div>
)}
{/* Compliance Gaps */}
{riskAssessment.compliance_gaps.length > 0 && (
<div>
<div className="text-xs text-gray-500 uppercase mb-1">Compliance-Lücken</div>
<ul className="text-sm text-gray-700 space-y-1">
{riskAssessment.compliance_gaps.map((gap, i) => (
<li key={i} className="flex items-start gap-1">
<span className="text-red-500"></span>
<span>{gap}</span>
</li>
))}
</ul>
</div>
)}
{/* Recommendations */}
{riskAssessment.recommendations.length > 0 && (
<div>
<div className="text-xs text-gray-500 uppercase mb-1">Empfehlungen</div>
<ul className="text-sm text-gray-700 space-y-1">
{riskAssessment.recommendations.map((rec, i) => (
<li key={i} className="flex items-start gap-1">
<span className="text-green-500"></span>
<span>{rec}</span>
</li>
))}
</ul>
</div>
)}
</div>
) : (
<div className="text-center py-4 text-gray-500 text-sm">
Klicken Sie auf &quot;AI Risikobewertung&quot; um eine Analyse zu starten.
</div>
)}
</div>
)}
</div>
)}
</div>
{mp.selectedModule && (
<ModuleDetailPanel
module={mp.selectedModule}
loadingDetail={mp.loadingDetail}
loadingRisk={mp.loadingRisk}
showRiskPanel={mp.showRiskPanel}
riskAssessment={mp.riskAssessment}
onClose={() => mp.setSelectedModule(null)}
onAssessRisk={mp.assessModuleRisk}
onCloseRisk={() => mp.setShowRiskPanel(false)}
/>
)}
</div>
)}
{/* Regulations Coverage Overview */}
{overview && overview.regulations_coverage && Object.keys(overview.regulations_coverage).length > 0 && (
{/* Regulations Coverage */}
{mp.overview && mp.overview.regulations_coverage && Object.keys(mp.overview.regulations_coverage).length > 0 && (
<div className="bg-white rounded-lg shadow border p-4">
<h3 className="font-semibold text-gray-900 mb-4">Regulation Coverage</h3>
<div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-3">
{Object.entries(overview.regulations_coverage)
.sort(([, a], [, b]) => b - a)
.map(([code, count]) => (
<div key={code} className="bg-gray-50 rounded p-3 text-center">
<div className="text-2xl font-bold text-blue-600">{count}</div>
<div className="text-xs text-gray-600 truncate" title={code}>{code}</div>
</div>
{Object.entries(mp.overview.regulations_coverage).sort(([, a], [, b]) => b - a).map(([code, count]) => (
<div key={code} className="bg-gray-50 rounded p-3 text-center"><div className="text-2xl font-bold text-blue-600">{count}</div><div className="text-xs text-gray-600 truncate" title={code}>{code}</div></div>
))}
</div>
</div>
@@ -0,0 +1,162 @@
'use client'
import {
Source, ScraperStatus, ScrapeResult,
PDFDocument, PDFExtractionResult,
} from './types'
import SourceCard from './SourceCard'
interface ScraperTabsProps {
activeTab: string
sources: Source[]
pdfDocuments: PDFDocument[]
status: ScraperStatus | null
scraping: boolean
extracting: boolean
results: ScrapeResult[]
pdfResult: PDFExtractionResult | null
handleScrapeAll: () => void
handleScrapeSingle: (code: string, force: boolean) => void
handleExtractPdf: (code: string, saveToDb: boolean, force: boolean) => void
}
export default function ScraperTabs(props: ScraperTabsProps) {
const { activeTab, sources, pdfDocuments, status, scraping, extracting, results, pdfResult } = props
if (activeTab === 'sources') {
return (
<div>
<div className="flex justify-between items-center mb-6">
<div>
<h3 className="text-lg font-semibold text-slate-900">Regulierungsquellen</h3>
<p className="text-sm text-slate-500">EU-Lex, BSI-TR und deutsche Gesetze</p>
</div>
<button onClick={props.handleScrapeAll} disabled={scraping} className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2">
{scraping ? (<><svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24"><circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" /><path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" /></svg>Laeuft...</>) : (<><svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" /></svg>Alle Quellen scrapen</>)}
</button>
</div>
<div className="space-y-6">
<div>
<h4 className="text-sm font-medium text-slate-700 mb-3 flex items-center gap-2"><span className="text-lg">🇪🇺</span> EU-Regulierungen (EUR-Lex)</h4>
<div className="grid gap-3">{sources.filter(s => s.source_type === 'eur_lex').map(source => (<SourceCard key={source.code} source={source} onScrape={props.handleScrapeSingle} scraping={scraping} />))}</div>
</div>
<div>
<h4 className="text-sm font-medium text-slate-700 mb-3 flex items-center gap-2"><span className="text-lg">🔒</span> BSI Technical Guidelines</h4>
<div className="grid gap-3">{sources.filter(s => s.source_type === 'bsi_pdf').map(source => (<SourceCard key={source.code} source={source} onScrape={props.handleScrapeSingle} scraping={scraping} />))}</div>
</div>
</div>
</div>
)
}
if (activeTab === 'pdf') {
return (
<div>
<div className="mb-6">
<h3 className="text-lg font-semibold text-slate-900">PDF-Extraktion (PyMuPDF)</h3>
<p className="text-sm text-slate-500">Extrahiert ALLE Pruefaspekte aus BSI-TR-03161 PDFs mit Regex-Pattern-Matching</p>
</div>
<div className="space-y-4">
{pdfDocuments.map(doc => (
<div key={doc.code} className="bg-slate-50 rounded-lg p-4 border border-slate-200">
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<span className="text-3xl">📄</span>
<div>
<div className="flex items-center gap-2">
<span className="font-semibold text-slate-900">{doc.code}</span>
<span className={`px-2 py-0.5 rounded text-xs font-medium ${doc.available ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700'}`}>{doc.available ? 'Verfuegbar' : 'Nicht gefunden'}</span>
</div>
<div className="text-sm text-slate-600">{doc.name}</div>
<div className="text-xs text-slate-500">{doc.description}</div>
<div className="text-xs text-slate-400 mt-1">Erwartete Pruefaspekte: {doc.expected_aspects}</div>
</div>
</div>
<div className="flex gap-2">
<button onClick={() => props.handleExtractPdf(doc.code, true, false)} disabled={extracting || !doc.available} className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2">
{extracting ? (<><svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24"><circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" /><path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" /></svg>Extrahiere...</>) : (<><svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z" /></svg>Extrahieren</>)}
</button>
<button onClick={() => props.handleExtractPdf(doc.code, true, true)} disabled={extracting || !doc.available} className="px-3 py-2 bg-orange-100 text-orange-700 rounded-lg hover:bg-orange-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed" title="Force: Loescht vorhandene und extrahiert neu">Force</button>
</div>
</div>
</div>
))}
</div>
{pdfResult && (
<div className="mt-6 bg-green-50 rounded-lg p-4 border border-green-200">
<h4 className="font-semibold text-green-800 mb-3">Letztes Extraktions-Ergebnis</h4>
<div className="grid grid-cols-3 gap-4 mb-4">
<div className="text-center p-3 bg-white rounded-lg"><div className="text-2xl font-bold text-green-700">{pdfResult.total_aspects}</div><div className="text-sm text-slate-500">Pruefaspekte gefunden</div></div>
<div className="text-center p-3 bg-white rounded-lg"><div className="text-2xl font-bold text-blue-700">{pdfResult.requirements_created}</div><div className="text-sm text-slate-500">Requirements erstellt</div></div>
<div className="text-center p-3 bg-white rounded-lg"><div className="text-2xl font-bold text-slate-700">{Object.keys(pdfResult.statistics.by_category || {}).length}</div><div className="text-sm text-slate-500">Kategorien</div></div>
</div>
{pdfResult.statistics.by_category && Object.keys(pdfResult.statistics.by_category).length > 0 && (
<div><h5 className="text-sm font-medium text-slate-700 mb-2">Nach Kategorie:</h5><div className="flex flex-wrap gap-2">{Object.entries(pdfResult.statistics.by_category).map(([cat, count]) => (<span key={cat} className="px-2 py-1 bg-white rounded text-xs text-slate-600">{cat}: <strong>{count}</strong></span>))}</div></div>
)}
</div>
)}
<div className="mt-6 bg-blue-50 rounded-lg p-4 border border-blue-200">
<h4 className="font-semibold text-blue-800 mb-2">Wie funktioniert die PDF-Extraktion?</h4>
<ul className="text-sm text-blue-700 space-y-1">
<li>- <strong>PyMuPDF (fitz)</strong> liest den PDF-Text</li>
<li>- <strong>Regex-Pattern</strong> finden Aspekte wie O.Auth_1, O.Sess_2, T.Network_1</li>
<li>- <strong>Kontextanalyse</strong> extrahiert Titel, Kategorie und Anforderungsstufe (MUSS/SOLL/KANN)</li>
<li>- <strong>Automatische Speicherung</strong> erstellt Requirements in der Datenbank</li>
</ul>
</div>
</div>
)
}
if (activeTab === 'status' && status) {
return (
<div className="space-y-6">
<div className="bg-slate-50 rounded-lg p-6">
<div className="flex items-center justify-between mb-4">
<div><h3 className="text-lg font-semibold text-slate-900">Scraper-Status</h3><p className="text-sm text-slate-500">Letzter Lauf: {status.stats.last_run ? new Date(status.stats.last_run).toLocaleString('de-DE') : 'Noch nie'}</p></div>
<div className={`px-3 py-1.5 rounded-full text-sm font-medium ${status.status === 'running' ? 'bg-blue-100 text-blue-700' : status.status === 'error' ? 'bg-red-100 text-red-700' : status.status === 'completed' ? 'bg-green-100 text-green-700' : 'bg-gray-100 text-gray-700'}`}>
{status.status === 'running' ? 'Laeuft' : status.status === 'error' ? 'Fehler' : status.status === 'completed' ? 'Abgeschlossen' : 'Bereit'}
</div>
</div>
<div className="grid grid-cols-3 gap-4">
<div className="text-center p-4 bg-white rounded-lg"><div className="text-2xl font-bold text-slate-900">{status.stats.sources_processed}</div><div className="text-sm text-slate-500">Quellen verarbeitet</div></div>
<div className="text-center p-4 bg-white rounded-lg"><div className="text-2xl font-bold text-green-600">{status.stats.requirements_extracted}</div><div className="text-sm text-slate-500">Anforderungen extrahiert</div></div>
<div className="text-center p-4 bg-white rounded-lg"><div className="text-2xl font-bold text-red-600">{status.stats.errors}</div><div className="text-sm text-slate-500">Fehler</div></div>
</div>
{status.last_error && (<div className="mt-4 p-3 bg-red-50 rounded-lg text-sm text-red-700"><strong>Letzter Fehler:</strong> {status.last_error}</div>)}
</div>
<div className="bg-white border border-slate-200 rounded-lg p-6">
<h4 className="font-semibold text-slate-900 mb-4">Wie funktioniert der Scraper?</h4>
<div className="space-y-3 text-sm text-slate-600">
{[{ n: '1', t: 'EUR-Lex Abruf', d: 'Holt HTML-Version der EU-Verordnung, extrahiert Artikel und Absaetze' }, { n: '2', t: 'BSI-TR Parsing', d: 'Extrahiert Pruefaspekte (O.Auth_1, O.Sess_1, etc.) aus den TR-Dokumenten' }, { n: '3', t: 'Datenbank-Speicherung', d: 'Jede Anforderung wird als Requirement in der Compliance-DB gespeichert' }].map(s => (
<div key={s.n} className="flex items-start gap-3"><div className="w-6 h-6 bg-blue-100 rounded-full flex items-center justify-center text-blue-600 font-bold">{s.n}</div><div><strong>{s.t}</strong>: {s.d}</div></div>
))}
<div className="flex items-start gap-3"><div className="w-6 h-6 bg-green-100 rounded-full flex items-center justify-center text-green-600 font-bold"></div><div><strong>Audit-Workspace</strong>: Anforderungen koennen mit Implementierungsdetails angereichert werden</div></div>
</div>
</div>
</div>
)
}
// logs tab
return (
<div>
<h3 className="text-lg font-semibold text-slate-900 mb-4">Letzte Ergebnisse</h3>
{results.length === 0 ? (
<div className="text-center py-12 text-slate-500">Keine Ergebnisse vorhanden. Starte einen Scrape-Vorgang.</div>
) : (
<div className="space-y-2">
{results.map((result, idx) => (
<div key={idx} className={`p-3 rounded-lg flex items-center justify-between ${result.error ? 'bg-red-50' : result.reason ? 'bg-yellow-50' : 'bg-green-50'}`}>
<div className="flex items-center gap-3">
<span className="text-lg">{result.error ? '❌' : result.reason ? '⏭️' : '✅'}</span>
<span className="font-medium">{result.code}</span>
<span className="text-sm text-slate-500">{result.error || result.reason || `${result.requirements_extracted} Anforderungen`}</span>
</div>
</div>
))}
</div>
)}
</div>
)
}
@@ -0,0 +1,73 @@
'use client'
import { Source, regulationTypeBadge, sourceTypeBadge } from './types'
export default function SourceCard({
source,
onScrape,
scraping,
}: {
source: Source
onScrape: (code: string, force: boolean) => void
scraping: boolean
}) {
const regType = regulationTypeBadge[source.regulation_type] || regulationTypeBadge.industry_standard
const srcType = sourceTypeBadge[source.source_type] || sourceTypeBadge.manual
return (
<div className="bg-white border border-slate-200 rounded-lg p-4 hover:shadow-sm transition-shadow">
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<span className="text-2xl">{regType.icon}</span>
<div>
<div className="flex items-center gap-2">
<span className="font-semibold text-slate-900">{source.code}</span>
<span className={`px-2 py-0.5 rounded text-xs font-medium ${regType.color}`}>
{regType.label}
</span>
<span className={`px-2 py-0.5 rounded text-xs font-medium ${srcType.color}`}>
{srcType.label}
</span>
</div>
<div className="text-sm text-slate-500 truncate max-w-md" title={source.url}>
{source.url.length > 60 ? source.url.substring(0, 60) + '...' : source.url}
</div>
</div>
</div>
<div className="flex items-center gap-3">
{source.has_data ? (
<span className="px-3 py-1 bg-green-100 text-green-700 rounded-full text-sm font-medium">
{source.requirement_count} Anforderungen
</span>
) : (
<span className="px-3 py-1 bg-gray-100 text-gray-500 rounded-full text-sm">
Keine Daten
</span>
)}
<div className="flex gap-1">
<button
onClick={() => onScrape(source.code, false)}
disabled={scraping}
className="px-3 py-1.5 text-sm bg-slate-100 text-slate-700 rounded hover:bg-slate-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
title="Scrapen (ueberspringt vorhandene)"
>
Scrapen
</button>
{source.has_data && (
<button
onClick={() => onScrape(source.code, true)}
disabled={scraping}
className="px-3 py-1.5 text-sm bg-orange-100 text-orange-700 rounded hover:bg-orange-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
title="Force: Loescht vorhandene Daten und scraped neu"
>
Force
</button>
)}
</div>
</div>
</div>
</div>
)
}
@@ -0,0 +1,65 @@
export interface Source {
code: string
url: string
source_type: string
regulation_type: string
has_data: boolean
requirement_count: number
}
export interface ScraperStatus {
status: 'idle' | 'running' | 'completed' | 'error'
current_source: string | null
last_error: string | null
stats: {
sources_processed: number
requirements_extracted: number
errors: number
last_run: string | null
}
known_sources: string[]
}
export interface ScrapeResult {
code: string
status: string
requirements_extracted?: number
reason?: string
error?: string
}
export interface PDFDocument {
code: string
name: string
description: string
expected_aspects: string
available: boolean
}
export interface PDFExtractionResult {
success: boolean
source_document: string
total_aspects: number
requirements_created: number
statistics: {
by_category: Record<string, number>
by_requirement_level: Record<string, number>
}
}
export const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
export const sourceTypeBadge: Record<string, { label: string; color: string }> = {
eur_lex: { label: 'EUR-Lex', color: 'bg-blue-100 text-blue-800' },
bsi_pdf: { label: 'BSI PDF', color: 'bg-green-100 text-green-800' },
gesetze_im_internet: { label: 'Gesetze', color: 'bg-yellow-100 text-yellow-800' },
manual: { label: 'Manuell', color: 'bg-gray-100 text-gray-800' },
}
export const regulationTypeBadge: Record<string, { label: string; color: string; icon: string }> = {
eu_regulation: { label: 'EU-Verordnung', color: 'bg-indigo-100 text-indigo-800', icon: '🇪🇺' },
eu_directive: { label: 'EU-Richtlinie', color: 'bg-purple-100 text-purple-800', icon: '📜' },
de_law: { label: 'DE-Gesetz', color: 'bg-yellow-100 text-yellow-800', icon: '🇩🇪' },
bsi_standard: { label: 'BSI-Standard', color: 'bg-green-100 text-green-800', icon: '🔒' },
industry_standard: { label: 'Standard', color: 'bg-gray-100 text-gray-800', icon: '📋' },
}
@@ -0,0 +1,106 @@
'use client'
import { useState, useEffect, useCallback } from 'react'
import {
Source, ScraperStatus, ScrapeResult,
PDFDocument, PDFExtractionResult, BACKEND_URL,
} from './types'
export function useComplianceScraper() {
const [activeTab, setActiveTab] = useState<'sources' | 'pdf' | 'status' | 'logs'>('sources')
const [sources, setSources] = useState<Source[]>([])
const [pdfDocuments, setPdfDocuments] = useState<PDFDocument[]>([])
const [status, setStatus] = useState<ScraperStatus | null>(null)
const [loading, setLoading] = useState(true)
const [scraping, setScraping] = useState(false)
const [extracting, setExtracting] = useState(false)
const [error, setError] = useState<string | null>(null)
const [success, setSuccess] = useState<string | null>(null)
const [results, setResults] = useState<ScrapeResult[]>([])
const [pdfResult, setPdfResult] = useState<PDFExtractionResult | null>(null)
const fetchSources = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/sources`)
if (res.ok) { const data = await res.json(); setSources(data.sources || []) }
} catch (err) { console.error('Failed to fetch sources:', err) }
}, [])
const fetchPdfDocuments = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/pdf-documents`)
if (res.ok) { const data = await res.json(); setPdfDocuments(data.documents || []) }
} catch (err) { console.error('Failed to fetch PDF documents:', err) }
}, [])
const fetchStatus = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/status`)
if (res.ok) { const data = await res.json(); setStatus(data) }
} catch (err) { console.error('Failed to fetch status:', err) }
}, [])
useEffect(() => {
const loadData = async () => {
setLoading(true)
await Promise.all([fetchSources(), fetchStatus(), fetchPdfDocuments()])
setLoading(false)
}
loadData()
}, [fetchSources, fetchStatus, fetchPdfDocuments])
useEffect(() => {
if (scraping) { const interval = setInterval(fetchStatus, 2000); return () => clearInterval(interval) }
}, [scraping, fetchStatus])
const handleScrapeAll = async () => {
setScraping(true); setError(null); setSuccess(null); setResults([])
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/scrape-all`, { method: 'POST' })
if (!res.ok) { const data = await res.json(); throw new Error(data.detail || 'Scraping fehlgeschlagen') }
const data = await res.json()
setResults([...data.results.success, ...data.results.failed, ...data.results.skipped])
setSuccess(`Scraping abgeschlossen: ${data.results.success.length} erfolgreich, ${data.results.skipped.length} uebersprungen, ${data.results.failed.length} fehlgeschlagen`)
await fetchSources()
} catch (err: any) { setError(err.message) }
finally { setScraping(false) }
}
const handleScrapeSingle = async (code: string, force: boolean = false) => {
setScraping(true); setError(null); setSuccess(null)
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/scrape/${code}?force=${force}`, { method: 'POST' })
if (!res.ok) { const data = await res.json(); throw new Error(data.detail || 'Scraping fehlgeschlagen') }
const data = await res.json()
if (data.status === 'skipped') { setSuccess(`${code}: Bereits vorhanden (${data.requirement_count} Anforderungen)`) }
else { setSuccess(`${code}: ${data.requirements_extracted} Anforderungen extrahiert`) }
await fetchSources()
} catch (err: any) { setError(err.message) }
finally { setScraping(false) }
}
const handleExtractPdf = async (code: string, saveToDb: boolean = true, force: boolean = false) => {
setExtracting(true); setError(null); setSuccess(null); setPdfResult(null)
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/extract-pdf`, {
method: 'POST', headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ document_code: code, save_to_db: saveToDb, force }),
})
if (!res.ok) { const data = await res.json(); throw new Error(data.detail || 'PDF-Extraktion fehlgeschlagen') }
const data: PDFExtractionResult = await res.json()
setPdfResult(data)
if (data.success) { setSuccess(`${code}: ${data.total_aspects} Pruefaspekte extrahiert, ${data.requirements_created} Requirements erstellt`) }
await fetchSources()
} catch (err: any) { setError(err.message) }
finally { setExtracting(false) }
}
useEffect(() => { if (success) { const timer = setTimeout(() => setSuccess(null), 5000); return () => clearTimeout(timer) } }, [success])
useEffect(() => { if (error) { const timer = setTimeout(() => setError(null), 10000); return () => clearTimeout(timer) } }, [error])
return {
activeTab, setActiveTab, sources, pdfDocuments, status,
loading, scraping, extracting, error, success, results, pdfResult,
handleScrapeAll, handleScrapeSingle, handleExtractPdf,
}
}
+38 -709
View File
@@ -7,283 +7,20 @@
* - EUR-Lex regulations (GDPR, AI Act, CRA, NIS2, etc.)
* - BSI Technical Guidelines (TR-03161)
* - German laws
*
* Similar pattern to edu-search and zeugnisse-crawler.
*/
import { useState, useEffect, useCallback } from 'react'
import AdminLayout from '@/components/admin/AdminLayout'
import SystemInfoSection, { SYSTEM_INFO_CONFIGS } from '@/components/admin/SystemInfoSection'
// Types
interface Source {
code: string
url: string
source_type: string
regulation_type: string
has_data: boolean
requirement_count: number
}
interface ScraperStatus {
status: 'idle' | 'running' | 'completed' | 'error'
current_source: string | null
last_error: string | null
stats: {
sources_processed: number
requirements_extracted: number
errors: number
last_run: string | null
}
known_sources: string[]
}
interface ScrapeResult {
code: string
status: string
requirements_extracted?: number
reason?: string
error?: string
}
interface PDFDocument {
code: string
name: string
description: string
expected_aspects: string
available: boolean
}
interface PDFExtractionResult {
success: boolean
source_document: string
total_aspects: number
requirements_created: number
statistics: {
by_category: Record<string, number>
by_requirement_level: Record<string, number>
}
}
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
// Source type badges
const sourceTypeBadge: Record<string, { label: string; color: string }> = {
eur_lex: { label: 'EUR-Lex', color: 'bg-blue-100 text-blue-800' },
bsi_pdf: { label: 'BSI PDF', color: 'bg-green-100 text-green-800' },
gesetze_im_internet: { label: 'Gesetze', color: 'bg-yellow-100 text-yellow-800' },
manual: { label: 'Manuell', color: 'bg-gray-100 text-gray-800' },
}
// Regulation type badges
const regulationTypeBadge: Record<string, { label: string; color: string; icon: string }> = {
eu_regulation: { label: 'EU-Verordnung', color: 'bg-indigo-100 text-indigo-800', icon: '🇪🇺' },
eu_directive: { label: 'EU-Richtlinie', color: 'bg-purple-100 text-purple-800', icon: '📜' },
de_law: { label: 'DE-Gesetz', color: 'bg-yellow-100 text-yellow-800', icon: '🇩🇪' },
bsi_standard: { label: 'BSI-Standard', color: 'bg-green-100 text-green-800', icon: '🔒' },
industry_standard: { label: 'Standard', color: 'bg-gray-100 text-gray-800', icon: '📋' },
}
import { useComplianceScraper } from './_components/useComplianceScraper'
import ScraperTabs from './_components/ScraperTabs'
export default function ComplianceScraperPage() {
const [activeTab, setActiveTab] = useState<'sources' | 'pdf' | 'status' | 'logs'>('sources')
const [sources, setSources] = useState<Source[]>([])
const [pdfDocuments, setPdfDocuments] = useState<PDFDocument[]>([])
const [status, setStatus] = useState<ScraperStatus | null>(null)
const [loading, setLoading] = useState(true)
const [scraping, setScraping] = useState(false)
const [extracting, setExtracting] = useState(false)
const [error, setError] = useState<string | null>(null)
const [success, setSuccess] = useState<string | null>(null)
const [results, setResults] = useState<ScrapeResult[]>([])
const [pdfResult, setPdfResult] = useState<PDFExtractionResult | null>(null)
const scraper = useComplianceScraper()
// Fetch sources
const fetchSources = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/sources`)
if (res.ok) {
const data = await res.json()
setSources(data.sources || [])
}
} catch (err) {
console.error('Failed to fetch sources:', err)
}
}, [])
// Fetch PDF documents
const fetchPdfDocuments = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/pdf-documents`)
if (res.ok) {
const data = await res.json()
setPdfDocuments(data.documents || [])
}
} catch (err) {
console.error('Failed to fetch PDF documents:', err)
}
}, [])
// Fetch status
const fetchStatus = useCallback(async () => {
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/status`)
if (res.ok) {
const data = await res.json()
setStatus(data)
}
} catch (err) {
console.error('Failed to fetch status:', err)
}
}, [])
// Initial load
useEffect(() => {
const loadData = async () => {
setLoading(true)
await Promise.all([fetchSources(), fetchStatus(), fetchPdfDocuments()])
setLoading(false)
}
loadData()
}, [fetchSources, fetchStatus, fetchPdfDocuments])
// Poll status while scraping
useEffect(() => {
if (scraping) {
const interval = setInterval(fetchStatus, 2000)
return () => clearInterval(interval)
}
}, [scraping, fetchStatus])
// Scrape all sources
const handleScrapeAll = async () => {
setScraping(true)
setError(null)
setSuccess(null)
setResults([])
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/scrape-all`, {
method: 'POST',
})
if (!res.ok) {
const data = await res.json()
throw new Error(data.detail || 'Scraping fehlgeschlagen')
}
const data = await res.json()
setResults([
...data.results.success,
...data.results.failed,
...data.results.skipped,
])
setSuccess(`Scraping abgeschlossen: ${data.results.success.length} erfolgreich, ${data.results.skipped.length} uebersprungen, ${data.results.failed.length} fehlgeschlagen`)
// Refresh sources
await fetchSources()
} catch (err: any) {
setError(err.message)
} finally {
setScraping(false)
}
}
// Scrape single source
const handleScrapeSingle = async (code: string, force: boolean = false) => {
setScraping(true)
setError(null)
setSuccess(null)
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/scrape/${code}?force=${force}`, {
method: 'POST',
})
if (!res.ok) {
const data = await res.json()
throw new Error(data.detail || 'Scraping fehlgeschlagen')
}
const data = await res.json()
if (data.status === 'skipped') {
setSuccess(`${code}: Bereits vorhanden (${data.requirement_count} Anforderungen)`)
} else {
setSuccess(`${code}: ${data.requirements_extracted} Anforderungen extrahiert`)
}
// Refresh sources
await fetchSources()
} catch (err: any) {
setError(err.message)
} finally {
setScraping(false)
}
}
// Extract PDF
const handleExtractPdf = async (code: string, saveToDb: boolean = true, force: boolean = false) => {
setExtracting(true)
setError(null)
setSuccess(null)
setPdfResult(null)
try {
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/extract-pdf`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
document_code: code,
save_to_db: saveToDb,
force: force,
}),
})
if (!res.ok) {
const data = await res.json()
throw new Error(data.detail || 'PDF-Extraktion fehlgeschlagen')
}
const data: PDFExtractionResult = await res.json()
setPdfResult(data)
if (data.success) {
setSuccess(`${code}: ${data.total_aspects} Pruefaspekte extrahiert, ${data.requirements_created} Requirements erstellt`)
}
// Refresh sources
await fetchSources()
} catch (err: any) {
setError(err.message)
} finally {
setExtracting(false)
}
}
// Clear messages
useEffect(() => {
if (success) {
const timer = setTimeout(() => setSuccess(null), 5000)
return () => clearTimeout(timer)
}
}, [success])
useEffect(() => {
if (error) {
const timer = setTimeout(() => setError(null), 10000)
return () => clearTimeout(timer)
}
}, [error])
// Stats cards
const StatsCard = ({ title, value, subtitle, icon }: { title: string; value: number | string; subtitle?: string; icon: string }) => (
<div className="bg-white rounded-lg shadow-sm p-5 border border-slate-200">
<div className="flex items-center">
<div className="flex-shrink-0">
<span className="text-2xl">{icon}</span>
</div>
<div className="flex-shrink-0"><span className="text-2xl">{icon}</span></div>
<div className="ml-4">
<p className="text-sm font-medium text-slate-500">{title}</p>
<p className="text-2xl font-semibold text-slate-900">{value}</p>
@@ -294,12 +31,8 @@ export default function ComplianceScraperPage() {
)
return (
<AdminLayout
title="Compliance Scraper"
description="Extrahiert Anforderungen aus EU-Regulierungen, BSI-Standards und Gesetzen"
>
{/* Loading */}
{loading && (
<AdminLayout title="Compliance Scraper" description="Extrahiert Anforderungen aus EU-Regulierungen, BSI-Standards und Gesetzen">
{scraper.loading && (
<div className="flex items-center justify-center py-12">
<svg className="w-8 h-8 animate-spin text-primary-600" fill="none" viewBox="0 0 24 24">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
@@ -309,481 +42,77 @@ export default function ComplianceScraperPage() {
</div>
)}
{!loading && (
{!scraper.loading && (
<>
{/* Messages */}
{error && (
<div className="mb-4 bg-red-50 border border-red-200 text-red-700 px-4 py-3 rounded-lg">
{error}
</div>
{scraper.error && (
<div className="mb-4 bg-red-50 border border-red-200 text-red-700 px-4 py-3 rounded-lg">{scraper.error}</div>
)}
{success && (
<div className="mb-4 bg-green-50 border border-green-200 text-green-700 px-4 py-3 rounded-lg">
{success}
</div>
{scraper.success && (
<div className="mb-4 bg-green-50 border border-green-200 text-green-700 px-4 py-3 rounded-lg">{scraper.success}</div>
)}
{/* Stats Cards */}
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-6">
<StatsCard
title="Bekannte Quellen"
value={sources.length}
icon="📚"
/>
<StatsCard
title="Mit Daten"
value={sources.filter(s => s.has_data).length}
subtitle={`${sources.length - sources.filter(s => s.has_data).length} noch zu scrapen`}
icon="✅"
/>
<StatsCard
title="Anforderungen gesamt"
value={sources.reduce((acc, s) => acc + s.requirement_count, 0)}
icon="📋"
/>
<StatsCard
title="Letzter Lauf"
value={status?.stats.last_run ? new Date(status.stats.last_run).toLocaleDateString('de-DE') : 'Nie'}
subtitle={status?.stats.errors ? `${status.stats.errors} Fehler` : undefined}
icon="🕐"
/>
<StatsCard title="Bekannte Quellen" value={scraper.sources.length} icon="📚" />
<StatsCard title="Mit Daten" value={scraper.sources.filter(s => s.has_data).length} subtitle={`${scraper.sources.length - scraper.sources.filter(s => s.has_data).length} noch zu scrapen`} icon="✅" />
<StatsCard title="Anforderungen gesamt" value={scraper.sources.reduce((acc, s) => acc + s.requirement_count, 0)} icon="📋" />
<StatsCard title="Letzter Lauf" value={scraper.status?.stats.last_run ? new Date(scraper.status.stats.last_run).toLocaleDateString('de-DE') : 'Nie'} subtitle={scraper.status?.stats.errors ? `${scraper.status.stats.errors} Fehler` : undefined} icon="🕐" />
</div>
{/* Scraper Status Bar */}
{(scraping || status?.status === 'running') && (
{(scraper.scraping || scraper.status?.status === 'running') && (
<div className="mb-6 p-4 bg-blue-50 border border-blue-200 rounded-lg">
<div className="flex items-center">
<div className="animate-spin rounded-full h-4 w-4 border-2 border-blue-600 border-t-transparent mr-3" />
<div>
<p className="font-medium text-blue-800">Scraper laeuft</p>
{status?.current_source && (
<p className="text-sm text-blue-600">Aktuell: {status.current_source}</p>
)}
{scraper.status?.current_source && (<p className="text-sm text-blue-600">Aktuell: {scraper.status.current_source}</p>)}
</div>
</div>
</div>
)}
{/* Tabs */}
<div className="bg-white rounded-xl shadow-sm border border-slate-200 mb-6">
<div className="border-b border-slate-200">
<nav className="flex -mb-px">
{[
{ id: 'sources', name: 'Quellen', icon: '📚' },
{ id: 'pdf', name: 'PDF-Extraktion', icon: '📄' },
{ id: 'status', name: 'Status', icon: '📊' },
{ id: 'logs', name: 'Ergebnisse', icon: '📝' },
{ id: 'sources' as const, name: 'Quellen', icon: '📚' },
{ id: 'pdf' as const, name: 'PDF-Extraktion', icon: '📄' },
{ id: 'status' as const, name: 'Status', icon: '📊' },
{ id: 'logs' as const, name: 'Ergebnisse', icon: '📝' },
].map(tab => (
<button
key={tab.id}
onClick={() => setActiveTab(tab.id as typeof activeTab)}
className={`px-6 py-4 text-sm font-medium border-b-2 transition-colors ${
activeTab === tab.id
? 'border-primary-600 text-primary-600'
: 'border-transparent text-slate-500 hover:text-slate-700 hover:border-slate-300'
}`}
>
<span className="mr-2">{tab.icon}</span>
{tab.name}
<button key={tab.id} onClick={() => scraper.setActiveTab(tab.id)} className={`px-6 py-4 text-sm font-medium border-b-2 transition-colors ${scraper.activeTab === tab.id ? 'border-primary-600 text-primary-600' : 'border-transparent text-slate-500 hover:text-slate-700 hover:border-slate-300'}`}>
<span className="mr-2">{tab.icon}</span>{tab.name}
</button>
))}
</nav>
</div>
<div className="p-6">
{/* Sources Tab */}
{activeTab === 'sources' && (
<div>
{/* Header */}
<div className="flex justify-between items-center mb-6">
<div>
<h3 className="text-lg font-semibold text-slate-900">Regulierungsquellen</h3>
<p className="text-sm text-slate-500">EU-Lex, BSI-TR und deutsche Gesetze</p>
</div>
<button
onClick={handleScrapeAll}
disabled={scraping}
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
>
{scraping ? (
<>
<svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" />
</svg>
Laeuft...
</>
) : (
<>
<svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
</svg>
Alle Quellen scrapen
</>
)}
</button>
</div>
{/* Sources by Type */}
<div className="space-y-6">
{/* EU Regulations */}
<div>
<h4 className="text-sm font-medium text-slate-700 mb-3 flex items-center gap-2">
<span className="text-lg">🇪🇺</span> EU-Regulierungen (EUR-Lex)
</h4>
<div className="grid gap-3">
{sources.filter(s => s.source_type === 'eur_lex').map(source => (
<SourceCard key={source.code} source={source} onScrape={handleScrapeSingle} scraping={scraping} />
))}
</div>
</div>
{/* BSI Standards */}
<div>
<h4 className="text-sm font-medium text-slate-700 mb-3 flex items-center gap-2">
<span className="text-lg">🔒</span> BSI Technical Guidelines
</h4>
<div className="grid gap-3">
{sources.filter(s => s.source_type === 'bsi_pdf').map(source => (
<SourceCard key={source.code} source={source} onScrape={handleScrapeSingle} scraping={scraping} />
))}
</div>
</div>
</div>
</div>
)}
{/* PDF Extraction Tab */}
{activeTab === 'pdf' && (
<div>
<div className="mb-6">
<h3 className="text-lg font-semibold text-slate-900">PDF-Extraktion (PyMuPDF)</h3>
<p className="text-sm text-slate-500">
Extrahiert ALLE Pruefaspekte aus BSI-TR-03161 PDFs mit Regex-Pattern-Matching
</p>
</div>
{/* PDF Documents */}
<div className="space-y-4">
{pdfDocuments.map(doc => (
<div key={doc.code} className="bg-slate-50 rounded-lg p-4 border border-slate-200">
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<span className="text-3xl">📄</span>
<div>
<div className="flex items-center gap-2">
<span className="font-semibold text-slate-900">{doc.code}</span>
<span className={`px-2 py-0.5 rounded text-xs font-medium ${
doc.available ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700'
}`}>
{doc.available ? 'Verfuegbar' : 'Nicht gefunden'}
</span>
</div>
<div className="text-sm text-slate-600">{doc.name}</div>
<div className="text-xs text-slate-500">{doc.description}</div>
<div className="text-xs text-slate-400 mt-1">
Erwartete Pruefaspekte: {doc.expected_aspects}
</div>
</div>
</div>
<div className="flex gap-2">
<button
onClick={() => handleExtractPdf(doc.code, true, false)}
disabled={extracting || !doc.available}
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
>
{extracting ? (
<>
<svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" />
</svg>
Extrahiere...
</>
) : (
<>
<svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z" />
</svg>
Extrahieren
</>
)}
</button>
<button
onClick={() => handleExtractPdf(doc.code, true, true)}
disabled={extracting || !doc.available}
className="px-3 py-2 bg-orange-100 text-orange-700 rounded-lg hover:bg-orange-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
title="Force: Loescht vorhandene und extrahiert neu"
>
Force
</button>
</div>
</div>
</div>
))}
</div>
{/* Last Extraction Result */}
{pdfResult && (
<div className="mt-6 bg-green-50 rounded-lg p-4 border border-green-200">
<h4 className="font-semibold text-green-800 mb-3">Letztes Extraktions-Ergebnis</h4>
<div className="grid grid-cols-3 gap-4 mb-4">
<div className="text-center p-3 bg-white rounded-lg">
<div className="text-2xl font-bold text-green-700">{pdfResult.total_aspects}</div>
<div className="text-sm text-slate-500">Pruefaspekte gefunden</div>
</div>
<div className="text-center p-3 bg-white rounded-lg">
<div className="text-2xl font-bold text-blue-700">{pdfResult.requirements_created}</div>
<div className="text-sm text-slate-500">Requirements erstellt</div>
</div>
<div className="text-center p-3 bg-white rounded-lg">
<div className="text-2xl font-bold text-slate-700">{Object.keys(pdfResult.statistics.by_category || {}).length}</div>
<div className="text-sm text-slate-500">Kategorien</div>
</div>
</div>
{/* Category Breakdown */}
{pdfResult.statistics.by_category && Object.keys(pdfResult.statistics.by_category).length > 0 && (
<div>
<h5 className="text-sm font-medium text-slate-700 mb-2">Nach Kategorie:</h5>
<div className="flex flex-wrap gap-2">
{Object.entries(pdfResult.statistics.by_category).map(([cat, count]) => (
<span key={cat} className="px-2 py-1 bg-white rounded text-xs text-slate-600">
{cat}: <strong>{count}</strong>
</span>
))}
</div>
</div>
)}
</div>
)}
{/* Info Box */}
<div className="mt-6 bg-blue-50 rounded-lg p-4 border border-blue-200">
<h4 className="font-semibold text-blue-800 mb-2">Wie funktioniert die PDF-Extraktion?</h4>
<ul className="text-sm text-blue-700 space-y-1">
<li> <strong>PyMuPDF (fitz)</strong> liest den PDF-Text</li>
<li> <strong>Regex-Pattern</strong> finden Aspekte wie O.Auth_1, O.Sess_2, T.Network_1</li>
<li> <strong>Kontextanalyse</strong> extrahiert Titel, Kategorie und Anforderungsstufe (MUSS/SOLL/KANN)</li>
<li> <strong>Automatische Speicherung</strong> erstellt Requirements in der Datenbank</li>
</ul>
</div>
</div>
)}
{/* Status Tab */}
{activeTab === 'status' && status && (
<div className="space-y-6">
{/* Current Status */}
<div className="bg-slate-50 rounded-lg p-6">
<div className="flex items-center justify-between mb-4">
<div>
<h3 className="text-lg font-semibold text-slate-900">Scraper-Status</h3>
<p className="text-sm text-slate-500">
Letzter Lauf: {status.stats.last_run ? new Date(status.stats.last_run).toLocaleString('de-DE') : 'Noch nie'}
</p>
</div>
<div className={`px-3 py-1.5 rounded-full text-sm font-medium ${
status.status === 'running' ? 'bg-blue-100 text-blue-700' :
status.status === 'error' ? 'bg-red-100 text-red-700' :
status.status === 'completed' ? 'bg-green-100 text-green-700' :
'bg-gray-100 text-gray-700'
}`}>
{status.status === 'running' ? '🔄 Laeuft' :
status.status === 'error' ? '❌ Fehler' :
status.status === 'completed' ? '✅ Abgeschlossen' :
'⏸️ Bereit'}
</div>
</div>
<div className="grid grid-cols-3 gap-4">
<div className="text-center p-4 bg-white rounded-lg">
<div className="text-2xl font-bold text-slate-900">{status.stats.sources_processed}</div>
<div className="text-sm text-slate-500">Quellen verarbeitet</div>
</div>
<div className="text-center p-4 bg-white rounded-lg">
<div className="text-2xl font-bold text-green-600">{status.stats.requirements_extracted}</div>
<div className="text-sm text-slate-500">Anforderungen extrahiert</div>
</div>
<div className="text-center p-4 bg-white rounded-lg">
<div className="text-2xl font-bold text-red-600">{status.stats.errors}</div>
<div className="text-sm text-slate-500">Fehler</div>
</div>
</div>
{status.last_error && (
<div className="mt-4 p-3 bg-red-50 rounded-lg text-sm text-red-700">
<strong>Letzter Fehler:</strong> {status.last_error}
</div>
)}
</div>
{/* Process Description */}
<div className="bg-white border border-slate-200 rounded-lg p-6">
<h4 className="font-semibold text-slate-900 mb-4">Wie funktioniert der Scraper?</h4>
<div className="space-y-3 text-sm text-slate-600">
<div className="flex items-start gap-3">
<div className="w-6 h-6 bg-blue-100 rounded-full flex items-center justify-center text-blue-600 font-bold">1</div>
<div>
<strong>EUR-Lex Abruf</strong>: Holt HTML-Version der EU-Verordnung, extrahiert Artikel und Absaetze
</div>
</div>
<div className="flex items-start gap-3">
<div className="w-6 h-6 bg-blue-100 rounded-full flex items-center justify-center text-blue-600 font-bold">2</div>
<div>
<strong>BSI-TR Parsing</strong>: Extrahiert Pruefaspekte (O.Auth_1, O.Sess_1, etc.) aus den TR-Dokumenten
</div>
</div>
<div className="flex items-start gap-3">
<div className="w-6 h-6 bg-blue-100 rounded-full flex items-center justify-center text-blue-600 font-bold">3</div>
<div>
<strong>Datenbank-Speicherung</strong>: Jede Anforderung wird als Requirement in der Compliance-DB gespeichert
</div>
</div>
<div className="flex items-start gap-3">
<div className="w-6 h-6 bg-green-100 rounded-full flex items-center justify-center text-green-600 font-bold"></div>
<div>
<strong>Audit-Workspace</strong>: Anforderungen koennen mit Implementierungsdetails angereichert werden
</div>
</div>
</div>
</div>
</div>
)}
{/* Results Tab */}
{activeTab === 'logs' && (
<div>
<h3 className="text-lg font-semibold text-slate-900 mb-4">Letzte Ergebnisse</h3>
{results.length === 0 ? (
<div className="text-center py-12 text-slate-500">
Keine Ergebnisse vorhanden. Starte einen Scrape-Vorgang.
</div>
) : (
<div className="space-y-2">
{results.map((result, idx) => (
<div
key={idx}
className={`p-3 rounded-lg flex items-center justify-between ${
result.error ? 'bg-red-50' :
result.reason ? 'bg-yellow-50' :
'bg-green-50'
}`}
>
<div className="flex items-center gap-3">
<span className="text-lg">
{result.error ? '❌' : result.reason ? '⏭️' : '✅'}
</span>
<span className="font-medium">{result.code}</span>
<span className="text-sm text-slate-500">
{result.error || result.reason || `${result.requirements_extracted} Anforderungen`}
</span>
</div>
</div>
))}
</div>
)}
</div>
)}
<ScraperTabs
activeTab={scraper.activeTab}
sources={scraper.sources}
pdfDocuments={scraper.pdfDocuments}
status={scraper.status}
scraping={scraper.scraping}
extracting={scraper.extracting}
results={scraper.results}
pdfResult={scraper.pdfResult}
handleScrapeAll={scraper.handleScrapeAll}
handleScrapeSingle={scraper.handleScrapeSingle}
handleExtractPdf={scraper.handleExtractPdf}
/>
</div>
</div>
</>
)}
{/* System Info Section */}
<div className="mt-8 border-t border-slate-200 pt-8">
<SystemInfoSection config={SYSTEM_INFO_CONFIGS.complianceScraper || {
title: 'Compliance Scraper',
description: 'Regulation & Requirements Extraction Service',
version: '1.0.0',
features: [
'EUR-Lex HTML Parsing',
'BSI-TR PDF Extraction',
'Automatic Requirement Mapping',
'Incremental Updates',
],
technicalDetails: {
'Backend': 'Python/FastAPI',
'HTTP Client': 'httpx async',
'HTML Parser': 'BeautifulSoup4',
'PDF Parser': 'PyMuPDF (optional)',
'Database': 'PostgreSQL',
},
features: ['EUR-Lex HTML Parsing', 'BSI-TR PDF Extraction', 'Automatic Requirement Mapping', 'Incremental Updates'],
technicalDetails: { 'Backend': 'Python/FastAPI', 'HTTP Client': 'httpx async', 'HTML Parser': 'BeautifulSoup4', 'PDF Parser': 'PyMuPDF (optional)', 'Database': 'PostgreSQL' },
}} />
</div>
</AdminLayout>
)
}
// Source Card Component
function SourceCard({
source,
onScrape,
scraping
}: {
source: Source
onScrape: (code: string, force: boolean) => void
scraping: boolean
}) {
const regType = regulationTypeBadge[source.regulation_type] || regulationTypeBadge.industry_standard
const srcType = sourceTypeBadge[source.source_type] || sourceTypeBadge.manual
return (
<div className="bg-white border border-slate-200 rounded-lg p-4 hover:shadow-sm transition-shadow">
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<span className="text-2xl">{regType.icon}</span>
<div>
<div className="flex items-center gap-2">
<span className="font-semibold text-slate-900">{source.code}</span>
<span className={`px-2 py-0.5 rounded text-xs font-medium ${regType.color}`}>
{regType.label}
</span>
<span className={`px-2 py-0.5 rounded text-xs font-medium ${srcType.color}`}>
{srcType.label}
</span>
</div>
<div className="text-sm text-slate-500 truncate max-w-md" title={source.url}>
{source.url.length > 60 ? source.url.substring(0, 60) + '...' : source.url}
</div>
</div>
</div>
<div className="flex items-center gap-3">
{source.has_data ? (
<span className="px-3 py-1 bg-green-100 text-green-700 rounded-full text-sm font-medium">
{source.requirement_count} Anforderungen
</span>
) : (
<span className="px-3 py-1 bg-gray-100 text-gray-500 rounded-full text-sm">
Keine Daten
</span>
)}
<div className="flex gap-1">
<button
onClick={() => onScrape(source.code, false)}
disabled={scraping}
className="px-3 py-1.5 text-sm bg-slate-100 text-slate-700 rounded hover:bg-slate-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
title="Scrapen (ueberspringt vorhandene)"
>
Scrapen
</button>
{source.has_data && (
<button
onClick={() => onScrape(source.code, true)}
disabled={scraping}
className="px-3 py-1.5 text-sm bg-orange-100 text-orange-700 rounded hover:bg-orange-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
title="Force: Loescht vorhandene Daten und scraped neu"
>
Force
</button>
)}
</div>
</div>
</div>
</div>
)
}
@@ -0,0 +1,184 @@
'use client'
import { WebsiteContent, FeatureContent } from '@/lib/content-types'
import HeroEditor from './HeroEditor'
interface ContentEditorTabsProps {
activeTab: string
content: WebsiteContent
isRTL: boolean
t: (key: string) => string
updateHero: (field: any, value: string) => void
updateFeature: (index: number, field: keyof FeatureContent, value: string) => void
updateFAQ: (index: number, field: 'question' | 'answer', value: string | string[]) => void
addFAQ: () => void
removeFAQ: (index: number) => void
updatePricing: (index: number, field: string, value: string | number | boolean) => void
updateTrust: (key: 'item1' | 'item2' | 'item3', field: 'value' | 'label', value: string) => void
updateTestimonial: (field: 'quote' | 'author' | 'role', value: string) => void
}
export default function ContentEditorTabs(props: ContentEditorTabsProps) {
const { activeTab, content, isRTL, t } = props
const dir = isRTL ? 'rtl' : 'ltr'
if (activeTab === 'hero') {
return <HeroEditor content={content} isRTL={isRTL} t={t} updateHero={props.updateHero} />
}
if (activeTab === 'features') {
return (
<div className="space-y-6">
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_features')}</h2>
{content.features.map((feature, index) => (
<div key={feature.id} className="border border-slate-200 rounded-lg p-4">
<div className="grid gap-4">
<div className="grid grid-cols-3 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Icon</label>
<input type="text" value={feature.icon} onChange={(e) => props.updateFeature(index, 'icon', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg text-2xl text-center" />
</div>
<div className="col-span-2">
<label className="block text-sm font-medium text-slate-700 mb-1">Titel</label>
<input type="text" value={feature.title} onChange={(e) => props.updateFeature(index, 'title', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Beschreibung</label>
<textarea value={feature.description} onChange={(e) => props.updateFeature(index, 'description', e.target.value)} rows={2} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
</div>
</div>
))}
</div>
)
}
if (activeTab === 'faq') {
return (
<div className="space-y-6">
<div className={`flex items-center justify-between ${isRTL ? 'flex-row-reverse' : ''}`}>
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_faq')}</h2>
<button onClick={props.addFAQ} className="px-4 py-2 bg-slate-100 text-slate-700 rounded-lg hover:bg-slate-200 transition-colors">{t('admin_add_faq')}</button>
</div>
{content.faq.map((item, index) => (
<div key={index} className="border border-slate-200 rounded-lg p-4">
<div className={`flex items-start justify-between gap-4 ${isRTL ? 'flex-row-reverse' : ''}`}>
<div className="flex-1 space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">{t('admin_question')} {index + 1}</label>
<input type="text" value={item.question} onChange={(e) => props.updateFAQ(index, 'question', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">{t('admin_answer')}</label>
<textarea value={item.answer.join('\n')} onChange={(e) => props.updateFAQ(index, 'answer', e.target.value)} rows={4} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500 font-mono text-sm" dir={dir} />
</div>
</div>
<button onClick={() => props.removeFAQ(index)} className="p-2 text-red-600 hover:bg-red-50 rounded-lg transition-colors" title="Frage entfernen">
<svg className="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
</svg>
</button>
</div>
</div>
))}
</div>
)
}
if (activeTab === 'pricing') {
return (
<div className="space-y-6">
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_pricing')}</h2>
{content.pricing.map((plan, index) => (
<div key={plan.id} className="border border-slate-200 rounded-lg p-4">
<div className="grid gap-4">
<div className="grid grid-cols-4 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Name</label>
<input type="text" value={plan.name} onChange={(e) => props.updatePricing(index, 'name', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Preis (EUR)</label>
<input type="number" step="0.01" value={plan.price} onChange={(e) => props.updatePricing(index, 'price', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Intervall</label>
<input type="text" value={plan.interval} onChange={(e) => props.updatePricing(index, 'interval', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
<div className="flex items-end">
<label className="flex items-center gap-2">
<input type="checkbox" checked={plan.popular || false} onChange={(e) => props.updatePricing(index, 'popular', e.target.checked)} className="w-4 h-4 text-primary-600 rounded" />
<span className="text-sm text-slate-700">{t('pricing_popular')}</span>
</label>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Beschreibung</label>
<input type="text" value={plan.description} onChange={(e) => props.updatePricing(index, 'description', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">{t('pricing_tasks')}</label>
<input type="text" value={plan.features.tasks} onChange={(e) => props.updatePricing(index, 'features.tasks', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Aufgaben-Beschreibung</label>
<input type="text" value={plan.features.taskDescription} onChange={(e) => props.updatePricing(index, 'features.taskDescription', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Features (eine pro Zeile)</label>
<textarea value={plan.features.included.join('\n')} onChange={(e) => props.updatePricing(index, 'features.included', e.target.value)} rows={4} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500 font-mono text-sm" dir={dir} />
</div>
</div>
</div>
))}
</div>
)
}
// 'other' tab
return (
<div className="space-y-8">
<div>
<h2 className="text-xl font-semibold text-slate-900 mb-4">Trust Indicators</h2>
<div className="grid grid-cols-3 gap-4">
{(['item1', 'item2', 'item3'] as const).map((key, index) => (
<div key={key} className="border border-slate-200 rounded-lg p-4">
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Wert {index + 1}</label>
<input type="text" value={content.trust[key].value} onChange={(e) => props.updateTrust(key, 'value', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Label {index + 1}</label>
<input type="text" value={content.trust[key].label} onChange={(e) => props.updateTrust(key, 'label', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
</div>
</div>
))}
</div>
</div>
<div>
<h2 className="text-xl font-semibold text-slate-900 mb-4">Testimonial</h2>
<div className="border border-slate-200 rounded-lg p-4 space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Zitat</label>
<textarea value={content.testimonial.quote} onChange={(e) => props.updateTestimonial('quote', e.target.value)} rows={3} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Autor</label>
<input type="text" value={content.testimonial.author} onChange={(e) => props.updateTestimonial('author', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Rolle</label>
<input type="text" value={content.testimonial.role} onChange={(e) => props.updateTestimonial('role', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
</div>
</div>
</div>
</div>
</div>
)
}
@@ -0,0 +1,106 @@
'use client'
import { WebsiteContent, HeroContent } from '@/lib/content-types'
interface HeroEditorProps {
content: WebsiteContent
isRTL: boolean
t: (key: string) => string
updateHero: (field: keyof HeroContent, value: string) => void
}
export default function HeroEditor({ content, isRTL, t, updateHero }: HeroEditorProps) {
const dir = isRTL ? 'rtl' : 'ltr'
return (
<div className="space-y-6">
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_hero')} Section</h2>
<div className="grid gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Badge</label>
<input
type="text"
value={content.hero.badge}
onChange={(e) => updateHero('badge', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={dir}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Titel (vor Highlight)</label>
<input
type="text"
value={content.hero.title}
onChange={(e) => updateHero('title', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={dir}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Highlight 1</label>
<input
type="text"
value={content.hero.titleHighlight1}
onChange={(e) => updateHero('titleHighlight1', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={dir}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Highlight 2</label>
<input
type="text"
value={content.hero.titleHighlight2}
onChange={(e) => updateHero('titleHighlight2', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={dir}
/>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Untertitel</label>
<textarea
value={content.hero.subtitle}
onChange={(e) => updateHero('subtitle', e.target.value)}
rows={3}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={dir}
/>
</div>
<div className="grid grid-cols-3 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">CTA Primaer</label>
<input
type="text"
value={content.hero.ctaPrimary}
onChange={(e) => updateHero('ctaPrimary', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={dir}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">CTA Sekundaer</label>
<input
type="text"
value={content.hero.ctaSecondary}
onChange={(e) => updateHero('ctaSecondary', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={dir}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">CTA Hinweis</label>
<input
type="text"
value={content.hero.ctaHint}
onChange={(e) => updateHero('ctaHint', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={dir}
/>
</div>
</div>
</div>
</div>
)
}
@@ -0,0 +1,74 @@
'use client'
import { RefObject } from 'react'
interface LivePreviewPanelProps {
activeTab: string
iframeRef: RefObject<HTMLIFrameElement | null>
}
export default function LivePreviewPanel({ activeTab, iframeRef }: LivePreviewPanelProps) {
return (
<div className="bg-white rounded-xl border border-slate-200 shadow-sm overflow-hidden">
{/* Preview Header */}
<div className="bg-slate-50 border-b border-slate-200 px-4 py-3 flex items-center justify-between">
<div className="flex items-center gap-2">
<div className="flex gap-1.5">
<div className="w-3 h-3 rounded-full bg-red-400"></div>
<div className="w-3 h-3 rounded-full bg-yellow-400"></div>
<div className="w-3 h-3 rounded-full bg-green-400"></div>
</div>
<span className="text-xs text-slate-500 ml-2">localhost:3000</span>
</div>
<div className="flex items-center gap-2">
<span className="text-xs font-medium text-slate-600 bg-slate-200 px-2 py-1 rounded">
{activeTab === 'hero' && 'Hero Section'}
{activeTab === 'features' && 'Features'}
{activeTab === 'faq' && 'FAQ'}
{activeTab === 'pricing' && 'Pricing'}
{activeTab === 'other' && 'Trust & Testimonial'}
</span>
<button
onClick={() => iframeRef.current?.contentWindow?.location.reload()}
className="p-1.5 text-slate-500 hover:text-slate-700 hover:bg-slate-200 rounded transition-colors"
title="Preview neu laden"
>
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
</svg>
</button>
</div>
</div>
{/* Preview Frame */}
<div className="relative h-[calc(100vh-340px)] bg-slate-100">
<iframe
ref={iframeRef}
src={`/?preview=true&section=${activeTab}#${activeTab}`}
className="w-full h-full border-0 scale-75 origin-top-left"
style={{
width: '133.33%',
height: '133.33%',
transform: 'scale(0.75)',
transformOrigin: 'top left',
}}
title="Website Preview"
sandbox="allow-same-origin allow-scripts"
/>
<div className="absolute bottom-4 left-4 right-4 bg-blue-600 text-white px-4 py-2 rounded-lg shadow-lg flex items-center gap-2 text-sm">
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
<span>
Du bearbeitest: <strong>
{activeTab === 'hero' && 'Hero Section (Startbereich)'}
{activeTab === 'features' && 'Features (Funktionen)'}
{activeTab === 'faq' && 'FAQ (Haeufige Fragen)'}
{activeTab === 'pricing' && 'Pricing (Preise)'}
{activeTab === 'other' && 'Trust & Testimonial'}
</strong>
</span>
</div>
</div>
</div>
)
}
@@ -0,0 +1,11 @@
export const ADMIN_KEY = 'breakpilot-admin-2024'
export const SECTION_MAP: Record<string, { selector: string; scrollTo: string }> = {
hero: { selector: '#hero', scrollTo: 'hero' },
features: { selector: '#features', scrollTo: 'features' },
faq: { selector: '#faq', scrollTo: 'faq' },
pricing: { selector: '#pricing', scrollTo: 'pricing' },
other: { selector: '#trust', scrollTo: 'trust' },
}
export type ContentTab = 'hero' | 'features' | 'faq' | 'pricing' | 'other'
@@ -0,0 +1,173 @@
'use client'
import { useState, useEffect, useRef, useCallback } from 'react'
import { WebsiteContent, HeroContent, FeatureContent } from '@/lib/content-types'
import { useLanguage } from '@/lib/LanguageContext'
import { ADMIN_KEY, SECTION_MAP, ContentTab } from './types'
export function useContentEditor() {
const { language, setLanguage, t, isRTL } = useLanguage()
const [content, setContent] = useState<WebsiteContent | null>(null)
const [loading, setLoading] = useState(true)
const [saving, setSaving] = useState(false)
const [message, setMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null)
const [activeTab, setActiveTab] = useState<ContentTab>('hero')
const [showPreview, setShowPreview] = useState(true)
const iframeRef = useRef<HTMLIFrameElement>(null)
const scrollToSection = useCallback((tab: string) => {
if (!iframeRef.current?.contentWindow) return
const section = SECTION_MAP[tab]
if (section) {
try {
iframeRef.current.contentWindow.postMessage(
{ type: 'scrollTo', section: section.scrollTo },
'*'
)
} catch {
// Same-origin policy - fallback
}
}
}, [])
useEffect(() => {
scrollToSection(activeTab)
}, [activeTab, scrollToSection])
useEffect(() => {
loadContent()
}, [])
async function loadContent() {
try {
const res = await fetch('/api/content')
if (res.ok) {
const data = await res.json()
setContent(data)
} else {
setMessage({ type: 'error', text: t('admin_error') })
}
} catch (error) {
setMessage({ type: 'error', text: t('admin_error') })
} finally {
setLoading(false)
}
}
async function saveChanges() {
if (!content) return
setSaving(true)
setMessage(null)
try {
const res = await fetch('/api/content', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-admin-key': ADMIN_KEY,
},
body: JSON.stringify(content),
})
if (res.ok) {
setMessage({ type: 'success', text: t('admin_saved') })
} else {
const error = await res.json()
setMessage({ type: 'error', text: error.error || t('admin_error') })
}
} catch (error) {
setMessage({ type: 'error', text: t('admin_error') })
} finally {
setSaving(false)
}
}
function updateHero(field: keyof HeroContent, value: string) {
if (!content) return
setContent({ ...content, hero: { ...content.hero, [field]: value } })
}
function updateFeature(index: number, field: keyof FeatureContent, value: string) {
if (!content) return
const newFeatures = [...content.features]
newFeatures[index] = { ...newFeatures[index], [field]: value }
setContent({ ...content, features: newFeatures })
}
function updateFAQ(index: number, field: 'question' | 'answer', value: string | string[]) {
if (!content) return
const newFAQ = [...content.faq]
if (field === 'answer' && typeof value === 'string') {
newFAQ[index] = { ...newFAQ[index], answer: value.split('\n') }
} else if (field === 'question' && typeof value === 'string') {
newFAQ[index] = { ...newFAQ[index], question: value }
}
setContent({ ...content, faq: newFAQ })
}
function addFAQ() {
if (!content) return
setContent({
...content,
faq: [...content.faq, { question: 'Neue Frage?', answer: ['Antwort hier...'] }],
})
}
function removeFAQ(index: number) {
if (!content) return
const newFAQ = content.faq.filter((_, i) => i !== index)
setContent({ ...content, faq: newFAQ })
}
function updatePricing(index: number, field: string, value: string | number | boolean) {
if (!content) return
const newPricing = [...content.pricing]
if (field === 'price') {
newPricing[index] = { ...newPricing[index], price: Number(value) }
} else if (field === 'popular') {
newPricing[index] = { ...newPricing[index], popular: Boolean(value) }
} else if (field.startsWith('features.')) {
const subField = field.replace('features.', '')
if (subField === 'included' && typeof value === 'string') {
newPricing[index] = {
...newPricing[index],
features: { ...newPricing[index].features, included: value.split('\n') },
}
} else {
newPricing[index] = {
...newPricing[index],
features: { ...newPricing[index].features, [subField]: value },
}
}
} else {
newPricing[index] = { ...newPricing[index], [field]: value }
}
setContent({ ...content, pricing: newPricing })
}
function updateTrust(key: 'item1' | 'item2' | 'item3', field: 'value' | 'label', value: string) {
if (!content) return
setContent({
...content,
trust: { ...content.trust, [key]: { ...content.trust[key], [field]: value } },
})
}
function updateTestimonial(field: 'quote' | 'author' | 'role', value: string) {
if (!content) return
setContent({
...content,
testimonial: { ...content.testimonial, [field]: value },
})
}
return {
language, setLanguage, t, isRTL,
content, loading, saving, message,
activeTab, setActiveTab,
showPreview, setShowPreview,
iframeRef,
saveChanges,
updateHero, updateFeature, updateFAQ,
addFAQ, removeFAQ, updatePricing,
updateTrust, updateTestimonial,
}
}
+51 -731
View File
@@ -4,233 +4,53 @@
* Admin Panel fuer Website-Content
*
* Erlaubt das Bearbeiten aller Website-Texte:
* - Hero Section
* - Features
* - FAQ
* - Pricing
* - Trust Indicators
* - Testimonial
* - Hero Section, Features, FAQ, Pricing, Trust Indicators, Testimonial
*
* NEU: Live-Preview der Website zeigt Kontext beim Bearbeiten
*/
import { useState, useEffect, useRef, useCallback } from 'react'
import { WebsiteContent, HeroContent, FeatureContent, FAQItem, PricingPlan } from '@/lib/content-types'
import { useLanguage } from '@/lib/LanguageContext'
import LanguageSelector from '@/components/LanguageSelector'
import AdminLayout from '@/components/admin/AdminLayout'
// Admin Key (in Produktion via Login)
const ADMIN_KEY = 'breakpilot-admin-2024'
// Mapping von Tabs zu Website-Sektionen (CSS Selektoren und Scroll-Positionen)
const SECTION_MAP: Record<string, { selector: string; scrollTo: string }> = {
hero: { selector: '#hero', scrollTo: 'hero' },
features: { selector: '#features', scrollTo: 'features' },
faq: { selector: '#faq', scrollTo: 'faq' },
pricing: { selector: '#pricing', scrollTo: 'pricing' },
other: { selector: '#trust', scrollTo: 'trust' },
}
import { useContentEditor } from './_components/useContentEditor'
import ContentEditorTabs from './_components/ContentEditorTabs'
import LivePreviewPanel from './_components/LivePreviewPanel'
export default function AdminPage() {
const { language, setLanguage, t, isRTL } = useLanguage()
const [content, setContent] = useState<WebsiteContent | null>(null)
const [loading, setLoading] = useState(true)
const [saving, setSaving] = useState(false)
const [message, setMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null)
const [activeTab, setActiveTab] = useState<'hero' | 'features' | 'faq' | 'pricing' | 'other'>('hero')
const [showPreview, setShowPreview] = useState(true)
const iframeRef = useRef<HTMLIFrameElement>(null)
const editor = useContentEditor()
// Scrollt die Preview zur entsprechenden Sektion
const scrollToSection = useCallback((tab: string) => {
if (!iframeRef.current?.contentWindow) return
const section = SECTION_MAP[tab]
if (section) {
try {
iframeRef.current.contentWindow.postMessage(
{ type: 'scrollTo', section: section.scrollTo },
'*'
)
} catch {
// Same-origin policy - fallback
}
}
}, [])
// Bei Tab-Wechsel zur Sektion scrollen
useEffect(() => {
scrollToSection(activeTab)
}, [activeTab, scrollToSection])
// Content laden
useEffect(() => {
loadContent()
}, [])
async function loadContent() {
try {
const res = await fetch('/api/content')
if (res.ok) {
const data = await res.json()
setContent(data)
} else {
setMessage({ type: 'error', text: t('admin_error') })
}
} catch (error) {
setMessage({ type: 'error', text: t('admin_error') })
} finally {
setLoading(false)
}
}
async function saveChanges() {
if (!content) return
setSaving(true)
setMessage(null)
try {
const res = await fetch('/api/content', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-admin-key': ADMIN_KEY,
},
body: JSON.stringify(content),
})
if (res.ok) {
setMessage({ type: 'success', text: t('admin_saved') })
} else {
const error = await res.json()
setMessage({ type: 'error', text: error.error || t('admin_error') })
}
} catch (error) {
setMessage({ type: 'error', text: t('admin_error') })
} finally {
setSaving(false)
}
}
// Hero Section updaten
function updateHero(field: keyof HeroContent, value: string) {
if (!content) return
setContent({
...content,
hero: { ...content.hero, [field]: value },
})
}
// Feature updaten
function updateFeature(index: number, field: keyof FeatureContent, value: string) {
if (!content) return
const newFeatures = [...content.features]
newFeatures[index] = { ...newFeatures[index], [field]: value }
setContent({ ...content, features: newFeatures })
}
// FAQ updaten
function updateFAQ(index: number, field: 'question' | 'answer', value: string | string[]) {
if (!content) return
const newFAQ = [...content.faq]
if (field === 'answer' && typeof value === 'string') {
// Split by newlines for array
newFAQ[index] = { ...newFAQ[index], answer: value.split('\n') }
} else if (field === 'question' && typeof value === 'string') {
newFAQ[index] = { ...newFAQ[index], question: value }
}
setContent({ ...content, faq: newFAQ })
}
// FAQ hinzufuegen
function addFAQ() {
if (!content) return
setContent({
...content,
faq: [...content.faq, { question: 'Neue Frage?', answer: ['Antwort hier...'] }],
})
}
// FAQ entfernen
function removeFAQ(index: number) {
if (!content) return
const newFAQ = content.faq.filter((_, i) => i !== index)
setContent({ ...content, faq: newFAQ })
}
// Pricing updaten
function updatePricing(index: number, field: string, value: string | number | boolean) {
if (!content) return
const newPricing = [...content.pricing]
if (field === 'price') {
newPricing[index] = { ...newPricing[index], price: Number(value) }
} else if (field === 'popular') {
newPricing[index] = { ...newPricing[index], popular: Boolean(value) }
} else if (field.startsWith('features.')) {
const subField = field.replace('features.', '')
if (subField === 'included' && typeof value === 'string') {
newPricing[index] = {
...newPricing[index],
features: {
...newPricing[index].features,
included: value.split('\n'),
},
}
} else {
newPricing[index] = {
...newPricing[index],
features: {
...newPricing[index].features,
[subField]: value,
},
}
}
} else {
newPricing[index] = { ...newPricing[index], [field]: value }
}
setContent({ ...content, pricing: newPricing })
}
if (loading) {
if (editor.loading) {
return (
<AdminLayout title="Übersetzungen" description="Website Content & Sprachen">
<AdminLayout title="Uebersetzungen" description="Website Content & Sprachen">
<div className="flex items-center justify-center py-12">
<div className="text-xl text-slate-600">{t('admin_loading')}</div>
<div className="text-xl text-slate-600">{editor.t('admin_loading')}</div>
</div>
</AdminLayout>
)
}
if (!content) {
if (!editor.content) {
return (
<AdminLayout title="Übersetzungen" description="Website Content & Sprachen">
<AdminLayout title="Uebersetzungen" description="Website Content & Sprachen">
<div className="flex items-center justify-center py-12">
<div className="text-xl text-red-600">{t('admin_error')}</div>
<div className="text-xl text-red-600">{editor.t('admin_error')}</div>
</div>
</AdminLayout>
)
}
return (
<AdminLayout title="Übersetzungen" description="Website Content & Sprachen">
<div className={isRTL ? 'rtl' : ''} dir={isRTL ? 'rtl' : 'ltr'}>
<AdminLayout title="Uebersetzungen" description="Website Content & Sprachen">
<div className={editor.isRTL ? 'rtl' : ''} dir={editor.isRTL ? 'rtl' : 'ltr'}>
{/* Toolbar */}
<div className={`bg-white rounded-xl border border-slate-200 p-4 mb-6 flex items-center justify-between ${isRTL ? 'flex-row-reverse' : ''}`}>
<div className={`bg-white rounded-xl border border-slate-200 p-4 mb-6 flex items-center justify-between ${editor.isRTL ? 'flex-row-reverse' : ''}`}>
<div className="flex items-center gap-4">
<LanguageSelector
currentLanguage={language}
onLanguageChange={setLanguage}
/>
{/* Preview Toggle */}
<LanguageSelector currentLanguage={editor.language} onLanguageChange={editor.setLanguage} />
<button
onClick={() => setShowPreview(!showPreview)}
onClick={() => editor.setShowPreview(!editor.showPreview)}
className={`flex items-center gap-2 px-3 py-2 rounded-lg text-sm font-medium transition-colors ${
showPreview
? 'bg-blue-100 text-blue-700'
: 'bg-slate-100 text-slate-600 hover:bg-slate-200'
editor.showPreview ? 'bg-blue-100 text-blue-700' : 'bg-slate-100 text-slate-600 hover:bg-slate-200'
}`}
title={showPreview ? 'Preview ausblenden' : 'Preview einblenden'}
title={editor.showPreview ? 'Preview ausblenden' : 'Preview einblenden'}
>
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" />
@@ -239,565 +59,65 @@ export default function AdminPage() {
Live-Preview
</button>
</div>
<div className={`flex items-center gap-4 ${isRTL ? 'flex-row-reverse' : ''}`}>
{message && (
<span
className={`px-3 py-1 rounded text-sm ${
message.type === 'success'
? 'bg-green-100 text-green-800'
: 'bg-red-100 text-red-800'
}`}
>
{message.text}
<div className={`flex items-center gap-4 ${editor.isRTL ? 'flex-row-reverse' : ''}`}>
{editor.message && (
<span className={`px-3 py-1 rounded text-sm ${
editor.message.type === 'success' ? 'bg-green-100 text-green-800' : 'bg-red-100 text-red-800'
}`}>
{editor.message.text}
</span>
)}
<button
onClick={saveChanges}
disabled={saving}
onClick={editor.saveChanges}
disabled={editor.saving}
className="bg-primary-600 text-white px-6 py-2 rounded-lg font-medium hover:bg-primary-700 disabled:opacity-50 transition-colors"
>
{saving ? t('admin_saving') : t('admin_save')}
{editor.saving ? editor.t('admin_saving') : editor.t('admin_save')}
</button>
</div>
</div>
{/* Tabs */}
<div className="mb-6">
<div className={`flex gap-1 bg-slate-100 p-1 rounded-lg w-fit ${isRTL ? 'flex-row-reverse' : ''}`}>
<div className={`flex gap-1 bg-slate-100 p-1 rounded-lg w-fit ${editor.isRTL ? 'flex-row-reverse' : ''}`}>
{(['hero', 'features', 'faq', 'pricing', 'other'] as const).map((tab) => (
<button
key={tab}
onClick={() => setActiveTab(tab)}
onClick={() => editor.setActiveTab(tab)}
className={`px-4 py-2 text-sm font-medium rounded-md transition-colors ${
activeTab === tab
? 'bg-white text-slate-900 shadow-sm'
: 'text-slate-600 hover:text-slate-900'
editor.activeTab === tab ? 'bg-white text-slate-900 shadow-sm' : 'text-slate-600 hover:text-slate-900'
}`}
>
{tab === 'hero' && t('admin_tab_hero')}
{tab === 'features' && t('admin_tab_features')}
{tab === 'faq' && t('admin_tab_faq')}
{tab === 'pricing' && t('admin_tab_pricing')}
{tab === 'other' && t('admin_tab_other')}
{tab === 'hero' && editor.t('admin_tab_hero')}
{tab === 'features' && editor.t('admin_tab_features')}
{tab === 'faq' && editor.t('admin_tab_faq')}
{tab === 'pricing' && editor.t('admin_tab_pricing')}
{tab === 'other' && editor.t('admin_tab_other')}
</button>
))}
</div>
</div>
{/* Split Layout: Editor + Preview */}
<div className={`grid gap-6 ${showPreview ? 'grid-cols-2' : 'grid-cols-1'}`}>
{/* Editor Panel */}
<div className={`grid gap-6 ${editor.showPreview ? 'grid-cols-2' : 'grid-cols-1'}`}>
<div className="bg-white rounded-xl border border-slate-200 shadow-sm p-6 max-h-[calc(100vh-280px)] overflow-y-auto">
{/* Hero Tab */}
{activeTab === 'hero' && (
<div className="space-y-6">
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_hero')} Section</h2>
<div className="grid gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Badge</label>
<input
type="text"
value={content.hero.badge}
onChange={(e) => updateHero('badge', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
<ContentEditorTabs
activeTab={editor.activeTab}
content={editor.content}
isRTL={editor.isRTL}
t={editor.t}
updateHero={editor.updateHero}
updateFeature={editor.updateFeature}
updateFAQ={editor.updateFAQ}
addFAQ={editor.addFAQ}
removeFAQ={editor.removeFAQ}
updatePricing={editor.updatePricing}
updateTrust={editor.updateTrust}
updateTestimonial={editor.updateTestimonial}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Titel (vor Highlight)
</label>
<input
type="text"
value={content.hero.title}
onChange={(e) => updateHero('title', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Highlight 1
</label>
<input
type="text"
value={content.hero.titleHighlight1}
onChange={(e) => updateHero('titleHighlight1', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Highlight 2
</label>
<input
type="text"
value={content.hero.titleHighlight2}
onChange={(e) => updateHero('titleHighlight2', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Untertitel</label>
<textarea
value={content.hero.subtitle}
onChange={(e) => updateHero('subtitle', e.target.value)}
rows={3}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div className="grid grid-cols-3 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
CTA Primaer
</label>
<input
type="text"
value={content.hero.ctaPrimary}
onChange={(e) => updateHero('ctaPrimary', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
CTA Sekundaer
</label>
<input
type="text"
value={content.hero.ctaSecondary}
onChange={(e) => updateHero('ctaSecondary', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">CTA Hinweis</label>
<input
type="text"
value={content.hero.ctaHint}
onChange={(e) => updateHero('ctaHint', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
</div>
</div>
)}
{/* Features Tab */}
{activeTab === 'features' && (
<div className="space-y-6">
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_features')}</h2>
{content.features.map((feature, index) => (
<div key={feature.id} className="border border-slate-200 rounded-lg p-4">
<div className="grid gap-4">
<div className="grid grid-cols-3 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Icon</label>
<input
type="text"
value={feature.icon}
onChange={(e) => updateFeature(index, 'icon', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-2xl text-center"
/>
</div>
<div className="col-span-2">
<label className="block text-sm font-medium text-slate-700 mb-1">Titel</label>
<input
type="text"
value={feature.title}
onChange={(e) => updateFeature(index, 'title', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Beschreibung
</label>
<textarea
value={feature.description}
onChange={(e) => updateFeature(index, 'description', e.target.value)}
rows={2}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
</div>
))}
</div>
)}
{/* FAQ Tab */}
{activeTab === 'faq' && (
<div className="space-y-6">
<div className={`flex items-center justify-between ${isRTL ? 'flex-row-reverse' : ''}`}>
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_faq')}</h2>
<button
onClick={addFAQ}
className="px-4 py-2 bg-slate-100 text-slate-700 rounded-lg hover:bg-slate-200 transition-colors"
>
{t('admin_add_faq')}
</button>
</div>
{content.faq.map((item, index) => (
<div key={index} className="border border-slate-200 rounded-lg p-4">
<div className={`flex items-start justify-between gap-4 ${isRTL ? 'flex-row-reverse' : ''}`}>
<div className="flex-1 space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
{t('admin_question')} {index + 1}
</label>
<input
type="text"
value={item.question}
onChange={(e) => updateFAQ(index, 'question', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
{t('admin_answer')}
</label>
<textarea
value={item.answer.join('\n')}
onChange={(e) => updateFAQ(index, 'answer', e.target.value)}
rows={4}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500 font-mono text-sm"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
<button
onClick={() => removeFAQ(index)}
className="p-2 text-red-600 hover:bg-red-50 rounded-lg transition-colors"
title="Frage entfernen"
>
<svg className="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
/>
</svg>
</button>
</div>
</div>
))}
</div>
)}
{/* Pricing Tab */}
{activeTab === 'pricing' && (
<div className="space-y-6">
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_pricing')}</h2>
{content.pricing.map((plan, index) => (
<div key={plan.id} className="border border-slate-200 rounded-lg p-4">
<div className="grid gap-4">
<div className="grid grid-cols-4 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Name</label>
<input
type="text"
value={plan.name}
onChange={(e) => updatePricing(index, 'name', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Preis (EUR)
</label>
<input
type="number"
step="0.01"
value={plan.price}
onChange={(e) => updatePricing(index, 'price', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Intervall
</label>
<input
type="text"
value={plan.interval}
onChange={(e) => updatePricing(index, 'interval', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div className="flex items-end">
<label className="flex items-center gap-2">
<input
type="checkbox"
checked={plan.popular || false}
onChange={(e) => updatePricing(index, 'popular', e.target.checked)}
className="w-4 h-4 text-primary-600 rounded"
/>
<span className="text-sm text-slate-700">{t('pricing_popular')}</span>
</label>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Beschreibung
</label>
<input
type="text"
value={plan.description}
onChange={(e) => updatePricing(index, 'description', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
{t('pricing_tasks')}
</label>
<input
type="text"
value={plan.features.tasks}
onChange={(e) => updatePricing(index, 'features.tasks', e.target.value)}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Aufgaben-Beschreibung
</label>
<input
type="text"
value={plan.features.taskDescription}
onChange={(e) =>
updatePricing(index, 'features.taskDescription', e.target.value)
}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Features (eine pro Zeile)
</label>
<textarea
value={plan.features.included.join('\n')}
onChange={(e) => updatePricing(index, 'features.included', e.target.value)}
rows={4}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500 font-mono text-sm"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
</div>
))}
</div>
)}
{/* Other Tab */}
{activeTab === 'other' && (
<div className="space-y-8">
{/* Trust Indicators */}
<div>
<h2 className="text-xl font-semibold text-slate-900 mb-4">Trust Indicators</h2>
<div className="grid grid-cols-3 gap-4">
{(['item1', 'item2', 'item3'] as const).map((key, index) => (
<div key={key} className="border border-slate-200 rounded-lg p-4">
<div className="space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Wert {index + 1}
</label>
<input
type="text"
value={content.trust[key].value}
onChange={(e) =>
setContent({
...content,
trust: {
...content.trust,
[key]: { ...content.trust[key], value: e.target.value },
},
})
}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">
Label {index + 1}
</label>
<input
type="text"
value={content.trust[key].label}
onChange={(e) =>
setContent({
...content,
trust: {
...content.trust,
[key]: { ...content.trust[key], label: e.target.value },
},
})
}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
</div>
))}
</div>
</div>
{/* Testimonial */}
<div>
<h2 className="text-xl font-semibold text-slate-900 mb-4">Testimonial</h2>
<div className="border border-slate-200 rounded-lg p-4 space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Zitat</label>
<textarea
value={content.testimonial.quote}
onChange={(e) =>
setContent({
...content,
testimonial: { ...content.testimonial, quote: e.target.value },
})
}
rows={3}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Autor</label>
<input
type="text"
value={content.testimonial.author}
onChange={(e) =>
setContent({
...content,
testimonial: { ...content.testimonial, author: e.target.value },
})
}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Rolle</label>
<input
type="text"
value={content.testimonial.role}
onChange={(e) =>
setContent({
...content,
testimonial: { ...content.testimonial, role: e.target.value },
})
}
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
</div>
</div>
</div>
</div>
)}
</div>
{/* Live Preview Panel */}
{showPreview && (
<div className="bg-white rounded-xl border border-slate-200 shadow-sm overflow-hidden">
{/* Preview Header */}
<div className="bg-slate-50 border-b border-slate-200 px-4 py-3 flex items-center justify-between">
<div className="flex items-center gap-2">
<div className="flex gap-1.5">
<div className="w-3 h-3 rounded-full bg-red-400"></div>
<div className="w-3 h-3 rounded-full bg-yellow-400"></div>
<div className="w-3 h-3 rounded-full bg-green-400"></div>
</div>
<span className="text-xs text-slate-500 ml-2">localhost:3000</span>
</div>
<div className="flex items-center gap-2">
<span className="text-xs font-medium text-slate-600 bg-slate-200 px-2 py-1 rounded">
{activeTab === 'hero' && 'Hero Section'}
{activeTab === 'features' && 'Features'}
{activeTab === 'faq' && 'FAQ'}
{activeTab === 'pricing' && 'Pricing'}
{activeTab === 'other' && 'Trust & Testimonial'}
</span>
<button
onClick={() => iframeRef.current?.contentWindow?.location.reload()}
className="p-1.5 text-slate-500 hover:text-slate-700 hover:bg-slate-200 rounded transition-colors"
title="Preview neu laden"
>
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
</svg>
</button>
</div>
</div>
{/* Preview Frame */}
<div className="relative h-[calc(100vh-340px)] bg-slate-100">
<iframe
ref={iframeRef}
src={`/?preview=true&section=${activeTab}#${activeTab}`}
className="w-full h-full border-0 scale-75 origin-top-left"
style={{
width: '133.33%',
height: '133.33%',
transform: 'scale(0.75)',
transformOrigin: 'top left',
}}
title="Website Preview"
sandbox="allow-same-origin allow-scripts"
/>
{/* Sektion-Indikator */}
<div className="absolute bottom-4 left-4 right-4 bg-blue-600 text-white px-4 py-2 rounded-lg shadow-lg flex items-center gap-2 text-sm">
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
<span>
Du bearbeitest: <strong>
{activeTab === 'hero' && 'Hero Section (Startbereich)'}
{activeTab === 'features' && 'Features (Funktionen)'}
{activeTab === 'faq' && 'FAQ (Häufige Fragen)'}
{activeTab === 'pricing' && 'Pricing (Preise)'}
{activeTab === 'other' && 'Trust & Testimonial'}
</strong>
</span>
</div>
</div>
</div>
{editor.showPreview && (
<LivePreviewPanel activeTab={editor.activeTab} iframeRef={editor.iframeRef} />
)}
</div>
</div>
@@ -0,0 +1,98 @@
import { ScreenDefinition, ConnectionDef, FlowType } from './types'
/**
* Find all connected nodes recursively from a start node.
*/
export function findConnectedNodes(
startNodeId: string,
connections: ConnectionDef[],
direction: 'children' | 'parents' | 'both' = 'children'
): Set<string> {
const connected = new Set<string>()
connected.add(startNodeId)
const queue = [startNodeId]
while (queue.length > 0) {
const current = queue.shift()!
connections.forEach(conn => {
if ((direction === 'children' || direction === 'both') && conn.source === current) {
if (!connected.has(conn.target)) {
connected.add(conn.target)
queue.push(conn.target)
}
}
if ((direction === 'parents' || direction === 'both') && conn.target === current) {
if (!connected.has(conn.source)) {
connected.add(conn.source)
queue.push(conn.source)
}
}
})
}
return connected
}
/**
* Construct an embeddable URL from a base URL and screen URL.
*/
export function constructEmbedUrl(baseUrl: string, url: string | undefined): string | null {
if (!url) return null
const hashIndex = url.indexOf('#')
if (hashIndex !== -1) {
const basePart = url.substring(0, hashIndex)
const hashPart = url.substring(hashIndex)
const separator = basePart.includes('?') ? '&' : '?'
return `${baseUrl}${basePart}${separator}embed=true${hashPart}`
} else {
const separator = url.includes('?') ? '&' : '?'
return `${baseUrl}${url}${separator}embed=true`
}
}
/**
* Calculate node position based on category and index within that category.
*/
export function getNodePosition(
id: string,
category: string,
screens: ScreenDefinition[],
flowType: FlowType
) {
const studioPositions: Record<string, { x: number; y: number }> = {
navigation: { x: 400, y: 50 },
content: { x: 50, y: 250 },
communication: { x: 750, y: 250 },
school: { x: 50, y: 500 },
admin: { x: 750, y: 500 },
ai: { x: 400, y: 380 },
}
const adminPositions: Record<string, { x: number; y: number }> = {
overview: { x: 400, y: 30 },
infrastructure: { x: 50, y: 150 },
compliance: { x: 700, y: 150 },
ai: { x: 50, y: 350 },
communication: { x: 400, y: 350 },
security: { x: 700, y: 350 },
content: { x: 50, y: 550 },
game: { x: 400, y: 550 },
misc: { x: 700, y: 550 },
}
const positions = flowType === 'studio' ? studioPositions : adminPositions
const base = positions[category] || { x: 400, y: 300 }
const categoryScreens = screens.filter(s => s.category === category)
const categoryIndex = categoryScreens.findIndex(s => s.id === id)
const cols = Math.ceil(Math.sqrt(categoryScreens.length + 1))
const row = Math.floor(categoryIndex / cols)
const col = categoryIndex % cols
return {
x: base.x + col * 160,
y: base.y + row * 90,
}
}
@@ -0,0 +1,173 @@
import { ScreenDefinition, ConnectionDef } from './types'
// ============================================
// STUDIO SCREENS (Port 8000)
// ============================================
export const STUDIO_SCREENS: ScreenDefinition[] = [
{ id: 'lehrer-dashboard', name: 'Mein Dashboard', description: 'Hauptuebersicht mit Widgets', category: 'navigation', icon: '🏠', url: '/app#lehrer-dashboard' },
{ id: 'lehrer-onboarding', name: 'Erste Schritte', description: 'Onboarding & Schnellstart', category: 'navigation', icon: '🚀', url: '/app#lehrer-onboarding' },
{ id: 'hilfe', name: 'Dokumentation', description: 'Hilfe & Anleitungen', category: 'navigation', icon: '📚', url: '/app#hilfe' },
{ id: 'worksheets', name: 'Arbeitsblaetter Studio', description: 'Lernmaterialien erstellen', category: 'content', icon: '📝', url: '/app#worksheets' },
{ id: 'content-creator', name: 'Content Creator', description: 'Inhalte erstellen', category: 'content', icon: '✨', url: '/app#content-creator' },
{ id: 'content-feed', name: 'Content Feed', description: 'Inhalte durchsuchen', category: 'content', icon: '📰', url: '/app#content-feed' },
{ id: 'unit-creator', name: 'Unit Creator', description: 'Lerneinheiten erstellen', category: 'content', icon: '📦', url: '/app#unit-creator' },
{ id: 'letters', name: 'Briefe & Vorlagen', description: 'Brief-Generator', category: 'content', icon: '✉️', url: '/app#letters' },
{ id: 'correction', name: 'Korrektur', description: 'Arbeiten korrigieren', category: 'content', icon: '✏️', url: '/app#correction' },
{ id: 'klausur-korrektur', name: 'Abiturklausuren', description: 'KI-gestuetzte Klausurkorrektur', category: 'content', icon: '📋', url: '/app#klausur-korrektur' },
{ id: 'jitsi', name: 'Videokonferenz', description: 'Jitsi Meet Integration', category: 'communication', icon: '🎥', url: '/app#jitsi' },
{ id: 'messenger', name: 'Messenger', description: 'Matrix E2EE Chat', category: 'communication', icon: '💬', url: '/app#messenger' },
{ id: 'mail', name: 'Unified Inbox', description: 'E-Mail Verwaltung', category: 'communication', icon: '📧', url: '/app#mail' },
{ id: 'school-classes', name: 'Klassen', description: 'Klassenverwaltung', category: 'school', icon: '👥', url: '/app#school-classes' },
{ id: 'school-exams', name: 'Pruefungen', description: 'Pruefungsverwaltung', category: 'school', icon: '📝', url: '/app#school-exams' },
{ id: 'school-grades', name: 'Noten', description: 'Notenverwaltung', category: 'school', icon: '📊', url: '/app#school-grades' },
{ id: 'school-gradebook', name: 'Notenbuch', description: 'Digitales Notenbuch', category: 'school', icon: '📖', url: '/app#school-gradebook' },
{ id: 'school-certificates', name: 'Zeugnisse', description: 'Zeugniserstellung', category: 'school', icon: '🎓', url: '/app#school-certificates' },
{ id: 'companion', name: 'Begleiter & Stunde', description: 'KI-Unterrichtsassistent', category: 'ai', icon: '🤖', url: '/app#companion' },
{ id: 'alerts', name: 'Alerts', description: 'News & Benachrichtigungen', category: 'ai', icon: '🔔', url: '/app#alerts' },
{ id: 'admin', name: 'Einstellungen', description: 'Systemeinstellungen', category: 'admin', icon: '⚙️', url: '/app#admin' },
{ id: 'rbac-admin', name: 'Rollen & Rechte', description: 'Berechtigungsverwaltung', category: 'admin', icon: '🔐', url: '/app#rbac-admin' },
{ id: 'abitur-docs-admin', name: 'Abitur Dokumente', description: 'Erwartungshorizonte', category: 'admin', icon: '📄', url: '/app#abitur-docs-admin' },
{ id: 'system-info', name: 'System Info', description: 'Systeminformationen', category: 'admin', icon: '💻', url: '/app#system-info' },
{ id: 'workflow', name: 'Workflow', description: 'Automatisierungen', category: 'admin', icon: '⚡', url: '/app#workflow' },
]
export const STUDIO_CONNECTIONS: ConnectionDef[] = [
{ source: 'lehrer-onboarding', target: 'worksheets', label: 'Arbeitsblaetter' },
{ source: 'lehrer-onboarding', target: 'klausur-korrektur', label: 'Abiturklausuren' },
{ source: 'lehrer-onboarding', target: 'correction', label: 'Korrektur' },
{ source: 'lehrer-onboarding', target: 'letters', label: 'Briefe' },
{ source: 'lehrer-onboarding', target: 'school-classes', label: 'Klassen' },
{ source: 'lehrer-onboarding', target: 'jitsi', label: 'Meet' },
{ source: 'lehrer-onboarding', target: 'hilfe', label: 'Doku' },
{ source: 'lehrer-onboarding', target: 'admin', label: 'Settings' },
{ source: 'lehrer-dashboard', target: 'worksheets' },
{ source: 'lehrer-dashboard', target: 'correction' },
{ source: 'lehrer-dashboard', target: 'jitsi' },
{ source: 'lehrer-dashboard', target: 'letters' },
{ source: 'lehrer-dashboard', target: 'messenger' },
{ source: 'lehrer-dashboard', target: 'klausur-korrektur' },
{ source: 'lehrer-dashboard', target: 'companion' },
{ source: 'lehrer-dashboard', target: 'alerts' },
{ source: 'lehrer-dashboard', target: 'mail' },
{ source: 'lehrer-dashboard', target: 'school-classes' },
{ source: 'lehrer-dashboard', target: 'lehrer-onboarding', label: 'Sidebar' },
{ source: 'school-classes', target: 'school-exams' },
{ source: 'school-classes', target: 'school-grades' },
{ source: 'school-grades', target: 'school-gradebook' },
{ source: 'school-gradebook', target: 'school-certificates' },
{ source: 'worksheets', target: 'content-creator' },
{ source: 'worksheets', target: 'unit-creator' },
{ source: 'content-creator', target: 'content-feed' },
{ source: 'klausur-korrektur', target: 'abitur-docs-admin' },
{ source: 'admin', target: 'rbac-admin' },
{ source: 'admin', target: 'system-info' },
{ source: 'admin', target: 'workflow' },
]
// ============================================
// ADMIN SCREENS (Port 3000)
// ============================================
export const ADMIN_SCREENS: ScreenDefinition[] = [
{ id: 'admin-dashboard', name: 'Dashboard', description: 'Uebersicht & Statistiken', category: 'overview', icon: '🏠', url: '/admin' },
{ id: 'admin-onboarding', name: 'Onboarding', description: 'Lern-Wizards fuer alle Module', category: 'overview', icon: '📖', url: '/admin/onboarding' },
{ id: 'admin-gpu', name: 'GPU Infrastruktur', description: 'vast.ai GPU Management', category: 'infrastructure', icon: '🖥️', url: '/admin/gpu' },
{ id: 'admin-middleware', name: 'Middleware', description: 'Middleware Stack & Test', category: 'infrastructure', icon: '🔧', url: '/admin/middleware' },
{ id: 'admin-mac-mini', name: 'Mac Mini', description: 'Headless Mac Mini Control', category: 'infrastructure', icon: '🍎', url: '/admin/mac-mini' },
{ id: 'admin-consent', name: 'Consent Verwaltung', description: 'Rechtliche Dokumente', category: 'compliance', icon: '📄', url: '/admin/consent' },
{ id: 'admin-dsr', name: 'Datenschutzanfragen', description: 'DSGVO Art. 15-21', category: 'compliance', icon: '🔒', url: '/admin/dsr' },
{ id: 'admin-dsms', name: 'DSMS', description: 'Datenschutz-Management', category: 'compliance', icon: '🛡️', url: '/admin/dsms' },
{ id: 'admin-compliance', name: 'Compliance', description: 'GRC & Audit', category: 'compliance', icon: '✅', url: '/admin/compliance' },
{ id: 'admin-docs-audit', name: 'DSGVO-Audit', description: 'Audit-Dokumentation', category: 'compliance', icon: '📋', url: '/admin/docs/audit' },
{ id: 'admin-rag', name: 'Daten & RAG', description: 'Training Data & RAG', category: 'ai', icon: '🗄️', url: '/admin/rag' },
{ id: 'admin-ocr-labeling', name: 'OCR-Labeling', description: 'Handschrift-Training', category: 'ai', icon: '🏷️', url: '/admin/ocr-labeling' },
{ id: 'admin-magic-help', name: 'Magic Help (TrOCR)', description: 'Handschrift-OCR', category: 'ai', icon: '✨', url: '/admin/magic-help' },
{ id: 'admin-companion', name: 'Companion Dev', description: 'Lesson-Modus Entwicklung', category: 'ai', icon: '📚', url: '/admin/companion' },
{ id: 'admin-communication', name: 'Kommunikation', description: 'Matrix & Jitsi Monitoring', category: 'communication', icon: '💬', url: '/admin/communication' },
{ id: 'admin-alerts', name: 'Alerts Monitoring', description: 'Google Alerts & Feeds', category: 'communication', icon: '🔔', url: '/admin/alerts' },
{ id: 'admin-mail', name: 'Unified Inbox', description: 'E-Mail & KI-Analyse', category: 'communication', icon: '📧', url: '/admin/mail' },
{ id: 'admin-security', name: 'Security', description: 'DevSecOps Dashboard', category: 'security', icon: '🔐', url: '/admin/security' },
{ id: 'admin-sbom', name: 'SBOM', description: 'Software Bill of Materials', category: 'security', icon: '📦', url: '/admin/sbom' },
{ id: 'admin-screen-flow', name: 'Screen Flow', description: 'UI Verbindungen', category: 'security', icon: '🔀', url: '/admin/screen-flow' },
{ id: 'admin-content', name: 'Uebersetzungen', description: 'Website Content', category: 'content', icon: '🌍', url: '/admin/content' },
{ id: 'admin-edu-search', name: 'Education Search', description: 'Bildungsquellen & Crawler', category: 'content', icon: '🔍', url: '/admin/edu-search' },
{ id: 'admin-staff-search', name: 'Personensuche', description: 'Uni-Mitarbeiter', category: 'content', icon: '👤', url: '/admin/staff-search' },
{ id: 'admin-uni-crawler', name: 'Uni-Crawler', description: 'Universitaets-Crawling', category: 'content', icon: '🕷️', url: '/admin/uni-crawler' },
{ id: 'admin-game', name: 'Breakpilot Drive', description: 'Lernspiel Klasse 2-6', category: 'game', icon: '🎮', url: '/admin/game' },
{ id: 'admin-unity-bridge', name: 'Unity Bridge', description: 'Unity Editor Steuerung', category: 'game', icon: '⚡', url: '/admin/unity-bridge' },
{ id: 'admin-backlog', name: 'Production Backlog', description: 'Go-Live Checkliste', category: 'misc', icon: '📝', url: '/admin/backlog' },
{ id: 'admin-brandbook', name: 'Brandbook', description: 'Corporate Design', category: 'misc', icon: '🎨', url: '/admin/brandbook' },
{ id: 'admin-docs', name: 'Developer Docs', description: 'API & Architektur', category: 'misc', icon: '📖', url: '/admin/docs' },
{ id: 'admin-pca-platform', name: 'PCA Platform', description: 'Bot-Erkennung', category: 'misc', icon: '💰', url: '/admin/pca-platform' },
]
export const ADMIN_CONNECTIONS: ConnectionDef[] = [
{ source: 'admin-dashboard', target: 'admin-onboarding' },
{ source: 'admin-dashboard', target: 'admin-security' },
{ source: 'admin-dashboard', target: 'admin-compliance' },
{ source: 'admin-onboarding', target: 'admin-gpu' },
{ source: 'admin-onboarding', target: 'admin-consent' },
{ source: 'admin-consent', target: 'admin-dsr' },
{ source: 'admin-dsr', target: 'admin-dsms' },
{ source: 'admin-dsms', target: 'admin-compliance' },
{ source: 'admin-compliance', target: 'admin-docs-audit' },
{ source: 'admin-rag', target: 'admin-ocr-labeling' },
{ source: 'admin-ocr-labeling', target: 'admin-magic-help' },
{ source: 'admin-magic-help', target: 'admin-companion' },
{ source: 'admin-security', target: 'admin-sbom' },
{ source: 'admin-sbom', target: 'admin-screen-flow' },
{ source: 'admin-communication', target: 'admin-alerts' },
{ source: 'admin-alerts', target: 'admin-mail' },
{ source: 'admin-gpu', target: 'admin-middleware' },
{ source: 'admin-middleware', target: 'admin-mac-mini' },
{ source: 'admin-game', target: 'admin-unity-bridge' },
{ source: 'admin-edu-search', target: 'admin-staff-search' },
{ source: 'admin-staff-search', target: 'admin-uni-crawler' },
]
// ============================================
// CATEGORY COLORS & LABELS
// ============================================
export const STUDIO_COLORS: Record<string, { bg: string; border: string; text: string }> = {
navigation: { bg: '#dbeafe', border: '#3b82f6', text: '#1e40af' },
content: { bg: '#dcfce7', border: '#22c55e', text: '#166534' },
communication: { bg: '#fef3c7', border: '#f59e0b', text: '#92400e' },
school: { bg: '#fce7f3', border: '#ec4899', text: '#9d174d' },
admin: { bg: '#f3e8ff', border: '#a855f7', text: '#6b21a8' },
ai: { bg: '#cffafe', border: '#06b6d4', text: '#0e7490' },
}
export const ADMIN_COLORS: Record<string, { bg: string; border: string; text: string }> = {
overview: { bg: '#dbeafe', border: '#3b82f6', text: '#1e40af' },
infrastructure: { bg: '#fef3c7', border: '#f59e0b', text: '#92400e' },
compliance: { bg: '#dcfce7', border: '#22c55e', text: '#166534' },
ai: { bg: '#cffafe', border: '#06b6d4', text: '#0e7490' },
communication: { bg: '#fce7f3', border: '#ec4899', text: '#9d174d' },
security: { bg: '#fee2e2', border: '#ef4444', text: '#991b1b' },
content: { bg: '#f3e8ff', border: '#a855f7', text: '#6b21a8' },
game: { bg: '#fef9c3', border: '#eab308', text: '#713f12' },
misc: { bg: '#f1f5f9', border: '#64748b', text: '#334155' },
}
export const STUDIO_LABELS: Record<string, string> = {
navigation: 'Navigation',
content: 'Content & Tools',
communication: 'Kommunikation',
school: 'Schulverwaltung',
admin: 'Administration',
ai: 'KI & Assistent',
}
export const ADMIN_LABELS: Record<string, string> = {
overview: 'Uebersicht',
infrastructure: 'Infrastruktur',
compliance: 'DSGVO & Compliance',
ai: 'KI & LLM',
communication: 'Kommunikation',
security: 'Security & DevOps',
content: 'Content & Suche',
game: 'Game & Unity',
misc: 'Sonstiges',
}
@@ -0,0 +1,16 @@
export interface ScreenDefinition {
id: string
name: string
description: string
category: string
icon: string
url?: string
}
export interface ConnectionDef {
source: string
target: string
label?: string
}
export type FlowType = 'studio' | 'admin'

Some files were not shown because too many files have changed in this diff Show More