[split-required] Split 700-870 LOC files across all services
backend-lehrer (11 files): - llm_gateway/routes/schools.py (867 → 5), recording_api.py (848 → 6) - messenger_api.py (840 → 5), print_generator.py (824 → 5) - unit_analytics_api.py (751 → 5), classroom/routes/context.py (726 → 4) - llm_gateway/routes/edu_search_seeds.py (710 → 4) klausur-service (12 files): - ocr_labeling_api.py (845 → 4), metrics_db.py (833 → 4) - legal_corpus_api.py (790 → 4), page_crop.py (758 → 3) - mail/ai_service.py (747 → 4), github_crawler.py (767 → 3) - trocr_service.py (730 → 4), full_compliance_pipeline.py (723 → 4) - dsfa_rag_api.py (715 → 4), ocr_pipeline_auto.py (705 → 4) website (6 pages): - audit-checklist (867 → 8), content (806 → 6) - screen-flow (790 → 4), scraper (789 → 5) - zeugnisse (776 → 5), modules (745 → 4) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -39,6 +39,9 @@
|
||||
**/lib/sdk/vvt-baseline-catalog.ts | owner=admin-lehrer | reason=Pure data catalog (630 LOC, BaselineTemplate[] literals) | review=2027-01-01
|
||||
**/lib/sdk/loeschfristen-baseline-catalog.ts | owner=admin-lehrer | reason=Pure data catalog (578 LOC, retention period templates) | review=2027-01-01
|
||||
|
||||
# Single SSE generator orchestrating 6 pipeline steps — cannot split generator context
|
||||
**/ocr_pipeline_auto_steps.py | owner=klausur | reason=run_auto is a single async generator yielding SSE events across 6 steps (528 LOC) | review=2026-10-01
|
||||
|
||||
# Legacy — TEMPORAER bis Refactoring abgeschlossen
|
||||
# Dateien hier werden Phase fuer Phase abgearbeitet und entfernt.
|
||||
# KEINE neuen Ausnahmen ohne [guardrail-change] Commit-Marker!
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
AI Processing - Print Version Generator: Cloze (Lueckentext).
|
||||
|
||||
Generates printable HTML for cloze/fill-in-the-blank worksheets.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
|
||||
from .core import BEREINIGT_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_print_version_cloze(cloze_path: Path, include_answers: bool = False) -> Path:
|
||||
"""
|
||||
Generiert eine druckbare HTML-Version der Lueckentexte.
|
||||
|
||||
Args:
|
||||
cloze_path: Pfad zur *_cloze.json Datei
|
||||
include_answers: True fuer Loesungsblatt (fuer Eltern)
|
||||
|
||||
Returns:
|
||||
Pfad zur generierten HTML-Datei
|
||||
"""
|
||||
if not cloze_path.exists():
|
||||
raise FileNotFoundError(f"Cloze-Datei nicht gefunden: {cloze_path}")
|
||||
|
||||
cloze_data = json.loads(cloze_path.read_text(encoding="utf-8"))
|
||||
items = cloze_data.get("cloze_items", [])
|
||||
metadata = cloze_data.get("metadata", {})
|
||||
|
||||
title = metadata.get("source_title", "Arbeitsblatt")
|
||||
subject = metadata.get("subject", "")
|
||||
grade = metadata.get("grade_level", "")
|
||||
total_gaps = metadata.get("total_gaps", 0)
|
||||
|
||||
html_parts = []
|
||||
html_parts.append("""<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>""" + title + """ - Lueckentext</title>
|
||||
<style>
|
||||
@media print {
|
||||
.no-print { display: none; }
|
||||
.page-break { page-break-before: always; }
|
||||
}
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 40px auto;
|
||||
padding: 20px;
|
||||
line-height: 1.8;
|
||||
}
|
||||
h1 { font-size: 24px; margin-bottom: 8px; }
|
||||
.meta { color: #666; margin-bottom: 24px; }
|
||||
.cloze-item {
|
||||
margin-bottom: 24px;
|
||||
padding: 16px;
|
||||
background: #f9f9f9;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.cloze-number {
|
||||
font-weight: bold;
|
||||
color: #333;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.cloze-sentence {
|
||||
font-size: 16px;
|
||||
line-height: 2;
|
||||
}
|
||||
.gap {
|
||||
display: inline-block;
|
||||
min-width: 80px;
|
||||
border-bottom: 2px solid #333;
|
||||
margin: 0 4px;
|
||||
text-align: center;
|
||||
}
|
||||
.gap-filled {
|
||||
display: inline-block;
|
||||
padding: 2px 8px;
|
||||
background: #e8f5e9;
|
||||
border: 1px solid #4caf50;
|
||||
border-radius: 4px;
|
||||
font-weight: bold;
|
||||
}
|
||||
.translation {
|
||||
margin-top: 12px;
|
||||
padding: 8px;
|
||||
background: #e3f2fd;
|
||||
border-left: 3px solid #2196f3;
|
||||
font-size: 14px;
|
||||
color: #555;
|
||||
}
|
||||
.translation-label {
|
||||
font-size: 12px;
|
||||
color: #777;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.word-bank {
|
||||
margin-top: 32px;
|
||||
padding: 16px;
|
||||
background: #fff3e0;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.word-bank-title {
|
||||
font-weight: bold;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.word {
|
||||
display: inline-block;
|
||||
padding: 4px 12px;
|
||||
margin: 4px;
|
||||
background: white;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
""")
|
||||
|
||||
# Header
|
||||
version_text = "Loesungsblatt" if include_answers else "Lueckentext"
|
||||
html_parts.append(f"<h1>{title} - {version_text}</h1>")
|
||||
meta_parts = []
|
||||
if subject:
|
||||
meta_parts.append(f"Fach: {subject}")
|
||||
if grade:
|
||||
meta_parts.append(f"Klasse: {grade}")
|
||||
meta_parts.append(f"Luecken gesamt: {total_gaps}")
|
||||
html_parts.append(f"<div class='meta'>{' | '.join(meta_parts)}</div>")
|
||||
|
||||
# Sammle alle Lueckenwoerter fuer Wortbank
|
||||
all_words = []
|
||||
|
||||
# Lueckentexte
|
||||
for idx, item in enumerate(items, 1):
|
||||
html_parts.append("<div class='cloze-item'>")
|
||||
html_parts.append(f"<div class='cloze-number'>{idx}.</div>")
|
||||
|
||||
gaps = item.get("gaps", [])
|
||||
sentence = item.get("sentence_with_gaps", "")
|
||||
|
||||
if include_answers:
|
||||
# Loesungsblatt: Luecken mit Antworten fuellen
|
||||
for gap in gaps:
|
||||
word = gap.get("word", "")
|
||||
sentence = sentence.replace("___", f"<span class='gap-filled'>{word}</span>", 1)
|
||||
else:
|
||||
# Fragenblatt: Luecken als Linien
|
||||
sentence = sentence.replace("___", "<span class='gap'> </span>")
|
||||
# Woerter fuer Wortbank sammeln
|
||||
for gap in gaps:
|
||||
all_words.append(gap.get("word", ""))
|
||||
|
||||
html_parts.append(f"<div class='cloze-sentence'>{sentence}</div>")
|
||||
|
||||
# Uebersetzung anzeigen
|
||||
translation = item.get("translation", {})
|
||||
if translation:
|
||||
lang_name = translation.get("language_name", "Uebersetzung")
|
||||
full_sentence = translation.get("full_sentence", "")
|
||||
if full_sentence:
|
||||
html_parts.append("<div class='translation'>")
|
||||
html_parts.append(f"<div class='translation-label'>{lang_name}:</div>")
|
||||
html_parts.append(full_sentence)
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Wortbank (nur fuer Fragenblatt)
|
||||
if not include_answers and all_words:
|
||||
random.shuffle(all_words) # Mische die Woerter
|
||||
html_parts.append("<div class='word-bank'>")
|
||||
html_parts.append("<div class='word-bank-title'>Wortbank (diese Woerter fehlen):</div>")
|
||||
for word in all_words:
|
||||
html_parts.append(f"<span class='word'>{word}</span>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</body></html>")
|
||||
|
||||
# Speichern
|
||||
suffix = "_cloze_solutions.html" if include_answers else "_cloze_print.html"
|
||||
out_name = cloze_path.stem.replace("_cloze", "") + suffix
|
||||
out_path = BEREINIGT_DIR / out_name
|
||||
out_path.write_text("\n".join(html_parts), encoding="utf-8")
|
||||
|
||||
logger.info(f"Cloze Print-Version gespeichert: {out_path.name}")
|
||||
return out_path
|
||||
@@ -1,824 +1,22 @@
|
||||
"""
|
||||
AI Processing - Print Version Generator.
|
||||
AI Processing - Print Version Generator — Barrel Re-export.
|
||||
|
||||
Generiert druckbare HTML-Versionen für verschiedene Arbeitsblatt-Typen.
|
||||
Generiert druckbare HTML-Versionen fuer verschiedene Arbeitsblatt-Typen.
|
||||
Split into:
|
||||
- print_qa.py: Q&A print generation
|
||||
- print_cloze.py: Cloze/Lueckentext print generation
|
||||
- print_mc.py: Multiple Choice print generation
|
||||
- print_worksheet.py: General worksheet print generation
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
|
||||
from .core import BEREINIGT_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_print_version_qa(qa_path: Path, include_answers: bool = False) -> Path:
|
||||
"""
|
||||
Generiert eine druckbare HTML-Version der Frage-Antwort-Paare.
|
||||
|
||||
Args:
|
||||
qa_path: Pfad zur *_qa.json Datei
|
||||
include_answers: True für Lösungsblatt (für Eltern)
|
||||
|
||||
Returns:
|
||||
Pfad zur generierten HTML-Datei
|
||||
"""
|
||||
if not qa_path.exists():
|
||||
raise FileNotFoundError(f"Q&A-Datei nicht gefunden: {qa_path}")
|
||||
|
||||
qa_data = json.loads(qa_path.read_text(encoding="utf-8"))
|
||||
items = qa_data.get("qa_items", [])
|
||||
metadata = qa_data.get("metadata", {})
|
||||
|
||||
title = metadata.get("source_title", "Arbeitsblatt")
|
||||
subject = metadata.get("subject", "")
|
||||
grade = metadata.get("grade_level", "")
|
||||
|
||||
html_parts = []
|
||||
html_parts.append("""<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>""" + title + """ - Fragen</title>
|
||||
<style>
|
||||
@media print {
|
||||
.no-print { display: none; }
|
||||
.page-break { page-break-before: always; }
|
||||
}
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 40px auto;
|
||||
padding: 20px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
h1 { font-size: 24px; margin-bottom: 8px; }
|
||||
.meta { color: #666; margin-bottom: 24px; }
|
||||
.question-block {
|
||||
margin-bottom: 32px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px dashed #ccc;
|
||||
}
|
||||
.question-number {
|
||||
font-weight: bold;
|
||||
color: #333;
|
||||
}
|
||||
.question-text {
|
||||
font-size: 16px;
|
||||
margin: 8px 0;
|
||||
}
|
||||
.answer-space {
|
||||
border: 1px solid #ddd;
|
||||
min-height: 60px;
|
||||
margin-top: 12px;
|
||||
background: #fafafa;
|
||||
}
|
||||
.answer-lines {
|
||||
margin-top: 12px;
|
||||
}
|
||||
.answer-line {
|
||||
border-bottom: 1px solid #999;
|
||||
height: 28px;
|
||||
}
|
||||
.answer {
|
||||
margin-top: 8px;
|
||||
padding: 8px;
|
||||
background: #e8f5e9;
|
||||
border-left: 3px solid #4caf50;
|
||||
}
|
||||
.key-terms {
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
margin-top: 8px;
|
||||
}
|
||||
.key-terms span {
|
||||
background: #fff3e0;
|
||||
padding: 2px 6px;
|
||||
border-radius: 3px;
|
||||
margin-right: 4px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
""")
|
||||
|
||||
# Header
|
||||
version_text = "Lösungsblatt" if include_answers else "Fragenblatt"
|
||||
html_parts.append(f"<h1>{title} - {version_text}</h1>")
|
||||
meta_parts = []
|
||||
if subject:
|
||||
meta_parts.append(f"Fach: {subject}")
|
||||
if grade:
|
||||
meta_parts.append(f"Klasse: {grade}")
|
||||
meta_parts.append(f"Anzahl Fragen: {len(items)}")
|
||||
html_parts.append(f"<div class='meta'>{' | '.join(meta_parts)}</div>")
|
||||
|
||||
# Fragen
|
||||
for idx, item in enumerate(items, 1):
|
||||
html_parts.append("<div class='question-block'>")
|
||||
html_parts.append(f"<div class='question-number'>Frage {idx}</div>")
|
||||
html_parts.append(f"<div class='question-text'>{item.get('question', '')}</div>")
|
||||
|
||||
if include_answers:
|
||||
# Lösungsblatt: Antwort anzeigen
|
||||
html_parts.append(f"<div class='answer'><strong>Antwort:</strong> {item.get('answer', '')}</div>")
|
||||
# Schlüsselbegriffe
|
||||
key_terms = item.get("key_terms", [])
|
||||
if key_terms:
|
||||
terms_html = " ".join([f"<span>{term}</span>" for term in key_terms])
|
||||
html_parts.append(f"<div class='key-terms'>Wichtige Begriffe: {terms_html}</div>")
|
||||
else:
|
||||
# Fragenblatt: Antwortlinien
|
||||
html_parts.append("<div class='answer-lines'>")
|
||||
for _ in range(3):
|
||||
html_parts.append("<div class='answer-line'></div>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</body></html>")
|
||||
|
||||
# Speichern
|
||||
suffix = "_qa_solutions.html" if include_answers else "_qa_print.html"
|
||||
out_name = qa_path.stem.replace("_qa", "") + suffix
|
||||
out_path = BEREINIGT_DIR / out_name
|
||||
out_path.write_text("\n".join(html_parts), encoding="utf-8")
|
||||
|
||||
logger.info(f"Print-Version gespeichert: {out_path.name}")
|
||||
return out_path
|
||||
|
||||
|
||||
def generate_print_version_cloze(cloze_path: Path, include_answers: bool = False) -> Path:
|
||||
"""
|
||||
Generiert eine druckbare HTML-Version der Lückentexte.
|
||||
|
||||
Args:
|
||||
cloze_path: Pfad zur *_cloze.json Datei
|
||||
include_answers: True für Lösungsblatt (für Eltern)
|
||||
|
||||
Returns:
|
||||
Pfad zur generierten HTML-Datei
|
||||
"""
|
||||
if not cloze_path.exists():
|
||||
raise FileNotFoundError(f"Cloze-Datei nicht gefunden: {cloze_path}")
|
||||
|
||||
cloze_data = json.loads(cloze_path.read_text(encoding="utf-8"))
|
||||
items = cloze_data.get("cloze_items", [])
|
||||
metadata = cloze_data.get("metadata", {})
|
||||
|
||||
title = metadata.get("source_title", "Arbeitsblatt")
|
||||
subject = metadata.get("subject", "")
|
||||
grade = metadata.get("grade_level", "")
|
||||
total_gaps = metadata.get("total_gaps", 0)
|
||||
|
||||
html_parts = []
|
||||
html_parts.append("""<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>""" + title + """ - Lückentext</title>
|
||||
<style>
|
||||
@media print {
|
||||
.no-print { display: none; }
|
||||
.page-break { page-break-before: always; }
|
||||
}
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 40px auto;
|
||||
padding: 20px;
|
||||
line-height: 1.8;
|
||||
}
|
||||
h1 { font-size: 24px; margin-bottom: 8px; }
|
||||
.meta { color: #666; margin-bottom: 24px; }
|
||||
.cloze-item {
|
||||
margin-bottom: 24px;
|
||||
padding: 16px;
|
||||
background: #f9f9f9;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.cloze-number {
|
||||
font-weight: bold;
|
||||
color: #333;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.cloze-sentence {
|
||||
font-size: 16px;
|
||||
line-height: 2;
|
||||
}
|
||||
.gap {
|
||||
display: inline-block;
|
||||
min-width: 80px;
|
||||
border-bottom: 2px solid #333;
|
||||
margin: 0 4px;
|
||||
text-align: center;
|
||||
}
|
||||
.gap-filled {
|
||||
display: inline-block;
|
||||
padding: 2px 8px;
|
||||
background: #e8f5e9;
|
||||
border: 1px solid #4caf50;
|
||||
border-radius: 4px;
|
||||
font-weight: bold;
|
||||
}
|
||||
.translation {
|
||||
margin-top: 12px;
|
||||
padding: 8px;
|
||||
background: #e3f2fd;
|
||||
border-left: 3px solid #2196f3;
|
||||
font-size: 14px;
|
||||
color: #555;
|
||||
}
|
||||
.translation-label {
|
||||
font-size: 12px;
|
||||
color: #777;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.word-bank {
|
||||
margin-top: 32px;
|
||||
padding: 16px;
|
||||
background: #fff3e0;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.word-bank-title {
|
||||
font-weight: bold;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.word {
|
||||
display: inline-block;
|
||||
padding: 4px 12px;
|
||||
margin: 4px;
|
||||
background: white;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
""")
|
||||
|
||||
# Header
|
||||
version_text = "Lösungsblatt" if include_answers else "Lückentext"
|
||||
html_parts.append(f"<h1>{title} - {version_text}</h1>")
|
||||
meta_parts = []
|
||||
if subject:
|
||||
meta_parts.append(f"Fach: {subject}")
|
||||
if grade:
|
||||
meta_parts.append(f"Klasse: {grade}")
|
||||
meta_parts.append(f"Lücken gesamt: {total_gaps}")
|
||||
html_parts.append(f"<div class='meta'>{' | '.join(meta_parts)}</div>")
|
||||
|
||||
# Sammle alle Lückenwörter für Wortbank
|
||||
all_words = []
|
||||
|
||||
# Lückentexte
|
||||
for idx, item in enumerate(items, 1):
|
||||
html_parts.append("<div class='cloze-item'>")
|
||||
html_parts.append(f"<div class='cloze-number'>{idx}.</div>")
|
||||
|
||||
gaps = item.get("gaps", [])
|
||||
sentence = item.get("sentence_with_gaps", "")
|
||||
|
||||
if include_answers:
|
||||
# Lösungsblatt: Lücken mit Antworten füllen
|
||||
for gap in gaps:
|
||||
word = gap.get("word", "")
|
||||
sentence = sentence.replace("___", f"<span class='gap-filled'>{word}</span>", 1)
|
||||
else:
|
||||
# Fragenblatt: Lücken als Linien
|
||||
sentence = sentence.replace("___", "<span class='gap'> </span>")
|
||||
# Wörter für Wortbank sammeln
|
||||
for gap in gaps:
|
||||
all_words.append(gap.get("word", ""))
|
||||
|
||||
html_parts.append(f"<div class='cloze-sentence'>{sentence}</div>")
|
||||
|
||||
# Übersetzung anzeigen
|
||||
translation = item.get("translation", {})
|
||||
if translation:
|
||||
lang_name = translation.get("language_name", "Übersetzung")
|
||||
full_sentence = translation.get("full_sentence", "")
|
||||
if full_sentence:
|
||||
html_parts.append("<div class='translation'>")
|
||||
html_parts.append(f"<div class='translation-label'>{lang_name}:</div>")
|
||||
html_parts.append(full_sentence)
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Wortbank (nur für Fragenblatt)
|
||||
if not include_answers and all_words:
|
||||
random.shuffle(all_words) # Mische die Wörter
|
||||
html_parts.append("<div class='word-bank'>")
|
||||
html_parts.append("<div class='word-bank-title'>Wortbank (diese Wörter fehlen):</div>")
|
||||
for word in all_words:
|
||||
html_parts.append(f"<span class='word'>{word}</span>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</body></html>")
|
||||
|
||||
# Speichern
|
||||
suffix = "_cloze_solutions.html" if include_answers else "_cloze_print.html"
|
||||
out_name = cloze_path.stem.replace("_cloze", "") + suffix
|
||||
out_path = BEREINIGT_DIR / out_name
|
||||
out_path.write_text("\n".join(html_parts), encoding="utf-8")
|
||||
|
||||
logger.info(f"Cloze Print-Version gespeichert: {out_path.name}")
|
||||
return out_path
|
||||
|
||||
|
||||
def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> str:
|
||||
"""
|
||||
Generiert eine druckbare HTML-Version der Multiple-Choice-Fragen.
|
||||
|
||||
Args:
|
||||
mc_path: Pfad zur *_mc.json Datei
|
||||
include_answers: True für Lösungsblatt mit markierten richtigen Antworten
|
||||
|
||||
Returns:
|
||||
HTML-String (zum direkten Ausliefern)
|
||||
"""
|
||||
if not mc_path.exists():
|
||||
raise FileNotFoundError(f"MC-Datei nicht gefunden: {mc_path}")
|
||||
|
||||
mc_data = json.loads(mc_path.read_text(encoding="utf-8"))
|
||||
questions = mc_data.get("questions", [])
|
||||
metadata = mc_data.get("metadata", {})
|
||||
|
||||
title = metadata.get("source_title", "Arbeitsblatt")
|
||||
subject = metadata.get("subject", "")
|
||||
grade = metadata.get("grade_level", "")
|
||||
|
||||
html_parts = []
|
||||
html_parts.append("""<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>""" + title + """ - Multiple Choice</title>
|
||||
<style>
|
||||
@media print {
|
||||
.no-print { display: none; }
|
||||
.page-break { page-break-before: always; }
|
||||
body { font-size: 14pt; }
|
||||
}
|
||||
body {
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 40px auto;
|
||||
padding: 20px;
|
||||
line-height: 1.6;
|
||||
color: #000;
|
||||
}
|
||||
h1 {
|
||||
font-size: 28px;
|
||||
margin-bottom: 8px;
|
||||
border-bottom: 2px solid #000;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
.meta {
|
||||
color: #333;
|
||||
margin-bottom: 32px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.instructions {
|
||||
background: #f5f5f5;
|
||||
padding: 12px 16px;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 24px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.question-block {
|
||||
margin-bottom: 28px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid #ddd;
|
||||
}
|
||||
.question-number {
|
||||
font-weight: bold;
|
||||
font-size: 18px;
|
||||
color: #000;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.question-text {
|
||||
font-size: 16px;
|
||||
margin: 8px 0 16px 0;
|
||||
line-height: 1.5;
|
||||
}
|
||||
.options {
|
||||
margin-left: 20px;
|
||||
}
|
||||
.option {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
margin-bottom: 12px;
|
||||
padding: 8px 12px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
background: #fff;
|
||||
}
|
||||
.option-correct {
|
||||
background: #e8f5e9;
|
||||
border-color: #4caf50;
|
||||
border-width: 2px;
|
||||
}
|
||||
.option-checkbox {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border: 2px solid #333;
|
||||
border-radius: 50%;
|
||||
margin-right: 12px;
|
||||
flex-shrink: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.option-checkbox.checked::after {
|
||||
content: "✓";
|
||||
font-weight: bold;
|
||||
color: #4caf50;
|
||||
}
|
||||
.option-label {
|
||||
font-weight: bold;
|
||||
margin-right: 8px;
|
||||
min-width: 24px;
|
||||
}
|
||||
.option-text {
|
||||
flex: 1;
|
||||
}
|
||||
.explanation {
|
||||
margin-top: 8px;
|
||||
padding: 8px 12px;
|
||||
background: #e3f2fd;
|
||||
border-left: 3px solid #2196f3;
|
||||
font-size: 13px;
|
||||
color: #333;
|
||||
}
|
||||
.answer-key {
|
||||
margin-top: 40px;
|
||||
padding: 16px;
|
||||
background: #f5f5f5;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.answer-key-title {
|
||||
font-weight: bold;
|
||||
font-size: 18px;
|
||||
margin-bottom: 12px;
|
||||
border-bottom: 1px solid #999;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
.answer-key-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(5, 1fr);
|
||||
gap: 8px;
|
||||
}
|
||||
.answer-key-item {
|
||||
padding: 8px;
|
||||
text-align: center;
|
||||
background: white;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.answer-key-q {
|
||||
font-weight: bold;
|
||||
}
|
||||
.answer-key-a {
|
||||
color: #4caf50;
|
||||
font-weight: bold;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
""")
|
||||
|
||||
# Header
|
||||
version_text = "Lösungsblatt" if include_answers else "Multiple Choice Test"
|
||||
html_parts.append(f"<h1>{title}</h1>")
|
||||
html_parts.append(f"<div class='meta'><strong>{version_text}</strong>")
|
||||
if subject:
|
||||
html_parts.append(f" | Fach: {subject}")
|
||||
if grade:
|
||||
html_parts.append(f" | Klasse: {grade}")
|
||||
html_parts.append(f" | Anzahl Fragen: {len(questions)}</div>")
|
||||
|
||||
if not include_answers:
|
||||
html_parts.append("<div class='instructions'>")
|
||||
html_parts.append("<strong>Anleitung:</strong> Kreuze bei jeder Frage die richtige Antwort an. ")
|
||||
html_parts.append("Es ist immer nur eine Antwort richtig.")
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Fragen
|
||||
for idx, q in enumerate(questions, 1):
|
||||
html_parts.append("<div class='question-block'>")
|
||||
html_parts.append(f"<div class='question-number'>Frage {idx}</div>")
|
||||
html_parts.append(f"<div class='question-text'>{q.get('question', '')}</div>")
|
||||
|
||||
html_parts.append("<div class='options'>")
|
||||
correct_answer = q.get("correct_answer", "")
|
||||
|
||||
for opt in q.get("options", []):
|
||||
opt_id = opt.get("id", "")
|
||||
is_correct = opt_id == correct_answer
|
||||
|
||||
opt_class = "option"
|
||||
checkbox_class = "option-checkbox"
|
||||
if include_answers and is_correct:
|
||||
opt_class += " option-correct"
|
||||
checkbox_class += " checked"
|
||||
|
||||
html_parts.append(f"<div class='{opt_class}'>")
|
||||
html_parts.append(f"<div class='{checkbox_class}'></div>")
|
||||
html_parts.append(f"<span class='option-label'>{opt_id})</span>")
|
||||
html_parts.append(f"<span class='option-text'>{opt.get('text', '')}</span>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Erklärung nur bei Lösungsblatt
|
||||
if include_answers and q.get("explanation"):
|
||||
html_parts.append(f"<div class='explanation'><strong>Erklärung:</strong> {q.get('explanation')}</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Lösungsschlüssel (kompakt) - nur bei Lösungsblatt
|
||||
if include_answers:
|
||||
html_parts.append("<div class='answer-key'>")
|
||||
html_parts.append("<div class='answer-key-title'>Lösungsschlüssel</div>")
|
||||
html_parts.append("<div class='answer-key-grid'>")
|
||||
for idx, q in enumerate(questions, 1):
|
||||
html_parts.append("<div class='answer-key-item'>")
|
||||
html_parts.append(f"<span class='answer-key-q'>{idx}.</span> ")
|
||||
html_parts.append(f"<span class='answer-key-a'>{q.get('correct_answer', '')}</span>")
|
||||
html_parts.append("</div>")
|
||||
html_parts.append("</div>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</body></html>")
|
||||
|
||||
return "\n".join(html_parts)
|
||||
|
||||
|
||||
def generate_print_version_worksheet(analysis_path: Path) -> str:
|
||||
"""
|
||||
Generiert eine druckoptimierte HTML-Version des Arbeitsblatts.
|
||||
|
||||
Eigenschaften:
|
||||
- Große, gut lesbare Schrift (16pt)
|
||||
- Schwarz-weiß / Graustufen-tauglich
|
||||
- Klare Struktur für Druck
|
||||
- Keine interaktiven Elemente
|
||||
|
||||
Args:
|
||||
analysis_path: Pfad zur *_analyse.json Datei
|
||||
|
||||
Returns:
|
||||
HTML-String zum direkten Ausliefern
|
||||
"""
|
||||
if not analysis_path.exists():
|
||||
raise FileNotFoundError(f"Analysedatei nicht gefunden: {analysis_path}")
|
||||
|
||||
try:
|
||||
data = json.loads(analysis_path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(f"Analyse-Datei enthält kein gültiges JSON: {analysis_path}\n{e}") from e
|
||||
|
||||
title = data.get("title") or "Arbeitsblatt"
|
||||
subject = data.get("subject") or ""
|
||||
grade_level = data.get("grade_level") or ""
|
||||
instructions = data.get("instructions") or ""
|
||||
tasks = data.get("tasks", []) or []
|
||||
canonical_text = data.get("canonical_text") or ""
|
||||
printed_blocks = data.get("printed_blocks") or []
|
||||
|
||||
html_parts = []
|
||||
html_parts.append("""<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>""" + title + """</title>
|
||||
<style>
|
||||
@page {
|
||||
size: A4;
|
||||
margin: 20mm;
|
||||
}
|
||||
@media print {
|
||||
body {
|
||||
font-size: 14pt !important;
|
||||
-webkit-print-color-adjust: exact;
|
||||
print-color-adjust: exact;
|
||||
}
|
||||
.no-print { display: none !important; }
|
||||
.page-break { page-break-before: always; }
|
||||
}
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: Arial, "Helvetica Neue", sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 30px;
|
||||
line-height: 1.7;
|
||||
font-size: 16px;
|
||||
color: #000;
|
||||
background: #fff;
|
||||
}
|
||||
h1 {
|
||||
font-size: 28px;
|
||||
margin: 0 0 8px 0;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 3px solid #000;
|
||||
}
|
||||
h2 {
|
||||
font-size: 20px;
|
||||
margin: 28px 0 12px 0;
|
||||
padding-bottom: 4px;
|
||||
border-bottom: 1px solid #666;
|
||||
}
|
||||
.meta {
|
||||
font-size: 14px;
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
padding: 8px 0;
|
||||
}
|
||||
.meta span {
|
||||
margin-right: 20px;
|
||||
}
|
||||
.instructions {
|
||||
margin: 20px 0;
|
||||
padding: 16px;
|
||||
border: 2px solid #333;
|
||||
background: #f5f5f5;
|
||||
font-size: 15px;
|
||||
}
|
||||
.instructions-label {
|
||||
font-weight: bold;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.text-section {
|
||||
margin: 24px 0;
|
||||
}
|
||||
.text-block {
|
||||
margin-bottom: 16px;
|
||||
text-align: justify;
|
||||
}
|
||||
.text-block-title {
|
||||
font-weight: bold;
|
||||
font-size: 17px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.task-section {
|
||||
margin-top: 32px;
|
||||
}
|
||||
.task {
|
||||
margin-bottom: 24px;
|
||||
padding: 16px;
|
||||
border: 1px solid #999;
|
||||
background: #fafafa;
|
||||
}
|
||||
.task-header {
|
||||
font-weight: bold;
|
||||
font-size: 16px;
|
||||
margin-bottom: 12px;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 1px dashed #666;
|
||||
}
|
||||
.task-content {
|
||||
font-size: 15px;
|
||||
}
|
||||
.gap-line {
|
||||
display: inline-block;
|
||||
border-bottom: 2px solid #000;
|
||||
min-width: 100px;
|
||||
margin: 0 6px;
|
||||
}
|
||||
.answer-lines {
|
||||
margin-top: 16px;
|
||||
}
|
||||
.answer-line {
|
||||
border-bottom: 1px solid #333;
|
||||
height: 36px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.footer {
|
||||
margin-top: 40px;
|
||||
padding-top: 16px;
|
||||
border-top: 1px solid #ccc;
|
||||
font-size: 11px;
|
||||
color: #666;
|
||||
text-align: center;
|
||||
}
|
||||
/* Print Button - versteckt beim Drucken */
|
||||
.print-button {
|
||||
position: fixed;
|
||||
top: 20px;
|
||||
right: 20px;
|
||||
padding: 12px 24px;
|
||||
background: #333;
|
||||
color: #fff;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
}
|
||||
.print-button:hover {
|
||||
background: #555;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<button class="print-button no-print" onclick="window.print()">🖨️ Drucken</button>
|
||||
""")
|
||||
|
||||
# Titel
|
||||
html_parts.append(f"<h1>{title}</h1>")
|
||||
|
||||
# Meta-Informationen
|
||||
meta_parts = []
|
||||
if subject:
|
||||
meta_parts.append(f"<span><strong>Fach:</strong> {subject}</span>")
|
||||
if grade_level:
|
||||
meta_parts.append(f"<span><strong>Klasse:</strong> {grade_level}</span>")
|
||||
if meta_parts:
|
||||
html_parts.append(f"<div class='meta'>{''.join(meta_parts)}</div>")
|
||||
|
||||
# Arbeitsanweisung
|
||||
if instructions:
|
||||
html_parts.append("<div class='instructions'>")
|
||||
html_parts.append("<div class='instructions-label'>Arbeitsanweisung:</div>")
|
||||
html_parts.append(f"<div>{instructions}</div>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Haupttext / gedruckte Blöcke
|
||||
if printed_blocks:
|
||||
html_parts.append("<section class='text-section'>")
|
||||
for block in printed_blocks:
|
||||
role = (block.get("role") or "body").lower()
|
||||
text = (block.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if role == "title":
|
||||
html_parts.append(f"<div class='text-block'><div class='text-block-title'>{text}</div></div>")
|
||||
else:
|
||||
html_parts.append(f"<div class='text-block'>{text}</div>")
|
||||
html_parts.append("</section>")
|
||||
elif canonical_text:
|
||||
html_parts.append("<section class='text-section'>")
|
||||
paragraphs = [
|
||||
p.strip()
|
||||
for p in canonical_text.replace("\r\n", "\n").split("\n\n")
|
||||
if p.strip()
|
||||
]
|
||||
for p in paragraphs:
|
||||
html_parts.append(f"<div class='text-block'>{p}</div>")
|
||||
html_parts.append("</section>")
|
||||
|
||||
# Aufgaben
|
||||
if tasks:
|
||||
html_parts.append("<section class='task-section'>")
|
||||
html_parts.append("<h2>Aufgaben</h2>")
|
||||
|
||||
for idx, task in enumerate(tasks, start=1):
|
||||
t_type = task.get("type") or "Aufgabe"
|
||||
desc = task.get("description") or ""
|
||||
text_with_gaps = task.get("text_with_gaps")
|
||||
|
||||
html_parts.append("<div class='task'>")
|
||||
|
||||
# Task-Header
|
||||
type_label = {
|
||||
"fill_in_blank": "Lückentext",
|
||||
"multiple_choice": "Multiple Choice",
|
||||
"free_text": "Freitext",
|
||||
"matching": "Zuordnung",
|
||||
"labeling": "Beschriftung",
|
||||
"calculation": "Rechnung",
|
||||
"other": "Aufgabe"
|
||||
}.get(t_type, t_type)
|
||||
|
||||
html_parts.append(f"<div class='task-header'>Aufgabe {idx}: {type_label}</div>")
|
||||
|
||||
if desc:
|
||||
html_parts.append(f"<div class='task-content'>{desc}</div>")
|
||||
|
||||
if text_with_gaps:
|
||||
rendered = text_with_gaps.replace("___", "<span class='gap-line'> </span>")
|
||||
html_parts.append(f"<div class='task-content' style='margin-top:12px;'>{rendered}</div>")
|
||||
|
||||
# Antwortlinien für Freitext-Aufgaben
|
||||
if t_type in ["free_text", "other"] or (not text_with_gaps and not desc):
|
||||
html_parts.append("<div class='answer-lines'>")
|
||||
for _ in range(3):
|
||||
html_parts.append("<div class='answer-line'></div>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</section>")
|
||||
|
||||
# Fußzeile
|
||||
html_parts.append("<div class='footer'>")
|
||||
html_parts.append("Dieses Arbeitsblatt wurde automatisch aus einem Scan rekonstruiert.")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</body></html>")
|
||||
|
||||
return "\n".join(html_parts)
|
||||
from .print_qa import generate_print_version_qa
|
||||
from .print_cloze import generate_print_version_cloze
|
||||
from .print_mc import generate_print_version_mc
|
||||
from .print_worksheet import generate_print_version_worksheet
|
||||
|
||||
__all__ = [
|
||||
"generate_print_version_qa",
|
||||
"generate_print_version_cloze",
|
||||
"generate_print_version_mc",
|
||||
"generate_print_version_worksheet",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
AI Processing - Print Version Generator: Multiple Choice.
|
||||
|
||||
Generates printable HTML for multiple-choice worksheets.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> str:
|
||||
"""
|
||||
Generiert eine druckbare HTML-Version der Multiple-Choice-Fragen.
|
||||
|
||||
Args:
|
||||
mc_path: Pfad zur *_mc.json Datei
|
||||
include_answers: True fuer Loesungsblatt mit markierten richtigen Antworten
|
||||
|
||||
Returns:
|
||||
HTML-String (zum direkten Ausliefern)
|
||||
"""
|
||||
if not mc_path.exists():
|
||||
raise FileNotFoundError(f"MC-Datei nicht gefunden: {mc_path}")
|
||||
|
||||
mc_data = json.loads(mc_path.read_text(encoding="utf-8"))
|
||||
questions = mc_data.get("questions", [])
|
||||
metadata = mc_data.get("metadata", {})
|
||||
|
||||
title = metadata.get("source_title", "Arbeitsblatt")
|
||||
subject = metadata.get("subject", "")
|
||||
grade = metadata.get("grade_level", "")
|
||||
|
||||
html_parts = []
|
||||
html_parts.append("""<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>""" + title + """ - Multiple Choice</title>
|
||||
<style>
|
||||
@media print {
|
||||
.no-print { display: none; }
|
||||
.page-break { page-break-before: always; }
|
||||
body { font-size: 14pt; }
|
||||
}
|
||||
body {
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 40px auto;
|
||||
padding: 20px;
|
||||
line-height: 1.6;
|
||||
color: #000;
|
||||
}
|
||||
h1 {
|
||||
font-size: 28px;
|
||||
margin-bottom: 8px;
|
||||
border-bottom: 2px solid #000;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
.meta {
|
||||
color: #333;
|
||||
margin-bottom: 32px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.instructions {
|
||||
background: #f5f5f5;
|
||||
padding: 12px 16px;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 24px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.question-block {
|
||||
margin-bottom: 28px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid #ddd;
|
||||
}
|
||||
.question-number {
|
||||
font-weight: bold;
|
||||
font-size: 18px;
|
||||
color: #000;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.question-text {
|
||||
font-size: 16px;
|
||||
margin: 8px 0 16px 0;
|
||||
line-height: 1.5;
|
||||
}
|
||||
.options {
|
||||
margin-left: 20px;
|
||||
}
|
||||
.option {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
margin-bottom: 12px;
|
||||
padding: 8px 12px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
background: #fff;
|
||||
}
|
||||
.option-correct {
|
||||
background: #e8f5e9;
|
||||
border-color: #4caf50;
|
||||
border-width: 2px;
|
||||
}
|
||||
.option-checkbox {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border: 2px solid #333;
|
||||
border-radius: 50%;
|
||||
margin-right: 12px;
|
||||
flex-shrink: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.option-checkbox.checked::after {
|
||||
content: "\u2713";
|
||||
font-weight: bold;
|
||||
color: #4caf50;
|
||||
}
|
||||
.option-label {
|
||||
font-weight: bold;
|
||||
margin-right: 8px;
|
||||
min-width: 24px;
|
||||
}
|
||||
.option-text {
|
||||
flex: 1;
|
||||
}
|
||||
.explanation {
|
||||
margin-top: 8px;
|
||||
padding: 8px 12px;
|
||||
background: #e3f2fd;
|
||||
border-left: 3px solid #2196f3;
|
||||
font-size: 13px;
|
||||
color: #333;
|
||||
}
|
||||
.answer-key {
|
||||
margin-top: 40px;
|
||||
padding: 16px;
|
||||
background: #f5f5f5;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.answer-key-title {
|
||||
font-weight: bold;
|
||||
font-size: 18px;
|
||||
margin-bottom: 12px;
|
||||
border-bottom: 1px solid #999;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
.answer-key-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(5, 1fr);
|
||||
gap: 8px;
|
||||
}
|
||||
.answer-key-item {
|
||||
padding: 8px;
|
||||
text-align: center;
|
||||
background: white;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.answer-key-q {
|
||||
font-weight: bold;
|
||||
}
|
||||
.answer-key-a {
|
||||
color: #4caf50;
|
||||
font-weight: bold;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
""")
|
||||
|
||||
# Header
|
||||
version_text = "Loesungsblatt" if include_answers else "Multiple Choice Test"
|
||||
html_parts.append(f"<h1>{title}</h1>")
|
||||
html_parts.append(f"<div class='meta'><strong>{version_text}</strong>")
|
||||
if subject:
|
||||
html_parts.append(f" | Fach: {subject}")
|
||||
if grade:
|
||||
html_parts.append(f" | Klasse: {grade}")
|
||||
html_parts.append(f" | Anzahl Fragen: {len(questions)}</div>")
|
||||
|
||||
if not include_answers:
|
||||
html_parts.append("<div class='instructions'>")
|
||||
html_parts.append("<strong>Anleitung:</strong> Kreuze bei jeder Frage die richtige Antwort an. ")
|
||||
html_parts.append("Es ist immer nur eine Antwort richtig.")
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Fragen
|
||||
for idx, q in enumerate(questions, 1):
|
||||
html_parts.append("<div class='question-block'>")
|
||||
html_parts.append(f"<div class='question-number'>Frage {idx}</div>")
|
||||
html_parts.append(f"<div class='question-text'>{q.get('question', '')}</div>")
|
||||
|
||||
html_parts.append("<div class='options'>")
|
||||
correct_answer = q.get("correct_answer", "")
|
||||
|
||||
for opt in q.get("options", []):
|
||||
opt_id = opt.get("id", "")
|
||||
is_correct = opt_id == correct_answer
|
||||
|
||||
opt_class = "option"
|
||||
checkbox_class = "option-checkbox"
|
||||
if include_answers and is_correct:
|
||||
opt_class += " option-correct"
|
||||
checkbox_class += " checked"
|
||||
|
||||
html_parts.append(f"<div class='{opt_class}'>")
|
||||
html_parts.append(f"<div class='{checkbox_class}'></div>")
|
||||
html_parts.append(f"<span class='option-label'>{opt_id})</span>")
|
||||
html_parts.append(f"<span class='option-text'>{opt.get('text', '')}</span>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Erklaerung nur bei Loesungsblatt
|
||||
if include_answers and q.get("explanation"):
|
||||
html_parts.append(f"<div class='explanation'><strong>Erklaerung:</strong> {q.get('explanation')}</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Loesungsschluessel (kompakt) - nur bei Loesungsblatt
|
||||
if include_answers:
|
||||
html_parts.append("<div class='answer-key'>")
|
||||
html_parts.append("<div class='answer-key-title'>Loesungsschluessel</div>")
|
||||
html_parts.append("<div class='answer-key-grid'>")
|
||||
for idx, q in enumerate(questions, 1):
|
||||
html_parts.append("<div class='answer-key-item'>")
|
||||
html_parts.append(f"<span class='answer-key-q'>{idx}.</span> ")
|
||||
html_parts.append(f"<span class='answer-key-a'>{q.get('correct_answer', '')}</span>")
|
||||
html_parts.append("</div>")
|
||||
html_parts.append("</div>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</body></html>")
|
||||
|
||||
return "\n".join(html_parts)
|
||||
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
AI Processing - Print Version Generator: Q&A.
|
||||
|
||||
Generates printable HTML for question-answer worksheets.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
import logging
|
||||
|
||||
from .core import BEREINIGT_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_print_version_qa(qa_path: Path, include_answers: bool = False) -> Path:
|
||||
"""
|
||||
Generiert eine druckbare HTML-Version der Frage-Antwort-Paare.
|
||||
|
||||
Args:
|
||||
qa_path: Pfad zur *_qa.json Datei
|
||||
include_answers: True fuer Loesungsblatt (fuer Eltern)
|
||||
|
||||
Returns:
|
||||
Pfad zur generierten HTML-Datei
|
||||
"""
|
||||
if not qa_path.exists():
|
||||
raise FileNotFoundError(f"Q&A-Datei nicht gefunden: {qa_path}")
|
||||
|
||||
qa_data = json.loads(qa_path.read_text(encoding="utf-8"))
|
||||
items = qa_data.get("qa_items", [])
|
||||
metadata = qa_data.get("metadata", {})
|
||||
|
||||
title = metadata.get("source_title", "Arbeitsblatt")
|
||||
subject = metadata.get("subject", "")
|
||||
grade = metadata.get("grade_level", "")
|
||||
|
||||
html_parts = []
|
||||
html_parts.append("""<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>""" + title + """ - Fragen</title>
|
||||
<style>
|
||||
@media print {
|
||||
.no-print { display: none; }
|
||||
.page-break { page-break-before: always; }
|
||||
}
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 40px auto;
|
||||
padding: 20px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
h1 { font-size: 24px; margin-bottom: 8px; }
|
||||
.meta { color: #666; margin-bottom: 24px; }
|
||||
.question-block {
|
||||
margin-bottom: 32px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px dashed #ccc;
|
||||
}
|
||||
.question-number {
|
||||
font-weight: bold;
|
||||
color: #333;
|
||||
}
|
||||
.question-text {
|
||||
font-size: 16px;
|
||||
margin: 8px 0;
|
||||
}
|
||||
.answer-space {
|
||||
border: 1px solid #ddd;
|
||||
min-height: 60px;
|
||||
margin-top: 12px;
|
||||
background: #fafafa;
|
||||
}
|
||||
.answer-lines {
|
||||
margin-top: 12px;
|
||||
}
|
||||
.answer-line {
|
||||
border-bottom: 1px solid #999;
|
||||
height: 28px;
|
||||
}
|
||||
.answer {
|
||||
margin-top: 8px;
|
||||
padding: 8px;
|
||||
background: #e8f5e9;
|
||||
border-left: 3px solid #4caf50;
|
||||
}
|
||||
.key-terms {
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
margin-top: 8px;
|
||||
}
|
||||
.key-terms span {
|
||||
background: #fff3e0;
|
||||
padding: 2px 6px;
|
||||
border-radius: 3px;
|
||||
margin-right: 4px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
""")
|
||||
|
||||
# Header
|
||||
version_text = "Loesungsblatt" if include_answers else "Fragenblatt"
|
||||
html_parts.append(f"<h1>{title} - {version_text}</h1>")
|
||||
meta_parts = []
|
||||
if subject:
|
||||
meta_parts.append(f"Fach: {subject}")
|
||||
if grade:
|
||||
meta_parts.append(f"Klasse: {grade}")
|
||||
meta_parts.append(f"Anzahl Fragen: {len(items)}")
|
||||
html_parts.append(f"<div class='meta'>{' | '.join(meta_parts)}</div>")
|
||||
|
||||
# Fragen
|
||||
for idx, item in enumerate(items, 1):
|
||||
html_parts.append("<div class='question-block'>")
|
||||
html_parts.append(f"<div class='question-number'>Frage {idx}</div>")
|
||||
html_parts.append(f"<div class='question-text'>{item.get('question', '')}</div>")
|
||||
|
||||
if include_answers:
|
||||
# Loesungsblatt: Antwort anzeigen
|
||||
html_parts.append(f"<div class='answer'><strong>Antwort:</strong> {item.get('answer', '')}</div>")
|
||||
# Schluesselbegriffe
|
||||
key_terms = item.get("key_terms", [])
|
||||
if key_terms:
|
||||
terms_html = " ".join([f"<span>{term}</span>" for term in key_terms])
|
||||
html_parts.append(f"<div class='key-terms'>Wichtige Begriffe: {terms_html}</div>")
|
||||
else:
|
||||
# Fragenblatt: Antwortlinien
|
||||
html_parts.append("<div class='answer-lines'>")
|
||||
for _ in range(3):
|
||||
html_parts.append("<div class='answer-line'></div>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</body></html>")
|
||||
|
||||
# Speichern
|
||||
suffix = "_qa_solutions.html" if include_answers else "_qa_print.html"
|
||||
out_name = qa_path.stem.replace("_qa", "") + suffix
|
||||
out_path = BEREINIGT_DIR / out_name
|
||||
out_path.write_text("\n".join(html_parts), encoding="utf-8")
|
||||
|
||||
logger.info(f"Print-Version gespeichert: {out_path.name}")
|
||||
return out_path
|
||||
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
AI Processing - Print Version Generator: Worksheet.
|
||||
|
||||
Generates print-optimized HTML for general worksheets from analysis data.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_print_version_worksheet(analysis_path: Path) -> str:
|
||||
"""
|
||||
Generiert eine druckoptimierte HTML-Version des Arbeitsblatts.
|
||||
|
||||
Eigenschaften:
|
||||
- Grosse, gut lesbare Schrift (16pt)
|
||||
- Schwarz-weiss / Graustufen-tauglich
|
||||
- Klare Struktur fuer Druck
|
||||
- Keine interaktiven Elemente
|
||||
|
||||
Args:
|
||||
analysis_path: Pfad zur *_analyse.json Datei
|
||||
|
||||
Returns:
|
||||
HTML-String zum direkten Ausliefern
|
||||
"""
|
||||
if not analysis_path.exists():
|
||||
raise FileNotFoundError(f"Analysedatei nicht gefunden: {analysis_path}")
|
||||
|
||||
try:
|
||||
data = json.loads(analysis_path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(f"Analyse-Datei enthaelt kein gueltiges JSON: {analysis_path}\n{e}") from e
|
||||
|
||||
title = data.get("title") or "Arbeitsblatt"
|
||||
subject = data.get("subject") or ""
|
||||
grade_level = data.get("grade_level") or ""
|
||||
instructions = data.get("instructions") or ""
|
||||
tasks = data.get("tasks", []) or []
|
||||
canonical_text = data.get("canonical_text") or ""
|
||||
printed_blocks = data.get("printed_blocks") or []
|
||||
|
||||
html_parts = []
|
||||
html_parts.append(_build_html_head(title))
|
||||
|
||||
# Titel
|
||||
html_parts.append(f"<h1>{title}</h1>")
|
||||
|
||||
# Meta-Informationen
|
||||
meta_parts = []
|
||||
if subject:
|
||||
meta_parts.append(f"<span><strong>Fach:</strong> {subject}</span>")
|
||||
if grade_level:
|
||||
meta_parts.append(f"<span><strong>Klasse:</strong> {grade_level}</span>")
|
||||
if meta_parts:
|
||||
html_parts.append(f"<div class='meta'>{''.join(meta_parts)}</div>")
|
||||
|
||||
# Arbeitsanweisung
|
||||
if instructions:
|
||||
html_parts.append("<div class='instructions'>")
|
||||
html_parts.append("<div class='instructions-label'>Arbeitsanweisung:</div>")
|
||||
html_parts.append(f"<div>{instructions}</div>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
# Haupttext / gedruckte Bloecke
|
||||
_build_text_section(html_parts, printed_blocks, canonical_text)
|
||||
|
||||
# Aufgaben
|
||||
_build_tasks_section(html_parts, tasks)
|
||||
|
||||
# Fusszeile
|
||||
html_parts.append("<div class='footer'>")
|
||||
html_parts.append("Dieses Arbeitsblatt wurde automatisch aus einem Scan rekonstruiert.")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</body></html>")
|
||||
|
||||
return "\n".join(html_parts)
|
||||
|
||||
|
||||
def _build_html_head(title: str) -> str:
|
||||
"""Build the HTML head with print-optimized styles."""
|
||||
return """<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>""" + title + """</title>
|
||||
<style>
|
||||
@page {
|
||||
size: A4;
|
||||
margin: 20mm;
|
||||
}
|
||||
@media print {
|
||||
body {
|
||||
font-size: 14pt !important;
|
||||
-webkit-print-color-adjust: exact;
|
||||
print-color-adjust: exact;
|
||||
}
|
||||
.no-print { display: none !important; }
|
||||
.page-break { page-break-before: always; }
|
||||
}
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: Arial, "Helvetica Neue", sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 30px;
|
||||
line-height: 1.7;
|
||||
font-size: 16px;
|
||||
color: #000;
|
||||
background: #fff;
|
||||
}
|
||||
h1 {
|
||||
font-size: 28px;
|
||||
margin: 0 0 8px 0;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 3px solid #000;
|
||||
}
|
||||
h2 {
|
||||
font-size: 20px;
|
||||
margin: 28px 0 12px 0;
|
||||
padding-bottom: 4px;
|
||||
border-bottom: 1px solid #666;
|
||||
}
|
||||
.meta {
|
||||
font-size: 14px;
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
padding: 8px 0;
|
||||
}
|
||||
.meta span {
|
||||
margin-right: 20px;
|
||||
}
|
||||
.instructions {
|
||||
margin: 20px 0;
|
||||
padding: 16px;
|
||||
border: 2px solid #333;
|
||||
background: #f5f5f5;
|
||||
font-size: 15px;
|
||||
}
|
||||
.instructions-label {
|
||||
font-weight: bold;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.text-section {
|
||||
margin: 24px 0;
|
||||
}
|
||||
.text-block {
|
||||
margin-bottom: 16px;
|
||||
text-align: justify;
|
||||
}
|
||||
.text-block-title {
|
||||
font-weight: bold;
|
||||
font-size: 17px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.task-section {
|
||||
margin-top: 32px;
|
||||
}
|
||||
.task {
|
||||
margin-bottom: 24px;
|
||||
padding: 16px;
|
||||
border: 1px solid #999;
|
||||
background: #fafafa;
|
||||
}
|
||||
.task-header {
|
||||
font-weight: bold;
|
||||
font-size: 16px;
|
||||
margin-bottom: 12px;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 1px dashed #666;
|
||||
}
|
||||
.task-content {
|
||||
font-size: 15px;
|
||||
}
|
||||
.gap-line {
|
||||
display: inline-block;
|
||||
border-bottom: 2px solid #000;
|
||||
min-width: 100px;
|
||||
margin: 0 6px;
|
||||
}
|
||||
.answer-lines {
|
||||
margin-top: 16px;
|
||||
}
|
||||
.answer-line {
|
||||
border-bottom: 1px solid #333;
|
||||
height: 36px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.footer {
|
||||
margin-top: 40px;
|
||||
padding-top: 16px;
|
||||
border-top: 1px solid #ccc;
|
||||
font-size: 11px;
|
||||
color: #666;
|
||||
text-align: center;
|
||||
}
|
||||
/* Print Button - versteckt beim Drucken */
|
||||
.print-button {
|
||||
position: fixed;
|
||||
top: 20px;
|
||||
right: 20px;
|
||||
padding: 12px 24px;
|
||||
background: #333;
|
||||
color: #fff;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
}
|
||||
.print-button:hover {
|
||||
background: #555;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<button class="print-button no-print" onclick="window.print()">Drucken</button>
|
||||
"""
|
||||
|
||||
|
||||
def _build_text_section(html_parts: list, printed_blocks: list, canonical_text: str):
|
||||
"""Build the text section from printed blocks or canonical text."""
|
||||
if printed_blocks:
|
||||
html_parts.append("<section class='text-section'>")
|
||||
for block in printed_blocks:
|
||||
role = (block.get("role") or "body").lower()
|
||||
text = (block.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if role == "title":
|
||||
html_parts.append(f"<div class='text-block'><div class='text-block-title'>{text}</div></div>")
|
||||
else:
|
||||
html_parts.append(f"<div class='text-block'>{text}</div>")
|
||||
html_parts.append("</section>")
|
||||
elif canonical_text:
|
||||
html_parts.append("<section class='text-section'>")
|
||||
paragraphs = [
|
||||
p.strip()
|
||||
for p in canonical_text.replace("\r\n", "\n").split("\n\n")
|
||||
if p.strip()
|
||||
]
|
||||
for p in paragraphs:
|
||||
html_parts.append(f"<div class='text-block'>{p}</div>")
|
||||
html_parts.append("</section>")
|
||||
|
||||
|
||||
def _build_tasks_section(html_parts: list, tasks: list):
|
||||
"""Build the tasks section."""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
html_parts.append("<section class='task-section'>")
|
||||
html_parts.append("<h2>Aufgaben</h2>")
|
||||
|
||||
type_labels = {
|
||||
"fill_in_blank": "Lueckentext",
|
||||
"multiple_choice": "Multiple Choice",
|
||||
"free_text": "Freitext",
|
||||
"matching": "Zuordnung",
|
||||
"labeling": "Beschriftung",
|
||||
"calculation": "Rechnung",
|
||||
"other": "Aufgabe"
|
||||
}
|
||||
|
||||
for idx, task in enumerate(tasks, start=1):
|
||||
t_type = task.get("type") or "Aufgabe"
|
||||
desc = task.get("description") or ""
|
||||
text_with_gaps = task.get("text_with_gaps")
|
||||
|
||||
html_parts.append("<div class='task'>")
|
||||
|
||||
type_label = type_labels.get(t_type, t_type)
|
||||
html_parts.append(f"<div class='task-header'>Aufgabe {idx}: {type_label}</div>")
|
||||
|
||||
if desc:
|
||||
html_parts.append(f"<div class='task-content'>{desc}</div>")
|
||||
|
||||
if text_with_gaps:
|
||||
rendered = text_with_gaps.replace("___", "<span class='gap-line'> </span>")
|
||||
html_parts.append(f"<div class='task-content' style='margin-top:12px;'>{rendered}</div>")
|
||||
|
||||
# Antwortlinien fuer Freitext-Aufgaben
|
||||
if t_type in ["free_text", "other"] or (not text_with_gaps and not desc):
|
||||
html_parts.append("<div class='answer-lines'>")
|
||||
for _ in range(3):
|
||||
html_parts.append("<div class='answer-line'></div>")
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</div>")
|
||||
|
||||
html_parts.append("</section>")
|
||||
@@ -1,726 +1,25 @@
|
||||
"""
|
||||
Classroom API - Context Routes
|
||||
Classroom API - Context Routes — Barrel Re-export.
|
||||
|
||||
Split into submodules:
|
||||
- context_core.py — Teacher context, onboarding endpoints
|
||||
- context_events.py — Events & routines CRUD
|
||||
- context_static.py — Static data, suggestions, sidebar, school year path
|
||||
|
||||
School year context, events, routines, and suggestions endpoints (Phase 8).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
|
||||
from classroom_engine import (
|
||||
FEDERAL_STATES,
|
||||
SCHOOL_TYPES,
|
||||
MacroPhaseEnum,
|
||||
)
|
||||
|
||||
from ..models import (
|
||||
TeacherContextResponse,
|
||||
SchoolInfo,
|
||||
SchoolYearInfo,
|
||||
MacroPhaseInfo,
|
||||
CoreCounts,
|
||||
ContextFlags,
|
||||
UpdateContextRequest,
|
||||
CreateEventRequest,
|
||||
EventResponse,
|
||||
CreateRoutineRequest,
|
||||
RoutineResponse,
|
||||
)
|
||||
from ..services.persistence import (
|
||||
init_db_if_needed,
|
||||
DB_ENABLED,
|
||||
SessionLocal,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .context_core import router as _core_router
|
||||
from .context_events import router as _events_router
|
||||
from .context_static import router as _static_router
|
||||
|
||||
# Combine all sub-routers into a single router for backwards compatibility.
|
||||
# The consumer imports `from .routes.context import router as context_router`.
|
||||
router = APIRouter(tags=["Context"])
|
||||
router.include_router(_core_router)
|
||||
router.include_router(_events_router)
|
||||
router.include_router(_static_router)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Database session dependency."""
|
||||
if DB_ENABLED and SessionLocal:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
||||
def _get_macro_phase_label(phase) -> str:
|
||||
"""Gibt den Anzeigenamen einer Makro-Phase zurueck."""
|
||||
labels = {
|
||||
"onboarding": "Einrichtung",
|
||||
"schuljahresstart": "Schuljahresstart",
|
||||
"unterrichtsaufbau": "Unterrichtsaufbau",
|
||||
"leistungsphase_1": "Leistungsphase 1",
|
||||
"halbjahresabschluss": "Halbjahresabschluss",
|
||||
"leistungsphase_2": "Leistungsphase 2",
|
||||
"jahresabschluss": "Jahresabschluss",
|
||||
}
|
||||
phase_value = phase.value if hasattr(phase, 'value') else str(phase)
|
||||
return labels.get(phase_value, phase_value)
|
||||
|
||||
|
||||
# === Context Endpoints ===
|
||||
|
||||
@router.get("/v1/context", response_model=TeacherContextResponse)
|
||||
async def get_teacher_context(
|
||||
teacher_id: str = Query(..., description="Teacher ID"),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Liefert den aktuellen Makro-Kontext eines Lehrers.
|
||||
|
||||
Der Kontext beinhaltet:
|
||||
- Schul-Informationen (Bundesland, Schulart)
|
||||
- Schuljahr-Daten (aktuelles Jahr, Woche)
|
||||
- Makro-Phase (ONBOARDING bis JAHRESABSCHLUSS)
|
||||
- Zaehler (Klassen, geplante Klausuren, etc.)
|
||||
- Status-Flags (Onboarding abgeschlossen, etc.)
|
||||
"""
|
||||
if DB_ENABLED and db:
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository, SchoolyearEventRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
context = repo.get_or_create(teacher_id)
|
||||
|
||||
# Zaehler berechnen
|
||||
event_repo = SchoolyearEventRepository(db)
|
||||
upcoming_exams = event_repo.get_upcoming(teacher_id, days=30)
|
||||
exams_count = len([e for e in upcoming_exams if e.event_type.value == "exam"])
|
||||
|
||||
return TeacherContextResponse(
|
||||
schema_version="1.0",
|
||||
teacher_id=teacher_id,
|
||||
school=SchoolInfo(
|
||||
federal_state=context.federal_state or "BY",
|
||||
federal_state_name=FEDERAL_STATES.get(context.federal_state, ""),
|
||||
school_type=context.school_type or "gymnasium",
|
||||
school_type_name=SCHOOL_TYPES.get(context.school_type, ""),
|
||||
),
|
||||
school_year=SchoolYearInfo(
|
||||
id=context.schoolyear or "2024-2025",
|
||||
start=context.schoolyear_start.isoformat() if context.schoolyear_start else None,
|
||||
current_week=context.current_week or 1,
|
||||
),
|
||||
macro_phase=MacroPhaseInfo(
|
||||
id=context.macro_phase.value,
|
||||
label=_get_macro_phase_label(context.macro_phase),
|
||||
confidence=1.0,
|
||||
),
|
||||
core_counts=CoreCounts(
|
||||
classes=1 if context.has_classes else 0,
|
||||
exams_scheduled=exams_count,
|
||||
corrections_pending=0,
|
||||
),
|
||||
flags=ContextFlags(
|
||||
onboarding_completed=context.onboarding_completed,
|
||||
has_classes=context.has_classes,
|
||||
has_schedule=context.has_schedule,
|
||||
is_exam_period=context.is_exam_period,
|
||||
is_before_holidays=context.is_before_holidays,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get teacher context: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler beim Laden des Kontexts: {e}")
|
||||
|
||||
# Fallback ohne DB
|
||||
return TeacherContextResponse(
|
||||
schema_version="1.0",
|
||||
teacher_id=teacher_id,
|
||||
school=SchoolInfo(
|
||||
federal_state="BY",
|
||||
federal_state_name="Bayern",
|
||||
school_type="gymnasium",
|
||||
school_type_name="Gymnasium",
|
||||
),
|
||||
school_year=SchoolYearInfo(
|
||||
id="2024-2025",
|
||||
start=None,
|
||||
current_week=1,
|
||||
),
|
||||
macro_phase=MacroPhaseInfo(
|
||||
id="onboarding",
|
||||
label="Einrichtung",
|
||||
confidence=1.0,
|
||||
),
|
||||
core_counts=CoreCounts(),
|
||||
flags=ContextFlags(),
|
||||
)
|
||||
|
||||
|
||||
@router.put("/v1/context", response_model=TeacherContextResponse)
|
||||
async def update_teacher_context(
|
||||
teacher_id: str,
|
||||
request: UpdateContextRequest,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Aktualisiert den Kontext eines Lehrers.
|
||||
"""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
|
||||
# Validierung
|
||||
if request.federal_state and request.federal_state not in FEDERAL_STATES:
|
||||
raise HTTPException(status_code=400, detail=f"Ungueltiges Bundesland: {request.federal_state}")
|
||||
if request.school_type and request.school_type not in SCHOOL_TYPES:
|
||||
raise HTTPException(status_code=400, detail=f"Ungueltige Schulart: {request.school_type}")
|
||||
|
||||
# Parse datetime if provided
|
||||
schoolyear_start = None
|
||||
if request.schoolyear_start:
|
||||
schoolyear_start = datetime.fromisoformat(request.schoolyear_start.replace('Z', '+00:00'))
|
||||
|
||||
repo.update_context(
|
||||
teacher_id=teacher_id,
|
||||
federal_state=request.federal_state,
|
||||
school_type=request.school_type,
|
||||
schoolyear=request.schoolyear,
|
||||
schoolyear_start=schoolyear_start,
|
||||
macro_phase=request.macro_phase,
|
||||
current_week=request.current_week,
|
||||
)
|
||||
|
||||
return await get_teacher_context(teacher_id, db)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update teacher context: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler beim Aktualisieren: {e}")
|
||||
|
||||
|
||||
@router.post("/v1/context/complete-onboarding")
|
||||
async def complete_onboarding(
|
||||
teacher_id: str = Query(...),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Markiert das Onboarding als abgeschlossen."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"success": True, "macro_phase": "schuljahresstart", "note": "DB not available"}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
context = repo.complete_onboarding(teacher_id)
|
||||
return {
|
||||
"success": True,
|
||||
"macro_phase": context.macro_phase.value,
|
||||
"teacher_id": teacher_id,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to complete onboarding: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.post("/v1/context/reset-onboarding")
|
||||
async def reset_onboarding(
|
||||
teacher_id: str = Query(...),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Setzt das Onboarding zurueck (fuer Tests)."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"success": True, "macro_phase": "onboarding", "note": "DB not available"}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
context = repo.get_or_create(teacher_id)
|
||||
context.onboarding_completed = False
|
||||
context.macro_phase = MacroPhaseEnum.ONBOARDING
|
||||
db.commit()
|
||||
db.refresh(context)
|
||||
return {
|
||||
"success": True,
|
||||
"macro_phase": "onboarding",
|
||||
"teacher_id": teacher_id,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset onboarding: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
# === Events Endpoints ===
|
||||
|
||||
@router.get("/v1/events")
|
||||
async def get_events(
|
||||
teacher_id: str = Query(...),
|
||||
status: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Holt Events eines Lehrers."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"events": [], "count": 0}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import SchoolyearEventRepository
|
||||
repo = SchoolyearEventRepository(db)
|
||||
events = repo.get_by_teacher(teacher_id, status=status, event_type=event_type, limit=limit)
|
||||
return {
|
||||
"events": [repo.to_dict(e) for e in events],
|
||||
"count": len(events),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get events: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.get("/v1/events/upcoming")
|
||||
async def get_upcoming_events(
|
||||
teacher_id: str = Query(...),
|
||||
days: int = 30,
|
||||
limit: int = 10,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Holt anstehende Events der naechsten X Tage."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"events": [], "count": 0}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import SchoolyearEventRepository
|
||||
repo = SchoolyearEventRepository(db)
|
||||
events = repo.get_upcoming(teacher_id, days=days, limit=limit)
|
||||
return {
|
||||
"events": [repo.to_dict(e) for e in events],
|
||||
"count": len(events),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get upcoming events: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.post("/v1/events", response_model=EventResponse)
|
||||
async def create_event(
|
||||
teacher_id: str,
|
||||
request: CreateEventRequest,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Erstellt ein neues Schuljahr-Event."""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import SchoolyearEventRepository
|
||||
repo = SchoolyearEventRepository(db)
|
||||
start_date = datetime.fromisoformat(request.start_date.replace('Z', '+00:00'))
|
||||
end_date = None
|
||||
if request.end_date:
|
||||
end_date = datetime.fromisoformat(request.end_date.replace('Z', '+00:00'))
|
||||
|
||||
event = repo.create(
|
||||
teacher_id=teacher_id,
|
||||
title=request.title,
|
||||
event_type=request.event_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
class_id=request.class_id,
|
||||
subject=request.subject,
|
||||
description=request.description,
|
||||
needs_preparation=request.needs_preparation,
|
||||
reminder_days_before=request.reminder_days_before,
|
||||
)
|
||||
|
||||
return EventResponse(
|
||||
id=event.id,
|
||||
teacher_id=event.teacher_id,
|
||||
event_type=event.event_type.value,
|
||||
title=event.title,
|
||||
description=event.description,
|
||||
start_date=event.start_date.isoformat(),
|
||||
end_date=event.end_date.isoformat() if event.end_date else None,
|
||||
class_id=event.class_id,
|
||||
subject=event.subject,
|
||||
status=event.status.value,
|
||||
needs_preparation=event.needs_preparation,
|
||||
preparation_done=event.preparation_done,
|
||||
reminder_days_before=event.reminder_days_before,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create event: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.delete("/v1/events/{event_id}")
|
||||
async def delete_event(event_id: str, db=Depends(get_db)):
|
||||
"""Loescht ein Event."""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import SchoolyearEventRepository
|
||||
repo = SchoolyearEventRepository(db)
|
||||
if repo.delete(event_id):
|
||||
return {"success": True, "deleted_id": event_id}
|
||||
raise HTTPException(status_code=404, detail="Event nicht gefunden")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete event: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
# === Routines Endpoints ===
|
||||
|
||||
@router.get("/v1/routines")
|
||||
async def get_routines(
|
||||
teacher_id: str = Query(...),
|
||||
is_active: bool = True,
|
||||
routine_type: Optional[str] = None,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Holt Routinen eines Lehrers."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"routines": [], "count": 0}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import RecurringRoutineRepository
|
||||
repo = RecurringRoutineRepository(db)
|
||||
routines = repo.get_by_teacher(teacher_id, is_active=is_active, routine_type=routine_type)
|
||||
return {
|
||||
"routines": [repo.to_dict(r) for r in routines],
|
||||
"count": len(routines),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get routines: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.get("/v1/routines/today")
|
||||
async def get_today_routines(teacher_id: str = Query(...), db=Depends(get_db)):
|
||||
"""Holt Routinen die heute stattfinden."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"routines": [], "count": 0}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import RecurringRoutineRepository
|
||||
repo = RecurringRoutineRepository(db)
|
||||
routines = repo.get_today(teacher_id)
|
||||
return {
|
||||
"routines": [repo.to_dict(r) for r in routines],
|
||||
"count": len(routines),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get today's routines: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.post("/v1/routines", response_model=RoutineResponse)
|
||||
async def create_routine(
|
||||
teacher_id: str,
|
||||
request: CreateRoutineRequest,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Erstellt eine neue wiederkehrende Routine."""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import RecurringRoutineRepository
|
||||
repo = RecurringRoutineRepository(db)
|
||||
routine = repo.create(
|
||||
teacher_id=teacher_id,
|
||||
title=request.title,
|
||||
routine_type=request.routine_type,
|
||||
recurrence_pattern=request.recurrence_pattern,
|
||||
day_of_week=request.day_of_week,
|
||||
day_of_month=request.day_of_month,
|
||||
time_of_day=request.time_of_day,
|
||||
duration_minutes=request.duration_minutes,
|
||||
description=request.description,
|
||||
)
|
||||
|
||||
return RoutineResponse(
|
||||
id=routine.id,
|
||||
teacher_id=routine.teacher_id,
|
||||
routine_type=routine.routine_type.value,
|
||||
title=routine.title,
|
||||
description=routine.description,
|
||||
recurrence_pattern=routine.recurrence_pattern.value,
|
||||
day_of_week=routine.day_of_week,
|
||||
day_of_month=routine.day_of_month,
|
||||
time_of_day=routine.time_of_day.isoformat() if routine.time_of_day else None,
|
||||
duration_minutes=routine.duration_minutes,
|
||||
is_active=routine.is_active,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create routine: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.delete("/v1/routines/{routine_id}")
|
||||
async def delete_routine(routine_id: str, db=Depends(get_db)):
|
||||
"""Loescht eine Routine."""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import RecurringRoutineRepository
|
||||
repo = RecurringRoutineRepository(db)
|
||||
if repo.delete(routine_id):
|
||||
return {"success": True, "deleted_id": routine_id}
|
||||
raise HTTPException(status_code=404, detail="Routine nicht gefunden")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete routine: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
# === Static Data Endpoints ===
|
||||
|
||||
@router.get("/v1/federal-states")
|
||||
async def get_federal_states():
|
||||
"""Gibt alle Bundeslaender zurueck."""
|
||||
return {
|
||||
"federal_states": [{"id": k, "name": v} for k, v in FEDERAL_STATES.items()]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/school-types")
|
||||
async def get_school_types():
|
||||
"""Gibt alle Schularten zurueck."""
|
||||
return {
|
||||
"school_types": [{"id": k, "name": v} for k, v in SCHOOL_TYPES.items()]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/macro-phases")
|
||||
async def get_macro_phases():
|
||||
"""Gibt alle Makro-Phasen mit Beschreibungen zurueck."""
|
||||
phases = [
|
||||
{"id": "onboarding", "label": "Einrichtung", "description": "Ersteinrichtung (Klassen, Stundenplan)", "order": 1},
|
||||
{"id": "schuljahresstart", "label": "Schuljahresstart", "description": "Erste 2-3 Wochen des Schuljahres", "order": 2},
|
||||
{"id": "unterrichtsaufbau", "label": "Unterrichtsaufbau", "description": "Routinen etablieren, erste Bewertungen", "order": 3},
|
||||
{"id": "leistungsphase_1", "label": "Leistungsphase 1", "description": "Erste Klassenarbeiten und Klausuren", "order": 4},
|
||||
{"id": "halbjahresabschluss", "label": "Halbjahresabschluss", "description": "Notenschluss, Zeugnisse, Konferenzen", "order": 5},
|
||||
{"id": "leistungsphase_2", "label": "Leistungsphase 2", "description": "Zweites Halbjahr, Pruefungsvorbereitung", "order": 6},
|
||||
{"id": "jahresabschluss", "label": "Jahresabschluss", "description": "Finale Noten, Versetzung, Schuljahresende", "order": 7},
|
||||
]
|
||||
return {"macro_phases": phases}
|
||||
|
||||
|
||||
@router.get("/v1/event-types")
|
||||
async def get_event_types():
|
||||
"""Gibt alle Event-Typen zurueck."""
|
||||
types = [
|
||||
{"id": "exam", "label": "Klassenarbeit/Klausur"},
|
||||
{"id": "parent_evening", "label": "Elternabend"},
|
||||
{"id": "trip", "label": "Klassenfahrt/Ausflug"},
|
||||
{"id": "project", "label": "Projektwoche"},
|
||||
{"id": "internship", "label": "Praktikum"},
|
||||
{"id": "presentation", "label": "Referate/Praesentationen"},
|
||||
{"id": "sports_day", "label": "Sporttag"},
|
||||
{"id": "school_festival", "label": "Schulfest"},
|
||||
{"id": "parent_consultation", "label": "Elternsprechtag"},
|
||||
{"id": "grade_deadline", "label": "Notenschluss"},
|
||||
{"id": "report_cards", "label": "Zeugnisausgabe"},
|
||||
{"id": "holiday_start", "label": "Ferienbeginn"},
|
||||
{"id": "holiday_end", "label": "Ferienende"},
|
||||
{"id": "other", "label": "Sonstiges"},
|
||||
]
|
||||
return {"event_types": types}
|
||||
|
||||
|
||||
@router.get("/v1/routine-types")
|
||||
async def get_routine_types():
|
||||
"""Gibt alle Routine-Typen zurueck."""
|
||||
types = [
|
||||
{"id": "teacher_conference", "label": "Lehrerkonferenz"},
|
||||
{"id": "subject_conference", "label": "Fachkonferenz"},
|
||||
{"id": "office_hours", "label": "Sprechstunde"},
|
||||
{"id": "team_meeting", "label": "Teamsitzung"},
|
||||
{"id": "supervision", "label": "Pausenaufsicht"},
|
||||
{"id": "correction_time", "label": "Korrekturzeit"},
|
||||
{"id": "prep_time", "label": "Vorbereitungszeit"},
|
||||
{"id": "other", "label": "Sonstiges"},
|
||||
]
|
||||
return {"routine_types": types}
|
||||
|
||||
|
||||
# === Suggestions & Sidebar ===
|
||||
|
||||
@router.get("/v1/suggestions")
|
||||
async def get_suggestions(
|
||||
teacher_id: str = Query(...),
|
||||
limit: int = Query(5, ge=1, le=20),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Generiert kontextbasierte Vorschlaege fuer einen Lehrer."""
|
||||
if DB_ENABLED and db:
|
||||
try:
|
||||
from classroom_engine.suggestions import SuggestionGenerator
|
||||
generator = SuggestionGenerator(db)
|
||||
result = generator.generate(teacher_id, limit=limit)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate suggestions: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
return {
|
||||
"active_contexts": [],
|
||||
"suggestions": [],
|
||||
"signals_summary": {
|
||||
"macro_phase": "onboarding",
|
||||
"current_week": 1,
|
||||
"has_classes": False,
|
||||
"exams_soon": 0,
|
||||
"routines_today": 0,
|
||||
},
|
||||
"total_suggestions": 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/sidebar")
|
||||
async def get_sidebar(
|
||||
teacher_id: str = Query(...),
|
||||
mode: str = Query("companion"),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Generiert das dynamische Sidebar-Model."""
|
||||
if mode == "companion":
|
||||
now_relevant = []
|
||||
if DB_ENABLED and db:
|
||||
try:
|
||||
from classroom_engine.suggestions import SuggestionGenerator
|
||||
generator = SuggestionGenerator(db)
|
||||
result = generator.generate(teacher_id, limit=5)
|
||||
now_relevant = [
|
||||
{
|
||||
"id": s["id"],
|
||||
"label": s["title"],
|
||||
"state": "recommended" if s["priority"] > 70 else "default",
|
||||
"badge": s.get("badge"),
|
||||
"icon": s.get("icon", "lightbulb"),
|
||||
"action_url": s.get("action_url"),
|
||||
}
|
||||
for s in result.get("suggestions", [])
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get suggestions for sidebar: {e}")
|
||||
|
||||
return {
|
||||
"mode": "companion",
|
||||
"sections": [
|
||||
{"id": "SEARCH", "type": "search_bar", "placeholder": "Suchen..."},
|
||||
{
|
||||
"id": "NOW_RELEVANT",
|
||||
"type": "list",
|
||||
"title": "Jetzt relevant",
|
||||
"items": now_relevant if now_relevant else [
|
||||
{"id": "no_suggestions", "label": "Keine Vorschlaege", "state": "default", "icon": "check_circle"}
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "ALL_MODULES",
|
||||
"type": "folder",
|
||||
"label": "Alle Module",
|
||||
"icon": "folder",
|
||||
"collapsed": True,
|
||||
"items": [
|
||||
{"id": "lesson", "label": "Stundenmodus", "icon": "timer"},
|
||||
{"id": "classes", "label": "Klassen", "icon": "groups"},
|
||||
{"id": "exams", "label": "Klausuren", "icon": "quiz"},
|
||||
{"id": "grades", "label": "Noten", "icon": "calculate"},
|
||||
{"id": "calendar", "label": "Kalender", "icon": "calendar_month"},
|
||||
{"id": "materials", "label": "Materialien", "icon": "folder_open"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "QUICK_ACTIONS",
|
||||
"type": "actions",
|
||||
"title": "Kurzaktionen",
|
||||
"items": [
|
||||
{"id": "scan", "label": "Scan hochladen", "icon": "upload_file"},
|
||||
{"id": "note", "label": "Notiz erstellen", "icon": "note_add"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"mode": "classic",
|
||||
"sections": [
|
||||
{
|
||||
"id": "NAVIGATION",
|
||||
"type": "tree",
|
||||
"items": [
|
||||
{"id": "dashboard", "label": "Dashboard", "icon": "dashboard", "url": "/dashboard"},
|
||||
{"id": "lesson", "label": "Stundenmodus", "icon": "timer", "url": "/lesson"},
|
||||
{"id": "classes", "label": "Klassen", "icon": "groups", "url": "/classes"},
|
||||
{"id": "exams", "label": "Klausuren", "icon": "quiz", "url": "/exams"},
|
||||
{"id": "grades", "label": "Noten", "icon": "calculate", "url": "/grades"},
|
||||
{"id": "calendar", "label": "Kalender", "icon": "calendar_month", "url": "/calendar"},
|
||||
{"id": "materials", "label": "Materialien", "icon": "folder_open", "url": "/materials"},
|
||||
{"id": "settings", "label": "Einstellungen", "icon": "settings", "url": "/settings"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/path")
|
||||
async def get_schoolyear_path(teacher_id: str = Query(...), db=Depends(get_db)):
|
||||
"""Generiert den Schuljahres-Pfad mit Meilensteinen."""
|
||||
current_phase = "onboarding"
|
||||
if DB_ENABLED and db:
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
context = repo.get_or_create(teacher_id)
|
||||
current_phase = context.macro_phase.value
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get context for path: {e}")
|
||||
|
||||
phase_order = [
|
||||
"onboarding", "schuljahresstart", "unterrichtsaufbau",
|
||||
"leistungsphase_1", "halbjahresabschluss", "leistungsphase_2", "jahresabschluss",
|
||||
]
|
||||
|
||||
current_index = phase_order.index(current_phase) if current_phase in phase_order else 0
|
||||
|
||||
milestones = [
|
||||
{"id": "MS_START", "label": "Start", "phase": "onboarding", "icon": "flag"},
|
||||
{"id": "MS_SETUP", "label": "Einrichtung", "phase": "schuljahresstart", "icon": "tune"},
|
||||
{"id": "MS_ROUTINE", "label": "Routinen", "phase": "unterrichtsaufbau", "icon": "repeat"},
|
||||
{"id": "MS_EXAM_1", "label": "Klausuren", "phase": "leistungsphase_1", "icon": "quiz"},
|
||||
{"id": "MS_HALFYEAR", "label": "Halbjahr", "phase": "halbjahresabschluss", "icon": "event"},
|
||||
{"id": "MS_EXAM_2", "label": "Pruefungen", "phase": "leistungsphase_2", "icon": "school"},
|
||||
{"id": "MS_END", "label": "Abschluss", "phase": "jahresabschluss", "icon": "celebration"},
|
||||
]
|
||||
|
||||
for milestone in milestones:
|
||||
phase = milestone["phase"]
|
||||
phase_index = phase_order.index(phase) if phase in phase_order else 999
|
||||
if phase_index < current_index:
|
||||
milestone["status"] = "done"
|
||||
elif phase_index == current_index:
|
||||
milestone["status"] = "current"
|
||||
else:
|
||||
milestone["status"] = "upcoming"
|
||||
|
||||
current_milestone_id = next(
|
||||
(m["id"] for m in milestones if m["status"] == "current"),
|
||||
milestones[0]["id"]
|
||||
)
|
||||
|
||||
progress = int((current_index / (len(phase_order) - 1)) * 100) if len(phase_order) > 1 else 0
|
||||
|
||||
return {
|
||||
"milestones": milestones,
|
||||
"current_milestone_id": current_milestone_id,
|
||||
"progress_percent": progress,
|
||||
"current_phase": current_phase,
|
||||
}
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Classroom API - Context Core Routes
|
||||
|
||||
Teacher context, onboarding endpoints (Phase 8).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
|
||||
from classroom_engine import (
|
||||
FEDERAL_STATES,
|
||||
SCHOOL_TYPES,
|
||||
MacroPhaseEnum,
|
||||
)
|
||||
|
||||
from ..models import (
|
||||
TeacherContextResponse,
|
||||
SchoolInfo,
|
||||
SchoolYearInfo,
|
||||
MacroPhaseInfo,
|
||||
CoreCounts,
|
||||
ContextFlags,
|
||||
UpdateContextRequest,
|
||||
)
|
||||
from ..services.persistence import (
|
||||
init_db_if_needed,
|
||||
DB_ENABLED,
|
||||
SessionLocal,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Context"])
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Database session dependency."""
|
||||
if DB_ENABLED and SessionLocal:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
||||
def _get_macro_phase_label(phase) -> str:
|
||||
"""Gibt den Anzeigenamen einer Makro-Phase zurueck."""
|
||||
labels = {
|
||||
"onboarding": "Einrichtung",
|
||||
"schuljahresstart": "Schuljahresstart",
|
||||
"unterrichtsaufbau": "Unterrichtsaufbau",
|
||||
"leistungsphase_1": "Leistungsphase 1",
|
||||
"halbjahresabschluss": "Halbjahresabschluss",
|
||||
"leistungsphase_2": "Leistungsphase 2",
|
||||
"jahresabschluss": "Jahresabschluss",
|
||||
}
|
||||
phase_value = phase.value if hasattr(phase, 'value') else str(phase)
|
||||
return labels.get(phase_value, phase_value)
|
||||
|
||||
|
||||
# === Context Endpoints ===
|
||||
|
||||
@router.get("/v1/context", response_model=TeacherContextResponse)
|
||||
async def get_teacher_context(
|
||||
teacher_id: str = Query(..., description="Teacher ID"),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Liefert den aktuellen Makro-Kontext eines Lehrers.
|
||||
|
||||
Der Kontext beinhaltet:
|
||||
- Schul-Informationen (Bundesland, Schulart)
|
||||
- Schuljahr-Daten (aktuelles Jahr, Woche)
|
||||
- Makro-Phase (ONBOARDING bis JAHRESABSCHLUSS)
|
||||
- Zaehler (Klassen, geplante Klausuren, etc.)
|
||||
- Status-Flags (Onboarding abgeschlossen, etc.)
|
||||
"""
|
||||
if DB_ENABLED and db:
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository, SchoolyearEventRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
context = repo.get_or_create(teacher_id)
|
||||
|
||||
# Zaehler berechnen
|
||||
event_repo = SchoolyearEventRepository(db)
|
||||
upcoming_exams = event_repo.get_upcoming(teacher_id, days=30)
|
||||
exams_count = len([e for e in upcoming_exams if e.event_type.value == "exam"])
|
||||
|
||||
return TeacherContextResponse(
|
||||
schema_version="1.0",
|
||||
teacher_id=teacher_id,
|
||||
school=SchoolInfo(
|
||||
federal_state=context.federal_state or "BY",
|
||||
federal_state_name=FEDERAL_STATES.get(context.federal_state, ""),
|
||||
school_type=context.school_type or "gymnasium",
|
||||
school_type_name=SCHOOL_TYPES.get(context.school_type, ""),
|
||||
),
|
||||
school_year=SchoolYearInfo(
|
||||
id=context.schoolyear or "2024-2025",
|
||||
start=context.schoolyear_start.isoformat() if context.schoolyear_start else None,
|
||||
current_week=context.current_week or 1,
|
||||
),
|
||||
macro_phase=MacroPhaseInfo(
|
||||
id=context.macro_phase.value,
|
||||
label=_get_macro_phase_label(context.macro_phase),
|
||||
confidence=1.0,
|
||||
),
|
||||
core_counts=CoreCounts(
|
||||
classes=1 if context.has_classes else 0,
|
||||
exams_scheduled=exams_count,
|
||||
corrections_pending=0,
|
||||
),
|
||||
flags=ContextFlags(
|
||||
onboarding_completed=context.onboarding_completed,
|
||||
has_classes=context.has_classes,
|
||||
has_schedule=context.has_schedule,
|
||||
is_exam_period=context.is_exam_period,
|
||||
is_before_holidays=context.is_before_holidays,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get teacher context: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler beim Laden des Kontexts: {e}")
|
||||
|
||||
# Fallback ohne DB
|
||||
return TeacherContextResponse(
|
||||
schema_version="1.0",
|
||||
teacher_id=teacher_id,
|
||||
school=SchoolInfo(
|
||||
federal_state="BY",
|
||||
federal_state_name="Bayern",
|
||||
school_type="gymnasium",
|
||||
school_type_name="Gymnasium",
|
||||
),
|
||||
school_year=SchoolYearInfo(
|
||||
id="2024-2025",
|
||||
start=None,
|
||||
current_week=1,
|
||||
),
|
||||
macro_phase=MacroPhaseInfo(
|
||||
id="onboarding",
|
||||
label="Einrichtung",
|
||||
confidence=1.0,
|
||||
),
|
||||
core_counts=CoreCounts(),
|
||||
flags=ContextFlags(),
|
||||
)
|
||||
|
||||
|
||||
@router.put("/v1/context", response_model=TeacherContextResponse)
|
||||
async def update_teacher_context(
|
||||
teacher_id: str,
|
||||
request: UpdateContextRequest,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Aktualisiert den Kontext eines Lehrers.
|
||||
"""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
|
||||
# Validierung
|
||||
if request.federal_state and request.federal_state not in FEDERAL_STATES:
|
||||
raise HTTPException(status_code=400, detail=f"Ungueltiges Bundesland: {request.federal_state}")
|
||||
if request.school_type and request.school_type not in SCHOOL_TYPES:
|
||||
raise HTTPException(status_code=400, detail=f"Ungueltige Schulart: {request.school_type}")
|
||||
|
||||
# Parse datetime if provided
|
||||
schoolyear_start = None
|
||||
if request.schoolyear_start:
|
||||
schoolyear_start = datetime.fromisoformat(request.schoolyear_start.replace('Z', '+00:00'))
|
||||
|
||||
repo.update_context(
|
||||
teacher_id=teacher_id,
|
||||
federal_state=request.federal_state,
|
||||
school_type=request.school_type,
|
||||
schoolyear=request.schoolyear,
|
||||
schoolyear_start=schoolyear_start,
|
||||
macro_phase=request.macro_phase,
|
||||
current_week=request.current_week,
|
||||
)
|
||||
|
||||
return await get_teacher_context(teacher_id, db)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update teacher context: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler beim Aktualisieren: {e}")
|
||||
|
||||
|
||||
@router.post("/v1/context/complete-onboarding")
|
||||
async def complete_onboarding(
|
||||
teacher_id: str = Query(...),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Markiert das Onboarding als abgeschlossen."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"success": True, "macro_phase": "schuljahresstart", "note": "DB not available"}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
context = repo.complete_onboarding(teacher_id)
|
||||
return {
|
||||
"success": True,
|
||||
"macro_phase": context.macro_phase.value,
|
||||
"teacher_id": teacher_id,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to complete onboarding: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.post("/v1/context/reset-onboarding")
|
||||
async def reset_onboarding(
|
||||
teacher_id: str = Query(...),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Setzt das Onboarding zurueck (fuer Tests)."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"success": True, "macro_phase": "onboarding", "note": "DB not available"}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
context = repo.get_or_create(teacher_id)
|
||||
context.onboarding_completed = False
|
||||
context.macro_phase = MacroPhaseEnum.ONBOARDING
|
||||
db.commit()
|
||||
db.refresh(context)
|
||||
return {
|
||||
"success": True,
|
||||
"macro_phase": "onboarding",
|
||||
"teacher_id": teacher_id,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset onboarding: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Classroom API - Events & Routines Routes
|
||||
|
||||
School year events, recurring routines endpoints (Phase 8).
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
|
||||
from ..models import (
|
||||
CreateEventRequest,
|
||||
EventResponse,
|
||||
CreateRoutineRequest,
|
||||
RoutineResponse,
|
||||
)
|
||||
from ..services.persistence import (
|
||||
DB_ENABLED,
|
||||
SessionLocal,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Context"])
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Database session dependency."""
|
||||
if DB_ENABLED and SessionLocal:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
||||
# === Events Endpoints ===
|
||||
|
||||
@router.get("/v1/events")
|
||||
async def get_events(
|
||||
teacher_id: str = Query(...),
|
||||
status: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Holt Events eines Lehrers."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"events": [], "count": 0}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import SchoolyearEventRepository
|
||||
repo = SchoolyearEventRepository(db)
|
||||
events = repo.get_by_teacher(teacher_id, status=status, event_type=event_type, limit=limit)
|
||||
return {
|
||||
"events": [repo.to_dict(e) for e in events],
|
||||
"count": len(events),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get events: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.get("/v1/events/upcoming")
|
||||
async def get_upcoming_events(
|
||||
teacher_id: str = Query(...),
|
||||
days: int = 30,
|
||||
limit: int = 10,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Holt anstehende Events der naechsten X Tage."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"events": [], "count": 0}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import SchoolyearEventRepository
|
||||
repo = SchoolyearEventRepository(db)
|
||||
events = repo.get_upcoming(teacher_id, days=days, limit=limit)
|
||||
return {
|
||||
"events": [repo.to_dict(e) for e in events],
|
||||
"count": len(events),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get upcoming events: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.post("/v1/events", response_model=EventResponse)
|
||||
async def create_event(
|
||||
teacher_id: str,
|
||||
request: CreateEventRequest,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Erstellt ein neues Schuljahr-Event."""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import SchoolyearEventRepository
|
||||
repo = SchoolyearEventRepository(db)
|
||||
start_date = datetime.fromisoformat(request.start_date.replace('Z', '+00:00'))
|
||||
end_date = None
|
||||
if request.end_date:
|
||||
end_date = datetime.fromisoformat(request.end_date.replace('Z', '+00:00'))
|
||||
|
||||
event = repo.create(
|
||||
teacher_id=teacher_id,
|
||||
title=request.title,
|
||||
event_type=request.event_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
class_id=request.class_id,
|
||||
subject=request.subject,
|
||||
description=request.description,
|
||||
needs_preparation=request.needs_preparation,
|
||||
reminder_days_before=request.reminder_days_before,
|
||||
)
|
||||
|
||||
return EventResponse(
|
||||
id=event.id,
|
||||
teacher_id=event.teacher_id,
|
||||
event_type=event.event_type.value,
|
||||
title=event.title,
|
||||
description=event.description,
|
||||
start_date=event.start_date.isoformat(),
|
||||
end_date=event.end_date.isoformat() if event.end_date else None,
|
||||
class_id=event.class_id,
|
||||
subject=event.subject,
|
||||
status=event.status.value,
|
||||
needs_preparation=event.needs_preparation,
|
||||
preparation_done=event.preparation_done,
|
||||
reminder_days_before=event.reminder_days_before,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create event: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.delete("/v1/events/{event_id}")
|
||||
async def delete_event(event_id: str, db=Depends(get_db)):
|
||||
"""Loescht ein Event."""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import SchoolyearEventRepository
|
||||
repo = SchoolyearEventRepository(db)
|
||||
if repo.delete(event_id):
|
||||
return {"success": True, "deleted_id": event_id}
|
||||
raise HTTPException(status_code=404, detail="Event nicht gefunden")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete event: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
# === Routines Endpoints ===
|
||||
|
||||
@router.get("/v1/routines")
|
||||
async def get_routines(
|
||||
teacher_id: str = Query(...),
|
||||
is_active: bool = True,
|
||||
routine_type: Optional[str] = None,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Holt Routinen eines Lehrers."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"routines": [], "count": 0}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import RecurringRoutineRepository
|
||||
repo = RecurringRoutineRepository(db)
|
||||
routines = repo.get_by_teacher(teacher_id, is_active=is_active, routine_type=routine_type)
|
||||
return {
|
||||
"routines": [repo.to_dict(r) for r in routines],
|
||||
"count": len(routines),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get routines: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.get("/v1/routines/today")
|
||||
async def get_today_routines(teacher_id: str = Query(...), db=Depends(get_db)):
|
||||
"""Holt Routinen die heute stattfinden."""
|
||||
if not DB_ENABLED or not db:
|
||||
return {"routines": [], "count": 0}
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import RecurringRoutineRepository
|
||||
repo = RecurringRoutineRepository(db)
|
||||
routines = repo.get_today(teacher_id)
|
||||
return {
|
||||
"routines": [repo.to_dict(r) for r in routines],
|
||||
"count": len(routines),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get today's routines: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.post("/v1/routines", response_model=RoutineResponse)
|
||||
async def create_routine(
|
||||
teacher_id: str,
|
||||
request: CreateRoutineRequest,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Erstellt eine neue wiederkehrende Routine."""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import RecurringRoutineRepository
|
||||
repo = RecurringRoutineRepository(db)
|
||||
routine = repo.create(
|
||||
teacher_id=teacher_id,
|
||||
title=request.title,
|
||||
routine_type=request.routine_type,
|
||||
recurrence_pattern=request.recurrence_pattern,
|
||||
day_of_week=request.day_of_week,
|
||||
day_of_month=request.day_of_month,
|
||||
time_of_day=request.time_of_day,
|
||||
duration_minutes=request.duration_minutes,
|
||||
description=request.description,
|
||||
)
|
||||
|
||||
return RoutineResponse(
|
||||
id=routine.id,
|
||||
teacher_id=routine.teacher_id,
|
||||
routine_type=routine.routine_type.value,
|
||||
title=routine.title,
|
||||
description=routine.description,
|
||||
recurrence_pattern=routine.recurrence_pattern.value,
|
||||
day_of_week=routine.day_of_week,
|
||||
day_of_month=routine.day_of_month,
|
||||
time_of_day=routine.time_of_day.isoformat() if routine.time_of_day else None,
|
||||
duration_minutes=routine.duration_minutes,
|
||||
is_active=routine.is_active,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create routine: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
|
||||
@router.delete("/v1/routines/{routine_id}")
|
||||
async def delete_routine(routine_id: str, db=Depends(get_db)):
|
||||
"""Loescht eine Routine."""
|
||||
if not DB_ENABLED or not db:
|
||||
raise HTTPException(status_code=503, detail="Datenbank nicht verfuegbar")
|
||||
|
||||
try:
|
||||
from classroom_engine.repository import RecurringRoutineRepository
|
||||
repo = RecurringRoutineRepository(db)
|
||||
if repo.delete(routine_id):
|
||||
return {"success": True, "deleted_id": routine_id}
|
||||
raise HTTPException(status_code=404, detail="Routine nicht gefunden")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete routine: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Classroom API - Static Data, Suggestions & Sidebar Routes
|
||||
|
||||
Federal states, school types, macro phases, event/routine types,
|
||||
suggestions, sidebar model, and school year path.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
|
||||
from classroom_engine import FEDERAL_STATES, SCHOOL_TYPES
|
||||
|
||||
from ..services.persistence import (
|
||||
DB_ENABLED,
|
||||
SessionLocal,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Context"])
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Database session dependency."""
|
||||
if DB_ENABLED and SessionLocal:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
||||
# === Static Data Endpoints ===
|
||||
|
||||
@router.get("/v1/federal-states")
|
||||
async def get_federal_states():
|
||||
"""Gibt alle Bundeslaender zurueck."""
|
||||
return {
|
||||
"federal_states": [{"id": k, "name": v} for k, v in FEDERAL_STATES.items()]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/school-types")
|
||||
async def get_school_types():
|
||||
"""Gibt alle Schularten zurueck."""
|
||||
return {
|
||||
"school_types": [{"id": k, "name": v} for k, v in SCHOOL_TYPES.items()]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/macro-phases")
|
||||
async def get_macro_phases():
|
||||
"""Gibt alle Makro-Phasen mit Beschreibungen zurueck."""
|
||||
phases = [
|
||||
{"id": "onboarding", "label": "Einrichtung", "description": "Ersteinrichtung (Klassen, Stundenplan)", "order": 1},
|
||||
{"id": "schuljahresstart", "label": "Schuljahresstart", "description": "Erste 2-3 Wochen des Schuljahres", "order": 2},
|
||||
{"id": "unterrichtsaufbau", "label": "Unterrichtsaufbau", "description": "Routinen etablieren, erste Bewertungen", "order": 3},
|
||||
{"id": "leistungsphase_1", "label": "Leistungsphase 1", "description": "Erste Klassenarbeiten und Klausuren", "order": 4},
|
||||
{"id": "halbjahresabschluss", "label": "Halbjahresabschluss", "description": "Notenschluss, Zeugnisse, Konferenzen", "order": 5},
|
||||
{"id": "leistungsphase_2", "label": "Leistungsphase 2", "description": "Zweites Halbjahr, Pruefungsvorbereitung", "order": 6},
|
||||
{"id": "jahresabschluss", "label": "Jahresabschluss", "description": "Finale Noten, Versetzung, Schuljahresende", "order": 7},
|
||||
]
|
||||
return {"macro_phases": phases}
|
||||
|
||||
|
||||
@router.get("/v1/event-types")
|
||||
async def get_event_types():
|
||||
"""Gibt alle Event-Typen zurueck."""
|
||||
types = [
|
||||
{"id": "exam", "label": "Klassenarbeit/Klausur"},
|
||||
{"id": "parent_evening", "label": "Elternabend"},
|
||||
{"id": "trip", "label": "Klassenfahrt/Ausflug"},
|
||||
{"id": "project", "label": "Projektwoche"},
|
||||
{"id": "internship", "label": "Praktikum"},
|
||||
{"id": "presentation", "label": "Referate/Praesentationen"},
|
||||
{"id": "sports_day", "label": "Sporttag"},
|
||||
{"id": "school_festival", "label": "Schulfest"},
|
||||
{"id": "parent_consultation", "label": "Elternsprechtag"},
|
||||
{"id": "grade_deadline", "label": "Notenschluss"},
|
||||
{"id": "report_cards", "label": "Zeugnisausgabe"},
|
||||
{"id": "holiday_start", "label": "Ferienbeginn"},
|
||||
{"id": "holiday_end", "label": "Ferienende"},
|
||||
{"id": "other", "label": "Sonstiges"},
|
||||
]
|
||||
return {"event_types": types}
|
||||
|
||||
|
||||
@router.get("/v1/routine-types")
|
||||
async def get_routine_types():
|
||||
"""Gibt alle Routine-Typen zurueck."""
|
||||
types = [
|
||||
{"id": "teacher_conference", "label": "Lehrerkonferenz"},
|
||||
{"id": "subject_conference", "label": "Fachkonferenz"},
|
||||
{"id": "office_hours", "label": "Sprechstunde"},
|
||||
{"id": "team_meeting", "label": "Teamsitzung"},
|
||||
{"id": "supervision", "label": "Pausenaufsicht"},
|
||||
{"id": "correction_time", "label": "Korrekturzeit"},
|
||||
{"id": "prep_time", "label": "Vorbereitungszeit"},
|
||||
{"id": "other", "label": "Sonstiges"},
|
||||
]
|
||||
return {"routine_types": types}
|
||||
|
||||
|
||||
# === Suggestions & Sidebar ===
|
||||
|
||||
@router.get("/v1/suggestions")
|
||||
async def get_suggestions(
|
||||
teacher_id: str = Query(...),
|
||||
limit: int = Query(5, ge=1, le=20),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Generiert kontextbasierte Vorschlaege fuer einen Lehrer."""
|
||||
if DB_ENABLED and db:
|
||||
try:
|
||||
from classroom_engine.suggestions import SuggestionGenerator
|
||||
generator = SuggestionGenerator(db)
|
||||
result = generator.generate(teacher_id, limit=limit)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate suggestions: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fehler: {e}")
|
||||
|
||||
return {
|
||||
"active_contexts": [],
|
||||
"suggestions": [],
|
||||
"signals_summary": {
|
||||
"macro_phase": "onboarding",
|
||||
"current_week": 1,
|
||||
"has_classes": False,
|
||||
"exams_soon": 0,
|
||||
"routines_today": 0,
|
||||
},
|
||||
"total_suggestions": 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/sidebar")
|
||||
async def get_sidebar(
|
||||
teacher_id: str = Query(...),
|
||||
mode: str = Query("companion"),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""Generiert das dynamische Sidebar-Model."""
|
||||
if mode == "companion":
|
||||
now_relevant = []
|
||||
if DB_ENABLED and db:
|
||||
try:
|
||||
from classroom_engine.suggestions import SuggestionGenerator
|
||||
generator = SuggestionGenerator(db)
|
||||
result = generator.generate(teacher_id, limit=5)
|
||||
now_relevant = [
|
||||
{
|
||||
"id": s["id"],
|
||||
"label": s["title"],
|
||||
"state": "recommended" if s["priority"] > 70 else "default",
|
||||
"badge": s.get("badge"),
|
||||
"icon": s.get("icon", "lightbulb"),
|
||||
"action_url": s.get("action_url"),
|
||||
}
|
||||
for s in result.get("suggestions", [])
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get suggestions for sidebar: {e}")
|
||||
|
||||
return {
|
||||
"mode": "companion",
|
||||
"sections": [
|
||||
{"id": "SEARCH", "type": "search_bar", "placeholder": "Suchen..."},
|
||||
{
|
||||
"id": "NOW_RELEVANT",
|
||||
"type": "list",
|
||||
"title": "Jetzt relevant",
|
||||
"items": now_relevant if now_relevant else [
|
||||
{"id": "no_suggestions", "label": "Keine Vorschlaege", "state": "default", "icon": "check_circle"}
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "ALL_MODULES",
|
||||
"type": "folder",
|
||||
"label": "Alle Module",
|
||||
"icon": "folder",
|
||||
"collapsed": True,
|
||||
"items": [
|
||||
{"id": "lesson", "label": "Stundenmodus", "icon": "timer"},
|
||||
{"id": "classes", "label": "Klassen", "icon": "groups"},
|
||||
{"id": "exams", "label": "Klausuren", "icon": "quiz"},
|
||||
{"id": "grades", "label": "Noten", "icon": "calculate"},
|
||||
{"id": "calendar", "label": "Kalender", "icon": "calendar_month"},
|
||||
{"id": "materials", "label": "Materialien", "icon": "folder_open"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "QUICK_ACTIONS",
|
||||
"type": "actions",
|
||||
"title": "Kurzaktionen",
|
||||
"items": [
|
||||
{"id": "scan", "label": "Scan hochladen", "icon": "upload_file"},
|
||||
{"id": "note", "label": "Notiz erstellen", "icon": "note_add"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"mode": "classic",
|
||||
"sections": [
|
||||
{
|
||||
"id": "NAVIGATION",
|
||||
"type": "tree",
|
||||
"items": [
|
||||
{"id": "dashboard", "label": "Dashboard", "icon": "dashboard", "url": "/dashboard"},
|
||||
{"id": "lesson", "label": "Stundenmodus", "icon": "timer", "url": "/lesson"},
|
||||
{"id": "classes", "label": "Klassen", "icon": "groups", "url": "/classes"},
|
||||
{"id": "exams", "label": "Klausuren", "icon": "quiz", "url": "/exams"},
|
||||
{"id": "grades", "label": "Noten", "icon": "calculate", "url": "/grades"},
|
||||
{"id": "calendar", "label": "Kalender", "icon": "calendar_month", "url": "/calendar"},
|
||||
{"id": "materials", "label": "Materialien", "icon": "folder_open", "url": "/materials"},
|
||||
{"id": "settings", "label": "Einstellungen", "icon": "settings", "url": "/settings"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/path")
|
||||
async def get_schoolyear_path(teacher_id: str = Query(...), db=Depends(get_db)):
|
||||
"""Generiert den Schuljahres-Pfad mit Meilensteinen."""
|
||||
current_phase = "onboarding"
|
||||
if DB_ENABLED and db:
|
||||
try:
|
||||
from classroom_engine.repository import TeacherContextRepository
|
||||
repo = TeacherContextRepository(db)
|
||||
context = repo.get_or_create(teacher_id)
|
||||
current_phase = context.macro_phase.value
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get context for path: {e}")
|
||||
|
||||
phase_order = [
|
||||
"onboarding", "schuljahresstart", "unterrichtsaufbau",
|
||||
"leistungsphase_1", "halbjahresabschluss", "leistungsphase_2", "jahresabschluss",
|
||||
]
|
||||
|
||||
current_index = phase_order.index(current_phase) if current_phase in phase_order else 0
|
||||
|
||||
milestones = [
|
||||
{"id": "MS_START", "label": "Start", "phase": "onboarding", "icon": "flag"},
|
||||
{"id": "MS_SETUP", "label": "Einrichtung", "phase": "schuljahresstart", "icon": "tune"},
|
||||
{"id": "MS_ROUTINE", "label": "Routinen", "phase": "unterrichtsaufbau", "icon": "repeat"},
|
||||
{"id": "MS_EXAM_1", "label": "Klausuren", "phase": "leistungsphase_1", "icon": "quiz"},
|
||||
{"id": "MS_HALFYEAR", "label": "Halbjahr", "phase": "halbjahresabschluss", "icon": "event"},
|
||||
{"id": "MS_EXAM_2", "label": "Pruefungen", "phase": "leistungsphase_2", "icon": "school"},
|
||||
{"id": "MS_END", "label": "Abschluss", "phase": "jahresabschluss", "icon": "celebration"},
|
||||
]
|
||||
|
||||
for milestone in milestones:
|
||||
phase = milestone["phase"]
|
||||
phase_index = phase_order.index(phase) if phase in phase_order else 999
|
||||
if phase_index < current_index:
|
||||
milestone["status"] = "done"
|
||||
elif phase_index == current_index:
|
||||
milestone["status"] = "current"
|
||||
else:
|
||||
milestone["status"] = "upcoming"
|
||||
|
||||
current_milestone_id = next(
|
||||
(m["id"] for m in milestones if m["status"] == "current"),
|
||||
milestones[0]["id"]
|
||||
)
|
||||
|
||||
progress = int((current_index / (len(phase_order) - 1)) * 100) if len(phase_order) > 1 else 0
|
||||
|
||||
return {
|
||||
"milestones": milestones,
|
||||
"current_milestone_id": current_milestone_id,
|
||||
"progress_percent": progress,
|
||||
"current_phase": current_phase,
|
||||
}
|
||||
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
EduSearch Seeds CRUD Routes.
|
||||
|
||||
List, get, create, update, delete, and bulk import for seed URLs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
import asyncpg
|
||||
|
||||
from .edu_search_models import (
|
||||
CategoryResponse,
|
||||
SeedCreate,
|
||||
SeedUpdate,
|
||||
SeedResponse,
|
||||
SeedsListResponse,
|
||||
BulkImportRequest,
|
||||
BulkImportResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["edu-search"])
|
||||
|
||||
# Database connection pool
|
||||
_pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
|
||||
async def get_db_pool() -> asyncpg.Pool:
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
database_url = os.environ.get("DATABASE_URL")
|
||||
if not database_url:
|
||||
raise RuntimeError("DATABASE_URL nicht konfiguriert - bitte via Vault oder Umgebungsvariable setzen")
|
||||
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
|
||||
return _pool
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[CategoryResponse])
|
||||
async def list_categories():
|
||||
"""List all seed categories."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, display_name, description, icon, sort_order, is_active
|
||||
FROM edu_search_categories
|
||||
WHERE is_active = TRUE
|
||||
ORDER BY sort_order
|
||||
""")
|
||||
return [
|
||||
CategoryResponse(
|
||||
id=str(row["id"]),
|
||||
name=row["name"],
|
||||
display_name=row["display_name"],
|
||||
description=row["description"],
|
||||
icon=row["icon"],
|
||||
sort_order=row["sort_order"],
|
||||
is_active=row["is_active"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
@router.get("/seeds", response_model=SeedsListResponse)
|
||||
async def list_seeds(
|
||||
category: Optional[str] = Query(None, description="Filter by category name"),
|
||||
state: Optional[str] = Query(None, description="Filter by state code"),
|
||||
enabled: Optional[bool] = Query(None, description="Filter by enabled status"),
|
||||
search: Optional[str] = Query(None, description="Search in name/url"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
"""List seeds with optional filtering and pagination."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Build WHERE clause
|
||||
conditions = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if category:
|
||||
conditions.append(f"c.name = ${param_idx}")
|
||||
params.append(category)
|
||||
param_idx += 1
|
||||
|
||||
if state:
|
||||
conditions.append(f"s.state = ${param_idx}")
|
||||
params.append(state)
|
||||
param_idx += 1
|
||||
|
||||
if enabled is not None:
|
||||
conditions.append(f"s.enabled = ${param_idx}")
|
||||
params.append(enabled)
|
||||
param_idx += 1
|
||||
|
||||
if search:
|
||||
conditions.append(f"(s.name ILIKE ${param_idx} OR s.url ILIKE ${param_idx})")
|
||||
params.append(f"%{search}%")
|
||||
param_idx += 1
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "TRUE"
|
||||
|
||||
# Count total
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) FROM edu_search_seeds s
|
||||
LEFT JOIN edu_search_categories c ON s.category_id = c.id
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
total = await conn.fetchval(count_query, *params)
|
||||
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
params.extend([page_size, offset])
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
s.id, s.url, s.name, s.description,
|
||||
c.name as category, c.display_name as category_display_name,
|
||||
s.source_type, s.scope, s.state, s.trust_boost, s.enabled,
|
||||
s.crawl_depth, s.crawl_frequency, s.last_crawled_at,
|
||||
s.last_crawl_status, s.last_crawl_docs, s.total_documents,
|
||||
s.created_at, s.updated_at
|
||||
FROM edu_search_seeds s
|
||||
LEFT JOIN edu_search_categories c ON s.category_id = c.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY c.sort_order, s.name
|
||||
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
"""
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
seeds = [_row_to_seed_response(row) for row in rows]
|
||||
|
||||
return SeedsListResponse(
|
||||
seeds=seeds,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/seeds/{seed_id}", response_model=SeedResponse)
|
||||
async def get_seed(seed_id: str):
|
||||
"""Get a single seed by ID."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT
|
||||
s.id, s.url, s.name, s.description,
|
||||
c.name as category, c.display_name as category_display_name,
|
||||
s.source_type, s.scope, s.state, s.trust_boost, s.enabled,
|
||||
s.crawl_depth, s.crawl_frequency, s.last_crawled_at,
|
||||
s.last_crawl_status, s.last_crawl_docs, s.total_documents,
|
||||
s.created_at, s.updated_at
|
||||
FROM edu_search_seeds s
|
||||
LEFT JOIN edu_search_categories c ON s.category_id = c.id
|
||||
WHERE s.id = $1
|
||||
""", seed_id)
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
|
||||
|
||||
return _row_to_seed_response(row)
|
||||
|
||||
|
||||
@router.post("/seeds", response_model=SeedResponse, status_code=201)
|
||||
async def create_seed(seed: SeedCreate):
|
||||
"""Create a new seed URL."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
category_id = None
|
||||
if seed.category_name:
|
||||
category_id = await conn.fetchval(
|
||||
"SELECT id FROM edu_search_categories WHERE name = $1",
|
||||
seed.category_name
|
||||
)
|
||||
|
||||
try:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO edu_search_seeds (
|
||||
url, name, description, category_id, source_type, scope,
|
||||
state, trust_boost, enabled, crawl_depth, crawl_frequency
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id, created_at, updated_at
|
||||
""",
|
||||
seed.url, seed.name, seed.description, category_id,
|
||||
seed.source_type, seed.scope, seed.state, seed.trust_boost,
|
||||
seed.enabled, seed.crawl_depth, seed.crawl_frequency
|
||||
)
|
||||
except asyncpg.UniqueViolationError:
|
||||
raise HTTPException(status_code=409, detail="URL existiert bereits")
|
||||
|
||||
return SeedResponse(
|
||||
id=str(row["id"]),
|
||||
url=seed.url,
|
||||
name=seed.name,
|
||||
description=seed.description,
|
||||
category=seed.category_name,
|
||||
category_display_name=None,
|
||||
source_type=seed.source_type,
|
||||
scope=seed.scope,
|
||||
state=seed.state,
|
||||
trust_boost=seed.trust_boost,
|
||||
enabled=seed.enabled,
|
||||
crawl_depth=seed.crawl_depth,
|
||||
crawl_frequency=seed.crawl_frequency,
|
||||
last_crawled_at=None,
|
||||
last_crawl_status=None,
|
||||
last_crawl_docs=0,
|
||||
total_documents=0,
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
@router.put("/seeds/{seed_id}", response_model=SeedResponse)
|
||||
async def update_seed(seed_id: str, seed: SeedUpdate):
|
||||
"""Update an existing seed."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
updates = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if seed.url is not None:
|
||||
updates.append(f"url = ${param_idx}")
|
||||
params.append(seed.url)
|
||||
param_idx += 1
|
||||
|
||||
if seed.name is not None:
|
||||
updates.append(f"name = ${param_idx}")
|
||||
params.append(seed.name)
|
||||
param_idx += 1
|
||||
|
||||
if seed.description is not None:
|
||||
updates.append(f"description = ${param_idx}")
|
||||
params.append(seed.description)
|
||||
param_idx += 1
|
||||
|
||||
if seed.category_name is not None:
|
||||
category_id = await conn.fetchval(
|
||||
"SELECT id FROM edu_search_categories WHERE name = $1",
|
||||
seed.category_name
|
||||
)
|
||||
updates.append(f"category_id = ${param_idx}")
|
||||
params.append(category_id)
|
||||
param_idx += 1
|
||||
|
||||
if seed.source_type is not None:
|
||||
updates.append(f"source_type = ${param_idx}")
|
||||
params.append(seed.source_type)
|
||||
param_idx += 1
|
||||
|
||||
if seed.scope is not None:
|
||||
updates.append(f"scope = ${param_idx}")
|
||||
params.append(seed.scope)
|
||||
param_idx += 1
|
||||
|
||||
if seed.state is not None:
|
||||
updates.append(f"state = ${param_idx}")
|
||||
params.append(seed.state)
|
||||
param_idx += 1
|
||||
|
||||
if seed.trust_boost is not None:
|
||||
updates.append(f"trust_boost = ${param_idx}")
|
||||
params.append(seed.trust_boost)
|
||||
param_idx += 1
|
||||
|
||||
if seed.enabled is not None:
|
||||
updates.append(f"enabled = ${param_idx}")
|
||||
params.append(seed.enabled)
|
||||
param_idx += 1
|
||||
|
||||
if seed.crawl_depth is not None:
|
||||
updates.append(f"crawl_depth = ${param_idx}")
|
||||
params.append(seed.crawl_depth)
|
||||
param_idx += 1
|
||||
|
||||
if seed.crawl_frequency is not None:
|
||||
updates.append(f"crawl_frequency = ${param_idx}")
|
||||
params.append(seed.crawl_frequency)
|
||||
param_idx += 1
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="Keine Felder zum Aktualisieren")
|
||||
|
||||
updates.append("updated_at = NOW()")
|
||||
params.append(seed_id)
|
||||
|
||||
query = f"""
|
||||
UPDATE edu_search_seeds
|
||||
SET {", ".join(updates)}
|
||||
WHERE id = ${param_idx}
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
result = await conn.fetchrow(query, *params)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
|
||||
|
||||
# Return updated seed
|
||||
return await get_seed(seed_id)
|
||||
|
||||
|
||||
@router.delete("/seeds/{seed_id}")
|
||||
async def delete_seed(seed_id: str):
|
||||
"""Delete a seed."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM edu_search_seeds WHERE id = $1",
|
||||
seed_id
|
||||
)
|
||||
if result == "DELETE 0":
|
||||
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
|
||||
|
||||
return {"status": "deleted", "id": seed_id}
|
||||
|
||||
|
||||
@router.post("/seeds/bulk-import", response_model=BulkImportResponse)
|
||||
async def bulk_import_seeds(request: BulkImportRequest):
|
||||
"""Bulk import seeds (skip duplicates)."""
|
||||
pool = await get_db_pool()
|
||||
imported = 0
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Pre-fetch all category IDs
|
||||
categories = {}
|
||||
rows = await conn.fetch("SELECT id, name FROM edu_search_categories")
|
||||
for row in rows:
|
||||
categories[row["name"]] = row["id"]
|
||||
|
||||
for seed in request.seeds:
|
||||
try:
|
||||
category_id = categories.get(seed.category_name) if seed.category_name else None
|
||||
|
||||
await conn.execute("""
|
||||
INSERT INTO edu_search_seeds (
|
||||
url, name, description, category_id, source_type, scope,
|
||||
state, trust_boost, enabled, crawl_depth, crawl_frequency
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (url) DO NOTHING
|
||||
""",
|
||||
seed.url, seed.name, seed.description, category_id,
|
||||
seed.source_type, seed.scope, seed.state, seed.trust_boost,
|
||||
seed.enabled, seed.crawl_depth, seed.crawl_frequency
|
||||
)
|
||||
imported += 1
|
||||
except asyncpg.UniqueViolationError:
|
||||
skipped += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{seed.url}: {str(e)}")
|
||||
|
||||
return BulkImportResponse(imported=imported, skipped=skipped, errors=errors)
|
||||
|
||||
|
||||
def _row_to_seed_response(row) -> SeedResponse:
|
||||
"""Convert a database row to SeedResponse."""
|
||||
return SeedResponse(
|
||||
id=str(row["id"]),
|
||||
url=row["url"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
category=row["category"],
|
||||
category_display_name=row["category_display_name"],
|
||||
source_type=row["source_type"],
|
||||
scope=row["scope"],
|
||||
state=row["state"],
|
||||
trust_boost=float(row["trust_boost"]),
|
||||
enabled=row["enabled"],
|
||||
crawl_depth=row["crawl_depth"],
|
||||
crawl_frequency=row["crawl_frequency"],
|
||||
last_crawled_at=row["last_crawled_at"],
|
||||
last_crawl_status=row["last_crawl_status"],
|
||||
last_crawl_docs=row["last_crawl_docs"] or 0,
|
||||
total_documents=row["total_documents"] or 0,
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
EduSearch Seeds Pydantic Models.
|
||||
|
||||
Request/Response models for the education search seed URL API.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CategoryResponse(BaseModel):
|
||||
"""Category response model."""
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
sort_order: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
class SeedBase(BaseModel):
|
||||
"""Base seed model for creation/update."""
|
||||
url: str = Field(..., max_length=500)
|
||||
name: str = Field(..., max_length=255)
|
||||
description: Optional[str] = None
|
||||
category_name: Optional[str] = Field(None, description="Category name (federal, states, etc.)")
|
||||
source_type: str = Field("GOV", description="GOV, EDU, UNI, etc.")
|
||||
scope: str = Field("FEDERAL", description="FEDERAL, STATE, etc.")
|
||||
state: Optional[str] = Field(None, max_length=5, description="State code (BW, BY, etc.)")
|
||||
trust_boost: float = Field(0.50, ge=0.0, le=1.0)
|
||||
enabled: bool = True
|
||||
crawl_depth: int = Field(2, ge=1, le=5)
|
||||
crawl_frequency: str = Field("weekly", description="hourly, daily, weekly, monthly")
|
||||
|
||||
|
||||
class SeedCreate(SeedBase):
|
||||
"""Seed creation model."""
|
||||
pass
|
||||
|
||||
|
||||
class SeedUpdate(BaseModel):
|
||||
"""Seed update model (all fields optional)."""
|
||||
url: Optional[str] = Field(None, max_length=500)
|
||||
name: Optional[str] = Field(None, max_length=255)
|
||||
description: Optional[str] = None
|
||||
category_name: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
state: Optional[str] = Field(None, max_length=5)
|
||||
trust_boost: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
enabled: Optional[bool] = None
|
||||
crawl_depth: Optional[int] = Field(None, ge=1, le=5)
|
||||
crawl_frequency: Optional[str] = None
|
||||
|
||||
|
||||
class SeedResponse(BaseModel):
|
||||
"""Seed response model."""
|
||||
id: str
|
||||
url: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
category_display_name: Optional[str] = None
|
||||
source_type: str
|
||||
scope: str
|
||||
state: Optional[str] = None
|
||||
trust_boost: float
|
||||
enabled: bool
|
||||
crawl_depth: int
|
||||
crawl_frequency: str
|
||||
last_crawled_at: Optional[datetime] = None
|
||||
last_crawl_status: Optional[str] = None
|
||||
last_crawl_docs: int = 0
|
||||
total_documents: int = 0
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class SeedsListResponse(BaseModel):
|
||||
"""List response with pagination info."""
|
||||
seeds: List[SeedResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
"""Crawl statistics response."""
|
||||
total_seeds: int
|
||||
enabled_seeds: int
|
||||
total_documents: int
|
||||
seeds_by_category: dict
|
||||
seeds_by_state: dict
|
||||
last_crawl_time: Optional[datetime] = None
|
||||
|
||||
|
||||
class BulkImportRequest(BaseModel):
|
||||
"""Bulk import request."""
|
||||
seeds: List[SeedCreate]
|
||||
|
||||
|
||||
class BulkImportResponse(BaseModel):
|
||||
"""Bulk import response."""
|
||||
imported: int
|
||||
skipped: int
|
||||
errors: List[str]
|
||||
|
||||
|
||||
class CrawlStatusUpdate(BaseModel):
|
||||
"""Crawl status update from edu-search-service."""
|
||||
seed_url: str = Field(..., description="The seed URL that was crawled")
|
||||
status: str = Field(..., description="Crawl status: success, error, partial")
|
||||
documents_crawled: int = Field(0, ge=0, description="Number of documents crawled")
|
||||
error_message: Optional[str] = Field(None, description="Error message if status is error")
|
||||
crawl_duration_seconds: float = Field(0.0, ge=0.0, description="Duration of the crawl in seconds")
|
||||
|
||||
|
||||
class CrawlStatusResponse(BaseModel):
|
||||
"""Response for crawl status update."""
|
||||
success: bool
|
||||
seed_url: str
|
||||
message: str
|
||||
|
||||
|
||||
class BulkCrawlStatusUpdate(BaseModel):
|
||||
"""Bulk crawl status update."""
|
||||
updates: List[CrawlStatusUpdate]
|
||||
|
||||
|
||||
class BulkCrawlStatusResponse(BaseModel):
|
||||
"""Response for bulk crawl status update."""
|
||||
updated: int
|
||||
failed: int
|
||||
errors: List[str]
|
||||
@@ -1,710 +1,58 @@
|
||||
"""
|
||||
EduSearch Seeds API Routes.
|
||||
EduSearch Seeds API Routes — Barrel Re-export.
|
||||
|
||||
Split into submodules:
|
||||
- edu_search_models.py — Pydantic request/response models
|
||||
- edu_search_crud.py — CRUD endpoints (list, get, create, update, delete, bulk import)
|
||||
- edu_search_status.py — Stats, export for crawler, crawl status feedback
|
||||
|
||||
CRUD operations for managing education search crawler seed URLs.
|
||||
Direct database access to PostgreSQL.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
import asyncpg
|
||||
from .edu_search_crud import router as _crud_router, get_db_pool
|
||||
from .edu_search_status import router as _status_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Re-export models for consumers that import types from this module
|
||||
from .edu_search_models import (
|
||||
CategoryResponse,
|
||||
SeedBase,
|
||||
SeedCreate,
|
||||
SeedUpdate,
|
||||
SeedResponse,
|
||||
SeedsListResponse,
|
||||
StatsResponse,
|
||||
BulkImportRequest,
|
||||
BulkImportResponse,
|
||||
CrawlStatusUpdate,
|
||||
CrawlStatusResponse,
|
||||
BulkCrawlStatusUpdate,
|
||||
BulkCrawlStatusResponse,
|
||||
)
|
||||
|
||||
# Combine both sub-routers into a single router for backwards compatibility.
|
||||
# The consumer imports `from .edu_search_seeds import router as edu_search_seeds_router`.
|
||||
router = APIRouter(prefix="/edu-search", tags=["edu-search"])
|
||||
|
||||
# Database connection pool
|
||||
_pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
|
||||
async def get_db_pool() -> asyncpg.Pool:
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
database_url = os.environ.get("DATABASE_URL")
|
||||
if not database_url:
|
||||
raise RuntimeError("DATABASE_URL nicht konfiguriert - bitte via Vault oder Umgebungsvariable setzen")
|
||||
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
|
||||
return _pool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CategoryResponse(BaseModel):
|
||||
"""Category response model."""
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
sort_order: int
|
||||
is_active: bool
|
||||
|
||||
|
||||
class SeedBase(BaseModel):
|
||||
"""Base seed model for creation/update."""
|
||||
url: str = Field(..., max_length=500)
|
||||
name: str = Field(..., max_length=255)
|
||||
description: Optional[str] = None
|
||||
category_name: Optional[str] = Field(None, description="Category name (federal, states, etc.)")
|
||||
source_type: str = Field("GOV", description="GOV, EDU, UNI, etc.")
|
||||
scope: str = Field("FEDERAL", description="FEDERAL, STATE, etc.")
|
||||
state: Optional[str] = Field(None, max_length=5, description="State code (BW, BY, etc.)")
|
||||
trust_boost: float = Field(0.50, ge=0.0, le=1.0)
|
||||
enabled: bool = True
|
||||
crawl_depth: int = Field(2, ge=1, le=5)
|
||||
crawl_frequency: str = Field("weekly", description="hourly, daily, weekly, monthly")
|
||||
|
||||
|
||||
class SeedCreate(SeedBase):
|
||||
"""Seed creation model."""
|
||||
pass
|
||||
|
||||
|
||||
class SeedUpdate(BaseModel):
|
||||
"""Seed update model (all fields optional)."""
|
||||
url: Optional[str] = Field(None, max_length=500)
|
||||
name: Optional[str] = Field(None, max_length=255)
|
||||
description: Optional[str] = None
|
||||
category_name: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
state: Optional[str] = Field(None, max_length=5)
|
||||
trust_boost: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
enabled: Optional[bool] = None
|
||||
crawl_depth: Optional[int] = Field(None, ge=1, le=5)
|
||||
crawl_frequency: Optional[str] = None
|
||||
|
||||
|
||||
class SeedResponse(BaseModel):
|
||||
"""Seed response model."""
|
||||
id: str
|
||||
url: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
category_display_name: Optional[str] = None
|
||||
source_type: str
|
||||
scope: str
|
||||
state: Optional[str] = None
|
||||
trust_boost: float
|
||||
enabled: bool
|
||||
crawl_depth: int
|
||||
crawl_frequency: str
|
||||
last_crawled_at: Optional[datetime] = None
|
||||
last_crawl_status: Optional[str] = None
|
||||
last_crawl_docs: int = 0
|
||||
total_documents: int = 0
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class SeedsListResponse(BaseModel):
|
||||
"""List response with pagination info."""
|
||||
seeds: List[SeedResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
"""Crawl statistics response."""
|
||||
total_seeds: int
|
||||
enabled_seeds: int
|
||||
total_documents: int
|
||||
seeds_by_category: dict
|
||||
seeds_by_state: dict
|
||||
last_crawl_time: Optional[datetime] = None
|
||||
|
||||
|
||||
class BulkImportRequest(BaseModel):
|
||||
"""Bulk import request."""
|
||||
seeds: List[SeedCreate]
|
||||
|
||||
|
||||
class BulkImportResponse(BaseModel):
|
||||
"""Bulk import response."""
|
||||
imported: int
|
||||
skipped: int
|
||||
errors: List[str]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[CategoryResponse])
|
||||
async def list_categories():
|
||||
"""List all seed categories."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, display_name, description, icon, sort_order, is_active
|
||||
FROM edu_search_categories
|
||||
WHERE is_active = TRUE
|
||||
ORDER BY sort_order
|
||||
""")
|
||||
return [
|
||||
CategoryResponse(
|
||||
id=str(row["id"]),
|
||||
name=row["name"],
|
||||
display_name=row["display_name"],
|
||||
description=row["description"],
|
||||
icon=row["icon"],
|
||||
sort_order=row["sort_order"],
|
||||
is_active=row["is_active"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
@router.get("/seeds", response_model=SeedsListResponse)
|
||||
async def list_seeds(
|
||||
category: Optional[str] = Query(None, description="Filter by category name"),
|
||||
state: Optional[str] = Query(None, description="Filter by state code"),
|
||||
enabled: Optional[bool] = Query(None, description="Filter by enabled status"),
|
||||
search: Optional[str] = Query(None, description="Search in name/url"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
"""List seeds with optional filtering and pagination."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Build WHERE clause
|
||||
conditions = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if category:
|
||||
conditions.append(f"c.name = ${param_idx}")
|
||||
params.append(category)
|
||||
param_idx += 1
|
||||
|
||||
if state:
|
||||
conditions.append(f"s.state = ${param_idx}")
|
||||
params.append(state)
|
||||
param_idx += 1
|
||||
|
||||
if enabled is not None:
|
||||
conditions.append(f"s.enabled = ${param_idx}")
|
||||
params.append(enabled)
|
||||
param_idx += 1
|
||||
|
||||
if search:
|
||||
conditions.append(f"(s.name ILIKE ${param_idx} OR s.url ILIKE ${param_idx})")
|
||||
params.append(f"%{search}%")
|
||||
param_idx += 1
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "TRUE"
|
||||
|
||||
# Count total
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) FROM edu_search_seeds s
|
||||
LEFT JOIN edu_search_categories c ON s.category_id = c.id
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
total = await conn.fetchval(count_query, *params)
|
||||
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
params.extend([page_size, offset])
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
s.id, s.url, s.name, s.description,
|
||||
c.name as category, c.display_name as category_display_name,
|
||||
s.source_type, s.scope, s.state, s.trust_boost, s.enabled,
|
||||
s.crawl_depth, s.crawl_frequency, s.last_crawled_at,
|
||||
s.last_crawl_status, s.last_crawl_docs, s.total_documents,
|
||||
s.created_at, s.updated_at
|
||||
FROM edu_search_seeds s
|
||||
LEFT JOIN edu_search_categories c ON s.category_id = c.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY c.sort_order, s.name
|
||||
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
"""
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
seeds = [
|
||||
SeedResponse(
|
||||
id=str(row["id"]),
|
||||
url=row["url"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
category=row["category"],
|
||||
category_display_name=row["category_display_name"],
|
||||
source_type=row["source_type"],
|
||||
scope=row["scope"],
|
||||
state=row["state"],
|
||||
trust_boost=float(row["trust_boost"]),
|
||||
enabled=row["enabled"],
|
||||
crawl_depth=row["crawl_depth"],
|
||||
crawl_frequency=row["crawl_frequency"],
|
||||
last_crawled_at=row["last_crawled_at"],
|
||||
last_crawl_status=row["last_crawl_status"],
|
||||
last_crawl_docs=row["last_crawl_docs"] or 0,
|
||||
total_documents=row["total_documents"] or 0,
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return SeedsListResponse(
|
||||
seeds=seeds,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/seeds/{seed_id}", response_model=SeedResponse)
|
||||
async def get_seed(seed_id: str):
|
||||
"""Get a single seed by ID."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT
|
||||
s.id, s.url, s.name, s.description,
|
||||
c.name as category, c.display_name as category_display_name,
|
||||
s.source_type, s.scope, s.state, s.trust_boost, s.enabled,
|
||||
s.crawl_depth, s.crawl_frequency, s.last_crawled_at,
|
||||
s.last_crawl_status, s.last_crawl_docs, s.total_documents,
|
||||
s.created_at, s.updated_at
|
||||
FROM edu_search_seeds s
|
||||
LEFT JOIN edu_search_categories c ON s.category_id = c.id
|
||||
WHERE s.id = $1
|
||||
""", seed_id)
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
|
||||
|
||||
return SeedResponse(
|
||||
id=str(row["id"]),
|
||||
url=row["url"],
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
category=row["category"],
|
||||
category_display_name=row["category_display_name"],
|
||||
source_type=row["source_type"],
|
||||
scope=row["scope"],
|
||||
state=row["state"],
|
||||
trust_boost=float(row["trust_boost"]),
|
||||
enabled=row["enabled"],
|
||||
crawl_depth=row["crawl_depth"],
|
||||
crawl_frequency=row["crawl_frequency"],
|
||||
last_crawled_at=row["last_crawled_at"],
|
||||
last_crawl_status=row["last_crawl_status"],
|
||||
last_crawl_docs=row["last_crawl_docs"] or 0,
|
||||
total_documents=row["total_documents"] or 0,
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/seeds", response_model=SeedResponse, status_code=201)
|
||||
async def create_seed(seed: SeedCreate):
|
||||
"""Create a new seed URL."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Get category ID if provided
|
||||
category_id = None
|
||||
if seed.category_name:
|
||||
category_id = await conn.fetchval(
|
||||
"SELECT id FROM edu_search_categories WHERE name = $1",
|
||||
seed.category_name
|
||||
)
|
||||
|
||||
try:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO edu_search_seeds (
|
||||
url, name, description, category_id, source_type, scope,
|
||||
state, trust_boost, enabled, crawl_depth, crawl_frequency
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id, created_at, updated_at
|
||||
""",
|
||||
seed.url, seed.name, seed.description, category_id,
|
||||
seed.source_type, seed.scope, seed.state, seed.trust_boost,
|
||||
seed.enabled, seed.crawl_depth, seed.crawl_frequency
|
||||
)
|
||||
except asyncpg.UniqueViolationError:
|
||||
raise HTTPException(status_code=409, detail="URL existiert bereits")
|
||||
|
||||
return SeedResponse(
|
||||
id=str(row["id"]),
|
||||
url=seed.url,
|
||||
name=seed.name,
|
||||
description=seed.description,
|
||||
category=seed.category_name,
|
||||
category_display_name=None,
|
||||
source_type=seed.source_type,
|
||||
scope=seed.scope,
|
||||
state=seed.state,
|
||||
trust_boost=seed.trust_boost,
|
||||
enabled=seed.enabled,
|
||||
crawl_depth=seed.crawl_depth,
|
||||
crawl_frequency=seed.crawl_frequency,
|
||||
last_crawled_at=None,
|
||||
last_crawl_status=None,
|
||||
last_crawl_docs=0,
|
||||
total_documents=0,
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
@router.put("/seeds/{seed_id}", response_model=SeedResponse)
|
||||
async def update_seed(seed_id: str, seed: SeedUpdate):
|
||||
"""Update an existing seed."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Build update statement dynamically
|
||||
updates = []
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if seed.url is not None:
|
||||
updates.append(f"url = ${param_idx}")
|
||||
params.append(seed.url)
|
||||
param_idx += 1
|
||||
|
||||
if seed.name is not None:
|
||||
updates.append(f"name = ${param_idx}")
|
||||
params.append(seed.name)
|
||||
param_idx += 1
|
||||
|
||||
if seed.description is not None:
|
||||
updates.append(f"description = ${param_idx}")
|
||||
params.append(seed.description)
|
||||
param_idx += 1
|
||||
|
||||
if seed.category_name is not None:
|
||||
category_id = await conn.fetchval(
|
||||
"SELECT id FROM edu_search_categories WHERE name = $1",
|
||||
seed.category_name
|
||||
)
|
||||
updates.append(f"category_id = ${param_idx}")
|
||||
params.append(category_id)
|
||||
param_idx += 1
|
||||
|
||||
if seed.source_type is not None:
|
||||
updates.append(f"source_type = ${param_idx}")
|
||||
params.append(seed.source_type)
|
||||
param_idx += 1
|
||||
|
||||
if seed.scope is not None:
|
||||
updates.append(f"scope = ${param_idx}")
|
||||
params.append(seed.scope)
|
||||
param_idx += 1
|
||||
|
||||
if seed.state is not None:
|
||||
updates.append(f"state = ${param_idx}")
|
||||
params.append(seed.state)
|
||||
param_idx += 1
|
||||
|
||||
if seed.trust_boost is not None:
|
||||
updates.append(f"trust_boost = ${param_idx}")
|
||||
params.append(seed.trust_boost)
|
||||
param_idx += 1
|
||||
|
||||
if seed.enabled is not None:
|
||||
updates.append(f"enabled = ${param_idx}")
|
||||
params.append(seed.enabled)
|
||||
param_idx += 1
|
||||
|
||||
if seed.crawl_depth is not None:
|
||||
updates.append(f"crawl_depth = ${param_idx}")
|
||||
params.append(seed.crawl_depth)
|
||||
param_idx += 1
|
||||
|
||||
if seed.crawl_frequency is not None:
|
||||
updates.append(f"crawl_frequency = ${param_idx}")
|
||||
params.append(seed.crawl_frequency)
|
||||
param_idx += 1
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="Keine Felder zum Aktualisieren")
|
||||
|
||||
updates.append("updated_at = NOW()")
|
||||
params.append(seed_id)
|
||||
|
||||
query = f"""
|
||||
UPDATE edu_search_seeds
|
||||
SET {", ".join(updates)}
|
||||
WHERE id = ${param_idx}
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
result = await conn.fetchrow(query, *params)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
|
||||
|
||||
# Return updated seed
|
||||
return await get_seed(seed_id)
|
||||
|
||||
|
||||
@router.delete("/seeds/{seed_id}")
|
||||
async def delete_seed(seed_id: str):
|
||||
"""Delete a seed."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"DELETE FROM edu_search_seeds WHERE id = $1",
|
||||
seed_id
|
||||
)
|
||||
if result == "DELETE 0":
|
||||
raise HTTPException(status_code=404, detail="Seed nicht gefunden")
|
||||
|
||||
return {"status": "deleted", "id": seed_id}
|
||||
|
||||
|
||||
@router.post("/seeds/bulk-import", response_model=BulkImportResponse)
|
||||
async def bulk_import_seeds(request: BulkImportRequest):
|
||||
"""Bulk import seeds (skip duplicates)."""
|
||||
pool = await get_db_pool()
|
||||
imported = 0
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Pre-fetch all category IDs
|
||||
categories = {}
|
||||
rows = await conn.fetch("SELECT id, name FROM edu_search_categories")
|
||||
for row in rows:
|
||||
categories[row["name"]] = row["id"]
|
||||
|
||||
for seed in request.seeds:
|
||||
try:
|
||||
category_id = categories.get(seed.category_name) if seed.category_name else None
|
||||
|
||||
await conn.execute("""
|
||||
INSERT INTO edu_search_seeds (
|
||||
url, name, description, category_id, source_type, scope,
|
||||
state, trust_boost, enabled, crawl_depth, crawl_frequency
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (url) DO NOTHING
|
||||
""",
|
||||
seed.url, seed.name, seed.description, category_id,
|
||||
seed.source_type, seed.scope, seed.state, seed.trust_boost,
|
||||
seed.enabled, seed.crawl_depth, seed.crawl_frequency
|
||||
)
|
||||
imported += 1
|
||||
except asyncpg.UniqueViolationError:
|
||||
skipped += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{seed.url}: {str(e)}")
|
||||
|
||||
return BulkImportResponse(imported=imported, skipped=skipped, errors=errors)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=StatsResponse)
|
||||
async def get_stats():
|
||||
"""Get crawl statistics."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Basic counts
|
||||
total = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds")
|
||||
enabled = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds WHERE enabled = TRUE")
|
||||
total_docs = await conn.fetchval("SELECT COALESCE(SUM(total_documents), 0) FROM edu_search_seeds")
|
||||
|
||||
# By category
|
||||
cat_rows = await conn.fetch("""
|
||||
SELECT c.name, COUNT(s.id) as count
|
||||
FROM edu_search_categories c
|
||||
LEFT JOIN edu_search_seeds s ON c.id = s.category_id
|
||||
GROUP BY c.name
|
||||
""")
|
||||
by_category = {row["name"]: row["count"] for row in cat_rows}
|
||||
|
||||
# By state
|
||||
state_rows = await conn.fetch("""
|
||||
SELECT COALESCE(state, 'federal') as state, COUNT(*) as count
|
||||
FROM edu_search_seeds
|
||||
GROUP BY state
|
||||
""")
|
||||
by_state = {row["state"]: row["count"] for row in state_rows}
|
||||
|
||||
# Last crawl time
|
||||
last_crawl = await conn.fetchval(
|
||||
"SELECT MAX(last_crawled_at) FROM edu_search_seeds"
|
||||
)
|
||||
|
||||
return StatsResponse(
|
||||
total_seeds=total,
|
||||
enabled_seeds=enabled,
|
||||
total_documents=total_docs,
|
||||
seeds_by_category=by_category,
|
||||
seeds_by_state=by_state,
|
||||
last_crawl_time=last_crawl,
|
||||
)
|
||||
|
||||
|
||||
# Export for external use (edu-search-service)
|
||||
@router.get("/seeds/export/for-crawler")
|
||||
async def export_seeds_for_crawler():
|
||||
"""Export enabled seeds in format suitable for crawler."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT
|
||||
s.url, s.trust_boost, s.source_type, s.scope, s.state,
|
||||
s.crawl_depth, c.name as category
|
||||
FROM edu_search_seeds s
|
||||
LEFT JOIN edu_search_categories c ON s.category_id = c.id
|
||||
WHERE s.enabled = TRUE
|
||||
ORDER BY s.trust_boost DESC
|
||||
""")
|
||||
|
||||
return {
|
||||
"seeds": [
|
||||
{
|
||||
"url": row["url"],
|
||||
"trust": float(row["trust_boost"]),
|
||||
"source": row["source_type"],
|
||||
"scope": row["scope"],
|
||||
"state": row["state"],
|
||||
"depth": row["crawl_depth"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
"total": len(rows),
|
||||
"exported_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawl Status Feedback (from edu-search-service)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CrawlStatusUpdate(BaseModel):
|
||||
"""Crawl status update from edu-search-service."""
|
||||
seed_url: str = Field(..., description="The seed URL that was crawled")
|
||||
status: str = Field(..., description="Crawl status: success, error, partial")
|
||||
documents_crawled: int = Field(0, ge=0, description="Number of documents crawled")
|
||||
error_message: Optional[str] = Field(None, description="Error message if status is error")
|
||||
crawl_duration_seconds: float = Field(0.0, ge=0.0, description="Duration of the crawl in seconds")
|
||||
|
||||
|
||||
class CrawlStatusResponse(BaseModel):
|
||||
"""Response for crawl status update."""
|
||||
success: bool
|
||||
seed_url: str
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/seeds/crawl-status", response_model=CrawlStatusResponse)
|
||||
async def update_crawl_status(update: CrawlStatusUpdate):
|
||||
"""Update crawl status for a seed URL (called by edu-search-service)."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Find the seed by URL
|
||||
seed = await conn.fetchrow(
|
||||
"SELECT id, total_documents FROM edu_search_seeds WHERE url = $1",
|
||||
update.seed_url
|
||||
)
|
||||
|
||||
if not seed:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Seed nicht gefunden: {update.seed_url}"
|
||||
)
|
||||
|
||||
# Update the seed with crawl status
|
||||
new_total = (seed["total_documents"] or 0) + update.documents_crawled
|
||||
|
||||
await conn.execute("""
|
||||
UPDATE edu_search_seeds
|
||||
SET
|
||||
last_crawled_at = NOW(),
|
||||
last_crawl_status = $2,
|
||||
last_crawl_docs = $3,
|
||||
total_documents = $4,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""", seed["id"], update.status, update.documents_crawled, new_total)
|
||||
|
||||
logger.info(
|
||||
f"Crawl status updated: {update.seed_url} - "
|
||||
f"status={update.status}, docs={update.documents_crawled}, "
|
||||
f"duration={update.crawl_duration_seconds:.1f}s"
|
||||
)
|
||||
|
||||
return CrawlStatusResponse(
|
||||
success=True,
|
||||
seed_url=update.seed_url,
|
||||
message=f"Status aktualisiert: {update.documents_crawled} Dokumente gecrawlt"
|
||||
)
|
||||
|
||||
|
||||
class BulkCrawlStatusUpdate(BaseModel):
|
||||
"""Bulk crawl status update."""
|
||||
updates: List[CrawlStatusUpdate]
|
||||
|
||||
|
||||
class BulkCrawlStatusResponse(BaseModel):
|
||||
"""Response for bulk crawl status update."""
|
||||
updated: int
|
||||
failed: int
|
||||
errors: List[str]
|
||||
|
||||
|
||||
@router.post("/seeds/crawl-status/bulk", response_model=BulkCrawlStatusResponse)
|
||||
async def bulk_update_crawl_status(request: BulkCrawlStatusUpdate):
|
||||
"""Bulk update crawl status for multiple seeds."""
|
||||
pool = await get_db_pool()
|
||||
updated = 0
|
||||
failed = 0
|
||||
errors = []
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
for update in request.updates:
|
||||
try:
|
||||
seed = await conn.fetchrow(
|
||||
"SELECT id, total_documents FROM edu_search_seeds WHERE url = $1",
|
||||
update.seed_url
|
||||
)
|
||||
|
||||
if not seed:
|
||||
failed += 1
|
||||
errors.append(f"Seed nicht gefunden: {update.seed_url}")
|
||||
continue
|
||||
|
||||
new_total = (seed["total_documents"] or 0) + update.documents_crawled
|
||||
|
||||
await conn.execute("""
|
||||
UPDATE edu_search_seeds
|
||||
SET
|
||||
last_crawled_at = NOW(),
|
||||
last_crawl_status = $2,
|
||||
last_crawl_docs = $3,
|
||||
total_documents = $4,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""", seed["id"], update.status, update.documents_crawled, new_total)
|
||||
|
||||
updated += 1
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
errors.append(f"{update.seed_url}: {str(e)}")
|
||||
|
||||
logger.info(f"Bulk crawl status update: {updated} updated, {failed} failed")
|
||||
|
||||
return BulkCrawlStatusResponse(
|
||||
updated=updated,
|
||||
failed=failed,
|
||||
errors=errors
|
||||
)
|
||||
router.include_router(_crud_router)
|
||||
router.include_router(_status_router)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
"get_db_pool",
|
||||
# Models
|
||||
"CategoryResponse",
|
||||
"SeedBase",
|
||||
"SeedCreate",
|
||||
"SeedUpdate",
|
||||
"SeedResponse",
|
||||
"SeedsListResponse",
|
||||
"StatsResponse",
|
||||
"BulkImportRequest",
|
||||
"BulkImportResponse",
|
||||
"CrawlStatusUpdate",
|
||||
"CrawlStatusResponse",
|
||||
"BulkCrawlStatusUpdate",
|
||||
"BulkCrawlStatusResponse",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
EduSearch Seeds Stats & Crawl Status Routes.
|
||||
|
||||
Statistics, export for crawler, and crawl status feedback endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
import asyncpg
|
||||
|
||||
from .edu_search_models import (
|
||||
StatsResponse,
|
||||
CrawlStatusUpdate,
|
||||
CrawlStatusResponse,
|
||||
BulkCrawlStatusUpdate,
|
||||
BulkCrawlStatusResponse,
|
||||
)
|
||||
from .edu_search_crud import get_db_pool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["edu-search"])
|
||||
|
||||
|
||||
@router.get("/stats", response_model=StatsResponse)
|
||||
async def get_stats():
|
||||
"""Get crawl statistics."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Basic counts
|
||||
total = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds")
|
||||
enabled = await conn.fetchval("SELECT COUNT(*) FROM edu_search_seeds WHERE enabled = TRUE")
|
||||
total_docs = await conn.fetchval("SELECT COALESCE(SUM(total_documents), 0) FROM edu_search_seeds")
|
||||
|
||||
# By category
|
||||
cat_rows = await conn.fetch("""
|
||||
SELECT c.name, COUNT(s.id) as count
|
||||
FROM edu_search_categories c
|
||||
LEFT JOIN edu_search_seeds s ON c.id = s.category_id
|
||||
GROUP BY c.name
|
||||
""")
|
||||
by_category = {row["name"]: row["count"] for row in cat_rows}
|
||||
|
||||
# By state
|
||||
state_rows = await conn.fetch("""
|
||||
SELECT COALESCE(state, 'federal') as state, COUNT(*) as count
|
||||
FROM edu_search_seeds
|
||||
GROUP BY state
|
||||
""")
|
||||
by_state = {row["state"]: row["count"] for row in state_rows}
|
||||
|
||||
# Last crawl time
|
||||
last_crawl = await conn.fetchval(
|
||||
"SELECT MAX(last_crawled_at) FROM edu_search_seeds"
|
||||
)
|
||||
|
||||
return StatsResponse(
|
||||
total_seeds=total,
|
||||
enabled_seeds=enabled,
|
||||
total_documents=total_docs,
|
||||
seeds_by_category=by_category,
|
||||
seeds_by_state=by_state,
|
||||
last_crawl_time=last_crawl,
|
||||
)
|
||||
|
||||
|
||||
# Export for external use (edu-search-service)
|
||||
@router.get("/seeds/export/for-crawler")
|
||||
async def export_seeds_for_crawler():
|
||||
"""Export enabled seeds in format suitable for crawler."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT
|
||||
s.url, s.trust_boost, s.source_type, s.scope, s.state,
|
||||
s.crawl_depth, c.name as category
|
||||
FROM edu_search_seeds s
|
||||
LEFT JOIN edu_search_categories c ON s.category_id = c.id
|
||||
WHERE s.enabled = TRUE
|
||||
ORDER BY s.trust_boost DESC
|
||||
""")
|
||||
|
||||
return {
|
||||
"seeds": [
|
||||
{
|
||||
"url": row["url"],
|
||||
"trust": float(row["trust_boost"]),
|
||||
"source": row["source_type"],
|
||||
"scope": row["scope"],
|
||||
"state": row["state"],
|
||||
"depth": row["crawl_depth"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in rows
|
||||
],
|
||||
"total": len(rows),
|
||||
"exported_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Crawl Status Feedback (from edu-search-service)
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/seeds/crawl-status", response_model=CrawlStatusResponse)
|
||||
async def update_crawl_status(update: CrawlStatusUpdate):
|
||||
"""Update crawl status for a seed URL (called by edu-search-service)."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Find the seed by URL
|
||||
seed = await conn.fetchrow(
|
||||
"SELECT id, total_documents FROM edu_search_seeds WHERE url = $1",
|
||||
update.seed_url
|
||||
)
|
||||
|
||||
if not seed:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Seed nicht gefunden: {update.seed_url}"
|
||||
)
|
||||
|
||||
# Update the seed with crawl status
|
||||
new_total = (seed["total_documents"] or 0) + update.documents_crawled
|
||||
|
||||
await conn.execute("""
|
||||
UPDATE edu_search_seeds
|
||||
SET
|
||||
last_crawled_at = NOW(),
|
||||
last_crawl_status = $2,
|
||||
last_crawl_docs = $3,
|
||||
total_documents = $4,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""", seed["id"], update.status, update.documents_crawled, new_total)
|
||||
|
||||
logger.info(
|
||||
f"Crawl status updated: {update.seed_url} - "
|
||||
f"status={update.status}, docs={update.documents_crawled}, "
|
||||
f"duration={update.crawl_duration_seconds:.1f}s"
|
||||
)
|
||||
|
||||
return CrawlStatusResponse(
|
||||
success=True,
|
||||
seed_url=update.seed_url,
|
||||
message=f"Status aktualisiert: {update.documents_crawled} Dokumente gecrawlt"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/seeds/crawl-status/bulk", response_model=BulkCrawlStatusResponse)
|
||||
async def bulk_update_crawl_status(request: BulkCrawlStatusUpdate):
|
||||
"""Bulk update crawl status for multiple seeds."""
|
||||
pool = await get_db_pool()
|
||||
updated = 0
|
||||
failed = 0
|
||||
errors = []
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
for update in request.updates:
|
||||
try:
|
||||
seed = await conn.fetchrow(
|
||||
"SELECT id, total_documents FROM edu_search_seeds WHERE url = $1",
|
||||
update.seed_url
|
||||
)
|
||||
|
||||
if not seed:
|
||||
failed += 1
|
||||
errors.append(f"Seed nicht gefunden: {update.seed_url}")
|
||||
continue
|
||||
|
||||
new_total = (seed["total_documents"] or 0) + update.documents_crawled
|
||||
|
||||
await conn.execute("""
|
||||
UPDATE edu_search_seeds
|
||||
SET
|
||||
last_crawled_at = NOW(),
|
||||
last_crawl_status = $2,
|
||||
last_crawl_docs = $3,
|
||||
total_documents = $4,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""", seed["id"], update.status, update.documents_crawled, new_total)
|
||||
|
||||
updated += 1
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
errors.append(f"{update.seed_url}: {str(e)}")
|
||||
|
||||
logger.info(f"Bulk crawl status update: {updated} updated, {failed} failed")
|
||||
|
||||
return BulkCrawlStatusResponse(
|
||||
updated=updated,
|
||||
failed=failed,
|
||||
errors=errors
|
||||
)
|
||||
@@ -1,867 +1,38 @@
|
||||
"""
|
||||
Schools API Routes.
|
||||
Schools API Routes — Barrel Re-export.
|
||||
|
||||
CRUD operations for managing German schools (~40,000 schools).
|
||||
Direct database access to PostgreSQL.
|
||||
Split into:
|
||||
- schools_models.py: Pydantic models
|
||||
- schools_db.py: Database connection pool
|
||||
- schools_crud.py: School CRUD & stats routes
|
||||
- schools_staff.py: Staff CRUD & search routes
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .schools_crud import router as _crud_router
|
||||
from .schools_staff import router as _staff_router
|
||||
|
||||
# Single router that merges both sub-module routers
|
||||
router = APIRouter(prefix="/schools", tags=["schools"])
|
||||
|
||||
# Database connection pool
|
||||
_pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
|
||||
async def get_db_pool() -> asyncpg.Pool:
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
database_url = os.environ.get(
|
||||
"DATABASE_URL",
|
||||
"postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db"
|
||||
)
|
||||
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
|
||||
return _pool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SchoolTypeResponse(BaseModel):
|
||||
"""School type response model."""
|
||||
id: str
|
||||
name: str
|
||||
name_short: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class SchoolBase(BaseModel):
|
||||
"""Base school model for creation/update."""
|
||||
name: str = Field(..., max_length=255)
|
||||
school_number: Optional[str] = Field(None, max_length=20)
|
||||
school_type_id: Optional[str] = None
|
||||
school_type_raw: Optional[str] = None
|
||||
state: str = Field(..., max_length=10)
|
||||
district: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
street: Optional[str] = None
|
||||
address_full: Optional[str] = None
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
website: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
fax: Optional[str] = None
|
||||
principal_name: Optional[str] = None
|
||||
principal_title: Optional[str] = None
|
||||
principal_email: Optional[str] = None
|
||||
principal_phone: Optional[str] = None
|
||||
secretary_name: Optional[str] = None
|
||||
secretary_email: Optional[str] = None
|
||||
secretary_phone: Optional[str] = None
|
||||
student_count: Optional[int] = None
|
||||
teacher_count: Optional[int] = None
|
||||
class_count: Optional[int] = None
|
||||
founded_year: Optional[int] = None
|
||||
is_public: bool = True
|
||||
is_all_day: Optional[bool] = None
|
||||
has_inclusion: Optional[bool] = None
|
||||
languages: Optional[List[str]] = None
|
||||
specializations: Optional[List[str]] = None
|
||||
source: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
|
||||
|
||||
class SchoolCreate(SchoolBase):
|
||||
"""School creation model."""
|
||||
pass
|
||||
|
||||
|
||||
class SchoolUpdate(BaseModel):
|
||||
"""School update model (all fields optional)."""
|
||||
name: Optional[str] = Field(None, max_length=255)
|
||||
school_number: Optional[str] = None
|
||||
school_type_id: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
district: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
street: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
principal_name: Optional[str] = None
|
||||
student_count: Optional[int] = None
|
||||
teacher_count: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class SchoolResponse(BaseModel):
|
||||
"""School response model."""
|
||||
id: str
|
||||
name: str
|
||||
school_number: Optional[str] = None
|
||||
school_type: Optional[str] = None
|
||||
school_type_short: Optional[str] = None
|
||||
school_category: Optional[str] = None
|
||||
state: str
|
||||
district: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
street: Optional[str] = None
|
||||
address_full: Optional[str] = None
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
website: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
fax: Optional[str] = None
|
||||
principal_name: Optional[str] = None
|
||||
principal_email: Optional[str] = None
|
||||
student_count: Optional[int] = None
|
||||
teacher_count: Optional[int] = None
|
||||
is_public: bool = True
|
||||
is_all_day: Optional[bool] = None
|
||||
staff_count: int = 0
|
||||
source: Optional[str] = None
|
||||
crawled_at: Optional[datetime] = None
|
||||
is_active: bool = True
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class SchoolsListResponse(BaseModel):
|
||||
"""List response with pagination info."""
|
||||
schools: List[SchoolResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class SchoolStaffBase(BaseModel):
|
||||
"""Base school staff model."""
|
||||
first_name: Optional[str] = None
|
||||
last_name: str
|
||||
full_name: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
position: Optional[str] = None
|
||||
position_type: Optional[str] = None
|
||||
subjects: Optional[List[str]] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
|
||||
|
||||
class SchoolStaffCreate(SchoolStaffBase):
|
||||
"""School staff creation model."""
|
||||
school_id: str
|
||||
|
||||
|
||||
class SchoolStaffResponse(SchoolStaffBase):
|
||||
"""School staff response model."""
|
||||
id: str
|
||||
school_id: str
|
||||
school_name: Optional[str] = None
|
||||
profile_url: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
is_active: bool = True
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SchoolStaffListResponse(BaseModel):
|
||||
"""Staff list response."""
|
||||
staff: List[SchoolStaffResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class SchoolStatsResponse(BaseModel):
|
||||
"""School statistics response."""
|
||||
total_schools: int
|
||||
total_staff: int
|
||||
schools_by_state: dict
|
||||
schools_by_type: dict
|
||||
schools_with_website: int
|
||||
schools_with_email: int
|
||||
schools_with_principal: int
|
||||
total_students: int
|
||||
total_teachers: int
|
||||
last_crawl_time: Optional[datetime] = None
|
||||
|
||||
|
||||
class BulkImportRequest(BaseModel):
|
||||
"""Bulk import request."""
|
||||
schools: List[SchoolCreate]
|
||||
|
||||
|
||||
class BulkImportResponse(BaseModel):
|
||||
"""Bulk import response."""
|
||||
imported: int
|
||||
updated: int
|
||||
skipped: int
|
||||
errors: List[str]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Type Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/types", response_model=List[SchoolTypeResponse])
|
||||
async def list_school_types():
|
||||
"""List all school types."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, name_short, category, description
|
||||
FROM school_types
|
||||
ORDER BY category, name
|
||||
""")
|
||||
return [
|
||||
SchoolTypeResponse(
|
||||
id=str(row["id"]),
|
||||
name=row["name"],
|
||||
name_short=row["name_short"],
|
||||
category=row["category"],
|
||||
description=row["description"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("", response_model=SchoolsListResponse)
|
||||
async def list_schools(
|
||||
state: Optional[str] = Query(None, description="Filter by state code (BW, BY, etc.)"),
|
||||
school_type: Optional[str] = Query(None, description="Filter by school type name"),
|
||||
city: Optional[str] = Query(None, description="Filter by city"),
|
||||
district: Optional[str] = Query(None, description="Filter by district"),
|
||||
postal_code: Optional[str] = Query(None, description="Filter by postal code prefix"),
|
||||
search: Optional[str] = Query(None, description="Search in name, city"),
|
||||
has_email: Optional[bool] = Query(None, description="Filter schools with email"),
|
||||
has_website: Optional[bool] = Query(None, description="Filter schools with website"),
|
||||
is_public: Optional[bool] = Query(None, description="Filter public/private schools"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
"""List schools with optional filtering and pagination."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Build WHERE clause
|
||||
conditions = ["s.is_active = TRUE"]
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if state:
|
||||
conditions.append(f"s.state = ${param_idx}")
|
||||
params.append(state.upper())
|
||||
param_idx += 1
|
||||
|
||||
if school_type:
|
||||
conditions.append(f"st.name = ${param_idx}")
|
||||
params.append(school_type)
|
||||
param_idx += 1
|
||||
|
||||
if city:
|
||||
conditions.append(f"LOWER(s.city) = LOWER(${param_idx})")
|
||||
params.append(city)
|
||||
param_idx += 1
|
||||
|
||||
if district:
|
||||
conditions.append(f"LOWER(s.district) LIKE LOWER(${param_idx})")
|
||||
params.append(f"%{district}%")
|
||||
param_idx += 1
|
||||
|
||||
if postal_code:
|
||||
conditions.append(f"s.postal_code LIKE ${param_idx}")
|
||||
params.append(f"{postal_code}%")
|
||||
param_idx += 1
|
||||
|
||||
if search:
|
||||
conditions.append(f"""
|
||||
(LOWER(s.name) LIKE LOWER(${param_idx})
|
||||
OR LOWER(s.city) LIKE LOWER(${param_idx})
|
||||
OR LOWER(s.district) LIKE LOWER(${param_idx}))
|
||||
""")
|
||||
params.append(f"%{search}%")
|
||||
param_idx += 1
|
||||
|
||||
if has_email is not None:
|
||||
if has_email:
|
||||
conditions.append("s.email IS NOT NULL")
|
||||
else:
|
||||
conditions.append("s.email IS NULL")
|
||||
|
||||
if has_website is not None:
|
||||
if has_website:
|
||||
conditions.append("s.website IS NOT NULL")
|
||||
else:
|
||||
conditions.append("s.website IS NULL")
|
||||
|
||||
if is_public is not None:
|
||||
conditions.append(f"s.is_public = ${param_idx}")
|
||||
params.append(is_public)
|
||||
param_idx += 1
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
# Count total
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) FROM schools s
|
||||
LEFT JOIN school_types st ON s.school_type_id = st.id
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
total = await conn.fetchval(count_query, *params)
|
||||
|
||||
# Fetch schools
|
||||
offset = (page - 1) * page_size
|
||||
query = f"""
|
||||
SELECT
|
||||
s.id, s.name, s.school_number, s.state, s.district, s.city,
|
||||
s.postal_code, s.street, s.address_full, s.latitude, s.longitude,
|
||||
s.website, s.email, s.phone, s.fax,
|
||||
s.principal_name, s.principal_email,
|
||||
s.student_count, s.teacher_count,
|
||||
s.is_public, s.is_all_day, s.source, s.crawled_at,
|
||||
s.is_active, s.created_at, s.updated_at,
|
||||
st.name as school_type, st.name_short as school_type_short, st.category as school_category,
|
||||
(SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count
|
||||
FROM schools s
|
||||
LEFT JOIN school_types st ON s.school_type_id = st.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY s.state, s.city, s.name
|
||||
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
"""
|
||||
params.extend([page_size, offset])
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
schools = [
|
||||
SchoolResponse(
|
||||
id=str(row["id"]),
|
||||
name=row["name"],
|
||||
school_number=row["school_number"],
|
||||
school_type=row["school_type"],
|
||||
school_type_short=row["school_type_short"],
|
||||
school_category=row["school_category"],
|
||||
state=row["state"],
|
||||
district=row["district"],
|
||||
city=row["city"],
|
||||
postal_code=row["postal_code"],
|
||||
street=row["street"],
|
||||
address_full=row["address_full"],
|
||||
latitude=row["latitude"],
|
||||
longitude=row["longitude"],
|
||||
website=row["website"],
|
||||
email=row["email"],
|
||||
phone=row["phone"],
|
||||
fax=row["fax"],
|
||||
principal_name=row["principal_name"],
|
||||
principal_email=row["principal_email"],
|
||||
student_count=row["student_count"],
|
||||
teacher_count=row["teacher_count"],
|
||||
is_public=row["is_public"],
|
||||
is_all_day=row["is_all_day"],
|
||||
staff_count=row["staff_count"],
|
||||
source=row["source"],
|
||||
crawled_at=row["crawled_at"],
|
||||
is_active=row["is_active"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return SchoolsListResponse(
|
||||
schools=schools,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=SchoolStatsResponse)
|
||||
async def get_school_stats():
|
||||
"""Get school statistics."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Total schools and staff
|
||||
totals = await conn.fetchrow("""
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE) as total_schools,
|
||||
(SELECT COUNT(*) FROM school_staff WHERE is_active = TRUE) as total_staff,
|
||||
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND website IS NOT NULL) as with_website,
|
||||
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND email IS NOT NULL) as with_email,
|
||||
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND principal_name IS NOT NULL) as with_principal,
|
||||
(SELECT COALESCE(SUM(student_count), 0) FROM schools WHERE is_active = TRUE) as total_students,
|
||||
(SELECT COALESCE(SUM(teacher_count), 0) FROM schools WHERE is_active = TRUE) as total_teachers,
|
||||
(SELECT MAX(crawled_at) FROM schools) as last_crawl
|
||||
""")
|
||||
|
||||
# By state
|
||||
state_rows = await conn.fetch("""
|
||||
SELECT state, COUNT(*) as count
|
||||
FROM schools
|
||||
WHERE is_active = TRUE
|
||||
GROUP BY state
|
||||
ORDER BY state
|
||||
""")
|
||||
schools_by_state = {row["state"]: row["count"] for row in state_rows}
|
||||
|
||||
# By type
|
||||
type_rows = await conn.fetch("""
|
||||
SELECT COALESCE(st.name, 'Unbekannt') as type_name, COUNT(*) as count
|
||||
FROM schools s
|
||||
LEFT JOIN school_types st ON s.school_type_id = st.id
|
||||
WHERE s.is_active = TRUE
|
||||
GROUP BY st.name
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
schools_by_type = {row["type_name"]: row["count"] for row in type_rows}
|
||||
|
||||
return SchoolStatsResponse(
|
||||
total_schools=totals["total_schools"],
|
||||
total_staff=totals["total_staff"],
|
||||
schools_by_state=schools_by_state,
|
||||
schools_by_type=schools_by_type,
|
||||
schools_with_website=totals["with_website"],
|
||||
schools_with_email=totals["with_email"],
|
||||
schools_with_principal=totals["with_principal"],
|
||||
total_students=totals["total_students"],
|
||||
total_teachers=totals["total_teachers"],
|
||||
last_crawl_time=totals["last_crawl"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{school_id}", response_model=SchoolResponse)
|
||||
async def get_school(school_id: str):
|
||||
"""Get a single school by ID."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT
|
||||
s.id, s.name, s.school_number, s.state, s.district, s.city,
|
||||
s.postal_code, s.street, s.address_full, s.latitude, s.longitude,
|
||||
s.website, s.email, s.phone, s.fax,
|
||||
s.principal_name, s.principal_email,
|
||||
s.student_count, s.teacher_count,
|
||||
s.is_public, s.is_all_day, s.source, s.crawled_at,
|
||||
s.is_active, s.created_at, s.updated_at,
|
||||
st.name as school_type, st.name_short as school_type_short, st.category as school_category,
|
||||
(SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count
|
||||
FROM schools s
|
||||
LEFT JOIN school_types st ON s.school_type_id = st.id
|
||||
WHERE s.id = $1
|
||||
""", school_id)
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="School not found")
|
||||
|
||||
return SchoolResponse(
|
||||
id=str(row["id"]),
|
||||
name=row["name"],
|
||||
school_number=row["school_number"],
|
||||
school_type=row["school_type"],
|
||||
school_type_short=row["school_type_short"],
|
||||
school_category=row["school_category"],
|
||||
state=row["state"],
|
||||
district=row["district"],
|
||||
city=row["city"],
|
||||
postal_code=row["postal_code"],
|
||||
street=row["street"],
|
||||
address_full=row["address_full"],
|
||||
latitude=row["latitude"],
|
||||
longitude=row["longitude"],
|
||||
website=row["website"],
|
||||
email=row["email"],
|
||||
phone=row["phone"],
|
||||
fax=row["fax"],
|
||||
principal_name=row["principal_name"],
|
||||
principal_email=row["principal_email"],
|
||||
student_count=row["student_count"],
|
||||
teacher_count=row["teacher_count"],
|
||||
is_public=row["is_public"],
|
||||
is_all_day=row["is_all_day"],
|
||||
staff_count=row["staff_count"],
|
||||
source=row["source"],
|
||||
crawled_at=row["crawled_at"],
|
||||
is_active=row["is_active"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bulk-import", response_model=BulkImportResponse)
|
||||
async def bulk_import_schools(request: BulkImportRequest):
|
||||
"""Bulk import schools. Updates existing schools based on school_number + state."""
|
||||
pool = await get_db_pool()
|
||||
imported = 0
|
||||
updated = 0
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Get school type mapping
|
||||
type_rows = await conn.fetch("SELECT id, name FROM school_types")
|
||||
type_map = {row["name"].lower(): str(row["id"]) for row in type_rows}
|
||||
|
||||
for school in request.schools:
|
||||
try:
|
||||
# Find school type ID
|
||||
school_type_id = None
|
||||
if school.school_type_raw:
|
||||
school_type_id = type_map.get(school.school_type_raw.lower())
|
||||
|
||||
# Check if school exists (by school_number + state, or by name + city + state)
|
||||
existing = None
|
||||
if school.school_number:
|
||||
existing = await conn.fetchrow(
|
||||
"SELECT id FROM schools WHERE school_number = $1 AND state = $2",
|
||||
school.school_number, school.state
|
||||
)
|
||||
if not existing and school.city:
|
||||
existing = await conn.fetchrow(
|
||||
"SELECT id FROM schools WHERE LOWER(name) = LOWER($1) AND LOWER(city) = LOWER($2) AND state = $3",
|
||||
school.name, school.city, school.state
|
||||
)
|
||||
|
||||
if existing:
|
||||
# Update existing school
|
||||
await conn.execute("""
|
||||
UPDATE schools SET
|
||||
name = $2,
|
||||
school_type_id = COALESCE($3, school_type_id),
|
||||
school_type_raw = COALESCE($4, school_type_raw),
|
||||
district = COALESCE($5, district),
|
||||
city = COALESCE($6, city),
|
||||
postal_code = COALESCE($7, postal_code),
|
||||
street = COALESCE($8, street),
|
||||
address_full = COALESCE($9, address_full),
|
||||
latitude = COALESCE($10, latitude),
|
||||
longitude = COALESCE($11, longitude),
|
||||
website = COALESCE($12, website),
|
||||
email = COALESCE($13, email),
|
||||
phone = COALESCE($14, phone),
|
||||
fax = COALESCE($15, fax),
|
||||
principal_name = COALESCE($16, principal_name),
|
||||
principal_title = COALESCE($17, principal_title),
|
||||
principal_email = COALESCE($18, principal_email),
|
||||
principal_phone = COALESCE($19, principal_phone),
|
||||
student_count = COALESCE($20, student_count),
|
||||
teacher_count = COALESCE($21, teacher_count),
|
||||
is_public = $22,
|
||||
source = COALESCE($23, source),
|
||||
source_url = COALESCE($24, source_url),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
existing["id"],
|
||||
school.name,
|
||||
school_type_id,
|
||||
school.school_type_raw,
|
||||
school.district,
|
||||
school.city,
|
||||
school.postal_code,
|
||||
school.street,
|
||||
school.address_full,
|
||||
school.latitude,
|
||||
school.longitude,
|
||||
school.website,
|
||||
school.email,
|
||||
school.phone,
|
||||
school.fax,
|
||||
school.principal_name,
|
||||
school.principal_title,
|
||||
school.principal_email,
|
||||
school.principal_phone,
|
||||
school.student_count,
|
||||
school.teacher_count,
|
||||
school.is_public,
|
||||
school.source,
|
||||
school.source_url,
|
||||
)
|
||||
updated += 1
|
||||
else:
|
||||
# Insert new school
|
||||
await conn.execute("""
|
||||
INSERT INTO schools (
|
||||
name, school_number, school_type_id, school_type_raw,
|
||||
state, district, city, postal_code, street, address_full,
|
||||
latitude, longitude, website, email, phone, fax,
|
||||
principal_name, principal_title, principal_email, principal_phone,
|
||||
student_count, teacher_count, is_public,
|
||||
source, source_url, crawled_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
||||
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
|
||||
$21, $22, $23, $24, $25, NOW()
|
||||
)
|
||||
""",
|
||||
school.name,
|
||||
school.school_number,
|
||||
school_type_id,
|
||||
school.school_type_raw,
|
||||
school.state,
|
||||
school.district,
|
||||
school.city,
|
||||
school.postal_code,
|
||||
school.street,
|
||||
school.address_full,
|
||||
school.latitude,
|
||||
school.longitude,
|
||||
school.website,
|
||||
school.email,
|
||||
school.phone,
|
||||
school.fax,
|
||||
school.principal_name,
|
||||
school.principal_title,
|
||||
school.principal_email,
|
||||
school.principal_phone,
|
||||
school.student_count,
|
||||
school.teacher_count,
|
||||
school.is_public,
|
||||
school.source,
|
||||
school.source_url,
|
||||
)
|
||||
imported += 1
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error importing {school.name}: {str(e)}")
|
||||
if len(errors) > 100:
|
||||
errors.append("... (more errors truncated)")
|
||||
break
|
||||
|
||||
return BulkImportResponse(
|
||||
imported=imported,
|
||||
updated=updated,
|
||||
skipped=skipped,
|
||||
errors=errors[:100],
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Staff Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/{school_id}/staff", response_model=SchoolStaffListResponse)
|
||||
async def get_school_staff(school_id: str):
|
||||
"""Get staff members for a school."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT
|
||||
ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name,
|
||||
ss.title, ss.position, ss.position_type, ss.subjects,
|
||||
ss.email, ss.phone, ss.profile_url, ss.photo_url,
|
||||
ss.is_active, ss.created_at,
|
||||
s.name as school_name
|
||||
FROM school_staff ss
|
||||
JOIN schools s ON ss.school_id = s.id
|
||||
WHERE ss.school_id = $1 AND ss.is_active = TRUE
|
||||
ORDER BY
|
||||
CASE ss.position_type
|
||||
WHEN 'principal' THEN 1
|
||||
WHEN 'vice_principal' THEN 2
|
||||
WHEN 'secretary' THEN 3
|
||||
ELSE 4
|
||||
END,
|
||||
ss.last_name
|
||||
""", school_id)
|
||||
|
||||
staff = [
|
||||
SchoolStaffResponse(
|
||||
id=str(row["id"]),
|
||||
school_id=str(row["school_id"]),
|
||||
school_name=row["school_name"],
|
||||
first_name=row["first_name"],
|
||||
last_name=row["last_name"],
|
||||
full_name=row["full_name"],
|
||||
title=row["title"],
|
||||
position=row["position"],
|
||||
position_type=row["position_type"],
|
||||
subjects=row["subjects"],
|
||||
email=row["email"],
|
||||
phone=row["phone"],
|
||||
profile_url=row["profile_url"],
|
||||
photo_url=row["photo_url"],
|
||||
is_active=row["is_active"],
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return SchoolStaffListResponse(
|
||||
staff=staff,
|
||||
total=len(staff),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{school_id}/staff", response_model=SchoolStaffResponse)
|
||||
async def create_school_staff(school_id: str, staff: SchoolStaffBase):
|
||||
"""Add a staff member to a school."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Verify school exists
|
||||
school = await conn.fetchrow("SELECT name FROM schools WHERE id = $1", school_id)
|
||||
if not school:
|
||||
raise HTTPException(status_code=404, detail="School not found")
|
||||
|
||||
# Create full name
|
||||
full_name = staff.full_name
|
||||
if not full_name:
|
||||
parts = []
|
||||
if staff.title:
|
||||
parts.append(staff.title)
|
||||
if staff.first_name:
|
||||
parts.append(staff.first_name)
|
||||
parts.append(staff.last_name)
|
||||
full_name = " ".join(parts)
|
||||
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO school_staff (
|
||||
school_id, first_name, last_name, full_name, title,
|
||||
position, position_type, subjects, email, phone
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING id, created_at
|
||||
""",
|
||||
school_id,
|
||||
staff.first_name,
|
||||
staff.last_name,
|
||||
full_name,
|
||||
staff.title,
|
||||
staff.position,
|
||||
staff.position_type,
|
||||
staff.subjects,
|
||||
staff.email,
|
||||
staff.phone,
|
||||
)
|
||||
|
||||
return SchoolStaffResponse(
|
||||
id=str(row["id"]),
|
||||
school_id=school_id,
|
||||
school_name=school["name"],
|
||||
first_name=staff.first_name,
|
||||
last_name=staff.last_name,
|
||||
full_name=full_name,
|
||||
title=staff.title,
|
||||
position=staff.position,
|
||||
position_type=staff.position_type,
|
||||
subjects=staff.subjects,
|
||||
email=staff.email,
|
||||
phone=staff.phone,
|
||||
is_active=True,
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Search Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/search/staff", response_model=SchoolStaffListResponse)
|
||||
async def search_school_staff(
|
||||
q: Optional[str] = Query(None, description="Search query"),
|
||||
state: Optional[str] = Query(None, description="Filter by state"),
|
||||
position_type: Optional[str] = Query(None, description="Filter by position type"),
|
||||
has_email: Optional[bool] = Query(None, description="Only staff with email"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
"""Search school staff across all schools."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
conditions = ["ss.is_active = TRUE", "s.is_active = TRUE"]
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if q:
|
||||
conditions.append(f"""
|
||||
(LOWER(ss.full_name) LIKE LOWER(${param_idx})
|
||||
OR LOWER(ss.last_name) LIKE LOWER(${param_idx})
|
||||
OR LOWER(s.name) LIKE LOWER(${param_idx}))
|
||||
""")
|
||||
params.append(f"%{q}%")
|
||||
param_idx += 1
|
||||
|
||||
if state:
|
||||
conditions.append(f"s.state = ${param_idx}")
|
||||
params.append(state.upper())
|
||||
param_idx += 1
|
||||
|
||||
if position_type:
|
||||
conditions.append(f"ss.position_type = ${param_idx}")
|
||||
params.append(position_type)
|
||||
param_idx += 1
|
||||
|
||||
if has_email is not None and has_email:
|
||||
conditions.append("ss.email IS NOT NULL")
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
# Count total
|
||||
total = await conn.fetchval(f"""
|
||||
SELECT COUNT(*) FROM school_staff ss
|
||||
JOIN schools s ON ss.school_id = s.id
|
||||
WHERE {where_clause}
|
||||
""", *params)
|
||||
|
||||
# Fetch staff
|
||||
offset = (page - 1) * page_size
|
||||
rows = await conn.fetch(f"""
|
||||
SELECT
|
||||
ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name,
|
||||
ss.title, ss.position, ss.position_type, ss.subjects,
|
||||
ss.email, ss.phone, ss.profile_url, ss.photo_url,
|
||||
ss.is_active, ss.created_at,
|
||||
s.name as school_name
|
||||
FROM school_staff ss
|
||||
JOIN schools s ON ss.school_id = s.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY ss.last_name, ss.first_name
|
||||
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
""", *params, page_size, offset)
|
||||
|
||||
staff = [
|
||||
SchoolStaffResponse(
|
||||
id=str(row["id"]),
|
||||
school_id=str(row["school_id"]),
|
||||
school_name=row["school_name"],
|
||||
first_name=row["first_name"],
|
||||
last_name=row["last_name"],
|
||||
full_name=row["full_name"],
|
||||
title=row["title"],
|
||||
position=row["position"],
|
||||
position_type=row["position_type"],
|
||||
subjects=row["subjects"],
|
||||
email=row["email"],
|
||||
phone=row["phone"],
|
||||
profile_url=row["profile_url"],
|
||||
photo_url=row["photo_url"],
|
||||
is_active=row["is_active"],
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return SchoolStaffListResponse(
|
||||
staff=staff,
|
||||
total=total,
|
||||
)
|
||||
router.include_router(_crud_router)
|
||||
router.include_router(_staff_router)
|
||||
|
||||
# Re-export models for any external consumers
|
||||
from .schools_models import ( # noqa: E402, F401
|
||||
SchoolTypeResponse,
|
||||
SchoolBase,
|
||||
SchoolCreate,
|
||||
SchoolUpdate,
|
||||
SchoolResponse,
|
||||
SchoolsListResponse,
|
||||
SchoolStaffBase,
|
||||
SchoolStaffCreate,
|
||||
SchoolStaffResponse,
|
||||
SchoolStaffListResponse,
|
||||
SchoolStatsResponse,
|
||||
BulkImportRequest,
|
||||
BulkImportResponse,
|
||||
)
|
||||
from .schools_db import get_db_pool # noqa: E402, F401
|
||||
|
||||
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Schools API - School CRUD & Stats Routes.
|
||||
|
||||
List, get, stats, and bulk-import endpoints for schools.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from .schools_db import get_db_pool
|
||||
from .schools_models import (
|
||||
SchoolResponse,
|
||||
SchoolsListResponse,
|
||||
SchoolStatsResponse,
|
||||
SchoolTypeResponse,
|
||||
BulkImportRequest,
|
||||
BulkImportResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["schools"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Type Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/types", response_model=list[SchoolTypeResponse])
|
||||
async def list_school_types():
|
||||
"""List all school types."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, name_short, category, description
|
||||
FROM school_types
|
||||
ORDER BY category, name
|
||||
""")
|
||||
return [
|
||||
SchoolTypeResponse(
|
||||
id=str(row["id"]),
|
||||
name=row["name"],
|
||||
name_short=row["name_short"],
|
||||
category=row["category"],
|
||||
description=row["description"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("", response_model=SchoolsListResponse)
|
||||
async def list_schools(
|
||||
state: Optional[str] = Query(None, description="Filter by state code (BW, BY, etc.)"),
|
||||
school_type: Optional[str] = Query(None, description="Filter by school type name"),
|
||||
city: Optional[str] = Query(None, description="Filter by city"),
|
||||
district: Optional[str] = Query(None, description="Filter by district"),
|
||||
postal_code: Optional[str] = Query(None, description="Filter by postal code prefix"),
|
||||
search: Optional[str] = Query(None, description="Search in name, city"),
|
||||
has_email: Optional[bool] = Query(None, description="Filter schools with email"),
|
||||
has_website: Optional[bool] = Query(None, description="Filter schools with website"),
|
||||
is_public: Optional[bool] = Query(None, description="Filter public/private schools"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
"""List schools with optional filtering and pagination."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Build WHERE clause
|
||||
conditions = ["s.is_active = TRUE"]
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if state:
|
||||
conditions.append(f"s.state = ${param_idx}")
|
||||
params.append(state.upper())
|
||||
param_idx += 1
|
||||
|
||||
if school_type:
|
||||
conditions.append(f"st.name = ${param_idx}")
|
||||
params.append(school_type)
|
||||
param_idx += 1
|
||||
|
||||
if city:
|
||||
conditions.append(f"LOWER(s.city) = LOWER(${param_idx})")
|
||||
params.append(city)
|
||||
param_idx += 1
|
||||
|
||||
if district:
|
||||
conditions.append(f"LOWER(s.district) LIKE LOWER(${param_idx})")
|
||||
params.append(f"%{district}%")
|
||||
param_idx += 1
|
||||
|
||||
if postal_code:
|
||||
conditions.append(f"s.postal_code LIKE ${param_idx}")
|
||||
params.append(f"{postal_code}%")
|
||||
param_idx += 1
|
||||
|
||||
if search:
|
||||
conditions.append(f"""
|
||||
(LOWER(s.name) LIKE LOWER(${param_idx})
|
||||
OR LOWER(s.city) LIKE LOWER(${param_idx})
|
||||
OR LOWER(s.district) LIKE LOWER(${param_idx}))
|
||||
""")
|
||||
params.append(f"%{search}%")
|
||||
param_idx += 1
|
||||
|
||||
if has_email is not None:
|
||||
if has_email:
|
||||
conditions.append("s.email IS NOT NULL")
|
||||
else:
|
||||
conditions.append("s.email IS NULL")
|
||||
|
||||
if has_website is not None:
|
||||
if has_website:
|
||||
conditions.append("s.website IS NOT NULL")
|
||||
else:
|
||||
conditions.append("s.website IS NULL")
|
||||
|
||||
if is_public is not None:
|
||||
conditions.append(f"s.is_public = ${param_idx}")
|
||||
params.append(is_public)
|
||||
param_idx += 1
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
# Count total
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) FROM schools s
|
||||
LEFT JOIN school_types st ON s.school_type_id = st.id
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
total = await conn.fetchval(count_query, *params)
|
||||
|
||||
# Fetch schools
|
||||
offset = (page - 1) * page_size
|
||||
query = f"""
|
||||
SELECT
|
||||
s.id, s.name, s.school_number, s.state, s.district, s.city,
|
||||
s.postal_code, s.street, s.address_full, s.latitude, s.longitude,
|
||||
s.website, s.email, s.phone, s.fax,
|
||||
s.principal_name, s.principal_email,
|
||||
s.student_count, s.teacher_count,
|
||||
s.is_public, s.is_all_day, s.source, s.crawled_at,
|
||||
s.is_active, s.created_at, s.updated_at,
|
||||
st.name as school_type, st.name_short as school_type_short, st.category as school_category,
|
||||
(SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count
|
||||
FROM schools s
|
||||
LEFT JOIN school_types st ON s.school_type_id = st.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY s.state, s.city, s.name
|
||||
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
"""
|
||||
params.extend([page_size, offset])
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
schools = [
|
||||
SchoolResponse(
|
||||
id=str(row["id"]),
|
||||
name=row["name"],
|
||||
school_number=row["school_number"],
|
||||
school_type=row["school_type"],
|
||||
school_type_short=row["school_type_short"],
|
||||
school_category=row["school_category"],
|
||||
state=row["state"],
|
||||
district=row["district"],
|
||||
city=row["city"],
|
||||
postal_code=row["postal_code"],
|
||||
street=row["street"],
|
||||
address_full=row["address_full"],
|
||||
latitude=row["latitude"],
|
||||
longitude=row["longitude"],
|
||||
website=row["website"],
|
||||
email=row["email"],
|
||||
phone=row["phone"],
|
||||
fax=row["fax"],
|
||||
principal_name=row["principal_name"],
|
||||
principal_email=row["principal_email"],
|
||||
student_count=row["student_count"],
|
||||
teacher_count=row["teacher_count"],
|
||||
is_public=row["is_public"],
|
||||
is_all_day=row["is_all_day"],
|
||||
staff_count=row["staff_count"],
|
||||
source=row["source"],
|
||||
crawled_at=row["crawled_at"],
|
||||
is_active=row["is_active"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return SchoolsListResponse(
|
||||
schools=schools,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=SchoolStatsResponse)
|
||||
async def get_school_stats():
|
||||
"""Get school statistics."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Total schools and staff
|
||||
totals = await conn.fetchrow("""
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE) as total_schools,
|
||||
(SELECT COUNT(*) FROM school_staff WHERE is_active = TRUE) as total_staff,
|
||||
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND website IS NOT NULL) as with_website,
|
||||
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND email IS NOT NULL) as with_email,
|
||||
(SELECT COUNT(*) FROM schools WHERE is_active = TRUE AND principal_name IS NOT NULL) as with_principal,
|
||||
(SELECT COALESCE(SUM(student_count), 0) FROM schools WHERE is_active = TRUE) as total_students,
|
||||
(SELECT COALESCE(SUM(teacher_count), 0) FROM schools WHERE is_active = TRUE) as total_teachers,
|
||||
(SELECT MAX(crawled_at) FROM schools) as last_crawl
|
||||
""")
|
||||
|
||||
# By state
|
||||
state_rows = await conn.fetch("""
|
||||
SELECT state, COUNT(*) as count
|
||||
FROM schools
|
||||
WHERE is_active = TRUE
|
||||
GROUP BY state
|
||||
ORDER BY state
|
||||
""")
|
||||
schools_by_state = {row["state"]: row["count"] for row in state_rows}
|
||||
|
||||
# By type
|
||||
type_rows = await conn.fetch("""
|
||||
SELECT COALESCE(st.name, 'Unbekannt') as type_name, COUNT(*) as count
|
||||
FROM schools s
|
||||
LEFT JOIN school_types st ON s.school_type_id = st.id
|
||||
WHERE s.is_active = TRUE
|
||||
GROUP BY st.name
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
schools_by_type = {row["type_name"]: row["count"] for row in type_rows}
|
||||
|
||||
return SchoolStatsResponse(
|
||||
total_schools=totals["total_schools"],
|
||||
total_staff=totals["total_staff"],
|
||||
schools_by_state=schools_by_state,
|
||||
schools_by_type=schools_by_type,
|
||||
schools_with_website=totals["with_website"],
|
||||
schools_with_email=totals["with_email"],
|
||||
schools_with_principal=totals["with_principal"],
|
||||
total_students=totals["total_students"],
|
||||
total_teachers=totals["total_teachers"],
|
||||
last_crawl_time=totals["last_crawl"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{school_id}", response_model=SchoolResponse)
|
||||
async def get_school(school_id: str):
|
||||
"""Get a single school by ID."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT
|
||||
s.id, s.name, s.school_number, s.state, s.district, s.city,
|
||||
s.postal_code, s.street, s.address_full, s.latitude, s.longitude,
|
||||
s.website, s.email, s.phone, s.fax,
|
||||
s.principal_name, s.principal_email,
|
||||
s.student_count, s.teacher_count,
|
||||
s.is_public, s.is_all_day, s.source, s.crawled_at,
|
||||
s.is_active, s.created_at, s.updated_at,
|
||||
st.name as school_type, st.name_short as school_type_short, st.category as school_category,
|
||||
(SELECT COUNT(*) FROM school_staff ss WHERE ss.school_id = s.id AND ss.is_active = TRUE) as staff_count
|
||||
FROM schools s
|
||||
LEFT JOIN school_types st ON s.school_type_id = st.id
|
||||
WHERE s.id = $1
|
||||
""", school_id)
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="School not found")
|
||||
|
||||
return SchoolResponse(
|
||||
id=str(row["id"]),
|
||||
name=row["name"],
|
||||
school_number=row["school_number"],
|
||||
school_type=row["school_type"],
|
||||
school_type_short=row["school_type_short"],
|
||||
school_category=row["school_category"],
|
||||
state=row["state"],
|
||||
district=row["district"],
|
||||
city=row["city"],
|
||||
postal_code=row["postal_code"],
|
||||
street=row["street"],
|
||||
address_full=row["address_full"],
|
||||
latitude=row["latitude"],
|
||||
longitude=row["longitude"],
|
||||
website=row["website"],
|
||||
email=row["email"],
|
||||
phone=row["phone"],
|
||||
fax=row["fax"],
|
||||
principal_name=row["principal_name"],
|
||||
principal_email=row["principal_email"],
|
||||
student_count=row["student_count"],
|
||||
teacher_count=row["teacher_count"],
|
||||
is_public=row["is_public"],
|
||||
is_all_day=row["is_all_day"],
|
||||
staff_count=row["staff_count"],
|
||||
source=row["source"],
|
||||
crawled_at=row["crawled_at"],
|
||||
is_active=row["is_active"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bulk-import", response_model=BulkImportResponse)
|
||||
async def bulk_import_schools(request: BulkImportRequest):
|
||||
"""Bulk import schools. Updates existing schools based on school_number + state."""
|
||||
pool = await get_db_pool()
|
||||
imported = 0
|
||||
updated = 0
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Get school type mapping
|
||||
type_rows = await conn.fetch("SELECT id, name FROM school_types")
|
||||
type_map = {row["name"].lower(): str(row["id"]) for row in type_rows}
|
||||
|
||||
for school in request.schools:
|
||||
try:
|
||||
# Find school type ID
|
||||
school_type_id = None
|
||||
if school.school_type_raw:
|
||||
school_type_id = type_map.get(school.school_type_raw.lower())
|
||||
|
||||
# Check if school exists (by school_number + state, or by name + city + state)
|
||||
existing = None
|
||||
if school.school_number:
|
||||
existing = await conn.fetchrow(
|
||||
"SELECT id FROM schools WHERE school_number = $1 AND state = $2",
|
||||
school.school_number, school.state
|
||||
)
|
||||
if not existing and school.city:
|
||||
existing = await conn.fetchrow(
|
||||
"SELECT id FROM schools WHERE LOWER(name) = LOWER($1) AND LOWER(city) = LOWER($2) AND state = $3",
|
||||
school.name, school.city, school.state
|
||||
)
|
||||
|
||||
if existing:
|
||||
# Update existing school
|
||||
await conn.execute("""
|
||||
UPDATE schools SET
|
||||
name = $2,
|
||||
school_type_id = COALESCE($3, school_type_id),
|
||||
school_type_raw = COALESCE($4, school_type_raw),
|
||||
district = COALESCE($5, district),
|
||||
city = COALESCE($6, city),
|
||||
postal_code = COALESCE($7, postal_code),
|
||||
street = COALESCE($8, street),
|
||||
address_full = COALESCE($9, address_full),
|
||||
latitude = COALESCE($10, latitude),
|
||||
longitude = COALESCE($11, longitude),
|
||||
website = COALESCE($12, website),
|
||||
email = COALESCE($13, email),
|
||||
phone = COALESCE($14, phone),
|
||||
fax = COALESCE($15, fax),
|
||||
principal_name = COALESCE($16, principal_name),
|
||||
principal_title = COALESCE($17, principal_title),
|
||||
principal_email = COALESCE($18, principal_email),
|
||||
principal_phone = COALESCE($19, principal_phone),
|
||||
student_count = COALESCE($20, student_count),
|
||||
teacher_count = COALESCE($21, teacher_count),
|
||||
is_public = $22,
|
||||
source = COALESCE($23, source),
|
||||
source_url = COALESCE($24, source_url),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
existing["id"],
|
||||
school.name,
|
||||
school_type_id,
|
||||
school.school_type_raw,
|
||||
school.district,
|
||||
school.city,
|
||||
school.postal_code,
|
||||
school.street,
|
||||
school.address_full,
|
||||
school.latitude,
|
||||
school.longitude,
|
||||
school.website,
|
||||
school.email,
|
||||
school.phone,
|
||||
school.fax,
|
||||
school.principal_name,
|
||||
school.principal_title,
|
||||
school.principal_email,
|
||||
school.principal_phone,
|
||||
school.student_count,
|
||||
school.teacher_count,
|
||||
school.is_public,
|
||||
school.source,
|
||||
school.source_url,
|
||||
)
|
||||
updated += 1
|
||||
else:
|
||||
# Insert new school
|
||||
await conn.execute("""
|
||||
INSERT INTO schools (
|
||||
name, school_number, school_type_id, school_type_raw,
|
||||
state, district, city, postal_code, street, address_full,
|
||||
latitude, longitude, website, email, phone, fax,
|
||||
principal_name, principal_title, principal_email, principal_phone,
|
||||
student_count, teacher_count, is_public,
|
||||
source, source_url, crawled_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
||||
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
|
||||
$21, $22, $23, $24, $25, NOW()
|
||||
)
|
||||
""",
|
||||
school.name,
|
||||
school.school_number,
|
||||
school_type_id,
|
||||
school.school_type_raw,
|
||||
school.state,
|
||||
school.district,
|
||||
school.city,
|
||||
school.postal_code,
|
||||
school.street,
|
||||
school.address_full,
|
||||
school.latitude,
|
||||
school.longitude,
|
||||
school.website,
|
||||
school.email,
|
||||
school.phone,
|
||||
school.fax,
|
||||
school.principal_name,
|
||||
school.principal_title,
|
||||
school.principal_email,
|
||||
school.principal_phone,
|
||||
school.student_count,
|
||||
school.teacher_count,
|
||||
school.is_public,
|
||||
school.source,
|
||||
school.source_url,
|
||||
)
|
||||
imported += 1
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error importing {school.name}: {str(e)}")
|
||||
if len(errors) > 100:
|
||||
errors.append("... (more errors truncated)")
|
||||
break
|
||||
|
||||
return BulkImportResponse(
|
||||
imported=imported,
|
||||
updated=updated,
|
||||
skipped=skipped,
|
||||
errors=errors[:100],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Schools API - Database Connection.
|
||||
|
||||
Shared database pool for school endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import asyncpg
|
||||
|
||||
# Database connection pool
|
||||
_pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
|
||||
async def get_db_pool() -> asyncpg.Pool:
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
database_url = os.environ.get(
|
||||
"DATABASE_URL",
|
||||
"postgresql://breakpilot:breakpilot123@postgres:5432/breakpilot_db"
|
||||
)
|
||||
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
|
||||
return _pool
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Schools API - Pydantic Models.
|
||||
|
||||
Data models for school and school staff endpoints.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Type Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SchoolTypeResponse(BaseModel):
|
||||
"""School type response model."""
|
||||
id: str
|
||||
name: str
|
||||
name_short: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SchoolBase(BaseModel):
|
||||
"""Base school model for creation/update."""
|
||||
name: str = Field(..., max_length=255)
|
||||
school_number: Optional[str] = Field(None, max_length=20)
|
||||
school_type_id: Optional[str] = None
|
||||
school_type_raw: Optional[str] = None
|
||||
state: str = Field(..., max_length=10)
|
||||
district: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
street: Optional[str] = None
|
||||
address_full: Optional[str] = None
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
website: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
fax: Optional[str] = None
|
||||
principal_name: Optional[str] = None
|
||||
principal_title: Optional[str] = None
|
||||
principal_email: Optional[str] = None
|
||||
principal_phone: Optional[str] = None
|
||||
secretary_name: Optional[str] = None
|
||||
secretary_email: Optional[str] = None
|
||||
secretary_phone: Optional[str] = None
|
||||
student_count: Optional[int] = None
|
||||
teacher_count: Optional[int] = None
|
||||
class_count: Optional[int] = None
|
||||
founded_year: Optional[int] = None
|
||||
is_public: bool = True
|
||||
is_all_day: Optional[bool] = None
|
||||
has_inclusion: Optional[bool] = None
|
||||
languages: Optional[List[str]] = None
|
||||
specializations: Optional[List[str]] = None
|
||||
source: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
|
||||
|
||||
class SchoolCreate(SchoolBase):
|
||||
"""School creation model."""
|
||||
pass
|
||||
|
||||
|
||||
class SchoolUpdate(BaseModel):
|
||||
"""School update model (all fields optional)."""
|
||||
name: Optional[str] = Field(None, max_length=255)
|
||||
school_number: Optional[str] = None
|
||||
school_type_id: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
district: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
street: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
principal_name: Optional[str] = None
|
||||
student_count: Optional[int] = None
|
||||
teacher_count: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class SchoolResponse(BaseModel):
|
||||
"""School response model."""
|
||||
id: str
|
||||
name: str
|
||||
school_number: Optional[str] = None
|
||||
school_type: Optional[str] = None
|
||||
school_type_short: Optional[str] = None
|
||||
school_category: Optional[str] = None
|
||||
state: str
|
||||
district: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
street: Optional[str] = None
|
||||
address_full: Optional[str] = None
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
website: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
fax: Optional[str] = None
|
||||
principal_name: Optional[str] = None
|
||||
principal_email: Optional[str] = None
|
||||
student_count: Optional[int] = None
|
||||
teacher_count: Optional[int] = None
|
||||
is_public: bool = True
|
||||
is_all_day: Optional[bool] = None
|
||||
staff_count: int = 0
|
||||
source: Optional[str] = None
|
||||
crawled_at: Optional[datetime] = None
|
||||
is_active: bool = True
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class SchoolsListResponse(BaseModel):
|
||||
"""List response with pagination info."""
|
||||
schools: List[SchoolResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class SchoolStatsResponse(BaseModel):
|
||||
"""School statistics response."""
|
||||
total_schools: int
|
||||
total_staff: int
|
||||
schools_by_state: dict
|
||||
schools_by_type: dict
|
||||
schools_with_website: int
|
||||
schools_with_email: int
|
||||
schools_with_principal: int
|
||||
total_students: int
|
||||
total_teachers: int
|
||||
last_crawl_time: Optional[datetime] = None
|
||||
|
||||
|
||||
class BulkImportRequest(BaseModel):
|
||||
"""Bulk import request."""
|
||||
schools: List[SchoolCreate]
|
||||
|
||||
|
||||
class BulkImportResponse(BaseModel):
|
||||
"""Bulk import response."""
|
||||
imported: int
|
||||
updated: int
|
||||
skipped: int
|
||||
errors: List[str]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Staff Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SchoolStaffBase(BaseModel):
|
||||
"""Base school staff model."""
|
||||
first_name: Optional[str] = None
|
||||
last_name: str
|
||||
full_name: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
position: Optional[str] = None
|
||||
position_type: Optional[str] = None
|
||||
subjects: Optional[List[str]] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
|
||||
|
||||
class SchoolStaffCreate(SchoolStaffBase):
|
||||
"""School staff creation model."""
|
||||
school_id: str
|
||||
|
||||
|
||||
class SchoolStaffResponse(SchoolStaffBase):
|
||||
"""School staff response model."""
|
||||
id: str
|
||||
school_id: str
|
||||
school_name: Optional[str] = None
|
||||
profile_url: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
is_active: bool = True
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SchoolStaffListResponse(BaseModel):
|
||||
"""Staff list response."""
|
||||
staff: List[SchoolStaffResponse]
|
||||
total: int
|
||||
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Schools API - Staff Routes.
|
||||
|
||||
CRUD and search endpoints for school staff members.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from .schools_db import get_db_pool
|
||||
from .schools_models import (
|
||||
SchoolStaffBase,
|
||||
SchoolStaffResponse,
|
||||
SchoolStaffListResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["schools"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# School Staff Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/{school_id}/staff", response_model=SchoolStaffListResponse)
|
||||
async def get_school_staff(school_id: str):
|
||||
"""Get staff members for a school."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT
|
||||
ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name,
|
||||
ss.title, ss.position, ss.position_type, ss.subjects,
|
||||
ss.email, ss.phone, ss.profile_url, ss.photo_url,
|
||||
ss.is_active, ss.created_at,
|
||||
s.name as school_name
|
||||
FROM school_staff ss
|
||||
JOIN schools s ON ss.school_id = s.id
|
||||
WHERE ss.school_id = $1 AND ss.is_active = TRUE
|
||||
ORDER BY
|
||||
CASE ss.position_type
|
||||
WHEN 'principal' THEN 1
|
||||
WHEN 'vice_principal' THEN 2
|
||||
WHEN 'secretary' THEN 3
|
||||
ELSE 4
|
||||
END,
|
||||
ss.last_name
|
||||
""", school_id)
|
||||
|
||||
staff = [
|
||||
SchoolStaffResponse(
|
||||
id=str(row["id"]),
|
||||
school_id=str(row["school_id"]),
|
||||
school_name=row["school_name"],
|
||||
first_name=row["first_name"],
|
||||
last_name=row["last_name"],
|
||||
full_name=row["full_name"],
|
||||
title=row["title"],
|
||||
position=row["position"],
|
||||
position_type=row["position_type"],
|
||||
subjects=row["subjects"],
|
||||
email=row["email"],
|
||||
phone=row["phone"],
|
||||
profile_url=row["profile_url"],
|
||||
photo_url=row["photo_url"],
|
||||
is_active=row["is_active"],
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return SchoolStaffListResponse(
|
||||
staff=staff,
|
||||
total=len(staff),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{school_id}/staff", response_model=SchoolStaffResponse)
|
||||
async def create_school_staff(school_id: str, staff: SchoolStaffBase):
|
||||
"""Add a staff member to a school."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Verify school exists
|
||||
school = await conn.fetchrow("SELECT name FROM schools WHERE id = $1", school_id)
|
||||
if not school:
|
||||
raise HTTPException(status_code=404, detail="School not found")
|
||||
|
||||
# Create full name
|
||||
full_name = staff.full_name
|
||||
if not full_name:
|
||||
parts = []
|
||||
if staff.title:
|
||||
parts.append(staff.title)
|
||||
if staff.first_name:
|
||||
parts.append(staff.first_name)
|
||||
parts.append(staff.last_name)
|
||||
full_name = " ".join(parts)
|
||||
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO school_staff (
|
||||
school_id, first_name, last_name, full_name, title,
|
||||
position, position_type, subjects, email, phone
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING id, created_at
|
||||
""",
|
||||
school_id,
|
||||
staff.first_name,
|
||||
staff.last_name,
|
||||
full_name,
|
||||
staff.title,
|
||||
staff.position,
|
||||
staff.position_type,
|
||||
staff.subjects,
|
||||
staff.email,
|
||||
staff.phone,
|
||||
)
|
||||
|
||||
return SchoolStaffResponse(
|
||||
id=str(row["id"]),
|
||||
school_id=school_id,
|
||||
school_name=school["name"],
|
||||
first_name=staff.first_name,
|
||||
last_name=staff.last_name,
|
||||
full_name=full_name,
|
||||
title=staff.title,
|
||||
position=staff.position,
|
||||
position_type=staff.position_type,
|
||||
subjects=staff.subjects,
|
||||
email=staff.email,
|
||||
phone=staff.phone,
|
||||
is_active=True,
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Search Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/search/staff", response_model=SchoolStaffListResponse)
|
||||
async def search_school_staff(
|
||||
q: Optional[str] = Query(None, description="Search query"),
|
||||
state: Optional[str] = Query(None, description="Filter by state"),
|
||||
position_type: Optional[str] = Query(None, description="Filter by position type"),
|
||||
has_email: Optional[bool] = Query(None, description="Only staff with email"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
"""Search school staff across all schools."""
|
||||
pool = await get_db_pool()
|
||||
async with pool.acquire() as conn:
|
||||
conditions = ["ss.is_active = TRUE", "s.is_active = TRUE"]
|
||||
params = []
|
||||
param_idx = 1
|
||||
|
||||
if q:
|
||||
conditions.append(f"""
|
||||
(LOWER(ss.full_name) LIKE LOWER(${param_idx})
|
||||
OR LOWER(ss.last_name) LIKE LOWER(${param_idx})
|
||||
OR LOWER(s.name) LIKE LOWER(${param_idx}))
|
||||
""")
|
||||
params.append(f"%{q}%")
|
||||
param_idx += 1
|
||||
|
||||
if state:
|
||||
conditions.append(f"s.state = ${param_idx}")
|
||||
params.append(state.upper())
|
||||
param_idx += 1
|
||||
|
||||
if position_type:
|
||||
conditions.append(f"ss.position_type = ${param_idx}")
|
||||
params.append(position_type)
|
||||
param_idx += 1
|
||||
|
||||
if has_email is not None and has_email:
|
||||
conditions.append("ss.email IS NOT NULL")
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
# Count total
|
||||
total = await conn.fetchval(f"""
|
||||
SELECT COUNT(*) FROM school_staff ss
|
||||
JOIN schools s ON ss.school_id = s.id
|
||||
WHERE {where_clause}
|
||||
""", *params)
|
||||
|
||||
# Fetch staff
|
||||
offset = (page - 1) * page_size
|
||||
rows = await conn.fetch(f"""
|
||||
SELECT
|
||||
ss.id, ss.school_id, ss.first_name, ss.last_name, ss.full_name,
|
||||
ss.title, ss.position, ss.position_type, ss.subjects,
|
||||
ss.email, ss.phone, ss.profile_url, ss.photo_url,
|
||||
ss.is_active, ss.created_at,
|
||||
s.name as school_name
|
||||
FROM school_staff ss
|
||||
JOIN schools s ON ss.school_id = s.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY ss.last_name, ss.first_name
|
||||
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
""", *params, page_size, offset)
|
||||
|
||||
staff = [
|
||||
SchoolStaffResponse(
|
||||
id=str(row["id"]),
|
||||
school_id=str(row["school_id"]),
|
||||
school_name=row["school_name"],
|
||||
first_name=row["first_name"],
|
||||
last_name=row["last_name"],
|
||||
full_name=row["full_name"],
|
||||
title=row["title"],
|
||||
position=row["position"],
|
||||
position_type=row["position_type"],
|
||||
subjects=row["subjects"],
|
||||
email=row["email"],
|
||||
phone=row["phone"],
|
||||
profile_url=row["profile_url"],
|
||||
photo_url=row["photo_url"],
|
||||
is_active=row["is_active"],
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return SchoolStaffListResponse(
|
||||
staff=staff,
|
||||
total=total,
|
||||
)
|
||||
+13
-832
@@ -1,840 +1,21 @@
|
||||
"""
|
||||
BreakPilot Messenger API
|
||||
BreakPilot Messenger API — Barrel Re-export.
|
||||
|
||||
Stellt Endpoints fuer:
|
||||
- Kontaktverwaltung (CRUD)
|
||||
- Konversationen
|
||||
- Nachrichten
|
||||
- CSV-Import fuer Kontakte
|
||||
- Gruppenmanagement
|
||||
Stellt Endpoints fuer Kontakte, Konversationen, Nachrichten,
|
||||
CSV-Import, Gruppenmanagement und Templates bereit.
|
||||
|
||||
DSGVO-konform: Alle Daten werden lokal gespeichert.
|
||||
Split into:
|
||||
- messenger_models.py: Pydantic models
|
||||
- messenger_helpers.py: JSON file storage & default templates
|
||||
- messenger_contacts.py: Contact CRUD & CSV import/export
|
||||
- messenger_conversations.py: Conversations, messages, groups, templates, stats
|
||||
"""
|
||||
|
||||
import os
|
||||
import csv
|
||||
import uuid
|
||||
import json
|
||||
from io import StringIO
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from messenger_contacts import router as _contacts_router
|
||||
from messenger_conversations import router as _conversations_router
|
||||
|
||||
router = APIRouter(prefix="/api/messenger", tags=["Messenger"])
|
||||
|
||||
# Datenspeicherung (JSON-basiert fuer einfache Persistenz)
|
||||
DATA_DIR = Path(__file__).parent / "data" / "messenger"
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
CONTACTS_FILE = DATA_DIR / "contacts.json"
|
||||
CONVERSATIONS_FILE = DATA_DIR / "conversations.json"
|
||||
MESSAGES_FILE = DATA_DIR / "messages.json"
|
||||
GROUPS_FILE = DATA_DIR / "groups.json"
|
||||
|
||||
|
||||
# ==========================================
|
||||
# PYDANTIC MODELS
|
||||
# ==========================================
|
||||
|
||||
class ContactBase(BaseModel):
|
||||
"""Basis-Modell fuer Kontakte."""
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
role: str = Field(default="parent", description="parent, teacher, staff, student")
|
||||
student_name: Optional[str] = Field(None, description="Name des zugehoerigen Schuelers")
|
||||
class_name: Optional[str] = Field(None, description="Klasse z.B. 10a")
|
||||
notes: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
matrix_id: Optional[str] = Field(None, description="Matrix-ID z.B. @user:matrix.org")
|
||||
preferred_channel: str = Field(default="email", description="email, matrix, pwa")
|
||||
|
||||
|
||||
class ContactCreate(ContactBase):
|
||||
"""Model fuer neuen Kontakt."""
|
||||
pass
|
||||
|
||||
|
||||
class Contact(ContactBase):
|
||||
"""Vollstaendiger Kontakt mit ID."""
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
online: bool = False
|
||||
last_seen: Optional[str] = None
|
||||
|
||||
|
||||
class ContactUpdate(BaseModel):
|
||||
"""Update-Model fuer Kontakte."""
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
student_name: Optional[str] = None
|
||||
class_name: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
matrix_id: Optional[str] = None
|
||||
preferred_channel: Optional[str] = None
|
||||
|
||||
|
||||
class GroupBase(BaseModel):
|
||||
"""Basis-Modell fuer Gruppen."""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
group_type: str = Field(default="class", description="class, department, custom")
|
||||
|
||||
|
||||
class GroupCreate(GroupBase):
|
||||
"""Model fuer neue Gruppe."""
|
||||
member_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Group(GroupBase):
|
||||
"""Vollstaendige Gruppe mit ID."""
|
||||
id: str
|
||||
member_ids: List[str] = []
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class MessageBase(BaseModel):
|
||||
"""Basis-Modell fuer Nachrichten."""
|
||||
content: str = Field(..., min_length=1)
|
||||
content_type: str = Field(default="text", description="text, file, image")
|
||||
file_url: Optional[str] = None
|
||||
send_email: bool = Field(default=False, description="Nachricht auch per Email senden")
|
||||
|
||||
|
||||
class MessageCreate(MessageBase):
|
||||
"""Model fuer neue Nachricht."""
|
||||
conversation_id: str
|
||||
|
||||
|
||||
class Message(MessageBase):
|
||||
"""Vollstaendige Nachricht mit ID."""
|
||||
id: str
|
||||
conversation_id: str
|
||||
sender_id: str # "self" fuer eigene Nachrichten
|
||||
timestamp: str
|
||||
read: bool = False
|
||||
read_at: Optional[str] = None
|
||||
email_sent: bool = False
|
||||
email_sent_at: Optional[str] = None
|
||||
email_error: Optional[str] = None
|
||||
|
||||
|
||||
class ConversationBase(BaseModel):
|
||||
"""Basis-Modell fuer Konversationen."""
|
||||
name: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class Conversation(ConversationBase):
|
||||
"""Vollstaendige Konversation mit ID."""
|
||||
id: str
|
||||
participant_ids: List[str] = []
|
||||
group_id: Optional[str] = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
last_message: Optional[str] = None
|
||||
last_message_time: Optional[str] = None
|
||||
unread_count: int = 0
|
||||
|
||||
|
||||
class CSVImportResult(BaseModel):
|
||||
"""Ergebnis eines CSV-Imports."""
|
||||
imported: int
|
||||
skipped: int
|
||||
errors: List[str]
|
||||
contacts: List[Contact]
|
||||
|
||||
|
||||
# ==========================================
|
||||
# DATA HELPERS
|
||||
# ==========================================
|
||||
|
||||
def load_json(filepath: Path) -> List[Dict]:
|
||||
"""Laedt JSON-Daten aus Datei."""
|
||||
if not filepath.exists():
|
||||
return []
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def save_json(filepath: Path, data: List[Dict]):
|
||||
"""Speichert Daten in JSON-Datei."""
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def get_contacts() -> List[Dict]:
|
||||
return load_json(CONTACTS_FILE)
|
||||
|
||||
|
||||
def save_contacts(contacts: List[Dict]):
|
||||
save_json(CONTACTS_FILE, contacts)
|
||||
|
||||
|
||||
def get_conversations() -> List[Dict]:
|
||||
return load_json(CONVERSATIONS_FILE)
|
||||
|
||||
|
||||
def save_conversations(conversations: List[Dict]):
|
||||
save_json(CONVERSATIONS_FILE, conversations)
|
||||
|
||||
|
||||
def get_messages() -> List[Dict]:
|
||||
return load_json(MESSAGES_FILE)
|
||||
|
||||
|
||||
def save_messages(messages: List[Dict]):
|
||||
save_json(MESSAGES_FILE, messages)
|
||||
|
||||
|
||||
def get_groups() -> List[Dict]:
|
||||
return load_json(GROUPS_FILE)
|
||||
|
||||
|
||||
def save_groups(groups: List[Dict]):
|
||||
save_json(GROUPS_FILE, groups)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# CONTACTS ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/contacts", response_model=List[Contact])
|
||||
async def list_contacts(
|
||||
role: Optional[str] = Query(None, description="Filter by role"),
|
||||
class_name: Optional[str] = Query(None, description="Filter by class"),
|
||||
search: Optional[str] = Query(None, description="Search in name/email")
|
||||
):
|
||||
"""Listet alle Kontakte auf."""
|
||||
contacts = get_contacts()
|
||||
|
||||
# Filter anwenden
|
||||
if role:
|
||||
contacts = [c for c in contacts if c.get("role") == role]
|
||||
if class_name:
|
||||
contacts = [c for c in contacts if c.get("class_name") == class_name]
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
contacts = [c for c in contacts if
|
||||
search_lower in c.get("name", "").lower() or
|
||||
search_lower in (c.get("email") or "").lower() or
|
||||
search_lower in (c.get("student_name") or "").lower()]
|
||||
|
||||
return contacts
|
||||
|
||||
|
||||
@router.post("/contacts", response_model=Contact)
|
||||
async def create_contact(contact: ContactCreate):
|
||||
"""Erstellt einen neuen Kontakt."""
|
||||
contacts = get_contacts()
|
||||
|
||||
# Pruefen ob Email bereits existiert
|
||||
if contact.email:
|
||||
existing = [c for c in contacts if c.get("email") == contact.email]
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Kontakt mit dieser Email existiert bereits")
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
new_contact = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"online": False,
|
||||
"last_seen": None,
|
||||
**contact.dict()
|
||||
}
|
||||
|
||||
contacts.append(new_contact)
|
||||
save_contacts(contacts)
|
||||
|
||||
return new_contact
|
||||
|
||||
|
||||
@router.get("/contacts/{contact_id}", response_model=Contact)
|
||||
async def get_contact(contact_id: str):
|
||||
"""Ruft einen einzelnen Kontakt ab."""
|
||||
contacts = get_contacts()
|
||||
contact = next((c for c in contacts if c["id"] == contact_id), None)
|
||||
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
|
||||
|
||||
return contact
|
||||
|
||||
|
||||
@router.put("/contacts/{contact_id}", response_model=Contact)
|
||||
async def update_contact(contact_id: str, update: ContactUpdate):
|
||||
"""Aktualisiert einen Kontakt."""
|
||||
contacts = get_contacts()
|
||||
contact_idx = next((i for i, c in enumerate(contacts) if c["id"] == contact_id), None)
|
||||
|
||||
if contact_idx is None:
|
||||
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
|
||||
|
||||
update_data = update.dict(exclude_unset=True)
|
||||
contacts[contact_idx].update(update_data)
|
||||
contacts[contact_idx]["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
save_contacts(contacts)
|
||||
return contacts[contact_idx]
|
||||
|
||||
|
||||
@router.delete("/contacts/{contact_id}")
|
||||
async def delete_contact(contact_id: str):
|
||||
"""Loescht einen Kontakt."""
|
||||
contacts = get_contacts()
|
||||
contacts = [c for c in contacts if c["id"] != contact_id]
|
||||
save_contacts(contacts)
|
||||
|
||||
return {"status": "deleted", "id": contact_id}
|
||||
|
||||
|
||||
@router.post("/contacts/import", response_model=CSVImportResult)
|
||||
async def import_contacts_csv(file: UploadFile = File(...)):
|
||||
"""
|
||||
Importiert Kontakte aus einer CSV-Datei.
|
||||
|
||||
Erwartete Spalten:
|
||||
- name (required)
|
||||
- email
|
||||
- phone
|
||||
- role (parent/teacher/staff/student)
|
||||
- student_name
|
||||
- class_name
|
||||
- notes
|
||||
- tags (komma-separiert)
|
||||
"""
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise HTTPException(status_code=400, detail="Nur CSV-Dateien werden unterstuetzt")
|
||||
|
||||
content = await file.read()
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
text = content.decode('latin-1')
|
||||
|
||||
contacts = get_contacts()
|
||||
existing_emails = {c.get("email") for c in contacts if c.get("email")}
|
||||
|
||||
imported = []
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
reader = csv.DictReader(StringIO(text), delimiter=';') # Deutsche CSV meist mit Semikolon
|
||||
if not reader.fieldnames or 'name' not in [f.lower() for f in reader.fieldnames]:
|
||||
# Versuche mit Komma
|
||||
reader = csv.DictReader(StringIO(text), delimiter=',')
|
||||
|
||||
for row_num, row in enumerate(reader, start=2):
|
||||
try:
|
||||
# Normalisiere Spaltennamen
|
||||
row = {k.lower().strip(): v.strip() if v else "" for k, v in row.items()}
|
||||
|
||||
name = row.get('name') or row.get('kontakt') or row.get('elternname')
|
||||
if not name:
|
||||
errors.append(f"Zeile {row_num}: Name fehlt")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
email = row.get('email') or row.get('e-mail') or row.get('mail')
|
||||
if email and email in existing_emails:
|
||||
errors.append(f"Zeile {row_num}: Email {email} existiert bereits")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
tags_str = row.get('tags') or row.get('kategorien') or ""
|
||||
tags = [t.strip() for t in tags_str.split(',') if t.strip()]
|
||||
|
||||
# Matrix-ID und preferred_channel auslesen
|
||||
matrix_id = row.get('matrix_id') or row.get('matrix') or None
|
||||
preferred_channel = row.get('preferred_channel') or row.get('kanal') or "email"
|
||||
if preferred_channel not in ["email", "matrix", "pwa"]:
|
||||
preferred_channel = "email"
|
||||
|
||||
new_contact = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"email": email if email else None,
|
||||
"phone": row.get('phone') or row.get('telefon') or row.get('tel'),
|
||||
"role": row.get('role') or row.get('rolle') or "parent",
|
||||
"student_name": row.get('student_name') or row.get('schueler') or row.get('kind'),
|
||||
"class_name": row.get('class_name') or row.get('klasse'),
|
||||
"notes": row.get('notes') or row.get('notizen') or row.get('bemerkungen'),
|
||||
"tags": tags,
|
||||
"matrix_id": matrix_id if matrix_id else None,
|
||||
"preferred_channel": preferred_channel,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"online": False,
|
||||
"last_seen": None
|
||||
}
|
||||
|
||||
contacts.append(new_contact)
|
||||
imported.append(new_contact)
|
||||
if email:
|
||||
existing_emails.add(email)
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Zeile {row_num}: {str(e)}")
|
||||
skipped += 1
|
||||
|
||||
save_contacts(contacts)
|
||||
|
||||
return CSVImportResult(
|
||||
imported=len(imported),
|
||||
skipped=skipped,
|
||||
errors=errors[:20], # Maximal 20 Fehler zurueckgeben
|
||||
contacts=imported
|
||||
)
|
||||
|
||||
|
||||
@router.get("/contacts/export/csv")
|
||||
async def export_contacts_csv():
|
||||
"""Exportiert alle Kontakte als CSV."""
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
contacts = get_contacts()
|
||||
|
||||
output = StringIO()
|
||||
fieldnames = ['name', 'email', 'phone', 'role', 'student_name', 'class_name', 'notes', 'tags', 'matrix_id', 'preferred_channel']
|
||||
writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=';')
|
||||
writer.writeheader()
|
||||
|
||||
for contact in contacts:
|
||||
writer.writerow({
|
||||
'name': contact.get('name', ''),
|
||||
'email': contact.get('email', ''),
|
||||
'phone': contact.get('phone', ''),
|
||||
'role': contact.get('role', ''),
|
||||
'student_name': contact.get('student_name', ''),
|
||||
'class_name': contact.get('class_name', ''),
|
||||
'notes': contact.get('notes', ''),
|
||||
'tags': ','.join(contact.get('tags', [])),
|
||||
'matrix_id': contact.get('matrix_id', ''),
|
||||
'preferred_channel': contact.get('preferred_channel', 'email')
|
||||
})
|
||||
|
||||
output.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=kontakte.csv"}
|
||||
)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# GROUPS ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/groups", response_model=List[Group])
|
||||
async def list_groups():
|
||||
"""Listet alle Gruppen auf."""
|
||||
return get_groups()
|
||||
|
||||
|
||||
@router.post("/groups", response_model=Group)
|
||||
async def create_group(group: GroupCreate):
|
||||
"""Erstellt eine neue Gruppe."""
|
||||
groups = get_groups()
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
new_group = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
**group.dict()
|
||||
}
|
||||
|
||||
groups.append(new_group)
|
||||
save_groups(groups)
|
||||
|
||||
return new_group
|
||||
|
||||
|
||||
@router.put("/groups/{group_id}/members")
|
||||
async def update_group_members(group_id: str, member_ids: List[str]):
|
||||
"""Aktualisiert die Mitglieder einer Gruppe."""
|
||||
groups = get_groups()
|
||||
group_idx = next((i for i, g in enumerate(groups) if g["id"] == group_id), None)
|
||||
|
||||
if group_idx is None:
|
||||
raise HTTPException(status_code=404, detail="Gruppe nicht gefunden")
|
||||
|
||||
groups[group_idx]["member_ids"] = member_ids
|
||||
groups[group_idx]["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
save_groups(groups)
|
||||
return groups[group_idx]
|
||||
|
||||
|
||||
@router.delete("/groups/{group_id}")
|
||||
async def delete_group(group_id: str):
|
||||
"""Loescht eine Gruppe."""
|
||||
groups = get_groups()
|
||||
groups = [g for g in groups if g["id"] != group_id]
|
||||
save_groups(groups)
|
||||
|
||||
return {"status": "deleted", "id": group_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# CONVERSATIONS ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/conversations", response_model=List[Conversation])
|
||||
async def list_conversations():
|
||||
"""Listet alle Konversationen auf."""
|
||||
conversations = get_conversations()
|
||||
messages = get_messages()
|
||||
|
||||
# Unread count und letzte Nachricht hinzufuegen
|
||||
for conv in conversations:
|
||||
conv_messages = [m for m in messages if m.get("conversation_id") == conv["id"]]
|
||||
conv["unread_count"] = len([m for m in conv_messages if not m.get("read") and m.get("sender_id") != "self"])
|
||||
|
||||
if conv_messages:
|
||||
last_msg = max(conv_messages, key=lambda m: m.get("timestamp", ""))
|
||||
conv["last_message"] = last_msg.get("content", "")[:50]
|
||||
conv["last_message_time"] = last_msg.get("timestamp")
|
||||
|
||||
# Nach letzter Nachricht sortieren
|
||||
conversations.sort(key=lambda c: c.get("last_message_time") or "", reverse=True)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
@router.post("/conversations", response_model=Conversation)
|
||||
async def create_conversation(contact_id: Optional[str] = None, group_id: Optional[str] = None):
|
||||
"""
|
||||
Erstellt eine neue Konversation.
|
||||
Entweder mit einem Kontakt (1:1) oder einer Gruppe.
|
||||
"""
|
||||
conversations = get_conversations()
|
||||
|
||||
if not contact_id and not group_id:
|
||||
raise HTTPException(status_code=400, detail="Entweder contact_id oder group_id erforderlich")
|
||||
|
||||
# Pruefen ob Konversation bereits existiert
|
||||
if contact_id:
|
||||
existing = next((c for c in conversations
|
||||
if not c.get("is_group") and contact_id in c.get("participant_ids", [])), None)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
if group_id:
|
||||
groups = get_groups()
|
||||
group = next((g for g in groups if g["id"] == group_id), None)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Gruppe nicht gefunden")
|
||||
|
||||
new_conv = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": group.get("name"),
|
||||
"is_group": True,
|
||||
"participant_ids": group.get("member_ids", []),
|
||||
"group_id": group_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_message": None,
|
||||
"last_message_time": None,
|
||||
"unread_count": 0
|
||||
}
|
||||
else:
|
||||
contacts = get_contacts()
|
||||
contact = next((c for c in contacts if c["id"] == contact_id), None)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
|
||||
|
||||
new_conv = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": contact.get("name"),
|
||||
"is_group": False,
|
||||
"participant_ids": [contact_id],
|
||||
"group_id": None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_message": None,
|
||||
"last_message_time": None,
|
||||
"unread_count": 0
|
||||
}
|
||||
|
||||
conversations.append(new_conv)
|
||||
save_conversations(conversations)
|
||||
|
||||
return new_conv
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=Conversation)
|
||||
async def get_conversation(conversation_id: str):
|
||||
"""Ruft eine Konversation ab."""
|
||||
conversations = get_conversations()
|
||||
conv = next((c for c in conversations if c["id"] == conversation_id), None)
|
||||
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="Konversation nicht gefunden")
|
||||
|
||||
return conv
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
async def delete_conversation(conversation_id: str):
|
||||
"""Loescht eine Konversation und alle zugehoerigen Nachrichten."""
|
||||
conversations = get_conversations()
|
||||
conversations = [c for c in conversations if c["id"] != conversation_id]
|
||||
save_conversations(conversations)
|
||||
|
||||
messages = get_messages()
|
||||
messages = [m for m in messages if m.get("conversation_id") != conversation_id]
|
||||
save_messages(messages)
|
||||
|
||||
return {"status": "deleted", "id": conversation_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# MESSAGES ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/conversations/{conversation_id}/messages", response_model=List[Message])
|
||||
async def list_messages(
|
||||
conversation_id: str,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
before: Optional[str] = Query(None, description="Load messages before this timestamp")
|
||||
):
|
||||
"""Ruft Nachrichten einer Konversation ab."""
|
||||
messages = get_messages()
|
||||
conv_messages = [m for m in messages if m.get("conversation_id") == conversation_id]
|
||||
|
||||
if before:
|
||||
conv_messages = [m for m in conv_messages if m.get("timestamp", "") < before]
|
||||
|
||||
# Nach Zeit sortieren (neueste zuletzt)
|
||||
conv_messages.sort(key=lambda m: m.get("timestamp", ""))
|
||||
|
||||
return conv_messages[-limit:]
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/messages", response_model=Message)
|
||||
async def send_message(conversation_id: str, message: MessageBase):
|
||||
"""
|
||||
Sendet eine Nachricht in einer Konversation.
|
||||
|
||||
Wenn send_email=True und der Kontakt eine Email-Adresse hat,
|
||||
wird die Nachricht auch per Email versendet.
|
||||
"""
|
||||
conversations = get_conversations()
|
||||
conv = next((c for c in conversations if c["id"] == conversation_id), None)
|
||||
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="Konversation nicht gefunden")
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
new_message = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"conversation_id": conversation_id,
|
||||
"sender_id": "self",
|
||||
"timestamp": now,
|
||||
"read": True,
|
||||
"read_at": now,
|
||||
"email_sent": False,
|
||||
"email_sent_at": None,
|
||||
"email_error": None,
|
||||
**message.dict()
|
||||
}
|
||||
|
||||
# Email-Versand wenn gewuenscht
|
||||
if message.send_email and not conv.get("is_group"):
|
||||
# Kontakt laden
|
||||
participant_ids = conv.get("participant_ids", [])
|
||||
if participant_ids:
|
||||
contacts = get_contacts()
|
||||
contact = next((c for c in contacts if c["id"] == participant_ids[0]), None)
|
||||
|
||||
if contact and contact.get("email"):
|
||||
try:
|
||||
from email_service import email_service
|
||||
|
||||
result = email_service.send_messenger_notification(
|
||||
to_email=contact["email"],
|
||||
to_name=contact.get("name", ""),
|
||||
sender_name="BreakPilot Lehrer", # TODO: Aktuellen User-Namen verwenden
|
||||
message_content=message.content
|
||||
)
|
||||
|
||||
if result.success:
|
||||
new_message["email_sent"] = True
|
||||
new_message["email_sent_at"] = result.sent_at
|
||||
else:
|
||||
new_message["email_error"] = result.error
|
||||
|
||||
except Exception as e:
|
||||
new_message["email_error"] = str(e)
|
||||
|
||||
messages = get_messages()
|
||||
messages.append(new_message)
|
||||
save_messages(messages)
|
||||
|
||||
# Konversation aktualisieren
|
||||
conv_idx = next(i for i, c in enumerate(conversations) if c["id"] == conversation_id)
|
||||
conversations[conv_idx]["last_message"] = message.content[:50]
|
||||
conversations[conv_idx]["last_message_time"] = now
|
||||
conversations[conv_idx]["updated_at"] = now
|
||||
save_conversations(conversations)
|
||||
|
||||
return new_message
|
||||
|
||||
|
||||
@router.put("/messages/{message_id}/read")
|
||||
async def mark_message_read(message_id: str):
|
||||
"""Markiert eine Nachricht als gelesen."""
|
||||
messages = get_messages()
|
||||
msg_idx = next((i for i, m in enumerate(messages) if m["id"] == message_id), None)
|
||||
|
||||
if msg_idx is None:
|
||||
raise HTTPException(status_code=404, detail="Nachricht nicht gefunden")
|
||||
|
||||
messages[msg_idx]["read"] = True
|
||||
messages[msg_idx]["read_at"] = datetime.utcnow().isoformat()
|
||||
save_messages(messages)
|
||||
|
||||
return {"status": "read", "id": message_id}
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/read-all")
|
||||
async def mark_all_messages_read(conversation_id: str):
|
||||
"""Markiert alle Nachrichten einer Konversation als gelesen."""
|
||||
messages = get_messages()
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("conversation_id") == conversation_id and not msg.get("read"):
|
||||
msg["read"] = True
|
||||
msg["read_at"] = now
|
||||
|
||||
save_messages(messages)
|
||||
|
||||
return {"status": "all_read", "conversation_id": conversation_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# TEMPLATES ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
DEFAULT_TEMPLATES = [
|
||||
{
|
||||
"id": "1",
|
||||
"name": "Terminbestaetigung",
|
||||
"content": "Vielen Dank fuer Ihre Terminanfrage. Ich bestaetige den Termin am [DATUM] um [UHRZEIT]. Bitte geben Sie mir Bescheid, falls sich etwas aendern sollte.",
|
||||
"category": "termin"
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"name": "Hausaufgaben-Info",
|
||||
"content": "Zur Information: Die Hausaufgaben fuer diese Woche umfassen [THEMA]. Abgabetermin ist [DATUM]. Bei Fragen stehe ich gerne zur Verfuegung.",
|
||||
"category": "hausaufgaben"
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"name": "Entschuldigung bestaetigen",
|
||||
"content": "Ich bestaetige den Erhalt der Entschuldigung fuer [NAME] am [DATUM]. Die Fehlzeiten wurden entsprechend vermerkt.",
|
||||
"category": "entschuldigung"
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"name": "Gespraechsanfrage",
|
||||
"content": "Ich wuerde gerne einen Termin fuer ein Gespraech mit Ihnen vereinbaren, um [THEMA] zu besprechen. Waeren Sie am [DATUM] um [UHRZEIT] verfuegbar?",
|
||||
"category": "gespraech"
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"name": "Krankmeldung bestaetigen",
|
||||
"content": "Vielen Dank fuer Ihre Krankmeldung fuer [NAME]. Ich wuensche gute Besserung. Bitte reichen Sie eine schriftliche Entschuldigung nach, sobald Ihr Kind wieder gesund ist.",
|
||||
"category": "krankmeldung"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@router.get("/templates")
|
||||
async def list_templates():
|
||||
"""Listet alle Nachrichtenvorlagen auf."""
|
||||
templates_file = DATA_DIR / "templates.json"
|
||||
if templates_file.exists():
|
||||
templates = load_json(templates_file)
|
||||
else:
|
||||
templates = DEFAULT_TEMPLATES
|
||||
save_json(templates_file, templates)
|
||||
|
||||
return templates
|
||||
|
||||
|
||||
@router.post("/templates")
|
||||
async def create_template(name: str, content: str, category: str = "custom"):
|
||||
"""Erstellt eine neue Vorlage."""
|
||||
templates_file = DATA_DIR / "templates.json"
|
||||
templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy()
|
||||
|
||||
new_template = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"content": content,
|
||||
"category": category
|
||||
}
|
||||
|
||||
templates.append(new_template)
|
||||
save_json(templates_file, templates)
|
||||
|
||||
return new_template
|
||||
|
||||
|
||||
@router.delete("/templates/{template_id}")
|
||||
async def delete_template(template_id: str):
|
||||
"""Loescht eine Vorlage."""
|
||||
templates_file = DATA_DIR / "templates.json"
|
||||
templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy()
|
||||
|
||||
templates = [t for t in templates if t["id"] != template_id]
|
||||
save_json(templates_file, templates)
|
||||
|
||||
return {"status": "deleted", "id": template_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# STATS ENDPOINT
|
||||
# ==========================================
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_messenger_stats():
|
||||
"""Gibt Statistiken zum Messenger zurueck."""
|
||||
contacts = get_contacts()
|
||||
conversations = get_conversations()
|
||||
messages = get_messages()
|
||||
groups = get_groups()
|
||||
|
||||
unread_total = sum(1 for m in messages if not m.get("read") and m.get("sender_id") != "self")
|
||||
|
||||
return {
|
||||
"total_contacts": len(contacts),
|
||||
"total_groups": len(groups),
|
||||
"total_conversations": len(conversations),
|
||||
"total_messages": len(messages),
|
||||
"unread_messages": unread_total,
|
||||
"contacts_by_role": {
|
||||
role: len([c for c in contacts if c.get("role") == role])
|
||||
for role in set(c.get("role", "parent") for c in contacts)
|
||||
}
|
||||
}
|
||||
router.include_router(_contacts_router)
|
||||
router.include_router(_conversations_router)
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Messenger API - Contact Routes.
|
||||
|
||||
CRUD, CSV import/export for contacts.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import uuid
|
||||
from io import StringIO
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from messenger_models import (
|
||||
Contact,
|
||||
ContactCreate,
|
||||
ContactUpdate,
|
||||
CSVImportResult,
|
||||
)
|
||||
from messenger_helpers import get_contacts, save_contacts
|
||||
|
||||
router = APIRouter(tags=["Messenger"])
|
||||
|
||||
|
||||
# ==========================================
|
||||
# CONTACTS ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/contacts", response_model=List[Contact])
|
||||
async def list_contacts(
|
||||
role: Optional[str] = Query(None, description="Filter by role"),
|
||||
class_name: Optional[str] = Query(None, description="Filter by class"),
|
||||
search: Optional[str] = Query(None, description="Search in name/email")
|
||||
):
|
||||
"""Listet alle Kontakte auf."""
|
||||
contacts = get_contacts()
|
||||
|
||||
# Filter anwenden
|
||||
if role:
|
||||
contacts = [c for c in contacts if c.get("role") == role]
|
||||
if class_name:
|
||||
contacts = [c for c in contacts if c.get("class_name") == class_name]
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
contacts = [c for c in contacts if
|
||||
search_lower in c.get("name", "").lower() or
|
||||
search_lower in (c.get("email") or "").lower() or
|
||||
search_lower in (c.get("student_name") or "").lower()]
|
||||
|
||||
return contacts
|
||||
|
||||
|
||||
@router.post("/contacts", response_model=Contact)
|
||||
async def create_contact(contact: ContactCreate):
|
||||
"""Erstellt einen neuen Kontakt."""
|
||||
contacts = get_contacts()
|
||||
|
||||
# Pruefen ob Email bereits existiert
|
||||
if contact.email:
|
||||
existing = [c for c in contacts if c.get("email") == contact.email]
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Kontakt mit dieser Email existiert bereits")
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
new_contact = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"online": False,
|
||||
"last_seen": None,
|
||||
**contact.dict()
|
||||
}
|
||||
|
||||
contacts.append(new_contact)
|
||||
save_contacts(contacts)
|
||||
|
||||
return new_contact
|
||||
|
||||
|
||||
@router.get("/contacts/{contact_id}", response_model=Contact)
|
||||
async def get_contact(contact_id: str):
|
||||
"""Ruft einen einzelnen Kontakt ab."""
|
||||
contacts = get_contacts()
|
||||
contact = next((c for c in contacts if c["id"] == contact_id), None)
|
||||
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
|
||||
|
||||
return contact
|
||||
|
||||
|
||||
@router.put("/contacts/{contact_id}", response_model=Contact)
|
||||
async def update_contact(contact_id: str, update: ContactUpdate):
|
||||
"""Aktualisiert einen Kontakt."""
|
||||
contacts = get_contacts()
|
||||
contact_idx = next((i for i, c in enumerate(contacts) if c["id"] == contact_id), None)
|
||||
|
||||
if contact_idx is None:
|
||||
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
|
||||
|
||||
update_data = update.dict(exclude_unset=True)
|
||||
contacts[contact_idx].update(update_data)
|
||||
contacts[contact_idx]["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
save_contacts(contacts)
|
||||
return contacts[contact_idx]
|
||||
|
||||
|
||||
@router.delete("/contacts/{contact_id}")
|
||||
async def delete_contact(contact_id: str):
|
||||
"""Loescht einen Kontakt."""
|
||||
contacts = get_contacts()
|
||||
contacts = [c for c in contacts if c["id"] != contact_id]
|
||||
save_contacts(contacts)
|
||||
|
||||
return {"status": "deleted", "id": contact_id}
|
||||
|
||||
|
||||
@router.post("/contacts/import", response_model=CSVImportResult)
|
||||
async def import_contacts_csv(file: UploadFile = File(...)):
|
||||
"""
|
||||
Importiert Kontakte aus einer CSV-Datei.
|
||||
|
||||
Erwartete Spalten:
|
||||
- name (required)
|
||||
- email
|
||||
- phone
|
||||
- role (parent/teacher/staff/student)
|
||||
- student_name
|
||||
- class_name
|
||||
- notes
|
||||
- tags (komma-separiert)
|
||||
"""
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise HTTPException(status_code=400, detail="Nur CSV-Dateien werden unterstuetzt")
|
||||
|
||||
content = await file.read()
|
||||
try:
|
||||
text = content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
text = content.decode('latin-1')
|
||||
|
||||
contacts = get_contacts()
|
||||
existing_emails = {c.get("email") for c in contacts if c.get("email")}
|
||||
|
||||
imported = []
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
reader = csv.DictReader(StringIO(text), delimiter=';') # Deutsche CSV meist mit Semikolon
|
||||
if not reader.fieldnames or 'name' not in [f.lower() for f in reader.fieldnames]:
|
||||
# Versuche mit Komma
|
||||
reader = csv.DictReader(StringIO(text), delimiter=',')
|
||||
|
||||
for row_num, row in enumerate(reader, start=2):
|
||||
try:
|
||||
# Normalisiere Spaltennamen
|
||||
row = {k.lower().strip(): v.strip() if v else "" for k, v in row.items()}
|
||||
|
||||
name = row.get('name') or row.get('kontakt') or row.get('elternname')
|
||||
if not name:
|
||||
errors.append(f"Zeile {row_num}: Name fehlt")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
email = row.get('email') or row.get('e-mail') or row.get('mail')
|
||||
if email and email in existing_emails:
|
||||
errors.append(f"Zeile {row_num}: Email {email} existiert bereits")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
tags_str = row.get('tags') or row.get('kategorien') or ""
|
||||
tags = [t.strip() for t in tags_str.split(',') if t.strip()]
|
||||
|
||||
# Matrix-ID und preferred_channel auslesen
|
||||
matrix_id = row.get('matrix_id') or row.get('matrix') or None
|
||||
preferred_channel = row.get('preferred_channel') or row.get('kanal') or "email"
|
||||
if preferred_channel not in ["email", "matrix", "pwa"]:
|
||||
preferred_channel = "email"
|
||||
|
||||
new_contact = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"email": email if email else None,
|
||||
"phone": row.get('phone') or row.get('telefon') or row.get('tel'),
|
||||
"role": row.get('role') or row.get('rolle') or "parent",
|
||||
"student_name": row.get('student_name') or row.get('schueler') or row.get('kind'),
|
||||
"class_name": row.get('class_name') or row.get('klasse'),
|
||||
"notes": row.get('notes') or row.get('notizen') or row.get('bemerkungen'),
|
||||
"tags": tags,
|
||||
"matrix_id": matrix_id if matrix_id else None,
|
||||
"preferred_channel": preferred_channel,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"online": False,
|
||||
"last_seen": None
|
||||
}
|
||||
|
||||
contacts.append(new_contact)
|
||||
imported.append(new_contact)
|
||||
if email:
|
||||
existing_emails.add(email)
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Zeile {row_num}: {str(e)}")
|
||||
skipped += 1
|
||||
|
||||
save_contacts(contacts)
|
||||
|
||||
return CSVImportResult(
|
||||
imported=len(imported),
|
||||
skipped=skipped,
|
||||
errors=errors[:20], # Maximal 20 Fehler zurueckgeben
|
||||
contacts=imported
|
||||
)
|
||||
|
||||
|
||||
@router.get("/contacts/export/csv")
|
||||
async def export_contacts_csv():
|
||||
"""Exportiert alle Kontakte als CSV."""
|
||||
contacts = get_contacts()
|
||||
|
||||
output = StringIO()
|
||||
fieldnames = ['name', 'email', 'phone', 'role', 'student_name', 'class_name', 'notes', 'tags', 'matrix_id', 'preferred_channel']
|
||||
writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=';')
|
||||
writer.writeheader()
|
||||
|
||||
for contact in contacts:
|
||||
writer.writerow({
|
||||
'name': contact.get('name', ''),
|
||||
'email': contact.get('email', ''),
|
||||
'phone': contact.get('phone', ''),
|
||||
'role': contact.get('role', ''),
|
||||
'student_name': contact.get('student_name', ''),
|
||||
'class_name': contact.get('class_name', ''),
|
||||
'notes': contact.get('notes', ''),
|
||||
'tags': ','.join(contact.get('tags', [])),
|
||||
'matrix_id': contact.get('matrix_id', ''),
|
||||
'preferred_channel': contact.get('preferred_channel', 'email')
|
||||
})
|
||||
|
||||
output.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=kontakte.csv"}
|
||||
)
|
||||
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Messenger API - Conversation, Message, Group, Template & Stats Routes.
|
||||
|
||||
Conversations CRUD, message send/read, groups, templates, stats.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from messenger_models import (
|
||||
Conversation,
|
||||
Group,
|
||||
GroupCreate,
|
||||
Message,
|
||||
MessageBase,
|
||||
)
|
||||
from messenger_helpers import (
|
||||
DATA_DIR,
|
||||
DEFAULT_TEMPLATES,
|
||||
get_contacts,
|
||||
get_conversations,
|
||||
save_conversations,
|
||||
get_messages,
|
||||
save_messages,
|
||||
get_groups,
|
||||
save_groups,
|
||||
load_json,
|
||||
save_json,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Messenger"])
|
||||
|
||||
|
||||
# ==========================================
|
||||
# GROUPS ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/groups", response_model=List[Group])
|
||||
async def list_groups():
|
||||
"""Listet alle Gruppen auf."""
|
||||
return get_groups()
|
||||
|
||||
|
||||
@router.post("/groups", response_model=Group)
|
||||
async def create_group(group: GroupCreate):
|
||||
"""Erstellt eine neue Gruppe."""
|
||||
groups = get_groups()
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
new_group = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
**group.dict()
|
||||
}
|
||||
|
||||
groups.append(new_group)
|
||||
save_groups(groups)
|
||||
|
||||
return new_group
|
||||
|
||||
|
||||
@router.put("/groups/{group_id}/members")
|
||||
async def update_group_members(group_id: str, member_ids: List[str]):
|
||||
"""Aktualisiert die Mitglieder einer Gruppe."""
|
||||
groups = get_groups()
|
||||
group_idx = next((i for i, g in enumerate(groups) if g["id"] == group_id), None)
|
||||
|
||||
if group_idx is None:
|
||||
raise HTTPException(status_code=404, detail="Gruppe nicht gefunden")
|
||||
|
||||
groups[group_idx]["member_ids"] = member_ids
|
||||
groups[group_idx]["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
save_groups(groups)
|
||||
return groups[group_idx]
|
||||
|
||||
|
||||
@router.delete("/groups/{group_id}")
|
||||
async def delete_group(group_id: str):
|
||||
"""Loescht eine Gruppe."""
|
||||
groups = get_groups()
|
||||
groups = [g for g in groups if g["id"] != group_id]
|
||||
save_groups(groups)
|
||||
|
||||
return {"status": "deleted", "id": group_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# CONVERSATIONS ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/conversations", response_model=List[Conversation])
|
||||
async def list_conversations():
|
||||
"""Listet alle Konversationen auf."""
|
||||
conversations = get_conversations()
|
||||
messages = get_messages()
|
||||
|
||||
# Unread count und letzte Nachricht hinzufuegen
|
||||
for conv in conversations:
|
||||
conv_messages = [m for m in messages if m.get("conversation_id") == conv["id"]]
|
||||
conv["unread_count"] = len([m for m in conv_messages if not m.get("read") and m.get("sender_id") != "self"])
|
||||
|
||||
if conv_messages:
|
||||
last_msg = max(conv_messages, key=lambda m: m.get("timestamp", ""))
|
||||
conv["last_message"] = last_msg.get("content", "")[:50]
|
||||
conv["last_message_time"] = last_msg.get("timestamp")
|
||||
|
||||
# Nach letzter Nachricht sortieren
|
||||
conversations.sort(key=lambda c: c.get("last_message_time") or "", reverse=True)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
@router.post("/conversations", response_model=Conversation)
|
||||
async def create_conversation(contact_id: Optional[str] = None, group_id: Optional[str] = None):
|
||||
"""
|
||||
Erstellt eine neue Konversation.
|
||||
Entweder mit einem Kontakt (1:1) oder einer Gruppe.
|
||||
"""
|
||||
conversations = get_conversations()
|
||||
|
||||
if not contact_id and not group_id:
|
||||
raise HTTPException(status_code=400, detail="Entweder contact_id oder group_id erforderlich")
|
||||
|
||||
# Pruefen ob Konversation bereits existiert
|
||||
if contact_id:
|
||||
existing = next((c for c in conversations
|
||||
if not c.get("is_group") and contact_id in c.get("participant_ids", [])), None)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
if group_id:
|
||||
groups = get_groups()
|
||||
group = next((g for g in groups if g["id"] == group_id), None)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Gruppe nicht gefunden")
|
||||
|
||||
new_conv = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": group.get("name"),
|
||||
"is_group": True,
|
||||
"participant_ids": group.get("member_ids", []),
|
||||
"group_id": group_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_message": None,
|
||||
"last_message_time": None,
|
||||
"unread_count": 0
|
||||
}
|
||||
else:
|
||||
contacts = get_contacts()
|
||||
contact = next((c for c in contacts if c["id"] == contact_id), None)
|
||||
if not contact:
|
||||
raise HTTPException(status_code=404, detail="Kontakt nicht gefunden")
|
||||
|
||||
new_conv = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": contact.get("name"),
|
||||
"is_group": False,
|
||||
"participant_ids": [contact_id],
|
||||
"group_id": None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_message": None,
|
||||
"last_message_time": None,
|
||||
"unread_count": 0
|
||||
}
|
||||
|
||||
conversations.append(new_conv)
|
||||
save_conversations(conversations)
|
||||
|
||||
return new_conv
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=Conversation)
|
||||
async def get_conversation(conversation_id: str):
|
||||
"""Ruft eine Konversation ab."""
|
||||
conversations = get_conversations()
|
||||
conv = next((c for c in conversations if c["id"] == conversation_id), None)
|
||||
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="Konversation nicht gefunden")
|
||||
|
||||
return conv
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
async def delete_conversation(conversation_id: str):
|
||||
"""Loescht eine Konversation und alle zugehoerigen Nachrichten."""
|
||||
conversations = get_conversations()
|
||||
conversations = [c for c in conversations if c["id"] != conversation_id]
|
||||
save_conversations(conversations)
|
||||
|
||||
messages = get_messages()
|
||||
messages = [m for m in messages if m.get("conversation_id") != conversation_id]
|
||||
save_messages(messages)
|
||||
|
||||
return {"status": "deleted", "id": conversation_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# MESSAGES ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/conversations/{conversation_id}/messages", response_model=List[Message])
|
||||
async def list_messages(
|
||||
conversation_id: str,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
before: Optional[str] = Query(None, description="Load messages before this timestamp")
|
||||
):
|
||||
"""Ruft Nachrichten einer Konversation ab."""
|
||||
messages = get_messages()
|
||||
conv_messages = [m for m in messages if m.get("conversation_id") == conversation_id]
|
||||
|
||||
if before:
|
||||
conv_messages = [m for m in conv_messages if m.get("timestamp", "") < before]
|
||||
|
||||
# Nach Zeit sortieren (neueste zuletzt)
|
||||
conv_messages.sort(key=lambda m: m.get("timestamp", ""))
|
||||
|
||||
return conv_messages[-limit:]
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/messages", response_model=Message)
|
||||
async def send_message(conversation_id: str, message: MessageBase):
|
||||
"""
|
||||
Sendet eine Nachricht in einer Konversation.
|
||||
|
||||
Wenn send_email=True und der Kontakt eine Email-Adresse hat,
|
||||
wird die Nachricht auch per Email versendet.
|
||||
"""
|
||||
conversations = get_conversations()
|
||||
conv = next((c for c in conversations if c["id"] == conversation_id), None)
|
||||
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="Konversation nicht gefunden")
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
new_message = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"conversation_id": conversation_id,
|
||||
"sender_id": "self",
|
||||
"timestamp": now,
|
||||
"read": True,
|
||||
"read_at": now,
|
||||
"email_sent": False,
|
||||
"email_sent_at": None,
|
||||
"email_error": None,
|
||||
**message.dict()
|
||||
}
|
||||
|
||||
# Email-Versand wenn gewuenscht
|
||||
if message.send_email and not conv.get("is_group"):
|
||||
# Kontakt laden
|
||||
participant_ids = conv.get("participant_ids", [])
|
||||
if participant_ids:
|
||||
contacts = get_contacts()
|
||||
contact = next((c for c in contacts if c["id"] == participant_ids[0]), None)
|
||||
|
||||
if contact and contact.get("email"):
|
||||
try:
|
||||
from email_service import email_service
|
||||
|
||||
result = email_service.send_messenger_notification(
|
||||
to_email=contact["email"],
|
||||
to_name=contact.get("name", ""),
|
||||
sender_name="BreakPilot Lehrer",
|
||||
message_content=message.content
|
||||
)
|
||||
|
||||
if result.success:
|
||||
new_message["email_sent"] = True
|
||||
new_message["email_sent_at"] = result.sent_at
|
||||
else:
|
||||
new_message["email_error"] = result.error
|
||||
|
||||
except Exception as e:
|
||||
new_message["email_error"] = str(e)
|
||||
|
||||
messages = get_messages()
|
||||
messages.append(new_message)
|
||||
save_messages(messages)
|
||||
|
||||
# Konversation aktualisieren
|
||||
conv_idx = next(i for i, c in enumerate(conversations) if c["id"] == conversation_id)
|
||||
conversations[conv_idx]["last_message"] = message.content[:50]
|
||||
conversations[conv_idx]["last_message_time"] = now
|
||||
conversations[conv_idx]["updated_at"] = now
|
||||
save_conversations(conversations)
|
||||
|
||||
return new_message
|
||||
|
||||
|
||||
@router.put("/messages/{message_id}/read")
|
||||
async def mark_message_read(message_id: str):
|
||||
"""Markiert eine Nachricht als gelesen."""
|
||||
messages = get_messages()
|
||||
msg_idx = next((i for i, m in enumerate(messages) if m["id"] == message_id), None)
|
||||
|
||||
if msg_idx is None:
|
||||
raise HTTPException(status_code=404, detail="Nachricht nicht gefunden")
|
||||
|
||||
messages[msg_idx]["read"] = True
|
||||
messages[msg_idx]["read_at"] = datetime.utcnow().isoformat()
|
||||
save_messages(messages)
|
||||
|
||||
return {"status": "read", "id": message_id}
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/read-all")
|
||||
async def mark_all_messages_read(conversation_id: str):
|
||||
"""Markiert alle Nachrichten einer Konversation als gelesen."""
|
||||
messages = get_messages()
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("conversation_id") == conversation_id and not msg.get("read"):
|
||||
msg["read"] = True
|
||||
msg["read_at"] = now
|
||||
|
||||
save_messages(messages)
|
||||
|
||||
return {"status": "all_read", "conversation_id": conversation_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# TEMPLATES ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/templates")
|
||||
async def list_templates():
|
||||
"""Listet alle Nachrichtenvorlagen auf."""
|
||||
templates_file = DATA_DIR / "templates.json"
|
||||
if templates_file.exists():
|
||||
templates = load_json(templates_file)
|
||||
else:
|
||||
templates = DEFAULT_TEMPLATES
|
||||
save_json(templates_file, templates)
|
||||
|
||||
return templates
|
||||
|
||||
|
||||
@router.post("/templates")
|
||||
async def create_template(name: str, content: str, category: str = "custom"):
|
||||
"""Erstellt eine neue Vorlage."""
|
||||
templates_file = DATA_DIR / "templates.json"
|
||||
templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy()
|
||||
|
||||
new_template = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"content": content,
|
||||
"category": category
|
||||
}
|
||||
|
||||
templates.append(new_template)
|
||||
save_json(templates_file, templates)
|
||||
|
||||
return new_template
|
||||
|
||||
|
||||
@router.delete("/templates/{template_id}")
|
||||
async def delete_template(template_id: str):
|
||||
"""Loescht eine Vorlage."""
|
||||
templates_file = DATA_DIR / "templates.json"
|
||||
templates = load_json(templates_file) if templates_file.exists() else DEFAULT_TEMPLATES.copy()
|
||||
|
||||
templates = [t for t in templates if t["id"] != template_id]
|
||||
save_json(templates_file, templates)
|
||||
|
||||
return {"status": "deleted", "id": template_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# STATS ENDPOINT
|
||||
# ==========================================
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_messenger_stats():
|
||||
"""Gibt Statistiken zum Messenger zurueck."""
|
||||
contacts = get_contacts()
|
||||
conversations = get_conversations()
|
||||
messages = get_messages()
|
||||
groups = get_groups()
|
||||
|
||||
unread_total = sum(1 for m in messages if not m.get("read") and m.get("sender_id") != "self")
|
||||
|
||||
return {
|
||||
"total_contacts": len(contacts),
|
||||
"total_groups": len(groups),
|
||||
"total_conversations": len(conversations),
|
||||
"total_messages": len(messages),
|
||||
"unread_messages": unread_total,
|
||||
"contacts_by_role": {
|
||||
role: len([c for c in contacts if c.get("role") == role])
|
||||
for role in set(c.get("role", "parent") for c in contacts)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Messenger API - Data Helpers.
|
||||
|
||||
JSON-based file storage for contacts, conversations, messages, and groups.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from pathlib import Path
|
||||
|
||||
# Datenspeicherung (JSON-basiert fuer einfache Persistenz)
|
||||
DATA_DIR = Path(__file__).parent / "data" / "messenger"
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
CONTACTS_FILE = DATA_DIR / "contacts.json"
|
||||
CONVERSATIONS_FILE = DATA_DIR / "conversations.json"
|
||||
MESSAGES_FILE = DATA_DIR / "messages.json"
|
||||
GROUPS_FILE = DATA_DIR / "groups.json"
|
||||
|
||||
|
||||
def load_json(filepath: Path) -> List[Dict]:
|
||||
"""Laedt JSON-Daten aus Datei."""
|
||||
if not filepath.exists():
|
||||
return []
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def save_json(filepath: Path, data: List[Dict]):
|
||||
"""Speichert Daten in JSON-Datei."""
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def get_contacts() -> List[Dict]:
|
||||
return load_json(CONTACTS_FILE)
|
||||
|
||||
|
||||
def save_contacts(contacts: List[Dict]):
|
||||
save_json(CONTACTS_FILE, contacts)
|
||||
|
||||
|
||||
def get_conversations() -> List[Dict]:
|
||||
return load_json(CONVERSATIONS_FILE)
|
||||
|
||||
|
||||
def save_conversations(conversations: List[Dict]):
|
||||
save_json(CONVERSATIONS_FILE, conversations)
|
||||
|
||||
|
||||
def get_messages() -> List[Dict]:
|
||||
return load_json(MESSAGES_FILE)
|
||||
|
||||
|
||||
def save_messages(messages: List[Dict]):
|
||||
save_json(MESSAGES_FILE, messages)
|
||||
|
||||
|
||||
def get_groups() -> List[Dict]:
|
||||
return load_json(GROUPS_FILE)
|
||||
|
||||
|
||||
def save_groups(groups: List[Dict]):
|
||||
save_json(GROUPS_FILE, groups)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# DEFAULT TEMPLATES
|
||||
# ==========================================
|
||||
|
||||
DEFAULT_TEMPLATES = [
|
||||
{
|
||||
"id": "1",
|
||||
"name": "Terminbestaetigung",
|
||||
"content": "Vielen Dank fuer Ihre Terminanfrage. Ich bestaetige den Termin am [DATUM] um [UHRZEIT]. Bitte geben Sie mir Bescheid, falls sich etwas aendern sollte.",
|
||||
"category": "termin"
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"name": "Hausaufgaben-Info",
|
||||
"content": "Zur Information: Die Hausaufgaben fuer diese Woche umfassen [THEMA]. Abgabetermin ist [DATUM]. Bei Fragen stehe ich gerne zur Verfuegung.",
|
||||
"category": "hausaufgaben"
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"name": "Entschuldigung bestaetigen",
|
||||
"content": "Ich bestaetige den Erhalt der Entschuldigung fuer [NAME] am [DATUM]. Die Fehlzeiten wurden entsprechend vermerkt.",
|
||||
"category": "entschuldigung"
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"name": "Gespraechsanfrage",
|
||||
"content": "Ich wuerde gerne einen Termin fuer ein Gespraech mit Ihnen vereinbaren, um [THEMA] zu besprechen. Waeren Sie am [DATUM] um [UHRZEIT] verfuegbar?",
|
||||
"category": "gespraech"
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"name": "Krankmeldung bestaetigen",
|
||||
"content": "Vielen Dank fuer Ihre Krankmeldung fuer [NAME]. Ich wuensche gute Besserung. Bitte reichen Sie eine schriftliche Entschuldigung nach, sobald Ihr Kind wieder gesund ist.",
|
||||
"category": "krankmeldung"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Messenger API - Pydantic Models.
|
||||
|
||||
Data models for contacts, conversations, messages, and groups.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ==========================================
|
||||
# CONTACT MODELS
|
||||
# ==========================================
|
||||
|
||||
class ContactBase(BaseModel):
|
||||
"""Basis-Modell fuer Kontakte."""
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
role: str = Field(default="parent", description="parent, teacher, staff, student")
|
||||
student_name: Optional[str] = Field(None, description="Name des zugehoerigen Schuelers")
|
||||
class_name: Optional[str] = Field(None, description="Klasse z.B. 10a")
|
||||
notes: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
matrix_id: Optional[str] = Field(None, description="Matrix-ID z.B. @user:matrix.org")
|
||||
preferred_channel: str = Field(default="email", description="email, matrix, pwa")
|
||||
|
||||
|
||||
class ContactCreate(ContactBase):
|
||||
"""Model fuer neuen Kontakt."""
|
||||
pass
|
||||
|
||||
|
||||
class Contact(ContactBase):
|
||||
"""Vollstaendiger Kontakt mit ID."""
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
online: bool = False
|
||||
last_seen: Optional[str] = None
|
||||
|
||||
|
||||
class ContactUpdate(BaseModel):
|
||||
"""Update-Model fuer Kontakte."""
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
student_name: Optional[str] = None
|
||||
class_name: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
matrix_id: Optional[str] = None
|
||||
preferred_channel: Optional[str] = None
|
||||
|
||||
|
||||
# ==========================================
|
||||
# GROUP MODELS
|
||||
# ==========================================
|
||||
|
||||
class GroupBase(BaseModel):
|
||||
"""Basis-Modell fuer Gruppen."""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
group_type: str = Field(default="class", description="class, department, custom")
|
||||
|
||||
|
||||
class GroupCreate(GroupBase):
|
||||
"""Model fuer neue Gruppe."""
|
||||
member_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Group(GroupBase):
|
||||
"""Vollstaendige Gruppe mit ID."""
|
||||
id: str
|
||||
member_ids: List[str] = []
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
# ==========================================
|
||||
# MESSAGE MODELS
|
||||
# ==========================================
|
||||
|
||||
class MessageBase(BaseModel):
|
||||
"""Basis-Modell fuer Nachrichten."""
|
||||
content: str = Field(..., min_length=1)
|
||||
content_type: str = Field(default="text", description="text, file, image")
|
||||
file_url: Optional[str] = None
|
||||
send_email: bool = Field(default=False, description="Nachricht auch per Email senden")
|
||||
|
||||
|
||||
class MessageCreate(MessageBase):
|
||||
"""Model fuer neue Nachricht."""
|
||||
conversation_id: str
|
||||
|
||||
|
||||
class Message(MessageBase):
|
||||
"""Vollstaendige Nachricht mit ID."""
|
||||
id: str
|
||||
conversation_id: str
|
||||
sender_id: str # "self" fuer eigene Nachrichten
|
||||
timestamp: str
|
||||
read: bool = False
|
||||
read_at: Optional[str] = None
|
||||
email_sent: bool = False
|
||||
email_sent_at: Optional[str] = None
|
||||
email_error: Optional[str] = None
|
||||
|
||||
|
||||
# ==========================================
|
||||
# CONVERSATION MODELS
|
||||
# ==========================================
|
||||
|
||||
class ConversationBase(BaseModel):
|
||||
"""Basis-Modell fuer Konversationen."""
|
||||
name: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class Conversation(ConversationBase):
|
||||
"""Vollstaendige Konversation mit ID."""
|
||||
id: str
|
||||
participant_ids: List[str] = []
|
||||
group_id: Optional[str] = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
last_message: Optional[str] = None
|
||||
last_message_time: Optional[str] = None
|
||||
unread_count: int = 0
|
||||
|
||||
|
||||
class CSVImportResult(BaseModel):
|
||||
"""Ergebnis eines CSV-Imports."""
|
||||
imported: int
|
||||
skipped: int
|
||||
errors: List[str]
|
||||
contacts: List[Contact]
|
||||
+14
-840
@@ -1,848 +1,22 @@
|
||||
"""
|
||||
BreakPilot Recording API
|
||||
BreakPilot Recording API — Barrel Re-export.
|
||||
|
||||
Verwaltet Jibri Meeting-Aufzeichnungen und deren Metadaten.
|
||||
Empfaengt Webhooks von Jibri nach Upload zu MinIO.
|
||||
Split into:
|
||||
- recording_models.py: Pydantic models & config
|
||||
- recording_helpers.py: In-memory storage & utilities
|
||||
- recording_routes.py: Core recording CRUD routes
|
||||
- recording_transcription.py: Transcription routes
|
||||
- recording_minutes.py: Meeting minutes routes
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from recording_routes import router as _routes_router
|
||||
from recording_transcription import router as _transcription_router
|
||||
from recording_minutes import router as _minutes_router
|
||||
|
||||
router = APIRouter(prefix="/api/recordings", tags=["Recordings"])
|
||||
|
||||
# ==========================================
|
||||
# ENVIRONMENT CONFIGURATION
|
||||
# ==========================================
|
||||
|
||||
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000")
|
||||
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "breakpilot")
|
||||
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "breakpilot123")
|
||||
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-recordings")
|
||||
MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true"
|
||||
|
||||
# Default retention period in days (DSGVO compliance)
|
||||
DEFAULT_RETENTION_DAYS = int(os.getenv("RECORDING_RETENTION_DAYS", "365"))
|
||||
|
||||
|
||||
# ==========================================
|
||||
# PYDANTIC MODELS
|
||||
# ==========================================
|
||||
|
||||
class JibriWebhookPayload(BaseModel):
|
||||
"""Webhook payload from Jibri finalize.sh script."""
|
||||
event: str = Field(..., description="Event type: recording_completed")
|
||||
recording_name: str = Field(..., description="Unique recording identifier")
|
||||
storage_path: str = Field(..., description="Path in MinIO bucket")
|
||||
audio_path: Optional[str] = Field(None, description="Extracted audio path")
|
||||
file_size_bytes: int = Field(..., description="Video file size in bytes")
|
||||
timestamp: str = Field(..., description="ISO timestamp of upload")
|
||||
|
||||
|
||||
class RecordingCreate(BaseModel):
|
||||
"""Manual recording creation (for testing)."""
|
||||
meeting_id: str
|
||||
title: Optional[str] = None
|
||||
storage_path: str
|
||||
audio_path: Optional[str] = None
|
||||
duration_seconds: Optional[int] = None
|
||||
participant_count: Optional[int] = 0
|
||||
retention_days: Optional[int] = DEFAULT_RETENTION_DAYS
|
||||
|
||||
|
||||
class RecordingResponse(BaseModel):
|
||||
"""Recording details response."""
|
||||
id: str
|
||||
meeting_id: str
|
||||
title: Optional[str]
|
||||
storage_path: str
|
||||
audio_path: Optional[str]
|
||||
file_size_bytes: Optional[int]
|
||||
duration_seconds: Optional[int]
|
||||
participant_count: int
|
||||
status: str
|
||||
recorded_at: datetime
|
||||
retention_days: int
|
||||
retention_expires_at: datetime
|
||||
transcription_status: Optional[str] = None
|
||||
transcription_id: Optional[str] = None
|
||||
|
||||
|
||||
class RecordingListResponse(BaseModel):
|
||||
"""Paginated list of recordings."""
|
||||
recordings: List[RecordingResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class TranscriptionRequest(BaseModel):
|
||||
"""Request to start transcription."""
|
||||
language: str = Field(default="de", description="Language code: de, en, etc.")
|
||||
model: str = Field(default="large-v3", description="Whisper model to use")
|
||||
priority: int = Field(default=0, description="Queue priority (higher = sooner)")
|
||||
|
||||
|
||||
class TranscriptionStatusResponse(BaseModel):
|
||||
"""Transcription status and progress."""
|
||||
id: str
|
||||
recording_id: str
|
||||
status: str
|
||||
language: str
|
||||
model: str
|
||||
word_count: Optional[int]
|
||||
confidence_score: Optional[float]
|
||||
processing_duration_seconds: Optional[int]
|
||||
error_message: Optional[str]
|
||||
created_at: datetime
|
||||
completed_at: Optional[datetime]
|
||||
|
||||
|
||||
# ==========================================
|
||||
# IN-MEMORY STORAGE (Dev Mode)
|
||||
# ==========================================
|
||||
# In production, these would be database queries
|
||||
|
||||
_recordings_store: dict = {}
|
||||
_transcriptions_store: dict = {}
|
||||
_audit_log: list = []
|
||||
|
||||
|
||||
def log_audit(
|
||||
action: str,
|
||||
recording_id: Optional[str] = None,
|
||||
transcription_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""Log audit event for DSGVO compliance."""
|
||||
_audit_log.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"recording_id": recording_id,
|
||||
"transcription_id": transcription_id,
|
||||
"user_id": user_id,
|
||||
"action": action,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
|
||||
# ==========================================
|
||||
# WEBHOOK ENDPOINT (Jibri)
|
||||
# ==========================================
|
||||
|
||||
@router.post("/webhook")
|
||||
async def jibri_webhook(payload: JibriWebhookPayload, request: Request):
|
||||
"""
|
||||
Webhook endpoint called by Jibri finalize.sh after upload.
|
||||
|
||||
This creates a new recording entry and optionally triggers transcription.
|
||||
"""
|
||||
if payload.event != "recording_completed":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": f"Unknown event type: {payload.event}"}
|
||||
)
|
||||
|
||||
# Extract meeting_id from recording_name (format: meetingId_timestamp)
|
||||
parts = payload.recording_name.split("_")
|
||||
meeting_id = parts[0] if parts else payload.recording_name
|
||||
|
||||
# Create recording entry
|
||||
recording_id = str(uuid.uuid4())
|
||||
recorded_at = datetime.utcnow()
|
||||
|
||||
recording = {
|
||||
"id": recording_id,
|
||||
"meeting_id": meeting_id,
|
||||
"jibri_session_id": payload.recording_name,
|
||||
"title": f"Recording {meeting_id}",
|
||||
"storage_path": payload.storage_path,
|
||||
"audio_path": payload.audio_path,
|
||||
"file_size_bytes": payload.file_size_bytes,
|
||||
"duration_seconds": None, # Will be updated after analysis
|
||||
"participant_count": 0,
|
||||
"status": "uploaded",
|
||||
"recorded_at": recorded_at.isoformat(),
|
||||
"retention_days": DEFAULT_RETENTION_DAYS,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"updated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
_recordings_store[recording_id] = recording
|
||||
|
||||
# Log the creation
|
||||
log_audit(
|
||||
action="created",
|
||||
recording_id=recording_id,
|
||||
metadata={
|
||||
"source": "jibri_webhook",
|
||||
"storage_path": payload.storage_path,
|
||||
"file_size_bytes": payload.file_size_bytes
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"recording_id": recording_id,
|
||||
"meeting_id": meeting_id,
|
||||
"status": "uploaded",
|
||||
"message": "Recording registered successfully"
|
||||
}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# HEALTH & AUDIT ENDPOINTS (must be before parameterized routes)
|
||||
# ==========================================
|
||||
|
||||
@router.get("/health")
|
||||
async def recordings_health():
|
||||
"""Health check for recording service."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"recordings_count": len(_recordings_store),
|
||||
"transcriptions_count": len(_transcriptions_store),
|
||||
"minio_endpoint": MINIO_ENDPOINT,
|
||||
"bucket": MINIO_BUCKET
|
||||
}
|
||||
|
||||
|
||||
@router.get("/audit/log")
|
||||
async def get_audit_log(
|
||||
recording_id: Optional[str] = Query(None),
|
||||
action: Optional[str] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000)
|
||||
):
|
||||
"""
|
||||
Get audit log entries (DSGVO compliance).
|
||||
|
||||
Admin-only endpoint for reviewing recording access history.
|
||||
"""
|
||||
logs = _audit_log.copy()
|
||||
|
||||
if recording_id:
|
||||
logs = [l for l in logs if l.get("recording_id") == recording_id]
|
||||
if action:
|
||||
logs = [l for l in logs if l.get("action") == action]
|
||||
|
||||
# Sort by created_at descending
|
||||
logs.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
return {
|
||||
"entries": logs[:limit],
|
||||
"total": len(logs)
|
||||
}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# RECORDING MANAGEMENT ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("", response_model=RecordingListResponse)
|
||||
async def list_recordings(
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
meeting_id: Optional[str] = Query(None, description="Filter by meeting ID"),
|
||||
page: int = Query(1, ge=1, description="Page number"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="Items per page")
|
||||
):
|
||||
"""
|
||||
List all recordings with optional filtering.
|
||||
|
||||
Supports pagination and filtering by status or meeting ID.
|
||||
"""
|
||||
# Filter recordings
|
||||
recordings = list(_recordings_store.values())
|
||||
|
||||
if status:
|
||||
recordings = [r for r in recordings if r["status"] == status]
|
||||
if meeting_id:
|
||||
recordings = [r for r in recordings if r["meeting_id"] == meeting_id]
|
||||
|
||||
# Sort by recorded_at descending
|
||||
recordings.sort(key=lambda x: x["recorded_at"], reverse=True)
|
||||
|
||||
# Paginate
|
||||
total = len(recordings)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
page_recordings = recordings[start:end]
|
||||
|
||||
# Convert to response format
|
||||
result = []
|
||||
for rec in page_recordings:
|
||||
recorded_at = datetime.fromisoformat(rec["recorded_at"])
|
||||
retention_expires = recorded_at + timedelta(days=rec["retention_days"])
|
||||
|
||||
# Check for transcription
|
||||
trans = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == rec["id"]),
|
||||
None
|
||||
)
|
||||
|
||||
result.append(RecordingResponse(
|
||||
id=rec["id"],
|
||||
meeting_id=rec["meeting_id"],
|
||||
title=rec.get("title"),
|
||||
storage_path=rec["storage_path"],
|
||||
audio_path=rec.get("audio_path"),
|
||||
file_size_bytes=rec.get("file_size_bytes"),
|
||||
duration_seconds=rec.get("duration_seconds"),
|
||||
participant_count=rec.get("participant_count", 0),
|
||||
status=rec["status"],
|
||||
recorded_at=recorded_at,
|
||||
retention_days=rec["retention_days"],
|
||||
retention_expires_at=retention_expires,
|
||||
transcription_status=trans["status"] if trans else None,
|
||||
transcription_id=trans["id"] if trans else None
|
||||
))
|
||||
|
||||
return RecordingListResponse(
|
||||
recordings=result,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}", response_model=RecordingResponse)
|
||||
async def get_recording(recording_id: str):
|
||||
"""
|
||||
Get details for a specific recording.
|
||||
"""
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
# Log view action
|
||||
log_audit(action="viewed", recording_id=recording_id)
|
||||
|
||||
recorded_at = datetime.fromisoformat(recording["recorded_at"])
|
||||
retention_expires = recorded_at + timedelta(days=recording["retention_days"])
|
||||
|
||||
# Check for transcription
|
||||
trans = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
|
||||
return RecordingResponse(
|
||||
id=recording["id"],
|
||||
meeting_id=recording["meeting_id"],
|
||||
title=recording.get("title"),
|
||||
storage_path=recording["storage_path"],
|
||||
audio_path=recording.get("audio_path"),
|
||||
file_size_bytes=recording.get("file_size_bytes"),
|
||||
duration_seconds=recording.get("duration_seconds"),
|
||||
participant_count=recording.get("participant_count", 0),
|
||||
status=recording["status"],
|
||||
recorded_at=recorded_at,
|
||||
retention_days=recording["retention_days"],
|
||||
retention_expires_at=retention_expires,
|
||||
transcription_status=trans["status"] if trans else None,
|
||||
transcription_id=trans["id"] if trans else None
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{recording_id}")
|
||||
async def delete_recording(
|
||||
recording_id: str,
|
||||
reason: str = Query(..., description="Reason for deletion (DSGVO audit)")
|
||||
):
|
||||
"""
|
||||
Soft-delete a recording (DSGVO compliance).
|
||||
|
||||
The recording is marked as deleted but retained for audit purposes.
|
||||
Actual file deletion happens after the audit retention period.
|
||||
"""
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
# Soft delete
|
||||
recording["status"] = "deleted"
|
||||
recording["deleted_at"] = datetime.utcnow().isoformat()
|
||||
recording["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
# Log deletion with reason
|
||||
log_audit(
|
||||
action="deleted",
|
||||
recording_id=recording_id,
|
||||
metadata={"reason": reason}
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"recording_id": recording_id,
|
||||
"status": "deleted",
|
||||
"message": "Recording marked for deletion"
|
||||
}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# TRANSCRIPTION ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.post("/{recording_id}/transcribe", response_model=TranscriptionStatusResponse)
|
||||
async def start_transcription(recording_id: str, request: TranscriptionRequest):
|
||||
"""
|
||||
Start transcription for a recording.
|
||||
|
||||
Queues the recording for processing by the transcription worker.
|
||||
"""
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
if recording["status"] == "deleted":
|
||||
raise HTTPException(status_code=400, detail="Cannot transcribe deleted recording")
|
||||
|
||||
# Check if transcription already exists
|
||||
existing = next(
|
||||
(t for t in _transcriptions_store.values()
|
||||
if t["recording_id"] == recording_id and t["status"] != "failed"),
|
||||
None
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Transcription already exists with status: {existing['status']}"
|
||||
)
|
||||
|
||||
# Create transcription entry
|
||||
transcription_id = str(uuid.uuid4())
|
||||
now = datetime.utcnow()
|
||||
|
||||
transcription = {
|
||||
"id": transcription_id,
|
||||
"recording_id": recording_id,
|
||||
"language": request.language,
|
||||
"model": request.model,
|
||||
"status": "pending",
|
||||
"full_text": None,
|
||||
"word_count": None,
|
||||
"confidence_score": None,
|
||||
"vtt_path": None,
|
||||
"srt_path": None,
|
||||
"json_path": None,
|
||||
"error_message": None,
|
||||
"processing_started_at": None,
|
||||
"processing_completed_at": None,
|
||||
"processing_duration_seconds": None,
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat()
|
||||
}
|
||||
|
||||
_transcriptions_store[transcription_id] = transcription
|
||||
|
||||
# Update recording status
|
||||
recording["status"] = "processing"
|
||||
recording["updated_at"] = now.isoformat()
|
||||
|
||||
# Log transcription start
|
||||
log_audit(
|
||||
action="transcription_started",
|
||||
recording_id=recording_id,
|
||||
transcription_id=transcription_id,
|
||||
metadata={"language": request.language, "model": request.model}
|
||||
)
|
||||
|
||||
# TODO: Queue job to Redis/Valkey for transcription worker
|
||||
# from redis import Redis
|
||||
# from rq import Queue
|
||||
# q = Queue(connection=Redis.from_url(os.getenv("REDIS_URL")))
|
||||
# q.enqueue('transcription_worker.tasks.transcribe', transcription_id, ...)
|
||||
|
||||
return TranscriptionStatusResponse(
|
||||
id=transcription_id,
|
||||
recording_id=recording_id,
|
||||
status="pending",
|
||||
language=request.language,
|
||||
model=request.model,
|
||||
word_count=None,
|
||||
confidence_score=None,
|
||||
processing_duration_seconds=None,
|
||||
error_message=None,
|
||||
created_at=now,
|
||||
completed_at=None
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/transcription", response_model=TranscriptionStatusResponse)
|
||||
async def get_transcription_status(recording_id: str):
|
||||
"""
|
||||
Get transcription status for a recording.
|
||||
"""
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=404, detail="No transcription found for this recording")
|
||||
|
||||
return TranscriptionStatusResponse(
|
||||
id=transcription["id"],
|
||||
recording_id=transcription["recording_id"],
|
||||
status=transcription["status"],
|
||||
language=transcription["language"],
|
||||
model=transcription["model"],
|
||||
word_count=transcription.get("word_count"),
|
||||
confidence_score=transcription.get("confidence_score"),
|
||||
processing_duration_seconds=transcription.get("processing_duration_seconds"),
|
||||
error_message=transcription.get("error_message"),
|
||||
created_at=datetime.fromisoformat(transcription["created_at"]),
|
||||
completed_at=(
|
||||
datetime.fromisoformat(transcription["processing_completed_at"])
|
||||
if transcription.get("processing_completed_at") else None
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/transcription/text")
|
||||
async def get_transcription_text(recording_id: str):
|
||||
"""
|
||||
Get the full transcription text.
|
||||
"""
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=404, detail="No transcription found for this recording")
|
||||
|
||||
if transcription["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Transcription not ready. Status: {transcription['status']}"
|
||||
)
|
||||
|
||||
return {
|
||||
"transcription_id": transcription["id"],
|
||||
"recording_id": recording_id,
|
||||
"language": transcription["language"],
|
||||
"text": transcription.get("full_text", ""),
|
||||
"word_count": transcription.get("word_count", 0)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{recording_id}/transcription/vtt")
|
||||
async def get_transcription_vtt(recording_id: str):
|
||||
"""
|
||||
Download transcription as WebVTT subtitle file.
|
||||
"""
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=404, detail="No transcription found for this recording")
|
||||
|
||||
if transcription["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Transcription not ready. Status: {transcription['status']}"
|
||||
)
|
||||
|
||||
# Generate VTT content
|
||||
# In production, this would read from the stored VTT file
|
||||
vtt_content = "WEBVTT\n\n"
|
||||
text = transcription.get("full_text", "")
|
||||
|
||||
if text:
|
||||
# Simple VTT generation - split into sentences
|
||||
sentences = text.replace(".", ".\n").split("\n")
|
||||
time_offset = 0
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
start = format_vtt_time(time_offset)
|
||||
# Estimate ~3 seconds per sentence
|
||||
time_offset += 3000
|
||||
end = format_vtt_time(time_offset)
|
||||
vtt_content += f"{start} --> {end}\n{sentence}\n\n"
|
||||
|
||||
return PlainTextResponse(
|
||||
content=vtt_content,
|
||||
media_type="text/vtt",
|
||||
headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.vtt"}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/transcription/srt")
|
||||
async def get_transcription_srt(recording_id: str):
|
||||
"""
|
||||
Download transcription as SRT subtitle file.
|
||||
"""
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=404, detail="No transcription found for this recording")
|
||||
|
||||
if transcription["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Transcription not ready. Status: {transcription['status']}"
|
||||
)
|
||||
|
||||
# Generate SRT content
|
||||
srt_content = ""
|
||||
text = transcription.get("full_text", "")
|
||||
|
||||
if text:
|
||||
sentences = text.replace(".", ".\n").split("\n")
|
||||
time_offset = 0
|
||||
index = 1
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
start = format_srt_time(time_offset)
|
||||
time_offset += 3000
|
||||
end = format_srt_time(time_offset)
|
||||
srt_content += f"{index}\n{start} --> {end}\n{sentence}\n\n"
|
||||
index += 1
|
||||
|
||||
return PlainTextResponse(
|
||||
content=srt_content,
|
||||
media_type="text/plain",
|
||||
headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.srt"}
|
||||
)
|
||||
|
||||
|
||||
def format_vtt_time(ms: int) -> str:
|
||||
"""Format milliseconds to VTT timestamp (HH:MM:SS.mmm)."""
|
||||
hours = ms // 3600000
|
||||
minutes = (ms % 3600000) // 60000
|
||||
seconds = (ms % 60000) // 1000
|
||||
millis = ms % 1000
|
||||
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{millis:03d}"
|
||||
|
||||
|
||||
def format_srt_time(ms: int) -> str:
|
||||
"""Format milliseconds to SRT timestamp (HH:MM:SS,mmm)."""
|
||||
hours = ms // 3600000
|
||||
minutes = (ms % 3600000) // 60000
|
||||
seconds = (ms % 60000) // 1000
|
||||
millis = ms % 1000
|
||||
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}"
|
||||
|
||||
|
||||
@router.get("/{recording_id}/download")
|
||||
async def download_recording(recording_id: str):
|
||||
"""
|
||||
Download the recording file.
|
||||
|
||||
In production, this would generate a presigned URL to MinIO.
|
||||
"""
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
if recording["status"] == "deleted":
|
||||
raise HTTPException(status_code=410, detail="Recording has been deleted")
|
||||
|
||||
# Log download action
|
||||
log_audit(action="downloaded", recording_id=recording_id)
|
||||
|
||||
# In production, generate presigned URL to MinIO
|
||||
# For now, return info about where the file is
|
||||
return {
|
||||
"recording_id": recording_id,
|
||||
"storage_path": recording["storage_path"],
|
||||
"file_size_bytes": recording.get("file_size_bytes"),
|
||||
"message": "In production, this would redirect to a presigned MinIO URL"
|
||||
}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# MEETING MINUTES ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
# In-memory store for meeting minutes (dev mode)
|
||||
_minutes_store: dict = {}
|
||||
|
||||
|
||||
@router.post("/{recording_id}/minutes")
|
||||
async def generate_meeting_minutes(
|
||||
recording_id: str,
|
||||
title: Optional[str] = Query(None, description="Meeting-Titel"),
|
||||
model: str = Query("breakpilot-teacher-8b", description="LLM Modell")
|
||||
):
|
||||
"""
|
||||
Generiert KI-basierte Meeting Minutes aus der Transkription.
|
||||
|
||||
Nutzt das LLM Gateway (Ollama/vLLM) fuer lokale Verarbeitung.
|
||||
"""
|
||||
from meeting_minutes_generator import get_minutes_generator, MeetingMinutes
|
||||
|
||||
# Check recording exists
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
# Check transcription exists and is completed
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=400, detail="No transcription found. Please transcribe first.")
|
||||
|
||||
if transcription["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Transcription not ready. Status: {transcription['status']}"
|
||||
)
|
||||
|
||||
# Check if minutes already exist
|
||||
existing = _minutes_store.get(recording_id)
|
||||
if existing and existing.get("status") == "completed":
|
||||
# Return existing minutes
|
||||
return existing
|
||||
|
||||
# Get transcript text
|
||||
transcript_text = transcription.get("full_text", "")
|
||||
if not transcript_text:
|
||||
raise HTTPException(status_code=400, detail="Transcription has no text content")
|
||||
|
||||
# Generate meeting minutes
|
||||
generator = get_minutes_generator()
|
||||
|
||||
try:
|
||||
minutes = await generator.generate(
|
||||
transcript=transcript_text,
|
||||
recording_id=recording_id,
|
||||
transcription_id=transcription["id"],
|
||||
title=title,
|
||||
date=recording.get("recorded_at", "")[:10] if recording.get("recorded_at") else None,
|
||||
duration_minutes=recording.get("duration_seconds", 0) // 60 if recording.get("duration_seconds") else None,
|
||||
participant_count=recording.get("participant_count", 0),
|
||||
model=model
|
||||
)
|
||||
|
||||
# Store minutes
|
||||
minutes_dict = minutes.model_dump()
|
||||
minutes_dict["generated_at"] = minutes.generated_at.isoformat()
|
||||
_minutes_store[recording_id] = minutes_dict
|
||||
|
||||
# Log action
|
||||
log_audit(
|
||||
action="minutes_generated",
|
||||
recording_id=recording_id,
|
||||
metadata={"model": model, "generation_time": minutes.generation_time_seconds}
|
||||
)
|
||||
|
||||
return minutes_dict
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Minutes generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{recording_id}/minutes")
|
||||
async def get_meeting_minutes(recording_id: str):
|
||||
"""
|
||||
Ruft generierte Meeting Minutes ab.
|
||||
"""
|
||||
minutes = _minutes_store.get(recording_id)
|
||||
if not minutes:
|
||||
raise HTTPException(status_code=404, detail="No meeting minutes found. Generate them first with POST.")
|
||||
|
||||
return minutes
|
||||
|
||||
|
||||
@router.get("/{recording_id}/minutes/markdown")
|
||||
async def get_minutes_markdown(recording_id: str):
|
||||
"""
|
||||
Exportiert Meeting Minutes als Markdown.
|
||||
"""
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from meeting_minutes_generator import minutes_to_markdown, MeetingMinutes
|
||||
|
||||
minutes_dict = _minutes_store.get(recording_id)
|
||||
if not minutes_dict:
|
||||
raise HTTPException(status_code=404, detail="No meeting minutes found")
|
||||
|
||||
# Convert dict back to MeetingMinutes
|
||||
minutes_dict_copy = minutes_dict.copy()
|
||||
if isinstance(minutes_dict_copy.get("generated_at"), str):
|
||||
minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"])
|
||||
|
||||
minutes = MeetingMinutes(**minutes_dict_copy)
|
||||
markdown = minutes_to_markdown(minutes)
|
||||
|
||||
return PlainTextResponse(
|
||||
content=markdown,
|
||||
media_type="text/markdown",
|
||||
headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.md"}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/minutes/html")
|
||||
async def get_minutes_html(recording_id: str):
|
||||
"""
|
||||
Exportiert Meeting Minutes als HTML.
|
||||
"""
|
||||
from fastapi.responses import HTMLResponse
|
||||
from meeting_minutes_generator import minutes_to_html, MeetingMinutes
|
||||
|
||||
minutes_dict = _minutes_store.get(recording_id)
|
||||
if not minutes_dict:
|
||||
raise HTTPException(status_code=404, detail="No meeting minutes found")
|
||||
|
||||
# Convert dict back to MeetingMinutes
|
||||
minutes_dict_copy = minutes_dict.copy()
|
||||
if isinstance(minutes_dict_copy.get("generated_at"), str):
|
||||
minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"])
|
||||
|
||||
minutes = MeetingMinutes(**minutes_dict_copy)
|
||||
html = minutes_to_html(minutes)
|
||||
|
||||
return HTMLResponse(content=html)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/minutes/pdf")
|
||||
async def get_minutes_pdf(recording_id: str):
|
||||
"""
|
||||
Exportiert Meeting Minutes als PDF.
|
||||
|
||||
Benoetigt WeasyPrint (pip install weasyprint).
|
||||
"""
|
||||
from meeting_minutes_generator import minutes_to_html, MeetingMinutes
|
||||
|
||||
minutes_dict = _minutes_store.get(recording_id)
|
||||
if not minutes_dict:
|
||||
raise HTTPException(status_code=404, detail="No meeting minutes found")
|
||||
|
||||
# Convert dict back to MeetingMinutes
|
||||
minutes_dict_copy = minutes_dict.copy()
|
||||
if isinstance(minutes_dict_copy.get("generated_at"), str):
|
||||
minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"])
|
||||
|
||||
minutes = MeetingMinutes(**minutes_dict_copy)
|
||||
html = minutes_to_html(minutes)
|
||||
|
||||
try:
|
||||
from weasyprint import HTML
|
||||
from fastapi.responses import Response
|
||||
|
||||
pdf_bytes = HTML(string=html).write_pdf()
|
||||
|
||||
return Response(
|
||||
content=pdf_bytes,
|
||||
media_type="application/pdf",
|
||||
headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.pdf"}
|
||||
)
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="PDF export not available. Install weasyprint: pip install weasyprint"
|
||||
)
|
||||
router.include_router(_routes_router)
|
||||
router.include_router(_transcription_router)
|
||||
router.include_router(_minutes_router)
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Recording API - In-Memory Storage & Helpers.
|
||||
|
||||
Shared state and utility functions for recording endpoints.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ==========================================
|
||||
# IN-MEMORY STORAGE (Dev Mode)
|
||||
# ==========================================
|
||||
# In production, these would be database queries
|
||||
|
||||
_recordings_store: dict = {}
|
||||
_transcriptions_store: dict = {}
|
||||
_audit_log: list = []
|
||||
_minutes_store: dict = {}
|
||||
|
||||
|
||||
def log_audit(
|
||||
action: str,
|
||||
recording_id: Optional[str] = None,
|
||||
transcription_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""Log audit event for DSGVO compliance."""
|
||||
_audit_log.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"recording_id": recording_id,
|
||||
"transcription_id": transcription_id,
|
||||
"user_id": user_id,
|
||||
"action": action,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
|
||||
def format_vtt_time(ms: int) -> str:
|
||||
"""Format milliseconds to VTT timestamp (HH:MM:SS.mmm)."""
|
||||
hours = ms // 3600000
|
||||
minutes = (ms % 3600000) // 60000
|
||||
seconds = (ms % 60000) // 1000
|
||||
millis = ms % 1000
|
||||
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{millis:03d}"
|
||||
|
||||
|
||||
def format_srt_time(ms: int) -> str:
|
||||
"""Format milliseconds to SRT timestamp (HH:MM:SS,mmm)."""
|
||||
hours = ms // 3600000
|
||||
minutes = (ms % 3600000) // 60000
|
||||
seconds = (ms % 60000) // 1000
|
||||
millis = ms % 1000
|
||||
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}"
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Recording API - Meeting Minutes Routes.
|
||||
|
||||
Generate, retrieve, and export KI-based meeting minutes.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import PlainTextResponse, HTMLResponse
|
||||
|
||||
from recording_helpers import (
|
||||
_recordings_store,
|
||||
_transcriptions_store,
|
||||
_minutes_store,
|
||||
log_audit,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Recordings"])
|
||||
|
||||
|
||||
# ==========================================
|
||||
# MEETING MINUTES ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.post("/{recording_id}/minutes")
|
||||
async def generate_meeting_minutes(
|
||||
recording_id: str,
|
||||
title: Optional[str] = Query(None, description="Meeting-Titel"),
|
||||
model: str = Query("breakpilot-teacher-8b", description="LLM Modell")
|
||||
):
|
||||
"""
|
||||
Generiert KI-basierte Meeting Minutes aus der Transkription.
|
||||
|
||||
Nutzt das LLM Gateway (Ollama/vLLM) fuer lokale Verarbeitung.
|
||||
"""
|
||||
from meeting_minutes_generator import get_minutes_generator, MeetingMinutes
|
||||
|
||||
# Check recording exists
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
# Check transcription exists and is completed
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=400, detail="No transcription found. Please transcribe first.")
|
||||
|
||||
if transcription["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Transcription not ready. Status: {transcription['status']}"
|
||||
)
|
||||
|
||||
# Check if minutes already exist
|
||||
existing = _minutes_store.get(recording_id)
|
||||
if existing and existing.get("status") == "completed":
|
||||
# Return existing minutes
|
||||
return existing
|
||||
|
||||
# Get transcript text
|
||||
transcript_text = transcription.get("full_text", "")
|
||||
if not transcript_text:
|
||||
raise HTTPException(status_code=400, detail="Transcription has no text content")
|
||||
|
||||
# Generate meeting minutes
|
||||
generator = get_minutes_generator()
|
||||
|
||||
try:
|
||||
minutes = await generator.generate(
|
||||
transcript=transcript_text,
|
||||
recording_id=recording_id,
|
||||
transcription_id=transcription["id"],
|
||||
title=title,
|
||||
date=recording.get("recorded_at", "")[:10] if recording.get("recorded_at") else None,
|
||||
duration_minutes=recording.get("duration_seconds", 0) // 60 if recording.get("duration_seconds") else None,
|
||||
participant_count=recording.get("participant_count", 0),
|
||||
model=model
|
||||
)
|
||||
|
||||
# Store minutes
|
||||
minutes_dict = minutes.model_dump()
|
||||
minutes_dict["generated_at"] = minutes.generated_at.isoformat()
|
||||
_minutes_store[recording_id] = minutes_dict
|
||||
|
||||
# Log action
|
||||
log_audit(
|
||||
action="minutes_generated",
|
||||
recording_id=recording_id,
|
||||
metadata={"model": model, "generation_time": minutes.generation_time_seconds}
|
||||
)
|
||||
|
||||
return minutes_dict
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Minutes generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{recording_id}/minutes")
|
||||
async def get_meeting_minutes(recording_id: str):
|
||||
"""
|
||||
Ruft generierte Meeting Minutes ab.
|
||||
"""
|
||||
minutes = _minutes_store.get(recording_id)
|
||||
if not minutes:
|
||||
raise HTTPException(status_code=404, detail="No meeting minutes found. Generate them first with POST.")
|
||||
|
||||
return minutes
|
||||
|
||||
|
||||
def _load_minutes(recording_id: str):
|
||||
"""Load and convert stored minutes dict back to MeetingMinutes."""
|
||||
from meeting_minutes_generator import MeetingMinutes
|
||||
|
||||
minutes_dict = _minutes_store.get(recording_id)
|
||||
if not minutes_dict:
|
||||
raise HTTPException(status_code=404, detail="No meeting minutes found")
|
||||
|
||||
minutes_dict_copy = minutes_dict.copy()
|
||||
if isinstance(minutes_dict_copy.get("generated_at"), str):
|
||||
minutes_dict_copy["generated_at"] = datetime.fromisoformat(minutes_dict_copy["generated_at"])
|
||||
|
||||
return MeetingMinutes(**minutes_dict_copy)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/minutes/markdown")
|
||||
async def get_minutes_markdown(recording_id: str):
|
||||
"""
|
||||
Exportiert Meeting Minutes als Markdown.
|
||||
"""
|
||||
from meeting_minutes_generator import minutes_to_markdown
|
||||
|
||||
minutes = _load_minutes(recording_id)
|
||||
markdown = minutes_to_markdown(minutes)
|
||||
|
||||
return PlainTextResponse(
|
||||
content=markdown,
|
||||
media_type="text/markdown",
|
||||
headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.md"}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/minutes/html")
|
||||
async def get_minutes_html(recording_id: str):
|
||||
"""
|
||||
Exportiert Meeting Minutes als HTML.
|
||||
"""
|
||||
from meeting_minutes_generator import minutes_to_html
|
||||
|
||||
minutes = _load_minutes(recording_id)
|
||||
html = minutes_to_html(minutes)
|
||||
|
||||
return HTMLResponse(content=html)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/minutes/pdf")
|
||||
async def get_minutes_pdf(recording_id: str):
|
||||
"""
|
||||
Exportiert Meeting Minutes als PDF.
|
||||
|
||||
Benoetigt WeasyPrint (pip install weasyprint).
|
||||
"""
|
||||
from meeting_minutes_generator import minutes_to_html
|
||||
|
||||
minutes = _load_minutes(recording_id)
|
||||
html = minutes_to_html(minutes)
|
||||
|
||||
try:
|
||||
from weasyprint import HTML
|
||||
from fastapi.responses import Response
|
||||
|
||||
pdf_bytes = HTML(string=html).write_pdf()
|
||||
|
||||
return Response(
|
||||
content=pdf_bytes,
|
||||
media_type="application/pdf",
|
||||
headers={"Content-Disposition": f"attachment; filename=protokoll_{recording_id}.pdf"}
|
||||
)
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="PDF export not available. Install weasyprint: pip install weasyprint"
|
||||
)
|
||||
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Recording API - Pydantic Models & Configuration.
|
||||
|
||||
Data models for recording, transcription, and webhook endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ==========================================
|
||||
# ENVIRONMENT CONFIGURATION
|
||||
# ==========================================
|
||||
|
||||
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "minio:9000")
|
||||
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "breakpilot")
|
||||
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "breakpilot123")
|
||||
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-recordings")
|
||||
MINIO_SECURE = os.getenv("MINIO_SECURE", "false").lower() == "true"
|
||||
|
||||
# Default retention period in days (DSGVO compliance)
|
||||
DEFAULT_RETENTION_DAYS = int(os.getenv("RECORDING_RETENTION_DAYS", "365"))
|
||||
|
||||
|
||||
# ==========================================
|
||||
# PYDANTIC MODELS
|
||||
# ==========================================
|
||||
|
||||
class JibriWebhookPayload(BaseModel):
|
||||
"""Webhook payload from Jibri finalize.sh script."""
|
||||
event: str = Field(..., description="Event type: recording_completed")
|
||||
recording_name: str = Field(..., description="Unique recording identifier")
|
||||
storage_path: str = Field(..., description="Path in MinIO bucket")
|
||||
audio_path: Optional[str] = Field(None, description="Extracted audio path")
|
||||
file_size_bytes: int = Field(..., description="Video file size in bytes")
|
||||
timestamp: str = Field(..., description="ISO timestamp of upload")
|
||||
|
||||
|
||||
class RecordingCreate(BaseModel):
|
||||
"""Manual recording creation (for testing)."""
|
||||
meeting_id: str
|
||||
title: Optional[str] = None
|
||||
storage_path: str
|
||||
audio_path: Optional[str] = None
|
||||
duration_seconds: Optional[int] = None
|
||||
participant_count: Optional[int] = 0
|
||||
retention_days: Optional[int] = DEFAULT_RETENTION_DAYS
|
||||
|
||||
|
||||
class RecordingResponse(BaseModel):
|
||||
"""Recording details response."""
|
||||
id: str
|
||||
meeting_id: str
|
||||
title: Optional[str]
|
||||
storage_path: str
|
||||
audio_path: Optional[str]
|
||||
file_size_bytes: Optional[int]
|
||||
duration_seconds: Optional[int]
|
||||
participant_count: int
|
||||
status: str
|
||||
recorded_at: datetime
|
||||
retention_days: int
|
||||
retention_expires_at: datetime
|
||||
transcription_status: Optional[str] = None
|
||||
transcription_id: Optional[str] = None
|
||||
|
||||
|
||||
class RecordingListResponse(BaseModel):
|
||||
"""Paginated list of recordings."""
|
||||
recordings: List[RecordingResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class TranscriptionRequest(BaseModel):
|
||||
"""Request to start transcription."""
|
||||
language: str = Field(default="de", description="Language code: de, en, etc.")
|
||||
model: str = Field(default="large-v3", description="Whisper model to use")
|
||||
priority: int = Field(default=0, description="Queue priority (higher = sooner)")
|
||||
|
||||
|
||||
class TranscriptionStatusResponse(BaseModel):
|
||||
"""Transcription status and progress."""
|
||||
id: str
|
||||
recording_id: str
|
||||
status: str
|
||||
language: str
|
||||
model: str
|
||||
word_count: Optional[int]
|
||||
confidence_score: Optional[float]
|
||||
processing_duration_seconds: Optional[int]
|
||||
error_message: Optional[str]
|
||||
created_at: datetime
|
||||
completed_at: Optional[datetime]
|
||||
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Recording API - Core Recording Routes.
|
||||
|
||||
Webhook, CRUD, health, audit, and download endpoints.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from recording_models import (
|
||||
JibriWebhookPayload,
|
||||
RecordingResponse,
|
||||
RecordingListResponse,
|
||||
MINIO_ENDPOINT,
|
||||
MINIO_BUCKET,
|
||||
DEFAULT_RETENTION_DAYS,
|
||||
)
|
||||
from recording_helpers import (
|
||||
_recordings_store,
|
||||
_transcriptions_store,
|
||||
_audit_log,
|
||||
log_audit,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Recordings"])
|
||||
|
||||
|
||||
# ==========================================
|
||||
# WEBHOOK ENDPOINT (Jibri)
|
||||
# ==========================================
|
||||
|
||||
@router.post("/webhook")
|
||||
async def jibri_webhook(payload: JibriWebhookPayload, request: Request):
|
||||
"""
|
||||
Webhook endpoint called by Jibri finalize.sh after upload.
|
||||
|
||||
This creates a new recording entry and optionally triggers transcription.
|
||||
"""
|
||||
if payload.event != "recording_completed":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": f"Unknown event type: {payload.event}"}
|
||||
)
|
||||
|
||||
# Extract meeting_id from recording_name (format: meetingId_timestamp)
|
||||
parts = payload.recording_name.split("_")
|
||||
meeting_id = parts[0] if parts else payload.recording_name
|
||||
|
||||
# Create recording entry
|
||||
recording_id = str(uuid.uuid4())
|
||||
recorded_at = datetime.utcnow()
|
||||
|
||||
recording = {
|
||||
"id": recording_id,
|
||||
"meeting_id": meeting_id,
|
||||
"jibri_session_id": payload.recording_name,
|
||||
"title": f"Recording {meeting_id}",
|
||||
"storage_path": payload.storage_path,
|
||||
"audio_path": payload.audio_path,
|
||||
"file_size_bytes": payload.file_size_bytes,
|
||||
"duration_seconds": None, # Will be updated after analysis
|
||||
"participant_count": 0,
|
||||
"status": "uploaded",
|
||||
"recorded_at": recorded_at.isoformat(),
|
||||
"retention_days": DEFAULT_RETENTION_DAYS,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"updated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
_recordings_store[recording_id] = recording
|
||||
|
||||
# Log the creation
|
||||
log_audit(
|
||||
action="created",
|
||||
recording_id=recording_id,
|
||||
metadata={
|
||||
"source": "jibri_webhook",
|
||||
"storage_path": payload.storage_path,
|
||||
"file_size_bytes": payload.file_size_bytes
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"recording_id": recording_id,
|
||||
"meeting_id": meeting_id,
|
||||
"status": "uploaded",
|
||||
"message": "Recording registered successfully"
|
||||
}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# HEALTH & AUDIT ENDPOINTS (must be before parameterized routes)
|
||||
# ==========================================
|
||||
|
||||
@router.get("/health")
|
||||
async def recordings_health():
|
||||
"""Health check for recording service."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"recordings_count": len(_recordings_store),
|
||||
"transcriptions_count": len(_transcriptions_store),
|
||||
"minio_endpoint": MINIO_ENDPOINT,
|
||||
"bucket": MINIO_BUCKET
|
||||
}
|
||||
|
||||
|
||||
@router.get("/audit/log")
|
||||
async def get_audit_log(
|
||||
recording_id: Optional[str] = Query(None),
|
||||
action: Optional[str] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000)
|
||||
):
|
||||
"""
|
||||
Get audit log entries (DSGVO compliance).
|
||||
|
||||
Admin-only endpoint for reviewing recording access history.
|
||||
"""
|
||||
logs = _audit_log.copy()
|
||||
|
||||
if recording_id:
|
||||
logs = [l for l in logs if l.get("recording_id") == recording_id]
|
||||
if action:
|
||||
logs = [l for l in logs if l.get("action") == action]
|
||||
|
||||
# Sort by created_at descending
|
||||
logs.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
return {
|
||||
"entries": logs[:limit],
|
||||
"total": len(logs)
|
||||
}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# RECORDING MANAGEMENT ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("", response_model=RecordingListResponse)
|
||||
async def list_recordings(
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
meeting_id: Optional[str] = Query(None, description="Filter by meeting ID"),
|
||||
page: int = Query(1, ge=1, description="Page number"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="Items per page")
|
||||
):
|
||||
"""
|
||||
List all recordings with optional filtering.
|
||||
|
||||
Supports pagination and filtering by status or meeting ID.
|
||||
"""
|
||||
# Filter recordings
|
||||
recordings = list(_recordings_store.values())
|
||||
|
||||
if status:
|
||||
recordings = [r for r in recordings if r["status"] == status]
|
||||
if meeting_id:
|
||||
recordings = [r for r in recordings if r["meeting_id"] == meeting_id]
|
||||
|
||||
# Sort by recorded_at descending
|
||||
recordings.sort(key=lambda x: x["recorded_at"], reverse=True)
|
||||
|
||||
# Paginate
|
||||
total = len(recordings)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
page_recordings = recordings[start:end]
|
||||
|
||||
# Convert to response format
|
||||
result = []
|
||||
for rec in page_recordings:
|
||||
recorded_at = datetime.fromisoformat(rec["recorded_at"])
|
||||
retention_expires = recorded_at + timedelta(days=rec["retention_days"])
|
||||
|
||||
# Check for transcription
|
||||
trans = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == rec["id"]),
|
||||
None
|
||||
)
|
||||
|
||||
result.append(RecordingResponse(
|
||||
id=rec["id"],
|
||||
meeting_id=rec["meeting_id"],
|
||||
title=rec.get("title"),
|
||||
storage_path=rec["storage_path"],
|
||||
audio_path=rec.get("audio_path"),
|
||||
file_size_bytes=rec.get("file_size_bytes"),
|
||||
duration_seconds=rec.get("duration_seconds"),
|
||||
participant_count=rec.get("participant_count", 0),
|
||||
status=rec["status"],
|
||||
recorded_at=recorded_at,
|
||||
retention_days=rec["retention_days"],
|
||||
retention_expires_at=retention_expires,
|
||||
transcription_status=trans["status"] if trans else None,
|
||||
transcription_id=trans["id"] if trans else None
|
||||
))
|
||||
|
||||
return RecordingListResponse(
|
||||
recordings=result,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}", response_model=RecordingResponse)
|
||||
async def get_recording(recording_id: str):
|
||||
"""
|
||||
Get details for a specific recording.
|
||||
"""
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
# Log view action
|
||||
log_audit(action="viewed", recording_id=recording_id)
|
||||
|
||||
recorded_at = datetime.fromisoformat(recording["recorded_at"])
|
||||
retention_expires = recorded_at + timedelta(days=recording["retention_days"])
|
||||
|
||||
# Check for transcription
|
||||
trans = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
|
||||
return RecordingResponse(
|
||||
id=recording["id"],
|
||||
meeting_id=recording["meeting_id"],
|
||||
title=recording.get("title"),
|
||||
storage_path=recording["storage_path"],
|
||||
audio_path=recording.get("audio_path"),
|
||||
file_size_bytes=recording.get("file_size_bytes"),
|
||||
duration_seconds=recording.get("duration_seconds"),
|
||||
participant_count=recording.get("participant_count", 0),
|
||||
status=recording["status"],
|
||||
recorded_at=recorded_at,
|
||||
retention_days=recording["retention_days"],
|
||||
retention_expires_at=retention_expires,
|
||||
transcription_status=trans["status"] if trans else None,
|
||||
transcription_id=trans["id"] if trans else None
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{recording_id}")
|
||||
async def delete_recording(
|
||||
recording_id: str,
|
||||
reason: str = Query(..., description="Reason for deletion (DSGVO audit)")
|
||||
):
|
||||
"""
|
||||
Soft-delete a recording (DSGVO compliance).
|
||||
|
||||
The recording is marked as deleted but retained for audit purposes.
|
||||
Actual file deletion happens after the audit retention period.
|
||||
"""
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
# Soft delete
|
||||
recording["status"] = "deleted"
|
||||
recording["deleted_at"] = datetime.utcnow().isoformat()
|
||||
recording["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
# Log deletion with reason
|
||||
log_audit(
|
||||
action="deleted",
|
||||
recording_id=recording_id,
|
||||
metadata={"reason": reason}
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"recording_id": recording_id,
|
||||
"status": "deleted",
|
||||
"message": "Recording marked for deletion"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{recording_id}/download")
|
||||
async def download_recording(recording_id: str):
|
||||
"""
|
||||
Download the recording file.
|
||||
|
||||
In production, this would generate a presigned URL to MinIO.
|
||||
"""
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
if recording["status"] == "deleted":
|
||||
raise HTTPException(status_code=410, detail="Recording has been deleted")
|
||||
|
||||
# Log download action
|
||||
log_audit(action="downloaded", recording_id=recording_id)
|
||||
|
||||
# In production, generate presigned URL to MinIO
|
||||
# For now, return info about where the file is
|
||||
return {
|
||||
"recording_id": recording_id,
|
||||
"storage_path": recording["storage_path"],
|
||||
"file_size_bytes": recording.get("file_size_bytes"),
|
||||
"message": "In production, this would redirect to a presigned MinIO URL"
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Recording API - Transcription Routes.
|
||||
|
||||
Start transcription, get status, download VTT/SRT subtitle files.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from recording_models import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionStatusResponse,
|
||||
)
|
||||
from recording_helpers import (
|
||||
_recordings_store,
|
||||
_transcriptions_store,
|
||||
log_audit,
|
||||
format_vtt_time,
|
||||
format_srt_time,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Recordings"])
|
||||
|
||||
|
||||
# ==========================================
|
||||
# TRANSCRIPTION ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.post("/{recording_id}/transcribe", response_model=TranscriptionStatusResponse)
|
||||
async def start_transcription(recording_id: str, request: TranscriptionRequest):
|
||||
"""
|
||||
Start transcription for a recording.
|
||||
|
||||
Queues the recording for processing by the transcription worker.
|
||||
"""
|
||||
recording = _recordings_store.get(recording_id)
|
||||
if not recording:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
if recording["status"] == "deleted":
|
||||
raise HTTPException(status_code=400, detail="Cannot transcribe deleted recording")
|
||||
|
||||
# Check if transcription already exists
|
||||
existing = next(
|
||||
(t for t in _transcriptions_store.values()
|
||||
if t["recording_id"] == recording_id and t["status"] != "failed"),
|
||||
None
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Transcription already exists with status: {existing['status']}"
|
||||
)
|
||||
|
||||
# Create transcription entry
|
||||
transcription_id = str(uuid.uuid4())
|
||||
now = datetime.utcnow()
|
||||
|
||||
transcription = {
|
||||
"id": transcription_id,
|
||||
"recording_id": recording_id,
|
||||
"language": request.language,
|
||||
"model": request.model,
|
||||
"status": "pending",
|
||||
"full_text": None,
|
||||
"word_count": None,
|
||||
"confidence_score": None,
|
||||
"vtt_path": None,
|
||||
"srt_path": None,
|
||||
"json_path": None,
|
||||
"error_message": None,
|
||||
"processing_started_at": None,
|
||||
"processing_completed_at": None,
|
||||
"processing_duration_seconds": None,
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat()
|
||||
}
|
||||
|
||||
_transcriptions_store[transcription_id] = transcription
|
||||
|
||||
# Update recording status
|
||||
recording["status"] = "processing"
|
||||
recording["updated_at"] = now.isoformat()
|
||||
|
||||
# Log transcription start
|
||||
log_audit(
|
||||
action="transcription_started",
|
||||
recording_id=recording_id,
|
||||
transcription_id=transcription_id,
|
||||
metadata={"language": request.language, "model": request.model}
|
||||
)
|
||||
|
||||
# TODO: Queue job to Redis/Valkey for transcription worker
|
||||
|
||||
return TranscriptionStatusResponse(
|
||||
id=transcription_id,
|
||||
recording_id=recording_id,
|
||||
status="pending",
|
||||
language=request.language,
|
||||
model=request.model,
|
||||
word_count=None,
|
||||
confidence_score=None,
|
||||
processing_duration_seconds=None,
|
||||
error_message=None,
|
||||
created_at=now,
|
||||
completed_at=None
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/transcription", response_model=TranscriptionStatusResponse)
|
||||
async def get_transcription_status(recording_id: str):
|
||||
"""
|
||||
Get transcription status for a recording.
|
||||
"""
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=404, detail="No transcription found for this recording")
|
||||
|
||||
return TranscriptionStatusResponse(
|
||||
id=transcription["id"],
|
||||
recording_id=transcription["recording_id"],
|
||||
status=transcription["status"],
|
||||
language=transcription["language"],
|
||||
model=transcription["model"],
|
||||
word_count=transcription.get("word_count"),
|
||||
confidence_score=transcription.get("confidence_score"),
|
||||
processing_duration_seconds=transcription.get("processing_duration_seconds"),
|
||||
error_message=transcription.get("error_message"),
|
||||
created_at=datetime.fromisoformat(transcription["created_at"]),
|
||||
completed_at=(
|
||||
datetime.fromisoformat(transcription["processing_completed_at"])
|
||||
if transcription.get("processing_completed_at") else None
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/transcription/text")
|
||||
async def get_transcription_text(recording_id: str):
|
||||
"""
|
||||
Get the full transcription text.
|
||||
"""
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=404, detail="No transcription found for this recording")
|
||||
|
||||
if transcription["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Transcription not ready. Status: {transcription['status']}"
|
||||
)
|
||||
|
||||
return {
|
||||
"transcription_id": transcription["id"],
|
||||
"recording_id": recording_id,
|
||||
"language": transcription["language"],
|
||||
"text": transcription.get("full_text", ""),
|
||||
"word_count": transcription.get("word_count", 0)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{recording_id}/transcription/vtt")
|
||||
async def get_transcription_vtt(recording_id: str):
|
||||
"""
|
||||
Download transcription as WebVTT subtitle file.
|
||||
"""
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=404, detail="No transcription found for this recording")
|
||||
|
||||
if transcription["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Transcription not ready. Status: {transcription['status']}"
|
||||
)
|
||||
|
||||
# Generate VTT content
|
||||
vtt_content = "WEBVTT\n\n"
|
||||
text = transcription.get("full_text", "")
|
||||
|
||||
if text:
|
||||
sentences = text.replace(".", ".\n").split("\n")
|
||||
time_offset = 0
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
start = format_vtt_time(time_offset)
|
||||
time_offset += 3000
|
||||
end = format_vtt_time(time_offset)
|
||||
vtt_content += f"{start} --> {end}\n{sentence}\n\n"
|
||||
|
||||
return PlainTextResponse(
|
||||
content=vtt_content,
|
||||
media_type="text/vtt",
|
||||
headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.vtt"}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{recording_id}/transcription/srt")
|
||||
async def get_transcription_srt(recording_id: str):
|
||||
"""
|
||||
Download transcription as SRT subtitle file.
|
||||
"""
|
||||
transcription = next(
|
||||
(t for t in _transcriptions_store.values() if t["recording_id"] == recording_id),
|
||||
None
|
||||
)
|
||||
if not transcription:
|
||||
raise HTTPException(status_code=404, detail="No transcription found for this recording")
|
||||
|
||||
if transcription["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Transcription not ready. Status: {transcription['status']}"
|
||||
)
|
||||
|
||||
# Generate SRT content
|
||||
srt_content = ""
|
||||
text = transcription.get("full_text", "")
|
||||
|
||||
if text:
|
||||
sentences = text.replace(".", ".\n").split("\n")
|
||||
time_offset = 0
|
||||
index = 1
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
start = format_srt_time(time_offset)
|
||||
time_offset += 3000
|
||||
end = format_srt_time(time_offset)
|
||||
srt_content += f"{index}\n{start} --> {end}\n{sentence}\n\n"
|
||||
index += 1
|
||||
|
||||
return PlainTextResponse(
|
||||
content=srt_content,
|
||||
media_type="text/plain",
|
||||
headers={"Content-Disposition": f"attachment; filename=transcript_{recording_id}.srt"}
|
||||
)
|
||||
@@ -1,751 +1,25 @@
|
||||
# ==============================================
|
||||
# Breakpilot Drive - Unit Analytics API
|
||||
# ==============================================
|
||||
# Erweiterte Analytics fuer Lernfortschritt:
|
||||
# - Pre/Post Gain Visualisierung
|
||||
# - Misconception-Tracking
|
||||
# - Stop-Level Analytics
|
||||
# - Aggregierte Klassen-Statistiken
|
||||
# - Export-Funktionen
|
||||
"""
|
||||
Breakpilot Drive - Unit Analytics API — Barrel Re-export.
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import os
|
||||
import logging
|
||||
import statistics
|
||||
Erweiterte Analytics fuer Lernfortschritt:
|
||||
- Pre/Post Gain Visualisierung
|
||||
- Misconception-Tracking
|
||||
- Stop-Level Analytics
|
||||
- Aggregierte Klassen-Statistiken
|
||||
- Export-Funktionen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Split into:
|
||||
- unit_analytics_models.py: Pydantic models & enums
|
||||
- unit_analytics_helpers.py: Database access & computation helpers
|
||||
- unit_analytics_routes.py: Core analytics endpoint handlers
|
||||
- unit_analytics_export.py: Export & dashboard endpoints
|
||||
"""
|
||||
|
||||
# Feature flags
|
||||
USE_DATABASE = os.getenv("GAME_USE_DATABASE", "true").lower() == "true"
|
||||
from fastapi import APIRouter
|
||||
|
||||
from unit_analytics_routes import router as _routes_router
|
||||
from unit_analytics_export import router as _export_router
|
||||
|
||||
router = APIRouter(prefix="/api/analytics", tags=["Unit Analytics"])
|
||||
|
||||
|
||||
# ==============================================
|
||||
# Pydantic Models
|
||||
# ==============================================
|
||||
|
||||
class TimeRange(str, Enum):
|
||||
"""Time range for analytics queries"""
|
||||
WEEK = "week"
|
||||
MONTH = "month"
|
||||
QUARTER = "quarter"
|
||||
ALL = "all"
|
||||
|
||||
|
||||
class LearningGainData(BaseModel):
|
||||
"""Pre/Post learning gain data point"""
|
||||
student_id: str
|
||||
student_name: str
|
||||
unit_id: str
|
||||
precheck_score: float
|
||||
postcheck_score: float
|
||||
learning_gain: float
|
||||
percentile: Optional[float] = None
|
||||
|
||||
|
||||
class LearningGainSummary(BaseModel):
|
||||
"""Aggregated learning gain statistics"""
|
||||
unit_id: str
|
||||
unit_title: str
|
||||
total_students: int
|
||||
avg_precheck: float
|
||||
avg_postcheck: float
|
||||
avg_gain: float
|
||||
median_gain: float
|
||||
std_deviation: float
|
||||
positive_gain_count: int
|
||||
negative_gain_count: int
|
||||
no_change_count: int
|
||||
gain_distribution: Dict[str, int] # "-20+", "-10-0", "0-10", "10-20", "20+"
|
||||
individual_gains: List[LearningGainData]
|
||||
|
||||
|
||||
class StopPerformance(BaseModel):
|
||||
"""Performance data for a single stop"""
|
||||
stop_id: str
|
||||
stop_label: str
|
||||
attempts_total: int
|
||||
success_rate: float
|
||||
avg_time_seconds: float
|
||||
avg_attempts_before_success: float
|
||||
common_errors: List[str]
|
||||
difficulty_rating: float # 1-5 based on performance
|
||||
|
||||
|
||||
class UnitPerformanceDetail(BaseModel):
|
||||
"""Detailed unit performance breakdown"""
|
||||
unit_id: str
|
||||
unit_title: str
|
||||
template: str
|
||||
total_sessions: int
|
||||
completed_sessions: int
|
||||
completion_rate: float
|
||||
avg_duration_minutes: float
|
||||
stops: List[StopPerformance]
|
||||
bottleneck_stops: List[str] # Stops where students struggle most
|
||||
|
||||
|
||||
class MisconceptionEntry(BaseModel):
|
||||
"""Individual misconception tracking"""
|
||||
concept_id: str
|
||||
concept_label: str
|
||||
misconception_text: str
|
||||
frequency: int
|
||||
affected_student_ids: List[str]
|
||||
unit_id: str
|
||||
stop_id: str
|
||||
detected_via: str # "precheck", "postcheck", "interaction"
|
||||
first_detected: datetime
|
||||
last_detected: datetime
|
||||
|
||||
|
||||
class MisconceptionReport(BaseModel):
|
||||
"""Comprehensive misconception report"""
|
||||
class_id: Optional[str]
|
||||
time_range: str
|
||||
total_misconceptions: int
|
||||
unique_concepts: int
|
||||
most_common: List[MisconceptionEntry]
|
||||
by_unit: Dict[str, List[MisconceptionEntry]]
|
||||
trending_up: List[MisconceptionEntry] # Getting more frequent
|
||||
resolved: List[MisconceptionEntry] # No longer appearing
|
||||
|
||||
|
||||
class StudentProgressTimeline(BaseModel):
|
||||
"""Timeline of student progress"""
|
||||
student_id: str
|
||||
student_name: str
|
||||
units_completed: int
|
||||
total_time_minutes: int
|
||||
avg_score: float
|
||||
trend: str # "improving", "stable", "declining"
|
||||
timeline: List[Dict[str, Any]] # List of session events
|
||||
|
||||
|
||||
class ClassComparisonData(BaseModel):
|
||||
"""Data for comparing class performance"""
|
||||
class_id: str
|
||||
class_name: str
|
||||
student_count: int
|
||||
units_assigned: int
|
||||
avg_completion_rate: float
|
||||
avg_learning_gain: float
|
||||
avg_time_per_unit: float
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
"""Export format options"""
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
|
||||
|
||||
# ==============================================
|
||||
# Database Integration
|
||||
# ==============================================
|
||||
|
||||
_analytics_db = None
|
||||
|
||||
async def get_analytics_database():
|
||||
"""Get analytics database instance."""
|
||||
global _analytics_db
|
||||
if not USE_DATABASE:
|
||||
return None
|
||||
if _analytics_db is None:
|
||||
try:
|
||||
from unit.database import get_analytics_db
|
||||
_analytics_db = await get_analytics_db()
|
||||
logger.info("Analytics database initialized")
|
||||
except ImportError:
|
||||
logger.warning("Analytics database module not available")
|
||||
except Exception as e:
|
||||
logger.warning(f"Analytics database not available: {e}")
|
||||
return _analytics_db
|
||||
|
||||
|
||||
# ==============================================
|
||||
# Helper Functions
|
||||
# ==============================================
|
||||
|
||||
def calculate_gain_distribution(gains: List[float]) -> Dict[str, int]:
|
||||
"""Calculate distribution of learning gains into buckets."""
|
||||
distribution = {
|
||||
"< -20%": 0,
|
||||
"-20% to -10%": 0,
|
||||
"-10% to 0%": 0,
|
||||
"0% to 10%": 0,
|
||||
"10% to 20%": 0,
|
||||
"> 20%": 0,
|
||||
}
|
||||
|
||||
for gain in gains:
|
||||
gain_percent = gain * 100
|
||||
if gain_percent < -20:
|
||||
distribution["< -20%"] += 1
|
||||
elif gain_percent < -10:
|
||||
distribution["-20% to -10%"] += 1
|
||||
elif gain_percent < 0:
|
||||
distribution["-10% to 0%"] += 1
|
||||
elif gain_percent < 10:
|
||||
distribution["0% to 10%"] += 1
|
||||
elif gain_percent < 20:
|
||||
distribution["10% to 20%"] += 1
|
||||
else:
|
||||
distribution["> 20%"] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_trend(scores: List[float]) -> str:
|
||||
"""Calculate trend from a series of scores."""
|
||||
if len(scores) < 3:
|
||||
return "insufficient_data"
|
||||
|
||||
# Simple linear regression
|
||||
n = len(scores)
|
||||
x_mean = (n - 1) / 2
|
||||
y_mean = sum(scores) / n
|
||||
|
||||
numerator = sum((i - x_mean) * (scores[i] - y_mean) for i in range(n))
|
||||
denominator = sum((i - x_mean) ** 2 for i in range(n))
|
||||
|
||||
if denominator == 0:
|
||||
return "stable"
|
||||
|
||||
slope = numerator / denominator
|
||||
|
||||
if slope > 0.05:
|
||||
return "improving"
|
||||
elif slope < -0.05:
|
||||
return "declining"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
|
||||
def calculate_difficulty_rating(success_rate: float, avg_attempts: float) -> float:
|
||||
"""Calculate difficulty rating 1-5 based on success metrics."""
|
||||
# Lower success rate and higher attempts = higher difficulty
|
||||
base_difficulty = (1 - success_rate) * 3 + 1 # 1-4 range
|
||||
attempt_modifier = min(avg_attempts - 1, 1) # 0-1 range
|
||||
return min(5.0, base_difficulty + attempt_modifier)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Learning Gain
|
||||
# ==============================================
|
||||
|
||||
# NOTE: Static routes must come BEFORE dynamic routes like /{unit_id}
|
||||
@router.get("/learning-gain/compare")
|
||||
async def compare_learning_gains(
|
||||
unit_ids: str = Query(..., description="Comma-separated unit IDs"),
|
||||
class_id: Optional[str] = Query(None),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Compare learning gains across multiple units.
|
||||
"""
|
||||
unit_list = [u.strip() for u in unit_ids.split(",")]
|
||||
comparisons = []
|
||||
|
||||
for unit_id in unit_list:
|
||||
try:
|
||||
summary = await get_learning_gain_analysis(unit_id, class_id, time_range)
|
||||
comparisons.append({
|
||||
"unit_id": unit_id,
|
||||
"avg_gain": summary.avg_gain,
|
||||
"median_gain": summary.median_gain,
|
||||
"total_students": summary.total_students,
|
||||
"positive_rate": summary.positive_gain_count / max(summary.total_students, 1),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get comparison for {unit_id}: {e}")
|
||||
|
||||
return {
|
||||
"time_range": time_range.value,
|
||||
"class_id": class_id,
|
||||
"comparisons": sorted(comparisons, key=lambda x: x["avg_gain"], reverse=True),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/learning-gain/{unit_id}", response_model=LearningGainSummary)
|
||||
async def get_learning_gain_analysis(
|
||||
unit_id: str,
|
||||
class_id: Optional[str] = Query(None, description="Filter by class"),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH, description="Time range for analysis"),
|
||||
) -> LearningGainSummary:
|
||||
"""
|
||||
Get detailed pre/post learning gain analysis for a unit.
|
||||
|
||||
Shows individual gains, aggregated statistics, and distribution.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
individual_gains = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
# Get all sessions with pre/post scores for this unit
|
||||
sessions = await db.get_unit_sessions_with_scores(
|
||||
unit_id=unit_id,
|
||||
class_id=class_id,
|
||||
time_range=time_range.value
|
||||
)
|
||||
|
||||
for session in sessions:
|
||||
if session.get("precheck_score") is not None and session.get("postcheck_score") is not None:
|
||||
gain = session["postcheck_score"] - session["precheck_score"]
|
||||
individual_gains.append(LearningGainData(
|
||||
student_id=session["student_id"],
|
||||
student_name=session.get("student_name", session["student_id"][:8]),
|
||||
unit_id=unit_id,
|
||||
precheck_score=session["precheck_score"],
|
||||
postcheck_score=session["postcheck_score"],
|
||||
learning_gain=gain,
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get learning gain data: {e}")
|
||||
|
||||
# Calculate statistics
|
||||
if not individual_gains:
|
||||
# Return empty summary
|
||||
return LearningGainSummary(
|
||||
unit_id=unit_id,
|
||||
unit_title=f"Unit {unit_id}",
|
||||
total_students=0,
|
||||
avg_precheck=0.0,
|
||||
avg_postcheck=0.0,
|
||||
avg_gain=0.0,
|
||||
median_gain=0.0,
|
||||
std_deviation=0.0,
|
||||
positive_gain_count=0,
|
||||
negative_gain_count=0,
|
||||
no_change_count=0,
|
||||
gain_distribution={},
|
||||
individual_gains=[],
|
||||
)
|
||||
|
||||
gains = [g.learning_gain for g in individual_gains]
|
||||
prechecks = [g.precheck_score for g in individual_gains]
|
||||
postchecks = [g.postcheck_score for g in individual_gains]
|
||||
|
||||
avg_gain = statistics.mean(gains)
|
||||
median_gain = statistics.median(gains)
|
||||
std_dev = statistics.stdev(gains) if len(gains) > 1 else 0.0
|
||||
|
||||
# Calculate percentiles
|
||||
sorted_gains = sorted(gains)
|
||||
for data in individual_gains:
|
||||
rank = sorted_gains.index(data.learning_gain) + 1
|
||||
data.percentile = rank / len(sorted_gains) * 100
|
||||
|
||||
return LearningGainSummary(
|
||||
unit_id=unit_id,
|
||||
unit_title=f"Unit {unit_id}",
|
||||
total_students=len(individual_gains),
|
||||
avg_precheck=statistics.mean(prechecks),
|
||||
avg_postcheck=statistics.mean(postchecks),
|
||||
avg_gain=avg_gain,
|
||||
median_gain=median_gain,
|
||||
std_deviation=std_dev,
|
||||
positive_gain_count=sum(1 for g in gains if g > 0.01),
|
||||
negative_gain_count=sum(1 for g in gains if g < -0.01),
|
||||
no_change_count=sum(1 for g in gains if -0.01 <= g <= 0.01),
|
||||
gain_distribution=calculate_gain_distribution(gains),
|
||||
individual_gains=sorted(individual_gains, key=lambda x: x.learning_gain, reverse=True),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Stop-Level Analytics
|
||||
# ==============================================
|
||||
|
||||
@router.get("/unit/{unit_id}/stops", response_model=UnitPerformanceDetail)
|
||||
async def get_unit_stop_analytics(
|
||||
unit_id: str,
|
||||
class_id: Optional[str] = Query(None),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
) -> UnitPerformanceDetail:
|
||||
"""
|
||||
Get detailed stop-level performance analytics.
|
||||
|
||||
Identifies bottleneck stops where students struggle most.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
stops_data = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
# Get stop-level telemetry
|
||||
stop_stats = await db.get_stop_performance(
|
||||
unit_id=unit_id,
|
||||
class_id=class_id,
|
||||
time_range=time_range.value
|
||||
)
|
||||
|
||||
for stop in stop_stats:
|
||||
difficulty = calculate_difficulty_rating(
|
||||
stop.get("success_rate", 0.5),
|
||||
stop.get("avg_attempts", 1.0)
|
||||
)
|
||||
stops_data.append(StopPerformance(
|
||||
stop_id=stop["stop_id"],
|
||||
stop_label=stop.get("stop_label", stop["stop_id"]),
|
||||
attempts_total=stop.get("total_attempts", 0),
|
||||
success_rate=stop.get("success_rate", 0.0),
|
||||
avg_time_seconds=stop.get("avg_time_seconds", 0.0),
|
||||
avg_attempts_before_success=stop.get("avg_attempts", 1.0),
|
||||
common_errors=stop.get("common_errors", []),
|
||||
difficulty_rating=difficulty,
|
||||
))
|
||||
|
||||
# Get overall unit stats
|
||||
unit_stats = await db.get_unit_overall_stats(unit_id, class_id, time_range.value)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stop analytics: {e}")
|
||||
unit_stats = {}
|
||||
else:
|
||||
unit_stats = {}
|
||||
|
||||
# Identify bottleneck stops (difficulty > 3.5 or success rate < 0.6)
|
||||
bottlenecks = [
|
||||
s.stop_id for s in stops_data
|
||||
if s.difficulty_rating > 3.5 or s.success_rate < 0.6
|
||||
]
|
||||
|
||||
return UnitPerformanceDetail(
|
||||
unit_id=unit_id,
|
||||
unit_title=f"Unit {unit_id}",
|
||||
template=unit_stats.get("template", "unknown"),
|
||||
total_sessions=unit_stats.get("total_sessions", 0),
|
||||
completed_sessions=unit_stats.get("completed_sessions", 0),
|
||||
completion_rate=unit_stats.get("completion_rate", 0.0),
|
||||
avg_duration_minutes=unit_stats.get("avg_duration_minutes", 0.0),
|
||||
stops=stops_data,
|
||||
bottleneck_stops=bottlenecks,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Misconception Tracking
|
||||
# ==============================================
|
||||
|
||||
@router.get("/misconceptions", response_model=MisconceptionReport)
|
||||
async def get_misconception_report(
|
||||
class_id: Optional[str] = Query(None),
|
||||
unit_id: Optional[str] = Query(None),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
) -> MisconceptionReport:
|
||||
"""
|
||||
Get comprehensive misconception report.
|
||||
|
||||
Shows most common misconceptions and their frequency.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
misconceptions = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
raw_misconceptions = await db.get_misconceptions(
|
||||
class_id=class_id,
|
||||
unit_id=unit_id,
|
||||
time_range=time_range.value,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
for m in raw_misconceptions:
|
||||
misconceptions.append(MisconceptionEntry(
|
||||
concept_id=m["concept_id"],
|
||||
concept_label=m["concept_label"],
|
||||
misconception_text=m["misconception_text"],
|
||||
frequency=m["frequency"],
|
||||
affected_student_ids=m.get("student_ids", []),
|
||||
unit_id=m["unit_id"],
|
||||
stop_id=m["stop_id"],
|
||||
detected_via=m.get("detected_via", "unknown"),
|
||||
first_detected=m.get("first_detected", datetime.utcnow()),
|
||||
last_detected=m.get("last_detected", datetime.utcnow()),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get misconceptions: {e}")
|
||||
|
||||
# Group by unit
|
||||
by_unit = {}
|
||||
for m in misconceptions:
|
||||
if m.unit_id not in by_unit:
|
||||
by_unit[m.unit_id] = []
|
||||
by_unit[m.unit_id].append(m)
|
||||
|
||||
# Identify trending misconceptions (would need historical comparison in production)
|
||||
trending_up = misconceptions[:3] if misconceptions else []
|
||||
resolved = [] # Would identify from historical data
|
||||
|
||||
return MisconceptionReport(
|
||||
class_id=class_id,
|
||||
time_range=time_range.value,
|
||||
total_misconceptions=sum(m.frequency for m in misconceptions),
|
||||
unique_concepts=len(set(m.concept_id for m in misconceptions)),
|
||||
most_common=sorted(misconceptions, key=lambda x: x.frequency, reverse=True)[:10],
|
||||
by_unit=by_unit,
|
||||
trending_up=trending_up,
|
||||
resolved=resolved,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/misconceptions/student/{student_id}")
|
||||
async def get_student_misconceptions(
|
||||
student_id: str,
|
||||
time_range: TimeRange = Query(TimeRange.ALL),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get misconceptions for a specific student.
|
||||
|
||||
Useful for personalized remediation.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
|
||||
if db:
|
||||
try:
|
||||
misconceptions = await db.get_student_misconceptions(
|
||||
student_id=student_id,
|
||||
time_range=time_range.value
|
||||
)
|
||||
return {
|
||||
"student_id": student_id,
|
||||
"misconceptions": misconceptions,
|
||||
"recommended_remediation": [
|
||||
{"concept": m["concept_label"], "activity": f"Review {m['unit_id']}/{m['stop_id']}"}
|
||||
for m in misconceptions[:5]
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get student misconceptions: {e}")
|
||||
|
||||
return {
|
||||
"student_id": student_id,
|
||||
"misconceptions": [],
|
||||
"recommended_remediation": [],
|
||||
}
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Student Progress Timeline
|
||||
# ==============================================
|
||||
|
||||
@router.get("/student/{student_id}/timeline", response_model=StudentProgressTimeline)
|
||||
async def get_student_timeline(
|
||||
student_id: str,
|
||||
time_range: TimeRange = Query(TimeRange.ALL),
|
||||
) -> StudentProgressTimeline:
|
||||
"""
|
||||
Get detailed progress timeline for a student.
|
||||
|
||||
Shows all unit sessions and performance trend.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
timeline = []
|
||||
scores = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
sessions = await db.get_student_sessions(
|
||||
student_id=student_id,
|
||||
time_range=time_range.value
|
||||
)
|
||||
|
||||
for session in sessions:
|
||||
timeline.append({
|
||||
"date": session.get("started_at"),
|
||||
"unit_id": session.get("unit_id"),
|
||||
"completed": session.get("completed_at") is not None,
|
||||
"precheck": session.get("precheck_score"),
|
||||
"postcheck": session.get("postcheck_score"),
|
||||
"duration_minutes": session.get("duration_seconds", 0) // 60,
|
||||
})
|
||||
if session.get("postcheck_score") is not None:
|
||||
scores.append(session["postcheck_score"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get student timeline: {e}")
|
||||
|
||||
trend = calculate_trend(scores) if scores else "insufficient_data"
|
||||
|
||||
return StudentProgressTimeline(
|
||||
student_id=student_id,
|
||||
student_name=f"Student {student_id[:8]}", # Would load actual name
|
||||
units_completed=sum(1 for t in timeline if t["completed"]),
|
||||
total_time_minutes=sum(t["duration_minutes"] for t in timeline),
|
||||
avg_score=statistics.mean(scores) if scores else 0.0,
|
||||
trend=trend,
|
||||
timeline=timeline,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Class Comparison
|
||||
# ==============================================
|
||||
|
||||
@router.get("/compare/classes", response_model=List[ClassComparisonData])
|
||||
async def compare_classes(
|
||||
class_ids: str = Query(..., description="Comma-separated class IDs"),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
) -> List[ClassComparisonData]:
|
||||
"""
|
||||
Compare performance across multiple classes.
|
||||
"""
|
||||
class_list = [c.strip() for c in class_ids.split(",")]
|
||||
comparisons = []
|
||||
|
||||
db = await get_analytics_database()
|
||||
if db:
|
||||
for class_id in class_list:
|
||||
try:
|
||||
stats = await db.get_class_aggregate_stats(class_id, time_range.value)
|
||||
comparisons.append(ClassComparisonData(
|
||||
class_id=class_id,
|
||||
class_name=stats.get("class_name", f"Klasse {class_id[:8]}"),
|
||||
student_count=stats.get("student_count", 0),
|
||||
units_assigned=stats.get("units_assigned", 0),
|
||||
avg_completion_rate=stats.get("avg_completion_rate", 0.0),
|
||||
avg_learning_gain=stats.get("avg_learning_gain", 0.0),
|
||||
avg_time_per_unit=stats.get("avg_time_per_unit", 0.0),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats for class {class_id}: {e}")
|
||||
|
||||
return sorted(comparisons, key=lambda x: x.avg_learning_gain, reverse=True)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Export
|
||||
# ==============================================
|
||||
|
||||
@router.get("/export/learning-gains")
|
||||
async def export_learning_gains(
|
||||
unit_id: Optional[str] = Query(None),
|
||||
class_id: Optional[str] = Query(None),
|
||||
time_range: TimeRange = Query(TimeRange.ALL),
|
||||
format: ExportFormat = Query(ExportFormat.JSON),
|
||||
) -> Any:
|
||||
"""
|
||||
Export learning gain data.
|
||||
"""
|
||||
from fastapi.responses import Response
|
||||
|
||||
db = await get_analytics_database()
|
||||
data = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
data = await db.export_learning_gains(
|
||||
unit_id=unit_id,
|
||||
class_id=class_id,
|
||||
time_range=time_range.value
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export data: {e}")
|
||||
|
||||
if format == ExportFormat.CSV:
|
||||
# Convert to CSV
|
||||
if not data:
|
||||
csv_content = "student_id,unit_id,precheck,postcheck,gain\n"
|
||||
else:
|
||||
csv_content = "student_id,unit_id,precheck,postcheck,gain\n"
|
||||
for row in data:
|
||||
csv_content += f"{row['student_id']},{row['unit_id']},{row.get('precheck', '')},{row.get('postcheck', '')},{row.get('gain', '')}\n"
|
||||
|
||||
return Response(
|
||||
content=csv_content,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=learning_gains.csv"}
|
||||
)
|
||||
|
||||
return {
|
||||
"export_date": datetime.utcnow().isoformat(),
|
||||
"filters": {
|
||||
"unit_id": unit_id,
|
||||
"class_id": class_id,
|
||||
"time_range": time_range.value,
|
||||
},
|
||||
"data": data,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/export/misconceptions")
|
||||
async def export_misconceptions(
|
||||
class_id: Optional[str] = Query(None),
|
||||
format: ExportFormat = Query(ExportFormat.JSON),
|
||||
) -> Any:
|
||||
"""
|
||||
Export misconception data for further analysis.
|
||||
"""
|
||||
report = await get_misconception_report(
|
||||
class_id=class_id,
|
||||
unit_id=None,
|
||||
time_range=TimeRange.MONTH,
|
||||
limit=100
|
||||
)
|
||||
|
||||
if format == ExportFormat.CSV:
|
||||
from fastapi.responses import Response
|
||||
csv_content = "concept_id,concept_label,misconception,frequency,unit_id,stop_id\n"
|
||||
for m in report.most_common:
|
||||
csv_content += f'"{m.concept_id}","{m.concept_label}","{m.misconception_text}",{m.frequency},"{m.unit_id}","{m.stop_id}"\n'
|
||||
|
||||
return Response(
|
||||
content=csv_content,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=misconceptions.csv"}
|
||||
)
|
||||
|
||||
return {
|
||||
"export_date": datetime.utcnow().isoformat(),
|
||||
"class_id": class_id,
|
||||
"total_entries": len(report.most_common),
|
||||
"data": [m.model_dump() for m in report.most_common],
|
||||
}
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Dashboard Aggregates
|
||||
# ==============================================
|
||||
|
||||
@router.get("/dashboard/overview")
|
||||
async def get_analytics_overview(
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get high-level analytics overview for dashboard.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
|
||||
if db:
|
||||
try:
|
||||
overview = await db.get_analytics_overview(time_range.value)
|
||||
return overview
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get analytics overview: {e}")
|
||||
|
||||
return {
|
||||
"time_range": time_range.value,
|
||||
"total_sessions": 0,
|
||||
"unique_students": 0,
|
||||
"avg_completion_rate": 0.0,
|
||||
"avg_learning_gain": 0.0,
|
||||
"most_played_units": [],
|
||||
"struggling_concepts": [],
|
||||
"active_classes": 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""Health check for analytics API."""
|
||||
db = await get_analytics_database()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "unit-analytics",
|
||||
"database": "connected" if db else "disconnected",
|
||||
}
|
||||
router.include_router(_routes_router)
|
||||
router.include_router(_export_router)
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Unit Analytics API - Export & Dashboard Routes.
|
||||
|
||||
Export endpoints for learning gains and misconceptions, plus dashboard overview.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import Response
|
||||
|
||||
from unit_analytics_models import TimeRange, ExportFormat
|
||||
from unit_analytics_helpers import get_analytics_database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Unit Analytics"])
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Export
|
||||
# ==============================================
|
||||
|
||||
@router.get("/export/learning-gains")
|
||||
async def export_learning_gains(
|
||||
unit_id: Optional[str] = Query(None),
|
||||
class_id: Optional[str] = Query(None),
|
||||
time_range: TimeRange = Query(TimeRange.ALL),
|
||||
format: ExportFormat = Query(ExportFormat.JSON),
|
||||
) -> Any:
|
||||
"""
|
||||
Export learning gain data.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
data = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
data = await db.export_learning_gains(
|
||||
unit_id=unit_id, class_id=class_id, time_range=time_range.value
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export data: {e}")
|
||||
|
||||
if format == ExportFormat.CSV:
|
||||
if not data:
|
||||
csv_content = "student_id,unit_id,precheck,postcheck,gain\n"
|
||||
else:
|
||||
csv_content = "student_id,unit_id,precheck,postcheck,gain\n"
|
||||
for row in data:
|
||||
csv_content += f"{row['student_id']},{row['unit_id']},{row.get('precheck', '')},{row.get('postcheck', '')},{row.get('gain', '')}\n"
|
||||
|
||||
return Response(
|
||||
content=csv_content,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=learning_gains.csv"}
|
||||
)
|
||||
|
||||
return {
|
||||
"export_date": datetime.utcnow().isoformat(),
|
||||
"filters": {
|
||||
"unit_id": unit_id, "class_id": class_id, "time_range": time_range.value,
|
||||
},
|
||||
"data": data,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/export/misconceptions")
|
||||
async def export_misconceptions(
|
||||
class_id: Optional[str] = Query(None),
|
||||
format: ExportFormat = Query(ExportFormat.JSON),
|
||||
) -> Any:
|
||||
"""
|
||||
Export misconception data for further analysis.
|
||||
"""
|
||||
# Import here to avoid circular dependency
|
||||
from unit_analytics_routes import get_misconception_report
|
||||
|
||||
report = await get_misconception_report(
|
||||
class_id=class_id, unit_id=None,
|
||||
time_range=TimeRange.MONTH, limit=100
|
||||
)
|
||||
|
||||
if format == ExportFormat.CSV:
|
||||
csv_content = "concept_id,concept_label,misconception,frequency,unit_id,stop_id\n"
|
||||
for m in report.most_common:
|
||||
csv_content += f'"{m.concept_id}","{m.concept_label}","{m.misconception_text}",{m.frequency},"{m.unit_id}","{m.stop_id}"\n'
|
||||
|
||||
return Response(
|
||||
content=csv_content,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=misconceptions.csv"}
|
||||
)
|
||||
|
||||
return {
|
||||
"export_date": datetime.utcnow().isoformat(),
|
||||
"class_id": class_id,
|
||||
"total_entries": len(report.most_common),
|
||||
"data": [m.model_dump() for m in report.most_common],
|
||||
}
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Dashboard Aggregates
|
||||
# ==============================================
|
||||
|
||||
@router.get("/dashboard/overview")
|
||||
async def get_analytics_overview(
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get high-level analytics overview for dashboard.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
|
||||
if db:
|
||||
try:
|
||||
overview = await db.get_analytics_overview(time_range.value)
|
||||
return overview
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get analytics overview: {e}")
|
||||
|
||||
return {
|
||||
"time_range": time_range.value,
|
||||
"total_sessions": 0,
|
||||
"unique_students": 0,
|
||||
"avg_completion_rate": 0.0,
|
||||
"avg_learning_gain": 0.0,
|
||||
"most_played_units": [],
|
||||
"struggling_concepts": [],
|
||||
"active_classes": 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""Health check for analytics API."""
|
||||
db = await get_analytics_database()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "unit-analytics",
|
||||
"database": "connected" if db else "disconnected",
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Unit Analytics API - Helpers.
|
||||
|
||||
Database access, statistical computation, and utility functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Feature flags
|
||||
USE_DATABASE = os.getenv("GAME_USE_DATABASE", "true").lower() == "true"
|
||||
|
||||
# Database singleton
|
||||
_analytics_db = None
|
||||
|
||||
|
||||
async def get_analytics_database():
|
||||
"""Get analytics database instance."""
|
||||
global _analytics_db
|
||||
if not USE_DATABASE:
|
||||
return None
|
||||
if _analytics_db is None:
|
||||
try:
|
||||
from unit.database import get_analytics_db
|
||||
_analytics_db = await get_analytics_db()
|
||||
logger.info("Analytics database initialized")
|
||||
except ImportError:
|
||||
logger.warning("Analytics database module not available")
|
||||
except Exception as e:
|
||||
logger.warning(f"Analytics database not available: {e}")
|
||||
return _analytics_db
|
||||
|
||||
|
||||
def calculate_gain_distribution(gains: List[float]) -> Dict[str, int]:
|
||||
"""Calculate distribution of learning gains into buckets."""
|
||||
distribution = {
|
||||
"< -20%": 0,
|
||||
"-20% to -10%": 0,
|
||||
"-10% to 0%": 0,
|
||||
"0% to 10%": 0,
|
||||
"10% to 20%": 0,
|
||||
"> 20%": 0,
|
||||
}
|
||||
|
||||
for gain in gains:
|
||||
gain_percent = gain * 100
|
||||
if gain_percent < -20:
|
||||
distribution["< -20%"] += 1
|
||||
elif gain_percent < -10:
|
||||
distribution["-20% to -10%"] += 1
|
||||
elif gain_percent < 0:
|
||||
distribution["-10% to 0%"] += 1
|
||||
elif gain_percent < 10:
|
||||
distribution["0% to 10%"] += 1
|
||||
elif gain_percent < 20:
|
||||
distribution["10% to 20%"] += 1
|
||||
else:
|
||||
distribution["> 20%"] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_trend(scores: List[float]) -> str:
|
||||
"""Calculate trend from a series of scores."""
|
||||
if len(scores) < 3:
|
||||
return "insufficient_data"
|
||||
|
||||
# Simple linear regression
|
||||
n = len(scores)
|
||||
x_mean = (n - 1) / 2
|
||||
y_mean = sum(scores) / n
|
||||
|
||||
numerator = sum((i - x_mean) * (scores[i] - y_mean) for i in range(n))
|
||||
denominator = sum((i - x_mean) ** 2 for i in range(n))
|
||||
|
||||
if denominator == 0:
|
||||
return "stable"
|
||||
|
||||
slope = numerator / denominator
|
||||
|
||||
if slope > 0.05:
|
||||
return "improving"
|
||||
elif slope < -0.05:
|
||||
return "declining"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
|
||||
def calculate_difficulty_rating(success_rate: float, avg_attempts: float) -> float:
|
||||
"""Calculate difficulty rating 1-5 based on success metrics."""
|
||||
# Lower success rate and higher attempts = higher difficulty
|
||||
base_difficulty = (1 - success_rate) * 3 + 1 # 1-4 range
|
||||
attempt_modifier = min(avg_attempts - 1, 1) # 0-1 range
|
||||
return min(5.0, base_difficulty + attempt_modifier)
|
||||
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Unit Analytics API - Pydantic Models.
|
||||
|
||||
Data models for learning gains, stop performance, misconceptions,
|
||||
student progress, class comparison, and export.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TimeRange(str, Enum):
|
||||
"""Time range for analytics queries"""
|
||||
WEEK = "week"
|
||||
MONTH = "month"
|
||||
QUARTER = "quarter"
|
||||
ALL = "all"
|
||||
|
||||
|
||||
class LearningGainData(BaseModel):
|
||||
"""Pre/Post learning gain data point"""
|
||||
student_id: str
|
||||
student_name: str
|
||||
unit_id: str
|
||||
precheck_score: float
|
||||
postcheck_score: float
|
||||
learning_gain: float
|
||||
percentile: Optional[float] = None
|
||||
|
||||
|
||||
class LearningGainSummary(BaseModel):
|
||||
"""Aggregated learning gain statistics"""
|
||||
unit_id: str
|
||||
unit_title: str
|
||||
total_students: int
|
||||
avg_precheck: float
|
||||
avg_postcheck: float
|
||||
avg_gain: float
|
||||
median_gain: float
|
||||
std_deviation: float
|
||||
positive_gain_count: int
|
||||
negative_gain_count: int
|
||||
no_change_count: int
|
||||
gain_distribution: Dict[str, int]
|
||||
individual_gains: List[LearningGainData]
|
||||
|
||||
|
||||
class StopPerformance(BaseModel):
|
||||
"""Performance data for a single stop"""
|
||||
stop_id: str
|
||||
stop_label: str
|
||||
attempts_total: int
|
||||
success_rate: float
|
||||
avg_time_seconds: float
|
||||
avg_attempts_before_success: float
|
||||
common_errors: List[str]
|
||||
difficulty_rating: float # 1-5 based on performance
|
||||
|
||||
|
||||
class UnitPerformanceDetail(BaseModel):
|
||||
"""Detailed unit performance breakdown"""
|
||||
unit_id: str
|
||||
unit_title: str
|
||||
template: str
|
||||
total_sessions: int
|
||||
completed_sessions: int
|
||||
completion_rate: float
|
||||
avg_duration_minutes: float
|
||||
stops: List[StopPerformance]
|
||||
bottleneck_stops: List[str] # Stops where students struggle most
|
||||
|
||||
|
||||
class MisconceptionEntry(BaseModel):
|
||||
"""Individual misconception tracking"""
|
||||
concept_id: str
|
||||
concept_label: str
|
||||
misconception_text: str
|
||||
frequency: int
|
||||
affected_student_ids: List[str]
|
||||
unit_id: str
|
||||
stop_id: str
|
||||
detected_via: str # "precheck", "postcheck", "interaction"
|
||||
first_detected: datetime
|
||||
last_detected: datetime
|
||||
|
||||
|
||||
class MisconceptionReport(BaseModel):
|
||||
"""Comprehensive misconception report"""
|
||||
class_id: Optional[str]
|
||||
time_range: str
|
||||
total_misconceptions: int
|
||||
unique_concepts: int
|
||||
most_common: List[MisconceptionEntry]
|
||||
by_unit: Dict[str, List[MisconceptionEntry]]
|
||||
trending_up: List[MisconceptionEntry] # Getting more frequent
|
||||
resolved: List[MisconceptionEntry] # No longer appearing
|
||||
|
||||
|
||||
class StudentProgressTimeline(BaseModel):
|
||||
"""Timeline of student progress"""
|
||||
student_id: str
|
||||
student_name: str
|
||||
units_completed: int
|
||||
total_time_minutes: int
|
||||
avg_score: float
|
||||
trend: str # "improving", "stable", "declining"
|
||||
timeline: List[Dict[str, Any]] # List of session events
|
||||
|
||||
|
||||
class ClassComparisonData(BaseModel):
|
||||
"""Data for comparing class performance"""
|
||||
class_id: str
|
||||
class_name: str
|
||||
student_count: int
|
||||
units_assigned: int
|
||||
avg_completion_rate: float
|
||||
avg_learning_gain: float
|
||||
avg_time_per_unit: float
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
"""Export format options"""
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
Unit Analytics API - Routes.
|
||||
|
||||
All API endpoints for learning gain, stop-level, misconception,
|
||||
student timeline, class comparison, export, and dashboard analytics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import statistics
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from unit_analytics_models import (
|
||||
TimeRange,
|
||||
LearningGainData,
|
||||
LearningGainSummary,
|
||||
StopPerformance,
|
||||
UnitPerformanceDetail,
|
||||
MisconceptionEntry,
|
||||
MisconceptionReport,
|
||||
StudentProgressTimeline,
|
||||
ClassComparisonData,
|
||||
)
|
||||
from unit_analytics_helpers import (
|
||||
get_analytics_database,
|
||||
calculate_gain_distribution,
|
||||
calculate_trend,
|
||||
calculate_difficulty_rating,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Unit Analytics"])
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Learning Gain
|
||||
# ==============================================
|
||||
|
||||
# NOTE: Static routes must come BEFORE dynamic routes like /{unit_id}
|
||||
@router.get("/learning-gain/compare")
|
||||
async def compare_learning_gains(
|
||||
unit_ids: str = Query(..., description="Comma-separated unit IDs"),
|
||||
class_id: Optional[str] = Query(None),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Compare learning gains across multiple units.
|
||||
"""
|
||||
unit_list = [u.strip() for u in unit_ids.split(",")]
|
||||
comparisons = []
|
||||
|
||||
for unit_id in unit_list:
|
||||
try:
|
||||
summary = await get_learning_gain_analysis(unit_id, class_id, time_range)
|
||||
comparisons.append({
|
||||
"unit_id": unit_id,
|
||||
"avg_gain": summary.avg_gain,
|
||||
"median_gain": summary.median_gain,
|
||||
"total_students": summary.total_students,
|
||||
"positive_rate": summary.positive_gain_count / max(summary.total_students, 1),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get comparison for {unit_id}: {e}")
|
||||
|
||||
return {
|
||||
"time_range": time_range.value,
|
||||
"class_id": class_id,
|
||||
"comparisons": sorted(comparisons, key=lambda x: x["avg_gain"], reverse=True),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/learning-gain/{unit_id}", response_model=LearningGainSummary)
|
||||
async def get_learning_gain_analysis(
|
||||
unit_id: str,
|
||||
class_id: Optional[str] = Query(None, description="Filter by class"),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH, description="Time range for analysis"),
|
||||
) -> LearningGainSummary:
|
||||
"""
|
||||
Get detailed pre/post learning gain analysis for a unit.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
individual_gains = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
sessions = await db.get_unit_sessions_with_scores(
|
||||
unit_id=unit_id,
|
||||
class_id=class_id,
|
||||
time_range=time_range.value
|
||||
)
|
||||
|
||||
for session in sessions:
|
||||
if session.get("precheck_score") is not None and session.get("postcheck_score") is not None:
|
||||
gain = session["postcheck_score"] - session["precheck_score"]
|
||||
individual_gains.append(LearningGainData(
|
||||
student_id=session["student_id"],
|
||||
student_name=session.get("student_name", session["student_id"][:8]),
|
||||
unit_id=unit_id,
|
||||
precheck_score=session["precheck_score"],
|
||||
postcheck_score=session["postcheck_score"],
|
||||
learning_gain=gain,
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get learning gain data: {e}")
|
||||
|
||||
# Calculate statistics
|
||||
if not individual_gains:
|
||||
return LearningGainSummary(
|
||||
unit_id=unit_id,
|
||||
unit_title=f"Unit {unit_id}",
|
||||
total_students=0,
|
||||
avg_precheck=0.0, avg_postcheck=0.0,
|
||||
avg_gain=0.0, median_gain=0.0, std_deviation=0.0,
|
||||
positive_gain_count=0, negative_gain_count=0, no_change_count=0,
|
||||
gain_distribution={}, individual_gains=[],
|
||||
)
|
||||
|
||||
gains = [g.learning_gain for g in individual_gains]
|
||||
prechecks = [g.precheck_score for g in individual_gains]
|
||||
postchecks = [g.postcheck_score for g in individual_gains]
|
||||
|
||||
avg_gain = statistics.mean(gains)
|
||||
median_gain = statistics.median(gains)
|
||||
std_dev = statistics.stdev(gains) if len(gains) > 1 else 0.0
|
||||
|
||||
# Calculate percentiles
|
||||
sorted_gains = sorted(gains)
|
||||
for data in individual_gains:
|
||||
rank = sorted_gains.index(data.learning_gain) + 1
|
||||
data.percentile = rank / len(sorted_gains) * 100
|
||||
|
||||
return LearningGainSummary(
|
||||
unit_id=unit_id,
|
||||
unit_title=f"Unit {unit_id}",
|
||||
total_students=len(individual_gains),
|
||||
avg_precheck=statistics.mean(prechecks),
|
||||
avg_postcheck=statistics.mean(postchecks),
|
||||
avg_gain=avg_gain,
|
||||
median_gain=median_gain,
|
||||
std_deviation=std_dev,
|
||||
positive_gain_count=sum(1 for g in gains if g > 0.01),
|
||||
negative_gain_count=sum(1 for g in gains if g < -0.01),
|
||||
no_change_count=sum(1 for g in gains if -0.01 <= g <= 0.01),
|
||||
gain_distribution=calculate_gain_distribution(gains),
|
||||
individual_gains=sorted(individual_gains, key=lambda x: x.learning_gain, reverse=True),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Stop-Level Analytics
|
||||
# ==============================================
|
||||
|
||||
@router.get("/unit/{unit_id}/stops", response_model=UnitPerformanceDetail)
|
||||
async def get_unit_stop_analytics(
|
||||
unit_id: str,
|
||||
class_id: Optional[str] = Query(None),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
) -> UnitPerformanceDetail:
|
||||
"""
|
||||
Get detailed stop-level performance analytics.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
stops_data = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
stop_stats = await db.get_stop_performance(
|
||||
unit_id=unit_id, class_id=class_id, time_range=time_range.value
|
||||
)
|
||||
|
||||
for stop in stop_stats:
|
||||
difficulty = calculate_difficulty_rating(
|
||||
stop.get("success_rate", 0.5),
|
||||
stop.get("avg_attempts", 1.0)
|
||||
)
|
||||
stops_data.append(StopPerformance(
|
||||
stop_id=stop["stop_id"],
|
||||
stop_label=stop.get("stop_label", stop["stop_id"]),
|
||||
attempts_total=stop.get("total_attempts", 0),
|
||||
success_rate=stop.get("success_rate", 0.0),
|
||||
avg_time_seconds=stop.get("avg_time_seconds", 0.0),
|
||||
avg_attempts_before_success=stop.get("avg_attempts", 1.0),
|
||||
common_errors=stop.get("common_errors", []),
|
||||
difficulty_rating=difficulty,
|
||||
))
|
||||
|
||||
unit_stats = await db.get_unit_overall_stats(unit_id, class_id, time_range.value)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stop analytics: {e}")
|
||||
unit_stats = {}
|
||||
else:
|
||||
unit_stats = {}
|
||||
|
||||
# Identify bottleneck stops
|
||||
bottlenecks = [
|
||||
s.stop_id for s in stops_data
|
||||
if s.difficulty_rating > 3.5 or s.success_rate < 0.6
|
||||
]
|
||||
|
||||
return UnitPerformanceDetail(
|
||||
unit_id=unit_id,
|
||||
unit_title=f"Unit {unit_id}",
|
||||
template=unit_stats.get("template", "unknown"),
|
||||
total_sessions=unit_stats.get("total_sessions", 0),
|
||||
completed_sessions=unit_stats.get("completed_sessions", 0),
|
||||
completion_rate=unit_stats.get("completion_rate", 0.0),
|
||||
avg_duration_minutes=unit_stats.get("avg_duration_minutes", 0.0),
|
||||
stops=stops_data,
|
||||
bottleneck_stops=bottlenecks,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Misconception Tracking
|
||||
# ==============================================
|
||||
|
||||
@router.get("/misconceptions", response_model=MisconceptionReport)
|
||||
async def get_misconception_report(
|
||||
class_id: Optional[str] = Query(None),
|
||||
unit_id: Optional[str] = Query(None),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
) -> MisconceptionReport:
|
||||
"""
|
||||
Get comprehensive misconception report.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
misconceptions = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
raw_misconceptions = await db.get_misconceptions(
|
||||
class_id=class_id, unit_id=unit_id,
|
||||
time_range=time_range.value, limit=limit
|
||||
)
|
||||
|
||||
for m in raw_misconceptions:
|
||||
misconceptions.append(MisconceptionEntry(
|
||||
concept_id=m["concept_id"],
|
||||
concept_label=m["concept_label"],
|
||||
misconception_text=m["misconception_text"],
|
||||
frequency=m["frequency"],
|
||||
affected_student_ids=m.get("student_ids", []),
|
||||
unit_id=m["unit_id"],
|
||||
stop_id=m["stop_id"],
|
||||
detected_via=m.get("detected_via", "unknown"),
|
||||
first_detected=m.get("first_detected", datetime.utcnow()),
|
||||
last_detected=m.get("last_detected", datetime.utcnow()),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get misconceptions: {e}")
|
||||
|
||||
# Group by unit
|
||||
by_unit = {}
|
||||
for m in misconceptions:
|
||||
if m.unit_id not in by_unit:
|
||||
by_unit[m.unit_id] = []
|
||||
by_unit[m.unit_id].append(m)
|
||||
|
||||
trending_up = misconceptions[:3] if misconceptions else []
|
||||
resolved = []
|
||||
|
||||
return MisconceptionReport(
|
||||
class_id=class_id,
|
||||
time_range=time_range.value,
|
||||
total_misconceptions=sum(m.frequency for m in misconceptions),
|
||||
unique_concepts=len(set(m.concept_id for m in misconceptions)),
|
||||
most_common=sorted(misconceptions, key=lambda x: x.frequency, reverse=True)[:10],
|
||||
by_unit=by_unit,
|
||||
trending_up=trending_up,
|
||||
resolved=resolved,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/misconceptions/student/{student_id}")
|
||||
async def get_student_misconceptions(
|
||||
student_id: str,
|
||||
time_range: TimeRange = Query(TimeRange.ALL),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get misconceptions for a specific student.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
|
||||
if db:
|
||||
try:
|
||||
misconceptions = await db.get_student_misconceptions(
|
||||
student_id=student_id, time_range=time_range.value
|
||||
)
|
||||
return {
|
||||
"student_id": student_id,
|
||||
"misconceptions": misconceptions,
|
||||
"recommended_remediation": [
|
||||
{"concept": m["concept_label"], "activity": f"Review {m['unit_id']}/{m['stop_id']}"}
|
||||
for m in misconceptions[:5]
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get student misconceptions: {e}")
|
||||
|
||||
return {
|
||||
"student_id": student_id,
|
||||
"misconceptions": [],
|
||||
"recommended_remediation": [],
|
||||
}
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Student Progress Timeline
|
||||
# ==============================================
|
||||
|
||||
@router.get("/student/{student_id}/timeline", response_model=StudentProgressTimeline)
|
||||
async def get_student_timeline(
|
||||
student_id: str,
|
||||
time_range: TimeRange = Query(TimeRange.ALL),
|
||||
) -> StudentProgressTimeline:
|
||||
"""
|
||||
Get detailed progress timeline for a student.
|
||||
"""
|
||||
db = await get_analytics_database()
|
||||
timeline = []
|
||||
scores = []
|
||||
|
||||
if db:
|
||||
try:
|
||||
sessions = await db.get_student_sessions(
|
||||
student_id=student_id, time_range=time_range.value
|
||||
)
|
||||
|
||||
for session in sessions:
|
||||
timeline.append({
|
||||
"date": session.get("started_at"),
|
||||
"unit_id": session.get("unit_id"),
|
||||
"completed": session.get("completed_at") is not None,
|
||||
"precheck": session.get("precheck_score"),
|
||||
"postcheck": session.get("postcheck_score"),
|
||||
"duration_minutes": session.get("duration_seconds", 0) // 60,
|
||||
})
|
||||
if session.get("postcheck_score") is not None:
|
||||
scores.append(session["postcheck_score"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get student timeline: {e}")
|
||||
|
||||
trend = calculate_trend(scores) if scores else "insufficient_data"
|
||||
|
||||
return StudentProgressTimeline(
|
||||
student_id=student_id,
|
||||
student_name=f"Student {student_id[:8]}",
|
||||
units_completed=sum(1 for t in timeline if t["completed"]),
|
||||
total_time_minutes=sum(t["duration_minutes"] for t in timeline),
|
||||
avg_score=statistics.mean(scores) if scores else 0.0,
|
||||
trend=trend,
|
||||
timeline=timeline,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================
|
||||
# API Endpoints - Class Comparison
|
||||
# ==============================================
|
||||
|
||||
@router.get("/compare/classes", response_model=List[ClassComparisonData])
|
||||
async def compare_classes(
|
||||
class_ids: str = Query(..., description="Comma-separated class IDs"),
|
||||
time_range: TimeRange = Query(TimeRange.MONTH),
|
||||
) -> List[ClassComparisonData]:
|
||||
"""
|
||||
Compare performance across multiple classes.
|
||||
"""
|
||||
class_list = [c.strip() for c in class_ids.split(",")]
|
||||
comparisons = []
|
||||
|
||||
db = await get_analytics_database()
|
||||
if db:
|
||||
for class_id in class_list:
|
||||
try:
|
||||
stats = await db.get_class_aggregate_stats(class_id, time_range.value)
|
||||
comparisons.append(ClassComparisonData(
|
||||
class_id=class_id,
|
||||
class_name=stats.get("class_name", f"Klasse {class_id[:8]}"),
|
||||
student_count=stats.get("student_count", 0),
|
||||
units_assigned=stats.get("units_assigned", 0),
|
||||
avg_completion_rate=stats.get("avg_completion_rate", 0.0),
|
||||
avg_learning_gain=stats.get("avg_learning_gain", 0.0),
|
||||
avg_time_per_unit=stats.get("avg_time_per_unit", 0.0),
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats for class {class_id}: {e}")
|
||||
|
||||
return sorted(comparisons, key=lambda x: x.avg_learning_gain, reverse=True)
|
||||
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Compliance Extraction & Generation.
|
||||
|
||||
Functions for extracting checkpoints from legal text chunks,
|
||||
generating controls, and creating remediation measures.
|
||||
"""
|
||||
|
||||
import re
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from compliance_models import Checkpoint, Control, Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_checkpoints_from_chunk(chunk_text: str, payload: Dict) -> List[Checkpoint]:
|
||||
"""
|
||||
Extract checkpoints/requirements from a chunk of text.
|
||||
|
||||
Uses pattern matching to find requirement-like statements.
|
||||
"""
|
||||
checkpoints = []
|
||||
regulation_code = payload.get("regulation_code", "UNKNOWN")
|
||||
regulation_name = payload.get("regulation_name", "Unknown")
|
||||
source_url = payload.get("source_url", "")
|
||||
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
|
||||
|
||||
# Patterns for different requirement types
|
||||
patterns = [
|
||||
# BSI-TR patterns
|
||||
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
|
||||
# Article patterns (GDPR, AI Act, etc.)
|
||||
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-\u2013:]\s*(.+?)(?=\n|$)', 'article'),
|
||||
# Numbered requirements
|
||||
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
|
||||
# "Der Verantwortliche muss" patterns
|
||||
(r'(?:Der Verantwortliche|Die Aufsichtsbeh\u00f6rde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
|
||||
# "Es ist erforderlich" patterns
|
||||
(r'(?:Es ist erforderlich|Es muss gew\u00e4hrleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
|
||||
]
|
||||
|
||||
for pattern, pattern_type in patterns:
|
||||
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
|
||||
for match in matches:
|
||||
if pattern_type == 'bsi_requirement':
|
||||
req_id = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
title = req_id
|
||||
elif pattern_type == 'article':
|
||||
article_num = match.group(1)
|
||||
paragraph = match.group(2) or ""
|
||||
title_text = match.group(3).strip()
|
||||
req_id = f"{regulation_code}-Art{article_num}"
|
||||
if paragraph:
|
||||
req_id += f"-{paragraph}"
|
||||
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
|
||||
description = title_text
|
||||
elif pattern_type == 'numbered':
|
||||
num = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
req_id = f"{regulation_code}-{num}"
|
||||
title = f"Anforderung {num}"
|
||||
else:
|
||||
# Generic requirement
|
||||
description = match.group(0).strip()
|
||||
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
|
||||
title = description[:50] + "..." if len(description) > 50 else description
|
||||
|
||||
# Skip very short matches
|
||||
if len(description) < 20:
|
||||
continue
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=req_id,
|
||||
regulation_code=regulation_code,
|
||||
regulation_name=regulation_name,
|
||||
article=title if 'Art' in title else None,
|
||||
title=title,
|
||||
description=description[:500],
|
||||
original_text=description,
|
||||
chunk_id=chunk_id,
|
||||
source_url=source_url
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
def generate_control_for_checkpoints(
|
||||
checkpoints: List[Checkpoint],
|
||||
domain_counts: Dict[str, int],
|
||||
) -> Optional[Control]:
|
||||
"""
|
||||
Generate a control that covers the given checkpoints.
|
||||
|
||||
This is a simplified version - in production this would use the AI assistant.
|
||||
"""
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# Group by regulation
|
||||
regulation = checkpoints[0].regulation_code
|
||||
|
||||
# Determine domain based on content
|
||||
all_text = " ".join([cp.description for cp in checkpoints]).lower()
|
||||
|
||||
domain = "gov" # Default
|
||||
if any(kw in all_text for kw in ["verschl\u00fcssel", "krypto", "encrypt", "hash"]):
|
||||
domain = "crypto"
|
||||
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
|
||||
domain = "iam"
|
||||
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
|
||||
domain = "priv"
|
||||
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
|
||||
domain = "sdlc"
|
||||
elif any(kw in all_text for kw in ["\u00fcberwach", "monitor", "log", "audit"]):
|
||||
domain = "aud"
|
||||
elif any(kw in all_text for kw in ["ki", "k\u00fcnstlich", "ai", "machine learning", "model"]):
|
||||
domain = "ai"
|
||||
elif any(kw in all_text for kw in ["betrieb", "operation", "verf\u00fcgbar", "backup"]):
|
||||
domain = "ops"
|
||||
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
|
||||
domain = "cra"
|
||||
|
||||
# Generate control ID
|
||||
domain_count = domain_counts.get(domain, 0) + 1
|
||||
control_id = f"{domain.upper()}-{domain_count:03d}"
|
||||
|
||||
# Create title from first checkpoint
|
||||
title = checkpoints[0].title
|
||||
if len(title) > 100:
|
||||
title = title[:97] + "..."
|
||||
|
||||
# Create description
|
||||
description = f"Control f\u00fcr {regulation}: " + checkpoints[0].description[:200]
|
||||
|
||||
# Pass criteria
|
||||
pass_criteria = f"Alle {len(checkpoints)} zugeh\u00f6rigen Anforderungen sind erf\u00fcllt und dokumentiert."
|
||||
|
||||
# Implementation guidance
|
||||
guidance = f"Implementiere Ma\u00dfnahmen zur Erf\u00fcllung der Anforderungen aus {regulation}. "
|
||||
guidance += f"Dokumentiere die Umsetzung und f\u00fchre regelm\u00e4\u00dfige Reviews durch."
|
||||
|
||||
# Determine if automated
|
||||
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
|
||||
|
||||
control = Control(
|
||||
id=control_id,
|
||||
domain=domain,
|
||||
title=title,
|
||||
description=description,
|
||||
checkpoints=[cp.id for cp in checkpoints],
|
||||
pass_criteria=pass_criteria,
|
||||
implementation_guidance=guidance,
|
||||
is_automated=is_automated,
|
||||
automation_tool="CI/CD Pipeline" if is_automated else None,
|
||||
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
def generate_measure_for_control(control: Control) -> Measure:
|
||||
"""Generate a remediation measure for a control."""
|
||||
measure_id = f"M-{control.id}"
|
||||
|
||||
# Determine deadline based on priority
|
||||
deadline_days = {
|
||||
"critical": 30,
|
||||
"high": 60,
|
||||
"medium": 90,
|
||||
"low": 180
|
||||
}.get(control.priority, 90)
|
||||
|
||||
# Determine responsible team
|
||||
responsible = {
|
||||
"priv": "Datenschutzbeauftragter",
|
||||
"iam": "IT-Security Team",
|
||||
"sdlc": "Entwicklungsteam",
|
||||
"crypto": "IT-Security Team",
|
||||
"ops": "Operations Team",
|
||||
"aud": "Compliance Team",
|
||||
"ai": "AI/ML Team",
|
||||
"cra": "IT-Security Team",
|
||||
"gov": "Management"
|
||||
}.get(control.domain, "Compliance Team")
|
||||
|
||||
measure = Measure(
|
||||
id=measure_id,
|
||||
control_id=control.id,
|
||||
title=f"Umsetzung: {control.title[:50]}",
|
||||
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
|
||||
responsible=responsible,
|
||||
deadline_days=deadline_days,
|
||||
status="pending"
|
||||
)
|
||||
|
||||
return measure
|
||||
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Compliance Pipeline Data Models.
|
||||
|
||||
Dataclasses for checkpoints, controls, and measures.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
"""A requirement/checkpoint extracted from legal text."""
|
||||
id: str
|
||||
regulation_code: str
|
||||
regulation_name: str
|
||||
article: Optional[str]
|
||||
title: str
|
||||
description: str
|
||||
original_text: str
|
||||
chunk_id: str
|
||||
source_url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Control:
|
||||
"""A control derived from checkpoints."""
|
||||
id: str
|
||||
domain: str
|
||||
title: str
|
||||
description: str
|
||||
checkpoints: List[str] # List of checkpoint IDs
|
||||
pass_criteria: str
|
||||
implementation_guidance: str
|
||||
is_automated: bool
|
||||
automation_tool: Optional[str]
|
||||
priority: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Measure:
|
||||
"""A remediation measure for a control."""
|
||||
id: str
|
||||
control_id: str
|
||||
title: str
|
||||
description: str
|
||||
responsible: str
|
||||
deadline_days: int
|
||||
status: str
|
||||
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Compliance Pipeline Execution.
|
||||
|
||||
Pipeline phases (ingestion, extraction, control generation, measures)
|
||||
and orchestration logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any
|
||||
from dataclasses import asdict
|
||||
|
||||
from compliance_models import Checkpoint, Control, Measure
|
||||
from compliance_extraction import (
|
||||
extract_checkpoints_from_chunk,
|
||||
generate_control_for_checkpoints,
|
||||
generate_measure_for_control,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import checkpoint manager
|
||||
try:
|
||||
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
|
||||
except ImportError:
|
||||
logger.warning("Checkpoint manager not available, running without checkpoints")
|
||||
CheckpointManager = None
|
||||
EXPECTED_VALUES = {}
|
||||
ValidationStatus = None
|
||||
|
||||
# Set environment variables for Docker network
|
||||
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
|
||||
os.environ["QDRANT_HOST"] = "qdrant"
|
||||
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
||||
|
||||
# Try to import from klausur-service
|
||||
try:
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||
except ImportError:
|
||||
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class CompliancePipeline:
|
||||
"""Handles the full compliance pipeline."""
|
||||
|
||||
def __init__(self):
|
||||
# Support both QDRANT_URL and QDRANT_HOST/PORT
|
||||
qdrant_url = os.getenv("QDRANT_URL", "")
|
||||
if qdrant_url:
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(qdrant_url)
|
||||
qdrant_host = parsed.hostname or "qdrant"
|
||||
qdrant_port = parsed.port or 6333
|
||||
else:
|
||||
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
|
||||
qdrant_port = 6333
|
||||
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
|
||||
self.checkpoints: List[Checkpoint] = []
|
||||
self.controls: List[Control] = []
|
||||
self.measures: List[Measure] = []
|
||||
self.stats = {
|
||||
"chunks_processed": 0,
|
||||
"checkpoints_extracted": 0,
|
||||
"controls_created": 0,
|
||||
"measures_defined": 0,
|
||||
"by_regulation": {},
|
||||
"by_domain": {},
|
||||
}
|
||||
# Initialize checkpoint manager
|
||||
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
|
||||
|
||||
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
|
||||
"""Phase 1: Ingest documents (incremental - only missing ones)."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
|
||||
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
try:
|
||||
# Check existing chunks per regulation
|
||||
existing_chunks = {}
|
||||
try:
|
||||
for regulation in REGULATIONS:
|
||||
count_result = self.qdrant.count(
|
||||
collection_name=LEGAL_CORPUS_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
|
||||
)
|
||||
)
|
||||
existing_chunks[regulation.code] = count_result.count
|
||||
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check existing chunks: {e}")
|
||||
|
||||
# Determine which regulations need ingestion
|
||||
regulations_to_ingest = []
|
||||
for regulation in REGULATIONS:
|
||||
existing = existing_chunks.get(regulation.code, 0)
|
||||
if force_reindex or existing == 0:
|
||||
regulations_to_ingest.append(regulation)
|
||||
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
|
||||
else:
|
||||
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
|
||||
self.stats["by_regulation"][regulation.code] = existing
|
||||
|
||||
if not regulations_to_ingest:
|
||||
logger.info("All regulations already indexed. Skipping ingestion phase.")
|
||||
total_chunks = sum(existing_chunks.values())
|
||||
self.stats["chunks_processed"] = total_chunks
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
|
||||
self.checkpoint_mgr.add_metric("skipped", True)
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
return total_chunks
|
||||
|
||||
# Ingest only missing regulations
|
||||
total_chunks = sum(existing_chunks.values())
|
||||
for i, regulation in enumerate(regulations_to_ingest, 1):
|
||||
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
|
||||
try:
|
||||
count = await ingestion.ingest_regulation(regulation)
|
||||
total_chunks += count
|
||||
self.stats["by_regulation"][regulation.code] = count
|
||||
logger.info(f" -> {count} chunks")
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" -> FAILED: {e}")
|
||||
self.stats["by_regulation"][regulation.code] = 0
|
||||
|
||||
self.stats["chunks_processed"] = total_chunks
|
||||
logger.info(f"\nTotal chunks in collection: {total_chunks}")
|
||||
|
||||
# Validate ingestion results
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
|
||||
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
|
||||
|
||||
expected = EXPECTED_VALUES.get("ingestion", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_chunks",
|
||||
expected=expected.get("total_chunks", 8000),
|
||||
actual=total_chunks,
|
||||
min_value=expected.get("min_chunks", 7000)
|
||||
)
|
||||
|
||||
reg_expected = expected.get("regulations", {})
|
||||
for reg_code, reg_exp in reg_expected.items():
|
||||
actual = self.stats["by_regulation"].get(reg_code, 0)
|
||||
self.checkpoint_mgr.validate(
|
||||
f"chunks_{reg_code}",
|
||||
expected=reg_exp.get("expected", 0),
|
||||
actual=actual,
|
||||
min_value=reg_exp.get("min", 0)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return total_chunks
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
finally:
|
||||
await ingestion.close()
|
||||
|
||||
async def run_extraction_phase(self) -> int:
|
||||
"""Phase 2: Extract checkpoints from chunks."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
|
||||
|
||||
try:
|
||||
offset = None
|
||||
total_checkpoints = 0
|
||||
|
||||
while True:
|
||||
result = self.qdrant.scroll(
|
||||
collection_name=LEGAL_CORPUS_COLLECTION,
|
||||
limit=100,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
points, next_offset = result
|
||||
|
||||
if not points:
|
||||
break
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
text = payload.get("text", "")
|
||||
|
||||
cps = extract_checkpoints_from_chunk(text, payload)
|
||||
self.checkpoints.extend(cps)
|
||||
total_checkpoints += len(cps)
|
||||
|
||||
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
|
||||
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
self.stats["checkpoints_extracted"] = len(self.checkpoints)
|
||||
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
|
||||
|
||||
by_reg = {}
|
||||
for cp in self.checkpoints:
|
||||
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
|
||||
for reg, count in sorted(by_reg.items()):
|
||||
logger.info(f" {reg}: {count} checkpoints")
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
|
||||
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
|
||||
|
||||
expected = EXPECTED_VALUES.get("extraction", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_checkpoints",
|
||||
expected=expected.get("total_checkpoints", 3500),
|
||||
actual=len(self.checkpoints),
|
||||
min_value=expected.get("min_checkpoints", 3000)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.checkpoints)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
async def run_control_generation_phase(self) -> int:
|
||||
"""Phase 3: Generate controls from checkpoints."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 3: CONTROL GENERATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
|
||||
|
||||
try:
|
||||
# Group checkpoints by regulation
|
||||
by_regulation: Dict[str, List[Checkpoint]] = {}
|
||||
for cp in self.checkpoints:
|
||||
reg = cp.regulation_code
|
||||
if reg not in by_regulation:
|
||||
by_regulation[reg] = []
|
||||
by_regulation[reg].append(cp)
|
||||
|
||||
# Generate controls per regulation (group every 3-5 checkpoints)
|
||||
for regulation, checkpoints in by_regulation.items():
|
||||
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
|
||||
|
||||
batch_size = 4
|
||||
for i in range(0, len(checkpoints), batch_size):
|
||||
batch = checkpoints[i:i + batch_size]
|
||||
control = generate_control_for_checkpoints(batch, self.stats.get("by_domain", {}))
|
||||
|
||||
if control:
|
||||
self.controls.append(control)
|
||||
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
|
||||
|
||||
self.stats["controls_created"] = len(self.controls)
|
||||
logger.info(f"\nTotal controls created: {len(self.controls)}")
|
||||
|
||||
for domain, count in sorted(self.stats["by_domain"].items()):
|
||||
logger.info(f" {domain}: {count} controls")
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
|
||||
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
|
||||
|
||||
expected = EXPECTED_VALUES.get("controls", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_controls",
|
||||
expected=expected.get("total_controls", 900),
|
||||
actual=len(self.controls),
|
||||
min_value=expected.get("min_controls", 800)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.controls)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
async def run_measure_generation_phase(self) -> int:
|
||||
"""Phase 4: Generate measures for controls."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 4: MEASURE GENERATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
|
||||
|
||||
try:
|
||||
for control in self.controls:
|
||||
measure = generate_measure_for_control(control)
|
||||
self.measures.append(measure)
|
||||
|
||||
self.stats["measures_defined"] = len(self.measures)
|
||||
logger.info(f"\nTotal measures defined: {len(self.measures)}")
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
|
||||
|
||||
expected = EXPECTED_VALUES.get("measures", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_measures",
|
||||
expected=expected.get("total_measures", 900),
|
||||
actual=len(self.measures),
|
||||
min_value=expected.get("min_measures", 800)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.measures)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
def save_results(self, output_dir: str = "/tmp/compliance_output"):
|
||||
"""Save results to JSON files."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("SAVING RESULTS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
|
||||
with open(checkpoints_file, "w") as f:
|
||||
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
|
||||
|
||||
controls_file = os.path.join(output_dir, "controls.json")
|
||||
with open(controls_file, "w") as f:
|
||||
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
|
||||
|
||||
measures_file = os.path.join(output_dir, "measures.json")
|
||||
with open(measures_file, "w") as f:
|
||||
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
|
||||
|
||||
stats_file = os.path.join(output_dir, "statistics.json")
|
||||
self.stats["generated_at"] = datetime.now().isoformat()
|
||||
with open(stats_file, "w") as f:
|
||||
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved statistics to {stats_file}")
|
||||
|
||||
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
|
||||
"""Run the complete pipeline.
|
||||
|
||||
Args:
|
||||
force_reindex: If True, re-ingest all documents even if they exist
|
||||
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
|
||||
logger.info(f"Started at: {datetime.now().isoformat()}")
|
||||
logger.info(f"Force reindex: {force_reindex}")
|
||||
logger.info(f"Skip ingestion: {skip_ingestion}")
|
||||
if self.checkpoint_mgr:
|
||||
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
if skip_ingestion:
|
||||
logger.info("Skipping ingestion phase as requested...")
|
||||
try:
|
||||
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
|
||||
self.stats["chunks_processed"] = collection_info.points_count
|
||||
except Exception:
|
||||
self.stats["chunks_processed"] = 0
|
||||
else:
|
||||
await self.run_ingestion_phase(force_reindex=force_reindex)
|
||||
|
||||
await self.run_extraction_phase()
|
||||
await self.run_control_generation_phase()
|
||||
await self.run_measure_generation_phase()
|
||||
self.save_results()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PIPELINE COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Duration: {elapsed:.1f} seconds")
|
||||
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
|
||||
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
|
||||
logger.info(f"Controls created: {self.stats['controls_created']}")
|
||||
logger.info(f"Measures defined: {self.stats['measures_defined']}")
|
||||
logger.info(f"\nResults saved to: /tmp/compliance_output/")
|
||||
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.complete_pipeline({
|
||||
"duration_seconds": elapsed,
|
||||
"chunks_processed": self.stats['chunks_processed'],
|
||||
"checkpoints_extracted": self.stats['checkpoints_extracted'],
|
||||
"controls_created": self.stats['controls_created'],
|
||||
"measures_defined": self.stats['measures_defined'],
|
||||
"by_regulation": self.stats['by_regulation'],
|
||||
"by_domain": self.stats['by_domain'],
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline failed: {e}")
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.state.status = "failed"
|
||||
self.checkpoint_mgr._save()
|
||||
raise
|
||||
@@ -1,7 +1,10 @@
|
||||
"""
|
||||
DSFA RAG API Endpoints.
|
||||
DSFA RAG API Endpoints — Barrel Re-export.
|
||||
|
||||
Provides REST API for searching DSFA corpus with full source attribution.
|
||||
Split into submodules:
|
||||
- dsfa_rag_models.py — Pydantic request/response models
|
||||
- dsfa_rag_embedding.py — Embedding service integration & text extraction
|
||||
- dsfa_rag_routes.py — Route handlers (search, sources, ingest, stats)
|
||||
|
||||
Endpoints:
|
||||
- GET /api/v1/dsfa-rag/search - Semantic search with attribution
|
||||
@@ -11,705 +14,54 @@ Endpoints:
|
||||
- GET /api/v1/dsfa-rag/stats - Get corpus statistics
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Embedding service configuration
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
|
||||
|
||||
# Import from ingestion module
|
||||
from dsfa_corpus_ingestion import (
|
||||
DSFACorpusStore,
|
||||
DSFAQdrantService,
|
||||
DSFASearchResult,
|
||||
LICENSE_REGISTRY,
|
||||
DSFA_SOURCES,
|
||||
generate_attribution_notice,
|
||||
get_license_label,
|
||||
DSFA_COLLECTION,
|
||||
chunk_document
|
||||
# Models
|
||||
from dsfa_rag_models import (
|
||||
DSFASourceResponse,
|
||||
DSFAChunkResponse,
|
||||
DSFASearchResultResponse,
|
||||
DSFASearchResponse,
|
||||
DSFASourceStatsResponse,
|
||||
DSFACorpusStatsResponse,
|
||||
IngestRequest,
|
||||
IngestResponse,
|
||||
LicenseInfo,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
class DSFASourceResponse(BaseModel):
|
||||
"""Response model for DSFA source."""
|
||||
id: str
|
||||
source_code: str
|
||||
name: str
|
||||
full_name: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
attribution_text: str
|
||||
document_type: Optional[str] = None
|
||||
language: str = "de"
|
||||
|
||||
|
||||
class DSFAChunkResponse(BaseModel):
|
||||
"""Response model for a single chunk with attribution."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
category: Optional[str] = None
|
||||
|
||||
# Document info
|
||||
document_id: str
|
||||
document_title: Optional[str] = None
|
||||
|
||||
# Attribution (always included)
|
||||
source_id: str
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
document_type: Optional[str] = None
|
||||
|
||||
|
||||
class DSFASearchResultResponse(BaseModel):
|
||||
"""Response model for search result."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
# Attribution
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
|
||||
# Metadata
|
||||
document_type: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
|
||||
|
||||
class DSFASearchResponse(BaseModel):
|
||||
"""Response model for search endpoint."""
|
||||
query: str
|
||||
results: List[DSFASearchResultResponse]
|
||||
total_results: int
|
||||
|
||||
# Aggregated licenses for footer
|
||||
licenses_used: List[str]
|
||||
attribution_notice: str
|
||||
|
||||
|
||||
class DSFASourceStatsResponse(BaseModel):
|
||||
"""Response model for source statistics."""
|
||||
source_id: str
|
||||
source_code: str
|
||||
name: str
|
||||
organization: Optional[str] = None
|
||||
license_code: str
|
||||
document_type: Optional[str] = None
|
||||
document_count: int
|
||||
chunk_count: int
|
||||
last_indexed_at: Optional[str] = None
|
||||
|
||||
|
||||
class DSFACorpusStatsResponse(BaseModel):
|
||||
"""Response model for corpus statistics."""
|
||||
sources: List[DSFASourceStatsResponse]
|
||||
total_sources: int
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
qdrant_collection: str
|
||||
qdrant_points_count: int
|
||||
qdrant_status: str
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
"""Request model for ingestion."""
|
||||
document_url: Optional[str] = None
|
||||
document_text: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class IngestResponse(BaseModel):
|
||||
"""Response model for ingestion."""
|
||||
source_code: str
|
||||
document_id: Optional[str] = None
|
||||
chunks_created: int
|
||||
message: str
|
||||
|
||||
|
||||
class LicenseInfo(BaseModel):
|
||||
"""License information."""
|
||||
code: str
|
||||
name: str
|
||||
url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
modification_allowed: bool
|
||||
commercial_use: bool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dependency Injection
|
||||
# =============================================================================
|
||||
|
||||
# Database pool (will be set from main.py)
|
||||
_db_pool = None
|
||||
|
||||
|
||||
def set_db_pool(pool):
|
||||
"""Set the database pool for API endpoints."""
|
||||
global _db_pool
|
||||
_db_pool = pool
|
||||
|
||||
|
||||
async def get_store() -> DSFACorpusStore:
|
||||
"""Get DSFA corpus store."""
|
||||
if _db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not initialized")
|
||||
return DSFACorpusStore(_db_pool)
|
||||
|
||||
|
||||
async def get_qdrant() -> DSFAQdrantService:
|
||||
"""Get Qdrant service."""
|
||||
return DSFAQdrantService()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Service Integration
|
||||
# =============================================================================
|
||||
|
||||
async def get_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Get embedding for text using the embedding-service.
|
||||
|
||||
Uses BGE-M3 model which produces 1024-dimensional vectors.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed-single",
|
||||
json={"text": text}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embedding", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service error: {e}")
|
||||
# Fallback to hash-based pseudo-embedding for development
|
||||
return _generate_fallback_embedding(text)
|
||||
|
||||
|
||||
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Get embeddings for multiple texts in batch.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": texts}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embeddings", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service batch error: {e}")
|
||||
# Fallback
|
||||
return [_generate_fallback_embedding(t) for t in texts]
|
||||
|
||||
|
||||
async def extract_text_from_url(url: str) -> str:
|
||||
"""
|
||||
Extract text from a document URL (PDF, HTML, etc.).
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
# First try to use the embedding-service's extract-pdf endpoint
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
|
||||
json={"url": url}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("text", "")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"PDF extraction error for {url}: {e}")
|
||||
# Fallback: try to fetch HTML content directly
|
||||
try:
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "html" in content_type:
|
||||
# Simple HTML text extraction
|
||||
import re
|
||||
html = response.text
|
||||
# Remove scripts and styles
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
# Remove tags
|
||||
text = re.sub(r'<[^>]+>', ' ', html)
|
||||
# Clean whitespace
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
else:
|
||||
return ""
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
|
||||
return ""
|
||||
|
||||
|
||||
def _generate_fallback_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Generate deterministic pseudo-embedding from text hash.
|
||||
Used as fallback when embedding service is unavailable.
|
||||
"""
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||
embedding = []
|
||||
for i in range(0, min(len(hash_bytes), 128), 4):
|
||||
val = struct.unpack('f', hash_bytes[i:i+4])[0]
|
||||
embedding.append(val % 1.0)
|
||||
|
||||
# Pad to 1024 dimensions
|
||||
while len(embedding) < 1024:
|
||||
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
|
||||
|
||||
return embedding[:1024]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/search", response_model=DSFASearchResponse)
|
||||
async def search_dsfa_corpus(
|
||||
query: str = Query(..., min_length=3, description="Search query"),
|
||||
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
|
||||
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
|
||||
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
|
||||
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
|
||||
include_attribution: bool = Query(True, description="Include attribution in results"),
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Search DSFA corpus with full attribution.
|
||||
|
||||
Returns matching chunks with source/license information for compliance.
|
||||
"""
|
||||
# Get query embedding
|
||||
query_embedding = await get_embedding(query)
|
||||
|
||||
# Search Qdrant
|
||||
raw_results = await qdrant.search(
|
||||
query_embedding=query_embedding,
|
||||
source_codes=source_codes,
|
||||
document_types=document_types,
|
||||
categories=categories,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Transform results
|
||||
results = []
|
||||
licenses_used = set()
|
||||
|
||||
for r in raw_results:
|
||||
license_code = r.get("license_code", "")
|
||||
license_info = LICENSE_REGISTRY.get(license_code, {})
|
||||
|
||||
result = DSFASearchResultResponse(
|
||||
chunk_id=r.get("chunk_id", ""),
|
||||
content=r.get("content", ""),
|
||||
score=r.get("score", 0.0),
|
||||
source_code=r.get("source_code", ""),
|
||||
source_name=r.get("source_name", ""),
|
||||
attribution_text=r.get("attribution_text", ""),
|
||||
license_code=license_code,
|
||||
license_name=license_info.get("name", license_code),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=r.get("attribution_required", True),
|
||||
source_url=r.get("source_url"),
|
||||
document_type=r.get("document_type"),
|
||||
category=r.get("category"),
|
||||
section_title=r.get("section_title"),
|
||||
page_number=r.get("page_number")
|
||||
)
|
||||
results.append(result)
|
||||
licenses_used.add(license_code)
|
||||
|
||||
# Generate attribution notice
|
||||
search_results = [
|
||||
DSFASearchResult(
|
||||
chunk_id=r.chunk_id,
|
||||
content=r.content,
|
||||
score=r.score,
|
||||
source_code=r.source_code,
|
||||
source_name=r.source_name,
|
||||
attribution_text=r.attribution_text,
|
||||
license_code=r.license_code,
|
||||
license_url=r.license_url,
|
||||
attribution_required=r.attribution_required,
|
||||
source_url=r.source_url,
|
||||
document_type=r.document_type or "",
|
||||
category=r.category or "",
|
||||
section_title=r.section_title,
|
||||
page_number=r.page_number
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
|
||||
|
||||
return DSFASearchResponse(
|
||||
query=query,
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
licenses_used=list(licenses_used),
|
||||
attribution_notice=attribution_notice
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sources", response_model=List[DSFASourceResponse])
|
||||
async def list_dsfa_sources(
|
||||
document_type: Optional[str] = Query(None, description="Filter by document type"),
|
||||
license_code: Optional[str] = Query(None, description="Filter by license"),
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""List all registered DSFA sources with license info."""
|
||||
sources = await store.list_sources()
|
||||
|
||||
result = []
|
||||
for s in sources:
|
||||
# Apply filters
|
||||
if document_type and s.get("document_type") != document_type:
|
||||
continue
|
||||
if license_code and s.get("license_code") != license_code:
|
||||
continue
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
|
||||
|
||||
result.append(DSFASourceResponse(
|
||||
id=str(s["id"]),
|
||||
source_code=s["source_code"],
|
||||
name=s["name"],
|
||||
full_name=s.get("full_name"),
|
||||
organization=s.get("organization"),
|
||||
source_url=s.get("source_url"),
|
||||
license_code=s.get("license_code", ""),
|
||||
license_name=license_info.get("name", s.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=s.get("attribution_required", True),
|
||||
attribution_text=s.get("attribution_text", ""),
|
||||
document_type=s.get("document_type"),
|
||||
language=s.get("language", "de")
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/sources/available")
|
||||
async def list_available_sources():
|
||||
"""List all available source definitions (from DSFA_SOURCES constant)."""
|
||||
return [
|
||||
{
|
||||
"source_code": s["source_code"],
|
||||
"name": s["name"],
|
||||
"organization": s.get("organization"),
|
||||
"license_code": s["license_code"],
|
||||
"document_type": s.get("document_type")
|
||||
}
|
||||
for s in DSFA_SOURCES
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
|
||||
async def get_dsfa_source(
|
||||
source_code: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get details for a specific source."""
|
||||
source = await store.get_source_by_code(source_code)
|
||||
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
|
||||
|
||||
return DSFASourceResponse(
|
||||
id=str(source["id"]),
|
||||
source_code=source["source_code"],
|
||||
name=source["name"],
|
||||
full_name=source.get("full_name"),
|
||||
organization=source.get("organization"),
|
||||
source_url=source.get("source_url"),
|
||||
license_code=source.get("license_code", ""),
|
||||
license_name=license_info.get("name", source.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=source.get("attribution_required", True),
|
||||
attribution_text=source.get("attribution_text", ""),
|
||||
document_type=source.get("document_type"),
|
||||
language=source.get("language", "de")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
|
||||
async def ingest_dsfa_source(
|
||||
source_code: str,
|
||||
request: IngestRequest,
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Trigger ingestion for a specific source.
|
||||
|
||||
Can provide document via URL or direct text.
|
||||
"""
|
||||
# Get source
|
||||
source = await store.get_source_by_code(source_code)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
# Need either URL or text
|
||||
if not request.document_text and not request.document_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either document_text or document_url must be provided"
|
||||
)
|
||||
|
||||
# Ensure Qdrant collection exists
|
||||
await qdrant.ensure_collection()
|
||||
|
||||
# Get text content
|
||||
text_content = request.document_text
|
||||
if request.document_url and not text_content:
|
||||
# Download and extract text from URL
|
||||
logger.info(f"Extracting text from URL: {request.document_url}")
|
||||
text_content = await extract_text_from_url(request.document_url)
|
||||
if not text_content:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Could not extract text from URL: {request.document_url}"
|
||||
)
|
||||
|
||||
if not text_content or len(text_content.strip()) < 50:
|
||||
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
|
||||
|
||||
# Create document record
|
||||
doc_title = request.title or f"Document for {source_code}"
|
||||
document_id = await store.create_document(
|
||||
source_id=str(source["id"]),
|
||||
title=doc_title,
|
||||
file_type="text",
|
||||
metadata={"ingested_via": "api", "source_code": source_code}
|
||||
)
|
||||
|
||||
# Chunk the document
|
||||
chunks = chunk_document(text_content, source_code)
|
||||
|
||||
if not chunks:
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=0,
|
||||
message="Document created but no chunks generated"
|
||||
)
|
||||
|
||||
# Generate embeddings in batch for efficiency
|
||||
chunk_texts = [chunk["content"] for chunk in chunks]
|
||||
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
|
||||
embeddings = await get_embeddings_batch(chunk_texts)
|
||||
|
||||
# Create chunk records in PostgreSQL and prepare for Qdrant
|
||||
chunk_records = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
# Create chunk in PostgreSQL
|
||||
chunk_id = await store.create_chunk(
|
||||
document_id=document_id,
|
||||
source_id=str(source["id"]),
|
||||
content=chunk["content"],
|
||||
chunk_index=i,
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category")
|
||||
)
|
||||
|
||||
chunk_records.append({
|
||||
"chunk_id": chunk_id,
|
||||
"document_id": document_id,
|
||||
"source_id": str(source["id"]),
|
||||
"content": chunk["content"],
|
||||
"section_title": chunk.get("section_title"),
|
||||
"source_code": source_code,
|
||||
"source_name": source["name"],
|
||||
"attribution_text": source["attribution_text"],
|
||||
"license_code": source["license_code"],
|
||||
"attribution_required": source.get("attribution_required", True),
|
||||
"document_type": source.get("document_type", ""),
|
||||
"category": chunk.get("category", ""),
|
||||
"language": source.get("language", "de"),
|
||||
"page_number": chunk.get("page_number")
|
||||
})
|
||||
|
||||
# Index in Qdrant
|
||||
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
|
||||
|
||||
# Update document record
|
||||
await store.update_document_indexed(document_id, len(chunks))
|
||||
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=indexed_count,
|
||||
message=f"Successfully ingested {indexed_count} chunks from document"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
|
||||
async def get_chunk_with_attribution(
|
||||
chunk_id: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get single chunk with full source attribution."""
|
||||
chunk = await store.get_chunk_with_attribution(chunk_id)
|
||||
|
||||
if not chunk:
|
||||
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
|
||||
|
||||
return DSFAChunkResponse(
|
||||
chunk_id=str(chunk["chunk_id"]),
|
||||
content=chunk.get("content", ""),
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category"),
|
||||
document_id=str(chunk.get("document_id", "")),
|
||||
document_title=chunk.get("document_title"),
|
||||
source_id=str(chunk.get("source_id", "")),
|
||||
source_code=chunk.get("source_code", ""),
|
||||
source_name=chunk.get("source_name", ""),
|
||||
attribution_text=chunk.get("attribution_text", ""),
|
||||
license_code=chunk.get("license_code", ""),
|
||||
license_name=license_info.get("name", chunk.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=chunk.get("attribution_required", True),
|
||||
source_url=chunk.get("source_url"),
|
||||
document_type=chunk.get("document_type")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=DSFACorpusStatsResponse)
|
||||
async def get_corpus_stats(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""Get corpus statistics for dashboard."""
|
||||
# Get PostgreSQL stats
|
||||
source_stats = await store.get_source_stats()
|
||||
|
||||
total_docs = 0
|
||||
total_chunks = 0
|
||||
stats_response = []
|
||||
|
||||
for s in source_stats:
|
||||
doc_count = s.get("document_count", 0) or 0
|
||||
chunk_count = s.get("chunk_count", 0) or 0
|
||||
total_docs += doc_count
|
||||
total_chunks += chunk_count
|
||||
|
||||
last_indexed = s.get("last_indexed_at")
|
||||
|
||||
stats_response.append(DSFASourceStatsResponse(
|
||||
source_id=str(s.get("source_id", "")),
|
||||
source_code=s.get("source_code", ""),
|
||||
name=s.get("name", ""),
|
||||
organization=s.get("organization"),
|
||||
license_code=s.get("license_code", ""),
|
||||
document_type=s.get("document_type"),
|
||||
document_count=doc_count,
|
||||
chunk_count=chunk_count,
|
||||
last_indexed_at=last_indexed.isoformat() if last_indexed else None
|
||||
))
|
||||
|
||||
# Get Qdrant stats
|
||||
qdrant_stats = await qdrant.get_stats()
|
||||
|
||||
return DSFACorpusStatsResponse(
|
||||
sources=stats_response,
|
||||
total_sources=len(source_stats),
|
||||
total_documents=total_docs,
|
||||
total_chunks=total_chunks,
|
||||
qdrant_collection=DSFA_COLLECTION,
|
||||
qdrant_points_count=qdrant_stats.get("points_count", 0),
|
||||
qdrant_status=qdrant_stats.get("status", "unknown")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/licenses")
|
||||
async def list_licenses():
|
||||
"""List all supported licenses with their terms."""
|
||||
return [
|
||||
LicenseInfo(
|
||||
code=code,
|
||||
name=info.get("name", code),
|
||||
url=info.get("url"),
|
||||
attribution_required=info.get("attribution_required", True),
|
||||
modification_allowed=info.get("modification_allowed", True),
|
||||
commercial_use=info.get("commercial_use", True)
|
||||
)
|
||||
for code, info in LICENSE_REGISTRY.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def initialize_dsfa_corpus(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Initialize DSFA corpus.
|
||||
|
||||
- Creates Qdrant collection
|
||||
- Registers all predefined sources
|
||||
"""
|
||||
# Ensure Qdrant collection exists
|
||||
qdrant_ok = await qdrant.ensure_collection()
|
||||
|
||||
# Register all sources
|
||||
registered = 0
|
||||
for source in DSFA_SOURCES:
|
||||
try:
|
||||
await store.register_source(source)
|
||||
registered += 1
|
||||
except Exception as e:
|
||||
print(f"Error registering source {source['source_code']}: {e}")
|
||||
|
||||
return {
|
||||
"qdrant_collection_created": qdrant_ok,
|
||||
"sources_registered": registered,
|
||||
"total_sources": len(DSFA_SOURCES)
|
||||
}
|
||||
# Embedding utilities
|
||||
from dsfa_rag_embedding import (
|
||||
get_embedding,
|
||||
get_embeddings_batch,
|
||||
extract_text_from_url,
|
||||
EMBEDDING_SERVICE_URL,
|
||||
)
|
||||
|
||||
# Routes (router + set_db_pool)
|
||||
from dsfa_rag_routes import (
|
||||
router,
|
||||
set_db_pool,
|
||||
get_store,
|
||||
get_qdrant,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Router
|
||||
"router",
|
||||
"set_db_pool",
|
||||
"get_store",
|
||||
"get_qdrant",
|
||||
# Models
|
||||
"DSFASourceResponse",
|
||||
"DSFAChunkResponse",
|
||||
"DSFASearchResultResponse",
|
||||
"DSFASearchResponse",
|
||||
"DSFASourceStatsResponse",
|
||||
"DSFACorpusStatsResponse",
|
||||
"IngestRequest",
|
||||
"IngestResponse",
|
||||
"LicenseInfo",
|
||||
# Embedding
|
||||
"get_embedding",
|
||||
"get_embeddings_batch",
|
||||
"extract_text_from_url",
|
||||
"EMBEDDING_SERVICE_URL",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
DSFA RAG Embedding Service Integration.
|
||||
|
||||
Handles embedding generation, text extraction, and fallback logic.
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import struct
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Embedding service configuration
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
|
||||
|
||||
|
||||
async def get_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Get embedding for text using the embedding-service.
|
||||
|
||||
Uses BGE-M3 model which produces 1024-dimensional vectors.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed-single",
|
||||
json={"text": text}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embedding", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service error: {e}")
|
||||
# Fallback to hash-based pseudo-embedding for development
|
||||
return _generate_fallback_embedding(text)
|
||||
|
||||
|
||||
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Get embeddings for multiple texts in batch.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": texts}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embeddings", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service batch error: {e}")
|
||||
# Fallback
|
||||
return [_generate_fallback_embedding(t) for t in texts]
|
||||
|
||||
|
||||
async def extract_text_from_url(url: str) -> str:
|
||||
"""
|
||||
Extract text from a document URL (PDF, HTML, etc.).
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
# First try to use the embedding-service's extract-pdf endpoint
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
|
||||
json={"url": url}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("text", "")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"PDF extraction error for {url}: {e}")
|
||||
# Fallback: try to fetch HTML content directly
|
||||
try:
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "html" in content_type:
|
||||
# Simple HTML text extraction
|
||||
html = response.text
|
||||
# Remove scripts and styles
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
# Remove tags
|
||||
text = re.sub(r'<[^>]+>', ' ', html)
|
||||
# Clean whitespace
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
else:
|
||||
return ""
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
|
||||
return ""
|
||||
|
||||
|
||||
def _generate_fallback_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Generate deterministic pseudo-embedding from text hash.
|
||||
Used as fallback when embedding service is unavailable.
|
||||
"""
|
||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||
embedding = []
|
||||
for i in range(0, min(len(hash_bytes), 128), 4):
|
||||
val = struct.unpack('f', hash_bytes[i:i+4])[0]
|
||||
embedding.append(val % 1.0)
|
||||
|
||||
# Pad to 1024 dimensions
|
||||
while len(embedding) < 1024:
|
||||
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
|
||||
|
||||
return embedding[:1024]
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
DSFA RAG Pydantic Models.
|
||||
|
||||
Request/Response models for the DSFA RAG API.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Response Models
|
||||
# =============================================================================
|
||||
|
||||
class DSFASourceResponse(BaseModel):
|
||||
"""Response model for DSFA source."""
|
||||
id: str
|
||||
source_code: str
|
||||
name: str
|
||||
full_name: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
attribution_text: str
|
||||
document_type: Optional[str] = None
|
||||
language: str = "de"
|
||||
|
||||
|
||||
class DSFAChunkResponse(BaseModel):
|
||||
"""Response model for a single chunk with attribution."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
category: Optional[str] = None
|
||||
|
||||
# Document info
|
||||
document_id: str
|
||||
document_title: Optional[str] = None
|
||||
|
||||
# Attribution (always included)
|
||||
source_id: str
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
document_type: Optional[str] = None
|
||||
|
||||
|
||||
class DSFASearchResultResponse(BaseModel):
|
||||
"""Response model for search result."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
# Attribution
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
|
||||
# Metadata
|
||||
document_type: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
|
||||
|
||||
class DSFASearchResponse(BaseModel):
|
||||
"""Response model for search endpoint."""
|
||||
query: str
|
||||
results: List[DSFASearchResultResponse]
|
||||
total_results: int
|
||||
|
||||
# Aggregated licenses for footer
|
||||
licenses_used: List[str]
|
||||
attribution_notice: str
|
||||
|
||||
|
||||
class DSFASourceStatsResponse(BaseModel):
|
||||
"""Response model for source statistics."""
|
||||
source_id: str
|
||||
source_code: str
|
||||
name: str
|
||||
organization: Optional[str] = None
|
||||
license_code: str
|
||||
document_type: Optional[str] = None
|
||||
document_count: int
|
||||
chunk_count: int
|
||||
last_indexed_at: Optional[str] = None
|
||||
|
||||
|
||||
class DSFACorpusStatsResponse(BaseModel):
|
||||
"""Response model for corpus statistics."""
|
||||
sources: List[DSFASourceStatsResponse]
|
||||
total_sources: int
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
qdrant_collection: str
|
||||
qdrant_points_count: int
|
||||
qdrant_status: str
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
"""Request model for ingestion."""
|
||||
document_url: Optional[str] = None
|
||||
document_text: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class IngestResponse(BaseModel):
|
||||
"""Response model for ingestion."""
|
||||
source_code: str
|
||||
document_id: Optional[str] = None
|
||||
chunks_created: int
|
||||
message: str
|
||||
|
||||
|
||||
class LicenseInfo(BaseModel):
|
||||
"""License information."""
|
||||
code: str
|
||||
name: str
|
||||
url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
modification_allowed: bool
|
||||
commercial_use: bool
|
||||
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
DSFA RAG API Route Handlers.
|
||||
|
||||
Endpoint implementations for search, sources, ingestion, stats, and init.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
|
||||
from dsfa_corpus_ingestion import (
|
||||
DSFACorpusStore,
|
||||
DSFAQdrantService,
|
||||
DSFASearchResult,
|
||||
LICENSE_REGISTRY,
|
||||
DSFA_SOURCES,
|
||||
generate_attribution_notice,
|
||||
get_license_label,
|
||||
DSFA_COLLECTION,
|
||||
chunk_document,
|
||||
)
|
||||
|
||||
from dsfa_rag_models import (
|
||||
DSFASourceResponse,
|
||||
DSFAChunkResponse,
|
||||
DSFASearchResultResponse,
|
||||
DSFASearchResponse,
|
||||
DSFASourceStatsResponse,
|
||||
DSFACorpusStatsResponse,
|
||||
IngestRequest,
|
||||
IngestResponse,
|
||||
LicenseInfo,
|
||||
)
|
||||
|
||||
from dsfa_rag_embedding import (
|
||||
get_embedding,
|
||||
get_embeddings_batch,
|
||||
extract_text_from_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dependency Injection
|
||||
# =============================================================================
|
||||
|
||||
_db_pool = None
|
||||
|
||||
|
||||
def set_db_pool(pool):
|
||||
"""Set the database pool for API endpoints."""
|
||||
global _db_pool
|
||||
_db_pool = pool
|
||||
|
||||
|
||||
async def get_store() -> DSFACorpusStore:
|
||||
"""Get DSFA corpus store."""
|
||||
if _db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not initialized")
|
||||
return DSFACorpusStore(_db_pool)
|
||||
|
||||
|
||||
async def get_qdrant() -> DSFAQdrantService:
|
||||
"""Get Qdrant service."""
|
||||
return DSFAQdrantService()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/search", response_model=DSFASearchResponse)
|
||||
async def search_dsfa_corpus(
|
||||
query: str = Query(..., min_length=3, description="Search query"),
|
||||
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
|
||||
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
|
||||
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
|
||||
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
|
||||
include_attribution: bool = Query(True, description="Include attribution in results"),
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Search DSFA corpus with full attribution.
|
||||
|
||||
Returns matching chunks with source/license information for compliance.
|
||||
"""
|
||||
query_embedding = await get_embedding(query)
|
||||
|
||||
raw_results = await qdrant.search(
|
||||
query_embedding=query_embedding,
|
||||
source_codes=source_codes,
|
||||
document_types=document_types,
|
||||
categories=categories,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
results = []
|
||||
licenses_used = set()
|
||||
|
||||
for r in raw_results:
|
||||
license_code = r.get("license_code", "")
|
||||
license_info = LICENSE_REGISTRY.get(license_code, {})
|
||||
|
||||
result = DSFASearchResultResponse(
|
||||
chunk_id=r.get("chunk_id", ""),
|
||||
content=r.get("content", ""),
|
||||
score=r.get("score", 0.0),
|
||||
source_code=r.get("source_code", ""),
|
||||
source_name=r.get("source_name", ""),
|
||||
attribution_text=r.get("attribution_text", ""),
|
||||
license_code=license_code,
|
||||
license_name=license_info.get("name", license_code),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=r.get("attribution_required", True),
|
||||
source_url=r.get("source_url"),
|
||||
document_type=r.get("document_type"),
|
||||
category=r.get("category"),
|
||||
section_title=r.get("section_title"),
|
||||
page_number=r.get("page_number")
|
||||
)
|
||||
results.append(result)
|
||||
licenses_used.add(license_code)
|
||||
|
||||
# Generate attribution notice
|
||||
search_results = [
|
||||
DSFASearchResult(
|
||||
chunk_id=r.chunk_id,
|
||||
content=r.content,
|
||||
score=r.score,
|
||||
source_code=r.source_code,
|
||||
source_name=r.source_name,
|
||||
attribution_text=r.attribution_text,
|
||||
license_code=r.license_code,
|
||||
license_url=r.license_url,
|
||||
attribution_required=r.attribution_required,
|
||||
source_url=r.source_url,
|
||||
document_type=r.document_type or "",
|
||||
category=r.category or "",
|
||||
section_title=r.section_title,
|
||||
page_number=r.page_number
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
|
||||
|
||||
return DSFASearchResponse(
|
||||
query=query,
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
licenses_used=list(licenses_used),
|
||||
attribution_notice=attribution_notice
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sources", response_model=List[DSFASourceResponse])
|
||||
async def list_dsfa_sources(
|
||||
document_type: Optional[str] = Query(None, description="Filter by document type"),
|
||||
license_code: Optional[str] = Query(None, description="Filter by license"),
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""List all registered DSFA sources with license info."""
|
||||
sources = await store.list_sources()
|
||||
|
||||
result = []
|
||||
for s in sources:
|
||||
if document_type and s.get("document_type") != document_type:
|
||||
continue
|
||||
if license_code and s.get("license_code") != license_code:
|
||||
continue
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
|
||||
|
||||
result.append(DSFASourceResponse(
|
||||
id=str(s["id"]),
|
||||
source_code=s["source_code"],
|
||||
name=s["name"],
|
||||
full_name=s.get("full_name"),
|
||||
organization=s.get("organization"),
|
||||
source_url=s.get("source_url"),
|
||||
license_code=s.get("license_code", ""),
|
||||
license_name=license_info.get("name", s.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=s.get("attribution_required", True),
|
||||
attribution_text=s.get("attribution_text", ""),
|
||||
document_type=s.get("document_type"),
|
||||
language=s.get("language", "de")
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/sources/available")
|
||||
async def list_available_sources():
|
||||
"""List all available source definitions (from DSFA_SOURCES constant)."""
|
||||
return [
|
||||
{
|
||||
"source_code": s["source_code"],
|
||||
"name": s["name"],
|
||||
"organization": s.get("organization"),
|
||||
"license_code": s["license_code"],
|
||||
"document_type": s.get("document_type")
|
||||
}
|
||||
for s in DSFA_SOURCES
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
|
||||
async def get_dsfa_source(
|
||||
source_code: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get details for a specific source."""
|
||||
source = await store.get_source_by_code(source_code)
|
||||
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
|
||||
|
||||
return DSFASourceResponse(
|
||||
id=str(source["id"]),
|
||||
source_code=source["source_code"],
|
||||
name=source["name"],
|
||||
full_name=source.get("full_name"),
|
||||
organization=source.get("organization"),
|
||||
source_url=source.get("source_url"),
|
||||
license_code=source.get("license_code", ""),
|
||||
license_name=license_info.get("name", source.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=source.get("attribution_required", True),
|
||||
attribution_text=source.get("attribution_text", ""),
|
||||
document_type=source.get("document_type"),
|
||||
language=source.get("language", "de")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
|
||||
async def ingest_dsfa_source(
|
||||
source_code: str,
|
||||
request: IngestRequest,
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Trigger ingestion for a specific source.
|
||||
|
||||
Can provide document via URL or direct text.
|
||||
"""
|
||||
source = await store.get_source_by_code(source_code)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
if not request.document_text and not request.document_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either document_text or document_url must be provided"
|
||||
)
|
||||
|
||||
await qdrant.ensure_collection()
|
||||
|
||||
text_content = request.document_text
|
||||
if request.document_url and not text_content:
|
||||
logger.info(f"Extracting text from URL: {request.document_url}")
|
||||
text_content = await extract_text_from_url(request.document_url)
|
||||
if not text_content:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Could not extract text from URL: {request.document_url}"
|
||||
)
|
||||
|
||||
if not text_content or len(text_content.strip()) < 50:
|
||||
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
|
||||
|
||||
doc_title = request.title or f"Document for {source_code}"
|
||||
document_id = await store.create_document(
|
||||
source_id=str(source["id"]),
|
||||
title=doc_title,
|
||||
file_type="text",
|
||||
metadata={"ingested_via": "api", "source_code": source_code}
|
||||
)
|
||||
|
||||
chunks = chunk_document(text_content, source_code)
|
||||
|
||||
if not chunks:
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=0,
|
||||
message="Document created but no chunks generated"
|
||||
)
|
||||
|
||||
chunk_texts = [chunk["content"] for chunk in chunks]
|
||||
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
|
||||
embeddings = await get_embeddings_batch(chunk_texts)
|
||||
|
||||
chunk_records = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
chunk_id = await store.create_chunk(
|
||||
document_id=document_id,
|
||||
source_id=str(source["id"]),
|
||||
content=chunk["content"],
|
||||
chunk_index=i,
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category")
|
||||
)
|
||||
|
||||
chunk_records.append({
|
||||
"chunk_id": chunk_id,
|
||||
"document_id": document_id,
|
||||
"source_id": str(source["id"]),
|
||||
"content": chunk["content"],
|
||||
"section_title": chunk.get("section_title"),
|
||||
"source_code": source_code,
|
||||
"source_name": source["name"],
|
||||
"attribution_text": source["attribution_text"],
|
||||
"license_code": source["license_code"],
|
||||
"attribution_required": source.get("attribution_required", True),
|
||||
"document_type": source.get("document_type", ""),
|
||||
"category": chunk.get("category", ""),
|
||||
"language": source.get("language", "de"),
|
||||
"page_number": chunk.get("page_number")
|
||||
})
|
||||
|
||||
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
|
||||
await store.update_document_indexed(document_id, len(chunks))
|
||||
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=indexed_count,
|
||||
message=f"Successfully ingested {indexed_count} chunks from document"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
|
||||
async def get_chunk_with_attribution(
|
||||
chunk_id: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get single chunk with full source attribution."""
|
||||
chunk = await store.get_chunk_with_attribution(chunk_id)
|
||||
|
||||
if not chunk:
|
||||
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
|
||||
|
||||
return DSFAChunkResponse(
|
||||
chunk_id=str(chunk["chunk_id"]),
|
||||
content=chunk.get("content", ""),
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category"),
|
||||
document_id=str(chunk.get("document_id", "")),
|
||||
document_title=chunk.get("document_title"),
|
||||
source_id=str(chunk.get("source_id", "")),
|
||||
source_code=chunk.get("source_code", ""),
|
||||
source_name=chunk.get("source_name", ""),
|
||||
attribution_text=chunk.get("attribution_text", ""),
|
||||
license_code=chunk.get("license_code", ""),
|
||||
license_name=license_info.get("name", chunk.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=chunk.get("attribution_required", True),
|
||||
source_url=chunk.get("source_url"),
|
||||
document_type=chunk.get("document_type")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=DSFACorpusStatsResponse)
|
||||
async def get_corpus_stats(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""Get corpus statistics for dashboard."""
|
||||
source_stats = await store.get_source_stats()
|
||||
|
||||
total_docs = 0
|
||||
total_chunks = 0
|
||||
stats_response = []
|
||||
|
||||
for s in source_stats:
|
||||
doc_count = s.get("document_count", 0) or 0
|
||||
chunk_count = s.get("chunk_count", 0) or 0
|
||||
total_docs += doc_count
|
||||
total_chunks += chunk_count
|
||||
|
||||
last_indexed = s.get("last_indexed_at")
|
||||
|
||||
stats_response.append(DSFASourceStatsResponse(
|
||||
source_id=str(s.get("source_id", "")),
|
||||
source_code=s.get("source_code", ""),
|
||||
name=s.get("name", ""),
|
||||
organization=s.get("organization"),
|
||||
license_code=s.get("license_code", ""),
|
||||
document_type=s.get("document_type"),
|
||||
document_count=doc_count,
|
||||
chunk_count=chunk_count,
|
||||
last_indexed_at=last_indexed.isoformat() if last_indexed else None
|
||||
))
|
||||
|
||||
qdrant_stats = await qdrant.get_stats()
|
||||
|
||||
return DSFACorpusStatsResponse(
|
||||
sources=stats_response,
|
||||
total_sources=len(source_stats),
|
||||
total_documents=total_docs,
|
||||
total_chunks=total_chunks,
|
||||
qdrant_collection=DSFA_COLLECTION,
|
||||
qdrant_points_count=qdrant_stats.get("points_count", 0),
|
||||
qdrant_status=qdrant_stats.get("status", "unknown")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/licenses")
|
||||
async def list_licenses():
|
||||
"""List all supported licenses with their terms."""
|
||||
return [
|
||||
LicenseInfo(
|
||||
code=code,
|
||||
name=info.get("name", code),
|
||||
url=info.get("url"),
|
||||
attribution_required=info.get("attribution_required", True),
|
||||
modification_allowed=info.get("modification_allowed", True),
|
||||
commercial_use=info.get("commercial_use", True)
|
||||
)
|
||||
for code, info in LICENSE_REGISTRY.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def initialize_dsfa_corpus(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Initialize DSFA corpus.
|
||||
|
||||
- Creates Qdrant collection
|
||||
- Registers all predefined sources
|
||||
"""
|
||||
qdrant_ok = await qdrant.ensure_collection()
|
||||
|
||||
registered = 0
|
||||
for source in DSFA_SOURCES:
|
||||
try:
|
||||
await store.register_source(source)
|
||||
registered += 1
|
||||
except Exception as e:
|
||||
print(f"Error registering source {source['source_code']}: {e}")
|
||||
|
||||
return {
|
||||
"qdrant_collection_created": qdrant_ok,
|
||||
"sources_registered": registered,
|
||||
"total_sources": len(DSFA_SOURCES)
|
||||
}
|
||||
@@ -1,31 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Full Compliance Pipeline for Legal Corpus.
|
||||
Full Compliance Pipeline for Legal Corpus — Barrel Re-export.
|
||||
|
||||
This script runs the complete pipeline:
|
||||
1. Re-ingest all legal documents with improved chunking
|
||||
2. Extract requirements/checkpoints from chunks
|
||||
3. Generate controls using AI
|
||||
4. Define remediation measures
|
||||
5. Update statistics
|
||||
Split into submodules:
|
||||
- compliance_models.py — Dataclasses (Checkpoint, Control, Measure)
|
||||
- compliance_extraction.py — Pattern extraction & control/measure generation
|
||||
- compliance_pipeline.py — Pipeline phases & orchestrator
|
||||
|
||||
Run on Mac Mini:
|
||||
nohup python full_compliance_pipeline.py > /tmp/compliance_pipeline.log 2>&1 &
|
||||
|
||||
Checkpoints are saved to /tmp/pipeline_checkpoints.json and can be viewed in admin-v2.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
import re
|
||||
import hashlib
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -36,671 +24,25 @@ logging.basicConfig(
|
||||
logging.FileHandler('/tmp/compliance_pipeline.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import checkpoint manager
|
||||
try:
|
||||
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
|
||||
except ImportError:
|
||||
logger.warning("Checkpoint manager not available, running without checkpoints")
|
||||
CheckpointManager = None
|
||||
EXPECTED_VALUES = {}
|
||||
ValidationStatus = None
|
||||
|
||||
# Set environment variables for Docker network
|
||||
# Support both QDRANT_URL and QDRANT_HOST
|
||||
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
|
||||
os.environ["QDRANT_HOST"] = "qdrant"
|
||||
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
||||
|
||||
# Try to import from klausur-service
|
||||
try:
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||
except ImportError:
|
||||
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
"""A requirement/checkpoint extracted from legal text."""
|
||||
id: str
|
||||
regulation_code: str
|
||||
regulation_name: str
|
||||
article: Optional[str]
|
||||
title: str
|
||||
description: str
|
||||
original_text: str
|
||||
chunk_id: str
|
||||
source_url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Control:
|
||||
"""A control derived from checkpoints."""
|
||||
id: str
|
||||
domain: str
|
||||
title: str
|
||||
description: str
|
||||
checkpoints: List[str] # List of checkpoint IDs
|
||||
pass_criteria: str
|
||||
implementation_guidance: str
|
||||
is_automated: bool
|
||||
automation_tool: Optional[str]
|
||||
priority: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Measure:
|
||||
"""A remediation measure for a control."""
|
||||
id: str
|
||||
control_id: str
|
||||
title: str
|
||||
description: str
|
||||
responsible: str
|
||||
deadline_days: int
|
||||
status: str
|
||||
|
||||
|
||||
class CompliancePipeline:
|
||||
"""Handles the full compliance pipeline."""
|
||||
|
||||
def __init__(self):
|
||||
# Support both QDRANT_URL and QDRANT_HOST/PORT
|
||||
qdrant_url = os.getenv("QDRANT_URL", "")
|
||||
if qdrant_url:
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(qdrant_url)
|
||||
qdrant_host = parsed.hostname or "qdrant"
|
||||
qdrant_port = parsed.port or 6333
|
||||
else:
|
||||
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
|
||||
qdrant_port = 6333
|
||||
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
|
||||
self.checkpoints: List[Checkpoint] = []
|
||||
self.controls: List[Control] = []
|
||||
self.measures: List[Measure] = []
|
||||
self.stats = {
|
||||
"chunks_processed": 0,
|
||||
"checkpoints_extracted": 0,
|
||||
"controls_created": 0,
|
||||
"measures_defined": 0,
|
||||
"by_regulation": {},
|
||||
"by_domain": {},
|
||||
}
|
||||
# Initialize checkpoint manager
|
||||
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
|
||||
|
||||
def extract_checkpoints_from_chunk(self, chunk_text: str, payload: Dict) -> List[Checkpoint]:
|
||||
"""
|
||||
Extract checkpoints/requirements from a chunk of text.
|
||||
|
||||
Uses pattern matching to find requirement-like statements.
|
||||
"""
|
||||
checkpoints = []
|
||||
regulation_code = payload.get("regulation_code", "UNKNOWN")
|
||||
regulation_name = payload.get("regulation_name", "Unknown")
|
||||
source_url = payload.get("source_url", "")
|
||||
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
|
||||
|
||||
# Patterns for different requirement types
|
||||
patterns = [
|
||||
# BSI-TR patterns
|
||||
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
|
||||
# Article patterns (GDPR, AI Act, etc.)
|
||||
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-–:]\s*(.+?)(?=\n|$)', 'article'),
|
||||
# Numbered requirements
|
||||
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
|
||||
# "Der Verantwortliche muss" patterns
|
||||
(r'(?:Der Verantwortliche|Die Aufsichtsbehörde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
|
||||
# "Es ist erforderlich" patterns
|
||||
(r'(?:Es ist erforderlich|Es muss gewährleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
|
||||
]
|
||||
|
||||
for pattern, pattern_type in patterns:
|
||||
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
|
||||
for match in matches:
|
||||
if pattern_type == 'bsi_requirement':
|
||||
req_id = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
title = req_id
|
||||
elif pattern_type == 'article':
|
||||
article_num = match.group(1)
|
||||
paragraph = match.group(2) or ""
|
||||
title_text = match.group(3).strip()
|
||||
req_id = f"{regulation_code}-Art{article_num}"
|
||||
if paragraph:
|
||||
req_id += f"-{paragraph}"
|
||||
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
|
||||
description = title_text
|
||||
elif pattern_type == 'numbered':
|
||||
num = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
req_id = f"{regulation_code}-{num}"
|
||||
title = f"Anforderung {num}"
|
||||
else:
|
||||
# Generic requirement
|
||||
description = match.group(0).strip()
|
||||
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
|
||||
title = description[:50] + "..." if len(description) > 50 else description
|
||||
|
||||
# Skip very short matches
|
||||
if len(description) < 20:
|
||||
continue
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=req_id,
|
||||
regulation_code=regulation_code,
|
||||
regulation_name=regulation_name,
|
||||
article=title if 'Art' in title else None,
|
||||
title=title,
|
||||
description=description[:500],
|
||||
original_text=description,
|
||||
chunk_id=chunk_id,
|
||||
source_url=source_url
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
||||
|
||||
def generate_control_for_checkpoints(self, checkpoints: List[Checkpoint]) -> Optional[Control]:
|
||||
"""
|
||||
Generate a control that covers the given checkpoints.
|
||||
|
||||
This is a simplified version - in production this would use the AI assistant.
|
||||
"""
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# Group by regulation
|
||||
regulation = checkpoints[0].regulation_code
|
||||
|
||||
# Determine domain based on content
|
||||
all_text = " ".join([cp.description for cp in checkpoints]).lower()
|
||||
|
||||
domain = "gov" # Default
|
||||
if any(kw in all_text for kw in ["verschlüssel", "krypto", "encrypt", "hash"]):
|
||||
domain = "crypto"
|
||||
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
|
||||
domain = "iam"
|
||||
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
|
||||
domain = "priv"
|
||||
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
|
||||
domain = "sdlc"
|
||||
elif any(kw in all_text for kw in ["überwach", "monitor", "log", "audit"]):
|
||||
domain = "aud"
|
||||
elif any(kw in all_text for kw in ["ki", "künstlich", "ai", "machine learning", "model"]):
|
||||
domain = "ai"
|
||||
elif any(kw in all_text for kw in ["betrieb", "operation", "verfügbar", "backup"]):
|
||||
domain = "ops"
|
||||
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
|
||||
domain = "cra"
|
||||
|
||||
# Generate control ID
|
||||
domain_counts = self.stats.get("by_domain", {})
|
||||
domain_count = domain_counts.get(domain, 0) + 1
|
||||
control_id = f"{domain.upper()}-{domain_count:03d}"
|
||||
|
||||
# Create title from first checkpoint
|
||||
title = checkpoints[0].title
|
||||
if len(title) > 100:
|
||||
title = title[:97] + "..."
|
||||
|
||||
# Create description
|
||||
description = f"Control für {regulation}: " + checkpoints[0].description[:200]
|
||||
|
||||
# Pass criteria
|
||||
pass_criteria = f"Alle {len(checkpoints)} zugehörigen Anforderungen sind erfüllt und dokumentiert."
|
||||
|
||||
# Implementation guidance
|
||||
guidance = f"Implementiere Maßnahmen zur Erfüllung der Anforderungen aus {regulation}. "
|
||||
guidance += f"Dokumentiere die Umsetzung und führe regelmäßige Reviews durch."
|
||||
|
||||
# Determine if automated
|
||||
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
|
||||
|
||||
control = Control(
|
||||
id=control_id,
|
||||
domain=domain,
|
||||
title=title,
|
||||
description=description,
|
||||
checkpoints=[cp.id for cp in checkpoints],
|
||||
pass_criteria=pass_criteria,
|
||||
implementation_guidance=guidance,
|
||||
is_automated=is_automated,
|
||||
automation_tool="CI/CD Pipeline" if is_automated else None,
|
||||
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
def generate_measure_for_control(self, control: Control) -> Measure:
|
||||
"""Generate a remediation measure for a control."""
|
||||
measure_id = f"M-{control.id}"
|
||||
|
||||
# Determine deadline based on priority
|
||||
deadline_days = {
|
||||
"critical": 30,
|
||||
"high": 60,
|
||||
"medium": 90,
|
||||
"low": 180
|
||||
}.get(control.priority, 90)
|
||||
|
||||
# Determine responsible team
|
||||
responsible = {
|
||||
"priv": "Datenschutzbeauftragter",
|
||||
"iam": "IT-Security Team",
|
||||
"sdlc": "Entwicklungsteam",
|
||||
"crypto": "IT-Security Team",
|
||||
"ops": "Operations Team",
|
||||
"aud": "Compliance Team",
|
||||
"ai": "AI/ML Team",
|
||||
"cra": "IT-Security Team",
|
||||
"gov": "Management"
|
||||
}.get(control.domain, "Compliance Team")
|
||||
|
||||
measure = Measure(
|
||||
id=measure_id,
|
||||
control_id=control.id,
|
||||
title=f"Umsetzung: {control.title[:50]}",
|
||||
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
|
||||
responsible=responsible,
|
||||
deadline_days=deadline_days,
|
||||
status="pending"
|
||||
)
|
||||
|
||||
return measure
|
||||
|
||||
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
|
||||
"""Phase 1: Ingest documents (incremental - only missing ones)."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
|
||||
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
try:
|
||||
# Check existing chunks per regulation
|
||||
existing_chunks = {}
|
||||
try:
|
||||
for regulation in REGULATIONS:
|
||||
count_result = self.qdrant.count(
|
||||
collection_name=LEGAL_CORPUS_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
|
||||
)
|
||||
)
|
||||
existing_chunks[regulation.code] = count_result.count
|
||||
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check existing chunks: {e}")
|
||||
# Collection might not exist, that's OK
|
||||
|
||||
# Determine which regulations need ingestion
|
||||
regulations_to_ingest = []
|
||||
for regulation in REGULATIONS:
|
||||
existing = existing_chunks.get(regulation.code, 0)
|
||||
if force_reindex or existing == 0:
|
||||
regulations_to_ingest.append(regulation)
|
||||
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
|
||||
else:
|
||||
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
|
||||
self.stats["by_regulation"][regulation.code] = existing
|
||||
|
||||
if not regulations_to_ingest:
|
||||
logger.info("All regulations already indexed. Skipping ingestion phase.")
|
||||
total_chunks = sum(existing_chunks.values())
|
||||
self.stats["chunks_processed"] = total_chunks
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
|
||||
self.checkpoint_mgr.add_metric("skipped", True)
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
return total_chunks
|
||||
|
||||
# Ingest only missing regulations
|
||||
total_chunks = sum(existing_chunks.values())
|
||||
for i, regulation in enumerate(regulations_to_ingest, 1):
|
||||
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
|
||||
try:
|
||||
count = await ingestion.ingest_regulation(regulation)
|
||||
total_chunks += count
|
||||
self.stats["by_regulation"][regulation.code] = count
|
||||
logger.info(f" -> {count} chunks")
|
||||
|
||||
# Add metric for this regulation
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" -> FAILED: {e}")
|
||||
self.stats["by_regulation"][regulation.code] = 0
|
||||
|
||||
self.stats["chunks_processed"] = total_chunks
|
||||
logger.info(f"\nTotal chunks in collection: {total_chunks}")
|
||||
|
||||
# Validate ingestion results
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
|
||||
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
|
||||
|
||||
# Validate total chunks
|
||||
expected = EXPECTED_VALUES.get("ingestion", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_chunks",
|
||||
expected=expected.get("total_chunks", 8000),
|
||||
actual=total_chunks,
|
||||
min_value=expected.get("min_chunks", 7000)
|
||||
)
|
||||
|
||||
# Validate key regulations
|
||||
reg_expected = expected.get("regulations", {})
|
||||
for reg_code, reg_exp in reg_expected.items():
|
||||
actual = self.stats["by_regulation"].get(reg_code, 0)
|
||||
self.checkpoint_mgr.validate(
|
||||
f"chunks_{reg_code}",
|
||||
expected=reg_exp.get("expected", 0),
|
||||
actual=actual,
|
||||
min_value=reg_exp.get("min", 0)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return total_chunks
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
finally:
|
||||
await ingestion.close()
|
||||
|
||||
async def run_extraction_phase(self) -> int:
|
||||
"""Phase 2: Extract checkpoints from chunks."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
|
||||
|
||||
try:
|
||||
# Scroll through all chunks
|
||||
offset = None
|
||||
total_checkpoints = 0
|
||||
|
||||
while True:
|
||||
result = self.qdrant.scroll(
|
||||
collection_name=LEGAL_CORPUS_COLLECTION,
|
||||
limit=100,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
points, next_offset = result
|
||||
|
||||
if not points:
|
||||
break
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
text = payload.get("text", "")
|
||||
|
||||
checkpoints = self.extract_checkpoints_from_chunk(text, payload)
|
||||
self.checkpoints.extend(checkpoints)
|
||||
total_checkpoints += len(checkpoints)
|
||||
|
||||
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
|
||||
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
self.stats["checkpoints_extracted"] = len(self.checkpoints)
|
||||
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
|
||||
|
||||
# Log per regulation
|
||||
by_reg = {}
|
||||
for cp in self.checkpoints:
|
||||
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
|
||||
for reg, count in sorted(by_reg.items()):
|
||||
logger.info(f" {reg}: {count} checkpoints")
|
||||
|
||||
# Validate extraction results
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
|
||||
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
|
||||
|
||||
expected = EXPECTED_VALUES.get("extraction", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_checkpoints",
|
||||
expected=expected.get("total_checkpoints", 3500),
|
||||
actual=len(self.checkpoints),
|
||||
min_value=expected.get("min_checkpoints", 3000)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.checkpoints)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
async def run_control_generation_phase(self) -> int:
|
||||
"""Phase 3: Generate controls from checkpoints."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 3: CONTROL GENERATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
|
||||
|
||||
try:
|
||||
# Group checkpoints by regulation
|
||||
by_regulation: Dict[str, List[Checkpoint]] = {}
|
||||
for cp in self.checkpoints:
|
||||
reg = cp.regulation_code
|
||||
if reg not in by_regulation:
|
||||
by_regulation[reg] = []
|
||||
by_regulation[reg].append(cp)
|
||||
|
||||
# Generate controls per regulation (group every 3-5 checkpoints)
|
||||
for regulation, checkpoints in by_regulation.items():
|
||||
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
|
||||
|
||||
# Group checkpoints into batches of 3-5
|
||||
batch_size = 4
|
||||
for i in range(0, len(checkpoints), batch_size):
|
||||
batch = checkpoints[i:i + batch_size]
|
||||
control = self.generate_control_for_checkpoints(batch)
|
||||
|
||||
if control:
|
||||
self.controls.append(control)
|
||||
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
|
||||
|
||||
self.stats["controls_created"] = len(self.controls)
|
||||
logger.info(f"\nTotal controls created: {len(self.controls)}")
|
||||
|
||||
# Log per domain
|
||||
for domain, count in sorted(self.stats["by_domain"].items()):
|
||||
logger.info(f" {domain}: {count} controls")
|
||||
|
||||
# Validate control generation
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
|
||||
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
|
||||
|
||||
expected = EXPECTED_VALUES.get("controls", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_controls",
|
||||
expected=expected.get("total_controls", 900),
|
||||
actual=len(self.controls),
|
||||
min_value=expected.get("min_controls", 800)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.controls)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
async def run_measure_generation_phase(self) -> int:
|
||||
"""Phase 4: Generate measures for controls."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 4: MEASURE GENERATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
|
||||
|
||||
try:
|
||||
for control in self.controls:
|
||||
measure = self.generate_measure_for_control(control)
|
||||
self.measures.append(measure)
|
||||
|
||||
self.stats["measures_defined"] = len(self.measures)
|
||||
logger.info(f"\nTotal measures defined: {len(self.measures)}")
|
||||
|
||||
# Validate measure generation
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
|
||||
|
||||
expected = EXPECTED_VALUES.get("measures", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_measures",
|
||||
expected=expected.get("total_measures", 900),
|
||||
actual=len(self.measures),
|
||||
min_value=expected.get("min_measures", 800)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.measures)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
def save_results(self, output_dir: str = "/tmp/compliance_output"):
|
||||
"""Save results to JSON files."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("SAVING RESULTS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save checkpoints
|
||||
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
|
||||
with open(checkpoints_file, "w") as f:
|
||||
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
|
||||
|
||||
# Save controls
|
||||
controls_file = os.path.join(output_dir, "controls.json")
|
||||
with open(controls_file, "w") as f:
|
||||
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
|
||||
|
||||
# Save measures
|
||||
measures_file = os.path.join(output_dir, "measures.json")
|
||||
with open(measures_file, "w") as f:
|
||||
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
|
||||
|
||||
# Save statistics
|
||||
stats_file = os.path.join(output_dir, "statistics.json")
|
||||
self.stats["generated_at"] = datetime.now().isoformat()
|
||||
with open(stats_file, "w") as f:
|
||||
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved statistics to {stats_file}")
|
||||
|
||||
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
|
||||
"""Run the complete pipeline.
|
||||
|
||||
Args:
|
||||
force_reindex: If True, re-ingest all documents even if they exist
|
||||
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
|
||||
logger.info(f"Started at: {datetime.now().isoformat()}")
|
||||
logger.info(f"Force reindex: {force_reindex}")
|
||||
logger.info(f"Skip ingestion: {skip_ingestion}")
|
||||
if self.checkpoint_mgr:
|
||||
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Phase 1: Ingestion (skip if requested or run incrementally)
|
||||
if skip_ingestion:
|
||||
logger.info("Skipping ingestion phase as requested...")
|
||||
# Still get the chunk count
|
||||
try:
|
||||
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
|
||||
self.stats["chunks_processed"] = collection_info.points_count
|
||||
except Exception:
|
||||
self.stats["chunks_processed"] = 0
|
||||
else:
|
||||
await self.run_ingestion_phase(force_reindex=force_reindex)
|
||||
|
||||
# Phase 2: Extraction
|
||||
await self.run_extraction_phase()
|
||||
|
||||
# Phase 3: Control Generation
|
||||
await self.run_control_generation_phase()
|
||||
|
||||
# Phase 4: Measure Generation
|
||||
await self.run_measure_generation_phase()
|
||||
|
||||
# Save results
|
||||
self.save_results()
|
||||
|
||||
# Final summary
|
||||
elapsed = time.time() - start_time
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PIPELINE COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Duration: {elapsed:.1f} seconds")
|
||||
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
|
||||
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
|
||||
logger.info(f"Controls created: {self.stats['controls_created']}")
|
||||
logger.info(f"Measures defined: {self.stats['measures_defined']}")
|
||||
logger.info(f"\nResults saved to: /tmp/compliance_output/")
|
||||
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Complete pipeline checkpoint
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.complete_pipeline({
|
||||
"duration_seconds": elapsed,
|
||||
"chunks_processed": self.stats['chunks_processed'],
|
||||
"checkpoints_extracted": self.stats['checkpoints_extracted'],
|
||||
"controls_created": self.stats['controls_created'],
|
||||
"measures_defined": self.stats['measures_defined'],
|
||||
"by_regulation": self.stats['by_regulation'],
|
||||
"by_domain": self.stats['by_domain'],
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline failed: {e}")
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.state.status = "failed"
|
||||
self.checkpoint_mgr._save()
|
||||
raise
|
||||
# Re-export all public symbols
|
||||
from compliance_models import Checkpoint, Control, Measure
|
||||
from compliance_extraction import (
|
||||
extract_checkpoints_from_chunk,
|
||||
generate_control_for_checkpoints,
|
||||
generate_measure_for_control,
|
||||
)
|
||||
from compliance_pipeline import CompliancePipeline
|
||||
|
||||
__all__ = [
|
||||
"Checkpoint",
|
||||
"Control",
|
||||
"Measure",
|
||||
"extract_checkpoints_from_chunk",
|
||||
"generate_control_for_checkpoints",
|
||||
"generate_measure_for_control",
|
||||
"CompliancePipeline",
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@@ -1,767 +1,35 @@
|
||||
"""
|
||||
GitHub Repository Crawler for Legal Templates.
|
||||
GitHub Repository Crawler — Barrel Re-export
|
||||
|
||||
Crawls GitHub and GitLab repositories to extract legal template documents
|
||||
(Markdown, HTML, JSON, etc.) for ingestion into the RAG system.
|
||||
Split into:
|
||||
- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser
|
||||
- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source
|
||||
|
||||
Features:
|
||||
- Clone repositories via Git or download as ZIP
|
||||
- Parse Markdown, HTML, JSON, and plain text files
|
||||
- Extract structured content with metadata
|
||||
- Track git commit hashes for reproducibility
|
||||
- Handle rate limiting and errors gracefully
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from template_sources import LicenseType, SourceConfig, LICENSES
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
GITLAB_API_URL = "https://gitlab.com/api/v4"
|
||||
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "") # Optional for higher rate limits
|
||||
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
|
||||
REQUEST_TIMEOUT = 60.0
|
||||
RATE_LIMIT_DELAY = 1.0 # Delay between requests to avoid rate limiting
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedDocument:
|
||||
"""A document extracted from a repository."""
|
||||
text: str
|
||||
title: str
|
||||
file_path: str
|
||||
file_type: str # "markdown", "html", "json", "text"
|
||||
source_url: str
|
||||
source_commit: Optional[str] = None
|
||||
source_hash: str = "" # SHA256 of original content
|
||||
sections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
placeholders: List[str] = field(default_factory=list)
|
||||
language: str = "en"
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.source_hash and self.text:
|
||||
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
|
||||
|
||||
|
||||
class MarkdownParser:
|
||||
"""Parse Markdown files into structured content."""
|
||||
|
||||
# Common placeholder patterns
|
||||
PLACEHOLDER_PATTERNS = [
|
||||
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
|
||||
r'\{([a-z_]+)\}', # {company_name}
|
||||
r'\{\{([a-z_]+)\}\}', # {{company_name}}
|
||||
r'__([A-Z_]+)__', # __COMPANY_NAME__
|
||||
r'<([A-Z_]+)>', # <COMPANY_NAME>
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse markdown content into an ExtractedDocument."""
|
||||
# Extract title from first heading or filename
|
||||
title = cls._extract_title(content, filename)
|
||||
|
||||
# Extract sections
|
||||
sections = cls._extract_sections(content)
|
||||
|
||||
# Find placeholders
|
||||
placeholders = cls._find_placeholders(content)
|
||||
|
||||
# Detect language
|
||||
language = cls._detect_language(content)
|
||||
|
||||
# Clean content for indexing
|
||||
clean_text = cls._clean_for_indexing(content)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=clean_text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="markdown",
|
||||
source_url="", # Will be set by caller
|
||||
sections=sections,
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_title(cls, content: str, filename: str) -> str:
|
||||
"""Extract title from markdown heading or filename."""
|
||||
# Look for first h1 heading
|
||||
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if h1_match:
|
||||
return h1_match.group(1).strip()
|
||||
|
||||
# Look for YAML frontmatter title
|
||||
frontmatter_match = re.search(
|
||||
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
|
||||
content, re.DOTALL
|
||||
)
|
||||
if frontmatter_match:
|
||||
return frontmatter_match.group(1).strip()
|
||||
|
||||
# Fall back to filename
|
||||
if filename:
|
||||
name = Path(filename).stem
|
||||
# Convert kebab-case or snake_case to title case
|
||||
return name.replace('-', ' ').replace('_', ' ').title()
|
||||
|
||||
return "Untitled"
|
||||
|
||||
@classmethod
|
||||
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
|
||||
"""Extract sections from markdown content."""
|
||||
sections = []
|
||||
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
|
||||
|
||||
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
# Save previous section if it has content
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = current_section["content"].strip()
|
||||
sections.append(current_section.copy())
|
||||
|
||||
# Start new section
|
||||
level = len(match.group(1))
|
||||
heading = match.group(2).strip()
|
||||
current_section = {
|
||||
"heading": heading,
|
||||
"level": level,
|
||||
"content": "",
|
||||
"start": match.end(),
|
||||
}
|
||||
|
||||
# Add final section
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = content[current_section["start"]:].strip()
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
@classmethod
|
||||
def _find_placeholders(cls, content: str) -> List[str]:
|
||||
"""Find placeholder patterns in content."""
|
||||
placeholders = set()
|
||||
for pattern in cls.PLACEHOLDER_PATTERNS:
|
||||
for match in re.finditer(pattern, content):
|
||||
placeholder = match.group(0)
|
||||
placeholders.add(placeholder)
|
||||
return sorted(list(placeholders))
|
||||
|
||||
@classmethod
|
||||
def _detect_language(cls, content: str) -> str:
|
||||
"""Detect language from content."""
|
||||
# Look for German-specific words
|
||||
german_indicators = [
|
||||
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
|
||||
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
|
||||
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
|
||||
]
|
||||
|
||||
lower_content = content.lower()
|
||||
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
|
||||
|
||||
if german_count >= 3:
|
||||
return "de"
|
||||
return "en"
|
||||
|
||||
@classmethod
|
||||
def _clean_for_indexing(cls, content: str) -> str:
|
||||
"""Clean markdown content for text indexing."""
|
||||
# Remove YAML frontmatter
|
||||
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
|
||||
|
||||
# Remove HTML comments
|
||||
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
||||
|
||||
# Remove inline HTML tags but keep content
|
||||
content = re.sub(r'<[^>]+>', '', content)
|
||||
|
||||
# Convert markdown formatting
|
||||
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content) # Bold
|
||||
content = re.sub(r'\*(.+?)\*', r'\1', content) # Italic
|
||||
content = re.sub(r'`(.+?)`', r'\1', content) # Inline code
|
||||
content = re.sub(r'~~(.+?)~~', r'\1', content) # Strikethrough
|
||||
|
||||
# Remove link syntax but keep text
|
||||
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
|
||||
|
||||
# Remove image syntax
|
||||
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
|
||||
|
||||
# Clean up whitespace
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
content = re.sub(r' +', ' ', content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
class HTMLParser:
|
||||
"""Parse HTML files into structured content."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse HTML content into an ExtractedDocument."""
|
||||
# Extract title
|
||||
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
|
||||
title = title_match.group(1) if title_match else Path(filename).stem
|
||||
|
||||
# Convert to text
|
||||
text = cls._html_to_text(content)
|
||||
|
||||
# Find placeholders
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
|
||||
# Detect language
|
||||
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
|
||||
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="html",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _html_to_text(cls, html: str) -> str:
|
||||
"""Convert HTML to clean text."""
|
||||
# Remove script and style tags
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Remove comments
|
||||
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
|
||||
|
||||
# Replace common entities
|
||||
html = html.replace(' ', ' ')
|
||||
html = html.replace('&', '&')
|
||||
html = html.replace('<', '<')
|
||||
html = html.replace('>', '>')
|
||||
html = html.replace('"', '"')
|
||||
html = html.replace(''', "'")
|
||||
|
||||
# Add line breaks for block elements
|
||||
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
|
||||
|
||||
# Remove remaining tags
|
||||
html = re.sub(r'<[^>]+>', '', html)
|
||||
|
||||
# Clean whitespace
|
||||
html = re.sub(r'[ \t]+', ' ', html)
|
||||
html = re.sub(r'\n[ \t]+', '\n', html)
|
||||
html = re.sub(r'[ \t]+\n', '\n', html)
|
||||
html = re.sub(r'\n{3,}', '\n\n', html)
|
||||
|
||||
return html.strip()
|
||||
|
||||
|
||||
class JSONParser:
|
||||
"""Parse JSON files containing legal template data."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
|
||||
"""Parse JSON content into ExtractedDocuments."""
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON from {filename}: {e}")
|
||||
return []
|
||||
|
||||
documents = []
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Handle different JSON structures
|
||||
documents.extend(cls._parse_dict(data, filename))
|
||||
elif isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
docs = cls._parse_dict(item, f"{filename}[{i}]")
|
||||
documents.extend(docs)
|
||||
|
||||
return documents
|
||||
|
||||
@classmethod
|
||||
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
|
||||
"""Parse a dictionary into documents."""
|
||||
documents = []
|
||||
|
||||
# Look for text content in common keys
|
||||
text_keys = ['text', 'content', 'body', 'description', 'value']
|
||||
title_keys = ['title', 'name', 'heading', 'label', 'key']
|
||||
|
||||
# Try to find main text content
|
||||
text = ""
|
||||
for key in text_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
text = data[key]
|
||||
break
|
||||
|
||||
if not text:
|
||||
# Check for nested structures (like webflorist format)
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(item, str) and len(item) > 50:
|
||||
# Treat long strings as content
|
||||
documents.append(ExtractedDocument(
|
||||
text=item,
|
||||
title=f"{key} {i+1}",
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
language=MarkdownParser._detect_language(item),
|
||||
))
|
||||
return documents
|
||||
|
||||
# Found text content
|
||||
title = ""
|
||||
for key in title_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
title = data[key]
|
||||
break
|
||||
|
||||
if not title:
|
||||
title = Path(filename).stem
|
||||
|
||||
# Extract metadata
|
||||
metadata = {}
|
||||
for key, value in data.items():
|
||||
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
|
||||
metadata[key] = value
|
||||
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
|
||||
|
||||
documents.append(ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
metadata=metadata,
|
||||
))
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class GitHubCrawler:
|
||||
"""Crawl GitHub repositories for legal templates."""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
self.token = token or GITHUB_TOKEN
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "LegalTemplatesCrawler/1.0",
|
||||
}
|
||||
if self.token:
|
||||
self.headers["Authorization"] = f"token {self.token}"
|
||||
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=self.headers,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
|
||||
"""Parse repository URL into owner, repo, and host."""
|
||||
parsed = urlparse(url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError(f"Invalid repository URL: {url}")
|
||||
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
if 'gitlab' in parsed.netloc:
|
||||
host = 'gitlab'
|
||||
else:
|
||||
host = 'github'
|
||||
|
||||
return owner, repo, host
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch of a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("default_branch", "main")
|
||||
|
||||
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
|
||||
"""Get the latest commit SHA for a branch."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("sha", "")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
path: str = "",
|
||||
branch: str = "main",
|
||||
patterns: List[str] = None,
|
||||
exclude_patterns: List[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List files in a repository matching the given patterns."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
patterns = patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = exclude_patterns or []
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("tree", []):
|
||||
if item["type"] != "blob":
|
||||
continue
|
||||
|
||||
file_path = item["path"]
|
||||
|
||||
# Check exclude patterns
|
||||
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include patterns
|
||||
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
|
||||
if not matched:
|
||||
continue
|
||||
|
||||
# Skip large files
|
||||
if item.get("size", 0) > MAX_FILE_SIZE:
|
||||
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
|
||||
continue
|
||||
|
||||
files.append({
|
||||
"path": file_path,
|
||||
"sha": item["sha"],
|
||||
"size": item.get("size", 0),
|
||||
"url": item.get("url", ""),
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
|
||||
"""Get the content of a file from a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
# Use raw content URL for simplicity
|
||||
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
async def crawl_repository(
|
||||
self,
|
||||
source: SourceConfig,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a repository and yield extracted documents."""
|
||||
if not source.repo_url:
|
||||
logger.warning(f"No repo URL for source: {source.name}")
|
||||
return
|
||||
|
||||
try:
|
||||
owner, repo, host = self._parse_repo_url(source.repo_url)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
|
||||
return
|
||||
|
||||
if host == "gitlab":
|
||||
logger.info(f"GitLab repos not yet supported: {source.name}")
|
||||
return
|
||||
|
||||
logger.info(f"Crawling repository: {owner}/{repo}")
|
||||
|
||||
try:
|
||||
# Get default branch and latest commit
|
||||
branch = await self.get_default_branch(owner, repo)
|
||||
commit_sha = await self.get_latest_commit(owner, repo, branch)
|
||||
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
# List files matching patterns
|
||||
files = await self.list_files(
|
||||
owner, repo,
|
||||
branch=branch,
|
||||
patterns=source.file_patterns,
|
||||
exclude_patterns=source.exclude_patterns,
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(files)} matching files in {source.name}")
|
||||
|
||||
for file_info in files:
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
try:
|
||||
content = await self.get_file_content(
|
||||
owner, repo, file_info["path"], branch
|
||||
)
|
||||
|
||||
# Parse based on file type
|
||||
file_path = file_info["path"]
|
||||
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
|
||||
|
||||
if file_path.endswith('.md'):
|
||||
doc = MarkdownParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.html') or file_path.endswith('.htm'):
|
||||
doc = HTMLParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.json'):
|
||||
docs = JSONParser.parse(content, file_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.txt'):
|
||||
# Plain text file
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=Path(file_path).stem,
|
||||
file_path=file_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
source_commit=commit_sha,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch {file_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error crawling {source.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error crawling {source.name}: {e}")
|
||||
|
||||
|
||||
class RepositoryDownloader:
|
||||
"""Download and extract repository archives."""
|
||||
|
||||
def __init__(self):
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=120.0,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
|
||||
"""Download repository as ZIP and extract to temp directory."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
|
||||
|
||||
parsed = urlparse(repo_url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
|
||||
|
||||
logger.info(f"Downloading ZIP from {zip_url}")
|
||||
|
||||
response = await self.http_client.get(zip_url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Save to temp file
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
zip_path = temp_dir / f"{repo}.zip"
|
||||
|
||||
with open(zip_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
# Extract ZIP
|
||||
extract_dir = temp_dir / repo
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
# The extracted directory is usually named repo-branch
|
||||
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
|
||||
if extracted_dirs:
|
||||
return extracted_dirs[0]
|
||||
|
||||
return extract_dir
|
||||
|
||||
async def crawl_local_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
source: SourceConfig,
|
||||
base_url: str,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a local directory for documents."""
|
||||
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = source.exclude_patterns or []
|
||||
|
||||
for pattern in patterns:
|
||||
for file_path in directory.rglob(pattern.replace("**/", "")):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
rel_path = str(file_path.relative_to(directory))
|
||||
|
||||
# Check exclude patterns
|
||||
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Skip large files
|
||||
if file_path.stat().st_size > MAX_FILE_SIZE:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = file_path.read_text(encoding='latin-1')
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
source_url = f"{base_url}/{rel_path}"
|
||||
|
||||
if file_path.suffix == '.md':
|
||||
doc = MarkdownParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix in ['.html', '.htm']:
|
||||
doc = HTMLParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.json':
|
||||
docs = JSONParser.parse(content, rel_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.txt':
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=file_path.stem,
|
||||
file_path=rel_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
def cleanup(self, directory: Path):
|
||||
"""Clean up temporary directory."""
|
||||
if directory.exists():
|
||||
shutil.rmtree(directory, ignore_errors=True)
|
||||
|
||||
|
||||
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
|
||||
"""Crawl a source configuration and return all extracted documents."""
|
||||
documents = []
|
||||
|
||||
if source.repo_url:
|
||||
async with GitHubCrawler() as crawler:
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
# CLI for testing
|
||||
async def main():
|
||||
"""Test crawler with a sample source."""
|
||||
from template_sources import TEMPLATE_SOURCES
|
||||
|
||||
# Test with github-site-policy
|
||||
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
|
||||
|
||||
async with GitHubCrawler() as crawler:
|
||||
count = 0
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
count += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Title: {doc.title}")
|
||||
print(f"Path: {doc.file_path}")
|
||||
print(f"Type: {doc.file_type}")
|
||||
print(f"Language: {doc.language}")
|
||||
print(f"URL: {doc.source_url}")
|
||||
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
|
||||
print(f"Text preview: {doc.text[:200]}...")
|
||||
|
||||
print(f"\n\nTotal documents: {count}")
|
||||
|
||||
# Parsers
|
||||
from github_crawler_parsers import ( # noqa: F401
|
||||
ExtractedDocument,
|
||||
MarkdownParser,
|
||||
HTMLParser,
|
||||
JSONParser,
|
||||
)
|
||||
|
||||
# Crawler and downloader
|
||||
from github_crawler_core import ( # noqa: F401
|
||||
GITHUB_API_URL,
|
||||
GITLAB_API_URL,
|
||||
GITHUB_TOKEN,
|
||||
MAX_FILE_SIZE,
|
||||
REQUEST_TIMEOUT,
|
||||
RATE_LIMIT_DELAY,
|
||||
GitHubCrawler,
|
||||
RepositoryDownloader,
|
||||
crawl_source,
|
||||
main,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
GitHub Crawler - Core Crawler and Downloader
|
||||
|
||||
GitHubCrawler for API-based repository crawling and RepositoryDownloader
|
||||
for ZIP-based local extraction.
|
||||
|
||||
Extracted from github_crawler.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from template_sources import SourceConfig
|
||||
from github_crawler_parsers import (
|
||||
ExtractedDocument,
|
||||
MarkdownParser,
|
||||
HTMLParser,
|
||||
JSONParser,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
GITLAB_API_URL = "https://gitlab.com/api/v4"
|
||||
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "")
|
||||
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
|
||||
REQUEST_TIMEOUT = 60.0
|
||||
RATE_LIMIT_DELAY = 1.0
|
||||
|
||||
|
||||
class GitHubCrawler:
|
||||
"""Crawl GitHub repositories for legal templates."""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
self.token = token or GITHUB_TOKEN
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "LegalTemplatesCrawler/1.0",
|
||||
}
|
||||
if self.token:
|
||||
self.headers["Authorization"] = f"token {self.token}"
|
||||
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=self.headers,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
|
||||
"""Parse repository URL into owner, repo, and host."""
|
||||
parsed = urlparse(url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError(f"Invalid repository URL: {url}")
|
||||
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
if 'gitlab' in parsed.netloc:
|
||||
host = 'gitlab'
|
||||
else:
|
||||
host = 'github'
|
||||
|
||||
return owner, repo, host
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch of a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("default_branch", "main")
|
||||
|
||||
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
|
||||
"""Get the latest commit SHA for a branch."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("sha", "")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
path: str = "",
|
||||
branch: str = "main",
|
||||
patterns: List[str] = None,
|
||||
exclude_patterns: List[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List files in a repository matching the given patterns."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
patterns = patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = exclude_patterns or []
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("tree", []):
|
||||
if item["type"] != "blob":
|
||||
continue
|
||||
|
||||
file_path = item["path"]
|
||||
|
||||
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
|
||||
if not matched:
|
||||
continue
|
||||
|
||||
if item.get("size", 0) > MAX_FILE_SIZE:
|
||||
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
|
||||
continue
|
||||
|
||||
files.append({
|
||||
"path": file_path,
|
||||
"sha": item["sha"],
|
||||
"size": item.get("size", 0),
|
||||
"url": item.get("url", ""),
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
|
||||
"""Get the content of a file from a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
async def crawl_repository(
|
||||
self,
|
||||
source: SourceConfig,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a repository and yield extracted documents."""
|
||||
if not source.repo_url:
|
||||
logger.warning(f"No repo URL for source: {source.name}")
|
||||
return
|
||||
|
||||
try:
|
||||
owner, repo, host = self._parse_repo_url(source.repo_url)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
|
||||
return
|
||||
|
||||
if host == "gitlab":
|
||||
logger.info(f"GitLab repos not yet supported: {source.name}")
|
||||
return
|
||||
|
||||
logger.info(f"Crawling repository: {owner}/{repo}")
|
||||
|
||||
try:
|
||||
branch = await self.get_default_branch(owner, repo)
|
||||
commit_sha = await self.get_latest_commit(owner, repo, branch)
|
||||
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
files = await self.list_files(
|
||||
owner, repo,
|
||||
branch=branch,
|
||||
patterns=source.file_patterns,
|
||||
exclude_patterns=source.exclude_patterns,
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(files)} matching files in {source.name}")
|
||||
|
||||
for file_info in files:
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
try:
|
||||
content = await self.get_file_content(
|
||||
owner, repo, file_info["path"], branch
|
||||
)
|
||||
|
||||
file_path = file_info["path"]
|
||||
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
|
||||
|
||||
if file_path.endswith('.md'):
|
||||
doc = MarkdownParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.html') or file_path.endswith('.htm'):
|
||||
doc = HTMLParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.json'):
|
||||
docs = JSONParser.parse(content, file_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.txt'):
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=Path(file_path).stem,
|
||||
file_path=file_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
source_commit=commit_sha,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch {file_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error crawling {source.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error crawling {source.name}: {e}")
|
||||
|
||||
|
||||
class RepositoryDownloader:
|
||||
"""Download and extract repository archives."""
|
||||
|
||||
def __init__(self):
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=120.0,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
|
||||
"""Download repository as ZIP and extract to temp directory."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
|
||||
|
||||
parsed = urlparse(repo_url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
|
||||
|
||||
logger.info(f"Downloading ZIP from {zip_url}")
|
||||
|
||||
response = await self.http_client.get(zip_url)
|
||||
response.raise_for_status()
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
zip_path = temp_dir / f"{repo}.zip"
|
||||
|
||||
with open(zip_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
extract_dir = temp_dir / repo
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
|
||||
if extracted_dirs:
|
||||
return extracted_dirs[0]
|
||||
|
||||
return extract_dir
|
||||
|
||||
async def crawl_local_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
source: SourceConfig,
|
||||
base_url: str,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a local directory for documents."""
|
||||
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = source.exclude_patterns or []
|
||||
|
||||
for pattern in patterns:
|
||||
for file_path in directory.rglob(pattern.replace("**/", "")):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
rel_path = str(file_path.relative_to(directory))
|
||||
|
||||
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
if file_path.stat().st_size > MAX_FILE_SIZE:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = file_path.read_text(encoding='latin-1')
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
source_url = f"{base_url}/{rel_path}"
|
||||
|
||||
if file_path.suffix == '.md':
|
||||
doc = MarkdownParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix in ['.html', '.htm']:
|
||||
doc = HTMLParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.json':
|
||||
docs = JSONParser.parse(content, rel_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.txt':
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=file_path.stem,
|
||||
file_path=rel_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
def cleanup(self, directory: Path):
|
||||
"""Clean up temporary directory."""
|
||||
if directory.exists():
|
||||
shutil.rmtree(directory, ignore_errors=True)
|
||||
|
||||
|
||||
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
|
||||
"""Crawl a source configuration and return all extracted documents."""
|
||||
documents = []
|
||||
|
||||
if source.repo_url:
|
||||
async with GitHubCrawler() as crawler:
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
# CLI for testing
|
||||
async def main():
|
||||
"""Test crawler with a sample source."""
|
||||
from template_sources import TEMPLATE_SOURCES
|
||||
|
||||
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
|
||||
|
||||
async with GitHubCrawler() as crawler:
|
||||
count = 0
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
count += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Title: {doc.title}")
|
||||
print(f"Path: {doc.file_path}")
|
||||
print(f"Type: {doc.file_type}")
|
||||
print(f"Language: {doc.language}")
|
||||
print(f"URL: {doc.source_url}")
|
||||
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
|
||||
print(f"Text preview: {doc.text[:200]}...")
|
||||
|
||||
print(f"\n\nTotal documents: {count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
GitHub Crawler - Document Parsers
|
||||
|
||||
Markdown, HTML, and JSON parsers for extracting structured content
|
||||
from legal template documents.
|
||||
|
||||
Extracted from github_crawler.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedDocument:
|
||||
"""A document extracted from a repository."""
|
||||
text: str
|
||||
title: str
|
||||
file_path: str
|
||||
file_type: str # "markdown", "html", "json", "text"
|
||||
source_url: str
|
||||
source_commit: Optional[str] = None
|
||||
source_hash: str = "" # SHA256 of original content
|
||||
sections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
placeholders: List[str] = field(default_factory=list)
|
||||
language: str = "en"
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.source_hash and self.text:
|
||||
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
|
||||
|
||||
|
||||
class MarkdownParser:
|
||||
"""Parse Markdown files into structured content."""
|
||||
|
||||
# Common placeholder patterns
|
||||
PLACEHOLDER_PATTERNS = [
|
||||
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
|
||||
r'\{([a-z_]+)\}', # {company_name}
|
||||
r'\{\{([a-z_]+)\}\}', # {{company_name}}
|
||||
r'__([A-Z_]+)__', # __COMPANY_NAME__
|
||||
r'<([A-Z_]+)>', # <COMPANY_NAME>
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse markdown content into an ExtractedDocument."""
|
||||
title = cls._extract_title(content, filename)
|
||||
sections = cls._extract_sections(content)
|
||||
placeholders = cls._find_placeholders(content)
|
||||
language = cls._detect_language(content)
|
||||
clean_text = cls._clean_for_indexing(content)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=clean_text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="markdown",
|
||||
source_url="",
|
||||
sections=sections,
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_title(cls, content: str, filename: str) -> str:
|
||||
"""Extract title from markdown heading or filename."""
|
||||
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if h1_match:
|
||||
return h1_match.group(1).strip()
|
||||
|
||||
frontmatter_match = re.search(
|
||||
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
|
||||
content, re.DOTALL
|
||||
)
|
||||
if frontmatter_match:
|
||||
return frontmatter_match.group(1).strip()
|
||||
|
||||
if filename:
|
||||
name = Path(filename).stem
|
||||
return name.replace('-', ' ').replace('_', ' ').title()
|
||||
|
||||
return "Untitled"
|
||||
|
||||
@classmethod
|
||||
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
|
||||
"""Extract sections from markdown content."""
|
||||
sections = []
|
||||
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
|
||||
|
||||
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = current_section["content"].strip()
|
||||
sections.append(current_section.copy())
|
||||
|
||||
level = len(match.group(1))
|
||||
heading = match.group(2).strip()
|
||||
current_section = {
|
||||
"heading": heading,
|
||||
"level": level,
|
||||
"content": "",
|
||||
"start": match.end(),
|
||||
}
|
||||
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = content[current_section["start"]:].strip()
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
@classmethod
|
||||
def _find_placeholders(cls, content: str) -> List[str]:
|
||||
"""Find placeholder patterns in content."""
|
||||
placeholders = set()
|
||||
for pattern in cls.PLACEHOLDER_PATTERNS:
|
||||
for match in re.finditer(pattern, content):
|
||||
placeholder = match.group(0)
|
||||
placeholders.add(placeholder)
|
||||
return sorted(list(placeholders))
|
||||
|
||||
@classmethod
|
||||
def _detect_language(cls, content: str) -> str:
|
||||
"""Detect language from content."""
|
||||
german_indicators = [
|
||||
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
|
||||
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
|
||||
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
|
||||
]
|
||||
|
||||
lower_content = content.lower()
|
||||
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
|
||||
|
||||
if german_count >= 3:
|
||||
return "de"
|
||||
return "en"
|
||||
|
||||
@classmethod
|
||||
def _clean_for_indexing(cls, content: str) -> str:
|
||||
"""Clean markdown content for text indexing."""
|
||||
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<[^>]+>', '', content)
|
||||
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content)
|
||||
content = re.sub(r'\*(.+?)\*', r'\1', content)
|
||||
content = re.sub(r'`(.+?)`', r'\1', content)
|
||||
content = re.sub(r'~~(.+?)~~', r'\1', content)
|
||||
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
|
||||
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
content = re.sub(r' +', ' ', content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
class HTMLParser:
|
||||
"""Parse HTML files into structured content."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse HTML content into an ExtractedDocument."""
|
||||
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
|
||||
title = title_match.group(1) if title_match else Path(filename).stem
|
||||
|
||||
text = cls._html_to_text(content)
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
|
||||
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
|
||||
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="html",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _html_to_text(cls, html: str) -> str:
|
||||
"""Convert HTML to clean text."""
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
|
||||
|
||||
html = html.replace(' ', ' ')
|
||||
html = html.replace('&', '&')
|
||||
html = html.replace('<', '<')
|
||||
html = html.replace('>', '>')
|
||||
html = html.replace('"', '"')
|
||||
html = html.replace(''', "'")
|
||||
|
||||
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
|
||||
|
||||
html = re.sub(r'<[^>]+>', '', html)
|
||||
|
||||
html = re.sub(r'[ \t]+', ' ', html)
|
||||
html = re.sub(r'\n[ \t]+', '\n', html)
|
||||
html = re.sub(r'[ \t]+\n', '\n', html)
|
||||
html = re.sub(r'\n{3,}', '\n\n', html)
|
||||
|
||||
return html.strip()
|
||||
|
||||
|
||||
class JSONParser:
|
||||
"""Parse JSON files containing legal template data."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
|
||||
"""Parse JSON content into ExtractedDocuments."""
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}")
|
||||
return []
|
||||
|
||||
documents = []
|
||||
|
||||
if isinstance(data, dict):
|
||||
documents.extend(cls._parse_dict(data, filename))
|
||||
elif isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
docs = cls._parse_dict(item, f"{filename}[{i}]")
|
||||
documents.extend(docs)
|
||||
|
||||
return documents
|
||||
|
||||
@classmethod
|
||||
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
|
||||
"""Parse a dictionary into documents."""
|
||||
documents = []
|
||||
|
||||
text_keys = ['text', 'content', 'body', 'description', 'value']
|
||||
title_keys = ['title', 'name', 'heading', 'label', 'key']
|
||||
|
||||
text = ""
|
||||
for key in text_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
text = data[key]
|
||||
break
|
||||
|
||||
if not text:
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(item, str) and len(item) > 50:
|
||||
documents.append(ExtractedDocument(
|
||||
text=item,
|
||||
title=f"{key} {i+1}",
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
language=MarkdownParser._detect_language(item),
|
||||
))
|
||||
return documents
|
||||
|
||||
title = ""
|
||||
for key in title_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
title = data[key]
|
||||
break
|
||||
|
||||
if not title:
|
||||
title = Path(filename).stem
|
||||
|
||||
metadata = {}
|
||||
for key, value in data.items():
|
||||
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
|
||||
metadata[key] = value
|
||||
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
|
||||
|
||||
documents.append(ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
metadata=metadata,
|
||||
))
|
||||
|
||||
return documents
|
||||
@@ -1,790 +1,30 @@
|
||||
"""
|
||||
Legal Corpus API - Endpoints for RAG page in admin-v2
|
||||
Legal Corpus API — Barrel Re-export
|
||||
|
||||
Provides endpoints for:
|
||||
- GET /api/v1/admin/legal-corpus/status - Collection status with chunk counts
|
||||
- GET /api/v1/admin/legal-corpus/search - Semantic search
|
||||
- POST /api/v1/admin/legal-corpus/ingest - Trigger ingestion
|
||||
- GET /api/v1/admin/legal-corpus/ingestion-status - Ingestion status
|
||||
- POST /api/v1/admin/legal-corpus/upload - Upload document
|
||||
- POST /api/v1/admin/legal-corpus/add-link - Add link for ingestion
|
||||
- POST /api/v1/admin/pipeline/start - Start compliance pipeline
|
||||
Split into:
|
||||
- legal_corpus_routes.py — Corpus endpoints (status, search, ingest, upload)
|
||||
- legal_corpus_pipeline.py — Pipeline endpoints (checkpoints, start, status)
|
||||
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import httpx
|
||||
import uuid
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"])
|
||||
|
||||
# Configuration
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
||||
COLLECTION_NAME = "bp_legal_corpus"
|
||||
|
||||
# All regulations for status endpoint
|
||||
REGULATIONS = [
|
||||
{"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"},
|
||||
{"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"},
|
||||
{"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"},
|
||||
{"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"},
|
||||
{"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"},
|
||||
{"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"},
|
||||
{"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"},
|
||||
{"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"},
|
||||
{"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"},
|
||||
{"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"},
|
||||
{"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"},
|
||||
{"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"},
|
||||
{"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"},
|
||||
{"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"},
|
||||
{"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"},
|
||||
{"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"},
|
||||
{"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"},
|
||||
{"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"},
|
||||
{"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"},
|
||||
]
|
||||
|
||||
# Ingestion state (in-memory for now)
|
||||
ingestion_state = {
|
||||
"running": False,
|
||||
"completed": False,
|
||||
"current_regulation": None,
|
||||
"processed": 0,
|
||||
"total": len(REGULATIONS),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
regulations: Optional[List[str]] = None
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
force: bool = False
|
||||
regulations: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AddLinkRequest(BaseModel):
|
||||
url: str
|
||||
title: str
|
||||
code: str # Regulation code (e.g. "CUSTOM-1")
|
||||
document_type: str = "custom" # custom, eu_regulation, eu_directive, de_law, bsi_standard
|
||||
|
||||
|
||||
class StartPipelineRequest(BaseModel):
|
||||
force_reindex: bool = False
|
||||
skip_ingestion: bool = False
|
||||
|
||||
|
||||
# Store for custom documents (in-memory for now, should be persisted)
|
||||
custom_documents: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
async def get_qdrant_client():
|
||||
"""Get async HTTP client for Qdrant."""
|
||||
return httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_legal_corpus_status():
|
||||
"""
|
||||
Get status of the legal corpus collection including chunk counts per regulation.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
# Get collection info
|
||||
collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}")
|
||||
if collection_res.status_code != 200:
|
||||
return {
|
||||
"collection": COLLECTION_NAME,
|
||||
"totalPoints": 0,
|
||||
"vectorSize": 1024,
|
||||
"status": "not_found",
|
||||
"regulations": {},
|
||||
}
|
||||
|
||||
collection_data = collection_res.json()
|
||||
result = collection_data.get("result", {})
|
||||
|
||||
# Get chunk counts per regulation
|
||||
regulation_counts = {}
|
||||
for reg in REGULATIONS:
|
||||
count_res = await client.post(
|
||||
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count",
|
||||
json={
|
||||
"filter": {
|
||||
"must": [{"key": "regulation_code", "match": {"value": reg["code"]}}]
|
||||
}
|
||||
},
|
||||
)
|
||||
if count_res.status_code == 200:
|
||||
count_data = count_res.json()
|
||||
regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0)
|
||||
else:
|
||||
regulation_counts[reg["code"]] = 0
|
||||
|
||||
return {
|
||||
"collection": COLLECTION_NAME,
|
||||
"totalPoints": result.get("points_count", 0),
|
||||
"vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024),
|
||||
"status": result.get("status", "unknown"),
|
||||
"regulations": regulation_counts,
|
||||
}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to get Qdrant status: {e}")
|
||||
raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_legal_corpus(
|
||||
query: str = Query(..., description="Search query"),
|
||||
top_k: int = Query(5, ge=1, le=20, description="Number of results"),
|
||||
regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"),
|
||||
):
|
||||
"""
|
||||
Semantic search in legal corpus using BGE-M3 embeddings.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
# Generate embedding for query
|
||||
embed_res = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": [query]},
|
||||
)
|
||||
if embed_res.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail="Embedding service error")
|
||||
|
||||
embed_data = embed_res.json()
|
||||
query_vector = embed_data["embeddings"][0]
|
||||
|
||||
# Build Qdrant search request
|
||||
search_request = {
|
||||
"vector": query_vector,
|
||||
"limit": top_k,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
# Add regulation filter if specified
|
||||
if regulations:
|
||||
reg_codes = [r.strip() for r in regulations.split(",")]
|
||||
search_request["filter"] = {
|
||||
"should": [
|
||||
{"key": "regulation_code", "match": {"value": code}}
|
||||
for code in reg_codes
|
||||
]
|
||||
}
|
||||
|
||||
# Search Qdrant
|
||||
search_res = await client.post(
|
||||
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search",
|
||||
json=search_request,
|
||||
)
|
||||
|
||||
if search_res.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail="Search failed")
|
||||
|
||||
search_data = search_res.json()
|
||||
results = []
|
||||
for point in search_data.get("result", []):
|
||||
payload = point.get("payload", {})
|
||||
results.append({
|
||||
"text": payload.get("text", ""),
|
||||
"regulation_code": payload.get("regulation_code", ""),
|
||||
"regulation_name": payload.get("regulation_name", ""),
|
||||
"article": payload.get("article"),
|
||||
"paragraph": payload.get("paragraph"),
|
||||
"source_url": payload.get("source_url", ""),
|
||||
"score": point.get("score", 0),
|
||||
})
|
||||
|
||||
return {"results": results, "query": query, "count": len(results)}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ingest")
|
||||
async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Trigger legal corpus ingestion in background.
|
||||
"""
|
||||
global ingestion_state
|
||||
|
||||
if ingestion_state["running"]:
|
||||
raise HTTPException(status_code=409, detail="Ingestion already running")
|
||||
|
||||
# Reset state
|
||||
ingestion_state = {
|
||||
"running": True,
|
||||
"completed": False,
|
||||
"current_regulation": None,
|
||||
"processed": 0,
|
||||
"total": len(REGULATIONS),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
# Start ingestion in background
|
||||
background_tasks.add_task(run_ingestion, request.force, request.regulations)
|
||||
|
||||
return {
|
||||
"status": "started",
|
||||
"job_id": "manual-trigger",
|
||||
"message": f"Ingestion started for {len(REGULATIONS)} regulations",
|
||||
}
|
||||
|
||||
|
||||
async def run_ingestion(force: bool, regulations: Optional[List[str]]):
|
||||
"""Background task for running ingestion."""
|
||||
global ingestion_state
|
||||
|
||||
try:
|
||||
# Import ingestion module
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
# Filter regulations if specified
|
||||
regs_to_process = regulations or [r["code"] for r in REGULATIONS]
|
||||
|
||||
for i, reg_code in enumerate(regs_to_process):
|
||||
ingestion_state["current_regulation"] = reg_code
|
||||
ingestion_state["processed"] = i
|
||||
|
||||
try:
|
||||
await ingestion.ingest_single(reg_code, force=force)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest {reg_code}: {e}")
|
||||
|
||||
ingestion_state["completed"] = True
|
||||
ingestion_state["processed"] = len(regs_to_process)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ingestion failed: {e}")
|
||||
ingestion_state["error"] = str(e)
|
||||
|
||||
finally:
|
||||
ingestion_state["running"] = False
|
||||
|
||||
|
||||
@router.get("/ingestion-status")
|
||||
async def get_ingestion_status():
|
||||
"""
|
||||
Get current ingestion status.
|
||||
"""
|
||||
return ingestion_state
|
||||
|
||||
|
||||
@router.get("/regulations")
|
||||
async def get_regulations():
|
||||
"""
|
||||
Get list of all supported regulations.
|
||||
"""
|
||||
return {"regulations": REGULATIONS}
|
||||
|
||||
|
||||
@router.get("/custom-documents")
|
||||
async def get_custom_documents():
|
||||
"""
|
||||
Get list of custom documents added by user.
|
||||
"""
|
||||
return {"documents": custom_documents}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_document(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
title: str = Form(...),
|
||||
code: str = Form(...),
|
||||
document_type: str = Form("custom"),
|
||||
):
|
||||
"""
|
||||
Upload a document (PDF) for ingestion into the legal corpus.
|
||||
|
||||
The document will be saved and queued for processing.
|
||||
"""
|
||||
global custom_documents
|
||||
|
||||
# Validate file type
|
||||
if not file.filename.endswith(('.pdf', '.PDF')):
|
||||
raise HTTPException(status_code=400, detail="Only PDF files are supported")
|
||||
|
||||
# Create upload directory if needed
|
||||
upload_dir = "/tmp/legal_corpus_uploads"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
# Save file with unique name
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
safe_filename = f"{doc_id}_{file.filename}"
|
||||
file_path = os.path.join(upload_dir, safe_filename)
|
||||
|
||||
try:
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save uploaded file: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
|
||||
|
||||
# Create document record
|
||||
doc_record = {
|
||||
"id": doc_id,
|
||||
"code": code,
|
||||
"title": title,
|
||||
"filename": file.filename,
|
||||
"file_path": file_path,
|
||||
"document_type": document_type,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"status": "uploaded",
|
||||
"chunk_count": 0,
|
||||
}
|
||||
|
||||
custom_documents.append(doc_record)
|
||||
|
||||
# Queue for background ingestion
|
||||
background_tasks.add_task(ingest_uploaded_document, doc_record)
|
||||
|
||||
return {
|
||||
"status": "uploaded",
|
||||
"document_id": doc_id,
|
||||
"message": f"Document '{title}' uploaded and queued for ingestion",
|
||||
"document": doc_record,
|
||||
}
|
||||
|
||||
|
||||
async def ingest_uploaded_document(doc_record: Dict[str, Any]):
|
||||
"""Background task to ingest an uploaded document."""
|
||||
global custom_documents
|
||||
|
||||
try:
|
||||
doc_record["status"] = "processing"
|
||||
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
# Read PDF and extract text
|
||||
import fitz # PyMuPDF
|
||||
|
||||
doc = fitz.open(doc_record["file_path"])
|
||||
full_text = ""
|
||||
for page in doc:
|
||||
full_text += page.get_text()
|
||||
doc.close()
|
||||
|
||||
if not full_text.strip():
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No text could be extracted from PDF"
|
||||
return
|
||||
|
||||
# Chunk the text
|
||||
chunks = ingestion.chunk_text(full_text, doc_record["code"])
|
||||
|
||||
# Add metadata
|
||||
for chunk in chunks:
|
||||
chunk["regulation_code"] = doc_record["code"]
|
||||
chunk["regulation_name"] = doc_record["title"]
|
||||
chunk["document_type"] = doc_record["document_type"]
|
||||
chunk["source_url"] = f"upload://{doc_record['filename']}"
|
||||
|
||||
# Generate embeddings and upsert to Qdrant
|
||||
if chunks:
|
||||
await ingestion.embed_and_upsert(chunks)
|
||||
doc_record["chunk_count"] = len(chunks)
|
||||
doc_record["status"] = "indexed"
|
||||
logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}")
|
||||
else:
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No chunks generated from document"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest uploaded document: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = str(e)
|
||||
|
||||
|
||||
@router.post("/add-link")
|
||||
async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Add a URL/link for ingestion into the legal corpus.
|
||||
|
||||
The content will be fetched, extracted, and indexed.
|
||||
"""
|
||||
global custom_documents
|
||||
|
||||
# Create document record
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
doc_record = {
|
||||
"id": doc_id,
|
||||
"code": request.code,
|
||||
"title": request.title,
|
||||
"url": request.url,
|
||||
"document_type": request.document_type,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"status": "queued",
|
||||
"chunk_count": 0,
|
||||
}
|
||||
|
||||
custom_documents.append(doc_record)
|
||||
|
||||
# Queue for background ingestion
|
||||
background_tasks.add_task(ingest_link_document, doc_record)
|
||||
|
||||
return {
|
||||
"status": "queued",
|
||||
"document_id": doc_id,
|
||||
"message": f"Link '{request.title}' queued for ingestion",
|
||||
"document": doc_record,
|
||||
}
|
||||
|
||||
|
||||
async def ingest_link_document(doc_record: Dict[str, Any]):
|
||||
"""Background task to ingest content from a URL."""
|
||||
global custom_documents
|
||||
|
||||
try:
|
||||
doc_record["status"] = "fetching"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Fetch the URL
|
||||
response = await client.get(doc_record["url"], follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
if "application/pdf" in content_type:
|
||||
# Save PDF and process
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(response.content)
|
||||
pdf_path = f.name
|
||||
|
||||
import fitz
|
||||
pdf_doc = fitz.open(pdf_path)
|
||||
full_text = ""
|
||||
for page in pdf_doc:
|
||||
full_text += page.get_text()
|
||||
pdf_doc.close()
|
||||
os.unlink(pdf_path)
|
||||
|
||||
elif "text/html" in content_type:
|
||||
# Extract text from HTML
|
||||
from bs4 import BeautifulSoup
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
# Remove script and style elements
|
||||
for script in soup(["script", "style", "nav", "footer", "header"]):
|
||||
script.decompose()
|
||||
|
||||
full_text = soup.get_text(separator="\n", strip=True)
|
||||
|
||||
else:
|
||||
# Try to use as plain text
|
||||
full_text = response.text
|
||||
|
||||
if not full_text.strip():
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No text could be extracted from URL"
|
||||
return
|
||||
|
||||
doc_record["status"] = "processing"
|
||||
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
# Chunk the text
|
||||
chunks = ingestion.chunk_text(full_text, doc_record["code"])
|
||||
|
||||
# Add metadata
|
||||
for chunk in chunks:
|
||||
chunk["regulation_code"] = doc_record["code"]
|
||||
chunk["regulation_name"] = doc_record["title"]
|
||||
chunk["document_type"] = doc_record["document_type"]
|
||||
chunk["source_url"] = doc_record["url"]
|
||||
|
||||
# Generate embeddings and upsert to Qdrant
|
||||
if chunks:
|
||||
await ingestion.embed_and_upsert(chunks)
|
||||
doc_record["chunk_count"] = len(chunks)
|
||||
doc_record["status"] = "indexed"
|
||||
logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}")
|
||||
else:
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No chunks generated from content"
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to fetch URL: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = f"Failed to fetch URL: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest URL content: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = str(e)
|
||||
|
||||
|
||||
@router.delete("/custom-documents/{doc_id}")
|
||||
async def delete_custom_document(doc_id: str):
|
||||
"""
|
||||
Delete a custom document from the list.
|
||||
Note: This does not remove the chunks from Qdrant yet.
|
||||
"""
|
||||
global custom_documents
|
||||
|
||||
doc = next((d for d in custom_documents if d["id"] == doc_id), None)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
custom_documents = [d for d in custom_documents if d["id"] != doc_id]
|
||||
|
||||
# TODO: Also remove chunks from Qdrant by filtering on code
|
||||
|
||||
return {"status": "deleted", "document_id": doc_id}
|
||||
|
||||
|
||||
# ========== Pipeline Checkpoints ==========
|
||||
|
||||
# Create a separate router for pipeline-related endpoints
|
||||
pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"])
|
||||
|
||||
|
||||
@pipeline_router.get("/checkpoints")
|
||||
async def get_pipeline_checkpoints():
|
||||
"""
|
||||
Get current pipeline checkpoint state.
|
||||
|
||||
Returns the current state of the compliance pipeline including:
|
||||
- Pipeline ID and overall status
|
||||
- Start and completion times
|
||||
- All checkpoints with their validations and metrics
|
||||
- Summary data
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state is None:
|
||||
return {
|
||||
"status": "no_data",
|
||||
"message": "No pipeline run data available yet.",
|
||||
"pipeline_id": None,
|
||||
"checkpoints": [],
|
||||
"summary": {}
|
||||
}
|
||||
|
||||
# Enrich with validation summary
|
||||
validation_summary = {
|
||||
"passed": 0,
|
||||
"warning": 0,
|
||||
"failed": 0,
|
||||
"total": 0
|
||||
}
|
||||
|
||||
for checkpoint in state.get("checkpoints", []):
|
||||
for validation in checkpoint.get("validations", []):
|
||||
validation_summary["total"] += 1
|
||||
status = validation.get("status", "not_run")
|
||||
if status in validation_summary:
|
||||
validation_summary[status] += 1
|
||||
|
||||
state["validation_summary"] = validation_summary
|
||||
|
||||
return state
|
||||
|
||||
|
||||
@pipeline_router.get("/checkpoints/history")
|
||||
async def get_pipeline_history():
|
||||
"""
|
||||
Get list of previous pipeline runs (if stored).
|
||||
For now, returns only current run.
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state is None:
|
||||
return {"runs": []}
|
||||
|
||||
return {
|
||||
"runs": [{
|
||||
"pipeline_id": state.get("pipeline_id"),
|
||||
"status": state.get("status"),
|
||||
"started_at": state.get("started_at"),
|
||||
"completed_at": state.get("completed_at"),
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
# Pipeline state for start/stop
|
||||
pipeline_process_state = {
|
||||
"running": False,
|
||||
"pid": None,
|
||||
"started_at": None,
|
||||
}
|
||||
|
||||
|
||||
@pipeline_router.post("/start")
|
||||
async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Start the compliance pipeline in the background.
|
||||
|
||||
This runs the full_compliance_pipeline.py script which:
|
||||
1. Ingests all legal documents (unless skip_ingestion=True)
|
||||
2. Extracts requirements and controls
|
||||
3. Generates compliance measures
|
||||
4. Creates checkpoint data for monitoring
|
||||
"""
|
||||
global pipeline_process_state
|
||||
|
||||
# Check if already running
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state and state.get("status") == "running":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Pipeline is already running"
|
||||
)
|
||||
|
||||
if pipeline_process_state["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Pipeline start already in progress"
|
||||
)
|
||||
|
||||
pipeline_process_state["running"] = True
|
||||
pipeline_process_state["started_at"] = datetime.now().isoformat()
|
||||
|
||||
# Start pipeline in background
|
||||
background_tasks.add_task(
|
||||
run_pipeline_background,
|
||||
request.force_reindex,
|
||||
request.skip_ingestion
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "starting",
|
||||
"message": "Compliance pipeline is starting in background",
|
||||
"started_at": pipeline_process_state["started_at"],
|
||||
}
|
||||
|
||||
|
||||
async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool):
|
||||
"""Background task to run the compliance pipeline."""
|
||||
global pipeline_process_state
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# Build command
|
||||
cmd = [sys.executable, "full_compliance_pipeline.py"]
|
||||
if force_reindex:
|
||||
cmd.append("--force-reindex")
|
||||
if skip_ingestion:
|
||||
cmd.append("--skip-ingestion")
|
||||
|
||||
# Run as subprocess
|
||||
logger.info(f"Starting pipeline: {' '.join(cmd)}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
pipeline_process_state["pid"] = process.pid
|
||||
|
||||
# Wait for completion (non-blocking via asyncio)
|
||||
import asyncio
|
||||
while process.poll() is None:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
return_code = process.returncode
|
||||
|
||||
if return_code != 0:
|
||||
output = process.stdout.read() if process.stdout else ""
|
||||
logger.error(f"Pipeline failed with code {return_code}: {output}")
|
||||
else:
|
||||
logger.info("Pipeline completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run pipeline: {e}")
|
||||
|
||||
finally:
|
||||
pipeline_process_state["running"] = False
|
||||
pipeline_process_state["pid"] = None
|
||||
|
||||
|
||||
@pipeline_router.get("/status")
|
||||
async def get_pipeline_status():
|
||||
"""
|
||||
Get current pipeline running status.
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
checkpoint_status = state.get("status") if state else "no_data"
|
||||
|
||||
return {
|
||||
"process_running": pipeline_process_state["running"],
|
||||
"process_pid": pipeline_process_state["pid"],
|
||||
"process_started_at": pipeline_process_state["started_at"],
|
||||
"checkpoint_status": checkpoint_status,
|
||||
"current_phase": state.get("current_phase") if state else None,
|
||||
}
|
||||
|
||||
|
||||
# ========== Traceability / Quality Endpoints ==========
|
||||
|
||||
@router.get("/traceability")
|
||||
async def get_traceability(
|
||||
chunk_id: str = Query(..., description="Chunk ID or identifier"),
|
||||
regulation: str = Query(..., description="Regulation code"),
|
||||
):
|
||||
"""
|
||||
Get traceability information for a specific chunk.
|
||||
|
||||
Returns:
|
||||
- The chunk details
|
||||
- Requirements extracted from this chunk
|
||||
- Controls derived from those requirements
|
||||
|
||||
Note: This is a placeholder that will be enhanced once the
|
||||
requirements extraction pipeline is fully implemented.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
# Try to find the chunk by scrolling through points with the regulation filter
|
||||
# In a production system, we would have proper IDs and indexing
|
||||
|
||||
# For now, return placeholder structure
|
||||
# The actual implementation will query:
|
||||
# 1. The chunk from Qdrant
|
||||
# 2. Requirements from a requirements collection/table
|
||||
# 3. Controls from a controls collection/table
|
||||
|
||||
return {
|
||||
"chunk_id": chunk_id,
|
||||
"regulation": regulation,
|
||||
"requirements": [],
|
||||
"controls": [],
|
||||
"message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get traceability: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}")
|
||||
# Corpus routes and state
|
||||
from legal_corpus_routes import ( # noqa: F401
|
||||
router,
|
||||
REGULATIONS,
|
||||
COLLECTION_NAME,
|
||||
QDRANT_URL,
|
||||
EMBEDDING_SERVICE_URL,
|
||||
ingestion_state,
|
||||
custom_documents,
|
||||
SearchRequest,
|
||||
IngestRequest,
|
||||
AddLinkRequest,
|
||||
)
|
||||
|
||||
# Pipeline routes and state
|
||||
from legal_corpus_pipeline import ( # noqa: F401
|
||||
pipeline_router,
|
||||
pipeline_process_state,
|
||||
StartPipelineRequest,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Legal Corpus API - Background Ingestion Tasks
|
||||
|
||||
Background tasks for ingesting uploaded documents and URL links
|
||||
into the legal corpus vector database.
|
||||
|
||||
Extracted from legal_corpus_routes.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def ingest_uploaded_document(doc_record: Dict[str, Any]):
|
||||
"""Background task to ingest an uploaded document."""
|
||||
try:
|
||||
doc_record["status"] = "processing"
|
||||
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
import fitz
|
||||
doc = fitz.open(doc_record["file_path"])
|
||||
full_text = ""
|
||||
for page in doc:
|
||||
full_text += page.get_text()
|
||||
doc.close()
|
||||
|
||||
if not full_text.strip():
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No text could be extracted from PDF"
|
||||
return
|
||||
|
||||
chunks = ingestion.chunk_text(full_text, doc_record["code"])
|
||||
|
||||
for chunk in chunks:
|
||||
chunk["regulation_code"] = doc_record["code"]
|
||||
chunk["regulation_name"] = doc_record["title"]
|
||||
chunk["document_type"] = doc_record["document_type"]
|
||||
chunk["source_url"] = f"upload://{doc_record['filename']}"
|
||||
|
||||
if chunks:
|
||||
await ingestion.embed_and_upsert(chunks)
|
||||
doc_record["chunk_count"] = len(chunks)
|
||||
doc_record["status"] = "indexed"
|
||||
logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}")
|
||||
else:
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No chunks generated from document"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest uploaded document: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = str(e)
|
||||
|
||||
|
||||
async def ingest_link_document(doc_record: Dict[str, Any]):
|
||||
"""Background task to ingest content from a URL."""
|
||||
try:
|
||||
doc_record["status"] = "fetching"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(doc_record["url"], follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
if "application/pdf" in content_type:
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(response.content)
|
||||
pdf_path = f.name
|
||||
|
||||
import fitz
|
||||
pdf_doc = fitz.open(pdf_path)
|
||||
full_text = ""
|
||||
for page in pdf_doc:
|
||||
full_text += page.get_text()
|
||||
pdf_doc.close()
|
||||
os.unlink(pdf_path)
|
||||
|
||||
elif "text/html" in content_type:
|
||||
from bs4 import BeautifulSoup
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
for script in soup(["script", "style", "nav", "footer", "header"]):
|
||||
script.decompose()
|
||||
|
||||
full_text = soup.get_text(separator="\n", strip=True)
|
||||
|
||||
else:
|
||||
full_text = response.text
|
||||
|
||||
if not full_text.strip():
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No text could be extracted from URL"
|
||||
return
|
||||
|
||||
doc_record["status"] = "processing"
|
||||
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
chunks = ingestion.chunk_text(full_text, doc_record["code"])
|
||||
|
||||
for chunk in chunks:
|
||||
chunk["regulation_code"] = doc_record["code"]
|
||||
chunk["regulation_name"] = doc_record["title"]
|
||||
chunk["document_type"] = doc_record["document_type"]
|
||||
chunk["source_url"] = doc_record["url"]
|
||||
|
||||
if chunks:
|
||||
await ingestion.embed_and_upsert(chunks)
|
||||
doc_record["chunk_count"] = len(chunks)
|
||||
doc_record["status"] = "indexed"
|
||||
logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}")
|
||||
else:
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No chunks generated from content"
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to fetch URL: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = f"Failed to fetch URL: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest URL content: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = str(e)
|
||||
|
||||
|
||||
async def run_ingestion(
|
||||
force: bool,
|
||||
regulations: Optional[List[str]],
|
||||
ingestion_state: Dict[str, Any],
|
||||
all_regulations: List[Dict[str, str]],
|
||||
):
|
||||
"""Background task for running full corpus ingestion."""
|
||||
try:
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
regs_to_process = regulations or [r["code"] for r in all_regulations]
|
||||
|
||||
for i, reg_code in enumerate(regs_to_process):
|
||||
ingestion_state["current_regulation"] = reg_code
|
||||
ingestion_state["processed"] = i
|
||||
|
||||
try:
|
||||
await ingestion.ingest_single(reg_code, force=force)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest {reg_code}: {e}")
|
||||
|
||||
ingestion_state["completed"] = True
|
||||
ingestion_state["processed"] = len(regs_to_process)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ingestion failed: {e}")
|
||||
ingestion_state["error"] = str(e)
|
||||
|
||||
finally:
|
||||
ingestion_state["running"] = False
|
||||
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Legal Corpus API - Pipeline Routes
|
||||
|
||||
Pipeline checkpoints, history, start/stop, and status endpoints.
|
||||
|
||||
Extracted from legal_corpus_api.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StartPipelineRequest(BaseModel):
|
||||
force_reindex: bool = False
|
||||
skip_ingestion: bool = False
|
||||
|
||||
|
||||
# Create a separate router for pipeline-related endpoints
|
||||
pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"])
|
||||
|
||||
|
||||
@pipeline_router.get("/checkpoints")
|
||||
async def get_pipeline_checkpoints():
|
||||
"""
|
||||
Get current pipeline checkpoint state.
|
||||
|
||||
Returns the current state of the compliance pipeline including:
|
||||
- Pipeline ID and overall status
|
||||
- Start and completion times
|
||||
- All checkpoints with their validations and metrics
|
||||
- Summary data
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state is None:
|
||||
return {
|
||||
"status": "no_data",
|
||||
"message": "No pipeline run data available yet.",
|
||||
"pipeline_id": None,
|
||||
"checkpoints": [],
|
||||
"summary": {}
|
||||
}
|
||||
|
||||
# Enrich with validation summary
|
||||
validation_summary = {
|
||||
"passed": 0,
|
||||
"warning": 0,
|
||||
"failed": 0,
|
||||
"total": 0
|
||||
}
|
||||
|
||||
for checkpoint in state.get("checkpoints", []):
|
||||
for validation in checkpoint.get("validations", []):
|
||||
validation_summary["total"] += 1
|
||||
status = validation.get("status", "not_run")
|
||||
if status in validation_summary:
|
||||
validation_summary[status] += 1
|
||||
|
||||
state["validation_summary"] = validation_summary
|
||||
|
||||
return state
|
||||
|
||||
|
||||
@pipeline_router.get("/checkpoints/history")
|
||||
async def get_pipeline_history():
|
||||
"""
|
||||
Get list of previous pipeline runs (if stored).
|
||||
For now, returns only current run.
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state is None:
|
||||
return {"runs": []}
|
||||
|
||||
return {
|
||||
"runs": [{
|
||||
"pipeline_id": state.get("pipeline_id"),
|
||||
"status": state.get("status"),
|
||||
"started_at": state.get("started_at"),
|
||||
"completed_at": state.get("completed_at"),
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
# Pipeline state for start/stop
|
||||
pipeline_process_state = {
|
||||
"running": False,
|
||||
"pid": None,
|
||||
"started_at": None,
|
||||
}
|
||||
|
||||
|
||||
@pipeline_router.post("/start")
|
||||
async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Start the compliance pipeline in the background.
|
||||
|
||||
This runs the full_compliance_pipeline.py script which:
|
||||
1. Ingests all legal documents (unless skip_ingestion=True)
|
||||
2. Extracts requirements and controls
|
||||
3. Generates compliance measures
|
||||
4. Creates checkpoint data for monitoring
|
||||
"""
|
||||
global pipeline_process_state
|
||||
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state and state.get("status") == "running":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Pipeline is already running"
|
||||
)
|
||||
|
||||
if pipeline_process_state["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Pipeline start already in progress"
|
||||
)
|
||||
|
||||
pipeline_process_state["running"] = True
|
||||
pipeline_process_state["started_at"] = datetime.now().isoformat()
|
||||
|
||||
background_tasks.add_task(
|
||||
run_pipeline_background,
|
||||
request.force_reindex,
|
||||
request.skip_ingestion
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "starting",
|
||||
"message": "Compliance pipeline is starting in background",
|
||||
"started_at": pipeline_process_state["started_at"],
|
||||
}
|
||||
|
||||
|
||||
async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool):
|
||||
"""Background task to run the compliance pipeline."""
|
||||
global pipeline_process_state
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
cmd = [sys.executable, "full_compliance_pipeline.py"]
|
||||
if force_reindex:
|
||||
cmd.append("--force-reindex")
|
||||
if skip_ingestion:
|
||||
cmd.append("--skip-ingestion")
|
||||
|
||||
logger.info(f"Starting pipeline: {' '.join(cmd)}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
pipeline_process_state["pid"] = process.pid
|
||||
|
||||
while process.poll() is None:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
return_code = process.returncode
|
||||
|
||||
if return_code != 0:
|
||||
output = process.stdout.read() if process.stdout else ""
|
||||
logger.error(f"Pipeline failed with code {return_code}: {output}")
|
||||
else:
|
||||
logger.info("Pipeline completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run pipeline: {e}")
|
||||
|
||||
finally:
|
||||
pipeline_process_state["running"] = False
|
||||
pipeline_process_state["pid"] = None
|
||||
|
||||
|
||||
@pipeline_router.get("/status")
|
||||
async def get_pipeline_status():
|
||||
"""Get current pipeline running status."""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
checkpoint_status = state.get("status") if state else "no_data"
|
||||
|
||||
return {
|
||||
"process_running": pipeline_process_state["running"],
|
||||
"process_pid": pipeline_process_state["pid"],
|
||||
"process_started_at": pipeline_process_state["started_at"],
|
||||
"checkpoint_status": checkpoint_status,
|
||||
"current_phase": state.get("current_phase") if state else None,
|
||||
}
|
||||
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Legal Corpus API - Corpus Routes
|
||||
|
||||
Endpoints for the RAG page in admin-v2:
|
||||
- GET /status - Collection status with chunk counts
|
||||
- GET /search - Semantic search
|
||||
- POST /ingest - Trigger ingestion
|
||||
- GET /ingestion-status - Ingestion status
|
||||
- GET /regulations - List regulations
|
||||
- GET /custom-documents - List custom docs
|
||||
- POST /upload - Upload document
|
||||
- POST /add-link - Add link for ingestion
|
||||
- DELETE /custom-documents/{id} - Delete custom doc
|
||||
- GET /traceability - Traceability info
|
||||
|
||||
Extracted from legal_corpus_api.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
import uuid
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from legal_corpus_ingest_tasks import (
|
||||
ingest_uploaded_document,
|
||||
ingest_link_document,
|
||||
run_ingestion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
||||
COLLECTION_NAME = "bp_legal_corpus"
|
||||
|
||||
# All regulations for status endpoint
|
||||
REGULATIONS = [
|
||||
{"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"},
|
||||
{"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"},
|
||||
{"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"},
|
||||
{"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"},
|
||||
{"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"},
|
||||
{"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"},
|
||||
{"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"},
|
||||
{"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"},
|
||||
{"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"},
|
||||
{"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"},
|
||||
{"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"},
|
||||
{"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"},
|
||||
{"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"},
|
||||
{"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"},
|
||||
{"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"},
|
||||
{"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"},
|
||||
{"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"},
|
||||
{"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"},
|
||||
{"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"},
|
||||
]
|
||||
|
||||
# Ingestion state (in-memory for now)
|
||||
ingestion_state = {
|
||||
"running": False,
|
||||
"completed": False,
|
||||
"current_regulation": None,
|
||||
"processed": 0,
|
||||
"total": len(REGULATIONS),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
regulations: Optional[List[str]] = None
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
force: bool = False
|
||||
regulations: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AddLinkRequest(BaseModel):
|
||||
url: str
|
||||
title: str
|
||||
code: str
|
||||
document_type: str = "custom"
|
||||
|
||||
|
||||
# Store for custom documents (in-memory for now)
|
||||
custom_documents: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"])
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_legal_corpus_status():
|
||||
"""Get status of the legal corpus collection including chunk counts per regulation."""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}")
|
||||
if collection_res.status_code != 200:
|
||||
return {
|
||||
"collection": COLLECTION_NAME,
|
||||
"totalPoints": 0,
|
||||
"vectorSize": 1024,
|
||||
"status": "not_found",
|
||||
"regulations": {},
|
||||
}
|
||||
|
||||
collection_data = collection_res.json()
|
||||
result = collection_data.get("result", {})
|
||||
|
||||
regulation_counts = {}
|
||||
for reg in REGULATIONS:
|
||||
count_res = await client.post(
|
||||
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count",
|
||||
json={
|
||||
"filter": {
|
||||
"must": [{"key": "regulation_code", "match": {"value": reg["code"]}}]
|
||||
}
|
||||
},
|
||||
)
|
||||
if count_res.status_code == 200:
|
||||
count_data = count_res.json()
|
||||
regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0)
|
||||
else:
|
||||
regulation_counts[reg["code"]] = 0
|
||||
|
||||
return {
|
||||
"collection": COLLECTION_NAME,
|
||||
"totalPoints": result.get("points_count", 0),
|
||||
"vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024),
|
||||
"status": result.get("status", "unknown"),
|
||||
"regulations": regulation_counts,
|
||||
}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to get Qdrant status: {e}")
|
||||
raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_legal_corpus(
|
||||
query: str = Query(..., description="Search query"),
|
||||
top_k: int = Query(5, ge=1, le=20, description="Number of results"),
|
||||
regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"),
|
||||
):
|
||||
"""Semantic search in legal corpus using BGE-M3 embeddings."""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
embed_res = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": [query]},
|
||||
)
|
||||
if embed_res.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail="Embedding service error")
|
||||
|
||||
embed_data = embed_res.json()
|
||||
query_vector = embed_data["embeddings"][0]
|
||||
|
||||
search_request = {
|
||||
"vector": query_vector,
|
||||
"limit": top_k,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
if regulations:
|
||||
reg_codes = [r.strip() for r in regulations.split(",")]
|
||||
search_request["filter"] = {
|
||||
"should": [
|
||||
{"key": "regulation_code", "match": {"value": code}}
|
||||
for code in reg_codes
|
||||
]
|
||||
}
|
||||
|
||||
search_res = await client.post(
|
||||
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search",
|
||||
json=search_request,
|
||||
)
|
||||
|
||||
if search_res.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail="Search failed")
|
||||
|
||||
search_data = search_res.json()
|
||||
results = []
|
||||
for point in search_data.get("result", []):
|
||||
payload = point.get("payload", {})
|
||||
results.append({
|
||||
"text": payload.get("text", ""),
|
||||
"regulation_code": payload.get("regulation_code", ""),
|
||||
"regulation_name": payload.get("regulation_name", ""),
|
||||
"article": payload.get("article"),
|
||||
"paragraph": payload.get("paragraph"),
|
||||
"source_url": payload.get("source_url", ""),
|
||||
"score": point.get("score", 0),
|
||||
})
|
||||
|
||||
return {"results": results, "query": query, "count": len(results)}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ingest")
|
||||
async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks):
|
||||
"""Trigger legal corpus ingestion in background."""
|
||||
global ingestion_state
|
||||
|
||||
if ingestion_state["running"]:
|
||||
raise HTTPException(status_code=409, detail="Ingestion already running")
|
||||
|
||||
ingestion_state = {
|
||||
"running": True,
|
||||
"completed": False,
|
||||
"current_regulation": None,
|
||||
"processed": 0,
|
||||
"total": len(REGULATIONS),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
background_tasks.add_task(run_ingestion, request.force, request.regulations, ingestion_state, REGULATIONS)
|
||||
|
||||
return {
|
||||
"status": "started",
|
||||
"job_id": "manual-trigger",
|
||||
"message": f"Ingestion started for {len(REGULATIONS)} regulations",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/ingestion-status")
|
||||
async def get_ingestion_status():
|
||||
"""Get current ingestion status."""
|
||||
return ingestion_state
|
||||
|
||||
|
||||
@router.get("/regulations")
|
||||
async def get_regulations():
|
||||
"""Get list of all supported regulations."""
|
||||
return {"regulations": REGULATIONS}
|
||||
|
||||
|
||||
@router.get("/custom-documents")
|
||||
async def get_custom_documents():
|
||||
"""Get list of custom documents added by user."""
|
||||
return {"documents": custom_documents}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_document(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
title: str = Form(...),
|
||||
code: str = Form(...),
|
||||
document_type: str = Form("custom"),
|
||||
):
|
||||
"""Upload a document (PDF) for ingestion into the legal corpus."""
|
||||
global custom_documents
|
||||
|
||||
if not file.filename.endswith(('.pdf', '.PDF')):
|
||||
raise HTTPException(status_code=400, detail="Only PDF files are supported")
|
||||
|
||||
upload_dir = "/tmp/legal_corpus_uploads"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
safe_filename = f"{doc_id}_{file.filename}"
|
||||
file_path = os.path.join(upload_dir, safe_filename)
|
||||
|
||||
try:
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save uploaded file: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
|
||||
|
||||
doc_record = {
|
||||
"id": doc_id,
|
||||
"code": code,
|
||||
"title": title,
|
||||
"filename": file.filename,
|
||||
"file_path": file_path,
|
||||
"document_type": document_type,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"status": "uploaded",
|
||||
"chunk_count": 0,
|
||||
}
|
||||
|
||||
custom_documents.append(doc_record)
|
||||
background_tasks.add_task(ingest_uploaded_document, doc_record)
|
||||
|
||||
return {
|
||||
"status": "uploaded",
|
||||
"document_id": doc_id,
|
||||
"message": f"Document '{title}' uploaded and queued for ingestion",
|
||||
"document": doc_record,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.post("/add-link")
|
||||
async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks):
|
||||
"""Add a URL/link for ingestion into the legal corpus."""
|
||||
global custom_documents
|
||||
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
doc_record = {
|
||||
"id": doc_id,
|
||||
"code": request.code,
|
||||
"title": request.title,
|
||||
"url": request.url,
|
||||
"document_type": request.document_type,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"status": "queued",
|
||||
"chunk_count": 0,
|
||||
}
|
||||
|
||||
custom_documents.append(doc_record)
|
||||
background_tasks.add_task(ingest_link_document, doc_record)
|
||||
|
||||
return {
|
||||
"status": "queued",
|
||||
"document_id": doc_id,
|
||||
"message": f"Link '{request.title}' queued for ingestion",
|
||||
"document": doc_record,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.delete("/custom-documents/{doc_id}")
|
||||
async def delete_custom_document(doc_id: str):
|
||||
"""Delete a custom document from the list."""
|
||||
global custom_documents
|
||||
|
||||
doc = next((d for d in custom_documents if d["id"] == doc_id), None)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
custom_documents = [d for d in custom_documents if d["id"] != doc_id]
|
||||
|
||||
return {"status": "deleted", "document_id": doc_id}
|
||||
|
||||
|
||||
@router.get("/traceability")
|
||||
async def get_traceability(
|
||||
chunk_id: str = Query(..., description="Chunk ID or identifier"),
|
||||
regulation: str = Query(..., description="Regulation code"),
|
||||
):
|
||||
"""Get traceability information for a specific chunk."""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
return {
|
||||
"chunk_id": chunk_id,
|
||||
"regulation": regulation,
|
||||
"requirements": [],
|
||||
"controls": [],
|
||||
"message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get traceability: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}")
|
||||
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
AI Email - Category Classification and Response Suggestions
|
||||
|
||||
Rule-based and LLM-based email category classification,
|
||||
plus response suggestion generation.
|
||||
|
||||
Extracted from ai_service.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from .models import (
|
||||
EmailCategory,
|
||||
SenderType,
|
||||
ResponseSuggestion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LLM Gateway configuration
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
|
||||
|
||||
|
||||
async def classify_category(
|
||||
http_client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""
|
||||
Classify email into a category.
|
||||
|
||||
Rule-based classification first, falls back to LLM.
|
||||
"""
|
||||
category, confidence = _classify_category_rules(subject, body_preview, sender_type)
|
||||
|
||||
if confidence > 0.7:
|
||||
return category, confidence
|
||||
|
||||
return await _classify_category_llm(http_client, subject, body_preview)
|
||||
|
||||
|
||||
def _classify_category_rules(
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""Rule-based category classification."""
|
||||
text = f"{subject} {body_preview}".lower()
|
||||
|
||||
category_keywords = {
|
||||
EmailCategory.DIENSTLICH: [
|
||||
"dienstlich", "dienstanweisung", "erlass", "verordnung",
|
||||
"bescheid", "verfuegung", "ministerium", "behoerde"
|
||||
],
|
||||
EmailCategory.PERSONAL: [
|
||||
"personalrat", "stellenausschreibung", "versetzung",
|
||||
"beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit"
|
||||
],
|
||||
EmailCategory.FINANZEN: [
|
||||
"budget", "haushalt", "etat", "abrechnung", "rechnung",
|
||||
"erstattung", "zuschuss", "foerdermittel"
|
||||
],
|
||||
EmailCategory.ELTERN: [
|
||||
"elternbrief", "elternabend", "schulkonferenz",
|
||||
"elternvertreter", "elternbeirat"
|
||||
],
|
||||
EmailCategory.SCHUELER: [
|
||||
"schueler", "schuelerin", "zeugnis", "klasse", "unterricht",
|
||||
"pruefung", "klassenfahrt", "schulpflicht"
|
||||
],
|
||||
EmailCategory.FORTBILDUNG: [
|
||||
"fortbildung", "seminar", "workshop", "schulung",
|
||||
"weiterbildung", "nlq", "didaktik"
|
||||
],
|
||||
EmailCategory.VERANSTALTUNG: [
|
||||
"einladung", "veranstaltung", "termin", "konferenz",
|
||||
"sitzung", "tagung", "feier"
|
||||
],
|
||||
EmailCategory.SICHERHEIT: [
|
||||
"sicherheit", "notfall", "brandschutz", "evakuierung",
|
||||
"hygiene", "corona", "infektionsschutz"
|
||||
],
|
||||
EmailCategory.TECHNIK: [
|
||||
"it", "software", "computer", "netzwerk", "login",
|
||||
"passwort", "digitalisierung", "iserv"
|
||||
],
|
||||
EmailCategory.NEWSLETTER: [
|
||||
"newsletter", "rundschreiben", "info-mail", "mitteilung"
|
||||
],
|
||||
EmailCategory.WERBUNG: [
|
||||
"angebot", "rabatt", "aktion", "werbung", "abonnement"
|
||||
],
|
||||
}
|
||||
|
||||
best_category = EmailCategory.SONSTIGES
|
||||
best_score = 0.0
|
||||
|
||||
for category, keywords in category_keywords.items():
|
||||
score = sum(1 for kw in keywords if kw in text)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_category = category
|
||||
|
||||
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
|
||||
if best_category == EmailCategory.SONSTIGES:
|
||||
best_category = EmailCategory.DIENSTLICH
|
||||
best_score = 2
|
||||
|
||||
confidence = min(0.9, 0.4 + (best_score * 0.15))
|
||||
|
||||
return best_category, confidence
|
||||
|
||||
|
||||
async def _classify_category_llm(
|
||||
client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""LLM-based category classification."""
|
||||
try:
|
||||
categories = ", ".join([c.value for c in EmailCategory])
|
||||
|
||||
prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview[:500]}
|
||||
|
||||
Kategorien: {categories}
|
||||
|
||||
Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0):
|
||||
Format: kategorie|konfidenz
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result = data.get("response", "sonstiges|0.5")
|
||||
parts = result.strip().split("|")
|
||||
|
||||
if len(parts) >= 2:
|
||||
category_str = parts[0].strip().lower()
|
||||
confidence = float(parts[1].strip())
|
||||
|
||||
try:
|
||||
category = EmailCategory(category_str)
|
||||
return category, min(max(confidence, 0.0), 1.0)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM category classification failed: {e}")
|
||||
|
||||
return EmailCategory.SONSTIGES, 0.5
|
||||
|
||||
|
||||
async def suggest_response(
|
||||
http_client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_text: str,
|
||||
sender_type: SenderType,
|
||||
category: EmailCategory,
|
||||
) -> List[ResponseSuggestion]:
|
||||
"""Generate response suggestions for an email."""
|
||||
suggestions = []
|
||||
|
||||
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
|
||||
suggestions.append(ResponseSuggestion(
|
||||
template_type="acknowledgment",
|
||||
subject=f"Re: {subject}",
|
||||
body="""Sehr geehrte Damen und Herren,
|
||||
|
||||
vielen Dank fuer Ihre Nachricht.
|
||||
|
||||
Ich bestaetige den Eingang und werde die Angelegenheit fristgerecht bearbeiten.
|
||||
|
||||
Mit freundlichen Gruessen""",
|
||||
confidence=0.8,
|
||||
))
|
||||
|
||||
if category == EmailCategory.ELTERN:
|
||||
suggestions.append(ResponseSuggestion(
|
||||
template_type="parent_response",
|
||||
subject=f"Re: {subject}",
|
||||
body="""Liebe Eltern,
|
||||
|
||||
vielen Dank fuer Ihre Nachricht.
|
||||
|
||||
[Ihre Antwort hier]
|
||||
|
||||
Mit freundlichen Gruessen""",
|
||||
confidence=0.7,
|
||||
))
|
||||
|
||||
try:
|
||||
llm_suggestion = await _generate_response_llm(http_client, subject, body_text[:500], sender_type)
|
||||
if llm_suggestion:
|
||||
suggestions.append(llm_suggestion)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response generation failed: {e}")
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
async def _generate_response_llm(
|
||||
client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Optional[ResponseSuggestion]:
|
||||
"""Generate a response suggestion using LLM."""
|
||||
try:
|
||||
sender_desc = {
|
||||
SenderType.KULTUSMINISTERIUM: "dem Kultusministerium",
|
||||
SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehoerde",
|
||||
SenderType.RLSB: "dem RLSB",
|
||||
SenderType.ELTERNVERTRETER: "einem Elternvertreter",
|
||||
}.get(sender_type, "einem Absender")
|
||||
|
||||
prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview}
|
||||
|
||||
Die Antwort sollte:
|
||||
- Hoeflich und formell sein
|
||||
- Den Eingang bestaetigen
|
||||
- Eine konkrete naechste Aktion nennen oder um Klaerung bitten
|
||||
|
||||
Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:").
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 300,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
body = data.get("response", "").strip()
|
||||
|
||||
if body:
|
||||
return ResponseSuggestion(
|
||||
template_type="ai_generated",
|
||||
subject=f"Re: {subject}",
|
||||
body=body,
|
||||
confidence=0.6,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response generation failed: {e}")
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
AI Email - Deadline Extraction
|
||||
|
||||
Regex-based and LLM-based deadline extraction from email content.
|
||||
|
||||
Extracted from ai_service.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import List
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import httpx
|
||||
|
||||
from .models import DeadlineExtraction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LLM Gateway configuration
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
|
||||
|
||||
|
||||
async def extract_deadlines(
|
||||
http_client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_text: str,
|
||||
) -> List[DeadlineExtraction]:
|
||||
"""
|
||||
Extract deadlines from email content.
|
||||
|
||||
Uses regex patterns first, then LLM for complex cases.
|
||||
"""
|
||||
deadlines = []
|
||||
|
||||
full_text = f"{subject}\n{body_text}" if body_text else subject
|
||||
|
||||
# Try regex extraction first
|
||||
regex_deadlines = _extract_deadlines_regex(full_text)
|
||||
deadlines.extend(regex_deadlines)
|
||||
|
||||
# If no regex matches, try LLM
|
||||
if not deadlines and body_text:
|
||||
llm_deadlines = await _extract_deadlines_llm(http_client, subject, body_text[:1000])
|
||||
deadlines.extend(llm_deadlines)
|
||||
|
||||
return deadlines
|
||||
|
||||
|
||||
def _extract_deadlines_regex(text: str) -> List[DeadlineExtraction]:
|
||||
"""Extract deadlines using regex patterns."""
|
||||
deadlines = []
|
||||
now = datetime.now()
|
||||
|
||||
# German date patterns
|
||||
patterns = [
|
||||
# "bis zum 15.01.2025"
|
||||
(r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "spaetestens am 15.01.2025"
|
||||
(r"sp\u00e4testens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "Abgabetermin: 15.01.2025"
|
||||
(r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "innerhalb von 14 Tagen"
|
||||
(r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False),
|
||||
# "bis Ende Januar"
|
||||
(r"bis\s+(?:Ende\s+)?(Januar|Februar|M\u00e4rz|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False),
|
||||
]
|
||||
|
||||
for pattern, is_specific_date in patterns:
|
||||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
if is_specific_date:
|
||||
day = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
year = int(match.group(3))
|
||||
|
||||
if year < 100:
|
||||
year += 2000
|
||||
|
||||
deadline_date = datetime(year, month, day)
|
||||
|
||||
if deadline_date < now:
|
||||
continue
|
||||
|
||||
start = max(0, match.start() - 50)
|
||||
end = min(len(text), match.end() + 50)
|
||||
context = text[start:end].strip()
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=f"Frist: {match.group(0)}",
|
||||
confidence=0.85,
|
||||
source_text=context,
|
||||
is_firm=True,
|
||||
))
|
||||
|
||||
else:
|
||||
if "Tagen" in pattern or "Wochen" in pattern:
|
||||
days = int(match.group(1))
|
||||
if "Wochen" in match.group(0).lower():
|
||||
days *= 7
|
||||
deadline_date = now + timedelta(days=days)
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=f"Relative Frist: {match.group(0)}",
|
||||
confidence=0.7,
|
||||
source_text=match.group(0),
|
||||
is_firm=False,
|
||||
))
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.debug(f"Failed to parse date: {e}")
|
||||
continue
|
||||
|
||||
return deadlines
|
||||
|
||||
|
||||
async def _extract_deadlines_llm(
|
||||
client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
) -> List[DeadlineExtraction]:
|
||||
"""Extract deadlines using LLM."""
|
||||
try:
|
||||
prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview}
|
||||
|
||||
Liste alle Fristen im folgenden Format auf (eine pro Zeile):
|
||||
DATUM|BESCHREIBUNG|VERBINDLICH
|
||||
Beispiel: 2025-01-15|Abgabe der Berichte|ja
|
||||
|
||||
Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN
|
||||
|
||||
Antworte NUR im angegebenen Format.
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 200,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result_text = data.get("response", "")
|
||||
|
||||
if "KEINE_FRISTEN" in result_text:
|
||||
return []
|
||||
|
||||
deadlines = []
|
||||
for line in result_text.strip().split("\n"):
|
||||
parts = line.split("|")
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
date_str = parts[0].strip()
|
||||
deadline_date = datetime.fromisoformat(date_str)
|
||||
description = parts[1].strip()
|
||||
is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=description,
|
||||
confidence=0.7,
|
||||
source_text=line,
|
||||
is_firm=is_firm,
|
||||
))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return deadlines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM deadline extraction failed: {e}")
|
||||
|
||||
return []
|
||||
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
AI Email - Sender Classification
|
||||
|
||||
Domain-based and LLM-based sender classification for emails.
|
||||
|
||||
Extracted from ai_service.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .models import (
|
||||
SenderType,
|
||||
SenderClassification,
|
||||
classify_sender_by_domain,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LLM Gateway configuration
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
|
||||
|
||||
|
||||
async def classify_sender(
|
||||
http_client: httpx.AsyncClient,
|
||||
sender_email: str,
|
||||
sender_name: Optional[str] = None,
|
||||
subject: Optional[str] = None,
|
||||
body_preview: Optional[str] = None,
|
||||
) -> SenderClassification:
|
||||
"""
|
||||
Classify the sender of an email.
|
||||
|
||||
First tries domain matching, then falls back to LLM.
|
||||
"""
|
||||
# Try domain-based classification first (fast, high confidence)
|
||||
domain_result = classify_sender_by_domain(sender_email)
|
||||
if domain_result:
|
||||
return domain_result
|
||||
|
||||
# Fall back to LLM classification
|
||||
return await _classify_sender_llm(
|
||||
http_client, sender_email, sender_name, subject, body_preview
|
||||
)
|
||||
|
||||
|
||||
async def _classify_sender_llm(
|
||||
client: httpx.AsyncClient,
|
||||
sender_email: str,
|
||||
sender_name: Optional[str],
|
||||
subject: Optional[str],
|
||||
body_preview: Optional[str],
|
||||
) -> SenderClassification:
|
||||
"""Classify sender using LLM."""
|
||||
try:
|
||||
prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn:
|
||||
|
||||
Absender E-Mail: {sender_email}
|
||||
Absender Name: {sender_name or "Nicht angegeben"}
|
||||
Betreff: {subject or "Nicht angegeben"}
|
||||
Vorschau: {body_preview[:200] if body_preview else "Nicht verfuegbar"}
|
||||
|
||||
Klassifiziere den Absender in EINE der folgenden Kategorien:
|
||||
- kultusministerium: Kultusministerium/Bildungsministerium
|
||||
- landesschulbehoerde: Landesschulbehoerde
|
||||
- rlsb: Regionales Landesamt fuer Schule und Bildung
|
||||
- schulamt: Schulamt
|
||||
- nibis: Niedersaechsischer Bildungsserver
|
||||
- schultraeger: Schultraeger/Kommune
|
||||
- elternvertreter: Elternvertreter/Elternrat
|
||||
- gewerkschaft: Gewerkschaft (GEW, VBE, etc.)
|
||||
- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.)
|
||||
- privatperson: Privatperson
|
||||
- unternehmen: Unternehmen/Firma
|
||||
- unbekannt: Nicht einzuordnen
|
||||
|
||||
Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0.
|
||||
Format: kategorie|konfidenz|kurze_begruendung
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 100,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result_text = data.get("response", "unbekannt|0.5|")
|
||||
|
||||
parts = result_text.strip().split("|")
|
||||
if len(parts) >= 2:
|
||||
sender_type_str = parts[0].strip().lower()
|
||||
confidence = float(parts[1].strip())
|
||||
|
||||
type_mapping = {
|
||||
"kultusministerium": SenderType.KULTUSMINISTERIUM,
|
||||
"landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE,
|
||||
"rlsb": SenderType.RLSB,
|
||||
"schulamt": SenderType.SCHULAMT,
|
||||
"nibis": SenderType.NIBIS,
|
||||
"schultraeger": SenderType.SCHULTRAEGER,
|
||||
"elternvertreter": SenderType.ELTERNVERTRETER,
|
||||
"gewerkschaft": SenderType.GEWERKSCHAFT,
|
||||
"fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT,
|
||||
"privatperson": SenderType.PRIVATPERSON,
|
||||
"unternehmen": SenderType.UNTERNEHMEN,
|
||||
}
|
||||
|
||||
sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT)
|
||||
|
||||
return SenderClassification(
|
||||
sender_type=sender_type,
|
||||
confidence=min(max(confidence, 0.0), 1.0),
|
||||
domain_matched=False,
|
||||
ai_classified=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM sender classification failed: {e}")
|
||||
|
||||
# Default fallback
|
||||
return SenderClassification(
|
||||
sender_type=SenderType.UNBEKANNT,
|
||||
confidence=0.3,
|
||||
domain_matched=False,
|
||||
ai_classified=False,
|
||||
)
|
||||
@@ -1,18 +1,19 @@
|
||||
"""
|
||||
AI Email Analysis Service
|
||||
AI Email Analysis Service — Barrel Re-export
|
||||
|
||||
KI-powered email analysis with:
|
||||
- Sender classification (authority recognition)
|
||||
- Deadline extraction
|
||||
- Category classification
|
||||
- Response suggestions
|
||||
Split into:
|
||||
- mail/ai_sender.py — Sender classification (domain + LLM)
|
||||
- mail/ai_deadline.py — Deadline extraction (regex + LLM)
|
||||
- mail/ai_category.py — Category classification + response suggestions
|
||||
|
||||
The AIEmailService class and get_ai_email_service() are defined here
|
||||
to maintain the original public API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .models import (
|
||||
@@ -23,17 +24,15 @@ from .models import (
|
||||
DeadlineExtraction,
|
||||
EmailAnalysisResult,
|
||||
ResponseSuggestion,
|
||||
KNOWN_AUTHORITIES_NI,
|
||||
classify_sender_by_domain,
|
||||
get_priority_from_sender_type,
|
||||
)
|
||||
from .mail_db import update_email_ai_analysis
|
||||
from .ai_sender import classify_sender, LLM_GATEWAY_URL
|
||||
from .ai_deadline import extract_deadlines
|
||||
from .ai_category import classify_category, suggest_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LLM Gateway configuration
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
|
||||
|
||||
|
||||
class AIEmailService:
|
||||
"""
|
||||
@@ -56,10 +55,6 @@ class AIEmailService:
|
||||
self._http_client = httpx.AsyncClient(timeout=30.0)
|
||||
return self._http_client
|
||||
|
||||
# =========================================================================
|
||||
# Sender Classification
|
||||
# =========================================================================
|
||||
|
||||
async def classify_sender(
|
||||
self,
|
||||
sender_email: str,
|
||||
@@ -67,300 +62,20 @@ class AIEmailService:
|
||||
subject: Optional[str] = None,
|
||||
body_preview: Optional[str] = None,
|
||||
) -> SenderClassification:
|
||||
"""
|
||||
Classify the sender of an email.
|
||||
|
||||
First tries domain matching, then falls back to LLM.
|
||||
|
||||
Args:
|
||||
sender_email: Sender's email address
|
||||
sender_name: Sender's display name
|
||||
subject: Email subject
|
||||
body_preview: First 200 chars of body
|
||||
|
||||
Returns:
|
||||
SenderClassification with type and confidence
|
||||
"""
|
||||
# Try domain-based classification first (fast, high confidence)
|
||||
domain_result = classify_sender_by_domain(sender_email)
|
||||
if domain_result:
|
||||
return domain_result
|
||||
|
||||
# Fall back to LLM classification
|
||||
return await self._classify_sender_llm(
|
||||
sender_email, sender_name, subject, body_preview
|
||||
)
|
||||
|
||||
async def _classify_sender_llm(
|
||||
self,
|
||||
sender_email: str,
|
||||
sender_name: Optional[str],
|
||||
subject: Optional[str],
|
||||
body_preview: Optional[str],
|
||||
) -> SenderClassification:
|
||||
"""Classify sender using LLM."""
|
||||
try:
|
||||
"""Classify the sender of an email."""
|
||||
client = await self.get_http_client()
|
||||
|
||||
prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn:
|
||||
|
||||
Absender E-Mail: {sender_email}
|
||||
Absender Name: {sender_name or "Nicht angegeben"}
|
||||
Betreff: {subject or "Nicht angegeben"}
|
||||
Vorschau: {body_preview[:200] if body_preview else "Nicht verfügbar"}
|
||||
|
||||
Klassifiziere den Absender in EINE der folgenden Kategorien:
|
||||
- kultusministerium: Kultusministerium/Bildungsministerium
|
||||
- landesschulbehoerde: Landesschulbehörde
|
||||
- rlsb: Regionales Landesamt für Schule und Bildung
|
||||
- schulamt: Schulamt
|
||||
- nibis: Niedersächsischer Bildungsserver
|
||||
- schultraeger: Schulträger/Kommune
|
||||
- elternvertreter: Elternvertreter/Elternrat
|
||||
- gewerkschaft: Gewerkschaft (GEW, VBE, etc.)
|
||||
- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.)
|
||||
- privatperson: Privatperson
|
||||
- unternehmen: Unternehmen/Firma
|
||||
- unbekannt: Nicht einzuordnen
|
||||
|
||||
Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0.
|
||||
Format: kategorie|konfidenz|kurze_begründung
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 100,
|
||||
},
|
||||
return await classify_sender(
|
||||
client, sender_email, sender_name, subject, body_preview
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result_text = data.get("response", "unbekannt|0.5|")
|
||||
|
||||
# Parse response
|
||||
parts = result_text.strip().split("|")
|
||||
if len(parts) >= 2:
|
||||
sender_type_str = parts[0].strip().lower()
|
||||
confidence = float(parts[1].strip())
|
||||
|
||||
# Map to enum
|
||||
type_mapping = {
|
||||
"kultusministerium": SenderType.KULTUSMINISTERIUM,
|
||||
"landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE,
|
||||
"rlsb": SenderType.RLSB,
|
||||
"schulamt": SenderType.SCHULAMT,
|
||||
"nibis": SenderType.NIBIS,
|
||||
"schultraeger": SenderType.SCHULTRAEGER,
|
||||
"elternvertreter": SenderType.ELTERNVERTRETER,
|
||||
"gewerkschaft": SenderType.GEWERKSCHAFT,
|
||||
"fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT,
|
||||
"privatperson": SenderType.PRIVATPERSON,
|
||||
"unternehmen": SenderType.UNTERNEHMEN,
|
||||
}
|
||||
|
||||
sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT)
|
||||
|
||||
return SenderClassification(
|
||||
sender_type=sender_type,
|
||||
confidence=min(max(confidence, 0.0), 1.0),
|
||||
domain_matched=False,
|
||||
ai_classified=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM sender classification failed: {e}")
|
||||
|
||||
# Default fallback
|
||||
return SenderClassification(
|
||||
sender_type=SenderType.UNBEKANNT,
|
||||
confidence=0.3,
|
||||
domain_matched=False,
|
||||
ai_classified=False,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Deadline Extraction
|
||||
# =========================================================================
|
||||
|
||||
async def extract_deadlines(
|
||||
self,
|
||||
subject: str,
|
||||
body_text: str,
|
||||
) -> List[DeadlineExtraction]:
|
||||
"""
|
||||
Extract deadlines from email content.
|
||||
|
||||
Uses regex patterns first, then LLM for complex cases.
|
||||
|
||||
Args:
|
||||
subject: Email subject
|
||||
body_text: Email body text
|
||||
|
||||
Returns:
|
||||
List of extracted deadlines
|
||||
"""
|
||||
deadlines = []
|
||||
|
||||
# Combine subject and body
|
||||
full_text = f"{subject}\n{body_text}" if body_text else subject
|
||||
|
||||
# Try regex extraction first
|
||||
regex_deadlines = self._extract_deadlines_regex(full_text)
|
||||
deadlines.extend(regex_deadlines)
|
||||
|
||||
# If no regex matches, try LLM
|
||||
if not deadlines and body_text:
|
||||
llm_deadlines = await self._extract_deadlines_llm(subject, body_text[:1000])
|
||||
deadlines.extend(llm_deadlines)
|
||||
|
||||
return deadlines
|
||||
|
||||
def _extract_deadlines_regex(self, text: str) -> List[DeadlineExtraction]:
|
||||
"""Extract deadlines using regex patterns."""
|
||||
deadlines = []
|
||||
now = datetime.now()
|
||||
|
||||
# German date patterns
|
||||
patterns = [
|
||||
# "bis zum 15.01.2025"
|
||||
(r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "spätestens am 15.01.2025"
|
||||
(r"spätestens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "Abgabetermin: 15.01.2025"
|
||||
(r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "innerhalb von 14 Tagen"
|
||||
(r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False),
|
||||
# "bis Ende Januar"
|
||||
(r"bis\s+(?:Ende\s+)?(Januar|Februar|März|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False),
|
||||
]
|
||||
|
||||
for pattern, is_specific_date in patterns:
|
||||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
if is_specific_date:
|
||||
day = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
year = int(match.group(3))
|
||||
|
||||
# Handle 2-digit years
|
||||
if year < 100:
|
||||
year += 2000
|
||||
|
||||
deadline_date = datetime(year, month, day)
|
||||
|
||||
# Skip past dates
|
||||
if deadline_date < now:
|
||||
continue
|
||||
|
||||
# Get surrounding context
|
||||
start = max(0, match.start() - 50)
|
||||
end = min(len(text), match.end() + 50)
|
||||
context = text[start:end].strip()
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=f"Frist: {match.group(0)}",
|
||||
confidence=0.85,
|
||||
source_text=context,
|
||||
is_firm=True,
|
||||
))
|
||||
|
||||
else:
|
||||
# Relative dates (innerhalb von X Tagen)
|
||||
if "Tagen" in pattern or "Wochen" in pattern:
|
||||
days = int(match.group(1))
|
||||
if "Wochen" in match.group(0).lower():
|
||||
days *= 7
|
||||
deadline_date = now + timedelta(days=days)
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=f"Relative Frist: {match.group(0)}",
|
||||
confidence=0.7,
|
||||
source_text=match.group(0),
|
||||
is_firm=False,
|
||||
))
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.debug(f"Failed to parse date: {e}")
|
||||
continue
|
||||
|
||||
return deadlines
|
||||
|
||||
async def _extract_deadlines_llm(
|
||||
self,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
) -> List[DeadlineExtraction]:
|
||||
"""Extract deadlines using LLM."""
|
||||
try:
|
||||
"""Extract deadlines from email content."""
|
||||
client = await self.get_http_client()
|
||||
|
||||
prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview}
|
||||
|
||||
Liste alle Fristen im folgenden Format auf (eine pro Zeile):
|
||||
DATUM|BESCHREIBUNG|VERBINDLICH
|
||||
Beispiel: 2025-01-15|Abgabe der Berichte|ja
|
||||
|
||||
Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN
|
||||
|
||||
Antworte NUR im angegebenen Format.
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 200,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result_text = data.get("response", "")
|
||||
|
||||
if "KEINE_FRISTEN" in result_text:
|
||||
return []
|
||||
|
||||
deadlines = []
|
||||
for line in result_text.strip().split("\n"):
|
||||
parts = line.split("|")
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
date_str = parts[0].strip()
|
||||
deadline_date = datetime.fromisoformat(date_str)
|
||||
description = parts[1].strip()
|
||||
is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=description,
|
||||
confidence=0.7,
|
||||
source_text=line,
|
||||
is_firm=is_firm,
|
||||
))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return deadlines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM deadline extraction failed: {e}")
|
||||
|
||||
return []
|
||||
|
||||
# =========================================================================
|
||||
# Email Category Classification
|
||||
# =========================================================================
|
||||
return await extract_deadlines(client, subject, body_text)
|
||||
|
||||
async def classify_category(
|
||||
self,
|
||||
@@ -368,155 +83,9 @@ Antworte NUR im angegebenen Format.
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""
|
||||
Classify email into a category.
|
||||
|
||||
Args:
|
||||
subject: Email subject
|
||||
body_preview: First 200 chars of body
|
||||
sender_type: Already classified sender type
|
||||
|
||||
Returns:
|
||||
Tuple of (category, confidence)
|
||||
"""
|
||||
# Rule-based classification first
|
||||
category, confidence = self._classify_category_rules(subject, body_preview, sender_type)
|
||||
|
||||
if confidence > 0.7:
|
||||
return category, confidence
|
||||
|
||||
# Fall back to LLM
|
||||
return await self._classify_category_llm(subject, body_preview)
|
||||
|
||||
def _classify_category_rules(
|
||||
self,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""Rule-based category classification."""
|
||||
text = f"{subject} {body_preview}".lower()
|
||||
|
||||
# Keywords for each category
|
||||
category_keywords = {
|
||||
EmailCategory.DIENSTLICH: [
|
||||
"dienstlich", "dienstanweisung", "erlass", "verordnung",
|
||||
"bescheid", "verfügung", "ministerium", "behörde"
|
||||
],
|
||||
EmailCategory.PERSONAL: [
|
||||
"personalrat", "stellenausschreibung", "versetzung",
|
||||
"beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit"
|
||||
],
|
||||
EmailCategory.FINANZEN: [
|
||||
"budget", "haushalt", "etat", "abrechnung", "rechnung",
|
||||
"erstattung", "zuschuss", "fördermittel"
|
||||
],
|
||||
EmailCategory.ELTERN: [
|
||||
"elternbrief", "elternabend", "schulkonferenz",
|
||||
"elternvertreter", "elternbeirat"
|
||||
],
|
||||
EmailCategory.SCHUELER: [
|
||||
"schüler", "schülerin", "zeugnis", "klasse", "unterricht",
|
||||
"prüfung", "klassenfahrt", "schulpflicht"
|
||||
],
|
||||
EmailCategory.FORTBILDUNG: [
|
||||
"fortbildung", "seminar", "workshop", "schulung",
|
||||
"weiterbildung", "nlq", "didaktik"
|
||||
],
|
||||
EmailCategory.VERANSTALTUNG: [
|
||||
"einladung", "veranstaltung", "termin", "konferenz",
|
||||
"sitzung", "tagung", "feier"
|
||||
],
|
||||
EmailCategory.SICHERHEIT: [
|
||||
"sicherheit", "notfall", "brandschutz", "evakuierung",
|
||||
"hygiene", "corona", "infektionsschutz"
|
||||
],
|
||||
EmailCategory.TECHNIK: [
|
||||
"it", "software", "computer", "netzwerk", "login",
|
||||
"passwort", "digitalisierung", "iserv"
|
||||
],
|
||||
EmailCategory.NEWSLETTER: [
|
||||
"newsletter", "rundschreiben", "info-mail", "mitteilung"
|
||||
],
|
||||
EmailCategory.WERBUNG: [
|
||||
"angebot", "rabatt", "aktion", "werbung", "abonnement"
|
||||
],
|
||||
}
|
||||
|
||||
best_category = EmailCategory.SONSTIGES
|
||||
best_score = 0.0
|
||||
|
||||
for category, keywords in category_keywords.items():
|
||||
score = sum(1 for kw in keywords if kw in text)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_category = category
|
||||
|
||||
# Adjust based on sender type
|
||||
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
|
||||
if best_category == EmailCategory.SONSTIGES:
|
||||
best_category = EmailCategory.DIENSTLICH
|
||||
best_score = 2
|
||||
|
||||
# Convert score to confidence
|
||||
confidence = min(0.9, 0.4 + (best_score * 0.15))
|
||||
|
||||
return best_category, confidence
|
||||
|
||||
async def _classify_category_llm(
|
||||
self,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""LLM-based category classification."""
|
||||
try:
|
||||
"""Classify email into a category."""
|
||||
client = await self.get_http_client()
|
||||
|
||||
categories = ", ".join([c.value for c in EmailCategory])
|
||||
|
||||
prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview[:500]}
|
||||
|
||||
Kategorien: {categories}
|
||||
|
||||
Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0):
|
||||
Format: kategorie|konfidenz
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result = data.get("response", "sonstiges|0.5")
|
||||
parts = result.strip().split("|")
|
||||
|
||||
if len(parts) >= 2:
|
||||
category_str = parts[0].strip().lower()
|
||||
confidence = float(parts[1].strip())
|
||||
|
||||
try:
|
||||
category = EmailCategory(category_str)
|
||||
return category, min(max(confidence, 0.0), 1.0)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM category classification failed: {e}")
|
||||
|
||||
return EmailCategory.SONSTIGES, 0.5
|
||||
|
||||
# =========================================================================
|
||||
# Full Analysis Pipeline
|
||||
# =========================================================================
|
||||
return await classify_category(client, subject, body_preview, sender_type)
|
||||
|
||||
async def analyze_email(
|
||||
self,
|
||||
@@ -527,20 +96,7 @@ Format: kategorie|konfidenz
|
||||
body_text: Optional[str],
|
||||
body_preview: Optional[str],
|
||||
) -> EmailAnalysisResult:
|
||||
"""
|
||||
Run full analysis pipeline on an email.
|
||||
|
||||
Args:
|
||||
email_id: Database ID of the email
|
||||
sender_email: Sender's email address
|
||||
sender_name: Sender's display name
|
||||
subject: Email subject
|
||||
body_text: Full body text
|
||||
body_preview: Preview text
|
||||
|
||||
Returns:
|
||||
Complete analysis result
|
||||
"""
|
||||
"""Run full analysis pipeline on an email."""
|
||||
# 1. Classify sender
|
||||
sender_classification = await self.classify_sender(
|
||||
sender_email, sender_name, subject, body_preview
|
||||
@@ -569,8 +125,8 @@ Format: kategorie|konfidenz
|
||||
elif days_until <= 7:
|
||||
suggested_priority = max(suggested_priority, TaskPriority.MEDIUM)
|
||||
|
||||
# 5. Generate summary (optional, can be expensive)
|
||||
summary = None # Could add LLM summary generation here
|
||||
# 5. Summary (optional)
|
||||
summary = None
|
||||
|
||||
# 6. Determine if task should be auto-created
|
||||
auto_create_task = (
|
||||
@@ -612,10 +168,6 @@ Format: kategorie|konfidenz
|
||||
auto_create_task=auto_create_task,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Response Suggestions
|
||||
# =========================================================================
|
||||
|
||||
async def suggest_response(
|
||||
self,
|
||||
subject: str,
|
||||
@@ -623,115 +175,12 @@ Format: kategorie|konfidenz
|
||||
sender_type: SenderType,
|
||||
category: EmailCategory,
|
||||
) -> List[ResponseSuggestion]:
|
||||
"""
|
||||
Generate response suggestions for an email.
|
||||
|
||||
Args:
|
||||
subject: Original email subject
|
||||
body_text: Original email body
|
||||
sender_type: Classified sender type
|
||||
category: Classified category
|
||||
|
||||
Returns:
|
||||
List of response suggestions
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
# Add standard templates based on sender type and category
|
||||
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
|
||||
suggestions.append(ResponseSuggestion(
|
||||
template_type="acknowledgment",
|
||||
subject=f"Re: {subject}",
|
||||
body="""Sehr geehrte Damen und Herren,
|
||||
|
||||
vielen Dank für Ihre Nachricht.
|
||||
|
||||
Ich bestätige den Eingang und werde die Angelegenheit fristgerecht bearbeiten.
|
||||
|
||||
Mit freundlichen Grüßen""",
|
||||
confidence=0.8,
|
||||
))
|
||||
|
||||
if category == EmailCategory.ELTERN:
|
||||
suggestions.append(ResponseSuggestion(
|
||||
template_type="parent_response",
|
||||
subject=f"Re: {subject}",
|
||||
body="""Liebe Eltern,
|
||||
|
||||
vielen Dank für Ihre Nachricht.
|
||||
|
||||
[Ihre Antwort hier]
|
||||
|
||||
Mit freundlichen Grüßen""",
|
||||
confidence=0.7,
|
||||
))
|
||||
|
||||
# Add LLM-generated suggestion
|
||||
try:
|
||||
llm_suggestion = await self._generate_response_llm(subject, body_text[:500], sender_type)
|
||||
if llm_suggestion:
|
||||
suggestions.append(llm_suggestion)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response generation failed: {e}")
|
||||
|
||||
return suggestions
|
||||
|
||||
async def _generate_response_llm(
|
||||
self,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Optional[ResponseSuggestion]:
|
||||
"""Generate a response suggestion using LLM."""
|
||||
try:
|
||||
"""Generate response suggestions for an email."""
|
||||
client = await self.get_http_client()
|
||||
|
||||
sender_desc = {
|
||||
SenderType.KULTUSMINISTERIUM: "dem Kultusministerium",
|
||||
SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehörde",
|
||||
SenderType.RLSB: "dem RLSB",
|
||||
SenderType.ELTERNVERTRETER: "einem Elternvertreter",
|
||||
}.get(sender_type, "einem Absender")
|
||||
|
||||
prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview}
|
||||
|
||||
Die Antwort sollte:
|
||||
- Höflich und formell sein
|
||||
- Den Eingang bestätigen
|
||||
- Eine konkrete nächste Aktion nennen oder um Klärung bitten
|
||||
|
||||
Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:").
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 300,
|
||||
},
|
||||
return await suggest_response(
|
||||
client, subject, body_text, sender_type, category
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
body = data.get("response", "").strip()
|
||||
|
||||
if body:
|
||||
return ResponseSuggestion(
|
||||
template_type="ai_generated",
|
||||
subject=f"Re: {subject}",
|
||||
body=body,
|
||||
confidence=0.6,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response generation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global instance
|
||||
_ai_service: Optional[AIEmailService] = None
|
||||
|
||||
@@ -1,833 +1,36 @@
|
||||
"""
|
||||
PostgreSQL Metrics Database Service
|
||||
Stores search feedback, calculates quality metrics (Precision, Recall, MRR).
|
||||
PostgreSQL Metrics Database Service — Barrel Re-export
|
||||
|
||||
Split into:
|
||||
- metrics_db_core.py — Pool, feedback, metrics, relevance
|
||||
- metrics_db_schema.py — Table initialization (DDL)
|
||||
- metrics_db_zeugnis.py — Zeugnis source/document/stats operations
|
||||
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
|
||||
# Database Configuration - uses test default if not configured (for CI)
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
|
||||
|
||||
# Connection pool
|
||||
_pool = None
|
||||
|
||||
|
||||
async def get_pool():
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
try:
|
||||
import asyncpg
|
||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
except ImportError:
|
||||
print("Warning: asyncpg not installed. Metrics storage disabled.")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to connect to PostgreSQL: {e}")
|
||||
return None
|
||||
return _pool
|
||||
|
||||
|
||||
async def init_metrics_tables() -> bool:
|
||||
"""Initialize metrics tables in PostgreSQL."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
create_tables_sql = """
|
||||
-- RAG Search Feedback Table
|
||||
CREATE TABLE IF NOT EXISTS rag_search_feedback (
|
||||
id SERIAL PRIMARY KEY,
|
||||
result_id VARCHAR(255) NOT NULL,
|
||||
query_text TEXT,
|
||||
collection_name VARCHAR(100),
|
||||
score FLOAT,
|
||||
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
|
||||
notes TEXT,
|
||||
user_id VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Index for efficient querying
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
|
||||
|
||||
-- RAG Search Logs Table (for latency tracking)
|
||||
CREATE TABLE IF NOT EXISTS rag_search_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
query_text TEXT NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
result_count INTEGER,
|
||||
latency_ms INTEGER,
|
||||
top_score FLOAT,
|
||||
filters JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
|
||||
|
||||
-- RAG Upload History Table
|
||||
CREATE TABLE IF NOT EXISTS rag_upload_history (
|
||||
id SERIAL PRIMARY KEY,
|
||||
filename VARCHAR(500) NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
year INTEGER,
|
||||
pdfs_extracted INTEGER,
|
||||
minio_path VARCHAR(1000),
|
||||
uploaded_by VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
|
||||
|
||||
-- Binäre Relevanz-Judgments für echte Precision/Recall
|
||||
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
|
||||
id SERIAL PRIMARY KEY,
|
||||
query_id VARCHAR(255) NOT NULL,
|
||||
query_text TEXT NOT NULL,
|
||||
result_id VARCHAR(255) NOT NULL,
|
||||
result_rank INTEGER,
|
||||
is_relevant BOOLEAN NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
user_id VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
|
||||
|
||||
-- Zeugnisse Source Tracking
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_sources (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
bundesland VARCHAR(10) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
base_url TEXT,
|
||||
license_type VARCHAR(50) NOT NULL,
|
||||
training_allowed BOOLEAN DEFAULT FALSE,
|
||||
verified_by VARCHAR(100),
|
||||
verified_at TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
|
||||
|
||||
-- Zeugnisse Seed URLs
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
|
||||
url TEXT NOT NULL,
|
||||
doc_type VARCHAR(50),
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
last_crawled TIMESTAMP,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
|
||||
|
||||
-- Zeugnisse Documents
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_documents (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
|
||||
title VARCHAR(500),
|
||||
url TEXT NOT NULL,
|
||||
content_hash VARCHAR(64),
|
||||
minio_path TEXT,
|
||||
training_allowed BOOLEAN DEFAULT FALSE,
|
||||
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
|
||||
file_size INTEGER,
|
||||
content_type VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
|
||||
|
||||
-- Zeugnisse Document Versions
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
|
||||
version INTEGER NOT NULL,
|
||||
content_hash VARCHAR(64),
|
||||
minio_path TEXT,
|
||||
change_summary TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
|
||||
|
||||
-- Zeugnisse Usage Events (Audit Trail)
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
|
||||
event_type VARCHAR(50) NOT NULL,
|
||||
user_id VARCHAR(100),
|
||||
details JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
|
||||
|
||||
-- Crawler Queue
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
|
||||
priority INTEGER DEFAULT 5,
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
documents_found INTEGER DEFAULT 0,
|
||||
documents_indexed INTEGER DEFAULT 0,
|
||||
error_count INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
|
||||
"""
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(create_tables_sql)
|
||||
print("RAG metrics tables initialized")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize metrics tables: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Feedback Storage
|
||||
# =============================================================================
|
||||
|
||||
async def store_feedback(
|
||||
result_id: str,
|
||||
rating: int,
|
||||
query_text: Optional[str] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
score: Optional[float] = None,
|
||||
notes: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Store search result feedback."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_search_feedback
|
||||
(result_id, query_text, collection_name, score, rating, notes, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""",
|
||||
result_id, query_text, collection_name, score, rating, notes, user_id
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to store feedback: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def log_search(
|
||||
query_text: str,
|
||||
collection_name: str,
|
||||
result_count: int,
|
||||
latency_ms: int,
|
||||
top_score: Optional[float] = None,
|
||||
filters: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Log a search for metrics tracking."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
import json
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_search_logs
|
||||
(query_text, collection_name, result_count, latency_ms, top_score, filters)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
query_text, collection_name, result_count, latency_ms, top_score,
|
||||
json.dumps(filters) if filters else None
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log search: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def log_upload(
|
||||
filename: str,
|
||||
collection_name: str,
|
||||
year: int,
|
||||
pdfs_extracted: int,
|
||||
minio_path: Optional[str] = None,
|
||||
uploaded_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Log an upload for history tracking."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_upload_history
|
||||
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log upload: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metrics Calculation
|
||||
# =============================================================================
|
||||
|
||||
async def calculate_metrics(
|
||||
collection_name: Optional[str] = None,
|
||||
days: int = 7,
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate RAG quality metrics from stored feedback.
|
||||
|
||||
Returns:
|
||||
Dict with precision, recall, MRR, latency, etc.
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available", "connected": False}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
# Date filter
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
# Collection filter
|
||||
collection_filter = ""
|
||||
params = [since]
|
||||
if collection_name:
|
||||
collection_filter = "AND collection_name = $2"
|
||||
params.append(collection_name)
|
||||
|
||||
# Total feedback count
|
||||
total_feedback = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Rating distribution
|
||||
rating_dist = await conn.fetch(
|
||||
f"""
|
||||
SELECT rating, COUNT(*) as count
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
GROUP BY rating
|
||||
ORDER BY rating DESC
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Average rating (proxy for precision)
|
||||
avg_rating = await conn.fetchval(
|
||||
f"""
|
||||
SELECT AVG(rating) FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Score distribution
|
||||
score_dist = await conn.fetch(
|
||||
f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN score >= 0.9 THEN '0.9+'
|
||||
WHEN score >= 0.7 THEN '0.7-0.9'
|
||||
WHEN score >= 0.5 THEN '0.5-0.7'
|
||||
ELSE '<0.5'
|
||||
END as range,
|
||||
COUNT(*) as count
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
|
||||
GROUP BY range
|
||||
ORDER BY range DESC
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Search logs for latency
|
||||
latency_stats = await conn.fetchrow(
|
||||
f"""
|
||||
SELECT
|
||||
AVG(latency_ms) as avg_latency,
|
||||
COUNT(*) as total_searches,
|
||||
AVG(result_count) as avg_results
|
||||
FROM rag_search_logs
|
||||
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Calculate precision@5 (% of top 5 rated 4+)
|
||||
precision_at_5 = await conn.fetchval(
|
||||
f"""
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) > 0
|
||||
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
|
||||
ELSE 0 END
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
# Calculate MRR (Mean Reciprocal Rank) - simplified
|
||||
# Using average rating as proxy for relevance
|
||||
mrr = (avg_rating or 0) / 5.0
|
||||
|
||||
# Error rate (ratings of 1 or 2)
|
||||
error_count = sum(
|
||||
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
|
||||
)
|
||||
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
|
||||
|
||||
# Score distribution as percentages
|
||||
total_scored = sum(s['count'] for s in score_dist)
|
||||
score_distribution = {}
|
||||
for s in score_dist:
|
||||
if total_scored > 0:
|
||||
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
|
||||
else:
|
||||
score_distribution[s['range']] = 0
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"period_days": days,
|
||||
"precision_at_5": round(precision_at_5, 2),
|
||||
"recall_at_10": round(precision_at_5 * 1.1, 2), # Estimated
|
||||
"mrr": round(mrr, 2),
|
||||
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
|
||||
"total_ratings": total_feedback,
|
||||
"total_searches": latency_stats['total_searches'] or 0,
|
||||
"error_rate": round(error_rate, 1),
|
||||
"score_distribution": score_distribution,
|
||||
"rating_distribution": {
|
||||
str(r['rating']): r['count'] for r in rating_dist if r['rating']
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to calculate metrics: {e}")
|
||||
return {"error": str(e), "connected": False}
|
||||
|
||||
|
||||
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent feedback entries."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
|
||||
FROM rag_search_feedback
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit
|
||||
)
|
||||
return [
|
||||
{
|
||||
"result_id": r['result_id'],
|
||||
"rating": r['rating'],
|
||||
"query_text": r['query_text'],
|
||||
"collection_name": r['collection_name'],
|
||||
"score": r['score'],
|
||||
"notes": r['notes'],
|
||||
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Failed to get recent feedback: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_upload_history(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent upload history."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
|
||||
FROM rag_upload_history
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit
|
||||
)
|
||||
return [
|
||||
{
|
||||
"filename": r['filename'],
|
||||
"collection_name": r['collection_name'],
|
||||
"year": r['year'],
|
||||
"pdfs_extracted": r['pdfs_extracted'],
|
||||
"minio_path": r['minio_path'],
|
||||
"uploaded_by": r['uploaded_by'],
|
||||
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Failed to get upload history: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Relevance Judgments (Binary Precision/Recall)
|
||||
# =============================================================================
|
||||
|
||||
async def store_relevance_judgment(
|
||||
query_id: str,
|
||||
query_text: str,
|
||||
result_id: str,
|
||||
is_relevant: bool,
|
||||
result_rank: Optional[int] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Store binary relevance judgment for Precision/Recall calculation."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_relevance_judgments
|
||||
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT DO NOTHING
|
||||
""",
|
||||
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to store relevance judgment: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def calculate_precision_recall(
|
||||
collection_name: Optional[str] = None,
|
||||
days: int = 7,
|
||||
k: int = 10,
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate true Precision@k and Recall@k from binary relevance judgments.
|
||||
|
||||
Precision@k = (Relevant docs in top k) / k
|
||||
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available", "connected": False}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
collection_filter = ""
|
||||
params = [since, k]
|
||||
if collection_name:
|
||||
collection_filter = "AND collection_name = $3"
|
||||
params.append(collection_name)
|
||||
|
||||
# Get precision@k per query, then average
|
||||
precision_result = await conn.fetchval(
|
||||
f"""
|
||||
WITH query_precision AS (
|
||||
SELECT
|
||||
query_id,
|
||||
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
|
||||
GREATEST(COUNT(*), 1) as precision
|
||||
FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1
|
||||
AND (result_rank IS NULL OR result_rank <= $2)
|
||||
{collection_filter}
|
||||
GROUP BY query_id
|
||||
)
|
||||
SELECT AVG(precision) FROM query_precision
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
# Get recall@k per query, then average
|
||||
recall_result = await conn.fetchval(
|
||||
f"""
|
||||
WITH query_recall AS (
|
||||
SELECT
|
||||
query_id,
|
||||
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
|
||||
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
|
||||
FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1
|
||||
{collection_filter}
|
||||
GROUP BY query_id
|
||||
)
|
||||
SELECT AVG(recall) FROM query_recall
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
# Total judgments
|
||||
total_judgments = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
since, *([collection_name] if collection_name else [])
|
||||
)
|
||||
|
||||
# Unique queries
|
||||
unique_queries = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
since, *([collection_name] if collection_name else [])
|
||||
)
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"period_days": days,
|
||||
"k": k,
|
||||
"precision_at_k": round(precision_result, 3),
|
||||
"recall_at_k": round(recall_result, 3),
|
||||
"f1_score": round(
|
||||
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
|
||||
),
|
||||
"total_judgments": total_judgments or 0,
|
||||
"unique_queries": unique_queries or 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to calculate precision/recall: {e}")
|
||||
return {"error": str(e), "connected": False}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Zeugnis Database Operations
|
||||
# =============================================================================
|
||||
|
||||
async def get_zeugnis_sources() -> List[Dict]:
|
||||
"""Get all zeugnis sources (Bundesländer)."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, bundesland, name, base_url, license_type, training_allowed,
|
||||
verified_by, verified_at, created_at, updated_at
|
||||
FROM zeugnis_sources
|
||||
ORDER BY bundesland
|
||||
"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis sources: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def upsert_zeugnis_source(
|
||||
id: str,
|
||||
bundesland: str,
|
||||
name: str,
|
||||
license_type: str,
|
||||
training_allowed: bool,
|
||||
base_url: Optional[str] = None,
|
||||
verified_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert or update a zeugnis source."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
base_url = EXCLUDED.base_url,
|
||||
license_type = EXCLUDED.license_type,
|
||||
training_allowed = EXCLUDED.training_allowed,
|
||||
verified_by = EXCLUDED.verified_by,
|
||||
verified_at = NOW(),
|
||||
updated_at = NOW()
|
||||
""",
|
||||
id, bundesland, name, base_url, license_type, training_allowed, verified_by
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to upsert zeugnis source: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_zeugnis_documents(
|
||||
bundesland: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[Dict]:
|
||||
"""Get zeugnis documents with optional filtering."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
if bundesland:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
WHERE s.bundesland = $1
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
""",
|
||||
bundesland, limit, offset
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
""",
|
||||
limit, offset
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis documents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_zeugnis_stats() -> Dict:
|
||||
"""Get zeugnis crawler statistics."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available"}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
# Total sources
|
||||
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
|
||||
|
||||
# Total documents
|
||||
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
|
||||
|
||||
# Indexed documents
|
||||
indexed = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
|
||||
)
|
||||
|
||||
# Training allowed
|
||||
training_allowed = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
|
||||
)
|
||||
|
||||
# Per Bundesland stats
|
||||
per_bundesland = await conn.fetch(
|
||||
"""
|
||||
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
|
||||
FROM zeugnis_sources s
|
||||
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
|
||||
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
|
||||
GROUP BY s.bundesland, s.name, s.training_allowed
|
||||
ORDER BY s.bundesland
|
||||
"""
|
||||
)
|
||||
|
||||
# Active crawls
|
||||
active_crawls = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
|
||||
)
|
||||
|
||||
return {
|
||||
"total_sources": sources or 0,
|
||||
"total_documents": documents or 0,
|
||||
"indexed_documents": indexed or 0,
|
||||
"training_allowed_documents": training_allowed or 0,
|
||||
"active_crawls": active_crawls or 0,
|
||||
"per_bundesland": [dict(r) for r in per_bundesland],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def log_zeugnis_event(
|
||||
document_id: str,
|
||||
event_type: str,
|
||||
user_id: Optional[str] = None,
|
||||
details: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Log a zeugnis usage event for audit trail."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
import json
|
||||
import uuid
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""",
|
||||
str(uuid.uuid4()), document_id, event_type, user_id,
|
||||
json.dumps(details) if details else None
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log zeugnis event: {e}")
|
||||
return False
|
||||
# Schema: table initialization
|
||||
from metrics_db_schema import init_metrics_tables # noqa: F401
|
||||
|
||||
# Core: pool, feedback, search logs, metrics, relevance
|
||||
from metrics_db_core import ( # noqa: F401
|
||||
DATABASE_URL,
|
||||
get_pool,
|
||||
store_feedback,
|
||||
log_search,
|
||||
log_upload,
|
||||
calculate_metrics,
|
||||
get_recent_feedback,
|
||||
get_upload_history,
|
||||
store_relevance_judgment,
|
||||
calculate_precision_recall,
|
||||
)
|
||||
|
||||
# Zeugnis operations
|
||||
from metrics_db_zeugnis import ( # noqa: F401
|
||||
get_zeugnis_sources,
|
||||
upsert_zeugnis_source,
|
||||
get_zeugnis_documents,
|
||||
get_zeugnis_stats,
|
||||
log_zeugnis_event,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
PostgreSQL Metrics Database - Core Operations
|
||||
|
||||
Connection pool, table initialization, feedback storage, search logging,
|
||||
upload history, metrics calculation, and relevance judgments.
|
||||
|
||||
Extracted from metrics_db.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Database Configuration - uses test default if not configured (for CI)
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
|
||||
|
||||
# Connection pool
|
||||
_pool = None
|
||||
|
||||
|
||||
async def get_pool():
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
try:
|
||||
import asyncpg
|
||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
except ImportError:
|
||||
print("Warning: asyncpg not installed. Metrics storage disabled.")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to connect to PostgreSQL: {e}")
|
||||
return None
|
||||
return _pool
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Feedback Storage
|
||||
# =============================================================================
|
||||
|
||||
async def store_feedback(
|
||||
result_id: str,
|
||||
rating: int,
|
||||
query_text: Optional[str] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
score: Optional[float] = None,
|
||||
notes: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Store search result feedback."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_search_feedback
|
||||
(result_id, query_text, collection_name, score, rating, notes, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""",
|
||||
result_id, query_text, collection_name, score, rating, notes, user_id
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to store feedback: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def log_search(
|
||||
query_text: str,
|
||||
collection_name: str,
|
||||
result_count: int,
|
||||
latency_ms: int,
|
||||
top_score: Optional[float] = None,
|
||||
filters: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Log a search for metrics tracking."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
import json
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_search_logs
|
||||
(query_text, collection_name, result_count, latency_ms, top_score, filters)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
query_text, collection_name, result_count, latency_ms, top_score,
|
||||
json.dumps(filters) if filters else None
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log search: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def log_upload(
|
||||
filename: str,
|
||||
collection_name: str,
|
||||
year: int,
|
||||
pdfs_extracted: int,
|
||||
minio_path: Optional[str] = None,
|
||||
uploaded_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Log an upload for history tracking."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_upload_history
|
||||
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log upload: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metrics Calculation
|
||||
# =============================================================================
|
||||
|
||||
async def calculate_metrics(
|
||||
collection_name: Optional[str] = None,
|
||||
days: int = 7,
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate RAG quality metrics from stored feedback.
|
||||
|
||||
Returns:
|
||||
Dict with precision, recall, MRR, latency, etc.
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available", "connected": False}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
collection_filter = ""
|
||||
params = [since]
|
||||
if collection_name:
|
||||
collection_filter = "AND collection_name = $2"
|
||||
params.append(collection_name)
|
||||
|
||||
total_feedback = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
rating_dist = await conn.fetch(
|
||||
f"""
|
||||
SELECT rating, COUNT(*) as count
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
GROUP BY rating
|
||||
ORDER BY rating DESC
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
avg_rating = await conn.fetchval(
|
||||
f"""
|
||||
SELECT AVG(rating) FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
score_dist = await conn.fetch(
|
||||
f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN score >= 0.9 THEN '0.9+'
|
||||
WHEN score >= 0.7 THEN '0.7-0.9'
|
||||
WHEN score >= 0.5 THEN '0.5-0.7'
|
||||
ELSE '<0.5'
|
||||
END as range,
|
||||
COUNT(*) as count
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
|
||||
GROUP BY range
|
||||
ORDER BY range DESC
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
latency_stats = await conn.fetchrow(
|
||||
f"""
|
||||
SELECT
|
||||
AVG(latency_ms) as avg_latency,
|
||||
COUNT(*) as total_searches,
|
||||
AVG(result_count) as avg_results
|
||||
FROM rag_search_logs
|
||||
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
precision_at_5 = await conn.fetchval(
|
||||
f"""
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) > 0
|
||||
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
|
||||
ELSE 0 END
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
mrr = (avg_rating or 0) / 5.0
|
||||
|
||||
error_count = sum(
|
||||
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
|
||||
)
|
||||
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
|
||||
|
||||
total_scored = sum(s['count'] for s in score_dist)
|
||||
score_distribution = {}
|
||||
for s in score_dist:
|
||||
if total_scored > 0:
|
||||
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
|
||||
else:
|
||||
score_distribution[s['range']] = 0
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"period_days": days,
|
||||
"precision_at_5": round(precision_at_5, 2),
|
||||
"recall_at_10": round(precision_at_5 * 1.1, 2),
|
||||
"mrr": round(mrr, 2),
|
||||
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
|
||||
"total_ratings": total_feedback,
|
||||
"total_searches": latency_stats['total_searches'] or 0,
|
||||
"error_rate": round(error_rate, 1),
|
||||
"score_distribution": score_distribution,
|
||||
"rating_distribution": {
|
||||
str(r['rating']): r['count'] for r in rating_dist if r['rating']
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to calculate metrics: {e}")
|
||||
return {"error": str(e), "connected": False}
|
||||
|
||||
|
||||
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent feedback entries."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
|
||||
FROM rag_search_feedback
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit
|
||||
)
|
||||
return [
|
||||
{
|
||||
"result_id": r['result_id'],
|
||||
"rating": r['rating'],
|
||||
"query_text": r['query_text'],
|
||||
"collection_name": r['collection_name'],
|
||||
"score": r['score'],
|
||||
"notes": r['notes'],
|
||||
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Failed to get recent feedback: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_upload_history(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent upload history."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
|
||||
FROM rag_upload_history
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit
|
||||
)
|
||||
return [
|
||||
{
|
||||
"filename": r['filename'],
|
||||
"collection_name": r['collection_name'],
|
||||
"year": r['year'],
|
||||
"pdfs_extracted": r['pdfs_extracted'],
|
||||
"minio_path": r['minio_path'],
|
||||
"uploaded_by": r['uploaded_by'],
|
||||
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Failed to get upload history: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Relevance Judgments (Binary Precision/Recall)
|
||||
# =============================================================================
|
||||
|
||||
async def store_relevance_judgment(
|
||||
query_id: str,
|
||||
query_text: str,
|
||||
result_id: str,
|
||||
is_relevant: bool,
|
||||
result_rank: Optional[int] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Store binary relevance judgment for Precision/Recall calculation."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_relevance_judgments
|
||||
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT DO NOTHING
|
||||
""",
|
||||
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to store relevance judgment: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def calculate_precision_recall(
|
||||
collection_name: Optional[str] = None,
|
||||
days: int = 7,
|
||||
k: int = 10,
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate true Precision@k and Recall@k from binary relevance judgments.
|
||||
|
||||
Precision@k = (Relevant docs in top k) / k
|
||||
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available", "connected": False}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
collection_filter = ""
|
||||
params = [since, k]
|
||||
if collection_name:
|
||||
collection_filter = "AND collection_name = $3"
|
||||
params.append(collection_name)
|
||||
|
||||
precision_result = await conn.fetchval(
|
||||
f"""
|
||||
WITH query_precision AS (
|
||||
SELECT
|
||||
query_id,
|
||||
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
|
||||
GREATEST(COUNT(*), 1) as precision
|
||||
FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1
|
||||
AND (result_rank IS NULL OR result_rank <= $2)
|
||||
{collection_filter}
|
||||
GROUP BY query_id
|
||||
)
|
||||
SELECT AVG(precision) FROM query_precision
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
recall_result = await conn.fetchval(
|
||||
f"""
|
||||
WITH query_recall AS (
|
||||
SELECT
|
||||
query_id,
|
||||
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
|
||||
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
|
||||
FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1
|
||||
{collection_filter}
|
||||
GROUP BY query_id
|
||||
)
|
||||
SELECT AVG(recall) FROM query_recall
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
total_judgments = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
since, *([collection_name] if collection_name else [])
|
||||
)
|
||||
|
||||
unique_queries = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
since, *([collection_name] if collection_name else [])
|
||||
)
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"period_days": days,
|
||||
"k": k,
|
||||
"precision_at_k": round(precision_result, 3),
|
||||
"recall_at_k": round(recall_result, 3),
|
||||
"f1_score": round(
|
||||
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
|
||||
),
|
||||
"total_judgments": total_judgments or 0,
|
||||
"unique_queries": unique_queries or 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to calculate precision/recall: {e}")
|
||||
return {"error": str(e), "connected": False}
|
||||
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
PostgreSQL Metrics Database - Schema Initialization
|
||||
|
||||
Table creation DDL for all metrics, feedback, and zeugnis tables.
|
||||
|
||||
Extracted from metrics_db_core.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
from metrics_db_core import get_pool
|
||||
|
||||
|
||||
async def init_metrics_tables() -> bool:
|
||||
"""Initialize metrics tables in PostgreSQL."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
create_tables_sql = """
|
||||
-- RAG Search Feedback Table
|
||||
CREATE TABLE IF NOT EXISTS rag_search_feedback (
|
||||
id SERIAL PRIMARY KEY,
|
||||
result_id VARCHAR(255) NOT NULL,
|
||||
query_text TEXT,
|
||||
collection_name VARCHAR(100),
|
||||
score FLOAT,
|
||||
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
|
||||
notes TEXT,
|
||||
user_id VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Index for efficient querying
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
|
||||
|
||||
-- RAG Search Logs Table (for latency tracking)
|
||||
CREATE TABLE IF NOT EXISTS rag_search_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
query_text TEXT NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
result_count INTEGER,
|
||||
latency_ms INTEGER,
|
||||
top_score FLOAT,
|
||||
filters JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
|
||||
|
||||
-- RAG Upload History Table
|
||||
CREATE TABLE IF NOT EXISTS rag_upload_history (
|
||||
id SERIAL PRIMARY KEY,
|
||||
filename VARCHAR(500) NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
year INTEGER,
|
||||
pdfs_extracted INTEGER,
|
||||
minio_path VARCHAR(1000),
|
||||
uploaded_by VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
|
||||
|
||||
-- Binaere Relevanz-Judgments fuer echte Precision/Recall
|
||||
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
|
||||
id SERIAL PRIMARY KEY,
|
||||
query_id VARCHAR(255) NOT NULL,
|
||||
query_text TEXT NOT NULL,
|
||||
result_id VARCHAR(255) NOT NULL,
|
||||
result_rank INTEGER,
|
||||
is_relevant BOOLEAN NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
user_id VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
|
||||
|
||||
-- Zeugnisse Source Tracking
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_sources (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
bundesland VARCHAR(10) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
base_url TEXT,
|
||||
license_type VARCHAR(50) NOT NULL,
|
||||
training_allowed BOOLEAN DEFAULT FALSE,
|
||||
verified_by VARCHAR(100),
|
||||
verified_at TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
|
||||
|
||||
-- Zeugnisse Seed URLs
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
|
||||
url TEXT NOT NULL,
|
||||
doc_type VARCHAR(50),
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
last_crawled TIMESTAMP,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
|
||||
|
||||
-- Zeugnisse Documents
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_documents (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
|
||||
title VARCHAR(500),
|
||||
url TEXT NOT NULL,
|
||||
content_hash VARCHAR(64),
|
||||
minio_path TEXT,
|
||||
training_allowed BOOLEAN DEFAULT FALSE,
|
||||
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
|
||||
file_size INTEGER,
|
||||
content_type VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
|
||||
|
||||
-- Zeugnisse Document Versions
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
|
||||
version INTEGER NOT NULL,
|
||||
content_hash VARCHAR(64),
|
||||
minio_path TEXT,
|
||||
change_summary TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
|
||||
|
||||
-- Zeugnisse Usage Events (Audit Trail)
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
|
||||
event_type VARCHAR(50) NOT NULL,
|
||||
user_id VARCHAR(100),
|
||||
details JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
|
||||
|
||||
-- Crawler Queue
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
|
||||
priority INTEGER DEFAULT 5,
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
documents_found INTEGER DEFAULT 0,
|
||||
documents_indexed INTEGER DEFAULT 0,
|
||||
error_count INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
|
||||
"""
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(create_tables_sql)
|
||||
print("RAG metrics tables initialized")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize metrics tables: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
PostgreSQL Metrics Database - Zeugnis Operations
|
||||
|
||||
Zeugnis source management, document queries, statistics, and event logging.
|
||||
|
||||
Extracted from metrics_db.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
from metrics_db_core import get_pool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Zeugnis Database Operations
|
||||
# =============================================================================
|
||||
|
||||
async def get_zeugnis_sources() -> List[Dict]:
|
||||
"""Get all zeugnis sources (Bundeslaender)."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, bundesland, name, base_url, license_type, training_allowed,
|
||||
verified_by, verified_at, created_at, updated_at
|
||||
FROM zeugnis_sources
|
||||
ORDER BY bundesland
|
||||
"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis sources: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def upsert_zeugnis_source(
|
||||
id: str,
|
||||
bundesland: str,
|
||||
name: str,
|
||||
license_type: str,
|
||||
training_allowed: bool,
|
||||
base_url: Optional[str] = None,
|
||||
verified_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert or update a zeugnis source."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
base_url = EXCLUDED.base_url,
|
||||
license_type = EXCLUDED.license_type,
|
||||
training_allowed = EXCLUDED.training_allowed,
|
||||
verified_by = EXCLUDED.verified_by,
|
||||
verified_at = NOW(),
|
||||
updated_at = NOW()
|
||||
""",
|
||||
id, bundesland, name, base_url, license_type, training_allowed, verified_by
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to upsert zeugnis source: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_zeugnis_documents(
|
||||
bundesland: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[Dict]:
|
||||
"""Get zeugnis documents with optional filtering."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
if bundesland:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
WHERE s.bundesland = $1
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
""",
|
||||
bundesland, limit, offset
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
""",
|
||||
limit, offset
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis documents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_zeugnis_stats() -> Dict:
|
||||
"""Get zeugnis crawler statistics."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available"}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
|
||||
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
|
||||
|
||||
indexed = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
|
||||
)
|
||||
|
||||
training_allowed = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
|
||||
)
|
||||
|
||||
per_bundesland = await conn.fetch(
|
||||
"""
|
||||
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
|
||||
FROM zeugnis_sources s
|
||||
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
|
||||
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
|
||||
GROUP BY s.bundesland, s.name, s.training_allowed
|
||||
ORDER BY s.bundesland
|
||||
"""
|
||||
)
|
||||
|
||||
active_crawls = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
|
||||
)
|
||||
|
||||
return {
|
||||
"total_sources": sources or 0,
|
||||
"total_documents": documents or 0,
|
||||
"indexed_documents": indexed or 0,
|
||||
"training_allowed_documents": training_allowed or 0,
|
||||
"active_crawls": active_crawls or 0,
|
||||
"per_bundesland": [dict(r) for r in per_bundesland],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def log_zeugnis_event(
|
||||
document_id: str,
|
||||
event_type: str,
|
||||
user_id: Optional[str] = None,
|
||||
details: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Log a zeugnis usage event for audit trail."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
import json
|
||||
import uuid
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""",
|
||||
str(uuid.uuid4()), document_id, event_type, user_id,
|
||||
json.dumps(details) if details else None
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log zeugnis event: {e}")
|
||||
return False
|
||||
@@ -1,845 +1,81 @@
|
||||
"""
|
||||
OCR Labeling API for Handwriting Training Data Collection
|
||||
OCR Labeling API — Barrel Re-export
|
||||
|
||||
DATENSCHUTZ/PRIVACY:
|
||||
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
|
||||
- Keine Daten werden an externe Server gesendet
|
||||
- Bilder werden mit SHA256-Hash dedupliziert
|
||||
- Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision)
|
||||
Split into:
|
||||
- ocr_labeling_models.py — Pydantic models and constants
|
||||
- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing
|
||||
- ocr_labeling_routes.py — Session/queue/labeling route handlers
|
||||
- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers
|
||||
|
||||
Endpoints:
|
||||
- POST /sessions - Create labeling session
|
||||
- POST /sessions/{id}/upload - Upload images for labeling
|
||||
- GET /queue - Get next items to label
|
||||
- POST /confirm - Confirm OCR as correct
|
||||
- POST /correct - Save corrected ground truth
|
||||
- POST /skip - Skip unusable item
|
||||
- GET /stats - Get labeling statistics
|
||||
- POST /export - Export training data
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import hashlib
|
||||
import os
|
||||
import base64
|
||||
|
||||
# Import database functions
|
||||
from metrics_db import (
|
||||
create_ocr_labeling_session,
|
||||
get_ocr_labeling_sessions,
|
||||
get_ocr_labeling_session,
|
||||
add_ocr_labeling_item,
|
||||
get_ocr_labeling_queue,
|
||||
get_ocr_labeling_item,
|
||||
confirm_ocr_label,
|
||||
correct_ocr_label,
|
||||
skip_ocr_item,
|
||||
get_ocr_labeling_stats,
|
||||
export_training_samples,
|
||||
get_training_samples,
|
||||
# Models
|
||||
from ocr_labeling_models import ( # noqa: F401
|
||||
LOCAL_STORAGE_PATH,
|
||||
SessionCreate,
|
||||
SessionResponse,
|
||||
ItemResponse,
|
||||
ConfirmRequest,
|
||||
CorrectRequest,
|
||||
SkipRequest,
|
||||
ExportRequest,
|
||||
StatsResponse,
|
||||
)
|
||||
|
||||
# Try to import Vision OCR service
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
|
||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
|
||||
VISION_OCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
VISION_OCR_AVAILABLE = False
|
||||
print("Warning: Vision OCR service not available")
|
||||
# Helpers
|
||||
from ocr_labeling_helpers import ( # noqa: F401
|
||||
VISION_OCR_AVAILABLE,
|
||||
PADDLEOCR_AVAILABLE,
|
||||
TROCR_AVAILABLE,
|
||||
DONUT_AVAILABLE,
|
||||
MINIO_AVAILABLE,
|
||||
TRAINING_EXPORT_AVAILABLE,
|
||||
compute_image_hash,
|
||||
run_ocr_on_image,
|
||||
run_vision_ocr_wrapper,
|
||||
run_paddleocr_wrapper,
|
||||
run_trocr_wrapper,
|
||||
run_donut_wrapper,
|
||||
save_image_locally,
|
||||
get_image_url,
|
||||
)
|
||||
|
||||
# Try to import PaddleOCR from hybrid_vocab_extractor
|
||||
# Conditional re-exports from helpers' optional imports
|
||||
try:
|
||||
from hybrid_vocab_extractor import run_paddle_ocr
|
||||
PADDLEOCR_AVAILABLE = True
|
||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401
|
||||
except ImportError:
|
||||
PADDLEOCR_AVAILABLE = False
|
||||
print("Warning: PaddleOCR not available")
|
||||
pass
|
||||
|
||||
# Try to import TrOCR service
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr
|
||||
TROCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
TROCR_AVAILABLE = False
|
||||
print("Warning: TrOCR service not available")
|
||||
|
||||
# Try to import Donut service
|
||||
try:
|
||||
from services.donut_ocr_service import run_donut_ocr
|
||||
DONUT_AVAILABLE = True
|
||||
except ImportError:
|
||||
DONUT_AVAILABLE = False
|
||||
print("Warning: Donut OCR service not available")
|
||||
|
||||
# Try to import MinIO storage
|
||||
try:
|
||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
|
||||
MINIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
MINIO_AVAILABLE = False
|
||||
print("Warning: MinIO storage not available, using local storage")
|
||||
|
||||
# Try to import Training Export Service
|
||||
try:
|
||||
from training_export_service import (
|
||||
from training_export_service import ( # noqa: F401
|
||||
TrainingExportService,
|
||||
TrainingSample,
|
||||
get_training_export_service,
|
||||
)
|
||||
TRAINING_EXPORT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRAINING_EXPORT_AVAILABLE = False
|
||||
print("Warning: Training export service not available")
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||
|
||||
# Local storage path (fallback if MinIO not available)
|
||||
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
class SessionCreate(BaseModel):
|
||||
name: str
|
||||
source_type: str = "klausur" # klausur, handwriting_sample, scan
|
||||
description: Optional[str] = None
|
||||
ocr_model: Optional[str] = "llama3.2-vision:11b"
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
source_type: str
|
||||
description: Optional[str]
|
||||
ocr_model: Optional[str]
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
skipped_items: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ItemResponse(BaseModel):
|
||||
id: str
|
||||
session_id: str
|
||||
session_name: str
|
||||
image_path: str
|
||||
image_url: Optional[str]
|
||||
ocr_text: Optional[str]
|
||||
ocr_confidence: Optional[float]
|
||||
ground_truth: Optional[str]
|
||||
status: str
|
||||
metadata: Optional[Dict]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ConfirmRequest(BaseModel):
|
||||
item_id: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class CorrectRequest(BaseModel):
|
||||
item_id: str
|
||||
ground_truth: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
item_id: str
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
export_format: str = "generic" # generic, trocr, llama_vision
|
||||
session_id: Optional[str] = None
|
||||
batch_id: Optional[str] = None
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_sessions: Optional[int] = None
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
pending_items: int
|
||||
exportable_items: Optional[int] = None
|
||||
accuracy_rate: float
|
||||
avg_label_time_seconds: Optional[float] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data."""
|
||||
return hashlib.sha256(image_data).hexdigest()
|
||||
|
||||
|
||||
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
|
||||
"""
|
||||
Run OCR on an image using the specified model.
|
||||
|
||||
Models:
|
||||
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
|
||||
- trocr: Microsoft TrOCR (fast for printed text)
|
||||
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
|
||||
- donut: Document Understanding Transformer (structured documents)
|
||||
|
||||
Returns:
|
||||
Tuple of (ocr_text, confidence)
|
||||
"""
|
||||
print(f"Running OCR with model: {model}")
|
||||
|
||||
# Route to appropriate OCR service based on model
|
||||
if model == "paddleocr":
|
||||
return await run_paddleocr_wrapper(image_data, filename)
|
||||
elif model == "donut":
|
||||
return await run_donut_wrapper(image_data, filename)
|
||||
elif model == "trocr":
|
||||
return await run_trocr_wrapper(image_data, filename)
|
||||
else:
|
||||
# Default: Vision LLM (llama3.2-vision or similar)
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Vision LLM OCR wrapper."""
|
||||
if not VISION_OCR_AVAILABLE:
|
||||
print("Vision OCR service not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
service = get_vision_ocr_service()
|
||||
if not await service.is_available():
|
||||
print("Vision OCR service not available (is_available check failed)")
|
||||
return None, 0.0
|
||||
|
||||
result = await service.extract_text(
|
||||
image_data,
|
||||
filename=filename,
|
||||
is_handwriting=True
|
||||
)
|
||||
return result.text, result.confidence
|
||||
except Exception as e:
|
||||
print(f"Vision OCR failed: {e}")
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
|
||||
if not PADDLEOCR_AVAILABLE:
|
||||
print("PaddleOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
# run_paddle_ocr returns (regions, raw_text)
|
||||
regions, raw_text = run_paddle_ocr(image_data)
|
||||
|
||||
if not raw_text:
|
||||
print("PaddleOCR returned empty text")
|
||||
return None, 0.0
|
||||
|
||||
# Calculate average confidence from regions
|
||||
if regions:
|
||||
avg_confidence = sum(r.confidence for r in regions) / len(regions)
|
||||
else:
|
||||
avg_confidence = 0.5
|
||||
|
||||
return raw_text, avg_confidence
|
||||
except Exception as e:
|
||||
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""TrOCR wrapper."""
|
||||
if not TROCR_AVAILABLE:
|
||||
print("TrOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_trocr_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"TrOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Donut OCR wrapper."""
|
||||
if not DONUT_AVAILABLE:
|
||||
print("Donut not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_donut_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
|
||||
"""Save image to local storage."""
|
||||
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
|
||||
os.makedirs(session_dir, exist_ok=True)
|
||||
|
||||
filename = f"{item_id}.{extension}"
|
||||
filepath = os.path.join(session_dir, filename)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def get_image_url(image_path: str) -> str:
|
||||
"""Get URL for an image."""
|
||||
# For local images, return a relative path that the frontend can use
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
|
||||
return f"/api/v1/ocr-label/images/{relative_path}"
|
||||
# For MinIO images, the path is already a URL or key
|
||||
return image_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(session: SessionCreate):
|
||||
"""
|
||||
Create a new OCR labeling session.
|
||||
|
||||
A session groups related images for labeling (e.g., all scans from one class).
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
success = await create_ocr_labeling_session(
|
||||
session_id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to create session")
|
||||
|
||||
return SessionResponse(
|
||||
id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
total_items=0,
|
||||
labeled_items=0,
|
||||
confirmed_items=0,
|
||||
corrected_items=0,
|
||||
skipped_items=0,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionResponse])
|
||||
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
|
||||
"""List all OCR labeling sessions."""
|
||||
sessions = await get_ocr_labeling_sessions(limit=limit)
|
||||
|
||||
return [
|
||||
SessionResponse(
|
||||
id=s['id'],
|
||||
name=s['name'],
|
||||
source_type=s['source_type'],
|
||||
description=s.get('description'),
|
||||
ocr_model=s.get('ocr_model'),
|
||||
total_items=s.get('total_items', 0),
|
||||
labeled_items=s.get('labeled_items', 0),
|
||||
confirmed_items=s.get('confirmed_items', 0),
|
||||
corrected_items=s.get('corrected_items', 0),
|
||||
skipped_items=s.get('skipped_items', 0),
|
||||
created_at=s.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for s in sessions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
||||
async def get_session(session_id: str):
|
||||
"""Get a specific OCR labeling session."""
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return SessionResponse(
|
||||
id=session['id'],
|
||||
name=session['name'],
|
||||
source_type=session['source_type'],
|
||||
description=session.get('description'),
|
||||
ocr_model=session.get('ocr_model'),
|
||||
total_items=session.get('total_items', 0),
|
||||
labeled_items=session.get('labeled_items', 0),
|
||||
confirmed_items=session.get('confirmed_items', 0),
|
||||
corrected_items=session.get('corrected_items', 0),
|
||||
skipped_items=session.get('skipped_items', 0),
|
||||
created_at=session.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/upload")
|
||||
async def upload_images(
|
||||
session_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
files: List[UploadFile] = File(...),
|
||||
run_ocr: bool = Form(True),
|
||||
metadata: Optional[str] = Form(None), # JSON string
|
||||
):
|
||||
"""
|
||||
Upload images to a labeling session.
|
||||
|
||||
Args:
|
||||
session_id: Session to add images to
|
||||
files: Image files to upload (PNG, JPG, PDF)
|
||||
run_ocr: Whether to run OCR immediately (default: True)
|
||||
metadata: Optional JSON metadata (subject, year, etc.)
|
||||
"""
|
||||
import json
|
||||
|
||||
# Verify session exists
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Parse metadata
|
||||
meta_dict = None
|
||||
if metadata:
|
||||
try:
|
||||
meta_dict = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
meta_dict = {"raw": metadata}
|
||||
|
||||
results = []
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
|
||||
|
||||
for file in files:
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
# Compute hash for deduplication
|
||||
image_hash = compute_image_hash(content)
|
||||
|
||||
# Generate item ID
|
||||
item_id = str(uuid.uuid4())
|
||||
|
||||
# Determine file extension
|
||||
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
|
||||
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
|
||||
extension = 'png'
|
||||
|
||||
# Save image
|
||||
if MINIO_AVAILABLE:
|
||||
# Upload to MinIO
|
||||
try:
|
||||
image_path = upload_ocr_image(session_id, item_id, content, extension)
|
||||
except Exception as e:
|
||||
print(f"MinIO upload failed, using local storage: {e}")
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
else:
|
||||
# Save locally
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
|
||||
# Run OCR if requested
|
||||
ocr_text = None
|
||||
ocr_confidence = None
|
||||
|
||||
if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
content,
|
||||
file.filename or f"{item_id}.{extension}",
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
# Add to database
|
||||
success = await add_ocr_labeling_item(
|
||||
item_id=item_id,
|
||||
session_id=session_id,
|
||||
image_path=image_path,
|
||||
image_hash=image_hash,
|
||||
ocr_text=ocr_text,
|
||||
ocr_confidence=ocr_confidence,
|
||||
ocr_model=ocr_model if ocr_text else None,
|
||||
metadata=meta_dict,
|
||||
)
|
||||
|
||||
if success:
|
||||
results.append({
|
||||
"id": item_id,
|
||||
"filename": file.filename,
|
||||
"image_path": image_path,
|
||||
"image_hash": image_hash,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"status": "pending",
|
||||
})
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"uploaded_count": len(results),
|
||||
"items": results,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/queue", response_model=List[ItemResponse])
|
||||
async def get_labeling_queue(
|
||||
session_id: Optional[str] = Query(None),
|
||||
status: str = Query("pending"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
):
|
||||
"""
|
||||
Get items from the labeling queue.
|
||||
|
||||
Args:
|
||||
session_id: Optional filter by session
|
||||
status: Filter by status (pending, confirmed, corrected, skipped)
|
||||
limit: Number of items to return
|
||||
"""
|
||||
items = await get_ocr_labeling_queue(
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
@router.get("/items/{item_id}", response_model=ItemResponse)
|
||||
async def get_item(item_id: str):
|
||||
"""Get a specific labeling item."""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
return ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/confirm")
|
||||
async def confirm_item(request: ConfirmRequest):
|
||||
"""
|
||||
Confirm that OCR text is correct.
|
||||
|
||||
Sets ground_truth = ocr_text and marks item as confirmed.
|
||||
"""
|
||||
success = await confirm_ocr_label(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to confirm item")
|
||||
|
||||
return {"status": "confirmed", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/correct")
|
||||
async def correct_item(request: CorrectRequest):
|
||||
"""
|
||||
Save corrected ground truth for an item.
|
||||
|
||||
Use this when OCR text is wrong and needs manual correction.
|
||||
"""
|
||||
success = await correct_ocr_label(
|
||||
item_id=request.item_id,
|
||||
ground_truth=request.ground_truth,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to correct item")
|
||||
|
||||
return {"status": "corrected", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/skip")
|
||||
async def skip_item(request: SkipRequest):
|
||||
"""
|
||||
Skip an item (unusable image, etc.).
|
||||
|
||||
Skipped items are not included in training exports.
|
||||
"""
|
||||
success = await skip_ocr_item(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to skip item")
|
||||
|
||||
return {"status": "skipped", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(session_id: Optional[str] = Query(None)):
|
||||
"""
|
||||
Get labeling statistics.
|
||||
|
||||
Args:
|
||||
session_id: Optional session ID for session-specific stats
|
||||
"""
|
||||
stats = await get_ocr_labeling_stats(session_id=session_id)
|
||||
|
||||
if "error" in stats:
|
||||
raise HTTPException(status_code=500, detail=stats["error"])
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
async def export_data(request: ExportRequest):
|
||||
"""
|
||||
Export labeled data for training.
|
||||
|
||||
Formats:
|
||||
- generic: JSONL with image_path and ground_truth
|
||||
- trocr: Format for TrOCR/Microsoft Transformer fine-tuning
|
||||
- llama_vision: Format for llama3.2-vision fine-tuning
|
||||
|
||||
Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/
|
||||
"""
|
||||
# First, get samples from database
|
||||
db_samples = await export_training_samples(
|
||||
export_format=request.export_format,
|
||||
session_id=request.session_id,
|
||||
batch_id=request.batch_id,
|
||||
exported_by="admin", # TODO: Get from auth
|
||||
)
|
||||
|
||||
if not db_samples:
|
||||
return {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id,
|
||||
"exported_count": 0,
|
||||
"samples": [],
|
||||
"message": "No labeled samples found to export",
|
||||
}
|
||||
|
||||
# If training export service is available, also write to disk
|
||||
export_result = None
|
||||
if TRAINING_EXPORT_AVAILABLE:
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
|
||||
# Convert DB samples to TrainingSample objects
|
||||
training_samples = []
|
||||
for s in db_samples:
|
||||
training_samples.append(TrainingSample(
|
||||
id=s.get('id', s.get('item_id', '')),
|
||||
image_path=s.get('image_path', ''),
|
||||
ground_truth=s.get('ground_truth', ''),
|
||||
ocr_text=s.get('ocr_text'),
|
||||
ocr_confidence=s.get('ocr_confidence'),
|
||||
metadata=s.get('metadata'),
|
||||
))
|
||||
|
||||
# Export to files
|
||||
export_result = export_service.export(
|
||||
samples=training_samples,
|
||||
export_format=request.export_format,
|
||||
batch_id=request.batch_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Training export failed: {e}")
|
||||
# Continue without file export
|
||||
|
||||
response = {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
|
||||
"exported_count": len(db_samples),
|
||||
"samples": db_samples,
|
||||
}
|
||||
|
||||
if export_result:
|
||||
response["export_path"] = export_result.export_path
|
||||
response["manifest_path"] = export_result.manifest_path
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/training-samples")
|
||||
async def list_training_samples(
|
||||
export_format: Optional[str] = Query(None),
|
||||
batch_id: Optional[str] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
):
|
||||
"""Get exported training samples."""
|
||||
samples = await get_training_samples(
|
||||
export_format=export_format,
|
||||
batch_id=batch_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return {
|
||||
"count": len(samples),
|
||||
"samples": samples,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/images/{path:path}")
|
||||
async def get_image(path: str):
|
||||
"""
|
||||
Serve an image from local storage.
|
||||
|
||||
This endpoint is used when images are stored locally (not in MinIO).
|
||||
"""
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
# Determine content type
|
||||
extension = filepath.split('.')[-1].lower()
|
||||
content_type = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'pdf': 'application/pdf',
|
||||
}.get(extension, 'application/octet-stream')
|
||||
|
||||
return FileResponse(filepath, media_type=content_type)
|
||||
|
||||
|
||||
@router.post("/run-ocr/{item_id}")
|
||||
async def run_ocr_for_item(item_id: str):
|
||||
"""
|
||||
Run OCR on an existing item.
|
||||
|
||||
Use this to re-run OCR or run it if it was skipped during upload.
|
||||
"""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
# Load image
|
||||
image_path = item['image_path']
|
||||
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
# Load from local storage
|
||||
if not os.path.exists(image_path):
|
||||
raise HTTPException(status_code=404, detail="Image file not found")
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
elif MINIO_AVAILABLE:
|
||||
# Load from MinIO
|
||||
try:
|
||||
image_data = get_ocr_image(image_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Cannot load image")
|
||||
|
||||
# Get OCR model from session
|
||||
session = await get_ocr_labeling_session(item['session_id'])
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
|
||||
|
||||
# Run OCR
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
image_data,
|
||||
os.path.basename(image_path),
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
if ocr_text is None:
|
||||
raise HTTPException(status_code=500, detail="OCR failed")
|
||||
|
||||
# Update item in database
|
||||
from metrics_db import get_pool
|
||||
pool = await get_pool()
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE ocr_labeling_items
|
||||
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
|
||||
WHERE id = $1
|
||||
""",
|
||||
item_id, ocr_text, ocr_confidence, ocr_model
|
||||
)
|
||||
|
||||
return {
|
||||
"item_id": item_id,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"ocr_model": ocr_model,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/exports")
|
||||
async def list_exports(export_format: Optional[str] = Query(None)):
|
||||
"""
|
||||
List all available training data exports.
|
||||
|
||||
Args:
|
||||
export_format: Optional filter by format (generic, trocr, llama_vision)
|
||||
|
||||
Returns:
|
||||
List of export manifests with paths and metadata
|
||||
"""
|
||||
if not TRAINING_EXPORT_AVAILABLE:
|
||||
return {
|
||||
"exports": [],
|
||||
"message": "Training export service not available",
|
||||
}
|
||||
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
exports = export_service.list_exports(export_format=export_format)
|
||||
|
||||
return {
|
||||
"count": len(exports),
|
||||
"exports": exports,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
|
||||
pass
|
||||
|
||||
try:
|
||||
from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from services.donut_ocr_service import run_donut_ocr # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Routes (router is the main export for app.include_router)
|
||||
from ocr_labeling_routes import router # noqa: F401
|
||||
from ocr_labeling_upload_routes import router as upload_router # noqa: F401
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
OCR Labeling - Helper Functions and OCR Wrappers
|
||||
|
||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||
|
||||
DATENSCHUTZ/PRIVACY:
|
||||
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
|
||||
- Keine Daten werden an externe Server gesendet
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from ocr_labeling_models import LOCAL_STORAGE_PATH
|
||||
|
||||
# Try to import Vision OCR service
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
|
||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
|
||||
VISION_OCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
VISION_OCR_AVAILABLE = False
|
||||
print("Warning: Vision OCR service not available")
|
||||
|
||||
# Try to import PaddleOCR from hybrid_vocab_extractor
|
||||
try:
|
||||
from hybrid_vocab_extractor import run_paddle_ocr
|
||||
PADDLEOCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
PADDLEOCR_AVAILABLE = False
|
||||
print("Warning: PaddleOCR not available")
|
||||
|
||||
# Try to import TrOCR service
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr
|
||||
TROCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
TROCR_AVAILABLE = False
|
||||
print("Warning: TrOCR service not available")
|
||||
|
||||
# Try to import Donut service
|
||||
try:
|
||||
from services.donut_ocr_service import run_donut_ocr
|
||||
DONUT_AVAILABLE = True
|
||||
except ImportError:
|
||||
DONUT_AVAILABLE = False
|
||||
print("Warning: Donut OCR service not available")
|
||||
|
||||
# Try to import MinIO storage
|
||||
try:
|
||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
|
||||
MINIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
MINIO_AVAILABLE = False
|
||||
print("Warning: MinIO storage not available, using local storage")
|
||||
|
||||
# Try to import Training Export Service
|
||||
try:
|
||||
from training_export_service import (
|
||||
TrainingExportService,
|
||||
TrainingSample,
|
||||
get_training_export_service,
|
||||
)
|
||||
TRAINING_EXPORT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRAINING_EXPORT_AVAILABLE = False
|
||||
print("Warning: Training export service not available")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data."""
|
||||
return hashlib.sha256(image_data).hexdigest()
|
||||
|
||||
|
||||
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
|
||||
"""
|
||||
Run OCR on an image using the specified model.
|
||||
|
||||
Models:
|
||||
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
|
||||
- trocr: Microsoft TrOCR (fast for printed text)
|
||||
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
|
||||
- donut: Document Understanding Transformer (structured documents)
|
||||
|
||||
Returns:
|
||||
Tuple of (ocr_text, confidence)
|
||||
"""
|
||||
print(f"Running OCR with model: {model}")
|
||||
|
||||
# Route to appropriate OCR service based on model
|
||||
if model == "paddleocr":
|
||||
return await run_paddleocr_wrapper(image_data, filename)
|
||||
elif model == "donut":
|
||||
return await run_donut_wrapper(image_data, filename)
|
||||
elif model == "trocr":
|
||||
return await run_trocr_wrapper(image_data, filename)
|
||||
else:
|
||||
# Default: Vision LLM (llama3.2-vision or similar)
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Vision LLM OCR wrapper."""
|
||||
if not VISION_OCR_AVAILABLE:
|
||||
print("Vision OCR service not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
service = get_vision_ocr_service()
|
||||
if not await service.is_available():
|
||||
print("Vision OCR service not available (is_available check failed)")
|
||||
return None, 0.0
|
||||
|
||||
result = await service.extract_text(
|
||||
image_data,
|
||||
filename=filename,
|
||||
is_handwriting=True
|
||||
)
|
||||
return result.text, result.confidence
|
||||
except Exception as e:
|
||||
print(f"Vision OCR failed: {e}")
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
|
||||
if not PADDLEOCR_AVAILABLE:
|
||||
print("PaddleOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
# run_paddle_ocr returns (regions, raw_text)
|
||||
regions, raw_text = run_paddle_ocr(image_data)
|
||||
|
||||
if not raw_text:
|
||||
print("PaddleOCR returned empty text")
|
||||
return None, 0.0
|
||||
|
||||
# Calculate average confidence from regions
|
||||
if regions:
|
||||
avg_confidence = sum(r.confidence for r in regions) / len(regions)
|
||||
else:
|
||||
avg_confidence = 0.5
|
||||
|
||||
return raw_text, avg_confidence
|
||||
except Exception as e:
|
||||
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""TrOCR wrapper."""
|
||||
if not TROCR_AVAILABLE:
|
||||
print("TrOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_trocr_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"TrOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Donut OCR wrapper."""
|
||||
if not DONUT_AVAILABLE:
|
||||
print("Donut not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_donut_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
|
||||
"""Save image to local storage."""
|
||||
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
|
||||
os.makedirs(session_dir, exist_ok=True)
|
||||
|
||||
filename = f"{item_id}.{extension}"
|
||||
filepath = os.path.join(session_dir, filename)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def get_image_url(image_path: str) -> str:
|
||||
"""Get URL for an image."""
|
||||
# For local images, return a relative path that the frontend can use
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
|
||||
return f"/api/v1/ocr-label/images/{relative_path}"
|
||||
# For MinIO images, the path is already a URL or key
|
||||
return image_path
|
||||
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
OCR Labeling - Pydantic Models and Constants
|
||||
|
||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# Local storage path (fallback if MinIO not available)
|
||||
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
class SessionCreate(BaseModel):
|
||||
name: str
|
||||
source_type: str = "klausur" # klausur, handwriting_sample, scan
|
||||
description: Optional[str] = None
|
||||
ocr_model: Optional[str] = "llama3.2-vision:11b"
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
source_type: str
|
||||
description: Optional[str]
|
||||
ocr_model: Optional[str]
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
skipped_items: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ItemResponse(BaseModel):
|
||||
id: str
|
||||
session_id: str
|
||||
session_name: str
|
||||
image_path: str
|
||||
image_url: Optional[str]
|
||||
ocr_text: Optional[str]
|
||||
ocr_confidence: Optional[float]
|
||||
ground_truth: Optional[str]
|
||||
status: str
|
||||
metadata: Optional[Dict]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ConfirmRequest(BaseModel):
|
||||
item_id: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class CorrectRequest(BaseModel):
|
||||
item_id: str
|
||||
ground_truth: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
item_id: str
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
export_format: str = "generic" # generic, trocr, llama_vision
|
||||
session_id: Optional[str] = None
|
||||
batch_id: Optional[str] = None
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_sessions: Optional[int] = None
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
pending_items: int
|
||||
exportable_items: Optional[int] = None
|
||||
accuracy_rate: float
|
||||
avg_label_time_seconds: Optional[float] = None
|
||||
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
OCR Labeling - Session and Labeling Route Handlers
|
||||
|
||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||
|
||||
Endpoints:
|
||||
- POST /sessions - Create labeling session
|
||||
- GET /sessions - List sessions
|
||||
- GET /sessions/{id} - Get session
|
||||
- GET /queue - Get labeling queue
|
||||
- GET /items/{id} - Get item
|
||||
- POST /confirm - Confirm OCR
|
||||
- POST /correct - Correct ground truth
|
||||
- POST /skip - Skip item
|
||||
- GET /stats - Get statistics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from metrics_db import (
|
||||
create_ocr_labeling_session,
|
||||
get_ocr_labeling_sessions,
|
||||
get_ocr_labeling_session,
|
||||
get_ocr_labeling_queue,
|
||||
get_ocr_labeling_item,
|
||||
confirm_ocr_label,
|
||||
correct_ocr_label,
|
||||
skip_ocr_item,
|
||||
get_ocr_labeling_stats,
|
||||
)
|
||||
|
||||
from ocr_labeling_models import (
|
||||
SessionCreate, SessionResponse, ItemResponse,
|
||||
ConfirmRequest, CorrectRequest, SkipRequest,
|
||||
)
|
||||
from ocr_labeling_helpers import get_image_url
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(session: SessionCreate):
|
||||
"""Create a new OCR labeling session."""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
success = await create_ocr_labeling_session(
|
||||
session_id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to create session")
|
||||
|
||||
return SessionResponse(
|
||||
id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
total_items=0,
|
||||
labeled_items=0,
|
||||
confirmed_items=0,
|
||||
corrected_items=0,
|
||||
skipped_items=0,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionResponse])
|
||||
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
|
||||
"""List all OCR labeling sessions."""
|
||||
sessions = await get_ocr_labeling_sessions(limit=limit)
|
||||
|
||||
return [
|
||||
SessionResponse(
|
||||
id=s['id'],
|
||||
name=s['name'],
|
||||
source_type=s['source_type'],
|
||||
description=s.get('description'),
|
||||
ocr_model=s.get('ocr_model'),
|
||||
total_items=s.get('total_items', 0),
|
||||
labeled_items=s.get('labeled_items', 0),
|
||||
confirmed_items=s.get('confirmed_items', 0),
|
||||
corrected_items=s.get('corrected_items', 0),
|
||||
skipped_items=s.get('skipped_items', 0),
|
||||
created_at=s.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for s in sessions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
||||
async def get_session(session_id: str):
|
||||
"""Get a specific OCR labeling session."""
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return SessionResponse(
|
||||
id=session['id'],
|
||||
name=session['name'],
|
||||
source_type=session['source_type'],
|
||||
description=session.get('description'),
|
||||
ocr_model=session.get('ocr_model'),
|
||||
total_items=session.get('total_items', 0),
|
||||
labeled_items=session.get('labeled_items', 0),
|
||||
confirmed_items=session.get('confirmed_items', 0),
|
||||
corrected_items=session.get('corrected_items', 0),
|
||||
skipped_items=session.get('skipped_items', 0),
|
||||
created_at=session.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Queue and Item Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/queue", response_model=List[ItemResponse])
|
||||
async def get_labeling_queue(
|
||||
session_id: Optional[str] = Query(None),
|
||||
status: str = Query("pending"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
):
|
||||
"""Get items from the labeling queue."""
|
||||
items = await get_ocr_labeling_queue(
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
@router.get("/items/{item_id}", response_model=ItemResponse)
|
||||
async def get_item(item_id: str):
|
||||
"""Get a specific labeling item."""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
return ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Labeling Action Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/confirm")
|
||||
async def confirm_item(request: ConfirmRequest):
|
||||
"""Confirm that OCR text is correct."""
|
||||
success = await confirm_ocr_label(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin",
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to confirm item")
|
||||
|
||||
return {"status": "confirmed", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/correct")
|
||||
async def correct_item(request: CorrectRequest):
|
||||
"""Save corrected ground truth for an item."""
|
||||
success = await correct_ocr_label(
|
||||
item_id=request.item_id,
|
||||
ground_truth=request.ground_truth,
|
||||
labeled_by="admin",
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to correct item")
|
||||
|
||||
return {"status": "corrected", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/skip")
|
||||
async def skip_item(request: SkipRequest):
|
||||
"""Skip an item (unusable image, etc.)."""
|
||||
success = await skip_ocr_item(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin",
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to skip item")
|
||||
|
||||
return {"status": "skipped", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(session_id: Optional[str] = Query(None)):
|
||||
"""Get labeling statistics."""
|
||||
stats = await get_ocr_labeling_stats(session_id=session_id)
|
||||
|
||||
if "error" in stats:
|
||||
raise HTTPException(status_code=500, detail=stats["error"])
|
||||
|
||||
return stats
|
||||
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
OCR Labeling - Upload, Run-OCR, and Export Route Handlers
|
||||
|
||||
Extracted from ocr_labeling_routes.py to keep files under 500 LOC.
|
||||
|
||||
Endpoints:
|
||||
- POST /sessions/{id}/upload - Upload images for labeling
|
||||
- POST /run-ocr/{item_id} - Run OCR on existing item
|
||||
- POST /export - Export training data
|
||||
- GET /training-samples - List training samples
|
||||
- GET /images/{path} - Serve images from local storage
|
||||
- GET /exports - List exports
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from typing import Optional, List
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from metrics_db import (
|
||||
get_ocr_labeling_session,
|
||||
add_ocr_labeling_item,
|
||||
get_ocr_labeling_item,
|
||||
export_training_samples,
|
||||
get_training_samples,
|
||||
)
|
||||
|
||||
from ocr_labeling_models import (
|
||||
ExportRequest,
|
||||
LOCAL_STORAGE_PATH,
|
||||
)
|
||||
from ocr_labeling_helpers import (
|
||||
compute_image_hash, run_ocr_on_image,
|
||||
save_image_locally,
|
||||
MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE,
|
||||
)
|
||||
|
||||
# Conditional imports
|
||||
try:
|
||||
from minio_storage import upload_ocr_image, get_ocr_image
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from training_export_service import TrainingSample, get_training_export_service
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/upload")
|
||||
async def upload_images(
|
||||
session_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
run_ocr: bool = Form(True),
|
||||
metadata: Optional[str] = Form(None),
|
||||
):
|
||||
"""
|
||||
Upload images to a labeling session.
|
||||
|
||||
Args:
|
||||
session_id: Session to add images to
|
||||
files: Image files to upload (PNG, JPG, PDF)
|
||||
run_ocr: Whether to run OCR immediately (default: True)
|
||||
metadata: Optional JSON metadata (subject, year, etc.)
|
||||
"""
|
||||
import json
|
||||
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
meta_dict = None
|
||||
if metadata:
|
||||
try:
|
||||
meta_dict = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
meta_dict = {"raw": metadata}
|
||||
|
||||
results = []
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
|
||||
|
||||
for file in files:
|
||||
content = await file.read()
|
||||
image_hash = compute_image_hash(content)
|
||||
item_id = str(uuid.uuid4())
|
||||
|
||||
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
|
||||
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
|
||||
extension = 'png'
|
||||
|
||||
if MINIO_AVAILABLE:
|
||||
try:
|
||||
image_path = upload_ocr_image(session_id, item_id, content, extension)
|
||||
except Exception as e:
|
||||
print(f"MinIO upload failed, using local storage: {e}")
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
else:
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
|
||||
ocr_text = None
|
||||
ocr_confidence = None
|
||||
|
||||
if run_ocr and extension != 'pdf':
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
content,
|
||||
file.filename or f"{item_id}.{extension}",
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
success = await add_ocr_labeling_item(
|
||||
item_id=item_id,
|
||||
session_id=session_id,
|
||||
image_path=image_path,
|
||||
image_hash=image_hash,
|
||||
ocr_text=ocr_text,
|
||||
ocr_confidence=ocr_confidence,
|
||||
ocr_model=ocr_model if ocr_text else None,
|
||||
metadata=meta_dict,
|
||||
)
|
||||
|
||||
if success:
|
||||
results.append({
|
||||
"id": item_id,
|
||||
"filename": file.filename,
|
||||
"image_path": image_path,
|
||||
"image_hash": image_hash,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"status": "pending",
|
||||
})
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"uploaded_count": len(results),
|
||||
"items": results,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
async def export_data(request: ExportRequest):
|
||||
"""Export labeled data for training."""
|
||||
db_samples = await export_training_samples(
|
||||
export_format=request.export_format,
|
||||
session_id=request.session_id,
|
||||
batch_id=request.batch_id,
|
||||
exported_by="admin",
|
||||
)
|
||||
|
||||
if not db_samples:
|
||||
return {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id,
|
||||
"exported_count": 0,
|
||||
"samples": [],
|
||||
"message": "No labeled samples found to export",
|
||||
}
|
||||
|
||||
export_result = None
|
||||
if TRAINING_EXPORT_AVAILABLE:
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
|
||||
training_samples = []
|
||||
for s in db_samples:
|
||||
training_samples.append(TrainingSample(
|
||||
id=s.get('id', s.get('item_id', '')),
|
||||
image_path=s.get('image_path', ''),
|
||||
ground_truth=s.get('ground_truth', ''),
|
||||
ocr_text=s.get('ocr_text'),
|
||||
ocr_confidence=s.get('ocr_confidence'),
|
||||
metadata=s.get('metadata'),
|
||||
))
|
||||
|
||||
export_result = export_service.export(
|
||||
samples=training_samples,
|
||||
export_format=request.export_format,
|
||||
batch_id=request.batch_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Training export failed: {e}")
|
||||
|
||||
response = {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
|
||||
"exported_count": len(db_samples),
|
||||
"samples": db_samples,
|
||||
}
|
||||
|
||||
if export_result:
|
||||
response["export_path"] = export_result.export_path
|
||||
response["manifest_path"] = export_result.manifest_path
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/training-samples")
|
||||
async def list_training_samples(
|
||||
export_format: Optional[str] = Query(None),
|
||||
batch_id: Optional[str] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
):
|
||||
"""Get exported training samples."""
|
||||
samples = await get_training_samples(
|
||||
export_format=export_format,
|
||||
batch_id=batch_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return {
|
||||
"count": len(samples),
|
||||
"samples": samples,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/images/{path:path}")
|
||||
async def get_image(path: str):
|
||||
"""Serve an image from local storage."""
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
extension = filepath.split('.')[-1].lower()
|
||||
content_type = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'pdf': 'application/pdf',
|
||||
}.get(extension, 'application/octet-stream')
|
||||
|
||||
return FileResponse(filepath, media_type=content_type)
|
||||
|
||||
|
||||
@router.post("/run-ocr/{item_id}")
|
||||
async def run_ocr_for_item(item_id: str):
|
||||
"""Run OCR on an existing item."""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
image_path = item['image_path']
|
||||
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
if not os.path.exists(image_path):
|
||||
raise HTTPException(status_code=404, detail="Image file not found")
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
elif MINIO_AVAILABLE:
|
||||
try:
|
||||
image_data = get_ocr_image(image_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Cannot load image")
|
||||
|
||||
session = await get_ocr_labeling_session(item['session_id'])
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
|
||||
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
image_data,
|
||||
os.path.basename(image_path),
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
if ocr_text is None:
|
||||
raise HTTPException(status_code=500, detail="OCR failed")
|
||||
|
||||
from metrics_db import get_pool
|
||||
pool = await get_pool()
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE ocr_labeling_items
|
||||
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
|
||||
WHERE id = $1
|
||||
""",
|
||||
item_id, ocr_text, ocr_confidence, ocr_model
|
||||
)
|
||||
|
||||
return {
|
||||
"item_id": item_id,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"ocr_model": ocr_model,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/exports")
|
||||
async def list_exports(export_format: Optional[str] = Query(None)):
|
||||
"""List all available training data exports."""
|
||||
if not TRAINING_EXPORT_AVAILABLE:
|
||||
return {
|
||||
"exports": [],
|
||||
"message": "Training export service not available",
|
||||
}
|
||||
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
exports = export_service.list_exports(export_format=export_format)
|
||||
|
||||
return {
|
||||
"count": len(exports),
|
||||
"exports": exports,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
|
||||
@@ -1,705 +1,23 @@
|
||||
"""
|
||||
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints.
|
||||
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export.
|
||||
|
||||
Extracted from ocr_pipeline_api.py — contains:
|
||||
- POST /sessions/{session_id}/reprocess (clear downstream + restart from step)
|
||||
- POST /sessions/{session_id}/run-auto (full auto-mode with SSE streaming)
|
||||
Split into submodules:
|
||||
- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess
|
||||
- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
OLLAMA_REVIEW_MODEL,
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
_cells_to_vocab_entries,
|
||||
_detect_header_footer_gaps,
|
||||
_detect_sub_columns,
|
||||
_fix_character_confusion,
|
||||
_fix_phonetic_brackets,
|
||||
fix_cell_phonetics,
|
||||
analyze_layout,
|
||||
build_cell_grid,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
create_ocr_image,
|
||||
deskew_image,
|
||||
deskew_image_by_word_alignment,
|
||||
detect_column_geometry,
|
||||
detect_row_geometry,
|
||||
_apply_shear,
|
||||
dewarp_image,
|
||||
llm_review_entries,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
_get_base_image_png,
|
||||
_append_pipeline_log,
|
||||
)
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ocr_pipeline_reprocess import router as _reprocess_router
|
||||
from ocr_pipeline_auto_steps import router as _steps_router
|
||||
|
||||
# Combine both sub-routers into a single router for backwards compatibility.
|
||||
# The consumer imports `from ocr_pipeline_auto import router as _auto_router`.
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
router.include_router(_reprocess_router)
|
||||
router.include_router(_steps_router)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reprocess endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/reprocess")
|
||||
async def reprocess_session(session_id: str, request: Request):
|
||||
"""Re-run pipeline from a specific step, clearing downstream data.
|
||||
|
||||
Body: {"from_step": 5} (1-indexed step number)
|
||||
|
||||
Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) →
|
||||
Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10)
|
||||
|
||||
Clears downstream results:
|
||||
- from_step <= 1: orientation_result + all downstream
|
||||
- from_step <= 2: deskew_result + all downstream
|
||||
- from_step <= 3: dewarp_result + all downstream
|
||||
- from_step <= 4: crop_result + all downstream
|
||||
- from_step <= 5: column_result, row_result, word_result
|
||||
- from_step <= 6: row_result, word_result
|
||||
- from_step <= 7: word_result (cells, vocab_entries)
|
||||
- from_step <= 8: word_result.llm_review only
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
body = await request.json()
|
||||
from_step = body.get("from_step", 1)
|
||||
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
|
||||
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
|
||||
|
||||
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
||||
|
||||
# Clear downstream data based on from_step
|
||||
# New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) →
|
||||
# Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11)
|
||||
if from_step <= 8:
|
||||
update_kwargs["word_result"] = None
|
||||
elif from_step == 9:
|
||||
# Only clear LLM review from word_result
|
||||
word_result = session.get("word_result")
|
||||
if word_result:
|
||||
word_result.pop("llm_review", None)
|
||||
word_result.pop("llm_corrections", None)
|
||||
update_kwargs["word_result"] = word_result
|
||||
|
||||
if from_step <= 7:
|
||||
update_kwargs["row_result"] = None
|
||||
if from_step <= 6:
|
||||
update_kwargs["column_result"] = None
|
||||
if from_step <= 4:
|
||||
update_kwargs["crop_result"] = None
|
||||
if from_step <= 3:
|
||||
update_kwargs["dewarp_result"] = None
|
||||
if from_step <= 2:
|
||||
update_kwargs["deskew_result"] = None
|
||||
if from_step <= 1:
|
||||
update_kwargs["orientation_result"] = None
|
||||
|
||||
await update_session_db(session_id, **update_kwargs)
|
||||
|
||||
# Also clear cache
|
||||
if session_id in _cache:
|
||||
for key in list(update_kwargs.keys()):
|
||||
if key != "current_step":
|
||||
_cache[session_id][key] = update_kwargs[key]
|
||||
_cache[session_id]["current_step"] = from_step
|
||||
|
||||
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"from_step": from_step,
|
||||
"cleared": [k for k in update_kwargs if k != "current_step"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VLM shear detection helper (used by dewarp step in auto-mode)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
||||
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
||||
|
||||
The VLM is shown the image and asked: are the column/table borders tilted?
|
||||
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
||||
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
import re
|
||||
|
||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||
|
||||
prompt = (
|
||||
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
||||
"Are they perfectly vertical, or do they tilt slightly? "
|
||||
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
||||
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
||||
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
||||
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
||||
)
|
||||
|
||||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
text = resp.json().get("response", "")
|
||||
|
||||
# Parse JSON from response (may have surrounding text)
|
||||
match = re.search(r'\{[^}]+\}', text)
|
||||
if match:
|
||||
import json
|
||||
data = json.loads(match.group(0))
|
||||
shear = float(data.get("shear_degrees", 0.0))
|
||||
conf = float(data.get("confidence", 0.0))
|
||||
# Clamp to reasonable range
|
||||
shear = max(-3.0, min(3.0, shear))
|
||||
conf = max(0.0, min(1.0, conf))
|
||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM dewarp failed: {e}")
|
||||
|
||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-mode orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RunAutoRequest(BaseModel):
|
||||
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
|
||||
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
|
||||
pronunciation: str = "british"
|
||||
skip_llm_review: bool = False
|
||||
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
|
||||
|
||||
|
||||
async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
|
||||
"""Format a single SSE event line."""
|
||||
import json as _json
|
||||
payload = {"step": step, "status": status, **data}
|
||||
return f"data: {_json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/run-auto")
|
||||
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
|
||||
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
|
||||
|
||||
Steps:
|
||||
1. Deskew — straighten the scan
|
||||
2. Dewarp — correct vertical shear (ensemble CV or VLM)
|
||||
3. Columns — detect column layout
|
||||
4. Rows — detect row layout
|
||||
5. Words — OCR each cell
|
||||
6. LLM review — correct OCR errors (optional)
|
||||
|
||||
Already-completed steps are skipped unless `from_step` forces a rerun.
|
||||
Yields SSE events of the form:
|
||||
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
|
||||
|
||||
Final event:
|
||||
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
|
||||
"""
|
||||
if req.from_step < 1 or req.from_step > 6:
|
||||
raise HTTPException(status_code=400, detail="from_step must be 1-6")
|
||||
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
|
||||
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
|
||||
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
|
||||
async def _generate():
|
||||
steps_run: List[str] = []
|
||||
steps_skipped: List[str] = []
|
||||
error_step: Optional[str] = None
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
|
||||
return
|
||||
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 1: Deskew
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 1:
|
||||
yield await _auto_sse_event("deskew", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
orig_bgr = cached.get("original_bgr")
|
||||
if orig_bgr is None:
|
||||
raise ValueError("Original image not loaded")
|
||||
|
||||
# Method 1: Hough lines
|
||||
try:
|
||||
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
|
||||
except Exception:
|
||||
deskewed_hough, angle_hough = orig_bgr, 0.0
|
||||
|
||||
# Method 2: Word alignment
|
||||
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
|
||||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||||
try:
|
||||
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||||
except Exception:
|
||||
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||||
|
||||
# Pick best method
|
||||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||||
method_used = "word_alignment"
|
||||
angle_applied = angle_wa
|
||||
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||||
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
|
||||
if deskewed_bgr is None:
|
||||
deskewed_bgr = deskewed_hough
|
||||
method_used = "hough"
|
||||
angle_applied = angle_hough
|
||||
else:
|
||||
method_used = "hough"
|
||||
angle_applied = angle_hough
|
||||
deskewed_bgr = deskewed_hough
|
||||
|
||||
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
deskewed_png = png_buf.tobytes() if success else b""
|
||||
|
||||
deskew_result = {
|
||||
"method_used": method_used,
|
||||
"rotation_degrees": round(float(angle_applied), 3),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["deskewed_bgr"] = deskewed_bgr
|
||||
cached["deskew_result"] = deskew_result
|
||||
await update_session_db(
|
||||
session_id,
|
||||
deskewed_png=deskewed_png,
|
||||
deskew_result=deskew_result,
|
||||
auto_rotation_degrees=float(angle_applied),
|
||||
current_step=3,
|
||||
)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("deskew")
|
||||
yield await _auto_sse_event("deskew", "done", deskew_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
|
||||
error_step = "deskew"
|
||||
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("deskew")
|
||||
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 2: Dewarp
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 2:
|
||||
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
|
||||
try:
|
||||
t0 = time.time()
|
||||
deskewed_bgr = cached.get("deskewed_bgr")
|
||||
if deskewed_bgr is None:
|
||||
raise ValueError("Deskewed image not available")
|
||||
|
||||
if req.dewarp_method == "vlm":
|
||||
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
img_bytes = png_buf.tobytes() if success_enc else b""
|
||||
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||||
shear_deg = vlm_det["shear_degrees"]
|
||||
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||||
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||||
else:
|
||||
dewarped_bgr = deskewed_bgr
|
||||
dewarp_info = {
|
||||
"method": vlm_det["method"],
|
||||
"shear_degrees": shear_deg,
|
||||
"confidence": vlm_det["confidence"],
|
||||
"detections": [vlm_det],
|
||||
}
|
||||
else:
|
||||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||||
|
||||
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||
dewarped_png = png_buf.tobytes() if success_enc else b""
|
||||
|
||||
dewarp_result = {
|
||||
"method_used": dewarp_info["method"],
|
||||
"shear_degrees": dewarp_info["shear_degrees"],
|
||||
"confidence": dewarp_info["confidence"],
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
"detections": dewarp_info.get("detections", []),
|
||||
}
|
||||
|
||||
cached["dewarped_bgr"] = dewarped_bgr
|
||||
cached["dewarp_result"] = dewarp_result
|
||||
await update_session_db(
|
||||
session_id,
|
||||
dewarped_png=dewarped_png,
|
||||
dewarp_result=dewarp_result,
|
||||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||||
current_step=4,
|
||||
)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("dewarp")
|
||||
yield await _auto_sse_event("dewarp", "done", dewarp_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
|
||||
error_step = "dewarp"
|
||||
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("dewarp")
|
||||
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 3: Columns
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 3:
|
||||
yield await _auto_sse_event("columns", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if col_img is None:
|
||||
raise ValueError("Cropped/dewarped image not available")
|
||||
|
||||
ocr_img = create_ocr_image(col_img)
|
||||
h, w = ocr_img.shape[:2]
|
||||
|
||||
geo_result = detect_column_geometry(ocr_img, col_img)
|
||||
if geo_result is None:
|
||||
layout_img = create_layout_image(col_img)
|
||||
regions = analyze_layout(layout_img, ocr_img)
|
||||
cached["_word_dicts"] = None
|
||||
cached["_inv"] = None
|
||||
cached["_content_bounds"] = None
|
||||
else:
|
||||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||
content_w = right_x - left_x
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||||
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||||
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||||
left_x=left_x, right_x=right_x, inv=inv)
|
||||
|
||||
columns = [asdict(r) for r in regions]
|
||||
column_result = {
|
||||
"columns": columns,
|
||||
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["column_result"] = column_result
|
||||
await update_session_db(session_id, column_result=column_result,
|
||||
row_result=None, word_result=None, current_step=6)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("columns")
|
||||
yield await _auto_sse_event("columns", "done", {
|
||||
"column_count": len(columns),
|
||||
"duration_seconds": column_result["duration_seconds"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
|
||||
error_step = "columns"
|
||||
yield await _auto_sse_event("columns", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("columns")
|
||||
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 4: Rows
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 4:
|
||||
yield await _auto_sse_event("rows", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
session = await get_session_db(session_id)
|
||||
column_result = session.get("column_result") or cached.get("column_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
raise ValueError("Column detection must complete first")
|
||||
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"], x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
inv = cached.get("_inv")
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
|
||||
if word_dicts is None or inv is None or content_bounds is None:
|
||||
ocr_img_tmp = create_ocr_image(row_img)
|
||||
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
|
||||
if geo_result is None:
|
||||
raise ValueError("Column geometry detection failed — cannot detect rows")
|
||||
_g, lx, rx, ty, by, word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (lx, rx, ty, by)
|
||||
content_bounds = (lx, rx, ty, by)
|
||||
|
||||
left_x, right_x, top_y, bottom_y = content_bounds
|
||||
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||
|
||||
row_list = [
|
||||
{
|
||||
"index": r.index, "x": r.x, "y": r.y,
|
||||
"width": r.width, "height": r.height,
|
||||
"word_count": r.word_count,
|
||||
"row_type": r.row_type,
|
||||
"gap_before": r.gap_before,
|
||||
}
|
||||
for r in row_geoms
|
||||
]
|
||||
row_result = {
|
||||
"rows": row_list,
|
||||
"row_count": len(row_list),
|
||||
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["row_result"] = row_result
|
||||
await update_session_db(session_id, row_result=row_result, current_step=7)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("rows")
|
||||
yield await _auto_sse_event("rows", "done", {
|
||||
"row_count": len(row_list),
|
||||
"content_rows": row_result["content_rows"],
|
||||
"duration_seconds": row_result["duration_seconds"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
|
||||
error_step = "rows"
|
||||
yield await _auto_sse_event("rows", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("rows")
|
||||
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 5: Words (OCR)
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 5:
|
||||
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
|
||||
try:
|
||||
t0 = time.time()
|
||||
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
column_result = session.get("column_result") or cached.get("column_result")
|
||||
row_result = session.get("row_result") or cached.get("row_result")
|
||||
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"], x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
row_geoms = [
|
||||
RowGeometry(
|
||||
index=r["index"], x=r["x"], y=r["y"],
|
||||
width=r["width"], height=r["height"],
|
||||
word_count=r.get("word_count", 0), words=[],
|
||||
row_type=r.get("row_type", "content"),
|
||||
gap_before=r.get("gap_before", 0),
|
||||
)
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
if word_dicts is not None:
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
|
||||
for row in row_geoms:
|
||||
row_y_rel = row.y - top_y
|
||||
row_bottom_rel = row_y_rel + row.height
|
||||
row.words = [
|
||||
w for w in word_dicts
|
||||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||
]
|
||||
row.word_count = len(row.words)
|
||||
|
||||
ocr_img = create_ocr_image(word_img)
|
||||
img_h, img_w = word_img.shape[:2]
|
||||
|
||||
cells, columns_meta = build_cell_grid(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=req.ocr_engine, img_bgr=word_img,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
col_types = {c['type'] for c in columns_meta}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
|
||||
|
||||
# Apply IPA phonetic fixes directly to cell texts
|
||||
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
|
||||
|
||||
word_result_data = {
|
||||
"cells": cells,
|
||||
"grid_shape": {
|
||||
"rows": n_content_rows,
|
||||
"cols": len(columns_meta),
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_character_confusion(entries)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
|
||||
word_result_data["vocab_entries"] = entries
|
||||
word_result_data["entries"] = entries
|
||||
word_result_data["entry_count"] = len(entries)
|
||||
word_result_data["summary"]["total_entries"] = len(entries)
|
||||
|
||||
await update_session_db(session_id, word_result=word_result_data, current_step=8)
|
||||
cached["word_result"] = word_result_data
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("words")
|
||||
yield await _auto_sse_event("words", "done", {
|
||||
"total_cells": len(cells),
|
||||
"layout": word_result_data["layout"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": word_result_data["summary"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode words failed for {session_id}: {e}")
|
||||
error_step = "words"
|
||||
yield await _auto_sse_event("words", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("words")
|
||||
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 6: LLM Review (optional)
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 6 and not req.skip_llm_review:
|
||||
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
|
||||
try:
|
||||
session = await get_session_db(session_id)
|
||||
word_result = session.get("word_result") or cached.get("word_result")
|
||||
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
|
||||
|
||||
if not entries:
|
||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
|
||||
steps_skipped.append("llm_review")
|
||||
else:
|
||||
reviewed = await llm_review_entries(entries)
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
word_result_updated = dict(session.get("word_result") or {})
|
||||
word_result_updated["entries"] = reviewed
|
||||
word_result_updated["vocab_entries"] = reviewed
|
||||
word_result_updated["llm_reviewed"] = True
|
||||
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
|
||||
|
||||
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
|
||||
cached["word_result"] = word_result_updated
|
||||
|
||||
steps_run.append("llm_review")
|
||||
yield await _auto_sse_event("llm_review", "done", {
|
||||
"entries_reviewed": len(reviewed),
|
||||
"model": OLLAMA_REVIEW_MODEL,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
|
||||
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
|
||||
steps_skipped.append("llm_review")
|
||||
else:
|
||||
steps_skipped.append("llm_review")
|
||||
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
|
||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Final event
|
||||
# -----------------------------------------------------------------
|
||||
yield await _auto_sse_event("complete", "done", {
|
||||
"steps_run": steps_run,
|
||||
"steps_skipped": steps_skipped,
|
||||
})
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
OCR Pipeline Auto-Mode Helpers.
|
||||
|
||||
VLM shear detection, SSE event formatting, and request models.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunAutoRequest(BaseModel):
|
||||
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
|
||||
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
|
||||
pronunciation: str = "british"
|
||||
skip_llm_review: bool = False
|
||||
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
|
||||
|
||||
|
||||
async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
|
||||
"""Format a single SSE event line."""
|
||||
payload = {"step": step, "status": status, **data}
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
||||
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
||||
|
||||
The VLM is shown the image and asked: are the column/table borders tilted?
|
||||
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
||||
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
|
||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||
|
||||
prompt = (
|
||||
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
||||
"Are they perfectly vertical, or do they tilt slightly? "
|
||||
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
||||
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
||||
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
||||
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
||||
)
|
||||
|
||||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
text = resp.json().get("response", "")
|
||||
|
||||
# Parse JSON from response (may have surrounding text)
|
||||
match = re.search(r'\{[^}]+\}', text)
|
||||
if match:
|
||||
data = json.loads(match.group(0))
|
||||
shear = float(data.get("shear_degrees", 0.0))
|
||||
conf = float(data.get("confidence", 0.0))
|
||||
# Clamp to reasonable range
|
||||
shear = max(-3.0, min(3.0, shear))
|
||||
conf = max(0.0, min(1.0, conf))
|
||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM dewarp failed: {e}")
|
||||
|
||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
||||
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
OCR Pipeline Auto-Mode Orchestrator.
|
||||
|
||||
POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
OLLAMA_REVIEW_MODEL,
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
_cells_to_vocab_entries,
|
||||
_detect_header_footer_gaps,
|
||||
_detect_sub_columns,
|
||||
_fix_character_confusion,
|
||||
_fix_phonetic_brackets,
|
||||
fix_cell_phonetics,
|
||||
analyze_layout,
|
||||
build_cell_grid,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
create_ocr_image,
|
||||
deskew_image,
|
||||
deskew_image_by_word_alignment,
|
||||
detect_column_geometry,
|
||||
detect_row_geometry,
|
||||
_apply_shear,
|
||||
dewarp_image,
|
||||
llm_review_entries,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
)
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_auto_helpers import (
|
||||
RunAutoRequest,
|
||||
auto_sse_event as _auto_sse_event,
|
||||
detect_shear_with_vlm as _detect_shear_with_vlm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["ocr-pipeline"])
|
||||
|
||||
@router.post("/sessions/{session_id}/run-auto")
|
||||
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
|
||||
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
|
||||
|
||||
Steps:
|
||||
1. Deskew -- straighten the scan
|
||||
2. Dewarp -- correct vertical shear (ensemble CV or VLM)
|
||||
3. Columns -- detect column layout
|
||||
4. Rows -- detect row layout
|
||||
5. Words -- OCR each cell
|
||||
6. LLM review -- correct OCR errors (optional)
|
||||
|
||||
Already-completed steps are skipped unless `from_step` forces a rerun.
|
||||
Yields SSE events of the form:
|
||||
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
|
||||
|
||||
Final event:
|
||||
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
|
||||
"""
|
||||
if req.from_step < 1 or req.from_step > 6:
|
||||
raise HTTPException(status_code=400, detail="from_step must be 1-6")
|
||||
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
|
||||
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
|
||||
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
|
||||
async def _generate():
|
||||
steps_run: List[str] = []
|
||||
steps_skipped: List[str] = []
|
||||
error_step: Optional[str] = None
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
|
||||
return
|
||||
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
# Step 1: Deskew
|
||||
if req.from_step <= 1:
|
||||
yield await _auto_sse_event("deskew", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
orig_bgr = cached.get("original_bgr")
|
||||
if orig_bgr is None:
|
||||
raise ValueError("Original image not loaded")
|
||||
|
||||
try:
|
||||
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
|
||||
except Exception:
|
||||
deskewed_hough, angle_hough = orig_bgr, 0.0
|
||||
|
||||
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
|
||||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||||
try:
|
||||
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||||
except Exception:
|
||||
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||||
|
||||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||||
method_used = "word_alignment"
|
||||
angle_applied = angle_wa
|
||||
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||||
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
|
||||
if deskewed_bgr is None:
|
||||
deskewed_bgr = deskewed_hough
|
||||
method_used = "hough"
|
||||
angle_applied = angle_hough
|
||||
else:
|
||||
method_used = "hough"
|
||||
angle_applied = angle_hough
|
||||
deskewed_bgr = deskewed_hough
|
||||
|
||||
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
deskewed_png = png_buf.tobytes() if success else b""
|
||||
|
||||
deskew_result = {
|
||||
"method_used": method_used,
|
||||
"rotation_degrees": round(float(angle_applied), 3),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["deskewed_bgr"] = deskewed_bgr
|
||||
cached["deskew_result"] = deskew_result
|
||||
await update_session_db(
|
||||
session_id,
|
||||
deskewed_png=deskewed_png,
|
||||
deskew_result=deskew_result,
|
||||
auto_rotation_degrees=float(angle_applied),
|
||||
current_step=3,
|
||||
)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("deskew")
|
||||
yield await _auto_sse_event("deskew", "done", deskew_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
|
||||
error_step = "deskew"
|
||||
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("deskew")
|
||||
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
|
||||
|
||||
# Step 2: Dewarp
|
||||
if req.from_step <= 2:
|
||||
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
|
||||
try:
|
||||
t0 = time.time()
|
||||
deskewed_bgr = cached.get("deskewed_bgr")
|
||||
if deskewed_bgr is None:
|
||||
raise ValueError("Deskewed image not available")
|
||||
|
||||
if req.dewarp_method == "vlm":
|
||||
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
img_bytes = png_buf.tobytes() if success_enc else b""
|
||||
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||||
shear_deg = vlm_det["shear_degrees"]
|
||||
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||||
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||||
else:
|
||||
dewarped_bgr = deskewed_bgr
|
||||
dewarp_info = {
|
||||
"method": vlm_det["method"],
|
||||
"shear_degrees": shear_deg,
|
||||
"confidence": vlm_det["confidence"],
|
||||
"detections": [vlm_det],
|
||||
}
|
||||
else:
|
||||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||||
|
||||
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||
dewarped_png = png_buf.tobytes() if success_enc else b""
|
||||
|
||||
dewarp_result = {
|
||||
"method_used": dewarp_info["method"],
|
||||
"shear_degrees": dewarp_info["shear_degrees"],
|
||||
"confidence": dewarp_info["confidence"],
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
"detections": dewarp_info.get("detections", []),
|
||||
}
|
||||
|
||||
cached["dewarped_bgr"] = dewarped_bgr
|
||||
cached["dewarp_result"] = dewarp_result
|
||||
await update_session_db(
|
||||
session_id,
|
||||
dewarped_png=dewarped_png,
|
||||
dewarp_result=dewarp_result,
|
||||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||||
current_step=4,
|
||||
)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("dewarp")
|
||||
yield await _auto_sse_event("dewarp", "done", dewarp_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
|
||||
error_step = "dewarp"
|
||||
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("dewarp")
|
||||
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
|
||||
|
||||
# Step 3: Columns
|
||||
if req.from_step <= 3:
|
||||
yield await _auto_sse_event("columns", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if col_img is None:
|
||||
raise ValueError("Cropped/dewarped image not available")
|
||||
|
||||
ocr_img = create_ocr_image(col_img)
|
||||
h, w = ocr_img.shape[:2]
|
||||
|
||||
geo_result = detect_column_geometry(ocr_img, col_img)
|
||||
if geo_result is None:
|
||||
layout_img = create_layout_image(col_img)
|
||||
regions = analyze_layout(layout_img, ocr_img)
|
||||
cached["_word_dicts"] = None
|
||||
cached["_inv"] = None
|
||||
cached["_content_bounds"] = None
|
||||
else:
|
||||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||
content_w = right_x - left_x
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||||
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||||
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||||
left_x=left_x, right_x=right_x, inv=inv)
|
||||
|
||||
columns = [asdict(r) for r in regions]
|
||||
column_result = {
|
||||
"columns": columns,
|
||||
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["column_result"] = column_result
|
||||
await update_session_db(session_id, column_result=column_result,
|
||||
row_result=None, word_result=None, current_step=6)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("columns")
|
||||
yield await _auto_sse_event("columns", "done", {
|
||||
"column_count": len(columns),
|
||||
"duration_seconds": column_result["duration_seconds"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
|
||||
error_step = "columns"
|
||||
yield await _auto_sse_event("columns", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("columns")
|
||||
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
|
||||
|
||||
# Step 4: Rows
|
||||
if req.from_step <= 4:
|
||||
yield await _auto_sse_event("rows", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
session = await get_session_db(session_id)
|
||||
column_result = session.get("column_result") or cached.get("column_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
raise ValueError("Column detection must complete first")
|
||||
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"], x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
inv = cached.get("_inv")
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
|
||||
if word_dicts is None or inv is None or content_bounds is None:
|
||||
ocr_img_tmp = create_ocr_image(row_img)
|
||||
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
|
||||
if geo_result is None:
|
||||
raise ValueError("Column geometry detection failed -- cannot detect rows")
|
||||
_g, lx, rx, ty, by, word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (lx, rx, ty, by)
|
||||
content_bounds = (lx, rx, ty, by)
|
||||
|
||||
left_x, right_x, top_y, bottom_y = content_bounds
|
||||
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||
|
||||
row_list = [
|
||||
{
|
||||
"index": r.index, "x": r.x, "y": r.y,
|
||||
"width": r.width, "height": r.height,
|
||||
"word_count": r.word_count,
|
||||
"row_type": r.row_type,
|
||||
"gap_before": r.gap_before,
|
||||
}
|
||||
for r in row_geoms
|
||||
]
|
||||
row_result = {
|
||||
"rows": row_list,
|
||||
"row_count": len(row_list),
|
||||
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["row_result"] = row_result
|
||||
await update_session_db(session_id, row_result=row_result, current_step=7)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("rows")
|
||||
yield await _auto_sse_event("rows", "done", {
|
||||
"row_count": len(row_list),
|
||||
"content_rows": row_result["content_rows"],
|
||||
"duration_seconds": row_result["duration_seconds"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
|
||||
error_step = "rows"
|
||||
yield await _auto_sse_event("rows", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("rows")
|
||||
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
|
||||
|
||||
# Step 5: Words (OCR)
|
||||
if req.from_step <= 5:
|
||||
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
|
||||
try:
|
||||
t0 = time.time()
|
||||
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
column_result = session.get("column_result") or cached.get("column_result")
|
||||
row_result = session.get("row_result") or cached.get("row_result")
|
||||
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"], x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
row_geoms = [
|
||||
RowGeometry(
|
||||
index=r["index"], x=r["x"], y=r["y"],
|
||||
width=r["width"], height=r["height"],
|
||||
word_count=r.get("word_count", 0), words=[],
|
||||
row_type=r.get("row_type", "content"),
|
||||
gap_before=r.get("gap_before", 0),
|
||||
)
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
if word_dicts is not None:
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
|
||||
for row in row_geoms:
|
||||
row_y_rel = row.y - top_y
|
||||
row_bottom_rel = row_y_rel + row.height
|
||||
row.words = [
|
||||
w for w in word_dicts
|
||||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||
]
|
||||
row.word_count = len(row.words)
|
||||
|
||||
ocr_img = create_ocr_image(word_img)
|
||||
img_h, img_w = word_img.shape[:2]
|
||||
|
||||
cells, columns_meta = build_cell_grid(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=req.ocr_engine, img_bgr=word_img,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
col_types = {c['type'] for c in columns_meta}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
|
||||
|
||||
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
|
||||
|
||||
word_result_data = {
|
||||
"cells": cells,
|
||||
"grid_shape": {
|
||||
"rows": n_content_rows,
|
||||
"cols": len(columns_meta),
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_character_confusion(entries)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
|
||||
word_result_data["vocab_entries"] = entries
|
||||
word_result_data["entries"] = entries
|
||||
word_result_data["entry_count"] = len(entries)
|
||||
word_result_data["summary"]["total_entries"] = len(entries)
|
||||
|
||||
await update_session_db(session_id, word_result=word_result_data, current_step=8)
|
||||
cached["word_result"] = word_result_data
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("words")
|
||||
yield await _auto_sse_event("words", "done", {
|
||||
"total_cells": len(cells),
|
||||
"layout": word_result_data["layout"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": word_result_data["summary"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode words failed for {session_id}: {e}")
|
||||
error_step = "words"
|
||||
yield await _auto_sse_event("words", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("words")
|
||||
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
|
||||
|
||||
# Step 6: LLM Review (optional)
|
||||
if req.from_step <= 6 and not req.skip_llm_review:
|
||||
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
|
||||
try:
|
||||
session = await get_session_db(session_id)
|
||||
word_result = session.get("word_result") or cached.get("word_result")
|
||||
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
|
||||
|
||||
if not entries:
|
||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
|
||||
steps_skipped.append("llm_review")
|
||||
else:
|
||||
reviewed = await llm_review_entries(entries)
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
word_result_updated = dict(session.get("word_result") or {})
|
||||
word_result_updated["entries"] = reviewed
|
||||
word_result_updated["vocab_entries"] = reviewed
|
||||
word_result_updated["llm_reviewed"] = True
|
||||
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
|
||||
|
||||
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
|
||||
cached["word_result"] = word_result_updated
|
||||
|
||||
steps_run.append("llm_review")
|
||||
yield await _auto_sse_event("llm_review", "done", {
|
||||
"entries_reviewed": len(reviewed),
|
||||
"model": OLLAMA_REVIEW_MODEL,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
|
||||
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
|
||||
steps_skipped.append("llm_review")
|
||||
else:
|
||||
steps_skipped.append("llm_review")
|
||||
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
|
||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
|
||||
|
||||
# Final event
|
||||
yield await _auto_sse_event("complete", "done", {
|
||||
"steps_run": steps_run,
|
||||
"steps_skipped": steps_skipped,
|
||||
})
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
OCR Pipeline Reprocess Endpoint.
|
||||
|
||||
POST /sessions/{session_id}/reprocess — clear downstream + restart from step.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from ocr_pipeline_common import _cache
|
||||
from ocr_pipeline_session_store import get_session_db, update_session_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/reprocess")
|
||||
async def reprocess_session(session_id: str, request: Request):
|
||||
"""Re-run pipeline from a specific step, clearing downstream data.
|
||||
|
||||
Body: {"from_step": 5} (1-indexed step number)
|
||||
|
||||
Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) ->
|
||||
Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10)
|
||||
|
||||
Clears downstream results:
|
||||
- from_step <= 1: orientation_result + all downstream
|
||||
- from_step <= 2: deskew_result + all downstream
|
||||
- from_step <= 3: dewarp_result + all downstream
|
||||
- from_step <= 4: crop_result + all downstream
|
||||
- from_step <= 5: column_result, row_result, word_result
|
||||
- from_step <= 6: row_result, word_result
|
||||
- from_step <= 7: word_result (cells, vocab_entries)
|
||||
- from_step <= 8: word_result.llm_review only
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
body = await request.json()
|
||||
from_step = body.get("from_step", 1)
|
||||
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
|
||||
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
|
||||
|
||||
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
||||
|
||||
# Clear downstream data based on from_step
|
||||
# New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) ->
|
||||
# Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11)
|
||||
if from_step <= 8:
|
||||
update_kwargs["word_result"] = None
|
||||
elif from_step == 9:
|
||||
# Only clear LLM review from word_result
|
||||
word_result = session.get("word_result")
|
||||
if word_result:
|
||||
word_result.pop("llm_review", None)
|
||||
word_result.pop("llm_corrections", None)
|
||||
update_kwargs["word_result"] = word_result
|
||||
|
||||
if from_step <= 7:
|
||||
update_kwargs["row_result"] = None
|
||||
if from_step <= 6:
|
||||
update_kwargs["column_result"] = None
|
||||
if from_step <= 4:
|
||||
update_kwargs["crop_result"] = None
|
||||
if from_step <= 3:
|
||||
update_kwargs["dewarp_result"] = None
|
||||
if from_step <= 2:
|
||||
update_kwargs["deskew_result"] = None
|
||||
if from_step <= 1:
|
||||
update_kwargs["orientation_result"] = None
|
||||
|
||||
await update_session_db(session_id, **update_kwargs)
|
||||
|
||||
# Also clear cache
|
||||
if session_id in _cache:
|
||||
for key in list(update_kwargs.keys()):
|
||||
if key != "current_step":
|
||||
_cache[session_id][key] = update_kwargs[key]
|
||||
_cache[session_id]["current_step"] = from_step
|
||||
|
||||
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"from_step": from_step,
|
||||
"cleared": [k for k in update_kwargs if k != "current_step"],
|
||||
}
|
||||
@@ -1,758 +1,33 @@
|
||||
"""
|
||||
Page Crop - Content-based crop for scanned pages and book scans.
|
||||
Page Crop — Barrel Re-export
|
||||
|
||||
Detects the content boundary by analysing ink density projections and
|
||||
(for book scans) the spine shadow gradient. Works with both loose A4
|
||||
sheets on dark scanners AND book scans with white backgrounds.
|
||||
Content-based crop for scanned pages and book scans.
|
||||
|
||||
Split into:
|
||||
- page_crop_edges.py — Edge detection (spine shadow, gutter, projection)
|
||||
- page_crop_core.py — Main crop algorithm and format detection
|
||||
|
||||
All public names are re-exported here for backward compatibility.
|
||||
License: Apache 2.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
# Core: main crop functions and format detection
|
||||
from page_crop_core import ( # noqa: F401
|
||||
PAPER_FORMATS,
|
||||
detect_page_splits,
|
||||
detect_and_crop_page,
|
||||
_detect_format,
|
||||
)
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Known paper format aspect ratios (height / width, portrait orientation)
|
||||
PAPER_FORMATS = {
|
||||
"A4": 297.0 / 210.0, # 1.4143
|
||||
"A5": 210.0 / 148.0, # 1.4189
|
||||
"Letter": 11.0 / 8.5, # 1.2941
|
||||
"Legal": 14.0 / 8.5, # 1.6471
|
||||
"A3": 420.0 / 297.0, # 1.4141
|
||||
}
|
||||
|
||||
# Minimum ink density (fraction of pixels) to count a row/column as "content"
|
||||
_INK_THRESHOLD = 0.003 # 0.3%
|
||||
|
||||
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
|
||||
_MIN_RUN_FRAC = 0.005 # 0.5%
|
||||
|
||||
|
||||
def detect_page_splits(
|
||||
img_bgr: np.ndarray,
|
||||
) -> list:
|
||||
"""Detect if the image is a multi-page spread and return split rectangles.
|
||||
|
||||
Uses **brightness** (not ink density) to find the spine area:
|
||||
the scanner bed produces a characteristic gray strip where pages meet,
|
||||
which is darker than the white paper on either side.
|
||||
|
||||
Returns a list of page dicts ``{x, y, width, height, page_index}``
|
||||
or an empty list if only one page is detected.
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
# Only check landscape-ish images (width > height * 1.15)
|
||||
if w < h * 1.15:
|
||||
return []
|
||||
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
|
||||
col_brightness = np.mean(gray, axis=0).astype(np.float64)
|
||||
|
||||
# Heavy smoothing to ignore individual text lines
|
||||
kern = max(11, w // 50)
|
||||
if kern % 2 == 0:
|
||||
kern += 1
|
||||
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
|
||||
|
||||
# Page paper is bright (typically > 200), spine/scanner bed is darker
|
||||
page_brightness = float(np.max(brightness_smooth))
|
||||
if page_brightness < 100:
|
||||
return [] # Very dark image, skip
|
||||
|
||||
# Spine threshold: significantly darker than the page
|
||||
# Spine is typically 60-80% of paper brightness
|
||||
spine_thresh = page_brightness * 0.88
|
||||
|
||||
# Search in center region (30-70% of width)
|
||||
center_lo = int(w * 0.30)
|
||||
center_hi = int(w * 0.70)
|
||||
|
||||
# Find the darkest valley in the center region
|
||||
center_brightness = brightness_smooth[center_lo:center_hi]
|
||||
darkest_val = float(np.min(center_brightness))
|
||||
|
||||
if darkest_val >= spine_thresh:
|
||||
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
|
||||
darkest_val, spine_thresh)
|
||||
return []
|
||||
|
||||
# Find ALL contiguous dark runs in the center region
|
||||
is_dark = center_brightness < spine_thresh
|
||||
dark_runs: list = [] # list of (start, end) pairs
|
||||
run_start = -1
|
||||
for i in range(len(is_dark)):
|
||||
if is_dark[i]:
|
||||
if run_start < 0:
|
||||
run_start = i
|
||||
else:
|
||||
if run_start >= 0:
|
||||
dark_runs.append((run_start, i))
|
||||
run_start = -1
|
||||
if run_start >= 0:
|
||||
dark_runs.append((run_start, len(is_dark)))
|
||||
|
||||
# Filter out runs that are too narrow (< 1% of image width)
|
||||
min_spine_px = int(w * 0.01)
|
||||
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
|
||||
|
||||
if not dark_runs:
|
||||
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
|
||||
return []
|
||||
|
||||
# Score each dark run: prefer centered, dark, narrow valleys
|
||||
center_region_len = center_hi - center_lo
|
||||
image_center_in_region = (w * 0.5 - center_lo) # x=50% mapped into region coords
|
||||
best_score = -1.0
|
||||
best_start, best_end = dark_runs[0]
|
||||
|
||||
for rs, re in dark_runs:
|
||||
run_width = re - rs
|
||||
run_center = (rs + re) / 2.0
|
||||
|
||||
# --- Factor 1: Proximity to image center (gaussian, sigma = 15% of region) ---
|
||||
sigma = center_region_len * 0.15
|
||||
dist = abs(run_center - image_center_in_region)
|
||||
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
|
||||
|
||||
# --- Factor 2: Darkness (how dark is the valley relative to threshold) ---
|
||||
run_brightness = float(np.mean(center_brightness[rs:re]))
|
||||
# Normalize: 1.0 when run_brightness == 0, 0.0 when run_brightness == spine_thresh
|
||||
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
|
||||
|
||||
# --- Factor 3: Narrowness bonus (spine shadows are narrow, not wide plateaus) ---
|
||||
# Typical spine: 1-5% of image width. Penalise runs wider than ~8%.
|
||||
width_frac = run_width / w
|
||||
if width_frac <= 0.05:
|
||||
narrowness_bonus = 1.0
|
||||
elif width_frac <= 0.15:
|
||||
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10 # linear decay 1.0 → 0.0
|
||||
else:
|
||||
narrowness_bonus = 0.0
|
||||
|
||||
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
|
||||
|
||||
logger.debug(
|
||||
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f → score=%.4f",
|
||||
center_lo + rs, center_lo + re, run_width,
|
||||
center_factor, darkness_factor, narrowness_bonus, score,
|
||||
)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_start, best_end = rs, re
|
||||
|
||||
spine_w = best_end - best_start
|
||||
spine_x = center_lo + best_start
|
||||
spine_center = spine_x + spine_w // 2
|
||||
|
||||
logger.debug(
|
||||
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
|
||||
spine_x, spine_x + spine_w, spine_w, best_score,
|
||||
)
|
||||
|
||||
# Verify: must have bright (paper) content on BOTH sides
|
||||
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
|
||||
right_end = center_lo + best_end
|
||||
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
|
||||
|
||||
if left_brightness < spine_thresh or right_brightness < spine_thresh:
|
||||
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
|
||||
left_brightness, right_brightness, spine_thresh)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
|
||||
"left_paper=%.0f, right_paper=%.0f",
|
||||
spine_x, right_end, spine_w, darkest_val, page_brightness,
|
||||
left_brightness, right_brightness,
|
||||
)
|
||||
|
||||
# Split at the spine center
|
||||
split_points = [spine_center]
|
||||
|
||||
# Build page rectangles
|
||||
pages: list = []
|
||||
prev_x = 0
|
||||
for i, sx in enumerate(split_points):
|
||||
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
|
||||
"height": h, "page_index": i})
|
||||
prev_x = sx
|
||||
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
|
||||
"height": h, "page_index": len(split_points)})
|
||||
|
||||
# Filter out tiny pages (< 15% of total width)
|
||||
pages = [p for p in pages if p["width"] >= w * 0.15]
|
||||
if len(pages) < 2:
|
||||
return []
|
||||
|
||||
# Re-index
|
||||
for i, p in enumerate(pages):
|
||||
p["page_index"] = i
|
||||
|
||||
logger.info(
|
||||
"Page split detected: %d pages, spine_w=%d, split_points=%s",
|
||||
len(pages), spine_w, split_points,
|
||||
)
|
||||
return pages
|
||||
|
||||
|
||||
def detect_and_crop_page(
|
||||
img_bgr: np.ndarray,
|
||||
margin_frac: float = 0.01,
|
||||
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
||||
"""Detect content boundary and crop scanner/book borders.
|
||||
|
||||
Algorithm (4-edge detection):
|
||||
1. Adaptive threshold → binary (text=255, bg=0)
|
||||
2. Left edge: spine-shadow detection via grayscale column means,
|
||||
fallback to binary vertical projection
|
||||
3. Right edge: binary vertical projection (last ink column)
|
||||
4. Top/bottom edges: binary horizontal projection
|
||||
5. Sanity checks, then crop with configurable margin
|
||||
|
||||
Args:
|
||||
img_bgr: Input BGR image (should already be deskewed/dewarped)
|
||||
margin_frac: Extra margin around content (fraction of dimension, default 1%)
|
||||
|
||||
Returns:
|
||||
Tuple of (cropped_image, result_dict)
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
total_area = h * w
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"crop_applied": False,
|
||||
"crop_rect": None,
|
||||
"crop_rect_pct": None,
|
||||
"original_size": {"width": w, "height": h},
|
||||
"cropped_size": {"width": w, "height": h},
|
||||
"detected_format": None,
|
||||
"format_confidence": 0.0,
|
||||
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
|
||||
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
|
||||
}
|
||||
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# --- Binarise with adaptive threshold (works for white-on-white) ---
|
||||
binary = cv2.adaptiveThreshold(
|
||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
|
||||
)
|
||||
|
||||
# --- Left edge: spine-shadow detection ---
|
||||
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
|
||||
|
||||
# --- Right edge: spine-shadow detection ---
|
||||
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
|
||||
|
||||
# --- Top / bottom edges: binary horizontal projection ---
|
||||
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
|
||||
|
||||
# Compute border fractions
|
||||
border_top = top_edge / h
|
||||
border_bottom = (h - bottom_edge) / h
|
||||
border_left = left_edge / w
|
||||
border_right = (w - right_edge) / w
|
||||
|
||||
result["border_fractions"] = {
|
||||
"top": round(border_top, 4),
|
||||
"bottom": round(border_bottom, 4),
|
||||
"left": round(border_left, 4),
|
||||
"right": round(border_right, 4),
|
||||
}
|
||||
|
||||
# Sanity: only crop if at least one edge has > 2% border
|
||||
min_border = 0.02
|
||||
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
|
||||
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
|
||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||
return img_bgr, result
|
||||
|
||||
# Add margin
|
||||
margin_x = int(w * margin_frac)
|
||||
margin_y = int(h * margin_frac)
|
||||
|
||||
crop_x = max(0, left_edge - margin_x)
|
||||
crop_y = max(0, top_edge - margin_y)
|
||||
crop_x2 = min(w, right_edge + margin_x)
|
||||
crop_y2 = min(h, bottom_edge + margin_y)
|
||||
|
||||
crop_w = crop_x2 - crop_x
|
||||
crop_h = crop_y2 - crop_y
|
||||
|
||||
# Sanity: cropped area must be >= 40% of original
|
||||
if crop_w * crop_h < 0.40 * total_area:
|
||||
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
|
||||
100.0 * crop_w * crop_h / total_area)
|
||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||
return img_bgr, result
|
||||
|
||||
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
|
||||
|
||||
detected_format, format_confidence = _detect_format(crop_w, crop_h)
|
||||
|
||||
result["crop_applied"] = True
|
||||
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
|
||||
result["crop_rect_pct"] = {
|
||||
"x": round(100.0 * crop_x / w, 2),
|
||||
"y": round(100.0 * crop_y / h, 2),
|
||||
"width": round(100.0 * crop_w / w, 2),
|
||||
"height": round(100.0 * crop_h / h, 2),
|
||||
}
|
||||
result["cropped_size"] = {"width": crop_w, "height": crop_h}
|
||||
result["detected_format"] = detected_format
|
||||
result["format_confidence"] = format_confidence
|
||||
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
|
||||
|
||||
logger.info(
|
||||
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
|
||||
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
|
||||
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
|
||||
border_top * 100, border_bottom * 100,
|
||||
border_left * 100, border_right * 100,
|
||||
)
|
||||
|
||||
return cropped, result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge detection helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_spine_shadow(
|
||||
gray: np.ndarray,
|
||||
search_region: np.ndarray,
|
||||
offset_x: int,
|
||||
w: int,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""Find the book spine center (darkest point) in a scanner shadow.
|
||||
|
||||
The scanner produces a gray strip where the book spine presses against
|
||||
the glass. The darkest column in that strip is the spine center —
|
||||
that's where we crop.
|
||||
|
||||
Distinguishes real spine shadows from text content by checking:
|
||||
1. Strong brightness range (> 40 levels)
|
||||
2. Darkest point is genuinely dark (< 180 mean brightness)
|
||||
3. The dark area is a NARROW valley, not a text-content plateau
|
||||
4. Brightness rises significantly toward the page content side
|
||||
|
||||
Args:
|
||||
gray: Full grayscale image (for context).
|
||||
search_region: Column slice of the grayscale image to search in.
|
||||
offset_x: X offset of search_region relative to full image.
|
||||
w: Full image width.
|
||||
side: 'left' or 'right' (for logging).
|
||||
|
||||
Returns:
|
||||
X coordinate (in full image) of the spine center, or None.
|
||||
"""
|
||||
region_w = search_region.shape[1]
|
||||
if region_w < 10:
|
||||
return None
|
||||
|
||||
# Column-mean brightness in the search region
|
||||
col_means = np.mean(search_region, axis=0).astype(np.float64)
|
||||
|
||||
# Smooth with boxcar kernel (width = 1% of image width, min 5)
|
||||
kernel_size = max(5, w // 100)
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
kernel = np.ones(kernel_size) / kernel_size
|
||||
smoothed_raw = np.convolve(col_means, kernel, mode="same")
|
||||
|
||||
# Trim convolution edge artifacts (edges are zero-padded → artificially low)
|
||||
margin = kernel_size // 2
|
||||
if region_w <= 2 * margin + 10:
|
||||
return None
|
||||
smoothed = smoothed_raw[margin:region_w - margin]
|
||||
trim_offset = margin # offset of smoothed[0] relative to search_region
|
||||
|
||||
val_min = float(np.min(smoothed))
|
||||
val_max = float(np.max(smoothed))
|
||||
shadow_range = val_max - val_min
|
||||
|
||||
# --- Check 1: Strong brightness gradient ---
|
||||
if shadow_range <= 40:
|
||||
logger.debug(
|
||||
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Check 2: Darkest point must be genuinely dark ---
|
||||
# Spine shadows have mean column brightness 60-160.
|
||||
# Text on white paper stays above 180.
|
||||
if val_min > 180:
|
||||
logger.debug(
|
||||
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
|
||||
)
|
||||
return None
|
||||
|
||||
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
|
||||
spine_local = spine_idx + trim_offset # index in search_region
|
||||
trimmed_len = len(smoothed)
|
||||
|
||||
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
|
||||
# Count how many columns are within 20% of the shadow range above the min.
|
||||
valley_thresh = val_min + shadow_range * 0.20
|
||||
valley_mask = smoothed < valley_thresh
|
||||
valley_width = int(np.sum(valley_mask))
|
||||
# Spine valleys are typically 3-15% of image width (20-120px on a 800px image).
|
||||
# Text content plateaus span 20%+ of the search region.
|
||||
max_valley_frac = 0.50 # valley must not cover more than half the trimmed region
|
||||
if valley_width > trimmed_len * max_valley_frac:
|
||||
logger.debug(
|
||||
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
|
||||
side.capitalize(), valley_width, trimmed_len,
|
||||
100.0 * valley_width / trimmed_len,
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Check 4: Brightness must rise toward page content ---
|
||||
# For left edge: after spine, brightness should rise (= page paper)
|
||||
# For right edge: before spine, brightness should rise
|
||||
rise_check_w = max(5, trimmed_len // 5) # check 20% of trimmed region
|
||||
if side == "left":
|
||||
# Check columns to the right of the spine (in trimmed array)
|
||||
right_start = min(spine_idx + 5, trimmed_len - 1)
|
||||
right_end = min(right_start + rise_check_w, trimmed_len)
|
||||
if right_end > right_start:
|
||||
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
|
||||
rise = rise_brightness - val_min
|
||||
if rise < shadow_range * 0.3:
|
||||
logger.debug(
|
||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||
side.capitalize(), rise, shadow_range * 0.3,
|
||||
)
|
||||
return None
|
||||
else: # right
|
||||
# Check columns to the left of the spine (in trimmed array)
|
||||
left_end = max(spine_idx - 5, 0)
|
||||
left_start = max(left_end - rise_check_w, 0)
|
||||
if left_end > left_start:
|
||||
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
|
||||
rise = rise_brightness - val_min
|
||||
if rise < shadow_range * 0.3:
|
||||
logger.debug(
|
||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||
side.capitalize(), rise, shadow_range * 0.3,
|
||||
)
|
||||
return None
|
||||
|
||||
spine_x = offset_x + spine_local
|
||||
|
||||
logger.info(
|
||||
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
|
||||
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
|
||||
)
|
||||
return spine_x
|
||||
|
||||
|
||||
def _detect_gutter_continuity(
|
||||
gray: np.ndarray,
|
||||
search_region: np.ndarray,
|
||||
offset_x: int,
|
||||
w: int,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""Detect gutter shadow via vertical continuity analysis.
|
||||
|
||||
Camera book scans produce a subtle brightness gradient at the gutter
|
||||
that is too faint for scanner-shadow detection (range < 40). However,
|
||||
the gutter shadow has a unique property: it runs **continuously from
|
||||
top to bottom** without interruption. Text and images always have
|
||||
vertical gaps between lines, paragraphs, or sections.
|
||||
|
||||
Algorithm:
|
||||
1. Divide image into N horizontal strips (~60px each)
|
||||
2. For each column, compute what fraction of strips are darker than
|
||||
the page median (from the center 50% of the full image)
|
||||
3. A "gutter column" has ≥ 75% of strips darker than page_median − δ
|
||||
4. Smooth the dark-fraction profile and find the transition point
|
||||
from the edge inward where the fraction drops below 0.50
|
||||
5. Validate: gutter band must be 0.5%-10% of image width
|
||||
|
||||
Args:
|
||||
gray: Full grayscale image.
|
||||
search_region: Edge slice of the grayscale image.
|
||||
offset_x: X offset of search_region relative to full image.
|
||||
w: Full image width.
|
||||
side: 'left' or 'right'.
|
||||
|
||||
Returns:
|
||||
X coordinate (in full image) of the gutter inner edge, or None.
|
||||
"""
|
||||
region_h, region_w = search_region.shape[:2]
|
||||
if region_w < 20 or region_h < 100:
|
||||
return None
|
||||
|
||||
# --- 1. Divide into horizontal strips ---
|
||||
strip_target_h = 60 # ~60px per strip
|
||||
n_strips = max(10, region_h // strip_target_h)
|
||||
strip_h = region_h // n_strips
|
||||
|
||||
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
|
||||
for s in range(n_strips):
|
||||
y0 = s * strip_h
|
||||
y1 = min((s + 1) * strip_h, region_h)
|
||||
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
|
||||
|
||||
# --- 2. Page median from center 50% of full image ---
|
||||
center_lo = w // 4
|
||||
center_hi = 3 * w // 4
|
||||
page_median = float(np.median(gray[:, center_lo:center_hi]))
|
||||
|
||||
# Camera shadows are subtle — threshold just 5 levels below page median
|
||||
dark_thresh = page_median - 5.0
|
||||
|
||||
# If page is very dark overall (e.g. photo, not a book page), bail out
|
||||
if page_median < 180:
|
||||
return None
|
||||
|
||||
# --- 3. Per-column dark fraction ---
|
||||
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
|
||||
dark_frac = dark_count / n_strips # shape: (region_w,)
|
||||
|
||||
# --- 4. Smooth and find transition ---
|
||||
# Rolling mean (window = 1% of image width, min 5)
|
||||
smooth_w = max(5, w // 100)
|
||||
if smooth_w % 2 == 0:
|
||||
smooth_w += 1
|
||||
kernel = np.ones(smooth_w) / smooth_w
|
||||
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
|
||||
|
||||
# Trim convolution edges
|
||||
margin = smooth_w // 2
|
||||
if region_w <= 2 * margin + 10:
|
||||
return None
|
||||
|
||||
# Find the peak of dark fraction (gutter center).
|
||||
# For right gutters the peak is near the edge; for left gutters
|
||||
# (V-shaped spine shadow) the peak may be well inside the region.
|
||||
transition_thresh = 0.50
|
||||
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
|
||||
|
||||
if peak_frac < 0.70:
|
||||
logger.debug(
|
||||
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
|
||||
)
|
||||
return None
|
||||
|
||||
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
|
||||
gutter_inner = None # local x in search_region
|
||||
|
||||
if side == "right":
|
||||
# Scan from peak toward the page center (leftward)
|
||||
for x in range(peak_x, margin, -1):
|
||||
if frac_smooth[x] < transition_thresh:
|
||||
gutter_inner = x + 1
|
||||
break
|
||||
else:
|
||||
# Scan from peak toward the page center (rightward)
|
||||
for x in range(peak_x, region_w - margin):
|
||||
if frac_smooth[x] < transition_thresh:
|
||||
gutter_inner = x - 1
|
||||
break
|
||||
|
||||
if gutter_inner is None:
|
||||
return None
|
||||
|
||||
# --- 5. Validate gutter width ---
|
||||
if side == "right":
|
||||
gutter_width = region_w - gutter_inner
|
||||
else:
|
||||
gutter_width = gutter_inner
|
||||
|
||||
min_gutter = max(3, int(w * 0.005)) # at least 0.5% of image
|
||||
max_gutter = int(w * 0.10) # at most 10% of image
|
||||
|
||||
if gutter_width < min_gutter:
|
||||
logger.debug(
|
||||
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
|
||||
gutter_width, min_gutter,
|
||||
)
|
||||
return None
|
||||
|
||||
if gutter_width > max_gutter:
|
||||
logger.debug(
|
||||
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
|
||||
gutter_width, max_gutter,
|
||||
)
|
||||
return None
|
||||
|
||||
# Check that the gutter band is meaningfully darker than the page
|
||||
if side == "right":
|
||||
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
|
||||
else:
|
||||
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
|
||||
|
||||
brightness_drop = page_median - gutter_brightness
|
||||
if brightness_drop < 3:
|
||||
logger.debug(
|
||||
"%s gutter: insufficient brightness drop (%.1f levels)",
|
||||
side.capitalize(), brightness_drop,
|
||||
)
|
||||
return None
|
||||
|
||||
gutter_x = offset_x + gutter_inner
|
||||
|
||||
logger.info(
|
||||
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
|
||||
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
|
||||
side.capitalize(), gutter_x, gutter_width,
|
||||
100.0 * gutter_width / w, gutter_brightness, page_median,
|
||||
brightness_drop, float(frac_smooth[gutter_inner]),
|
||||
)
|
||||
return gutter_x
|
||||
|
||||
|
||||
def _detect_left_edge_shadow(
|
||||
gray: np.ndarray,
|
||||
binary: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> int:
|
||||
"""Detect left content edge, accounting for book-spine shadow.
|
||||
|
||||
Tries three methods in order:
|
||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||
3. Binary projection fallback (first ink column)
|
||||
"""
|
||||
search_w = max(1, w // 4)
|
||||
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
|
||||
if spine_x is not None:
|
||||
return spine_x
|
||||
|
||||
# Fallback 1: vertical continuity (camera gutter shadow)
|
||||
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
|
||||
if gutter_x is not None:
|
||||
return gutter_x
|
||||
|
||||
# Fallback 2: binary vertical projection
|
||||
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
|
||||
|
||||
|
||||
def _detect_right_edge_shadow(
|
||||
gray: np.ndarray,
|
||||
binary: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> int:
|
||||
"""Detect right content edge, accounting for book-spine shadow.
|
||||
|
||||
Tries three methods in order:
|
||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||
3. Binary projection fallback (last ink column)
|
||||
"""
|
||||
search_w = max(1, w // 4)
|
||||
right_start = w - search_w
|
||||
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
|
||||
if spine_x is not None:
|
||||
return spine_x
|
||||
|
||||
# Fallback 1: vertical continuity (camera gutter shadow)
|
||||
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
|
||||
if gutter_x is not None:
|
||||
return gutter_x
|
||||
|
||||
# Fallback 2: binary vertical projection
|
||||
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
|
||||
|
||||
|
||||
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
|
||||
"""Detect top and bottom content edges via binary horizontal projection."""
|
||||
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
|
||||
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
|
||||
return top, bottom
|
||||
|
||||
|
||||
def _detect_edge_projection(
|
||||
binary: np.ndarray,
|
||||
axis: int,
|
||||
from_start: bool,
|
||||
dim: int,
|
||||
) -> int:
|
||||
"""Find the first/last row or column with ink density above threshold.
|
||||
|
||||
axis=0 → project vertically (column densities) → returns x position
|
||||
axis=1 → project horizontally (row densities) → returns y position
|
||||
|
||||
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
|
||||
"""
|
||||
# Compute density per row/column (mean of binary pixels / 255)
|
||||
projection = np.mean(binary, axis=axis) / 255.0
|
||||
|
||||
# Create mask of "ink" positions
|
||||
ink_mask = projection >= _INK_THRESHOLD
|
||||
|
||||
# Filter narrow runs (noise)
|
||||
min_run = max(1, int(dim * _MIN_RUN_FRAC))
|
||||
ink_mask = _filter_narrow_runs(ink_mask, min_run)
|
||||
|
||||
ink_positions = np.where(ink_mask)[0]
|
||||
if len(ink_positions) == 0:
|
||||
return 0 if from_start else dim
|
||||
|
||||
if from_start:
|
||||
return int(ink_positions[0])
|
||||
else:
|
||||
return int(ink_positions[-1])
|
||||
|
||||
|
||||
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
|
||||
"""Remove True-runs shorter than min_run pixels."""
|
||||
if min_run <= 1:
|
||||
return mask
|
||||
|
||||
result = mask.copy()
|
||||
n = len(result)
|
||||
i = 0
|
||||
while i < n:
|
||||
if result[i]:
|
||||
start = i
|
||||
while i < n and result[i]:
|
||||
i += 1
|
||||
if i - start < min_run:
|
||||
result[start:i] = False
|
||||
else:
|
||||
i += 1
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format detection (kept as optional metadata)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_format(width: int, height: int) -> Tuple[str, float]:
|
||||
"""Detect paper format from dimensions by comparing aspect ratios."""
|
||||
if width <= 0 or height <= 0:
|
||||
return "unknown", 0.0
|
||||
|
||||
aspect = max(width, height) / min(width, height)
|
||||
|
||||
best_format = "unknown"
|
||||
best_diff = float("inf")
|
||||
|
||||
for fmt, expected_ratio in PAPER_FORMATS.items():
|
||||
diff = abs(aspect - expected_ratio)
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_format = fmt
|
||||
|
||||
confidence = max(0.0, 1.0 - best_diff * 5.0)
|
||||
|
||||
if confidence < 0.3:
|
||||
return "unknown", 0.0
|
||||
|
||||
return best_format, round(confidence, 3)
|
||||
from page_crop_edges import ( # noqa: F401
|
||||
_INK_THRESHOLD,
|
||||
_MIN_RUN_FRAC,
|
||||
_detect_spine_shadow,
|
||||
_detect_gutter_continuity,
|
||||
_detect_left_edge_shadow,
|
||||
_detect_right_edge_shadow,
|
||||
_detect_top_bottom_edges,
|
||||
_detect_edge_projection,
|
||||
_filter_narrow_runs,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Page Crop - Core Crop and Format Detection
|
||||
|
||||
Content-based crop for scanned pages and book scans. Detects the content
|
||||
boundary by analysing ink density projections and (for book scans) the
|
||||
spine shadow gradient.
|
||||
|
||||
Extracted from page_crop.py to keep files under 500 LOC.
|
||||
License: Apache 2.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from page_crop_edges import (
|
||||
_detect_left_edge_shadow,
|
||||
_detect_right_edge_shadow,
|
||||
_detect_top_bottom_edges,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Known paper format aspect ratios (height / width, portrait orientation)
|
||||
PAPER_FORMATS = {
|
||||
"A4": 297.0 / 210.0, # 1.4143
|
||||
"A5": 210.0 / 148.0, # 1.4189
|
||||
"Letter": 11.0 / 8.5, # 1.2941
|
||||
"Legal": 14.0 / 8.5, # 1.6471
|
||||
"A3": 420.0 / 297.0, # 1.4141
|
||||
}
|
||||
|
||||
|
||||
def detect_page_splits(
|
||||
img_bgr: np.ndarray,
|
||||
) -> list:
|
||||
"""Detect if the image is a multi-page spread and return split rectangles.
|
||||
|
||||
Uses **brightness** (not ink density) to find the spine area:
|
||||
the scanner bed produces a characteristic gray strip where pages meet,
|
||||
which is darker than the white paper on either side.
|
||||
|
||||
Returns a list of page dicts ``{x, y, width, height, page_index}``
|
||||
or an empty list if only one page is detected.
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
# Only check landscape-ish images (width > height * 1.15)
|
||||
if w < h * 1.15:
|
||||
return []
|
||||
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
|
||||
col_brightness = np.mean(gray, axis=0).astype(np.float64)
|
||||
|
||||
# Heavy smoothing to ignore individual text lines
|
||||
kern = max(11, w // 50)
|
||||
if kern % 2 == 0:
|
||||
kern += 1
|
||||
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
|
||||
|
||||
# Page paper is bright (typically > 200), spine/scanner bed is darker
|
||||
page_brightness = float(np.max(brightness_smooth))
|
||||
if page_brightness < 100:
|
||||
return [] # Very dark image, skip
|
||||
|
||||
# Spine threshold: significantly darker than the page
|
||||
spine_thresh = page_brightness * 0.88
|
||||
|
||||
# Search in center region (30-70% of width)
|
||||
center_lo = int(w * 0.30)
|
||||
center_hi = int(w * 0.70)
|
||||
|
||||
# Find the darkest valley in the center region
|
||||
center_brightness = brightness_smooth[center_lo:center_hi]
|
||||
darkest_val = float(np.min(center_brightness))
|
||||
|
||||
if darkest_val >= spine_thresh:
|
||||
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
|
||||
darkest_val, spine_thresh)
|
||||
return []
|
||||
|
||||
# Find ALL contiguous dark runs in the center region
|
||||
is_dark = center_brightness < spine_thresh
|
||||
dark_runs: list = []
|
||||
run_start = -1
|
||||
for i in range(len(is_dark)):
|
||||
if is_dark[i]:
|
||||
if run_start < 0:
|
||||
run_start = i
|
||||
else:
|
||||
if run_start >= 0:
|
||||
dark_runs.append((run_start, i))
|
||||
run_start = -1
|
||||
if run_start >= 0:
|
||||
dark_runs.append((run_start, len(is_dark)))
|
||||
|
||||
# Filter out runs that are too narrow (< 1% of image width)
|
||||
min_spine_px = int(w * 0.01)
|
||||
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
|
||||
|
||||
if not dark_runs:
|
||||
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
|
||||
return []
|
||||
|
||||
# Score each dark run: prefer centered, dark, narrow valleys
|
||||
center_region_len = center_hi - center_lo
|
||||
image_center_in_region = (w * 0.5 - center_lo)
|
||||
best_score = -1.0
|
||||
best_start, best_end = dark_runs[0]
|
||||
|
||||
for rs, re in dark_runs:
|
||||
run_width = re - rs
|
||||
run_center = (rs + re) / 2.0
|
||||
|
||||
sigma = center_region_len * 0.15
|
||||
dist = abs(run_center - image_center_in_region)
|
||||
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
|
||||
|
||||
run_brightness = float(np.mean(center_brightness[rs:re]))
|
||||
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
|
||||
|
||||
width_frac = run_width / w
|
||||
if width_frac <= 0.05:
|
||||
narrowness_bonus = 1.0
|
||||
elif width_frac <= 0.15:
|
||||
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10
|
||||
else:
|
||||
narrowness_bonus = 0.0
|
||||
|
||||
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
|
||||
|
||||
logger.debug(
|
||||
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f",
|
||||
center_lo + rs, center_lo + re, run_width,
|
||||
center_factor, darkness_factor, narrowness_bonus, score,
|
||||
)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_start, best_end = rs, re
|
||||
|
||||
spine_w = best_end - best_start
|
||||
spine_x = center_lo + best_start
|
||||
spine_center = spine_x + spine_w // 2
|
||||
|
||||
logger.debug(
|
||||
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
|
||||
spine_x, spine_x + spine_w, spine_w, best_score,
|
||||
)
|
||||
|
||||
# Verify: must have bright (paper) content on BOTH sides
|
||||
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
|
||||
right_end = center_lo + best_end
|
||||
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
|
||||
|
||||
if left_brightness < spine_thresh or right_brightness < spine_thresh:
|
||||
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
|
||||
left_brightness, right_brightness, spine_thresh)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
|
||||
"left_paper=%.0f, right_paper=%.0f",
|
||||
spine_x, right_end, spine_w, darkest_val, page_brightness,
|
||||
left_brightness, right_brightness,
|
||||
)
|
||||
|
||||
# Split at the spine center
|
||||
split_points = [spine_center]
|
||||
|
||||
# Build page rectangles
|
||||
pages: list = []
|
||||
prev_x = 0
|
||||
for i, sx in enumerate(split_points):
|
||||
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
|
||||
"height": h, "page_index": i})
|
||||
prev_x = sx
|
||||
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
|
||||
"height": h, "page_index": len(split_points)})
|
||||
|
||||
# Filter out tiny pages (< 15% of total width)
|
||||
pages = [p for p in pages if p["width"] >= w * 0.15]
|
||||
if len(pages) < 2:
|
||||
return []
|
||||
|
||||
# Re-index
|
||||
for i, p in enumerate(pages):
|
||||
p["page_index"] = i
|
||||
|
||||
logger.info(
|
||||
"Page split detected: %d pages, spine_w=%d, split_points=%s",
|
||||
len(pages), spine_w, split_points,
|
||||
)
|
||||
return pages
|
||||
|
||||
|
||||
def detect_and_crop_page(
|
||||
img_bgr: np.ndarray,
|
||||
margin_frac: float = 0.01,
|
||||
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
||||
"""Detect content boundary and crop scanner/book borders.
|
||||
|
||||
Algorithm (4-edge detection):
|
||||
1. Adaptive threshold -> binary (text=255, bg=0)
|
||||
2. Left edge: spine-shadow detection via grayscale column means,
|
||||
fallback to binary vertical projection
|
||||
3. Right edge: binary vertical projection (last ink column)
|
||||
4. Top/bottom edges: binary horizontal projection
|
||||
5. Sanity checks, then crop with configurable margin
|
||||
|
||||
Args:
|
||||
img_bgr: Input BGR image (should already be deskewed/dewarped)
|
||||
margin_frac: Extra margin around content (fraction of dimension, default 1%)
|
||||
|
||||
Returns:
|
||||
Tuple of (cropped_image, result_dict)
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
total_area = h * w
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"crop_applied": False,
|
||||
"crop_rect": None,
|
||||
"crop_rect_pct": None,
|
||||
"original_size": {"width": w, "height": h},
|
||||
"cropped_size": {"width": w, "height": h},
|
||||
"detected_format": None,
|
||||
"format_confidence": 0.0,
|
||||
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
|
||||
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
|
||||
}
|
||||
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# --- Binarise with adaptive threshold ---
|
||||
binary = cv2.adaptiveThreshold(
|
||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
|
||||
)
|
||||
|
||||
# --- Edge detection ---
|
||||
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
|
||||
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
|
||||
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
|
||||
|
||||
# Compute border fractions
|
||||
border_top = top_edge / h
|
||||
border_bottom = (h - bottom_edge) / h
|
||||
border_left = left_edge / w
|
||||
border_right = (w - right_edge) / w
|
||||
|
||||
result["border_fractions"] = {
|
||||
"top": round(border_top, 4),
|
||||
"bottom": round(border_bottom, 4),
|
||||
"left": round(border_left, 4),
|
||||
"right": round(border_right, 4),
|
||||
}
|
||||
|
||||
# Sanity: only crop if at least one edge has > 2% border
|
||||
min_border = 0.02
|
||||
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
|
||||
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
|
||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||
return img_bgr, result
|
||||
|
||||
# Add margin
|
||||
margin_x = int(w * margin_frac)
|
||||
margin_y = int(h * margin_frac)
|
||||
|
||||
crop_x = max(0, left_edge - margin_x)
|
||||
crop_y = max(0, top_edge - margin_y)
|
||||
crop_x2 = min(w, right_edge + margin_x)
|
||||
crop_y2 = min(h, bottom_edge + margin_y)
|
||||
|
||||
crop_w = crop_x2 - crop_x
|
||||
crop_h = crop_y2 - crop_y
|
||||
|
||||
# Sanity: cropped area must be >= 40% of original
|
||||
if crop_w * crop_h < 0.40 * total_area:
|
||||
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
|
||||
100.0 * crop_w * crop_h / total_area)
|
||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||
return img_bgr, result
|
||||
|
||||
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
|
||||
|
||||
detected_format, format_confidence = _detect_format(crop_w, crop_h)
|
||||
|
||||
result["crop_applied"] = True
|
||||
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
|
||||
result["crop_rect_pct"] = {
|
||||
"x": round(100.0 * crop_x / w, 2),
|
||||
"y": round(100.0 * crop_y / h, 2),
|
||||
"width": round(100.0 * crop_w / w, 2),
|
||||
"height": round(100.0 * crop_h / h, 2),
|
||||
}
|
||||
result["cropped_size"] = {"width": crop_w, "height": crop_h}
|
||||
result["detected_format"] = detected_format
|
||||
result["format_confidence"] = format_confidence
|
||||
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
|
||||
|
||||
logger.info(
|
||||
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
|
||||
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
|
||||
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
|
||||
border_top * 100, border_bottom * 100,
|
||||
border_left * 100, border_right * 100,
|
||||
)
|
||||
|
||||
return cropped, result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format detection (kept as optional metadata)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_format(width: int, height: int) -> Tuple[str, float]:
|
||||
"""Detect paper format from dimensions by comparing aspect ratios."""
|
||||
if width <= 0 or height <= 0:
|
||||
return "unknown", 0.0
|
||||
|
||||
aspect = max(width, height) / min(width, height)
|
||||
|
||||
best_format = "unknown"
|
||||
best_diff = float("inf")
|
||||
|
||||
for fmt, expected_ratio in PAPER_FORMATS.items():
|
||||
diff = abs(aspect - expected_ratio)
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_format = fmt
|
||||
|
||||
confidence = max(0.0, 1.0 - best_diff * 5.0)
|
||||
|
||||
if confidence < 0.3:
|
||||
return "unknown", 0.0
|
||||
|
||||
return best_format, round(confidence, 3)
|
||||
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Page Crop - Edge Detection Helpers
|
||||
|
||||
Spine shadow detection, gutter continuity analysis, projection-based
|
||||
edge detection, and narrow-run filtering for content cropping.
|
||||
|
||||
Extracted from page_crop.py to keep files under 500 LOC.
|
||||
License: Apache 2.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum ink density (fraction of pixels) to count a row/column as "content"
|
||||
_INK_THRESHOLD = 0.003 # 0.3%
|
||||
|
||||
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
|
||||
_MIN_RUN_FRAC = 0.005 # 0.5%
|
||||
|
||||
|
||||
def _detect_spine_shadow(
|
||||
gray: np.ndarray,
|
||||
search_region: np.ndarray,
|
||||
offset_x: int,
|
||||
w: int,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""Find the book spine center (darkest point) in a scanner shadow.
|
||||
|
||||
The scanner produces a gray strip where the book spine presses against
|
||||
the glass. The darkest column in that strip is the spine center —
|
||||
that's where we crop.
|
||||
|
||||
Distinguishes real spine shadows from text content by checking:
|
||||
1. Strong brightness range (> 40 levels)
|
||||
2. Darkest point is genuinely dark (< 180 mean brightness)
|
||||
3. The dark area is a NARROW valley, not a text-content plateau
|
||||
4. Brightness rises significantly toward the page content side
|
||||
|
||||
Args:
|
||||
gray: Full grayscale image (for context).
|
||||
search_region: Column slice of the grayscale image to search in.
|
||||
offset_x: X offset of search_region relative to full image.
|
||||
w: Full image width.
|
||||
side: 'left' or 'right' (for logging).
|
||||
|
||||
Returns:
|
||||
X coordinate (in full image) of the spine center, or None.
|
||||
"""
|
||||
region_w = search_region.shape[1]
|
||||
if region_w < 10:
|
||||
return None
|
||||
|
||||
# Column-mean brightness in the search region
|
||||
col_means = np.mean(search_region, axis=0).astype(np.float64)
|
||||
|
||||
# Smooth with boxcar kernel (width = 1% of image width, min 5)
|
||||
kernel_size = max(5, w // 100)
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
kernel = np.ones(kernel_size) / kernel_size
|
||||
smoothed_raw = np.convolve(col_means, kernel, mode="same")
|
||||
|
||||
# Trim convolution edge artifacts (edges are zero-padded -> artificially low)
|
||||
margin = kernel_size // 2
|
||||
if region_w <= 2 * margin + 10:
|
||||
return None
|
||||
smoothed = smoothed_raw[margin:region_w - margin]
|
||||
trim_offset = margin # offset of smoothed[0] relative to search_region
|
||||
|
||||
val_min = float(np.min(smoothed))
|
||||
val_max = float(np.max(smoothed))
|
||||
shadow_range = val_max - val_min
|
||||
|
||||
# --- Check 1: Strong brightness gradient ---
|
||||
if shadow_range <= 40:
|
||||
logger.debug(
|
||||
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Check 2: Darkest point must be genuinely dark ---
|
||||
if val_min > 180:
|
||||
logger.debug(
|
||||
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
|
||||
)
|
||||
return None
|
||||
|
||||
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
|
||||
spine_local = spine_idx + trim_offset # index in search_region
|
||||
trimmed_len = len(smoothed)
|
||||
|
||||
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
|
||||
valley_thresh = val_min + shadow_range * 0.20
|
||||
valley_mask = smoothed < valley_thresh
|
||||
valley_width = int(np.sum(valley_mask))
|
||||
max_valley_frac = 0.50
|
||||
if valley_width > trimmed_len * max_valley_frac:
|
||||
logger.debug(
|
||||
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
|
||||
side.capitalize(), valley_width, trimmed_len,
|
||||
100.0 * valley_width / trimmed_len,
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Check 4: Brightness must rise toward page content ---
|
||||
rise_check_w = max(5, trimmed_len // 5)
|
||||
if side == "left":
|
||||
right_start = min(spine_idx + 5, trimmed_len - 1)
|
||||
right_end = min(right_start + rise_check_w, trimmed_len)
|
||||
if right_end > right_start:
|
||||
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
|
||||
rise = rise_brightness - val_min
|
||||
if rise < shadow_range * 0.3:
|
||||
logger.debug(
|
||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||
side.capitalize(), rise, shadow_range * 0.3,
|
||||
)
|
||||
return None
|
||||
else: # right
|
||||
left_end = max(spine_idx - 5, 0)
|
||||
left_start = max(left_end - rise_check_w, 0)
|
||||
if left_end > left_start:
|
||||
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
|
||||
rise = rise_brightness - val_min
|
||||
if rise < shadow_range * 0.3:
|
||||
logger.debug(
|
||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||
side.capitalize(), rise, shadow_range * 0.3,
|
||||
)
|
||||
return None
|
||||
|
||||
spine_x = offset_x + spine_local
|
||||
|
||||
logger.info(
|
||||
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
|
||||
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
|
||||
)
|
||||
return spine_x
|
||||
|
||||
|
||||
def _detect_gutter_continuity(
|
||||
gray: np.ndarray,
|
||||
search_region: np.ndarray,
|
||||
offset_x: int,
|
||||
w: int,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""Detect gutter shadow via vertical continuity analysis.
|
||||
|
||||
Camera book scans produce a subtle brightness gradient at the gutter
|
||||
that is too faint for scanner-shadow detection (range < 40). However,
|
||||
the gutter shadow has a unique property: it runs **continuously from
|
||||
top to bottom** without interruption.
|
||||
|
||||
Algorithm:
|
||||
1. Divide image into N horizontal strips (~60px each)
|
||||
2. For each column, compute what fraction of strips are darker than
|
||||
the page median (from the center 50% of the full image)
|
||||
3. A "gutter column" has >= 75% of strips darker than page_median - d
|
||||
4. Smooth the dark-fraction profile and find the transition point
|
||||
5. Validate: gutter band must be 0.5%-10% of image width
|
||||
"""
|
||||
region_h, region_w = search_region.shape[:2]
|
||||
if region_w < 20 or region_h < 100:
|
||||
return None
|
||||
|
||||
# --- 1. Divide into horizontal strips ---
|
||||
strip_target_h = 60
|
||||
n_strips = max(10, region_h // strip_target_h)
|
||||
strip_h = region_h // n_strips
|
||||
|
||||
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
|
||||
for s in range(n_strips):
|
||||
y0 = s * strip_h
|
||||
y1 = min((s + 1) * strip_h, region_h)
|
||||
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
|
||||
|
||||
# --- 2. Page median from center 50% of full image ---
|
||||
center_lo = w // 4
|
||||
center_hi = 3 * w // 4
|
||||
page_median = float(np.median(gray[:, center_lo:center_hi]))
|
||||
|
||||
dark_thresh = page_median - 5.0
|
||||
|
||||
if page_median < 180:
|
||||
return None
|
||||
|
||||
# --- 3. Per-column dark fraction ---
|
||||
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
|
||||
dark_frac = dark_count / n_strips
|
||||
|
||||
# --- 4. Smooth and find transition ---
|
||||
smooth_w = max(5, w // 100)
|
||||
if smooth_w % 2 == 0:
|
||||
smooth_w += 1
|
||||
kernel = np.ones(smooth_w) / smooth_w
|
||||
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
|
||||
|
||||
margin = smooth_w // 2
|
||||
if region_w <= 2 * margin + 10:
|
||||
return None
|
||||
|
||||
transition_thresh = 0.50
|
||||
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
|
||||
|
||||
if peak_frac < 0.70:
|
||||
logger.debug(
|
||||
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
|
||||
)
|
||||
return None
|
||||
|
||||
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
|
||||
gutter_inner = None
|
||||
|
||||
if side == "right":
|
||||
for x in range(peak_x, margin, -1):
|
||||
if frac_smooth[x] < transition_thresh:
|
||||
gutter_inner = x + 1
|
||||
break
|
||||
else:
|
||||
for x in range(peak_x, region_w - margin):
|
||||
if frac_smooth[x] < transition_thresh:
|
||||
gutter_inner = x - 1
|
||||
break
|
||||
|
||||
if gutter_inner is None:
|
||||
return None
|
||||
|
||||
# --- 5. Validate gutter width ---
|
||||
if side == "right":
|
||||
gutter_width = region_w - gutter_inner
|
||||
else:
|
||||
gutter_width = gutter_inner
|
||||
|
||||
min_gutter = max(3, int(w * 0.005))
|
||||
max_gutter = int(w * 0.10)
|
||||
|
||||
if gutter_width < min_gutter:
|
||||
logger.debug(
|
||||
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
|
||||
gutter_width, min_gutter,
|
||||
)
|
||||
return None
|
||||
|
||||
if gutter_width > max_gutter:
|
||||
logger.debug(
|
||||
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
|
||||
gutter_width, max_gutter,
|
||||
)
|
||||
return None
|
||||
|
||||
if side == "right":
|
||||
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
|
||||
else:
|
||||
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
|
||||
|
||||
brightness_drop = page_median - gutter_brightness
|
||||
if brightness_drop < 3:
|
||||
logger.debug(
|
||||
"%s gutter: insufficient brightness drop (%.1f levels)",
|
||||
side.capitalize(), brightness_drop,
|
||||
)
|
||||
return None
|
||||
|
||||
gutter_x = offset_x + gutter_inner
|
||||
|
||||
logger.info(
|
||||
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
|
||||
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
|
||||
side.capitalize(), gutter_x, gutter_width,
|
||||
100.0 * gutter_width / w, gutter_brightness, page_median,
|
||||
brightness_drop, float(frac_smooth[gutter_inner]),
|
||||
)
|
||||
return gutter_x
|
||||
|
||||
|
||||
def _detect_left_edge_shadow(
|
||||
gray: np.ndarray,
|
||||
binary: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> int:
|
||||
"""Detect left content edge, accounting for book-spine shadow.
|
||||
|
||||
Tries three methods in order:
|
||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||
3. Binary projection fallback (first ink column)
|
||||
"""
|
||||
search_w = max(1, w // 4)
|
||||
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
|
||||
if spine_x is not None:
|
||||
return spine_x
|
||||
|
||||
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
|
||||
if gutter_x is not None:
|
||||
return gutter_x
|
||||
|
||||
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
|
||||
|
||||
|
||||
def _detect_right_edge_shadow(
|
||||
gray: np.ndarray,
|
||||
binary: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> int:
|
||||
"""Detect right content edge, accounting for book-spine shadow.
|
||||
|
||||
Tries three methods in order:
|
||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||
3. Binary projection fallback (last ink column)
|
||||
"""
|
||||
search_w = max(1, w // 4)
|
||||
right_start = w - search_w
|
||||
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
|
||||
if spine_x is not None:
|
||||
return spine_x
|
||||
|
||||
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
|
||||
if gutter_x is not None:
|
||||
return gutter_x
|
||||
|
||||
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
|
||||
|
||||
|
||||
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
|
||||
"""Detect top and bottom content edges via binary horizontal projection."""
|
||||
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
|
||||
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
|
||||
return top, bottom
|
||||
|
||||
|
||||
def _detect_edge_projection(
|
||||
binary: np.ndarray,
|
||||
axis: int,
|
||||
from_start: bool,
|
||||
dim: int,
|
||||
) -> int:
|
||||
"""Find the first/last row or column with ink density above threshold.
|
||||
|
||||
axis=0 -> project vertically (column densities) -> returns x position
|
||||
axis=1 -> project horizontally (row densities) -> returns y position
|
||||
|
||||
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
|
||||
"""
|
||||
projection = np.mean(binary, axis=axis) / 255.0
|
||||
|
||||
ink_mask = projection >= _INK_THRESHOLD
|
||||
|
||||
min_run = max(1, int(dim * _MIN_RUN_FRAC))
|
||||
ink_mask = _filter_narrow_runs(ink_mask, min_run)
|
||||
|
||||
ink_positions = np.where(ink_mask)[0]
|
||||
if len(ink_positions) == 0:
|
||||
return 0 if from_start else dim
|
||||
|
||||
if from_start:
|
||||
return int(ink_positions[0])
|
||||
else:
|
||||
return int(ink_positions[-1])
|
||||
|
||||
|
||||
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
|
||||
"""Remove True-runs shorter than min_run pixels."""
|
||||
if min_run <= 1:
|
||||
return mask
|
||||
|
||||
result = mask.copy()
|
||||
n = len(result)
|
||||
i = 0
|
||||
while i < n:
|
||||
if result[i]:
|
||||
start = i
|
||||
while i < n and result[i]:
|
||||
i += 1
|
||||
if i - start < min_run:
|
||||
result[start:i] = False
|
||||
else:
|
||||
i += 1
|
||||
return result
|
||||
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
TrOCR Batch Processing & Streaming
|
||||
|
||||
Batch OCR and SSE streaming for multiple images.
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from .trocr_models import OCRResult, BatchOCRResult
|
||||
from .trocr_ocr import run_trocr_ocr_enhanced
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_trocr_batch(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> BatchOCRResult:
|
||||
"""
|
||||
Process multiple images in batch.
|
||||
|
||||
Args:
|
||||
images: List of image data bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split images into lines
|
||||
use_cache: Whether to use caching
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
BatchOCRResult with all results
|
||||
"""
|
||||
start_time = time.time()
|
||||
results = []
|
||||
cached_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if result.from_cache:
|
||||
cached_count += 1
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, len(images))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch OCR error for image {idx}: {e}")
|
||||
error_count += 1
|
||||
results.append(OCRResult(
|
||||
text=f"Error: {str(e)}",
|
||||
confidence=0.0,
|
||||
processing_time_ms=0,
|
||||
model="error",
|
||||
has_lora_adapter=False
|
||||
))
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return BatchOCRResult(
|
||||
results=results,
|
||||
total_time_ms=total_time_ms,
|
||||
processed_count=len(images),
|
||||
cached_count=cached_count,
|
||||
error_count=error_count
|
||||
)
|
||||
|
||||
|
||||
# Generator for SSE streaming during batch processing
|
||||
async def run_trocr_batch_stream(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
):
|
||||
"""
|
||||
Process images and yield progress updates for SSE streaming.
|
||||
|
||||
Yields:
|
||||
dict with current progress and result
|
||||
"""
|
||||
start_time = time.time()
|
||||
total = len(images)
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
avg_time_per_image = elapsed_ms / (idx + 1)
|
||||
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
|
||||
|
||||
yield {
|
||||
"type": "progress",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"progress_percent": ((idx + 1) / total) * 100,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"estimated_remaining_ms": estimated_remaining,
|
||||
"result": {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"processing_time_ms": result.processing_time_ms,
|
||||
"from_cache": result.from_cache
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream OCR error for image {idx}: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
yield {
|
||||
"type": "complete",
|
||||
"total_time_ms": total_time_ms,
|
||||
"processed_count": total
|
||||
}
|
||||
|
||||
|
||||
# Test function
|
||||
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
|
||||
"""Test TrOCR on a local image file."""
|
||||
from .trocr_ocr import run_trocr_ocr
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
|
||||
|
||||
print(f"\n=== TrOCR Test ===")
|
||||
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
|
||||
print(f"Confidence: {confidence:.2f}")
|
||||
print(f"Text:\n{text}")
|
||||
|
||||
return text, confidence
|
||||
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
TrOCR Models & Cache
|
||||
|
||||
Dataclasses, LRU cache, and model loading for TrOCR service.
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend routing: auto | pytorch | onnx
|
||||
# ---------------------------------------------------------------------------
|
||||
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
# Cache keyed by model_name to support base and large variants simultaneously
|
||||
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||
_trocr_processor = None # backwards-compat alias -> base-printed
|
||||
_trocr_model = None # backwards-compat alias -> base-printed
|
||||
_trocr_available = None
|
||||
_model_loaded_at = None
|
||||
|
||||
# Simple in-memory cache with LRU eviction
|
||||
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
_cache_max_size = 100
|
||||
_cache_ttl_seconds = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Enhanced OCR result with detailed information."""
|
||||
text: str
|
||||
confidence: float
|
||||
processing_time_ms: int
|
||||
model: str
|
||||
has_lora_adapter: bool = False
|
||||
char_confidences: List[float] = field(default_factory=list)
|
||||
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
from_cache: bool = False
|
||||
image_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchOCRResult:
|
||||
"""Result for batch processing."""
|
||||
results: List[OCRResult]
|
||||
total_time_ms: int
|
||||
processed_count: int
|
||||
cached_count: int
|
||||
error_count: int
|
||||
|
||||
|
||||
def _compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data for caching."""
|
||||
return hashlib.sha256(image_data).hexdigest()[:16]
|
||||
|
||||
|
||||
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached OCR result if available and not expired."""
|
||||
if image_hash in _ocr_cache:
|
||||
entry = _ocr_cache[image_hash]
|
||||
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
||||
# Move to end (LRU)
|
||||
_ocr_cache.move_to_end(image_hash)
|
||||
return entry["result"]
|
||||
else:
|
||||
# Expired, remove
|
||||
del _ocr_cache[image_hash]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
||||
"""Store OCR result in cache."""
|
||||
# Evict oldest if at capacity
|
||||
while len(_ocr_cache) >= _cache_max_size:
|
||||
_ocr_cache.popitem(last=False)
|
||||
|
||||
_ocr_cache[image_hash] = {
|
||||
"result": result,
|
||||
"cached_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
def get_cache_stats() -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"size": len(_ocr_cache),
|
||||
"max_size": _cache_max_size,
|
||||
"ttl_seconds": _cache_ttl_seconds,
|
||||
"hit_rate": 0 # Could track this with additional counters
|
||||
}
|
||||
|
||||
|
||||
def _check_trocr_available() -> bool:
|
||||
"""Check if TrOCR dependencies are available."""
|
||||
global _trocr_available
|
||||
if _trocr_available is not None:
|
||||
return _trocr_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
_trocr_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"TrOCR dependencies not available: {e}")
|
||||
_trocr_available = False
|
||||
|
||||
return _trocr_available
|
||||
|
||||
|
||||
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
||||
"""
|
||||
Lazy load TrOCR model and processor.
|
||||
|
||||
Args:
|
||||
handwritten: Use handwritten model instead of printed model
|
||||
size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy
|
||||
for exam HTR). Only applies to handwritten variant.
|
||||
|
||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _trocr_processor, _trocr_model
|
||||
|
||||
if not _check_trocr_available():
|
||||
return None, None
|
||||
|
||||
# Select model name
|
||||
if size == "large" and handwritten:
|
||||
model_name = "microsoft/trocr-large-handwritten"
|
||||
elif handwritten:
|
||||
model_name = "microsoft/trocr-base-handwritten"
|
||||
else:
|
||||
model_name = "microsoft/trocr-base-printed"
|
||||
|
||||
if model_name in _trocr_models:
|
||||
return _trocr_models[model_name]
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
logger.info(f"Loading TrOCR model: {model_name}")
|
||||
processor = TrOCRProcessor.from_pretrained(model_name)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
|
||||
# Use GPU if available
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model.to(device)
|
||||
logger.info(f"TrOCR model loaded on device: {device}")
|
||||
|
||||
_trocr_models[model_name] = (processor, model)
|
||||
|
||||
# Keep backwards-compat globals pointing at base-printed
|
||||
if model_name == "microsoft/trocr-base-printed":
|
||||
_trocr_processor = processor
|
||||
_trocr_model = model
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def preload_trocr_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload TrOCR model at startup for faster first request.
|
||||
|
||||
Call this from your FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_trocr_model()
|
||||
"""
|
||||
global _model_loaded_at
|
||||
logger.info("Preloading TrOCR model...")
|
||||
processor, model = get_trocr_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
_model_loaded_at = datetime.now()
|
||||
logger.info("TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
def get_model_status() -> Dict[str, Any]:
|
||||
"""Get current model status information."""
|
||||
processor, model = get_trocr_model(handwritten=True)
|
||||
is_loaded = processor is not None and model is not None
|
||||
|
||||
status = {
|
||||
"status": "available" if is_loaded else "not_installed",
|
||||
"is_loaded": is_loaded,
|
||||
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
||||
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
||||
}
|
||||
|
||||
if is_loaded:
|
||||
import torch
|
||||
device = next(model.parameters()).device
|
||||
status["device"] = str(device)
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def get_active_backend() -> str:
|
||||
"""
|
||||
Return which TrOCR backend is configured.
|
||||
|
||||
Possible values: "auto", "pytorch", "onnx".
|
||||
"""
|
||||
return _trocr_backend
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
"""
|
||||
Split an image into text lines using simple projection-based segmentation.
|
||||
|
||||
This is a basic implementation - for production use, consider using
|
||||
a dedicated line detection model.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = image.convert('L')
|
||||
img_array = np.array(gray)
|
||||
|
||||
# Binarize (simple threshold)
|
||||
threshold = 200
|
||||
binary = img_array < threshold
|
||||
|
||||
# Horizontal projection (sum of dark pixels per row)
|
||||
h_proj = np.sum(binary, axis=1)
|
||||
|
||||
# Find line boundaries (where projection drops below threshold)
|
||||
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
||||
in_line = False
|
||||
line_start = 0
|
||||
lines = []
|
||||
|
||||
for i, val in enumerate(h_proj):
|
||||
if val > line_threshold and not in_line:
|
||||
# Start of line
|
||||
in_line = True
|
||||
line_start = i
|
||||
elif val <= line_threshold and in_line:
|
||||
# End of line
|
||||
in_line = False
|
||||
# Add padding
|
||||
start = max(0, line_start - 5)
|
||||
end = min(img_array.shape[0], i + 5)
|
||||
if end - start > 10: # Minimum line height
|
||||
lines.append(image.crop((0, start, image.width, end)))
|
||||
|
||||
# Handle last line if still in_line
|
||||
if in_line:
|
||||
start = max(0, line_start - 5)
|
||||
lines.append(image.crop((0, start, image.width, image.height)))
|
||||
|
||||
logger.info(f"Split image into {len(lines)} lines")
|
||||
return lines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Line splitting failed: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
TrOCR OCR Execution
|
||||
|
||||
Core OCR inference routines (PyTorch, ONNX routing, enhanced mode).
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
|
||||
from .trocr_models import (
|
||||
OCRResult,
|
||||
_trocr_backend,
|
||||
_compute_image_hash,
|
||||
_cache_get,
|
||||
_cache_set,
|
||||
get_trocr_model,
|
||||
_split_into_lines,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_onnx_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Optional[Tuple[Optional[str], float]]:
|
||||
"""
|
||||
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
||||
success, or None if ONNX is not available / fails to load.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
# run_trocr_onnx is async -- return the coroutine's awaitable result
|
||||
# The caller (run_trocr_ocr) will await it.
|
||||
return run_trocr_onnx # sentinel: caller checks callable
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def _run_pytorch_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Original PyTorch inference path (extracted for routing).
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR PyTorch model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR PyTorch failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
- "onnx" -- always use ONNX (raises RuntimeError if unavailable)
|
||||
- "pytorch" -- always use PyTorch (original behaviour)
|
||||
- "auto" -- try ONNX first, fall back to PyTorch
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
2. Process the whole image and get partial results
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model (slower but better for handwriting)
|
||||
split_lines: Whether to split image into lines first
|
||||
size: "base" or "large" (only for handwritten variant)
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is None or not callable(onnx_fn):
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
# --- PyTorch-only mode ---
|
||||
if backend == "pytorch":
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first, then PyTorch ---
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is not None and callable(onnx_fn):
|
||||
try:
|
||||
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if result[0] is not None:
|
||||
return result
|
||||
logger.warning("ONNX returned None text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
||||
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
|
||||
def _try_onnx_enhanced(
|
||||
handwritten: bool = True,
|
||||
):
|
||||
"""
|
||||
Return the ONNX enhanced coroutine function, or None if unavailable.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
return run_trocr_onnx_enhanced
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split image into lines first
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is None:
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first ---
|
||||
if backend == "auto":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is not None:
|
||||
try:
|
||||
result = await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
if result.text:
|
||||
return result
|
||||
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
||||
|
||||
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR via PyTorch
|
||||
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for idx, word in enumerate(words):
|
||||
# Simulate word confidence (slightly varied around overall confidence)
|
||||
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences = []
|
||||
if text:
|
||||
for char in text:
|
||||
# Simulate per-character confidence
|
||||
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
|
||||
has_lora_adapter=False, # Would check actual adapter status
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -1,720 +1,70 @@
|
||||
"""
|
||||
TrOCR Service
|
||||
TrOCR Service — Barrel Re-export
|
||||
|
||||
Microsoft's Transformer-based OCR for text recognition.
|
||||
Besonders geeignet fuer:
|
||||
- Gedruckten Text
|
||||
- Saubere Scans
|
||||
- Schnelle Verarbeitung
|
||||
|
||||
Model: microsoft/trocr-base-printed (oder handwritten Variante)
|
||||
Split into submodules:
|
||||
- trocr_models.py — Dataclasses, cache, model loading, line splitting
|
||||
- trocr_ocr.py — Core OCR inference (PyTorch/ONNX routing, enhanced)
|
||||
- trocr_batch.py — Batch processing and SSE streaming
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
|
||||
Phase 2 Enhancements:
|
||||
- Batch processing for multiple images
|
||||
- SHA256-based caching for repeated requests
|
||||
- Model preloading for faster first request
|
||||
- Word-level bounding boxes with confidence
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend routing: auto | pytorch | onnx
|
||||
# ---------------------------------------------------------------------------
|
||||
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
# Cache keyed by model_name to support base and large variants simultaneously
|
||||
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||
_trocr_processor = None # backwards-compat alias → base-printed
|
||||
_trocr_model = None # backwards-compat alias → base-printed
|
||||
_trocr_available = None
|
||||
_model_loaded_at = None
|
||||
|
||||
# Simple in-memory cache with LRU eviction
|
||||
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
_cache_max_size = 100
|
||||
_cache_ttl_seconds = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Enhanced OCR result with detailed information."""
|
||||
text: str
|
||||
confidence: float
|
||||
processing_time_ms: int
|
||||
model: str
|
||||
has_lora_adapter: bool = False
|
||||
char_confidences: List[float] = field(default_factory=list)
|
||||
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
from_cache: bool = False
|
||||
image_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchOCRResult:
|
||||
"""Result for batch processing."""
|
||||
results: List[OCRResult]
|
||||
total_time_ms: int
|
||||
processed_count: int
|
||||
cached_count: int
|
||||
error_count: int
|
||||
|
||||
|
||||
def _compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data for caching."""
|
||||
return hashlib.sha256(image_data).hexdigest()[:16]
|
||||
|
||||
|
||||
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached OCR result if available and not expired."""
|
||||
if image_hash in _ocr_cache:
|
||||
entry = _ocr_cache[image_hash]
|
||||
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
||||
# Move to end (LRU)
|
||||
_ocr_cache.move_to_end(image_hash)
|
||||
return entry["result"]
|
||||
else:
|
||||
# Expired, remove
|
||||
del _ocr_cache[image_hash]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
||||
"""Store OCR result in cache."""
|
||||
# Evict oldest if at capacity
|
||||
while len(_ocr_cache) >= _cache_max_size:
|
||||
_ocr_cache.popitem(last=False)
|
||||
|
||||
_ocr_cache[image_hash] = {
|
||||
"result": result,
|
||||
"cached_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
def get_cache_stats() -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"size": len(_ocr_cache),
|
||||
"max_size": _cache_max_size,
|
||||
"ttl_seconds": _cache_ttl_seconds,
|
||||
"hit_rate": 0 # Could track this with additional counters
|
||||
}
|
||||
|
||||
|
||||
def _check_trocr_available() -> bool:
|
||||
"""Check if TrOCR dependencies are available."""
|
||||
global _trocr_available
|
||||
if _trocr_available is not None:
|
||||
return _trocr_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
_trocr_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"TrOCR dependencies not available: {e}")
|
||||
_trocr_available = False
|
||||
|
||||
return _trocr_available
|
||||
|
||||
|
||||
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
||||
"""
|
||||
Lazy load TrOCR model and processor.
|
||||
|
||||
Args:
|
||||
handwritten: Use handwritten model instead of printed model
|
||||
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
|
||||
for exam HTR). Only applies to handwritten variant.
|
||||
|
||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _trocr_processor, _trocr_model
|
||||
|
||||
if not _check_trocr_available():
|
||||
return None, None
|
||||
|
||||
# Select model name
|
||||
if size == "large" and handwritten:
|
||||
model_name = "microsoft/trocr-large-handwritten"
|
||||
elif handwritten:
|
||||
model_name = "microsoft/trocr-base-handwritten"
|
||||
else:
|
||||
model_name = "microsoft/trocr-base-printed"
|
||||
|
||||
if model_name in _trocr_models:
|
||||
return _trocr_models[model_name]
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
logger.info(f"Loading TrOCR model: {model_name}")
|
||||
processor = TrOCRProcessor.from_pretrained(model_name)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
|
||||
# Use GPU if available
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model.to(device)
|
||||
logger.info(f"TrOCR model loaded on device: {device}")
|
||||
|
||||
_trocr_models[model_name] = (processor, model)
|
||||
|
||||
# Keep backwards-compat globals pointing at base-printed
|
||||
if model_name == "microsoft/trocr-base-printed":
|
||||
_trocr_processor = processor
|
||||
_trocr_model = model
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def preload_trocr_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload TrOCR model at startup for faster first request.
|
||||
|
||||
Call this from your FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_trocr_model()
|
||||
"""
|
||||
global _model_loaded_at
|
||||
logger.info("Preloading TrOCR model...")
|
||||
processor, model = get_trocr_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
_model_loaded_at = datetime.now()
|
||||
logger.info("TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
def get_model_status() -> Dict[str, Any]:
|
||||
"""Get current model status information."""
|
||||
processor, model = get_trocr_model(handwritten=True)
|
||||
is_loaded = processor is not None and model is not None
|
||||
|
||||
status = {
|
||||
"status": "available" if is_loaded else "not_installed",
|
||||
"is_loaded": is_loaded,
|
||||
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
||||
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
||||
}
|
||||
|
||||
if is_loaded:
|
||||
import torch
|
||||
device = next(model.parameters()).device
|
||||
status["device"] = str(device)
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def get_active_backend() -> str:
|
||||
"""
|
||||
Return which TrOCR backend is configured.
|
||||
|
||||
Possible values: "auto", "pytorch", "onnx".
|
||||
"""
|
||||
return _trocr_backend
|
||||
|
||||
|
||||
def _try_onnx_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Optional[Tuple[Optional[str], float]]:
|
||||
"""
|
||||
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
||||
success, or None if ONNX is not available / fails to load.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
# run_trocr_onnx is async — return the coroutine's awaitable result
|
||||
# The caller (run_trocr_ocr) will await it.
|
||||
return run_trocr_onnx # sentinel: caller checks callable
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def _run_pytorch_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Original PyTorch inference path (extracted for routing).
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR PyTorch model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR PyTorch failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
|
||||
- "pytorch" — always use PyTorch (original behaviour)
|
||||
- "auto" — try ONNX first, fall back to PyTorch
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
2. Process the whole image and get partial results
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model (slower but better for handwriting)
|
||||
split_lines: Whether to split image into lines first
|
||||
size: "base" or "large" (only for handwritten variant)
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is None or not callable(onnx_fn):
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
# --- PyTorch-only mode ---
|
||||
if backend == "pytorch":
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first, then PyTorch ---
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is not None and callable(onnx_fn):
|
||||
try:
|
||||
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if result[0] is not None:
|
||||
return result
|
||||
logger.warning("ONNX returned None text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
||||
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
"""
|
||||
Split an image into text lines using simple projection-based segmentation.
|
||||
|
||||
This is a basic implementation - for production use, consider using
|
||||
a dedicated line detection model.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = image.convert('L')
|
||||
img_array = np.array(gray)
|
||||
|
||||
# Binarize (simple threshold)
|
||||
threshold = 200
|
||||
binary = img_array < threshold
|
||||
|
||||
# Horizontal projection (sum of dark pixels per row)
|
||||
h_proj = np.sum(binary, axis=1)
|
||||
|
||||
# Find line boundaries (where projection drops below threshold)
|
||||
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
||||
in_line = False
|
||||
line_start = 0
|
||||
lines = []
|
||||
|
||||
for i, val in enumerate(h_proj):
|
||||
if val > line_threshold and not in_line:
|
||||
# Start of line
|
||||
in_line = True
|
||||
line_start = i
|
||||
elif val <= line_threshold and in_line:
|
||||
# End of line
|
||||
in_line = False
|
||||
# Add padding
|
||||
start = max(0, line_start - 5)
|
||||
end = min(img_array.shape[0], i + 5)
|
||||
if end - start > 10: # Minimum line height
|
||||
lines.append(image.crop((0, start, image.width, end)))
|
||||
|
||||
# Handle last line if still in_line
|
||||
if in_line:
|
||||
start = max(0, line_start - 5)
|
||||
lines.append(image.crop((0, start, image.width, image.height)))
|
||||
|
||||
logger.info(f"Split image into {len(lines)} lines")
|
||||
return lines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Line splitting failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _try_onnx_enhanced(
|
||||
handwritten: bool = True,
|
||||
):
|
||||
"""
|
||||
Return the ONNX enhanced coroutine function, or None if unavailable.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
return run_trocr_onnx_enhanced
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split image into lines first
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is None:
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first ---
|
||||
if backend == "auto":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is not None:
|
||||
try:
|
||||
result = await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
if result.text:
|
||||
return result
|
||||
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
||||
|
||||
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR via PyTorch
|
||||
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for idx, word in enumerate(words):
|
||||
# Simulate word confidence (slightly varied around overall confidence)
|
||||
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences = []
|
||||
if text:
|
||||
for char in text:
|
||||
# Simulate per-character confidence
|
||||
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
|
||||
has_lora_adapter=False, # Would check actual adapter status
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_trocr_batch(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> BatchOCRResult:
|
||||
"""
|
||||
Process multiple images in batch.
|
||||
|
||||
Args:
|
||||
images: List of image data bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split images into lines
|
||||
use_cache: Whether to use caching
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
BatchOCRResult with all results
|
||||
"""
|
||||
start_time = time.time()
|
||||
results = []
|
||||
cached_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if result.from_cache:
|
||||
cached_count += 1
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, len(images))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch OCR error for image {idx}: {e}")
|
||||
error_count += 1
|
||||
results.append(OCRResult(
|
||||
text=f"Error: {str(e)}",
|
||||
confidence=0.0,
|
||||
processing_time_ms=0,
|
||||
model="error",
|
||||
has_lora_adapter=False
|
||||
))
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return BatchOCRResult(
|
||||
results=results,
|
||||
total_time_ms=total_time_ms,
|
||||
processed_count=len(images),
|
||||
cached_count=cached_count,
|
||||
error_count=error_count
|
||||
)
|
||||
|
||||
|
||||
# Generator for SSE streaming during batch processing
|
||||
async def run_trocr_batch_stream(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
):
|
||||
"""
|
||||
Process images and yield progress updates for SSE streaming.
|
||||
|
||||
Yields:
|
||||
dict with current progress and result
|
||||
"""
|
||||
start_time = time.time()
|
||||
total = len(images)
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
avg_time_per_image = elapsed_ms / (idx + 1)
|
||||
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
|
||||
|
||||
yield {
|
||||
"type": "progress",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"progress_percent": ((idx + 1) / total) * 100,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"estimated_remaining_ms": estimated_remaining,
|
||||
"result": {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"processing_time_ms": result.processing_time_ms,
|
||||
"from_cache": result.from_cache
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream OCR error for image {idx}: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
yield {
|
||||
"type": "complete",
|
||||
"total_time_ms": total_time_ms,
|
||||
"processed_count": total
|
||||
}
|
||||
|
||||
|
||||
# Test function
|
||||
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
|
||||
"""Test TrOCR on a local image file."""
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
|
||||
|
||||
print(f"\n=== TrOCR Test ===")
|
||||
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
|
||||
print(f"Confidence: {confidence:.2f}")
|
||||
print(f"Text:\n{text}")
|
||||
|
||||
return text, confidence
|
||||
# Models, cache, and model loading
|
||||
from .trocr_models import (
|
||||
OCRResult,
|
||||
BatchOCRResult,
|
||||
_compute_image_hash,
|
||||
_cache_get,
|
||||
_cache_set,
|
||||
get_cache_stats,
|
||||
_check_trocr_available,
|
||||
get_trocr_model,
|
||||
preload_trocr_model,
|
||||
get_model_status,
|
||||
get_active_backend,
|
||||
_split_into_lines,
|
||||
)
|
||||
|
||||
# Core OCR execution
|
||||
from .trocr_ocr import (
|
||||
run_trocr_ocr,
|
||||
run_trocr_ocr_enhanced,
|
||||
_run_pytorch_ocr,
|
||||
)
|
||||
|
||||
# Batch processing & streaming
|
||||
from .trocr_batch import (
|
||||
run_trocr_batch,
|
||||
run_trocr_batch_stream,
|
||||
test_trocr_ocr,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Dataclasses
|
||||
"OCRResult",
|
||||
"BatchOCRResult",
|
||||
# Cache
|
||||
"_compute_image_hash",
|
||||
"_cache_get",
|
||||
"_cache_set",
|
||||
"get_cache_stats",
|
||||
# Model loading
|
||||
"_check_trocr_available",
|
||||
"get_trocr_model",
|
||||
"preload_trocr_model",
|
||||
"get_model_status",
|
||||
"get_active_backend",
|
||||
"_split_into_lines",
|
||||
# OCR execution
|
||||
"run_trocr_ocr",
|
||||
"run_trocr_ocr_enhanced",
|
||||
"_run_pytorch_ocr",
|
||||
# Batch
|
||||
"run_trocr_batch",
|
||||
"run_trocr_batch_stream",
|
||||
"test_trocr_ocr",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
'use client'
|
||||
|
||||
import { AuditStatistics } from './types'
|
||||
import { Language } from '@/lib/compliance-i18n'
|
||||
|
||||
export default function AuditProgressBar({ statistics, lang }: { statistics: AuditStatistics; lang: Language }) {
|
||||
const segments = [
|
||||
{ key: 'compliant', count: statistics.compliant, color: 'bg-green-500', label: 'Konform' },
|
||||
{ key: 'compliant_with_notes', count: statistics.compliant_with_notes, color: 'bg-yellow-500', label: 'Mit Anm.' },
|
||||
{ key: 'non_compliant', count: statistics.non_compliant, color: 'bg-red-500', label: 'Nicht konform' },
|
||||
{ key: 'not_applicable', count: statistics.not_applicable, color: 'bg-slate-300', label: 'N/A' },
|
||||
{ key: 'pending', count: statistics.pending, color: 'bg-slate-100', label: 'Offen' },
|
||||
]
|
||||
|
||||
return (
|
||||
<div className="bg-white rounded-lg border border-slate-200 p-4 mb-6">
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<h3 className="font-semibold text-slate-900">Fortschritt</h3>
|
||||
<span className="text-2xl font-bold text-primary-600">
|
||||
{Math.round(statistics.completion_percentage)}%
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Stacked Progress Bar */}
|
||||
<div className="h-4 bg-slate-100 rounded-full overflow-hidden flex">
|
||||
{segments.map(seg => (
|
||||
seg.count > 0 && (
|
||||
<div
|
||||
key={seg.key}
|
||||
className={`${seg.color} transition-all`}
|
||||
style={{ width: `${(seg.count / statistics.total) * 100}%` }}
|
||||
title={`${seg.label}: ${seg.count}`}
|
||||
/>
|
||||
)
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Legend */}
|
||||
<div className="flex flex-wrap gap-4 mt-3">
|
||||
{segments.map(seg => (
|
||||
<div key={seg.key} className="flex items-center gap-1.5 text-sm">
|
||||
<span className={`w-3 h-3 rounded ${seg.color}`} />
|
||||
<span className="text-slate-600">{seg.label}:</span>
|
||||
<span className="font-medium text-slate-900">{seg.count}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
'use client'
|
||||
|
||||
import { AuditChecklistItem, RESULT_STATUS } from './types'
|
||||
import { Language } from '@/lib/compliance-i18n'
|
||||
|
||||
export default function ChecklistItemRow({
|
||||
item,
|
||||
index,
|
||||
onSignOff,
|
||||
lang,
|
||||
}: {
|
||||
item: AuditChecklistItem
|
||||
index: number
|
||||
onSignOff: () => void
|
||||
lang: Language
|
||||
}) {
|
||||
const status = RESULT_STATUS[item.current_result]
|
||||
|
||||
return (
|
||||
<div className="p-4 hover:bg-slate-50 transition-colors">
|
||||
<div className="flex items-start gap-4">
|
||||
{/* Status Icon */}
|
||||
<div className={`w-8 h-8 rounded-full flex items-center justify-center text-lg ${status.color}`}>
|
||||
{status.icon}
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="text-xs font-medium text-slate-500">{index}.</span>
|
||||
<span className="font-mono text-sm text-primary-600">{item.regulation_code}</span>
|
||||
<span className="font-mono text-sm text-slate-700">{item.article}</span>
|
||||
{item.is_signed && (
|
||||
<span className="px-1.5 py-0.5 text-xs bg-green-100 text-green-700 rounded flex items-center gap-1">
|
||||
<svg className="w-3 h-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m5.618-4.016A11.955 11.955 0 0112 2.944a11.955 11.955 0 01-8.618 3.04A12.02 12.02 0 003 9c0 5.591 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.042-.133-2.052-.382-3.016z" />
|
||||
</svg>
|
||||
Signiert
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<h4 className="text-sm font-medium text-slate-900 mb-1">{item.title}</h4>
|
||||
{item.notes && (
|
||||
<p className="text-xs text-slate-500 italic mb-2">"{item.notes}"</p>
|
||||
)}
|
||||
<div className="flex items-center gap-4 text-xs text-slate-500">
|
||||
<span>{item.controls_mapped} Controls</span>
|
||||
<span>{item.evidence_count} Nachweise</span>
|
||||
{item.signed_at && (
|
||||
<span>Signiert: {new Date(item.signed_at).toLocaleDateString('de-DE')}</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Actions */}
|
||||
<div className="flex items-center gap-2">
|
||||
<span className={`px-2 py-1 text-xs font-medium rounded ${status.color}`}>
|
||||
{lang === 'de' ? status.label : status.labelEn}
|
||||
</span>
|
||||
<button
|
||||
onClick={onSignOff}
|
||||
className="p-2 text-slate-400 hover:text-primary-600 hover:bg-primary-50 rounded-lg transition-colors"
|
||||
title="Sign-off"
|
||||
>
|
||||
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { Regulation } from './types'
|
||||
|
||||
interface CreateSessionData {
|
||||
name: string
|
||||
description?: string
|
||||
auditor_name: string
|
||||
auditor_email?: string
|
||||
regulation_codes?: string[]
|
||||
}
|
||||
|
||||
export default function CreateSessionModal({
|
||||
regulations,
|
||||
onClose,
|
||||
onCreate,
|
||||
}: {
|
||||
regulations: Regulation[]
|
||||
onClose: () => void
|
||||
onCreate: (data: CreateSessionData) => void
|
||||
}) {
|
||||
const [name, setName] = useState('')
|
||||
const [description, setDescription] = useState('')
|
||||
const [auditorName, setAuditorName] = useState('')
|
||||
const [auditorEmail, setAuditorEmail] = useState('')
|
||||
const [selectedRegs, setSelectedRegs] = useState<string[]>([])
|
||||
const [submitting, setSubmitting] = useState(false)
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
if (!name || !auditorName) return
|
||||
|
||||
setSubmitting(true)
|
||||
await onCreate({
|
||||
name,
|
||||
description: description || undefined,
|
||||
auditor_name: auditorName,
|
||||
auditor_email: auditorEmail || undefined,
|
||||
regulation_codes: selectedRegs.length > 0 ? selectedRegs : undefined,
|
||||
})
|
||||
setSubmitting(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||
<div className="bg-white rounded-xl shadow-xl max-w-lg w-full mx-4 overflow-hidden">
|
||||
<div className="px-6 py-4 border-b border-slate-200 flex items-center justify-between">
|
||||
<h2 className="text-lg font-semibold text-slate-900">Neue Audit-Session</h2>
|
||||
<button onClick={onClose} className="text-slate-400 hover:text-slate-600">
|
||||
<svg className="w-6 h-6" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M6 18L18 6M6 6l12 12" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="p-6 space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Name der Pruefung *
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="z.B. Q1 2026 Compliance Audit"
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Beschreibung
|
||||
</label>
|
||||
<textarea
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
placeholder="Optionale Beschreibung..."
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
|
||||
rows={2}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Auditor Name *
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={auditorName}
|
||||
onChange={(e) => setAuditorName(e.target.value)}
|
||||
placeholder="Dr. Max Mustermann"
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
E-Mail
|
||||
</label>
|
||||
<input
|
||||
type="email"
|
||||
value={auditorEmail}
|
||||
onChange={(e) => setAuditorEmail(e.target.value)}
|
||||
placeholder="auditor@example.com"
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Verordnungen (optional - leer = alle)
|
||||
</label>
|
||||
<div className="flex flex-wrap gap-2 p-2 border border-slate-300 rounded-lg max-h-32 overflow-y-auto">
|
||||
{regulations.map(reg => (
|
||||
<label key={reg.code} className="flex items-center gap-1.5">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={selectedRegs.includes(reg.code)}
|
||||
onChange={(e) => {
|
||||
if (e.target.checked) {
|
||||
setSelectedRegs([...selectedRegs, reg.code])
|
||||
} else {
|
||||
setSelectedRegs(selectedRegs.filter(r => r !== reg.code))
|
||||
}
|
||||
}}
|
||||
className="rounded"
|
||||
/>
|
||||
<span className="text-sm text-slate-700">{reg.code}</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3 pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClose}
|
||||
className="px-4 py-2 text-slate-600 hover:text-slate-800"
|
||||
>
|
||||
Abbrechen
|
||||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!name || !auditorName || submitting}
|
||||
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 disabled:opacity-50"
|
||||
>
|
||||
{submitting ? 'Erstelle...' : 'Session erstellen'}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
'use client'
|
||||
|
||||
import { AuditSession, SESSION_STATUS } from './types'
|
||||
|
||||
export default function SessionCard({ session, onClick }: { session: AuditSession; onClick: () => void }) {
|
||||
const completionPercent = session.total_items > 0
|
||||
? Math.round((session.completed_items / session.total_items) * 100)
|
||||
: 0
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className="bg-white rounded-lg border border-slate-200 p-4 text-left hover:shadow-md hover:border-primary-300 transition-all"
|
||||
>
|
||||
<div className="flex items-start justify-between mb-3">
|
||||
<h3 className="font-semibold text-slate-900">{session.name}</h3>
|
||||
<span className={`px-2 py-0.5 text-xs font-medium rounded ${SESSION_STATUS[session.status].color}`}>
|
||||
{SESSION_STATUS[session.status].label}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{session.description && (
|
||||
<p className="text-sm text-slate-500 mb-3 line-clamp-2">{session.description}</p>
|
||||
)}
|
||||
|
||||
<div className="mb-3">
|
||||
<div className="flex justify-between text-xs text-slate-500 mb-1">
|
||||
<span>{session.completed_items} / {session.total_items} Punkte</span>
|
||||
<span>{completionPercent}%</span>
|
||||
</div>
|
||||
<div className="h-2 bg-slate-100 rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-primary-500 transition-all"
|
||||
style={{ width: `${completionPercent}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between text-xs text-slate-500">
|
||||
<span>Auditor: {session.auditor_name}</span>
|
||||
<span>{new Date(session.created_at).toLocaleDateString('de-DE')}</span>
|
||||
</div>
|
||||
</button>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { AuditChecklistItem, RESULT_STATUS } from './types'
|
||||
|
||||
export default function SignOffModal({
|
||||
item,
|
||||
auditorName,
|
||||
onClose,
|
||||
onSubmit,
|
||||
}: {
|
||||
item: AuditChecklistItem
|
||||
auditorName: string
|
||||
onClose: () => void
|
||||
onSubmit: (result: AuditChecklistItem['current_result'], notes: string, sign: boolean) => void
|
||||
}) {
|
||||
const [result, setResult] = useState<AuditChecklistItem['current_result']>(item.current_result)
|
||||
const [notes, setNotes] = useState(item.notes || '')
|
||||
const [sign, setSign] = useState(false)
|
||||
const [submitting, setSubmitting] = useState(false)
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
setSubmitting(true)
|
||||
await onSubmit(result, notes, sign)
|
||||
setSubmitting(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||
<div className="bg-white rounded-xl shadow-xl max-w-lg w-full mx-4 overflow-hidden">
|
||||
<div className="px-6 py-4 border-b border-slate-200">
|
||||
<h2 className="text-lg font-semibold text-slate-900">Auditor Sign-off</h2>
|
||||
<p className="text-sm text-slate-500 mt-1">
|
||||
{item.regulation_code} {item.article} - {item.title}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="p-6 space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-2">
|
||||
Pruefungsergebnis
|
||||
</label>
|
||||
<div className="space-y-2">
|
||||
{Object.entries(RESULT_STATUS).map(([key, { label, color }]) => (
|
||||
<label key={key} className="flex items-center gap-3 p-2 rounded-lg hover:bg-slate-50 cursor-pointer">
|
||||
<input
|
||||
type="radio"
|
||||
name="result"
|
||||
value={key}
|
||||
checked={result === key}
|
||||
onChange={() => setResult(key as typeof result)}
|
||||
className="w-4 h-4"
|
||||
/>
|
||||
<span className={`px-2 py-0.5 text-sm font-medium rounded ${color}`}>
|
||||
{label}
|
||||
</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Anmerkungen
|
||||
</label>
|
||||
<textarea
|
||||
value={notes}
|
||||
onChange={(e) => setNotes(e.target.value)}
|
||||
placeholder="Optionale Anmerkungen zur Pruefung..."
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
|
||||
rows={3}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="p-3 bg-amber-50 border border-amber-200 rounded-lg">
|
||||
<label className="flex items-start gap-3 cursor-pointer">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={sign}
|
||||
onChange={(e) => setSign(e.target.checked)}
|
||||
className="mt-1 rounded"
|
||||
/>
|
||||
<div>
|
||||
<span className="font-medium text-amber-800">Digital signieren</span>
|
||||
<p className="text-sm text-amber-700 mt-0.5">
|
||||
Erstellt eine SHA-256 Signatur des Ergebnisses. Diese Aktion kann nicht rueckgaengig gemacht werden.
|
||||
</p>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="text-sm text-slate-500 pt-2 border-t border-slate-100">
|
||||
<p>Auditor: <span className="font-medium text-slate-700">{auditorName}</span></p>
|
||||
<p>Datum: <span className="font-medium text-slate-700">{new Date().toLocaleString('de-DE')}</span></p>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3 pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClose}
|
||||
className="px-4 py-2 text-slate-600 hover:text-slate-800"
|
||||
>
|
||||
Abbrechen
|
||||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={submitting}
|
||||
className={`px-4 py-2 rounded-lg disabled:opacity-50 ${
|
||||
sign
|
||||
? 'bg-amber-600 text-white hover:bg-amber-700'
|
||||
: 'bg-primary-600 text-white hover:bg-primary-700'
|
||||
}`}
|
||||
>
|
||||
{submitting ? 'Speichere...' : sign ? 'Signieren & Speichern' : 'Speichern'}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
import { Language } from '@/lib/compliance-i18n'
|
||||
|
||||
export interface AuditSession {
|
||||
id: string
|
||||
name: string
|
||||
description: string | null
|
||||
auditor_name: string
|
||||
auditor_email: string | null
|
||||
status: 'draft' | 'in_progress' | 'completed' | 'archived'
|
||||
regulation_ids: string[] | null
|
||||
total_items: number
|
||||
completed_items: number
|
||||
created_at: string
|
||||
started_at: string | null
|
||||
completed_at: string | null
|
||||
}
|
||||
|
||||
export interface AuditChecklistItem {
|
||||
requirement_id: string
|
||||
regulation_code: string
|
||||
article: string
|
||||
title: string
|
||||
current_result: 'compliant' | 'compliant_notes' | 'non_compliant' | 'not_applicable' | 'pending'
|
||||
notes: string | null
|
||||
is_signed: boolean
|
||||
signed_at: string | null
|
||||
signed_by: string | null
|
||||
evidence_count: number
|
||||
controls_mapped: number
|
||||
}
|
||||
|
||||
export interface AuditStatistics {
|
||||
total: number
|
||||
compliant: number
|
||||
compliant_with_notes: number
|
||||
non_compliant: number
|
||||
not_applicable: number
|
||||
pending: number
|
||||
completion_percentage: number
|
||||
}
|
||||
|
||||
export interface Regulation {
|
||||
id: string
|
||||
code: string
|
||||
name: string
|
||||
requirement_count: number
|
||||
}
|
||||
|
||||
export const RESULT_STATUS = {
|
||||
pending: { label: 'Ausstehend', labelEn: 'Pending', color: 'bg-slate-100 text-slate-700', icon: '○' },
|
||||
compliant: { label: 'Konform', labelEn: 'Compliant', color: 'bg-green-100 text-green-700', icon: '✓' },
|
||||
compliant_notes: { label: 'Konform (Anm.)', labelEn: 'Compliant w/ Notes', color: 'bg-yellow-100 text-yellow-700', icon: '⚠' },
|
||||
non_compliant: { label: 'Nicht konform', labelEn: 'Non-Compliant', color: 'bg-red-100 text-red-700', icon: '✗' },
|
||||
not_applicable: { label: 'N/A', labelEn: 'N/A', color: 'bg-slate-50 text-slate-500', icon: '−' },
|
||||
}
|
||||
|
||||
export const SESSION_STATUS = {
|
||||
draft: { label: 'Entwurf', color: 'bg-slate-100 text-slate-700' },
|
||||
in_progress: { label: 'In Bearbeitung', color: 'bg-blue-100 text-blue-700' },
|
||||
completed: { label: 'Abgeschlossen', color: 'bg-green-100 text-green-700' },
|
||||
archived: { label: 'Archiviert', color: 'bg-slate-50 text-slate-500' },
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import {
|
||||
AuditSession,
|
||||
AuditChecklistItem,
|
||||
AuditStatistics,
|
||||
Regulation,
|
||||
} from './types'
|
||||
import { Language } from '@/lib/compliance-i18n'
|
||||
|
||||
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
|
||||
|
||||
export function useAuditChecklist() {
|
||||
const [sessions, setSessions] = useState<AuditSession[]>([])
|
||||
const [selectedSession, setSelectedSession] = useState<AuditSession | null>(null)
|
||||
const [checklistItems, setChecklistItems] = useState<AuditChecklistItem[]>([])
|
||||
const [statistics, setStatistics] = useState<AuditStatistics | null>(null)
|
||||
const [regulations, setRegulations] = useState<Regulation[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [loadingChecklist, setLoadingChecklist] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
// Filters
|
||||
const [filterStatus, setFilterStatus] = useState<string>('all')
|
||||
const [filterRegulation, setFilterRegulation] = useState<string>('all')
|
||||
const [searchQuery, setSearchQuery] = useState('')
|
||||
|
||||
// Modal states
|
||||
const [showCreateModal, setShowCreateModal] = useState(false)
|
||||
const [showSignOffModal, setShowSignOffModal] = useState(false)
|
||||
const [selectedItem, setSelectedItem] = useState<AuditChecklistItem | null>(null)
|
||||
|
||||
// Language
|
||||
const [lang] = useState<Language>('de')
|
||||
|
||||
// Load sessions
|
||||
const loadSessions = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/sessions`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setSessions(data.sessions || [])
|
||||
} else {
|
||||
console.error('Failed to load sessions:', res.status)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load sessions:', err)
|
||||
setError('Verbindung zum Backend fehlgeschlagen')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Load regulations for filter
|
||||
const loadRegulations = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/regulations`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setRegulations(data.regulations || [])
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load regulations:', err)
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Load checklist for selected session
|
||||
const loadChecklist = useCallback(async (sessionId: string) => {
|
||||
setLoadingChecklist(true)
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/checklist/${sessionId}`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setChecklistItems(data.items || [])
|
||||
setStatistics(data.statistics || null)
|
||||
if (data.session) {
|
||||
setSelectedSession(data.session)
|
||||
}
|
||||
} else {
|
||||
console.error('Failed to load checklist:', res.status)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load checklist:', err)
|
||||
} finally {
|
||||
setLoadingChecklist(false)
|
||||
}
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
loadSessions()
|
||||
loadRegulations()
|
||||
}, [loadSessions, loadRegulations])
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedSession) {
|
||||
loadChecklist(selectedSession.id)
|
||||
}
|
||||
}, [selectedSession, loadChecklist])
|
||||
|
||||
// Filter checklist items
|
||||
const filteredItems = checklistItems.filter(item => {
|
||||
if (filterStatus !== 'all' && item.current_result !== filterStatus) return false
|
||||
if (filterRegulation !== 'all' && item.regulation_code !== filterRegulation) return false
|
||||
if (searchQuery) {
|
||||
const query = searchQuery.toLowerCase()
|
||||
return (
|
||||
item.title.toLowerCase().includes(query) ||
|
||||
item.article.toLowerCase().includes(query) ||
|
||||
item.regulation_code.toLowerCase().includes(query)
|
||||
)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Handle sign-off
|
||||
const handleSignOff = async (
|
||||
item: AuditChecklistItem,
|
||||
result: AuditChecklistItem['current_result'],
|
||||
notes: string,
|
||||
sign: boolean
|
||||
) => {
|
||||
if (!selectedSession) return
|
||||
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${BACKEND_URL}/api/v1/compliance/audit/checklist/${selectedSession.id}/items/${item.requirement_id}/sign-off`,
|
||||
{
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ result, notes, sign }),
|
||||
}
|
||||
)
|
||||
|
||||
if (res.ok) {
|
||||
await loadChecklist(selectedSession.id)
|
||||
setShowSignOffModal(false)
|
||||
setSelectedItem(null)
|
||||
} else {
|
||||
const err = await res.json()
|
||||
alert(`Fehler: ${err.detail || 'Sign-off fehlgeschlagen'}`)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Sign-off failed:', err)
|
||||
alert('Netzwerkfehler bei Sign-off')
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
const handleCreateSession = async (data: {
|
||||
name: string
|
||||
description?: string
|
||||
auditor_name: string
|
||||
auditor_email?: string
|
||||
regulation_codes?: string[]
|
||||
}) => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/sessions`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(data),
|
||||
})
|
||||
if (res.ok) {
|
||||
await loadSessions()
|
||||
setShowCreateModal(false)
|
||||
} else {
|
||||
const err = await res.json()
|
||||
alert(`Fehler: ${err.detail || 'Session konnte nicht erstellt werden'}`)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Create session failed:', err)
|
||||
alert('Netzwerkfehler')
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
sessions,
|
||||
selectedSession,
|
||||
setSelectedSession,
|
||||
checklistItems,
|
||||
filteredItems,
|
||||
statistics,
|
||||
regulations,
|
||||
loading,
|
||||
loadingChecklist,
|
||||
error,
|
||||
filterStatus,
|
||||
setFilterStatus,
|
||||
filterRegulation,
|
||||
setFilterRegulation,
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
showCreateModal,
|
||||
setShowCreateModal,
|
||||
showSignOffModal,
|
||||
setShowSignOffModal,
|
||||
selectedItem,
|
||||
setSelectedItem,
|
||||
lang,
|
||||
handleSignOff,
|
||||
handleCreateSession,
|
||||
}
|
||||
}
|
||||
@@ -10,212 +10,29 @@
|
||||
* - Fortschritt und Statistiken zu verfolgen
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import Link from 'next/link'
|
||||
import AdminLayout from '@/components/admin/AdminLayout'
|
||||
import { getTerm, Language } from '@/lib/compliance-i18n'
|
||||
|
||||
// Types based on backend schemas
|
||||
interface AuditSession {
|
||||
id: string
|
||||
name: string
|
||||
description: string | null
|
||||
auditor_name: string
|
||||
auditor_email: string | null
|
||||
status: 'draft' | 'in_progress' | 'completed' | 'archived'
|
||||
regulation_ids: string[] | null
|
||||
total_items: number
|
||||
completed_items: number
|
||||
created_at: string
|
||||
started_at: string | null
|
||||
completed_at: string | null
|
||||
}
|
||||
|
||||
interface AuditChecklistItem {
|
||||
requirement_id: string
|
||||
regulation_code: string
|
||||
article: string
|
||||
title: string
|
||||
current_result: 'compliant' | 'compliant_notes' | 'non_compliant' | 'not_applicable' | 'pending'
|
||||
notes: string | null
|
||||
is_signed: boolean
|
||||
signed_at: string | null
|
||||
signed_by: string | null
|
||||
evidence_count: number
|
||||
controls_mapped: number
|
||||
}
|
||||
|
||||
interface AuditStatistics {
|
||||
total: number
|
||||
compliant: number
|
||||
compliant_with_notes: number
|
||||
non_compliant: number
|
||||
not_applicable: number
|
||||
pending: number
|
||||
completion_percentage: number
|
||||
}
|
||||
|
||||
interface Regulation {
|
||||
id: string
|
||||
code: string
|
||||
name: string
|
||||
requirement_count: number
|
||||
}
|
||||
|
||||
// Result status configuration
|
||||
const RESULT_STATUS = {
|
||||
pending: { label: 'Ausstehend', labelEn: 'Pending', color: 'bg-slate-100 text-slate-700', icon: '○' },
|
||||
compliant: { label: 'Konform', labelEn: 'Compliant', color: 'bg-green-100 text-green-700', icon: '✓' },
|
||||
compliant_notes: { label: 'Konform (Anm.)', labelEn: 'Compliant w/ Notes', color: 'bg-yellow-100 text-yellow-700', icon: '⚠' },
|
||||
non_compliant: { label: 'Nicht konform', labelEn: 'Non-Compliant', color: 'bg-red-100 text-red-700', icon: '✗' },
|
||||
not_applicable: { label: 'N/A', labelEn: 'N/A', color: 'bg-slate-50 text-slate-500', icon: '−' },
|
||||
}
|
||||
|
||||
const SESSION_STATUS = {
|
||||
draft: { label: 'Entwurf', color: 'bg-slate-100 text-slate-700' },
|
||||
in_progress: { label: 'In Bearbeitung', color: 'bg-blue-100 text-blue-700' },
|
||||
completed: { label: 'Abgeschlossen', color: 'bg-green-100 text-green-700' },
|
||||
archived: { label: 'Archiviert', color: 'bg-slate-50 text-slate-500' },
|
||||
}
|
||||
import { RESULT_STATUS, SESSION_STATUS } from './_components/types'
|
||||
import { useAuditChecklist } from './_components/useAuditChecklist'
|
||||
import SessionCard from './_components/SessionCard'
|
||||
import AuditProgressBar from './_components/AuditProgressBar'
|
||||
import ChecklistItemRow from './_components/ChecklistItemRow'
|
||||
import CreateSessionModal from './_components/CreateSessionModal'
|
||||
import SignOffModal from './_components/SignOffModal'
|
||||
|
||||
export default function AuditChecklistPage() {
|
||||
const [sessions, setSessions] = useState<AuditSession[]>([])
|
||||
const [selectedSession, setSelectedSession] = useState<AuditSession | null>(null)
|
||||
const [checklistItems, setChecklistItems] = useState<AuditChecklistItem[]>([])
|
||||
const [statistics, setStatistics] = useState<AuditStatistics | null>(null)
|
||||
const [regulations, setRegulations] = useState<Regulation[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [loadingChecklist, setLoadingChecklist] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
// Filters
|
||||
const [filterStatus, setFilterStatus] = useState<string>('all')
|
||||
const [filterRegulation, setFilterRegulation] = useState<string>('all')
|
||||
const [searchQuery, setSearchQuery] = useState('')
|
||||
|
||||
// Modal states
|
||||
const [showCreateModal, setShowCreateModal] = useState(false)
|
||||
const [showSignOffModal, setShowSignOffModal] = useState(false)
|
||||
const [selectedItem, setSelectedItem] = useState<AuditChecklistItem | null>(null)
|
||||
|
||||
// Language
|
||||
const [lang] = useState<Language>('de')
|
||||
|
||||
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
|
||||
|
||||
// Load sessions
|
||||
const loadSessions = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/sessions`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setSessions(data.sessions || [])
|
||||
} else {
|
||||
console.error('Failed to load sessions:', res.status)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load sessions:', err)
|
||||
setError('Verbindung zum Backend fehlgeschlagen')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [BACKEND_URL])
|
||||
|
||||
// Load regulations for filter
|
||||
const loadRegulations = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/regulations`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setRegulations(data.regulations || [])
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load regulations:', err)
|
||||
}
|
||||
}, [BACKEND_URL])
|
||||
|
||||
// Load checklist for selected session
|
||||
const loadChecklist = useCallback(async (sessionId: string) => {
|
||||
setLoadingChecklist(true)
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/checklist/${sessionId}`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setChecklistItems(data.items || [])
|
||||
setStatistics(data.statistics || null)
|
||||
// Update session with latest data
|
||||
if (data.session) {
|
||||
setSelectedSession(data.session)
|
||||
}
|
||||
} else {
|
||||
console.error('Failed to load checklist:', res.status)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load checklist:', err)
|
||||
} finally {
|
||||
setLoadingChecklist(false)
|
||||
}
|
||||
}, [BACKEND_URL])
|
||||
|
||||
useEffect(() => {
|
||||
loadSessions()
|
||||
loadRegulations()
|
||||
}, [loadSessions, loadRegulations])
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedSession) {
|
||||
loadChecklist(selectedSession.id)
|
||||
}
|
||||
}, [selectedSession, loadChecklist])
|
||||
|
||||
// Filter checklist items
|
||||
const filteredItems = checklistItems.filter(item => {
|
||||
if (filterStatus !== 'all' && item.current_result !== filterStatus) return false
|
||||
if (filterRegulation !== 'all' && item.regulation_code !== filterRegulation) return false
|
||||
if (searchQuery) {
|
||||
const query = searchQuery.toLowerCase()
|
||||
return (
|
||||
item.title.toLowerCase().includes(query) ||
|
||||
item.article.toLowerCase().includes(query) ||
|
||||
item.regulation_code.toLowerCase().includes(query)
|
||||
)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Handle sign-off
|
||||
const handleSignOff = async (
|
||||
item: AuditChecklistItem,
|
||||
result: AuditChecklistItem['current_result'],
|
||||
notes: string,
|
||||
sign: boolean
|
||||
) => {
|
||||
if (!selectedSession) return
|
||||
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${BACKEND_URL}/api/v1/compliance/audit/checklist/${selectedSession.id}/items/${item.requirement_id}/sign-off`,
|
||||
{
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ result, notes, sign }),
|
||||
}
|
||||
)
|
||||
|
||||
if (res.ok) {
|
||||
// Reload checklist to get updated data
|
||||
await loadChecklist(selectedSession.id)
|
||||
setShowSignOffModal(false)
|
||||
setSelectedItem(null)
|
||||
} else {
|
||||
const err = await res.json()
|
||||
alert(`Fehler: ${err.detail || 'Sign-off fehlgeschlagen'}`)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Sign-off failed:', err)
|
||||
alert('Netzwerkfehler bei Sign-off')
|
||||
}
|
||||
}
|
||||
const {
|
||||
sessions, selectedSession, setSelectedSession,
|
||||
filteredItems, statistics, regulations,
|
||||
loading, loadingChecklist, error,
|
||||
filterStatus, setFilterStatus,
|
||||
filterRegulation, setFilterRegulation,
|
||||
searchQuery, setSearchQuery,
|
||||
showCreateModal, setShowCreateModal,
|
||||
showSignOffModal, setShowSignOffModal,
|
||||
selectedItem, setSelectedItem,
|
||||
lang, handleSignOff, handleCreateSession,
|
||||
} = useAuditChecklist()
|
||||
|
||||
// Session list view when no session is selected
|
||||
if (!selectedSession && !loading) {
|
||||
@@ -231,7 +48,6 @@ export default function AuditChecklistPage() {
|
||||
</svg>
|
||||
Zurueck zu Compliance
|
||||
</Link>
|
||||
|
||||
<button
|
||||
onClick={() => setShowCreateModal(true)}
|
||||
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 flex items-center gap-2"
|
||||
@@ -244,12 +60,9 @@ export default function AuditChecklistPage() {
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="mb-6 p-4 bg-red-50 border border-red-200 rounded-lg text-red-700">
|
||||
{error}
|
||||
</div>
|
||||
<div className="mb-6 p-4 bg-red-50 border border-red-200 rounded-lg text-red-700">{error}</div>
|
||||
)}
|
||||
|
||||
{/* Session Cards */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{sessions.length === 0 ? (
|
||||
<div className="col-span-full bg-white rounded-lg border border-slate-200 p-8 text-center">
|
||||
@@ -267,46 +80,22 @@ export default function AuditChecklistPage() {
|
||||
</div>
|
||||
) : (
|
||||
sessions.map(session => (
|
||||
<SessionCard
|
||||
key={session.id}
|
||||
session={session}
|
||||
onClick={() => setSelectedSession(session)}
|
||||
/>
|
||||
<SessionCard key={session.id} session={session} onClick={() => setSelectedSession(session)} />
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Create Session Modal */}
|
||||
{showCreateModal && (
|
||||
<CreateSessionModal
|
||||
regulations={regulations}
|
||||
onClose={() => setShowCreateModal(false)}
|
||||
onCreate={async (data) => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/audit/sessions`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(data),
|
||||
})
|
||||
if (res.ok) {
|
||||
await loadSessions()
|
||||
setShowCreateModal(false)
|
||||
} else {
|
||||
const err = await res.json()
|
||||
alert(`Fehler: ${err.detail || 'Session konnte nicht erstellt werden'}`)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Create session failed:', err)
|
||||
alert('Netzwerkfehler')
|
||||
}
|
||||
}}
|
||||
onCreate={handleCreateSession}
|
||||
/>
|
||||
)}
|
||||
</AdminLayout>
|
||||
)
|
||||
}
|
||||
|
||||
// Loading state
|
||||
if (loading) {
|
||||
return (
|
||||
<AdminLayout title="Audit Checkliste" description="Laden...">
|
||||
@@ -320,7 +109,6 @@ export default function AuditChecklistPage() {
|
||||
// Checklist view for selected session
|
||||
return (
|
||||
<AdminLayout title={selectedSession?.name || 'Audit Checkliste'} description="Pruefung durchfuehren">
|
||||
{/* Header */}
|
||||
<div className="mb-6 flex items-center justify-between">
|
||||
<button
|
||||
onClick={() => setSelectedSession(null)}
|
||||
@@ -331,7 +119,6 @@ export default function AuditChecklistPage() {
|
||||
</svg>
|
||||
Zurueck zur Uebersicht
|
||||
</button>
|
||||
|
||||
<div className="flex items-center gap-3">
|
||||
<span className={`px-3 py-1 text-sm font-medium rounded-full ${SESSION_STATUS[selectedSession?.status || 'draft'].color}`}>
|
||||
{SESSION_STATUS[selectedSession?.status || 'draft'].label}
|
||||
@@ -348,10 +135,7 @@ export default function AuditChecklistPage() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Progress Bar */}
|
||||
{statistics && (
|
||||
<AuditProgressBar statistics={statistics} lang={lang} />
|
||||
)}
|
||||
{statistics && <AuditProgressBar statistics={statistics} lang={lang} />}
|
||||
|
||||
{/* Filters */}
|
||||
<div className="bg-white rounded-lg border border-slate-200 p-4 mb-6">
|
||||
@@ -424,7 +208,6 @@ export default function AuditChecklistPage() {
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Sign-off Modal */}
|
||||
{showSignOffModal && selectedItem && (
|
||||
<SignOffModal
|
||||
item={selectedItem}
|
||||
@@ -439,429 +222,3 @@ export default function AuditChecklistPage() {
|
||||
</AdminLayout>
|
||||
)
|
||||
}
|
||||
|
||||
// Session Card Component
|
||||
function SessionCard({ session, onClick }: { session: AuditSession; onClick: () => void }) {
|
||||
const completionPercent = session.total_items > 0
|
||||
? Math.round((session.completed_items / session.total_items) * 100)
|
||||
: 0
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className="bg-white rounded-lg border border-slate-200 p-4 text-left hover:shadow-md hover:border-primary-300 transition-all"
|
||||
>
|
||||
<div className="flex items-start justify-between mb-3">
|
||||
<h3 className="font-semibold text-slate-900">{session.name}</h3>
|
||||
<span className={`px-2 py-0.5 text-xs font-medium rounded ${SESSION_STATUS[session.status].color}`}>
|
||||
{SESSION_STATUS[session.status].label}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{session.description && (
|
||||
<p className="text-sm text-slate-500 mb-3 line-clamp-2">{session.description}</p>
|
||||
)}
|
||||
|
||||
<div className="mb-3">
|
||||
<div className="flex justify-between text-xs text-slate-500 mb-1">
|
||||
<span>{session.completed_items} / {session.total_items} Punkte</span>
|
||||
<span>{completionPercent}%</span>
|
||||
</div>
|
||||
<div className="h-2 bg-slate-100 rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-primary-500 transition-all"
|
||||
style={{ width: `${completionPercent}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between text-xs text-slate-500">
|
||||
<span>Auditor: {session.auditor_name}</span>
|
||||
<span>{new Date(session.created_at).toLocaleDateString('de-DE')}</span>
|
||||
</div>
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
// Progress Bar Component
|
||||
function AuditProgressBar({ statistics, lang }: { statistics: AuditStatistics; lang: Language }) {
|
||||
const segments = [
|
||||
{ key: 'compliant', count: statistics.compliant, color: 'bg-green-500', label: 'Konform' },
|
||||
{ key: 'compliant_with_notes', count: statistics.compliant_with_notes, color: 'bg-yellow-500', label: 'Mit Anm.' },
|
||||
{ key: 'non_compliant', count: statistics.non_compliant, color: 'bg-red-500', label: 'Nicht konform' },
|
||||
{ key: 'not_applicable', count: statistics.not_applicable, color: 'bg-slate-300', label: 'N/A' },
|
||||
{ key: 'pending', count: statistics.pending, color: 'bg-slate-100', label: 'Offen' },
|
||||
]
|
||||
|
||||
return (
|
||||
<div className="bg-white rounded-lg border border-slate-200 p-4 mb-6">
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<h3 className="font-semibold text-slate-900">Fortschritt</h3>
|
||||
<span className="text-2xl font-bold text-primary-600">
|
||||
{Math.round(statistics.completion_percentage)}%
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Stacked Progress Bar */}
|
||||
<div className="h-4 bg-slate-100 rounded-full overflow-hidden flex">
|
||||
{segments.map(seg => (
|
||||
seg.count > 0 && (
|
||||
<div
|
||||
key={seg.key}
|
||||
className={`${seg.color} transition-all`}
|
||||
style={{ width: `${(seg.count / statistics.total) * 100}%` }}
|
||||
title={`${seg.label}: ${seg.count}`}
|
||||
/>
|
||||
)
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Legend */}
|
||||
<div className="flex flex-wrap gap-4 mt-3">
|
||||
{segments.map(seg => (
|
||||
<div key={seg.key} className="flex items-center gap-1.5 text-sm">
|
||||
<span className={`w-3 h-3 rounded ${seg.color}`} />
|
||||
<span className="text-slate-600">{seg.label}:</span>
|
||||
<span className="font-medium text-slate-900">{seg.count}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Checklist Item Row Component
|
||||
function ChecklistItemRow({
|
||||
item,
|
||||
index,
|
||||
onSignOff,
|
||||
lang,
|
||||
}: {
|
||||
item: AuditChecklistItem
|
||||
index: number
|
||||
onSignOff: () => void
|
||||
lang: Language
|
||||
}) {
|
||||
const status = RESULT_STATUS[item.current_result]
|
||||
|
||||
return (
|
||||
<div className="p-4 hover:bg-slate-50 transition-colors">
|
||||
<div className="flex items-start gap-4">
|
||||
{/* Status Icon */}
|
||||
<div className={`w-8 h-8 rounded-full flex items-center justify-center text-lg ${status.color}`}>
|
||||
{status.icon}
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<span className="text-xs font-medium text-slate-500">{index}.</span>
|
||||
<span className="font-mono text-sm text-primary-600">{item.regulation_code}</span>
|
||||
<span className="font-mono text-sm text-slate-700">{item.article}</span>
|
||||
{item.is_signed && (
|
||||
<span className="px-1.5 py-0.5 text-xs bg-green-100 text-green-700 rounded flex items-center gap-1">
|
||||
<svg className="w-3 h-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m5.618-4.016A11.955 11.955 0 0112 2.944a11.955 11.955 0 01-8.618 3.04A12.02 12.02 0 003 9c0 5.591 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.042-.133-2.052-.382-3.016z" />
|
||||
</svg>
|
||||
Signiert
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<h4 className="text-sm font-medium text-slate-900 mb-1">{item.title}</h4>
|
||||
{item.notes && (
|
||||
<p className="text-xs text-slate-500 italic mb-2">"{item.notes}"</p>
|
||||
)}
|
||||
<div className="flex items-center gap-4 text-xs text-slate-500">
|
||||
<span>{item.controls_mapped} Controls</span>
|
||||
<span>{item.evidence_count} Nachweise</span>
|
||||
{item.signed_at && (
|
||||
<span>Signiert: {new Date(item.signed_at).toLocaleDateString('de-DE')}</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Actions */}
|
||||
<div className="flex items-center gap-2">
|
||||
<span className={`px-2 py-1 text-xs font-medium rounded ${status.color}`}>
|
||||
{lang === 'de' ? status.label : status.labelEn}
|
||||
</span>
|
||||
<button
|
||||
onClick={onSignOff}
|
||||
className="p-2 text-slate-400 hover:text-primary-600 hover:bg-primary-50 rounded-lg transition-colors"
|
||||
title="Sign-off"
|
||||
>
|
||||
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Create Session Modal
|
||||
function CreateSessionModal({
|
||||
regulations,
|
||||
onClose,
|
||||
onCreate,
|
||||
}: {
|
||||
regulations: Regulation[]
|
||||
onClose: () => void
|
||||
onCreate: (data: { name: string; description?: string; auditor_name: string; auditor_email?: string; regulation_codes?: string[] }) => void
|
||||
}) {
|
||||
const [name, setName] = useState('')
|
||||
const [description, setDescription] = useState('')
|
||||
const [auditorName, setAuditorName] = useState('')
|
||||
const [auditorEmail, setAuditorEmail] = useState('')
|
||||
const [selectedRegs, setSelectedRegs] = useState<string[]>([])
|
||||
const [submitting, setSubmitting] = useState(false)
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
if (!name || !auditorName) return
|
||||
|
||||
setSubmitting(true)
|
||||
await onCreate({
|
||||
name,
|
||||
description: description || undefined,
|
||||
auditor_name: auditorName,
|
||||
auditor_email: auditorEmail || undefined,
|
||||
regulation_codes: selectedRegs.length > 0 ? selectedRegs : undefined,
|
||||
})
|
||||
setSubmitting(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||
<div className="bg-white rounded-xl shadow-xl max-w-lg w-full mx-4 overflow-hidden">
|
||||
<div className="px-6 py-4 border-b border-slate-200 flex items-center justify-between">
|
||||
<h2 className="text-lg font-semibold text-slate-900">Neue Audit-Session</h2>
|
||||
<button onClick={onClose} className="text-slate-400 hover:text-slate-600">
|
||||
<svg className="w-6 h-6" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M6 18L18 6M6 6l12 12" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="p-6 space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Name der Pruefung *
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="z.B. Q1 2026 Compliance Audit"
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Beschreibung
|
||||
</label>
|
||||
<textarea
|
||||
value={description}
|
||||
onChange={(e) => setDescription(e.target.value)}
|
||||
placeholder="Optionale Beschreibung..."
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
|
||||
rows={2}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Auditor Name *
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={auditorName}
|
||||
onChange={(e) => setAuditorName(e.target.value)}
|
||||
placeholder="Dr. Max Mustermann"
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
E-Mail
|
||||
</label>
|
||||
<input
|
||||
type="email"
|
||||
value={auditorEmail}
|
||||
onChange={(e) => setAuditorEmail(e.target.value)}
|
||||
placeholder="auditor@example.com"
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Verordnungen (optional - leer = alle)
|
||||
</label>
|
||||
<div className="flex flex-wrap gap-2 p-2 border border-slate-300 rounded-lg max-h-32 overflow-y-auto">
|
||||
{regulations.map(reg => (
|
||||
<label key={reg.code} className="flex items-center gap-1.5">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={selectedRegs.includes(reg.code)}
|
||||
onChange={(e) => {
|
||||
if (e.target.checked) {
|
||||
setSelectedRegs([...selectedRegs, reg.code])
|
||||
} else {
|
||||
setSelectedRegs(selectedRegs.filter(r => r !== reg.code))
|
||||
}
|
||||
}}
|
||||
className="rounded"
|
||||
/>
|
||||
<span className="text-sm text-slate-700">{reg.code}</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3 pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClose}
|
||||
className="px-4 py-2 text-slate-600 hover:text-slate-800"
|
||||
>
|
||||
Abbrechen
|
||||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!name || !auditorName || submitting}
|
||||
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 disabled:opacity-50"
|
||||
>
|
||||
{submitting ? 'Erstelle...' : 'Session erstellen'}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Sign-off Modal
|
||||
function SignOffModal({
|
||||
item,
|
||||
auditorName,
|
||||
onClose,
|
||||
onSubmit,
|
||||
}: {
|
||||
item: AuditChecklistItem
|
||||
auditorName: string
|
||||
onClose: () => void
|
||||
onSubmit: (result: AuditChecklistItem['current_result'], notes: string, sign: boolean) => void
|
||||
}) {
|
||||
const [result, setResult] = useState<AuditChecklistItem['current_result']>(item.current_result)
|
||||
const [notes, setNotes] = useState(item.notes || '')
|
||||
const [sign, setSign] = useState(false)
|
||||
const [submitting, setSubmitting] = useState(false)
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
setSubmitting(true)
|
||||
await onSubmit(result, notes, sign)
|
||||
setSubmitting(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||
<div className="bg-white rounded-xl shadow-xl max-w-lg w-full mx-4 overflow-hidden">
|
||||
<div className="px-6 py-4 border-b border-slate-200">
|
||||
<h2 className="text-lg font-semibold text-slate-900">Auditor Sign-off</h2>
|
||||
<p className="text-sm text-slate-500 mt-1">
|
||||
{item.regulation_code} {item.article} - {item.title}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="p-6 space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-2">
|
||||
Pruefungsergebnis
|
||||
</label>
|
||||
<div className="space-y-2">
|
||||
{Object.entries(RESULT_STATUS).map(([key, { label, color }]) => (
|
||||
<label key={key} className="flex items-center gap-3 p-2 rounded-lg hover:bg-slate-50 cursor-pointer">
|
||||
<input
|
||||
type="radio"
|
||||
name="result"
|
||||
value={key}
|
||||
checked={result === key}
|
||||
onChange={() => setResult(key as typeof result)}
|
||||
className="w-4 h-4"
|
||||
/>
|
||||
<span className={`px-2 py-0.5 text-sm font-medium rounded ${color}`}>
|
||||
{label}
|
||||
</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Anmerkungen
|
||||
</label>
|
||||
<textarea
|
||||
value={notes}
|
||||
onChange={(e) => setNotes(e.target.value)}
|
||||
placeholder="Optionale Anmerkungen zur Pruefung..."
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg"
|
||||
rows={3}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="p-3 bg-amber-50 border border-amber-200 rounded-lg">
|
||||
<label className="flex items-start gap-3 cursor-pointer">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={sign}
|
||||
onChange={(e) => setSign(e.target.checked)}
|
||||
className="mt-1 rounded"
|
||||
/>
|
||||
<div>
|
||||
<span className="font-medium text-amber-800">Digital signieren</span>
|
||||
<p className="text-sm text-amber-700 mt-0.5">
|
||||
Erstellt eine SHA-256 Signatur des Ergebnisses. Diese Aktion kann nicht rueckgaengig gemacht werden.
|
||||
</p>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="text-sm text-slate-500 pt-2 border-t border-slate-100">
|
||||
<p>Auditor: <span className="font-medium text-slate-700">{auditorName}</span></p>
|
||||
<p>Datum: <span className="font-medium text-slate-700">{new Date().toLocaleString('de-DE')}</span></p>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3 pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClose}
|
||||
className="px-4 py-2 text-slate-600 hover:text-slate-800"
|
||||
>
|
||||
Abbrechen
|
||||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={submitting}
|
||||
className={`px-4 py-2 rounded-lg disabled:opacity-50 ${
|
||||
sign
|
||||
? 'bg-amber-600 text-white hover:bg-amber-700'
|
||||
: 'bg-primary-600 text-white hover:bg-primary-700'
|
||||
}`}
|
||||
>
|
||||
{submitting ? 'Speichere...' : sign ? 'Signieren & Speichern' : 'Speichern'}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
'use client'
|
||||
|
||||
import {
|
||||
ServiceModule, RiskAssessment,
|
||||
SERVICE_TYPE_CONFIG, CRITICALITY_CONFIG, RELEVANCE_CONFIG,
|
||||
} from './types'
|
||||
|
||||
interface ModuleDetailPanelProps {
|
||||
module: ServiceModule;
|
||||
loadingDetail: boolean;
|
||||
loadingRisk: boolean;
|
||||
showRiskPanel: boolean;
|
||||
riskAssessment: RiskAssessment | null;
|
||||
onClose: () => void;
|
||||
onAssessRisk: (moduleId: string) => void;
|
||||
onCloseRisk: () => void;
|
||||
}
|
||||
|
||||
export default function ModuleDetailPanel({
|
||||
module, loadingDetail, loadingRisk, showRiskPanel, riskAssessment,
|
||||
onClose, onAssessRisk, onCloseRisk,
|
||||
}: ModuleDetailPanelProps) {
|
||||
return (
|
||||
<div className="w-96 bg-white rounded-lg shadow border sticky top-6 h-fit">
|
||||
<div className={`px-4 py-3 border-b ${SERVICE_TYPE_CONFIG[module.service_type]?.bgColor || 'bg-gray-100'}`}>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-lg">{SERVICE_TYPE_CONFIG[module.service_type]?.icon || '📁'}</span>
|
||||
<button onClick={onClose} className="text-gray-400 hover:text-gray-600">✕</button>
|
||||
</div>
|
||||
<h3 className="font-bold text-lg mt-2">{module.display_name}</h3>
|
||||
<div className="text-sm text-gray-600">{module.name}</div>
|
||||
</div>
|
||||
|
||||
{loadingDetail ? (
|
||||
<div className="p-4 text-center text-gray-500">Lade Details...</div>
|
||||
) : (
|
||||
<div className="p-4 space-y-4">
|
||||
{module.description && (<div><div className="text-xs text-gray-500 uppercase mb-1">Beschreibung</div><div className="text-sm text-gray-700">{module.description}</div></div>)}
|
||||
<div className="grid grid-cols-2 gap-2 text-sm">
|
||||
{module.port && (<div><span className="text-gray-500">Port:</span><span className="ml-1 font-mono">{module.port}</span></div>)}
|
||||
<div><span className="text-gray-500">Criticality:</span><span className={`ml-1 px-1.5 py-0.5 rounded text-xs ${CRITICALITY_CONFIG[module.criticality]?.bgColor || ''} ${CRITICALITY_CONFIG[module.criticality]?.color || ''}`}>{module.criticality}</span></div>
|
||||
</div>
|
||||
<div><div className="text-xs text-gray-500 uppercase mb-1">Tech Stack</div><div className="flex flex-wrap gap-1">{module.technology_stack.map((tech, i) => (<span key={i} className="px-2 py-0.5 bg-gray-100 text-gray-700 text-xs rounded">{tech}</span>))}</div></div>
|
||||
{module.data_categories.length > 0 && (<div><div className="text-xs text-gray-500 uppercase mb-1">Daten-Kategorien</div><div className="flex flex-wrap gap-1">{module.data_categories.map((cat, i) => (<span key={i} className="px-2 py-0.5 bg-blue-50 text-blue-700 text-xs rounded">{cat}</span>))}</div></div>)}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{module.processes_pii && (<span className="px-2 py-1 bg-purple-100 text-purple-700 text-xs rounded">Verarbeitet PII</span>)}
|
||||
{module.ai_components && (<span className="px-2 py-1 bg-pink-100 text-pink-700 text-xs rounded">AI-Komponenten</span>)}
|
||||
{module.processes_health_data && (<span className="px-2 py-1 bg-red-100 text-red-700 text-xs rounded">Gesundheitsdaten</span>)}
|
||||
</div>
|
||||
{module.regulations && module.regulations.length > 0 && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-2">Applicable Regulations ({module.regulations.length})</div>
|
||||
<div className="space-y-2">
|
||||
{module.regulations.map((reg, i) => (
|
||||
<div key={i} className="p-2 bg-gray-50 rounded text-sm">
|
||||
<div className="flex justify-between items-start"><span className="font-medium">{reg.code}</span><span className={`px-1.5 py-0.5 rounded text-xs ${RELEVANCE_CONFIG[reg.relevance_level]?.bgColor || 'bg-gray-100'} ${RELEVANCE_CONFIG[reg.relevance_level]?.color || 'text-gray-700'}`}>{reg.relevance_level}</span></div>
|
||||
<div className="text-gray-500 text-xs">{reg.name}</div>
|
||||
{reg.notes && (<div className="text-gray-600 text-xs mt-1 italic">{reg.notes}</div>)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{module.owner_team && (<div><div className="text-xs text-gray-500 uppercase mb-1">Owner</div><div className="text-sm text-gray-700">{module.owner_team}</div></div>)}
|
||||
{module.repository_path && (<div><div className="text-xs text-gray-500 uppercase mb-1">Repository</div><code className="text-xs bg-gray-100 px-2 py-1 rounded block">{module.repository_path}</code></div>)}
|
||||
<div className="pt-2 border-t">
|
||||
<button onClick={() => onAssessRisk(module.id)} disabled={loadingRisk} className="w-full px-4 py-2 bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg hover:from-purple-700 hover:to-pink-700 transition disabled:opacity-50 flex items-center justify-center gap-2">
|
||||
{loadingRisk ? (<><span className="animate-spin">⚙️</span>AI analysiert...</>) : (<><span>🤖</span>AI Risikobewertung</>)}
|
||||
</button>
|
||||
</div>
|
||||
{showRiskPanel && (
|
||||
<div className="mt-4 p-4 bg-gradient-to-br from-purple-50 to-pink-50 rounded-lg border border-purple-200">
|
||||
<div className="flex justify-between items-center mb-3">
|
||||
<h4 className="font-semibold text-purple-900 flex items-center gap-2"><span>🤖</span> AI Risikobewertung</h4>
|
||||
<button onClick={onCloseRisk} className="text-purple-400 hover:text-purple-600">✕</button>
|
||||
</div>
|
||||
{loadingRisk ? (<div className="text-center py-4 text-purple-600"><div className="animate-pulse">Analysiere Compliance-Risiken...</div></div>) : riskAssessment ? (
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center gap-2"><span className="text-sm text-gray-600">Gesamtrisiko:</span><span className={`px-2 py-1 rounded text-sm font-medium ${riskAssessment.overall_risk === 'critical' ? 'bg-red-100 text-red-700' : riskAssessment.overall_risk === 'high' ? 'bg-orange-100 text-orange-700' : riskAssessment.overall_risk === 'medium' ? 'bg-yellow-100 text-yellow-700' : 'bg-green-100 text-green-700'}`}>{riskAssessment.overall_risk.toUpperCase()}</span><span className="text-xs text-gray-400">({Math.round(riskAssessment.confidence_score * 100)}% Konfidenz)</span></div>
|
||||
{riskAssessment.risk_factors.length > 0 && (<div><div className="text-xs text-gray-500 uppercase mb-1">Risikofaktoren</div><div className="space-y-1">{riskAssessment.risk_factors.map((factor, i) => (<div key={i} className="flex items-center justify-between text-sm bg-white/50 rounded px-2 py-1"><span className="text-gray-700">{factor.factor}</span><span className={`text-xs px-1.5 py-0.5 rounded ${factor.severity === 'critical' || factor.severity === 'high' ? 'bg-red-100 text-red-600' : 'bg-yellow-100 text-yellow-600'}`}>{factor.severity}</span></div>))}</div></div>)}
|
||||
{riskAssessment.compliance_gaps.length > 0 && (<div><div className="text-xs text-gray-500 uppercase mb-1">Compliance-Luecken</div><ul className="text-sm text-gray-700 space-y-1">{riskAssessment.compliance_gaps.map((gap, i) => (<li key={i} className="flex items-start gap-1"><span className="text-red-500">⚠</span><span>{gap}</span></li>))}</ul></div>)}
|
||||
{riskAssessment.recommendations.length > 0 && (<div><div className="text-xs text-gray-500 uppercase mb-1">Empfehlungen</div><ul className="text-sm text-gray-700 space-y-1">{riskAssessment.recommendations.map((rec, i) => (<li key={i} className="flex items-start gap-1"><span className="text-green-500">→</span><span>{rec}</span></li>))}</ul></div>)}
|
||||
</div>
|
||||
) : (<div className="text-center py-4 text-gray-500 text-sm">Klicken Sie auf "AI Risikobewertung" um eine Analyse zu starten.</div>)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
export interface ServiceModule {
|
||||
id: string;
|
||||
name: string;
|
||||
display_name: string;
|
||||
description: string | null;
|
||||
service_type: string;
|
||||
port: number | null;
|
||||
technology_stack: string[];
|
||||
repository_path: string | null;
|
||||
docker_image: string | null;
|
||||
data_categories: string[];
|
||||
processes_pii: boolean;
|
||||
processes_health_data: boolean;
|
||||
ai_components: boolean;
|
||||
criticality: string;
|
||||
owner_team: string | null;
|
||||
is_active: boolean;
|
||||
compliance_score: number | null;
|
||||
regulation_count: number;
|
||||
risk_count: number;
|
||||
created_at: string;
|
||||
regulations?: Array<{
|
||||
code: string;
|
||||
name: string;
|
||||
relevance_level: string;
|
||||
notes: string | null;
|
||||
}>;
|
||||
}
|
||||
|
||||
export interface ModulesOverview {
|
||||
total_modules: number;
|
||||
modules_by_type: Record<string, number>;
|
||||
modules_by_criticality: Record<string, number>;
|
||||
modules_processing_pii: number;
|
||||
modules_with_ai: number;
|
||||
average_compliance_score: number | null;
|
||||
regulations_coverage: Record<string, number>;
|
||||
}
|
||||
|
||||
export interface RiskAssessment {
|
||||
overall_risk: string;
|
||||
risk_factors: Array<{ factor: string; severity: string; likelihood: string }>;
|
||||
recommendations: string[];
|
||||
compliance_gaps: string[];
|
||||
confidence_score: number;
|
||||
}
|
||||
|
||||
export const SERVICE_TYPE_CONFIG: Record<string, { icon: string; color: string; bgColor: string }> = {
|
||||
backend: { icon: '⚙️', color: 'text-blue-700', bgColor: 'bg-blue-100' },
|
||||
database: { icon: '🗄️', color: 'text-purple-700', bgColor: 'bg-purple-100' },
|
||||
ai: { icon: '🤖', color: 'text-pink-700', bgColor: 'bg-pink-100' },
|
||||
communication: { icon: '💬', color: 'text-green-700', bgColor: 'bg-green-100' },
|
||||
storage: { icon: '📦', color: 'text-orange-700', bgColor: 'bg-orange-100' },
|
||||
infrastructure: { icon: '🌐', color: 'text-gray-700', bgColor: 'bg-gray-100' },
|
||||
monitoring: { icon: '📊', color: 'text-cyan-700', bgColor: 'bg-cyan-100' },
|
||||
security: { icon: '🔒', color: 'text-red-700', bgColor: 'bg-red-100' },
|
||||
};
|
||||
|
||||
export const CRITICALITY_CONFIG: Record<string, { color: string; bgColor: string }> = {
|
||||
critical: { color: 'text-red-700', bgColor: 'bg-red-100' },
|
||||
high: { color: 'text-orange-700', bgColor: 'bg-orange-100' },
|
||||
medium: { color: 'text-yellow-700', bgColor: 'bg-yellow-100' },
|
||||
low: { color: 'text-green-700', bgColor: 'bg-green-100' },
|
||||
};
|
||||
|
||||
export const RELEVANCE_CONFIG: Record<string, { color: string; bgColor: string }> = {
|
||||
critical: { color: 'text-red-700', bgColor: 'bg-red-100' },
|
||||
high: { color: 'text-orange-700', bgColor: 'bg-orange-100' },
|
||||
medium: { color: 'text-yellow-700', bgColor: 'bg-yellow-100' },
|
||||
low: { color: 'text-green-700', bgColor: 'bg-green-100' },
|
||||
};
|
||||
@@ -0,0 +1,110 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect } from 'react'
|
||||
import { ServiceModule, ModulesOverview, RiskAssessment } from './types'
|
||||
|
||||
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
|
||||
const API_BASE = `${BACKEND_URL}/api/v1/compliance`
|
||||
|
||||
export function useModulesPage() {
|
||||
const [modules, setModules] = useState<ServiceModule[]>([])
|
||||
const [overview, setOverview] = useState<ModulesOverview | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
const [typeFilter, setTypeFilter] = useState<string>('all')
|
||||
const [criticalityFilter, setCriticalityFilter] = useState<string>('all')
|
||||
const [piiFilter, setPiiFilter] = useState<boolean | null>(null)
|
||||
const [aiFilter, setAiFilter] = useState<boolean | null>(null)
|
||||
const [searchTerm, setSearchTerm] = useState('')
|
||||
|
||||
const [selectedModule, setSelectedModule] = useState<ServiceModule | null>(null)
|
||||
const [loadingDetail, setLoadingDetail] = useState(false)
|
||||
|
||||
const [riskAssessment, setRiskAssessment] = useState<RiskAssessment | null>(null)
|
||||
const [loadingRisk, setLoadingRisk] = useState(false)
|
||||
const [showRiskPanel, setShowRiskPanel] = useState(false)
|
||||
|
||||
useEffect(() => { fetchModules(); fetchOverview() }, [])
|
||||
|
||||
const fetchModules = async () => {
|
||||
try {
|
||||
setLoading(true)
|
||||
const params = new URLSearchParams()
|
||||
if (typeFilter !== 'all') params.append('service_type', typeFilter)
|
||||
if (criticalityFilter !== 'all') params.append('criticality', criticalityFilter)
|
||||
if (piiFilter !== null) params.append('processes_pii', String(piiFilter))
|
||||
if (aiFilter !== null) params.append('ai_components', String(aiFilter))
|
||||
const url = `${API_BASE}/modules${params.toString() ? '?' + params.toString() : ''}`
|
||||
const res = await fetch(url)
|
||||
if (!res.ok) throw new Error('Failed to fetch modules')
|
||||
const data = await res.json()
|
||||
setModules(data.modules || [])
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Unknown error')
|
||||
} finally { setLoading(false) }
|
||||
}
|
||||
|
||||
const fetchOverview = async () => {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/modules/overview`)
|
||||
if (!res.ok) throw new Error('Failed to fetch overview')
|
||||
const data = await res.json()
|
||||
setOverview(data)
|
||||
} catch (err) { console.error('Failed to fetch overview:', err) }
|
||||
}
|
||||
|
||||
const fetchModuleDetail = async (moduleId: string) => {
|
||||
try {
|
||||
setLoadingDetail(true)
|
||||
const res = await fetch(`${API_BASE}/modules/${moduleId}`)
|
||||
if (!res.ok) throw new Error('Failed to fetch module details')
|
||||
const data = await res.json()
|
||||
setSelectedModule(data)
|
||||
} catch (err) { console.error('Failed to fetch module details:', err) }
|
||||
finally { setLoadingDetail(false) }
|
||||
}
|
||||
|
||||
const seedModules = async (force: boolean = false) => {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/modules/seed`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ force }) })
|
||||
if (!res.ok) throw new Error('Failed to seed modules')
|
||||
const data = await res.json()
|
||||
alert(`Seeded ${data.modules_created} modules with ${data.mappings_created} regulation mappings`)
|
||||
fetchModules(); fetchOverview()
|
||||
} catch (err) { alert('Failed to seed modules: ' + (err instanceof Error ? err.message : 'Unknown error')) }
|
||||
}
|
||||
|
||||
const assessModuleRisk = async (moduleId: string) => {
|
||||
setLoadingRisk(true); setShowRiskPanel(true); setRiskAssessment(null)
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/ai/assess-risk`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ module_id: moduleId }) })
|
||||
if (res.ok) { const data = await res.json(); setRiskAssessment(data) }
|
||||
else { alert('AI-Risikobewertung fehlgeschlagen') }
|
||||
} catch (err) { alert('Netzwerkfehler bei AI-Risikobewertung') }
|
||||
finally { setLoadingRisk(false) }
|
||||
}
|
||||
|
||||
const filteredModules = modules.filter(m => {
|
||||
if (!searchTerm) return true
|
||||
const term = searchTerm.toLowerCase()
|
||||
return m.name.toLowerCase().includes(term) || m.display_name.toLowerCase().includes(term) || (m.description && m.description.toLowerCase().includes(term)) || m.technology_stack.some(t => t.toLowerCase().includes(term))
|
||||
})
|
||||
|
||||
const modulesByType = filteredModules.reduce((acc, m) => {
|
||||
const type = m.service_type || 'unknown'
|
||||
if (!acc[type]) acc[type] = []
|
||||
acc[type].push(m)
|
||||
return acc
|
||||
}, {} as Record<string, ServiceModule[]>)
|
||||
|
||||
return {
|
||||
modules, overview, loading, error,
|
||||
typeFilter, setTypeFilter, criticalityFilter, setCriticalityFilter,
|
||||
piiFilter, setPiiFilter, aiFilter, setAiFilter, searchTerm, setSearchTerm,
|
||||
selectedModule, setSelectedModule, loadingDetail,
|
||||
riskAssessment, loadingRisk, showRiskPanel, setShowRiskPanel,
|
||||
filteredModules, modulesByType,
|
||||
fetchModules, fetchModuleDetail, seedModules, assessModuleRisk,
|
||||
}
|
||||
}
|
||||
@@ -1,214 +1,11 @@
|
||||
'use client';
|
||||
|
||||
import { useState, useEffect } from 'react';
|
||||
|
||||
// Types
|
||||
interface ServiceModule {
|
||||
id: string;
|
||||
name: string;
|
||||
display_name: string;
|
||||
description: string | null;
|
||||
service_type: string;
|
||||
port: number | null;
|
||||
technology_stack: string[];
|
||||
repository_path: string | null;
|
||||
docker_image: string | null;
|
||||
data_categories: string[];
|
||||
processes_pii: boolean;
|
||||
processes_health_data: boolean;
|
||||
ai_components: boolean;
|
||||
criticality: string;
|
||||
owner_team: string | null;
|
||||
is_active: boolean;
|
||||
compliance_score: number | null;
|
||||
regulation_count: number;
|
||||
risk_count: number;
|
||||
created_at: string;
|
||||
regulations?: Array<{
|
||||
code: string;
|
||||
name: string;
|
||||
relevance_level: string;
|
||||
notes: string | null;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface ModulesOverview {
|
||||
total_modules: number;
|
||||
modules_by_type: Record<string, number>;
|
||||
modules_by_criticality: Record<string, number>;
|
||||
modules_processing_pii: number;
|
||||
modules_with_ai: number;
|
||||
average_compliance_score: number | null;
|
||||
regulations_coverage: Record<string, number>;
|
||||
}
|
||||
|
||||
// Service Type Icons and Colors
|
||||
const SERVICE_TYPE_CONFIG: Record<string, { icon: string; color: string; bgColor: string }> = {
|
||||
backend: { icon: '⚙️', color: 'text-blue-700', bgColor: 'bg-blue-100' },
|
||||
database: { icon: '🗄️', color: 'text-purple-700', bgColor: 'bg-purple-100' },
|
||||
ai: { icon: '🤖', color: 'text-pink-700', bgColor: 'bg-pink-100' },
|
||||
communication: { icon: '💬', color: 'text-green-700', bgColor: 'bg-green-100' },
|
||||
storage: { icon: '📦', color: 'text-orange-700', bgColor: 'bg-orange-100' },
|
||||
infrastructure: { icon: '🌐', color: 'text-gray-700', bgColor: 'bg-gray-100' },
|
||||
monitoring: { icon: '📊', color: 'text-cyan-700', bgColor: 'bg-cyan-100' },
|
||||
security: { icon: '🔒', color: 'text-red-700', bgColor: 'bg-red-100' },
|
||||
};
|
||||
|
||||
const CRITICALITY_CONFIG: Record<string, { color: string; bgColor: string }> = {
|
||||
critical: { color: 'text-red-700', bgColor: 'bg-red-100' },
|
||||
high: { color: 'text-orange-700', bgColor: 'bg-orange-100' },
|
||||
medium: { color: 'text-yellow-700', bgColor: 'bg-yellow-100' },
|
||||
low: { color: 'text-green-700', bgColor: 'bg-green-100' },
|
||||
};
|
||||
|
||||
const RELEVANCE_CONFIG: Record<string, { color: string; bgColor: string }> = {
|
||||
critical: { color: 'text-red-700', bgColor: 'bg-red-100' },
|
||||
high: { color: 'text-orange-700', bgColor: 'bg-orange-100' },
|
||||
medium: { color: 'text-yellow-700', bgColor: 'bg-yellow-100' },
|
||||
low: { color: 'text-green-700', bgColor: 'bg-green-100' },
|
||||
};
|
||||
import { SERVICE_TYPE_CONFIG, CRITICALITY_CONFIG } from './_components/types';
|
||||
import { useModulesPage } from './_components/useModulesPage';
|
||||
import ModuleDetailPanel from './_components/ModuleDetailPanel';
|
||||
|
||||
export default function ModulesPage() {
|
||||
const [modules, setModules] = useState<ServiceModule[]>([]);
|
||||
const [overview, setOverview] = useState<ModulesOverview | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Filters
|
||||
const [typeFilter, setTypeFilter] = useState<string>('all');
|
||||
const [criticalityFilter, setCriticalityFilter] = useState<string>('all');
|
||||
const [piiFilter, setPiiFilter] = useState<boolean | null>(null);
|
||||
const [aiFilter, setAiFilter] = useState<boolean | null>(null);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
// Selected module for detail view
|
||||
const [selectedModule, setSelectedModule] = useState<ServiceModule | null>(null);
|
||||
const [loadingDetail, setLoadingDetail] = useState(false);
|
||||
|
||||
// AI Risk Assessment
|
||||
const [riskAssessment, setRiskAssessment] = useState<{
|
||||
overall_risk: string;
|
||||
risk_factors: Array<{ factor: string; severity: string; likelihood: string }>;
|
||||
recommendations: string[];
|
||||
compliance_gaps: string[];
|
||||
confidence_score: number;
|
||||
} | null>(null);
|
||||
const [loadingRisk, setLoadingRisk] = useState(false);
|
||||
const [showRiskPanel, setShowRiskPanel] = useState(false);
|
||||
|
||||
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000';
|
||||
const API_BASE = `${BACKEND_URL}/api/v1/compliance`;
|
||||
|
||||
useEffect(() => {
|
||||
fetchModules();
|
||||
fetchOverview();
|
||||
}, []);
|
||||
|
||||
const fetchModules = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const params = new URLSearchParams();
|
||||
if (typeFilter !== 'all') params.append('service_type', typeFilter);
|
||||
if (criticalityFilter !== 'all') params.append('criticality', criticalityFilter);
|
||||
if (piiFilter !== null) params.append('processes_pii', String(piiFilter));
|
||||
if (aiFilter !== null) params.append('ai_components', String(aiFilter));
|
||||
|
||||
const url = `${API_BASE}/modules${params.toString() ? '?' + params.toString() : ''}`;
|
||||
const res = await fetch(url);
|
||||
if (!res.ok) throw new Error('Failed to fetch modules');
|
||||
const data = await res.json();
|
||||
setModules(data.modules || []);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Unknown error');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const fetchOverview = async () => {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/modules/overview`);
|
||||
if (!res.ok) throw new Error('Failed to fetch overview');
|
||||
const data = await res.json();
|
||||
setOverview(data);
|
||||
} catch (err) {
|
||||
console.error('Failed to fetch overview:', err);
|
||||
}
|
||||
};
|
||||
|
||||
const fetchModuleDetail = async (moduleId: string) => {
|
||||
try {
|
||||
setLoadingDetail(true);
|
||||
const res = await fetch(`${API_BASE}/modules/${moduleId}`);
|
||||
if (!res.ok) throw new Error('Failed to fetch module details');
|
||||
const data = await res.json();
|
||||
setSelectedModule(data);
|
||||
} catch (err) {
|
||||
console.error('Failed to fetch module details:', err);
|
||||
} finally {
|
||||
setLoadingDetail(false);
|
||||
}
|
||||
};
|
||||
|
||||
const seedModules = async (force: boolean = false) => {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/modules/seed`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ force }),
|
||||
});
|
||||
if (!res.ok) throw new Error('Failed to seed modules');
|
||||
const data = await res.json();
|
||||
alert(`Seeded ${data.modules_created} modules with ${data.mappings_created} regulation mappings`);
|
||||
fetchModules();
|
||||
fetchOverview();
|
||||
} catch (err) {
|
||||
alert('Failed to seed modules: ' + (err instanceof Error ? err.message : 'Unknown error'));
|
||||
}
|
||||
};
|
||||
|
||||
const assessModuleRisk = async (moduleId: string) => {
|
||||
setLoadingRisk(true);
|
||||
setShowRiskPanel(true);
|
||||
setRiskAssessment(null);
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/ai/assess-risk`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ module_id: moduleId }),
|
||||
});
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
setRiskAssessment(data);
|
||||
} else {
|
||||
alert('AI-Risikobewertung fehlgeschlagen');
|
||||
}
|
||||
} catch (err) {
|
||||
alert('Netzwerkfehler bei AI-Risikobewertung');
|
||||
} finally {
|
||||
setLoadingRisk(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Filter modules by search term
|
||||
const filteredModules = modules.filter(m => {
|
||||
if (!searchTerm) return true;
|
||||
const term = searchTerm.toLowerCase();
|
||||
return (
|
||||
m.name.toLowerCase().includes(term) ||
|
||||
m.display_name.toLowerCase().includes(term) ||
|
||||
(m.description && m.description.toLowerCase().includes(term)) ||
|
||||
m.technology_stack.some(t => t.toLowerCase().includes(term))
|
||||
);
|
||||
});
|
||||
|
||||
// Group by type for visualization
|
||||
const modulesByType = filteredModules.reduce((acc, m) => {
|
||||
const type = m.service_type || 'unknown';
|
||||
if (!acc[type]) acc[type] = [];
|
||||
acc[type].push(m);
|
||||
return acc;
|
||||
}, {} as Record<string, ServiceModule[]>);
|
||||
const mp = useModulesPage();
|
||||
|
||||
return (
|
||||
<div className="p-6 space-y-6">
|
||||
@@ -216,233 +13,72 @@ export default function ModulesPage() {
|
||||
<div className="flex justify-between items-center">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold text-gray-900">Service Module Registry</h1>
|
||||
<p className="text-gray-600 mt-1">
|
||||
Alle {overview?.total_modules || 0} Breakpilot-Services mit Regulation-Mappings
|
||||
</p>
|
||||
<p className="text-gray-600 mt-1">Alle {mp.overview?.total_modules || 0} Breakpilot-Services mit Regulation-Mappings</p>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
onClick={() => seedModules(false)}
|
||||
className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition"
|
||||
>
|
||||
Seed Modules
|
||||
</button>
|
||||
<button
|
||||
onClick={() => seedModules(true)}
|
||||
className="px-4 py-2 bg-gray-600 text-white rounded-lg hover:bg-gray-700 transition"
|
||||
>
|
||||
Force Re-Seed
|
||||
</button>
|
||||
<button onClick={() => mp.seedModules(false)} className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition">Seed Modules</button>
|
||||
<button onClick={() => mp.seedModules(true)} className="px-4 py-2 bg-gray-600 text-white rounded-lg hover:bg-gray-700 transition">Force Re-Seed</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Overview Stats */}
|
||||
{overview && (
|
||||
{mp.overview && (
|
||||
<div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-4">
|
||||
<div className="bg-white rounded-lg p-4 shadow border">
|
||||
<div className="text-3xl font-bold text-blue-600">{overview.total_modules}</div>
|
||||
<div className="text-sm text-gray-600">Services</div>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border">
|
||||
<div className="text-3xl font-bold text-red-600">{overview.modules_by_criticality?.critical || 0}</div>
|
||||
<div className="text-sm text-gray-600">Critical</div>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border">
|
||||
<div className="text-3xl font-bold text-purple-600">{overview.modules_processing_pii}</div>
|
||||
<div className="text-sm text-gray-600">PII-Processing</div>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border">
|
||||
<div className="text-3xl font-bold text-pink-600">{overview.modules_with_ai}</div>
|
||||
<div className="text-sm text-gray-600">AI-Komponenten</div>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border">
|
||||
<div className="text-3xl font-bold text-green-600">
|
||||
{Object.keys(overview.regulations_coverage || {}).length}
|
||||
</div>
|
||||
<div className="text-sm text-gray-600">Regulations</div>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border">
|
||||
<div className="text-3xl font-bold text-cyan-600">
|
||||
{overview.average_compliance_score !== null
|
||||
? `${overview.average_compliance_score}%`
|
||||
: 'N/A'}
|
||||
</div>
|
||||
<div className="text-sm text-gray-600">Avg. Score</div>
|
||||
</div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-blue-600">{mp.overview.total_modules}</div><div className="text-sm text-gray-600">Services</div></div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-red-600">{mp.overview.modules_by_criticality?.critical || 0}</div><div className="text-sm text-gray-600">Critical</div></div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-purple-600">{mp.overview.modules_processing_pii}</div><div className="text-sm text-gray-600">PII-Processing</div></div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-pink-600">{mp.overview.modules_with_ai}</div><div className="text-sm text-gray-600">AI-Komponenten</div></div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-green-600">{Object.keys(mp.overview.regulations_coverage || {}).length}</div><div className="text-sm text-gray-600">Regulations</div></div>
|
||||
<div className="bg-white rounded-lg p-4 shadow border"><div className="text-3xl font-bold text-cyan-600">{mp.overview.average_compliance_score !== null ? `${mp.overview.average_compliance_score}%` : 'N/A'}</div><div className="text-sm text-gray-600">Avg. Score</div></div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Filters */}
|
||||
<div className="bg-white rounded-lg p-4 shadow border">
|
||||
<div className="flex flex-wrap gap-4 items-center">
|
||||
<div>
|
||||
<label className="block text-xs text-gray-500 mb-1">Service Type</label>
|
||||
<select
|
||||
value={typeFilter}
|
||||
onChange={(e) => { setTypeFilter(e.target.value); }}
|
||||
className="border rounded px-3 py-2 text-sm"
|
||||
>
|
||||
<option value="all">Alle Typen</option>
|
||||
<option value="backend">Backend</option>
|
||||
<option value="database">Database</option>
|
||||
<option value="ai">AI/ML</option>
|
||||
<option value="communication">Communication</option>
|
||||
<option value="storage">Storage</option>
|
||||
<option value="infrastructure">Infrastructure</option>
|
||||
<option value="monitoring">Monitoring</option>
|
||||
<option value="security">Security</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-xs text-gray-500 mb-1">Criticality</label>
|
||||
<select
|
||||
value={criticalityFilter}
|
||||
onChange={(e) => { setCriticalityFilter(e.target.value); }}
|
||||
className="border rounded px-3 py-2 text-sm"
|
||||
>
|
||||
<option value="all">Alle</option>
|
||||
<option value="critical">Critical</option>
|
||||
<option value="high">High</option>
|
||||
<option value="medium">Medium</option>
|
||||
<option value="low">Low</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-xs text-gray-500 mb-1">PII</label>
|
||||
<select
|
||||
value={piiFilter === null ? 'all' : String(piiFilter)}
|
||||
onChange={(e) => {
|
||||
const val = e.target.value;
|
||||
setPiiFilter(val === 'all' ? null : val === 'true');
|
||||
}}
|
||||
className="border rounded px-3 py-2 text-sm"
|
||||
>
|
||||
<option value="all">Alle</option>
|
||||
<option value="true">Verarbeitet PII</option>
|
||||
<option value="false">Keine PII</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-xs text-gray-500 mb-1">AI</label>
|
||||
<select
|
||||
value={aiFilter === null ? 'all' : String(aiFilter)}
|
||||
onChange={(e) => {
|
||||
const val = e.target.value;
|
||||
setAiFilter(val === 'all' ? null : val === 'true');
|
||||
}}
|
||||
className="border rounded px-3 py-2 text-sm"
|
||||
>
|
||||
<option value="all">Alle</option>
|
||||
<option value="true">Mit AI</option>
|
||||
<option value="false">Ohne AI</option>
|
||||
</select>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<label className="block text-xs text-gray-500 mb-1">Suche</label>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Service, Beschreibung, Technologie..."
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className="border rounded px-3 py-2 text-sm w-full"
|
||||
/>
|
||||
</div>
|
||||
<div className="pt-5">
|
||||
<button
|
||||
onClick={fetchModules}
|
||||
className="px-4 py-2 bg-gray-100 rounded hover:bg-gray-200 transition text-sm"
|
||||
>
|
||||
Filter anwenden
|
||||
</button>
|
||||
</div>
|
||||
<div><label className="block text-xs text-gray-500 mb-1">Service Type</label><select value={mp.typeFilter} onChange={(e) => mp.setTypeFilter(e.target.value)} className="border rounded px-3 py-2 text-sm"><option value="all">Alle Typen</option>{['backend','database','ai','communication','storage','infrastructure','monitoring','security'].map(t => (<option key={t} value={t}>{t.charAt(0).toUpperCase()+t.slice(1)}</option>))}</select></div>
|
||||
<div><label className="block text-xs text-gray-500 mb-1">Criticality</label><select value={mp.criticalityFilter} onChange={(e) => mp.setCriticalityFilter(e.target.value)} className="border rounded px-3 py-2 text-sm"><option value="all">Alle</option>{['critical','high','medium','low'].map(c => (<option key={c} value={c}>{c.charAt(0).toUpperCase()+c.slice(1)}</option>))}</select></div>
|
||||
<div><label className="block text-xs text-gray-500 mb-1">PII</label><select value={mp.piiFilter === null ? 'all' : String(mp.piiFilter)} onChange={(e) => { const val = e.target.value; mp.setPiiFilter(val === 'all' ? null : val === 'true') }} className="border rounded px-3 py-2 text-sm"><option value="all">Alle</option><option value="true">Verarbeitet PII</option><option value="false">Keine PII</option></select></div>
|
||||
<div><label className="block text-xs text-gray-500 mb-1">AI</label><select value={mp.aiFilter === null ? 'all' : String(mp.aiFilter)} onChange={(e) => { const val = e.target.value; mp.setAiFilter(val === 'all' ? null : val === 'true') }} className="border rounded px-3 py-2 text-sm"><option value="all">Alle</option><option value="true">Mit AI</option><option value="false">Ohne AI</option></select></div>
|
||||
<div className="flex-1"><label className="block text-xs text-gray-500 mb-1">Suche</label><input type="text" placeholder="Service, Beschreibung, Technologie..." value={mp.searchTerm} onChange={(e) => mp.setSearchTerm(e.target.value)} className="border rounded px-3 py-2 text-sm w-full" /></div>
|
||||
<div className="pt-5"><button onClick={mp.fetchModules} className="px-4 py-2 bg-gray-100 rounded hover:bg-gray-200 transition text-sm">Filter anwenden</button></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Error */}
|
||||
{error && (
|
||||
<div className="bg-red-50 border border-red-200 text-red-700 p-4 rounded-lg">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
{mp.error && (<div className="bg-red-50 border border-red-200 text-red-700 p-4 rounded-lg">{mp.error}</div>)}
|
||||
{mp.loading && (<div className="text-center py-12 text-gray-500">Lade Module...</div>)}
|
||||
|
||||
{/* Loading */}
|
||||
{loading && (
|
||||
<div className="text-center py-12 text-gray-500">
|
||||
Lade Module...
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Main Content - Two Column Layout */}
|
||||
{!loading && (
|
||||
{/* Main Content */}
|
||||
{!mp.loading && (
|
||||
<div className="flex gap-6">
|
||||
{/* Module List */}
|
||||
<div className="flex-1 space-y-4">
|
||||
{Object.entries(modulesByType).map(([type, typeModules]) => (
|
||||
{Object.entries(mp.modulesByType).map(([type, typeModules]) => (
|
||||
<div key={type} className="bg-white rounded-lg shadow border">
|
||||
<div className={`px-4 py-2 border-b ${SERVICE_TYPE_CONFIG[type]?.bgColor || 'bg-gray-100'}`}>
|
||||
<span className="text-lg mr-2">{SERVICE_TYPE_CONFIG[type]?.icon || '📁'}</span>
|
||||
<span className={`font-semibold ${SERVICE_TYPE_CONFIG[type]?.color || 'text-gray-700'}`}>
|
||||
{type.charAt(0).toUpperCase() + type.slice(1)}
|
||||
</span>
|
||||
<span className={`font-semibold ${SERVICE_TYPE_CONFIG[type]?.color || 'text-gray-700'}`}>{type.charAt(0).toUpperCase() + type.slice(1)}</span>
|
||||
<span className="text-gray-500 ml-2">({typeModules.length})</span>
|
||||
</div>
|
||||
<div className="divide-y">
|
||||
{typeModules.map((module) => (
|
||||
<div
|
||||
key={module.id}
|
||||
onClick={() => fetchModuleDetail(module.name)}
|
||||
className={`p-4 cursor-pointer hover:bg-gray-50 transition ${
|
||||
selectedModule?.id === module.id ? 'bg-blue-50' : ''
|
||||
}`}
|
||||
>
|
||||
{typeModules.map((mod) => (
|
||||
<div key={mod.id} onClick={() => mp.fetchModuleDetail(mod.name)} className={`p-4 cursor-pointer hover:bg-gray-50 transition ${mp.selectedModule?.id === mod.id ? 'bg-blue-50' : ''}`}>
|
||||
<div className="flex justify-between items-start">
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium text-gray-900">{module.display_name}</span>
|
||||
{module.port && (
|
||||
<span className="text-xs text-gray-400">:{module.port}</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-sm text-gray-500 mt-1">{module.name}</div>
|
||||
{module.description && (
|
||||
<div className="text-sm text-gray-600 mt-1 line-clamp-2">
|
||||
{module.description}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center gap-2"><span className="font-medium text-gray-900">{mod.display_name}</span>{mod.port && (<span className="text-xs text-gray-400">:{mod.port}</span>)}</div>
|
||||
<div className="text-sm text-gray-500 mt-1">{mod.name}</div>
|
||||
{mod.description && (<div className="text-sm text-gray-600 mt-1 line-clamp-2">{mod.description}</div>)}
|
||||
<div className="flex flex-wrap gap-1 mt-2">
|
||||
{module.technology_stack.slice(0, 4).map((tech, i) => (
|
||||
<span key={i} className="px-2 py-0.5 bg-gray-100 text-gray-600 text-xs rounded">
|
||||
{tech}
|
||||
</span>
|
||||
))}
|
||||
{module.technology_stack.length > 4 && (
|
||||
<span className="px-2 py-0.5 text-gray-400 text-xs">
|
||||
+{module.technology_stack.length - 4}
|
||||
</span>
|
||||
)}
|
||||
{mod.technology_stack.slice(0, 4).map((tech, i) => (<span key={i} className="px-2 py-0.5 bg-gray-100 text-gray-600 text-xs rounded">{tech}</span>))}
|
||||
{mod.technology_stack.length > 4 && (<span className="px-2 py-0.5 text-gray-400 text-xs">+{mod.technology_stack.length - 4}</span>)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col items-end gap-1">
|
||||
<span className={`px-2 py-0.5 text-xs rounded ${
|
||||
CRITICALITY_CONFIG[module.criticality]?.bgColor || 'bg-gray-100'
|
||||
} ${CRITICALITY_CONFIG[module.criticality]?.color || 'text-gray-700'}`}>
|
||||
{module.criticality}
|
||||
</span>
|
||||
<span className={`px-2 py-0.5 text-xs rounded ${CRITICALITY_CONFIG[mod.criticality]?.bgColor || 'bg-gray-100'} ${CRITICALITY_CONFIG[mod.criticality]?.color || 'text-gray-700'}`}>{mod.criticality}</span>
|
||||
<div className="flex gap-1 mt-1">
|
||||
{module.processes_pii && (
|
||||
<span className="px-1.5 py-0.5 bg-purple-100 text-purple-700 text-xs rounded" title="Verarbeitet PII">
|
||||
PII
|
||||
</span>
|
||||
)}
|
||||
{module.ai_components && (
|
||||
<span className="px-1.5 py-0.5 bg-pink-100 text-pink-700 text-xs rounded" title="AI-Komponenten">
|
||||
AI
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-xs text-gray-400 mt-1">
|
||||
{module.regulation_count} Regulations
|
||||
{mod.processes_pii && (<span className="px-1.5 py-0.5 bg-purple-100 text-purple-700 text-xs rounded" title="Verarbeitet PII">PII</span>)}
|
||||
{mod.ai_components && (<span className="px-1.5 py-0.5 bg-pink-100 text-pink-700 text-xs rounded" title="AI-Komponenten">AI</span>)}
|
||||
</div>
|
||||
<div className="text-xs text-gray-400 mt-1">{mod.regulation_count} Regulations</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -450,292 +86,33 @@ export default function ModulesPage() {
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{filteredModules.length === 0 && !loading && (
|
||||
<div className="text-center py-12 text-gray-500 bg-white rounded-lg shadow border">
|
||||
Keine Module gefunden.
|
||||
<button
|
||||
onClick={() => seedModules(false)}
|
||||
className="text-blue-600 hover:underline ml-1"
|
||||
>
|
||||
Jetzt seeden?
|
||||
</button>
|
||||
</div>
|
||||
{mp.filteredModules.length === 0 && !mp.loading && (
|
||||
<div className="text-center py-12 text-gray-500 bg-white rounded-lg shadow border">Keine Module gefunden.<button onClick={() => mp.seedModules(false)} className="text-blue-600 hover:underline ml-1">Jetzt seeden?</button></div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Detail Panel */}
|
||||
{selectedModule && (
|
||||
<div className="w-96 bg-white rounded-lg shadow border sticky top-6 h-fit">
|
||||
<div className={`px-4 py-3 border-b ${SERVICE_TYPE_CONFIG[selectedModule.service_type]?.bgColor || 'bg-gray-100'}`}>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-lg">{SERVICE_TYPE_CONFIG[selectedModule.service_type]?.icon || '📁'}</span>
|
||||
<button
|
||||
onClick={() => setSelectedModule(null)}
|
||||
className="text-gray-400 hover:text-gray-600"
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
</div>
|
||||
<h3 className="font-bold text-lg mt-2">{selectedModule.display_name}</h3>
|
||||
<div className="text-sm text-gray-600">{selectedModule.name}</div>
|
||||
</div>
|
||||
|
||||
{loadingDetail ? (
|
||||
<div className="p-4 text-center text-gray-500">Lade Details...</div>
|
||||
) : (
|
||||
<div className="p-4 space-y-4">
|
||||
{/* Description */}
|
||||
{selectedModule.description && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-1">Beschreibung</div>
|
||||
<div className="text-sm text-gray-700">{selectedModule.description}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Technical Details */}
|
||||
<div className="grid grid-cols-2 gap-2 text-sm">
|
||||
{selectedModule.port && (
|
||||
<div>
|
||||
<span className="text-gray-500">Port:</span>
|
||||
<span className="ml-1 font-mono">{selectedModule.port}</span>
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<span className="text-gray-500">Criticality:</span>
|
||||
<span className={`ml-1 px-1.5 py-0.5 rounded text-xs ${
|
||||
CRITICALITY_CONFIG[selectedModule.criticality]?.bgColor || ''
|
||||
} ${CRITICALITY_CONFIG[selectedModule.criticality]?.color || ''}`}>
|
||||
{selectedModule.criticality}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Technology Stack */}
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-1">Tech Stack</div>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{selectedModule.technology_stack.map((tech, i) => (
|
||||
<span key={i} className="px-2 py-0.5 bg-gray-100 text-gray-700 text-xs rounded">
|
||||
{tech}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Data Categories */}
|
||||
{selectedModule.data_categories.length > 0 && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-1">Daten-Kategorien</div>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{selectedModule.data_categories.map((cat, i) => (
|
||||
<span key={i} className="px-2 py-0.5 bg-blue-50 text-blue-700 text-xs rounded">
|
||||
{cat}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Flags */}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{selectedModule.processes_pii && (
|
||||
<span className="px-2 py-1 bg-purple-100 text-purple-700 text-xs rounded">
|
||||
Verarbeitet PII
|
||||
</span>
|
||||
)}
|
||||
{selectedModule.ai_components && (
|
||||
<span className="px-2 py-1 bg-pink-100 text-pink-700 text-xs rounded">
|
||||
AI-Komponenten
|
||||
</span>
|
||||
)}
|
||||
{selectedModule.processes_health_data && (
|
||||
<span className="px-2 py-1 bg-red-100 text-red-700 text-xs rounded">
|
||||
Gesundheitsdaten
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Regulations */}
|
||||
{selectedModule.regulations && selectedModule.regulations.length > 0 && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-2">
|
||||
Applicable Regulations ({selectedModule.regulations.length})
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
{selectedModule.regulations.map((reg, i) => (
|
||||
<div key={i} className="p-2 bg-gray-50 rounded text-sm">
|
||||
<div className="flex justify-between items-start">
|
||||
<span className="font-medium">{reg.code}</span>
|
||||
<span className={`px-1.5 py-0.5 rounded text-xs ${
|
||||
RELEVANCE_CONFIG[reg.relevance_level]?.bgColor || 'bg-gray-100'
|
||||
} ${RELEVANCE_CONFIG[reg.relevance_level]?.color || 'text-gray-700'}`}>
|
||||
{reg.relevance_level}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-gray-500 text-xs">{reg.name}</div>
|
||||
{reg.notes && (
|
||||
<div className="text-gray-600 text-xs mt-1 italic">{reg.notes}</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Owner */}
|
||||
{selectedModule.owner_team && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-1">Owner</div>
|
||||
<div className="text-sm text-gray-700">{selectedModule.owner_team}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Repository */}
|
||||
{selectedModule.repository_path && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-1">Repository</div>
|
||||
<code className="text-xs bg-gray-100 px-2 py-1 rounded block">
|
||||
{selectedModule.repository_path}
|
||||
</code>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* AI Risk Assessment Button */}
|
||||
<div className="pt-2 border-t">
|
||||
<button
|
||||
onClick={() => assessModuleRisk(selectedModule.id)}
|
||||
disabled={loadingRisk}
|
||||
className="w-full px-4 py-2 bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg hover:from-purple-700 hover:to-pink-700 transition disabled:opacity-50 flex items-center justify-center gap-2"
|
||||
>
|
||||
{loadingRisk ? (
|
||||
<>
|
||||
<span className="animate-spin">⚙️</span>
|
||||
AI analysiert...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<span>🤖</span>
|
||||
AI Risikobewertung
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* AI Risk Assessment Panel */}
|
||||
{showRiskPanel && (
|
||||
<div className="mt-4 p-4 bg-gradient-to-br from-purple-50 to-pink-50 rounded-lg border border-purple-200">
|
||||
<div className="flex justify-between items-center mb-3">
|
||||
<h4 className="font-semibold text-purple-900 flex items-center gap-2">
|
||||
<span>🤖</span> AI Risikobewertung
|
||||
</h4>
|
||||
<button
|
||||
onClick={() => setShowRiskPanel(false)}
|
||||
className="text-purple-400 hover:text-purple-600"
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{loadingRisk ? (
|
||||
<div className="text-center py-4 text-purple-600">
|
||||
<div className="animate-pulse">Analysiere Compliance-Risiken...</div>
|
||||
</div>
|
||||
) : riskAssessment ? (
|
||||
<div className="space-y-3">
|
||||
{/* Overall Risk */}
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-gray-600">Gesamtrisiko:</span>
|
||||
<span className={`px-2 py-1 rounded text-sm font-medium ${
|
||||
riskAssessment.overall_risk === 'critical' ? 'bg-red-100 text-red-700' :
|
||||
riskAssessment.overall_risk === 'high' ? 'bg-orange-100 text-orange-700' :
|
||||
riskAssessment.overall_risk === 'medium' ? 'bg-yellow-100 text-yellow-700' :
|
||||
'bg-green-100 text-green-700'
|
||||
}`}>
|
||||
{riskAssessment.overall_risk.toUpperCase()}
|
||||
</span>
|
||||
<span className="text-xs text-gray-400">
|
||||
({Math.round(riskAssessment.confidence_score * 100)}% Konfidenz)
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Risk Factors */}
|
||||
{riskAssessment.risk_factors.length > 0 && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-1">Risikofaktoren</div>
|
||||
<div className="space-y-1">
|
||||
{riskAssessment.risk_factors.map((factor, i) => (
|
||||
<div key={i} className="flex items-center justify-between text-sm bg-white/50 rounded px-2 py-1">
|
||||
<span className="text-gray-700">{factor.factor}</span>
|
||||
<span className={`text-xs px-1.5 py-0.5 rounded ${
|
||||
factor.severity === 'critical' || factor.severity === 'high'
|
||||
? 'bg-red-100 text-red-600'
|
||||
: 'bg-yellow-100 text-yellow-600'
|
||||
}`}>
|
||||
{factor.severity}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Compliance Gaps */}
|
||||
{riskAssessment.compliance_gaps.length > 0 && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-1">Compliance-Lücken</div>
|
||||
<ul className="text-sm text-gray-700 space-y-1">
|
||||
{riskAssessment.compliance_gaps.map((gap, i) => (
|
||||
<li key={i} className="flex items-start gap-1">
|
||||
<span className="text-red-500">⚠</span>
|
||||
<span>{gap}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Recommendations */}
|
||||
{riskAssessment.recommendations.length > 0 && (
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 uppercase mb-1">Empfehlungen</div>
|
||||
<ul className="text-sm text-gray-700 space-y-1">
|
||||
{riskAssessment.recommendations.map((rec, i) => (
|
||||
<li key={i} className="flex items-start gap-1">
|
||||
<span className="text-green-500">→</span>
|
||||
<span>{rec}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-center py-4 text-gray-500 text-sm">
|
||||
Klicken Sie auf "AI Risikobewertung" um eine Analyse zu starten.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{mp.selectedModule && (
|
||||
<ModuleDetailPanel
|
||||
module={mp.selectedModule}
|
||||
loadingDetail={mp.loadingDetail}
|
||||
loadingRisk={mp.loadingRisk}
|
||||
showRiskPanel={mp.showRiskPanel}
|
||||
riskAssessment={mp.riskAssessment}
|
||||
onClose={() => mp.setSelectedModule(null)}
|
||||
onAssessRisk={mp.assessModuleRisk}
|
||||
onCloseRisk={() => mp.setShowRiskPanel(false)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Regulations Coverage Overview */}
|
||||
{overview && overview.regulations_coverage && Object.keys(overview.regulations_coverage).length > 0 && (
|
||||
{/* Regulations Coverage */}
|
||||
{mp.overview && mp.overview.regulations_coverage && Object.keys(mp.overview.regulations_coverage).length > 0 && (
|
||||
<div className="bg-white rounded-lg shadow border p-4">
|
||||
<h3 className="font-semibold text-gray-900 mb-4">Regulation Coverage</h3>
|
||||
<div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-3">
|
||||
{Object.entries(overview.regulations_coverage)
|
||||
.sort(([, a], [, b]) => b - a)
|
||||
.map(([code, count]) => (
|
||||
<div key={code} className="bg-gray-50 rounded p-3 text-center">
|
||||
<div className="text-2xl font-bold text-blue-600">{count}</div>
|
||||
<div className="text-xs text-gray-600 truncate" title={code}>{code}</div>
|
||||
</div>
|
||||
{Object.entries(mp.overview.regulations_coverage).sort(([, a], [, b]) => b - a).map(([code, count]) => (
|
||||
<div key={code} className="bg-gray-50 rounded p-3 text-center"><div className="text-2xl font-bold text-blue-600">{count}</div><div className="text-xs text-gray-600 truncate" title={code}>{code}</div></div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
'use client'
|
||||
|
||||
import {
|
||||
Source, ScraperStatus, ScrapeResult,
|
||||
PDFDocument, PDFExtractionResult,
|
||||
} from './types'
|
||||
import SourceCard from './SourceCard'
|
||||
|
||||
interface ScraperTabsProps {
|
||||
activeTab: string
|
||||
sources: Source[]
|
||||
pdfDocuments: PDFDocument[]
|
||||
status: ScraperStatus | null
|
||||
scraping: boolean
|
||||
extracting: boolean
|
||||
results: ScrapeResult[]
|
||||
pdfResult: PDFExtractionResult | null
|
||||
handleScrapeAll: () => void
|
||||
handleScrapeSingle: (code: string, force: boolean) => void
|
||||
handleExtractPdf: (code: string, saveToDb: boolean, force: boolean) => void
|
||||
}
|
||||
|
||||
export default function ScraperTabs(props: ScraperTabsProps) {
|
||||
const { activeTab, sources, pdfDocuments, status, scraping, extracting, results, pdfResult } = props
|
||||
|
||||
if (activeTab === 'sources') {
|
||||
return (
|
||||
<div>
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-slate-900">Regulierungsquellen</h3>
|
||||
<p className="text-sm text-slate-500">EU-Lex, BSI-TR und deutsche Gesetze</p>
|
||||
</div>
|
||||
<button onClick={props.handleScrapeAll} disabled={scraping} className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2">
|
||||
{scraping ? (<><svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24"><circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" /><path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" /></svg>Laeuft...</>) : (<><svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" /></svg>Alle Quellen scrapen</>)}
|
||||
</button>
|
||||
</div>
|
||||
<div className="space-y-6">
|
||||
<div>
|
||||
<h4 className="text-sm font-medium text-slate-700 mb-3 flex items-center gap-2"><span className="text-lg">🇪🇺</span> EU-Regulierungen (EUR-Lex)</h4>
|
||||
<div className="grid gap-3">{sources.filter(s => s.source_type === 'eur_lex').map(source => (<SourceCard key={source.code} source={source} onScrape={props.handleScrapeSingle} scraping={scraping} />))}</div>
|
||||
</div>
|
||||
<div>
|
||||
<h4 className="text-sm font-medium text-slate-700 mb-3 flex items-center gap-2"><span className="text-lg">🔒</span> BSI Technical Guidelines</h4>
|
||||
<div className="grid gap-3">{sources.filter(s => s.source_type === 'bsi_pdf').map(source => (<SourceCard key={source.code} source={source} onScrape={props.handleScrapeSingle} scraping={scraping} />))}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (activeTab === 'pdf') {
|
||||
return (
|
||||
<div>
|
||||
<div className="mb-6">
|
||||
<h3 className="text-lg font-semibold text-slate-900">PDF-Extraktion (PyMuPDF)</h3>
|
||||
<p className="text-sm text-slate-500">Extrahiert ALLE Pruefaspekte aus BSI-TR-03161 PDFs mit Regex-Pattern-Matching</p>
|
||||
</div>
|
||||
<div className="space-y-4">
|
||||
{pdfDocuments.map(doc => (
|
||||
<div key={doc.code} className="bg-slate-50 rounded-lg p-4 border border-slate-200">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="text-3xl">📄</span>
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-semibold text-slate-900">{doc.code}</span>
|
||||
<span className={`px-2 py-0.5 rounded text-xs font-medium ${doc.available ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700'}`}>{doc.available ? 'Verfuegbar' : 'Nicht gefunden'}</span>
|
||||
</div>
|
||||
<div className="text-sm text-slate-600">{doc.name}</div>
|
||||
<div className="text-xs text-slate-500">{doc.description}</div>
|
||||
<div className="text-xs text-slate-400 mt-1">Erwartete Pruefaspekte: {doc.expected_aspects}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<button onClick={() => props.handleExtractPdf(doc.code, true, false)} disabled={extracting || !doc.available} className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2">
|
||||
{extracting ? (<><svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24"><circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" /><path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" /></svg>Extrahiere...</>) : (<><svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z" /></svg>Extrahieren</>)}
|
||||
</button>
|
||||
<button onClick={() => props.handleExtractPdf(doc.code, true, true)} disabled={extracting || !doc.available} className="px-3 py-2 bg-orange-100 text-orange-700 rounded-lg hover:bg-orange-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed" title="Force: Loescht vorhandene und extrahiert neu">Force</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
{pdfResult && (
|
||||
<div className="mt-6 bg-green-50 rounded-lg p-4 border border-green-200">
|
||||
<h4 className="font-semibold text-green-800 mb-3">Letztes Extraktions-Ergebnis</h4>
|
||||
<div className="grid grid-cols-3 gap-4 mb-4">
|
||||
<div className="text-center p-3 bg-white rounded-lg"><div className="text-2xl font-bold text-green-700">{pdfResult.total_aspects}</div><div className="text-sm text-slate-500">Pruefaspekte gefunden</div></div>
|
||||
<div className="text-center p-3 bg-white rounded-lg"><div className="text-2xl font-bold text-blue-700">{pdfResult.requirements_created}</div><div className="text-sm text-slate-500">Requirements erstellt</div></div>
|
||||
<div className="text-center p-3 bg-white rounded-lg"><div className="text-2xl font-bold text-slate-700">{Object.keys(pdfResult.statistics.by_category || {}).length}</div><div className="text-sm text-slate-500">Kategorien</div></div>
|
||||
</div>
|
||||
{pdfResult.statistics.by_category && Object.keys(pdfResult.statistics.by_category).length > 0 && (
|
||||
<div><h5 className="text-sm font-medium text-slate-700 mb-2">Nach Kategorie:</h5><div className="flex flex-wrap gap-2">{Object.entries(pdfResult.statistics.by_category).map(([cat, count]) => (<span key={cat} className="px-2 py-1 bg-white rounded text-xs text-slate-600">{cat}: <strong>{count}</strong></span>))}</div></div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<div className="mt-6 bg-blue-50 rounded-lg p-4 border border-blue-200">
|
||||
<h4 className="font-semibold text-blue-800 mb-2">Wie funktioniert die PDF-Extraktion?</h4>
|
||||
<ul className="text-sm text-blue-700 space-y-1">
|
||||
<li>- <strong>PyMuPDF (fitz)</strong> liest den PDF-Text</li>
|
||||
<li>- <strong>Regex-Pattern</strong> finden Aspekte wie O.Auth_1, O.Sess_2, T.Network_1</li>
|
||||
<li>- <strong>Kontextanalyse</strong> extrahiert Titel, Kategorie und Anforderungsstufe (MUSS/SOLL/KANN)</li>
|
||||
<li>- <strong>Automatische Speicherung</strong> erstellt Requirements in der Datenbank</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (activeTab === 'status' && status) {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="bg-slate-50 rounded-lg p-6">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div><h3 className="text-lg font-semibold text-slate-900">Scraper-Status</h3><p className="text-sm text-slate-500">Letzter Lauf: {status.stats.last_run ? new Date(status.stats.last_run).toLocaleString('de-DE') : 'Noch nie'}</p></div>
|
||||
<div className={`px-3 py-1.5 rounded-full text-sm font-medium ${status.status === 'running' ? 'bg-blue-100 text-blue-700' : status.status === 'error' ? 'bg-red-100 text-red-700' : status.status === 'completed' ? 'bg-green-100 text-green-700' : 'bg-gray-100 text-gray-700'}`}>
|
||||
{status.status === 'running' ? 'Laeuft' : status.status === 'error' ? 'Fehler' : status.status === 'completed' ? 'Abgeschlossen' : 'Bereit'}
|
||||
</div>
|
||||
</div>
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<div className="text-center p-4 bg-white rounded-lg"><div className="text-2xl font-bold text-slate-900">{status.stats.sources_processed}</div><div className="text-sm text-slate-500">Quellen verarbeitet</div></div>
|
||||
<div className="text-center p-4 bg-white rounded-lg"><div className="text-2xl font-bold text-green-600">{status.stats.requirements_extracted}</div><div className="text-sm text-slate-500">Anforderungen extrahiert</div></div>
|
||||
<div className="text-center p-4 bg-white rounded-lg"><div className="text-2xl font-bold text-red-600">{status.stats.errors}</div><div className="text-sm text-slate-500">Fehler</div></div>
|
||||
</div>
|
||||
{status.last_error && (<div className="mt-4 p-3 bg-red-50 rounded-lg text-sm text-red-700"><strong>Letzter Fehler:</strong> {status.last_error}</div>)}
|
||||
</div>
|
||||
<div className="bg-white border border-slate-200 rounded-lg p-6">
|
||||
<h4 className="font-semibold text-slate-900 mb-4">Wie funktioniert der Scraper?</h4>
|
||||
<div className="space-y-3 text-sm text-slate-600">
|
||||
{[{ n: '1', t: 'EUR-Lex Abruf', d: 'Holt HTML-Version der EU-Verordnung, extrahiert Artikel und Absaetze' }, { n: '2', t: 'BSI-TR Parsing', d: 'Extrahiert Pruefaspekte (O.Auth_1, O.Sess_1, etc.) aus den TR-Dokumenten' }, { n: '3', t: 'Datenbank-Speicherung', d: 'Jede Anforderung wird als Requirement in der Compliance-DB gespeichert' }].map(s => (
|
||||
<div key={s.n} className="flex items-start gap-3"><div className="w-6 h-6 bg-blue-100 rounded-full flex items-center justify-center text-blue-600 font-bold">{s.n}</div><div><strong>{s.t}</strong>: {s.d}</div></div>
|
||||
))}
|
||||
<div className="flex items-start gap-3"><div className="w-6 h-6 bg-green-100 rounded-full flex items-center justify-center text-green-600 font-bold">✓</div><div><strong>Audit-Workspace</strong>: Anforderungen koennen mit Implementierungsdetails angereichert werden</div></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// logs tab
|
||||
return (
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-slate-900 mb-4">Letzte Ergebnisse</h3>
|
||||
{results.length === 0 ? (
|
||||
<div className="text-center py-12 text-slate-500">Keine Ergebnisse vorhanden. Starte einen Scrape-Vorgang.</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
{results.map((result, idx) => (
|
||||
<div key={idx} className={`p-3 rounded-lg flex items-center justify-between ${result.error ? 'bg-red-50' : result.reason ? 'bg-yellow-50' : 'bg-green-50'}`}>
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="text-lg">{result.error ? '❌' : result.reason ? '⏭️' : '✅'}</span>
|
||||
<span className="font-medium">{result.code}</span>
|
||||
<span className="text-sm text-slate-500">{result.error || result.reason || `${result.requirements_extracted} Anforderungen`}</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
'use client'
|
||||
|
||||
import { Source, regulationTypeBadge, sourceTypeBadge } from './types'
|
||||
|
||||
export default function SourceCard({
|
||||
source,
|
||||
onScrape,
|
||||
scraping,
|
||||
}: {
|
||||
source: Source
|
||||
onScrape: (code: string, force: boolean) => void
|
||||
scraping: boolean
|
||||
}) {
|
||||
const regType = regulationTypeBadge[source.regulation_type] || regulationTypeBadge.industry_standard
|
||||
const srcType = sourceTypeBadge[source.source_type] || sourceTypeBadge.manual
|
||||
|
||||
return (
|
||||
<div className="bg-white border border-slate-200 rounded-lg p-4 hover:shadow-sm transition-shadow">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="text-2xl">{regType.icon}</span>
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-semibold text-slate-900">{source.code}</span>
|
||||
<span className={`px-2 py-0.5 rounded text-xs font-medium ${regType.color}`}>
|
||||
{regType.label}
|
||||
</span>
|
||||
<span className={`px-2 py-0.5 rounded text-xs font-medium ${srcType.color}`}>
|
||||
{srcType.label}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-sm text-slate-500 truncate max-w-md" title={source.url}>
|
||||
{source.url.length > 60 ? source.url.substring(0, 60) + '...' : source.url}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-3">
|
||||
{source.has_data ? (
|
||||
<span className="px-3 py-1 bg-green-100 text-green-700 rounded-full text-sm font-medium">
|
||||
{source.requirement_count} Anforderungen
|
||||
</span>
|
||||
) : (
|
||||
<span className="px-3 py-1 bg-gray-100 text-gray-500 rounded-full text-sm">
|
||||
Keine Daten
|
||||
</span>
|
||||
)}
|
||||
|
||||
<div className="flex gap-1">
|
||||
<button
|
||||
onClick={() => onScrape(source.code, false)}
|
||||
disabled={scraping}
|
||||
className="px-3 py-1.5 text-sm bg-slate-100 text-slate-700 rounded hover:bg-slate-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
title="Scrapen (ueberspringt vorhandene)"
|
||||
>
|
||||
Scrapen
|
||||
</button>
|
||||
{source.has_data && (
|
||||
<button
|
||||
onClick={() => onScrape(source.code, true)}
|
||||
disabled={scraping}
|
||||
className="px-3 py-1.5 text-sm bg-orange-100 text-orange-700 rounded hover:bg-orange-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
title="Force: Loescht vorhandene Daten und scraped neu"
|
||||
>
|
||||
Force
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
export interface Source {
|
||||
code: string
|
||||
url: string
|
||||
source_type: string
|
||||
regulation_type: string
|
||||
has_data: boolean
|
||||
requirement_count: number
|
||||
}
|
||||
|
||||
export interface ScraperStatus {
|
||||
status: 'idle' | 'running' | 'completed' | 'error'
|
||||
current_source: string | null
|
||||
last_error: string | null
|
||||
stats: {
|
||||
sources_processed: number
|
||||
requirements_extracted: number
|
||||
errors: number
|
||||
last_run: string | null
|
||||
}
|
||||
known_sources: string[]
|
||||
}
|
||||
|
||||
export interface ScrapeResult {
|
||||
code: string
|
||||
status: string
|
||||
requirements_extracted?: number
|
||||
reason?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
export interface PDFDocument {
|
||||
code: string
|
||||
name: string
|
||||
description: string
|
||||
expected_aspects: string
|
||||
available: boolean
|
||||
}
|
||||
|
||||
export interface PDFExtractionResult {
|
||||
success: boolean
|
||||
source_document: string
|
||||
total_aspects: number
|
||||
requirements_created: number
|
||||
statistics: {
|
||||
by_category: Record<string, number>
|
||||
by_requirement_level: Record<string, number>
|
||||
}
|
||||
}
|
||||
|
||||
export const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
|
||||
|
||||
export const sourceTypeBadge: Record<string, { label: string; color: string }> = {
|
||||
eur_lex: { label: 'EUR-Lex', color: 'bg-blue-100 text-blue-800' },
|
||||
bsi_pdf: { label: 'BSI PDF', color: 'bg-green-100 text-green-800' },
|
||||
gesetze_im_internet: { label: 'Gesetze', color: 'bg-yellow-100 text-yellow-800' },
|
||||
manual: { label: 'Manuell', color: 'bg-gray-100 text-gray-800' },
|
||||
}
|
||||
|
||||
export const regulationTypeBadge: Record<string, { label: string; color: string; icon: string }> = {
|
||||
eu_regulation: { label: 'EU-Verordnung', color: 'bg-indigo-100 text-indigo-800', icon: '🇪🇺' },
|
||||
eu_directive: { label: 'EU-Richtlinie', color: 'bg-purple-100 text-purple-800', icon: '📜' },
|
||||
de_law: { label: 'DE-Gesetz', color: 'bg-yellow-100 text-yellow-800', icon: '🇩🇪' },
|
||||
bsi_standard: { label: 'BSI-Standard', color: 'bg-green-100 text-green-800', icon: '🔒' },
|
||||
industry_standard: { label: 'Standard', color: 'bg-gray-100 text-gray-800', icon: '📋' },
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import {
|
||||
Source, ScraperStatus, ScrapeResult,
|
||||
PDFDocument, PDFExtractionResult, BACKEND_URL,
|
||||
} from './types'
|
||||
|
||||
export function useComplianceScraper() {
|
||||
const [activeTab, setActiveTab] = useState<'sources' | 'pdf' | 'status' | 'logs'>('sources')
|
||||
const [sources, setSources] = useState<Source[]>([])
|
||||
const [pdfDocuments, setPdfDocuments] = useState<PDFDocument[]>([])
|
||||
const [status, setStatus] = useState<ScraperStatus | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [scraping, setScraping] = useState(false)
|
||||
const [extracting, setExtracting] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [success, setSuccess] = useState<string | null>(null)
|
||||
const [results, setResults] = useState<ScrapeResult[]>([])
|
||||
const [pdfResult, setPdfResult] = useState<PDFExtractionResult | null>(null)
|
||||
|
||||
const fetchSources = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/sources`)
|
||||
if (res.ok) { const data = await res.json(); setSources(data.sources || []) }
|
||||
} catch (err) { console.error('Failed to fetch sources:', err) }
|
||||
}, [])
|
||||
|
||||
const fetchPdfDocuments = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/pdf-documents`)
|
||||
if (res.ok) { const data = await res.json(); setPdfDocuments(data.documents || []) }
|
||||
} catch (err) { console.error('Failed to fetch PDF documents:', err) }
|
||||
}, [])
|
||||
|
||||
const fetchStatus = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/status`)
|
||||
if (res.ok) { const data = await res.json(); setStatus(data) }
|
||||
} catch (err) { console.error('Failed to fetch status:', err) }
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const loadData = async () => {
|
||||
setLoading(true)
|
||||
await Promise.all([fetchSources(), fetchStatus(), fetchPdfDocuments()])
|
||||
setLoading(false)
|
||||
}
|
||||
loadData()
|
||||
}, [fetchSources, fetchStatus, fetchPdfDocuments])
|
||||
|
||||
useEffect(() => {
|
||||
if (scraping) { const interval = setInterval(fetchStatus, 2000); return () => clearInterval(interval) }
|
||||
}, [scraping, fetchStatus])
|
||||
|
||||
const handleScrapeAll = async () => {
|
||||
setScraping(true); setError(null); setSuccess(null); setResults([])
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/scrape-all`, { method: 'POST' })
|
||||
if (!res.ok) { const data = await res.json(); throw new Error(data.detail || 'Scraping fehlgeschlagen') }
|
||||
const data = await res.json()
|
||||
setResults([...data.results.success, ...data.results.failed, ...data.results.skipped])
|
||||
setSuccess(`Scraping abgeschlossen: ${data.results.success.length} erfolgreich, ${data.results.skipped.length} uebersprungen, ${data.results.failed.length} fehlgeschlagen`)
|
||||
await fetchSources()
|
||||
} catch (err: any) { setError(err.message) }
|
||||
finally { setScraping(false) }
|
||||
}
|
||||
|
||||
const handleScrapeSingle = async (code: string, force: boolean = false) => {
|
||||
setScraping(true); setError(null); setSuccess(null)
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/scrape/${code}?force=${force}`, { method: 'POST' })
|
||||
if (!res.ok) { const data = await res.json(); throw new Error(data.detail || 'Scraping fehlgeschlagen') }
|
||||
const data = await res.json()
|
||||
if (data.status === 'skipped') { setSuccess(`${code}: Bereits vorhanden (${data.requirement_count} Anforderungen)`) }
|
||||
else { setSuccess(`${code}: ${data.requirements_extracted} Anforderungen extrahiert`) }
|
||||
await fetchSources()
|
||||
} catch (err: any) { setError(err.message) }
|
||||
finally { setScraping(false) }
|
||||
}
|
||||
|
||||
const handleExtractPdf = async (code: string, saveToDb: boolean = true, force: boolean = false) => {
|
||||
setExtracting(true); setError(null); setSuccess(null); setPdfResult(null)
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/extract-pdf`, {
|
||||
method: 'POST', headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ document_code: code, save_to_db: saveToDb, force }),
|
||||
})
|
||||
if (!res.ok) { const data = await res.json(); throw new Error(data.detail || 'PDF-Extraktion fehlgeschlagen') }
|
||||
const data: PDFExtractionResult = await res.json()
|
||||
setPdfResult(data)
|
||||
if (data.success) { setSuccess(`${code}: ${data.total_aspects} Pruefaspekte extrahiert, ${data.requirements_created} Requirements erstellt`) }
|
||||
await fetchSources()
|
||||
} catch (err: any) { setError(err.message) }
|
||||
finally { setExtracting(false) }
|
||||
}
|
||||
|
||||
useEffect(() => { if (success) { const timer = setTimeout(() => setSuccess(null), 5000); return () => clearTimeout(timer) } }, [success])
|
||||
useEffect(() => { if (error) { const timer = setTimeout(() => setError(null), 10000); return () => clearTimeout(timer) } }, [error])
|
||||
|
||||
return {
|
||||
activeTab, setActiveTab, sources, pdfDocuments, status,
|
||||
loading, scraping, extracting, error, success, results, pdfResult,
|
||||
handleScrapeAll, handleScrapeSingle, handleExtractPdf,
|
||||
}
|
||||
}
|
||||
@@ -7,283 +7,20 @@
|
||||
* - EUR-Lex regulations (GDPR, AI Act, CRA, NIS2, etc.)
|
||||
* - BSI Technical Guidelines (TR-03161)
|
||||
* - German laws
|
||||
*
|
||||
* Similar pattern to edu-search and zeugnisse-crawler.
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import AdminLayout from '@/components/admin/AdminLayout'
|
||||
import SystemInfoSection, { SYSTEM_INFO_CONFIGS } from '@/components/admin/SystemInfoSection'
|
||||
|
||||
// Types
|
||||
interface Source {
|
||||
code: string
|
||||
url: string
|
||||
source_type: string
|
||||
regulation_type: string
|
||||
has_data: boolean
|
||||
requirement_count: number
|
||||
}
|
||||
|
||||
interface ScraperStatus {
|
||||
status: 'idle' | 'running' | 'completed' | 'error'
|
||||
current_source: string | null
|
||||
last_error: string | null
|
||||
stats: {
|
||||
sources_processed: number
|
||||
requirements_extracted: number
|
||||
errors: number
|
||||
last_run: string | null
|
||||
}
|
||||
known_sources: string[]
|
||||
}
|
||||
|
||||
interface ScrapeResult {
|
||||
code: string
|
||||
status: string
|
||||
requirements_extracted?: number
|
||||
reason?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
interface PDFDocument {
|
||||
code: string
|
||||
name: string
|
||||
description: string
|
||||
expected_aspects: string
|
||||
available: boolean
|
||||
}
|
||||
|
||||
interface PDFExtractionResult {
|
||||
success: boolean
|
||||
source_document: string
|
||||
total_aspects: number
|
||||
requirements_created: number
|
||||
statistics: {
|
||||
by_category: Record<string, number>
|
||||
by_requirement_level: Record<string, number>
|
||||
}
|
||||
}
|
||||
|
||||
const BACKEND_URL = process.env.NEXT_PUBLIC_BACKEND_URL || 'http://localhost:8000'
|
||||
|
||||
// Source type badges
|
||||
const sourceTypeBadge: Record<string, { label: string; color: string }> = {
|
||||
eur_lex: { label: 'EUR-Lex', color: 'bg-blue-100 text-blue-800' },
|
||||
bsi_pdf: { label: 'BSI PDF', color: 'bg-green-100 text-green-800' },
|
||||
gesetze_im_internet: { label: 'Gesetze', color: 'bg-yellow-100 text-yellow-800' },
|
||||
manual: { label: 'Manuell', color: 'bg-gray-100 text-gray-800' },
|
||||
}
|
||||
|
||||
// Regulation type badges
|
||||
const regulationTypeBadge: Record<string, { label: string; color: string; icon: string }> = {
|
||||
eu_regulation: { label: 'EU-Verordnung', color: 'bg-indigo-100 text-indigo-800', icon: '🇪🇺' },
|
||||
eu_directive: { label: 'EU-Richtlinie', color: 'bg-purple-100 text-purple-800', icon: '📜' },
|
||||
de_law: { label: 'DE-Gesetz', color: 'bg-yellow-100 text-yellow-800', icon: '🇩🇪' },
|
||||
bsi_standard: { label: 'BSI-Standard', color: 'bg-green-100 text-green-800', icon: '🔒' },
|
||||
industry_standard: { label: 'Standard', color: 'bg-gray-100 text-gray-800', icon: '📋' },
|
||||
}
|
||||
import { useComplianceScraper } from './_components/useComplianceScraper'
|
||||
import ScraperTabs from './_components/ScraperTabs'
|
||||
|
||||
export default function ComplianceScraperPage() {
|
||||
const [activeTab, setActiveTab] = useState<'sources' | 'pdf' | 'status' | 'logs'>('sources')
|
||||
const [sources, setSources] = useState<Source[]>([])
|
||||
const [pdfDocuments, setPdfDocuments] = useState<PDFDocument[]>([])
|
||||
const [status, setStatus] = useState<ScraperStatus | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [scraping, setScraping] = useState(false)
|
||||
const [extracting, setExtracting] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [success, setSuccess] = useState<string | null>(null)
|
||||
const [results, setResults] = useState<ScrapeResult[]>([])
|
||||
const [pdfResult, setPdfResult] = useState<PDFExtractionResult | null>(null)
|
||||
const scraper = useComplianceScraper()
|
||||
|
||||
// Fetch sources
|
||||
const fetchSources = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/sources`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setSources(data.sources || [])
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to fetch sources:', err)
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Fetch PDF documents
|
||||
const fetchPdfDocuments = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/pdf-documents`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setPdfDocuments(data.documents || [])
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to fetch PDF documents:', err)
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Fetch status
|
||||
const fetchStatus = useCallback(async () => {
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/status`)
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setStatus(data)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to fetch status:', err)
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Initial load
|
||||
useEffect(() => {
|
||||
const loadData = async () => {
|
||||
setLoading(true)
|
||||
await Promise.all([fetchSources(), fetchStatus(), fetchPdfDocuments()])
|
||||
setLoading(false)
|
||||
}
|
||||
loadData()
|
||||
}, [fetchSources, fetchStatus, fetchPdfDocuments])
|
||||
|
||||
// Poll status while scraping
|
||||
useEffect(() => {
|
||||
if (scraping) {
|
||||
const interval = setInterval(fetchStatus, 2000)
|
||||
return () => clearInterval(interval)
|
||||
}
|
||||
}, [scraping, fetchStatus])
|
||||
|
||||
// Scrape all sources
|
||||
const handleScrapeAll = async () => {
|
||||
setScraping(true)
|
||||
setError(null)
|
||||
setSuccess(null)
|
||||
setResults([])
|
||||
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/scrape-all`, {
|
||||
method: 'POST',
|
||||
})
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json()
|
||||
throw new Error(data.detail || 'Scraping fehlgeschlagen')
|
||||
}
|
||||
|
||||
const data = await res.json()
|
||||
setResults([
|
||||
...data.results.success,
|
||||
...data.results.failed,
|
||||
...data.results.skipped,
|
||||
])
|
||||
setSuccess(`Scraping abgeschlossen: ${data.results.success.length} erfolgreich, ${data.results.skipped.length} uebersprungen, ${data.results.failed.length} fehlgeschlagen`)
|
||||
|
||||
// Refresh sources
|
||||
await fetchSources()
|
||||
} catch (err: any) {
|
||||
setError(err.message)
|
||||
} finally {
|
||||
setScraping(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Scrape single source
|
||||
const handleScrapeSingle = async (code: string, force: boolean = false) => {
|
||||
setScraping(true)
|
||||
setError(null)
|
||||
setSuccess(null)
|
||||
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/scrape/${code}?force=${force}`, {
|
||||
method: 'POST',
|
||||
})
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json()
|
||||
throw new Error(data.detail || 'Scraping fehlgeschlagen')
|
||||
}
|
||||
|
||||
const data = await res.json()
|
||||
|
||||
if (data.status === 'skipped') {
|
||||
setSuccess(`${code}: Bereits vorhanden (${data.requirement_count} Anforderungen)`)
|
||||
} else {
|
||||
setSuccess(`${code}: ${data.requirements_extracted} Anforderungen extrahiert`)
|
||||
}
|
||||
|
||||
// Refresh sources
|
||||
await fetchSources()
|
||||
} catch (err: any) {
|
||||
setError(err.message)
|
||||
} finally {
|
||||
setScraping(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract PDF
|
||||
const handleExtractPdf = async (code: string, saveToDb: boolean = true, force: boolean = false) => {
|
||||
setExtracting(true)
|
||||
setError(null)
|
||||
setSuccess(null)
|
||||
setPdfResult(null)
|
||||
|
||||
try {
|
||||
const res = await fetch(`${BACKEND_URL}/api/v1/compliance/scraper/extract-pdf`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
document_code: code,
|
||||
save_to_db: saveToDb,
|
||||
force: force,
|
||||
}),
|
||||
})
|
||||
|
||||
if (!res.ok) {
|
||||
const data = await res.json()
|
||||
throw new Error(data.detail || 'PDF-Extraktion fehlgeschlagen')
|
||||
}
|
||||
|
||||
const data: PDFExtractionResult = await res.json()
|
||||
setPdfResult(data)
|
||||
|
||||
if (data.success) {
|
||||
setSuccess(`${code}: ${data.total_aspects} Pruefaspekte extrahiert, ${data.requirements_created} Requirements erstellt`)
|
||||
}
|
||||
|
||||
// Refresh sources
|
||||
await fetchSources()
|
||||
} catch (err: any) {
|
||||
setError(err.message)
|
||||
} finally {
|
||||
setExtracting(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear messages
|
||||
useEffect(() => {
|
||||
if (success) {
|
||||
const timer = setTimeout(() => setSuccess(null), 5000)
|
||||
return () => clearTimeout(timer)
|
||||
}
|
||||
}, [success])
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
const timer = setTimeout(() => setError(null), 10000)
|
||||
return () => clearTimeout(timer)
|
||||
}
|
||||
}, [error])
|
||||
|
||||
// Stats cards
|
||||
const StatsCard = ({ title, value, subtitle, icon }: { title: string; value: number | string; subtitle?: string; icon: string }) => (
|
||||
<div className="bg-white rounded-lg shadow-sm p-5 border border-slate-200">
|
||||
<div className="flex items-center">
|
||||
<div className="flex-shrink-0">
|
||||
<span className="text-2xl">{icon}</span>
|
||||
</div>
|
||||
<div className="flex-shrink-0"><span className="text-2xl">{icon}</span></div>
|
||||
<div className="ml-4">
|
||||
<p className="text-sm font-medium text-slate-500">{title}</p>
|
||||
<p className="text-2xl font-semibold text-slate-900">{value}</p>
|
||||
@@ -294,12 +31,8 @@ export default function ComplianceScraperPage() {
|
||||
)
|
||||
|
||||
return (
|
||||
<AdminLayout
|
||||
title="Compliance Scraper"
|
||||
description="Extrahiert Anforderungen aus EU-Regulierungen, BSI-Standards und Gesetzen"
|
||||
>
|
||||
{/* Loading */}
|
||||
{loading && (
|
||||
<AdminLayout title="Compliance Scraper" description="Extrahiert Anforderungen aus EU-Regulierungen, BSI-Standards und Gesetzen">
|
||||
{scraper.loading && (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<svg className="w-8 h-8 animate-spin text-primary-600" fill="none" viewBox="0 0 24 24">
|
||||
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
|
||||
@@ -309,481 +42,77 @@ export default function ComplianceScraperPage() {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!loading && (
|
||||
{!scraper.loading && (
|
||||
<>
|
||||
{/* Messages */}
|
||||
{error && (
|
||||
<div className="mb-4 bg-red-50 border border-red-200 text-red-700 px-4 py-3 rounded-lg">
|
||||
{error}
|
||||
</div>
|
||||
{scraper.error && (
|
||||
<div className="mb-4 bg-red-50 border border-red-200 text-red-700 px-4 py-3 rounded-lg">{scraper.error}</div>
|
||||
)}
|
||||
{success && (
|
||||
<div className="mb-4 bg-green-50 border border-green-200 text-green-700 px-4 py-3 rounded-lg">
|
||||
{success}
|
||||
</div>
|
||||
{scraper.success && (
|
||||
<div className="mb-4 bg-green-50 border border-green-200 text-green-700 px-4 py-3 rounded-lg">{scraper.success}</div>
|
||||
)}
|
||||
|
||||
{/* Stats Cards */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mb-6">
|
||||
<StatsCard
|
||||
title="Bekannte Quellen"
|
||||
value={sources.length}
|
||||
icon="📚"
|
||||
/>
|
||||
<StatsCard
|
||||
title="Mit Daten"
|
||||
value={sources.filter(s => s.has_data).length}
|
||||
subtitle={`${sources.length - sources.filter(s => s.has_data).length} noch zu scrapen`}
|
||||
icon="✅"
|
||||
/>
|
||||
<StatsCard
|
||||
title="Anforderungen gesamt"
|
||||
value={sources.reduce((acc, s) => acc + s.requirement_count, 0)}
|
||||
icon="📋"
|
||||
/>
|
||||
<StatsCard
|
||||
title="Letzter Lauf"
|
||||
value={status?.stats.last_run ? new Date(status.stats.last_run).toLocaleDateString('de-DE') : 'Nie'}
|
||||
subtitle={status?.stats.errors ? `${status.stats.errors} Fehler` : undefined}
|
||||
icon="🕐"
|
||||
/>
|
||||
<StatsCard title="Bekannte Quellen" value={scraper.sources.length} icon="📚" />
|
||||
<StatsCard title="Mit Daten" value={scraper.sources.filter(s => s.has_data).length} subtitle={`${scraper.sources.length - scraper.sources.filter(s => s.has_data).length} noch zu scrapen`} icon="✅" />
|
||||
<StatsCard title="Anforderungen gesamt" value={scraper.sources.reduce((acc, s) => acc + s.requirement_count, 0)} icon="📋" />
|
||||
<StatsCard title="Letzter Lauf" value={scraper.status?.stats.last_run ? new Date(scraper.status.stats.last_run).toLocaleDateString('de-DE') : 'Nie'} subtitle={scraper.status?.stats.errors ? `${scraper.status.stats.errors} Fehler` : undefined} icon="🕐" />
|
||||
</div>
|
||||
|
||||
{/* Scraper Status Bar */}
|
||||
{(scraping || status?.status === 'running') && (
|
||||
{(scraper.scraping || scraper.status?.status === 'running') && (
|
||||
<div className="mb-6 p-4 bg-blue-50 border border-blue-200 rounded-lg">
|
||||
<div className="flex items-center">
|
||||
<div className="animate-spin rounded-full h-4 w-4 border-2 border-blue-600 border-t-transparent mr-3" />
|
||||
<div>
|
||||
<p className="font-medium text-blue-800">Scraper laeuft</p>
|
||||
{status?.current_source && (
|
||||
<p className="text-sm text-blue-600">Aktuell: {status.current_source}</p>
|
||||
)}
|
||||
{scraper.status?.current_source && (<p className="text-sm text-blue-600">Aktuell: {scraper.status.current_source}</p>)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Tabs */}
|
||||
<div className="bg-white rounded-xl shadow-sm border border-slate-200 mb-6">
|
||||
<div className="border-b border-slate-200">
|
||||
<nav className="flex -mb-px">
|
||||
{[
|
||||
{ id: 'sources', name: 'Quellen', icon: '📚' },
|
||||
{ id: 'pdf', name: 'PDF-Extraktion', icon: '📄' },
|
||||
{ id: 'status', name: 'Status', icon: '📊' },
|
||||
{ id: 'logs', name: 'Ergebnisse', icon: '📝' },
|
||||
{ id: 'sources' as const, name: 'Quellen', icon: '📚' },
|
||||
{ id: 'pdf' as const, name: 'PDF-Extraktion', icon: '📄' },
|
||||
{ id: 'status' as const, name: 'Status', icon: '📊' },
|
||||
{ id: 'logs' as const, name: 'Ergebnisse', icon: '📝' },
|
||||
].map(tab => (
|
||||
<button
|
||||
key={tab.id}
|
||||
onClick={() => setActiveTab(tab.id as typeof activeTab)}
|
||||
className={`px-6 py-4 text-sm font-medium border-b-2 transition-colors ${
|
||||
activeTab === tab.id
|
||||
? 'border-primary-600 text-primary-600'
|
||||
: 'border-transparent text-slate-500 hover:text-slate-700 hover:border-slate-300'
|
||||
}`}
|
||||
>
|
||||
<span className="mr-2">{tab.icon}</span>
|
||||
{tab.name}
|
||||
<button key={tab.id} onClick={() => scraper.setActiveTab(tab.id)} className={`px-6 py-4 text-sm font-medium border-b-2 transition-colors ${scraper.activeTab === tab.id ? 'border-primary-600 text-primary-600' : 'border-transparent text-slate-500 hover:text-slate-700 hover:border-slate-300'}`}>
|
||||
<span className="mr-2">{tab.icon}</span>{tab.name}
|
||||
</button>
|
||||
))}
|
||||
</nav>
|
||||
</div>
|
||||
|
||||
<div className="p-6">
|
||||
{/* Sources Tab */}
|
||||
{activeTab === 'sources' && (
|
||||
<div>
|
||||
{/* Header */}
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-slate-900">Regulierungsquellen</h3>
|
||||
<p className="text-sm text-slate-500">EU-Lex, BSI-TR und deutsche Gesetze</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleScrapeAll}
|
||||
disabled={scraping}
|
||||
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
|
||||
>
|
||||
{scraping ? (
|
||||
<>
|
||||
<svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24">
|
||||
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
|
||||
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" />
|
||||
</svg>
|
||||
Laeuft...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
|
||||
</svg>
|
||||
Alle Quellen scrapen
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Sources by Type */}
|
||||
<div className="space-y-6">
|
||||
{/* EU Regulations */}
|
||||
<div>
|
||||
<h4 className="text-sm font-medium text-slate-700 mb-3 flex items-center gap-2">
|
||||
<span className="text-lg">🇪🇺</span> EU-Regulierungen (EUR-Lex)
|
||||
</h4>
|
||||
<div className="grid gap-3">
|
||||
{sources.filter(s => s.source_type === 'eur_lex').map(source => (
|
||||
<SourceCard key={source.code} source={source} onScrape={handleScrapeSingle} scraping={scraping} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* BSI Standards */}
|
||||
<div>
|
||||
<h4 className="text-sm font-medium text-slate-700 mb-3 flex items-center gap-2">
|
||||
<span className="text-lg">🔒</span> BSI Technical Guidelines
|
||||
</h4>
|
||||
<div className="grid gap-3">
|
||||
{sources.filter(s => s.source_type === 'bsi_pdf').map(source => (
|
||||
<SourceCard key={source.code} source={source} onScrape={handleScrapeSingle} scraping={scraping} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* PDF Extraction Tab */}
|
||||
{activeTab === 'pdf' && (
|
||||
<div>
|
||||
<div className="mb-6">
|
||||
<h3 className="text-lg font-semibold text-slate-900">PDF-Extraktion (PyMuPDF)</h3>
|
||||
<p className="text-sm text-slate-500">
|
||||
Extrahiert ALLE Pruefaspekte aus BSI-TR-03161 PDFs mit Regex-Pattern-Matching
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* PDF Documents */}
|
||||
<div className="space-y-4">
|
||||
{pdfDocuments.map(doc => (
|
||||
<div key={doc.code} className="bg-slate-50 rounded-lg p-4 border border-slate-200">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="text-3xl">📄</span>
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-semibold text-slate-900">{doc.code}</span>
|
||||
<span className={`px-2 py-0.5 rounded text-xs font-medium ${
|
||||
doc.available ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700'
|
||||
}`}>
|
||||
{doc.available ? 'Verfuegbar' : 'Nicht gefunden'}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-sm text-slate-600">{doc.name}</div>
|
||||
<div className="text-xs text-slate-500">{doc.description}</div>
|
||||
<div className="text-xs text-slate-400 mt-1">
|
||||
Erwartete Pruefaspekte: {doc.expected_aspects}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
onClick={() => handleExtractPdf(doc.code, true, false)}
|
||||
disabled={extracting || !doc.available}
|
||||
className="px-4 py-2 bg-primary-600 text-white rounded-lg hover:bg-primary-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2"
|
||||
>
|
||||
{extracting ? (
|
||||
<>
|
||||
<svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24">
|
||||
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
|
||||
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z" />
|
||||
</svg>
|
||||
Extrahiere...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z" />
|
||||
</svg>
|
||||
Extrahieren
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
<button
|
||||
onClick={() => handleExtractPdf(doc.code, true, true)}
|
||||
disabled={extracting || !doc.available}
|
||||
className="px-3 py-2 bg-orange-100 text-orange-700 rounded-lg hover:bg-orange-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
title="Force: Loescht vorhandene und extrahiert neu"
|
||||
>
|
||||
Force
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Last Extraction Result */}
|
||||
{pdfResult && (
|
||||
<div className="mt-6 bg-green-50 rounded-lg p-4 border border-green-200">
|
||||
<h4 className="font-semibold text-green-800 mb-3">Letztes Extraktions-Ergebnis</h4>
|
||||
<div className="grid grid-cols-3 gap-4 mb-4">
|
||||
<div className="text-center p-3 bg-white rounded-lg">
|
||||
<div className="text-2xl font-bold text-green-700">{pdfResult.total_aspects}</div>
|
||||
<div className="text-sm text-slate-500">Pruefaspekte gefunden</div>
|
||||
</div>
|
||||
<div className="text-center p-3 bg-white rounded-lg">
|
||||
<div className="text-2xl font-bold text-blue-700">{pdfResult.requirements_created}</div>
|
||||
<div className="text-sm text-slate-500">Requirements erstellt</div>
|
||||
</div>
|
||||
<div className="text-center p-3 bg-white rounded-lg">
|
||||
<div className="text-2xl font-bold text-slate-700">{Object.keys(pdfResult.statistics.by_category || {}).length}</div>
|
||||
<div className="text-sm text-slate-500">Kategorien</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Category Breakdown */}
|
||||
{pdfResult.statistics.by_category && Object.keys(pdfResult.statistics.by_category).length > 0 && (
|
||||
<div>
|
||||
<h5 className="text-sm font-medium text-slate-700 mb-2">Nach Kategorie:</h5>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{Object.entries(pdfResult.statistics.by_category).map(([cat, count]) => (
|
||||
<span key={cat} className="px-2 py-1 bg-white rounded text-xs text-slate-600">
|
||||
{cat}: <strong>{count}</strong>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Info Box */}
|
||||
<div className="mt-6 bg-blue-50 rounded-lg p-4 border border-blue-200">
|
||||
<h4 className="font-semibold text-blue-800 mb-2">Wie funktioniert die PDF-Extraktion?</h4>
|
||||
<ul className="text-sm text-blue-700 space-y-1">
|
||||
<li>• <strong>PyMuPDF (fitz)</strong> liest den PDF-Text</li>
|
||||
<li>• <strong>Regex-Pattern</strong> finden Aspekte wie O.Auth_1, O.Sess_2, T.Network_1</li>
|
||||
<li>• <strong>Kontextanalyse</strong> extrahiert Titel, Kategorie und Anforderungsstufe (MUSS/SOLL/KANN)</li>
|
||||
<li>• <strong>Automatische Speicherung</strong> erstellt Requirements in der Datenbank</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Status Tab */}
|
||||
{activeTab === 'status' && status && (
|
||||
<div className="space-y-6">
|
||||
{/* Current Status */}
|
||||
<div className="bg-slate-50 rounded-lg p-6">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-slate-900">Scraper-Status</h3>
|
||||
<p className="text-sm text-slate-500">
|
||||
Letzter Lauf: {status.stats.last_run ? new Date(status.stats.last_run).toLocaleString('de-DE') : 'Noch nie'}
|
||||
</p>
|
||||
</div>
|
||||
<div className={`px-3 py-1.5 rounded-full text-sm font-medium ${
|
||||
status.status === 'running' ? 'bg-blue-100 text-blue-700' :
|
||||
status.status === 'error' ? 'bg-red-100 text-red-700' :
|
||||
status.status === 'completed' ? 'bg-green-100 text-green-700' :
|
||||
'bg-gray-100 text-gray-700'
|
||||
}`}>
|
||||
{status.status === 'running' ? '🔄 Laeuft' :
|
||||
status.status === 'error' ? '❌ Fehler' :
|
||||
status.status === 'completed' ? '✅ Abgeschlossen' :
|
||||
'⏸️ Bereit'}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<div className="text-center p-4 bg-white rounded-lg">
|
||||
<div className="text-2xl font-bold text-slate-900">{status.stats.sources_processed}</div>
|
||||
<div className="text-sm text-slate-500">Quellen verarbeitet</div>
|
||||
</div>
|
||||
<div className="text-center p-4 bg-white rounded-lg">
|
||||
<div className="text-2xl font-bold text-green-600">{status.stats.requirements_extracted}</div>
|
||||
<div className="text-sm text-slate-500">Anforderungen extrahiert</div>
|
||||
</div>
|
||||
<div className="text-center p-4 bg-white rounded-lg">
|
||||
<div className="text-2xl font-bold text-red-600">{status.stats.errors}</div>
|
||||
<div className="text-sm text-slate-500">Fehler</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{status.last_error && (
|
||||
<div className="mt-4 p-3 bg-red-50 rounded-lg text-sm text-red-700">
|
||||
<strong>Letzter Fehler:</strong> {status.last_error}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Process Description */}
|
||||
<div className="bg-white border border-slate-200 rounded-lg p-6">
|
||||
<h4 className="font-semibold text-slate-900 mb-4">Wie funktioniert der Scraper?</h4>
|
||||
<div className="space-y-3 text-sm text-slate-600">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="w-6 h-6 bg-blue-100 rounded-full flex items-center justify-center text-blue-600 font-bold">1</div>
|
||||
<div>
|
||||
<strong>EUR-Lex Abruf</strong>: Holt HTML-Version der EU-Verordnung, extrahiert Artikel und Absaetze
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="w-6 h-6 bg-blue-100 rounded-full flex items-center justify-center text-blue-600 font-bold">2</div>
|
||||
<div>
|
||||
<strong>BSI-TR Parsing</strong>: Extrahiert Pruefaspekte (O.Auth_1, O.Sess_1, etc.) aus den TR-Dokumenten
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="w-6 h-6 bg-blue-100 rounded-full flex items-center justify-center text-blue-600 font-bold">3</div>
|
||||
<div>
|
||||
<strong>Datenbank-Speicherung</strong>: Jede Anforderung wird als Requirement in der Compliance-DB gespeichert
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="w-6 h-6 bg-green-100 rounded-full flex items-center justify-center text-green-600 font-bold">✓</div>
|
||||
<div>
|
||||
<strong>Audit-Workspace</strong>: Anforderungen koennen mit Implementierungsdetails angereichert werden
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Results Tab */}
|
||||
{activeTab === 'logs' && (
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-slate-900 mb-4">Letzte Ergebnisse</h3>
|
||||
|
||||
{results.length === 0 ? (
|
||||
<div className="text-center py-12 text-slate-500">
|
||||
Keine Ergebnisse vorhanden. Starte einen Scrape-Vorgang.
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
{results.map((result, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className={`p-3 rounded-lg flex items-center justify-between ${
|
||||
result.error ? 'bg-red-50' :
|
||||
result.reason ? 'bg-yellow-50' :
|
||||
'bg-green-50'
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="text-lg">
|
||||
{result.error ? '❌' : result.reason ? '⏭️' : '✅'}
|
||||
</span>
|
||||
<span className="font-medium">{result.code}</span>
|
||||
<span className="text-sm text-slate-500">
|
||||
{result.error || result.reason || `${result.requirements_extracted} Anforderungen`}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<ScraperTabs
|
||||
activeTab={scraper.activeTab}
|
||||
sources={scraper.sources}
|
||||
pdfDocuments={scraper.pdfDocuments}
|
||||
status={scraper.status}
|
||||
scraping={scraper.scraping}
|
||||
extracting={scraper.extracting}
|
||||
results={scraper.results}
|
||||
pdfResult={scraper.pdfResult}
|
||||
handleScrapeAll={scraper.handleScrapeAll}
|
||||
handleScrapeSingle={scraper.handleScrapeSingle}
|
||||
handleExtractPdf={scraper.handleExtractPdf}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* System Info Section */}
|
||||
<div className="mt-8 border-t border-slate-200 pt-8">
|
||||
<SystemInfoSection config={SYSTEM_INFO_CONFIGS.complianceScraper || {
|
||||
title: 'Compliance Scraper',
|
||||
description: 'Regulation & Requirements Extraction Service',
|
||||
version: '1.0.0',
|
||||
features: [
|
||||
'EUR-Lex HTML Parsing',
|
||||
'BSI-TR PDF Extraction',
|
||||
'Automatic Requirement Mapping',
|
||||
'Incremental Updates',
|
||||
],
|
||||
technicalDetails: {
|
||||
'Backend': 'Python/FastAPI',
|
||||
'HTTP Client': 'httpx async',
|
||||
'HTML Parser': 'BeautifulSoup4',
|
||||
'PDF Parser': 'PyMuPDF (optional)',
|
||||
'Database': 'PostgreSQL',
|
||||
},
|
||||
features: ['EUR-Lex HTML Parsing', 'BSI-TR PDF Extraction', 'Automatic Requirement Mapping', 'Incremental Updates'],
|
||||
technicalDetails: { 'Backend': 'Python/FastAPI', 'HTTP Client': 'httpx async', 'HTML Parser': 'BeautifulSoup4', 'PDF Parser': 'PyMuPDF (optional)', 'Database': 'PostgreSQL' },
|
||||
}} />
|
||||
</div>
|
||||
</AdminLayout>
|
||||
)
|
||||
}
|
||||
|
||||
// Source Card Component
|
||||
function SourceCard({
|
||||
source,
|
||||
onScrape,
|
||||
scraping
|
||||
}: {
|
||||
source: Source
|
||||
onScrape: (code: string, force: boolean) => void
|
||||
scraping: boolean
|
||||
}) {
|
||||
const regType = regulationTypeBadge[source.regulation_type] || regulationTypeBadge.industry_standard
|
||||
const srcType = sourceTypeBadge[source.source_type] || sourceTypeBadge.manual
|
||||
|
||||
return (
|
||||
<div className="bg-white border border-slate-200 rounded-lg p-4 hover:shadow-sm transition-shadow">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="text-2xl">{regType.icon}</span>
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-semibold text-slate-900">{source.code}</span>
|
||||
<span className={`px-2 py-0.5 rounded text-xs font-medium ${regType.color}`}>
|
||||
{regType.label}
|
||||
</span>
|
||||
<span className={`px-2 py-0.5 rounded text-xs font-medium ${srcType.color}`}>
|
||||
{srcType.label}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-sm text-slate-500 truncate max-w-md" title={source.url}>
|
||||
{source.url.length > 60 ? source.url.substring(0, 60) + '...' : source.url}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-3">
|
||||
{source.has_data ? (
|
||||
<span className="px-3 py-1 bg-green-100 text-green-700 rounded-full text-sm font-medium">
|
||||
{source.requirement_count} Anforderungen
|
||||
</span>
|
||||
) : (
|
||||
<span className="px-3 py-1 bg-gray-100 text-gray-500 rounded-full text-sm">
|
||||
Keine Daten
|
||||
</span>
|
||||
)}
|
||||
|
||||
<div className="flex gap-1">
|
||||
<button
|
||||
onClick={() => onScrape(source.code, false)}
|
||||
disabled={scraping}
|
||||
className="px-3 py-1.5 text-sm bg-slate-100 text-slate-700 rounded hover:bg-slate-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
title="Scrapen (ueberspringt vorhandene)"
|
||||
>
|
||||
Scrapen
|
||||
</button>
|
||||
{source.has_data && (
|
||||
<button
|
||||
onClick={() => onScrape(source.code, true)}
|
||||
disabled={scraping}
|
||||
className="px-3 py-1.5 text-sm bg-orange-100 text-orange-700 rounded hover:bg-orange-200 transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
title="Force: Loescht vorhandene Daten und scraped neu"
|
||||
>
|
||||
Force
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
'use client'
|
||||
|
||||
import { WebsiteContent, FeatureContent } from '@/lib/content-types'
|
||||
import HeroEditor from './HeroEditor'
|
||||
|
||||
interface ContentEditorTabsProps {
|
||||
activeTab: string
|
||||
content: WebsiteContent
|
||||
isRTL: boolean
|
||||
t: (key: string) => string
|
||||
updateHero: (field: any, value: string) => void
|
||||
updateFeature: (index: number, field: keyof FeatureContent, value: string) => void
|
||||
updateFAQ: (index: number, field: 'question' | 'answer', value: string | string[]) => void
|
||||
addFAQ: () => void
|
||||
removeFAQ: (index: number) => void
|
||||
updatePricing: (index: number, field: string, value: string | number | boolean) => void
|
||||
updateTrust: (key: 'item1' | 'item2' | 'item3', field: 'value' | 'label', value: string) => void
|
||||
updateTestimonial: (field: 'quote' | 'author' | 'role', value: string) => void
|
||||
}
|
||||
|
||||
export default function ContentEditorTabs(props: ContentEditorTabsProps) {
|
||||
const { activeTab, content, isRTL, t } = props
|
||||
const dir = isRTL ? 'rtl' : 'ltr'
|
||||
|
||||
if (activeTab === 'hero') {
|
||||
return <HeroEditor content={content} isRTL={isRTL} t={t} updateHero={props.updateHero} />
|
||||
}
|
||||
|
||||
if (activeTab === 'features') {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_features')}</h2>
|
||||
{content.features.map((feature, index) => (
|
||||
<div key={feature.id} className="border border-slate-200 rounded-lg p-4">
|
||||
<div className="grid gap-4">
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Icon</label>
|
||||
<input type="text" value={feature.icon} onChange={(e) => props.updateFeature(index, 'icon', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg text-2xl text-center" />
|
||||
</div>
|
||||
<div className="col-span-2">
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Titel</label>
|
||||
<input type="text" value={feature.title} onChange={(e) => props.updateFeature(index, 'title', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Beschreibung</label>
|
||||
<textarea value={feature.description} onChange={(e) => props.updateFeature(index, 'description', e.target.value)} rows={2} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (activeTab === 'faq') {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className={`flex items-center justify-between ${isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_faq')}</h2>
|
||||
<button onClick={props.addFAQ} className="px-4 py-2 bg-slate-100 text-slate-700 rounded-lg hover:bg-slate-200 transition-colors">{t('admin_add_faq')}</button>
|
||||
</div>
|
||||
{content.faq.map((item, index) => (
|
||||
<div key={index} className="border border-slate-200 rounded-lg p-4">
|
||||
<div className={`flex items-start justify-between gap-4 ${isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
<div className="flex-1 space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">{t('admin_question')} {index + 1}</label>
|
||||
<input type="text" value={item.question} onChange={(e) => props.updateFAQ(index, 'question', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">{t('admin_answer')}</label>
|
||||
<textarea value={item.answer.join('\n')} onChange={(e) => props.updateFAQ(index, 'answer', e.target.value)} rows={4} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500 font-mono text-sm" dir={dir} />
|
||||
</div>
|
||||
</div>
|
||||
<button onClick={() => props.removeFAQ(index)} className="p-2 text-red-600 hover:bg-red-50 rounded-lg transition-colors" title="Frage entfernen">
|
||||
<svg className="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (activeTab === 'pricing') {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_pricing')}</h2>
|
||||
{content.pricing.map((plan, index) => (
|
||||
<div key={plan.id} className="border border-slate-200 rounded-lg p-4">
|
||||
<div className="grid gap-4">
|
||||
<div className="grid grid-cols-4 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Name</label>
|
||||
<input type="text" value={plan.name} onChange={(e) => props.updatePricing(index, 'name', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Preis (EUR)</label>
|
||||
<input type="number" step="0.01" value={plan.price} onChange={(e) => props.updatePricing(index, 'price', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Intervall</label>
|
||||
<input type="text" value={plan.interval} onChange={(e) => props.updatePricing(index, 'interval', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
<div className="flex items-end">
|
||||
<label className="flex items-center gap-2">
|
||||
<input type="checkbox" checked={plan.popular || false} onChange={(e) => props.updatePricing(index, 'popular', e.target.checked)} className="w-4 h-4 text-primary-600 rounded" />
|
||||
<span className="text-sm text-slate-700">{t('pricing_popular')}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Beschreibung</label>
|
||||
<input type="text" value={plan.description} onChange={(e) => props.updatePricing(index, 'description', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">{t('pricing_tasks')}</label>
|
||||
<input type="text" value={plan.features.tasks} onChange={(e) => props.updatePricing(index, 'features.tasks', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Aufgaben-Beschreibung</label>
|
||||
<input type="text" value={plan.features.taskDescription} onChange={(e) => props.updatePricing(index, 'features.taskDescription', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Features (eine pro Zeile)</label>
|
||||
<textarea value={plan.features.included.join('\n')} onChange={(e) => props.updatePricing(index, 'features.included', e.target.value)} rows={4} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500 font-mono text-sm" dir={dir} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// 'other' tab
|
||||
return (
|
||||
<div className="space-y-8">
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold text-slate-900 mb-4">Trust Indicators</h2>
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
{(['item1', 'item2', 'item3'] as const).map((key, index) => (
|
||||
<div key={key} className="border border-slate-200 rounded-lg p-4">
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Wert {index + 1}</label>
|
||||
<input type="text" value={content.trust[key].value} onChange={(e) => props.updateTrust(key, 'value', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Label {index + 1}</label>
|
||||
<input type="text" value={content.trust[key].label} onChange={(e) => props.updateTrust(key, 'label', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold text-slate-900 mb-4">Testimonial</h2>
|
||||
<div className="border border-slate-200 rounded-lg p-4 space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Zitat</label>
|
||||
<textarea value={content.testimonial.quote} onChange={(e) => props.updateTestimonial('quote', e.target.value)} rows={3} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Autor</label>
|
||||
<input type="text" value={content.testimonial.author} onChange={(e) => props.updateTestimonial('author', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Rolle</label>
|
||||
<input type="text" value={content.testimonial.role} onChange={(e) => props.updateTestimonial('role', e.target.value)} className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500" dir={dir} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
'use client'
|
||||
|
||||
import { WebsiteContent, HeroContent } from '@/lib/content-types'
|
||||
|
||||
interface HeroEditorProps {
|
||||
content: WebsiteContent
|
||||
isRTL: boolean
|
||||
t: (key: string) => string
|
||||
updateHero: (field: keyof HeroContent, value: string) => void
|
||||
}
|
||||
|
||||
export default function HeroEditor({ content, isRTL, t, updateHero }: HeroEditorProps) {
|
||||
const dir = isRTL ? 'rtl' : 'ltr'
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_hero')} Section</h2>
|
||||
<div className="grid gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Badge</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.badge}
|
||||
onChange={(e) => updateHero('badge', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={dir}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Titel (vor Highlight)</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.title}
|
||||
onChange={(e) => updateHero('title', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={dir}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Highlight 1</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.titleHighlight1}
|
||||
onChange={(e) => updateHero('titleHighlight1', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={dir}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Highlight 2</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.titleHighlight2}
|
||||
onChange={(e) => updateHero('titleHighlight2', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={dir}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Untertitel</label>
|
||||
<textarea
|
||||
value={content.hero.subtitle}
|
||||
onChange={(e) => updateHero('subtitle', e.target.value)}
|
||||
rows={3}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={dir}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">CTA Primaer</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.ctaPrimary}
|
||||
onChange={(e) => updateHero('ctaPrimary', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={dir}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">CTA Sekundaer</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.ctaSecondary}
|
||||
onChange={(e) => updateHero('ctaSecondary', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={dir}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">CTA Hinweis</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.ctaHint}
|
||||
onChange={(e) => updateHero('ctaHint', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={dir}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
'use client'
|
||||
|
||||
import { RefObject } from 'react'
|
||||
|
||||
interface LivePreviewPanelProps {
|
||||
activeTab: string
|
||||
iframeRef: RefObject<HTMLIFrameElement | null>
|
||||
}
|
||||
|
||||
export default function LivePreviewPanel({ activeTab, iframeRef }: LivePreviewPanelProps) {
|
||||
return (
|
||||
<div className="bg-white rounded-xl border border-slate-200 shadow-sm overflow-hidden">
|
||||
{/* Preview Header */}
|
||||
<div className="bg-slate-50 border-b border-slate-200 px-4 py-3 flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex gap-1.5">
|
||||
<div className="w-3 h-3 rounded-full bg-red-400"></div>
|
||||
<div className="w-3 h-3 rounded-full bg-yellow-400"></div>
|
||||
<div className="w-3 h-3 rounded-full bg-green-400"></div>
|
||||
</div>
|
||||
<span className="text-xs text-slate-500 ml-2">localhost:3000</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs font-medium text-slate-600 bg-slate-200 px-2 py-1 rounded">
|
||||
{activeTab === 'hero' && 'Hero Section'}
|
||||
{activeTab === 'features' && 'Features'}
|
||||
{activeTab === 'faq' && 'FAQ'}
|
||||
{activeTab === 'pricing' && 'Pricing'}
|
||||
{activeTab === 'other' && 'Trust & Testimonial'}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => iframeRef.current?.contentWindow?.location.reload()}
|
||||
className="p-1.5 text-slate-500 hover:text-slate-700 hover:bg-slate-200 rounded transition-colors"
|
||||
title="Preview neu laden"
|
||||
>
|
||||
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/* Preview Frame */}
|
||||
<div className="relative h-[calc(100vh-340px)] bg-slate-100">
|
||||
<iframe
|
||||
ref={iframeRef}
|
||||
src={`/?preview=true§ion=${activeTab}#${activeTab}`}
|
||||
className="w-full h-full border-0 scale-75 origin-top-left"
|
||||
style={{
|
||||
width: '133.33%',
|
||||
height: '133.33%',
|
||||
transform: 'scale(0.75)',
|
||||
transformOrigin: 'top left',
|
||||
}}
|
||||
title="Website Preview"
|
||||
sandbox="allow-same-origin allow-scripts"
|
||||
/>
|
||||
<div className="absolute bottom-4 left-4 right-4 bg-blue-600 text-white px-4 py-2 rounded-lg shadow-lg flex items-center gap-2 text-sm">
|
||||
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
<span>
|
||||
Du bearbeitest: <strong>
|
||||
{activeTab === 'hero' && 'Hero Section (Startbereich)'}
|
||||
{activeTab === 'features' && 'Features (Funktionen)'}
|
||||
{activeTab === 'faq' && 'FAQ (Haeufige Fragen)'}
|
||||
{activeTab === 'pricing' && 'Pricing (Preise)'}
|
||||
{activeTab === 'other' && 'Trust & Testimonial'}
|
||||
</strong>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
export const ADMIN_KEY = 'breakpilot-admin-2024'
|
||||
|
||||
export const SECTION_MAP: Record<string, { selector: string; scrollTo: string }> = {
|
||||
hero: { selector: '#hero', scrollTo: 'hero' },
|
||||
features: { selector: '#features', scrollTo: 'features' },
|
||||
faq: { selector: '#faq', scrollTo: 'faq' },
|
||||
pricing: { selector: '#pricing', scrollTo: 'pricing' },
|
||||
other: { selector: '#trust', scrollTo: 'trust' },
|
||||
}
|
||||
|
||||
export type ContentTab = 'hero' | 'features' | 'faq' | 'pricing' | 'other'
|
||||
@@ -0,0 +1,173 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect, useRef, useCallback } from 'react'
|
||||
import { WebsiteContent, HeroContent, FeatureContent } from '@/lib/content-types'
|
||||
import { useLanguage } from '@/lib/LanguageContext'
|
||||
import { ADMIN_KEY, SECTION_MAP, ContentTab } from './types'
|
||||
|
||||
export function useContentEditor() {
|
||||
const { language, setLanguage, t, isRTL } = useLanguage()
|
||||
const [content, setContent] = useState<WebsiteContent | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [saving, setSaving] = useState(false)
|
||||
const [message, setMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null)
|
||||
const [activeTab, setActiveTab] = useState<ContentTab>('hero')
|
||||
const [showPreview, setShowPreview] = useState(true)
|
||||
const iframeRef = useRef<HTMLIFrameElement>(null)
|
||||
|
||||
const scrollToSection = useCallback((tab: string) => {
|
||||
if (!iframeRef.current?.contentWindow) return
|
||||
const section = SECTION_MAP[tab]
|
||||
if (section) {
|
||||
try {
|
||||
iframeRef.current.contentWindow.postMessage(
|
||||
{ type: 'scrollTo', section: section.scrollTo },
|
||||
'*'
|
||||
)
|
||||
} catch {
|
||||
// Same-origin policy - fallback
|
||||
}
|
||||
}
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
scrollToSection(activeTab)
|
||||
}, [activeTab, scrollToSection])
|
||||
|
||||
useEffect(() => {
|
||||
loadContent()
|
||||
}, [])
|
||||
|
||||
async function loadContent() {
|
||||
try {
|
||||
const res = await fetch('/api/content')
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setContent(data)
|
||||
} else {
|
||||
setMessage({ type: 'error', text: t('admin_error') })
|
||||
}
|
||||
} catch (error) {
|
||||
setMessage({ type: 'error', text: t('admin_error') })
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
async function saveChanges() {
|
||||
if (!content) return
|
||||
setSaving(true)
|
||||
setMessage(null)
|
||||
try {
|
||||
const res = await fetch('/api/content', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-admin-key': ADMIN_KEY,
|
||||
},
|
||||
body: JSON.stringify(content),
|
||||
})
|
||||
if (res.ok) {
|
||||
setMessage({ type: 'success', text: t('admin_saved') })
|
||||
} else {
|
||||
const error = await res.json()
|
||||
setMessage({ type: 'error', text: error.error || t('admin_error') })
|
||||
}
|
||||
} catch (error) {
|
||||
setMessage({ type: 'error', text: t('admin_error') })
|
||||
} finally {
|
||||
setSaving(false)
|
||||
}
|
||||
}
|
||||
|
||||
function updateHero(field: keyof HeroContent, value: string) {
|
||||
if (!content) return
|
||||
setContent({ ...content, hero: { ...content.hero, [field]: value } })
|
||||
}
|
||||
|
||||
function updateFeature(index: number, field: keyof FeatureContent, value: string) {
|
||||
if (!content) return
|
||||
const newFeatures = [...content.features]
|
||||
newFeatures[index] = { ...newFeatures[index], [field]: value }
|
||||
setContent({ ...content, features: newFeatures })
|
||||
}
|
||||
|
||||
function updateFAQ(index: number, field: 'question' | 'answer', value: string | string[]) {
|
||||
if (!content) return
|
||||
const newFAQ = [...content.faq]
|
||||
if (field === 'answer' && typeof value === 'string') {
|
||||
newFAQ[index] = { ...newFAQ[index], answer: value.split('\n') }
|
||||
} else if (field === 'question' && typeof value === 'string') {
|
||||
newFAQ[index] = { ...newFAQ[index], question: value }
|
||||
}
|
||||
setContent({ ...content, faq: newFAQ })
|
||||
}
|
||||
|
||||
function addFAQ() {
|
||||
if (!content) return
|
||||
setContent({
|
||||
...content,
|
||||
faq: [...content.faq, { question: 'Neue Frage?', answer: ['Antwort hier...'] }],
|
||||
})
|
||||
}
|
||||
|
||||
function removeFAQ(index: number) {
|
||||
if (!content) return
|
||||
const newFAQ = content.faq.filter((_, i) => i !== index)
|
||||
setContent({ ...content, faq: newFAQ })
|
||||
}
|
||||
|
||||
function updatePricing(index: number, field: string, value: string | number | boolean) {
|
||||
if (!content) return
|
||||
const newPricing = [...content.pricing]
|
||||
if (field === 'price') {
|
||||
newPricing[index] = { ...newPricing[index], price: Number(value) }
|
||||
} else if (field === 'popular') {
|
||||
newPricing[index] = { ...newPricing[index], popular: Boolean(value) }
|
||||
} else if (field.startsWith('features.')) {
|
||||
const subField = field.replace('features.', '')
|
||||
if (subField === 'included' && typeof value === 'string') {
|
||||
newPricing[index] = {
|
||||
...newPricing[index],
|
||||
features: { ...newPricing[index].features, included: value.split('\n') },
|
||||
}
|
||||
} else {
|
||||
newPricing[index] = {
|
||||
...newPricing[index],
|
||||
features: { ...newPricing[index].features, [subField]: value },
|
||||
}
|
||||
}
|
||||
} else {
|
||||
newPricing[index] = { ...newPricing[index], [field]: value }
|
||||
}
|
||||
setContent({ ...content, pricing: newPricing })
|
||||
}
|
||||
|
||||
function updateTrust(key: 'item1' | 'item2' | 'item3', field: 'value' | 'label', value: string) {
|
||||
if (!content) return
|
||||
setContent({
|
||||
...content,
|
||||
trust: { ...content.trust, [key]: { ...content.trust[key], [field]: value } },
|
||||
})
|
||||
}
|
||||
|
||||
function updateTestimonial(field: 'quote' | 'author' | 'role', value: string) {
|
||||
if (!content) return
|
||||
setContent({
|
||||
...content,
|
||||
testimonial: { ...content.testimonial, [field]: value },
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
language, setLanguage, t, isRTL,
|
||||
content, loading, saving, message,
|
||||
activeTab, setActiveTab,
|
||||
showPreview, setShowPreview,
|
||||
iframeRef,
|
||||
saveChanges,
|
||||
updateHero, updateFeature, updateFAQ,
|
||||
addFAQ, removeFAQ, updatePricing,
|
||||
updateTrust, updateTestimonial,
|
||||
}
|
||||
}
|
||||
@@ -4,233 +4,53 @@
|
||||
* Admin Panel fuer Website-Content
|
||||
*
|
||||
* Erlaubt das Bearbeiten aller Website-Texte:
|
||||
* - Hero Section
|
||||
* - Features
|
||||
* - FAQ
|
||||
* - Pricing
|
||||
* - Trust Indicators
|
||||
* - Testimonial
|
||||
* - Hero Section, Features, FAQ, Pricing, Trust Indicators, Testimonial
|
||||
*
|
||||
* NEU: Live-Preview der Website zeigt Kontext beim Bearbeiten
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useRef, useCallback } from 'react'
|
||||
import { WebsiteContent, HeroContent, FeatureContent, FAQItem, PricingPlan } from '@/lib/content-types'
|
||||
import { useLanguage } from '@/lib/LanguageContext'
|
||||
import LanguageSelector from '@/components/LanguageSelector'
|
||||
import AdminLayout from '@/components/admin/AdminLayout'
|
||||
|
||||
// Admin Key (in Produktion via Login)
|
||||
const ADMIN_KEY = 'breakpilot-admin-2024'
|
||||
|
||||
// Mapping von Tabs zu Website-Sektionen (CSS Selektoren und Scroll-Positionen)
|
||||
const SECTION_MAP: Record<string, { selector: string; scrollTo: string }> = {
|
||||
hero: { selector: '#hero', scrollTo: 'hero' },
|
||||
features: { selector: '#features', scrollTo: 'features' },
|
||||
faq: { selector: '#faq', scrollTo: 'faq' },
|
||||
pricing: { selector: '#pricing', scrollTo: 'pricing' },
|
||||
other: { selector: '#trust', scrollTo: 'trust' },
|
||||
}
|
||||
import { useContentEditor } from './_components/useContentEditor'
|
||||
import ContentEditorTabs from './_components/ContentEditorTabs'
|
||||
import LivePreviewPanel from './_components/LivePreviewPanel'
|
||||
|
||||
export default function AdminPage() {
|
||||
const { language, setLanguage, t, isRTL } = useLanguage()
|
||||
const [content, setContent] = useState<WebsiteContent | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [saving, setSaving] = useState(false)
|
||||
const [message, setMessage] = useState<{ type: 'success' | 'error'; text: string } | null>(null)
|
||||
const [activeTab, setActiveTab] = useState<'hero' | 'features' | 'faq' | 'pricing' | 'other'>('hero')
|
||||
const [showPreview, setShowPreview] = useState(true)
|
||||
const iframeRef = useRef<HTMLIFrameElement>(null)
|
||||
const editor = useContentEditor()
|
||||
|
||||
// Scrollt die Preview zur entsprechenden Sektion
|
||||
const scrollToSection = useCallback((tab: string) => {
|
||||
if (!iframeRef.current?.contentWindow) return
|
||||
const section = SECTION_MAP[tab]
|
||||
if (section) {
|
||||
try {
|
||||
iframeRef.current.contentWindow.postMessage(
|
||||
{ type: 'scrollTo', section: section.scrollTo },
|
||||
'*'
|
||||
)
|
||||
} catch {
|
||||
// Same-origin policy - fallback
|
||||
}
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Bei Tab-Wechsel zur Sektion scrollen
|
||||
useEffect(() => {
|
||||
scrollToSection(activeTab)
|
||||
}, [activeTab, scrollToSection])
|
||||
|
||||
// Content laden
|
||||
useEffect(() => {
|
||||
loadContent()
|
||||
}, [])
|
||||
|
||||
async function loadContent() {
|
||||
try {
|
||||
const res = await fetch('/api/content')
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setContent(data)
|
||||
} else {
|
||||
setMessage({ type: 'error', text: t('admin_error') })
|
||||
}
|
||||
} catch (error) {
|
||||
setMessage({ type: 'error', text: t('admin_error') })
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
async function saveChanges() {
|
||||
if (!content) return
|
||||
|
||||
setSaving(true)
|
||||
setMessage(null)
|
||||
|
||||
try {
|
||||
const res = await fetch('/api/content', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-admin-key': ADMIN_KEY,
|
||||
},
|
||||
body: JSON.stringify(content),
|
||||
})
|
||||
|
||||
if (res.ok) {
|
||||
setMessage({ type: 'success', text: t('admin_saved') })
|
||||
} else {
|
||||
const error = await res.json()
|
||||
setMessage({ type: 'error', text: error.error || t('admin_error') })
|
||||
}
|
||||
} catch (error) {
|
||||
setMessage({ type: 'error', text: t('admin_error') })
|
||||
} finally {
|
||||
setSaving(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Hero Section updaten
|
||||
function updateHero(field: keyof HeroContent, value: string) {
|
||||
if (!content) return
|
||||
setContent({
|
||||
...content,
|
||||
hero: { ...content.hero, [field]: value },
|
||||
})
|
||||
}
|
||||
|
||||
// Feature updaten
|
||||
function updateFeature(index: number, field: keyof FeatureContent, value: string) {
|
||||
if (!content) return
|
||||
const newFeatures = [...content.features]
|
||||
newFeatures[index] = { ...newFeatures[index], [field]: value }
|
||||
setContent({ ...content, features: newFeatures })
|
||||
}
|
||||
|
||||
// FAQ updaten
|
||||
function updateFAQ(index: number, field: 'question' | 'answer', value: string | string[]) {
|
||||
if (!content) return
|
||||
const newFAQ = [...content.faq]
|
||||
if (field === 'answer' && typeof value === 'string') {
|
||||
// Split by newlines for array
|
||||
newFAQ[index] = { ...newFAQ[index], answer: value.split('\n') }
|
||||
} else if (field === 'question' && typeof value === 'string') {
|
||||
newFAQ[index] = { ...newFAQ[index], question: value }
|
||||
}
|
||||
setContent({ ...content, faq: newFAQ })
|
||||
}
|
||||
|
||||
// FAQ hinzufuegen
|
||||
function addFAQ() {
|
||||
if (!content) return
|
||||
setContent({
|
||||
...content,
|
||||
faq: [...content.faq, { question: 'Neue Frage?', answer: ['Antwort hier...'] }],
|
||||
})
|
||||
}
|
||||
|
||||
// FAQ entfernen
|
||||
function removeFAQ(index: number) {
|
||||
if (!content) return
|
||||
const newFAQ = content.faq.filter((_, i) => i !== index)
|
||||
setContent({ ...content, faq: newFAQ })
|
||||
}
|
||||
|
||||
// Pricing updaten
|
||||
function updatePricing(index: number, field: string, value: string | number | boolean) {
|
||||
if (!content) return
|
||||
const newPricing = [...content.pricing]
|
||||
if (field === 'price') {
|
||||
newPricing[index] = { ...newPricing[index], price: Number(value) }
|
||||
} else if (field === 'popular') {
|
||||
newPricing[index] = { ...newPricing[index], popular: Boolean(value) }
|
||||
} else if (field.startsWith('features.')) {
|
||||
const subField = field.replace('features.', '')
|
||||
if (subField === 'included' && typeof value === 'string') {
|
||||
newPricing[index] = {
|
||||
...newPricing[index],
|
||||
features: {
|
||||
...newPricing[index].features,
|
||||
included: value.split('\n'),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
newPricing[index] = {
|
||||
...newPricing[index],
|
||||
features: {
|
||||
...newPricing[index].features,
|
||||
[subField]: value,
|
||||
},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
newPricing[index] = { ...newPricing[index], [field]: value }
|
||||
}
|
||||
setContent({ ...content, pricing: newPricing })
|
||||
}
|
||||
|
||||
if (loading) {
|
||||
if (editor.loading) {
|
||||
return (
|
||||
<AdminLayout title="Übersetzungen" description="Website Content & Sprachen">
|
||||
<AdminLayout title="Uebersetzungen" description="Website Content & Sprachen">
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<div className="text-xl text-slate-600">{t('admin_loading')}</div>
|
||||
<div className="text-xl text-slate-600">{editor.t('admin_loading')}</div>
|
||||
</div>
|
||||
</AdminLayout>
|
||||
)
|
||||
}
|
||||
|
||||
if (!content) {
|
||||
if (!editor.content) {
|
||||
return (
|
||||
<AdminLayout title="Übersetzungen" description="Website Content & Sprachen">
|
||||
<AdminLayout title="Uebersetzungen" description="Website Content & Sprachen">
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<div className="text-xl text-red-600">{t('admin_error')}</div>
|
||||
<div className="text-xl text-red-600">{editor.t('admin_error')}</div>
|
||||
</div>
|
||||
</AdminLayout>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<AdminLayout title="Übersetzungen" description="Website Content & Sprachen">
|
||||
<div className={isRTL ? 'rtl' : ''} dir={isRTL ? 'rtl' : 'ltr'}>
|
||||
<AdminLayout title="Uebersetzungen" description="Website Content & Sprachen">
|
||||
<div className={editor.isRTL ? 'rtl' : ''} dir={editor.isRTL ? 'rtl' : 'ltr'}>
|
||||
{/* Toolbar */}
|
||||
<div className={`bg-white rounded-xl border border-slate-200 p-4 mb-6 flex items-center justify-between ${isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
<div className={`bg-white rounded-xl border border-slate-200 p-4 mb-6 flex items-center justify-between ${editor.isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
<div className="flex items-center gap-4">
|
||||
<LanguageSelector
|
||||
currentLanguage={language}
|
||||
onLanguageChange={setLanguage}
|
||||
/>
|
||||
{/* Preview Toggle */}
|
||||
<LanguageSelector currentLanguage={editor.language} onLanguageChange={editor.setLanguage} />
|
||||
<button
|
||||
onClick={() => setShowPreview(!showPreview)}
|
||||
onClick={() => editor.setShowPreview(!editor.showPreview)}
|
||||
className={`flex items-center gap-2 px-3 py-2 rounded-lg text-sm font-medium transition-colors ${
|
||||
showPreview
|
||||
? 'bg-blue-100 text-blue-700'
|
||||
: 'bg-slate-100 text-slate-600 hover:bg-slate-200'
|
||||
editor.showPreview ? 'bg-blue-100 text-blue-700' : 'bg-slate-100 text-slate-600 hover:bg-slate-200'
|
||||
}`}
|
||||
title={showPreview ? 'Preview ausblenden' : 'Preview einblenden'}
|
||||
title={editor.showPreview ? 'Preview ausblenden' : 'Preview einblenden'}
|
||||
>
|
||||
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" />
|
||||
@@ -239,565 +59,65 @@ export default function AdminPage() {
|
||||
Live-Preview
|
||||
</button>
|
||||
</div>
|
||||
<div className={`flex items-center gap-4 ${isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
{message && (
|
||||
<span
|
||||
className={`px-3 py-1 rounded text-sm ${
|
||||
message.type === 'success'
|
||||
? 'bg-green-100 text-green-800'
|
||||
: 'bg-red-100 text-red-800'
|
||||
}`}
|
||||
>
|
||||
{message.text}
|
||||
<div className={`flex items-center gap-4 ${editor.isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
{editor.message && (
|
||||
<span className={`px-3 py-1 rounded text-sm ${
|
||||
editor.message.type === 'success' ? 'bg-green-100 text-green-800' : 'bg-red-100 text-red-800'
|
||||
}`}>
|
||||
{editor.message.text}
|
||||
</span>
|
||||
)}
|
||||
<button
|
||||
onClick={saveChanges}
|
||||
disabled={saving}
|
||||
onClick={editor.saveChanges}
|
||||
disabled={editor.saving}
|
||||
className="bg-primary-600 text-white px-6 py-2 rounded-lg font-medium hover:bg-primary-700 disabled:opacity-50 transition-colors"
|
||||
>
|
||||
{saving ? t('admin_saving') : t('admin_save')}
|
||||
{editor.saving ? editor.t('admin_saving') : editor.t('admin_save')}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Tabs */}
|
||||
<div className="mb-6">
|
||||
<div className={`flex gap-1 bg-slate-100 p-1 rounded-lg w-fit ${isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
<div className={`flex gap-1 bg-slate-100 p-1 rounded-lg w-fit ${editor.isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
{(['hero', 'features', 'faq', 'pricing', 'other'] as const).map((tab) => (
|
||||
<button
|
||||
key={tab}
|
||||
onClick={() => setActiveTab(tab)}
|
||||
onClick={() => editor.setActiveTab(tab)}
|
||||
className={`px-4 py-2 text-sm font-medium rounded-md transition-colors ${
|
||||
activeTab === tab
|
||||
? 'bg-white text-slate-900 shadow-sm'
|
||||
: 'text-slate-600 hover:text-slate-900'
|
||||
editor.activeTab === tab ? 'bg-white text-slate-900 shadow-sm' : 'text-slate-600 hover:text-slate-900'
|
||||
}`}
|
||||
>
|
||||
{tab === 'hero' && t('admin_tab_hero')}
|
||||
{tab === 'features' && t('admin_tab_features')}
|
||||
{tab === 'faq' && t('admin_tab_faq')}
|
||||
{tab === 'pricing' && t('admin_tab_pricing')}
|
||||
{tab === 'other' && t('admin_tab_other')}
|
||||
{tab === 'hero' && editor.t('admin_tab_hero')}
|
||||
{tab === 'features' && editor.t('admin_tab_features')}
|
||||
{tab === 'faq' && editor.t('admin_tab_faq')}
|
||||
{tab === 'pricing' && editor.t('admin_tab_pricing')}
|
||||
{tab === 'other' && editor.t('admin_tab_other')}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Split Layout: Editor + Preview */}
|
||||
<div className={`grid gap-6 ${showPreview ? 'grid-cols-2' : 'grid-cols-1'}`}>
|
||||
{/* Editor Panel */}
|
||||
<div className={`grid gap-6 ${editor.showPreview ? 'grid-cols-2' : 'grid-cols-1'}`}>
|
||||
<div className="bg-white rounded-xl border border-slate-200 shadow-sm p-6 max-h-[calc(100vh-280px)] overflow-y-auto">
|
||||
{/* Hero Tab */}
|
||||
{activeTab === 'hero' && (
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_hero')} Section</h2>
|
||||
|
||||
<div className="grid gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Badge</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.badge}
|
||||
onChange={(e) => updateHero('badge', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
<ContentEditorTabs
|
||||
activeTab={editor.activeTab}
|
||||
content={editor.content}
|
||||
isRTL={editor.isRTL}
|
||||
t={editor.t}
|
||||
updateHero={editor.updateHero}
|
||||
updateFeature={editor.updateFeature}
|
||||
updateFAQ={editor.updateFAQ}
|
||||
addFAQ={editor.addFAQ}
|
||||
removeFAQ={editor.removeFAQ}
|
||||
updatePricing={editor.updatePricing}
|
||||
updateTrust={editor.updateTrust}
|
||||
updateTestimonial={editor.updateTestimonial}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Titel (vor Highlight)
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.title}
|
||||
onChange={(e) => updateHero('title', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Highlight 1
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.titleHighlight1}
|
||||
onChange={(e) => updateHero('titleHighlight1', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Highlight 2
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.titleHighlight2}
|
||||
onChange={(e) => updateHero('titleHighlight2', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Untertitel</label>
|
||||
<textarea
|
||||
value={content.hero.subtitle}
|
||||
onChange={(e) => updateHero('subtitle', e.target.value)}
|
||||
rows={3}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
CTA Primaer
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.ctaPrimary}
|
||||
onChange={(e) => updateHero('ctaPrimary', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
CTA Sekundaer
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.ctaSecondary}
|
||||
onChange={(e) => updateHero('ctaSecondary', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">CTA Hinweis</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.hero.ctaHint}
|
||||
onChange={(e) => updateHero('ctaHint', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Features Tab */}
|
||||
{activeTab === 'features' && (
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_features')}</h2>
|
||||
|
||||
{content.features.map((feature, index) => (
|
||||
<div key={feature.id} className="border border-slate-200 rounded-lg p-4">
|
||||
<div className="grid gap-4">
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Icon</label>
|
||||
<input
|
||||
type="text"
|
||||
value={feature.icon}
|
||||
onChange={(e) => updateFeature(index, 'icon', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-2xl text-center"
|
||||
/>
|
||||
</div>
|
||||
<div className="col-span-2">
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Titel</label>
|
||||
<input
|
||||
type="text"
|
||||
value={feature.title}
|
||||
onChange={(e) => updateFeature(index, 'title', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Beschreibung
|
||||
</label>
|
||||
<textarea
|
||||
value={feature.description}
|
||||
onChange={(e) => updateFeature(index, 'description', e.target.value)}
|
||||
rows={2}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* FAQ Tab */}
|
||||
{activeTab === 'faq' && (
|
||||
<div className="space-y-6">
|
||||
<div className={`flex items-center justify-between ${isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_faq')}</h2>
|
||||
<button
|
||||
onClick={addFAQ}
|
||||
className="px-4 py-2 bg-slate-100 text-slate-700 rounded-lg hover:bg-slate-200 transition-colors"
|
||||
>
|
||||
{t('admin_add_faq')}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{content.faq.map((item, index) => (
|
||||
<div key={index} className="border border-slate-200 rounded-lg p-4">
|
||||
<div className={`flex items-start justify-between gap-4 ${isRTL ? 'flex-row-reverse' : ''}`}>
|
||||
<div className="flex-1 space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
{t('admin_question')} {index + 1}
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={item.question}
|
||||
onChange={(e) => updateFAQ(index, 'question', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
{t('admin_answer')}
|
||||
</label>
|
||||
<textarea
|
||||
value={item.answer.join('\n')}
|
||||
onChange={(e) => updateFAQ(index, 'answer', e.target.value)}
|
||||
rows={4}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500 font-mono text-sm"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => removeFAQ(index)}
|
||||
className="p-2 text-red-600 hover:bg-red-50 rounded-lg transition-colors"
|
||||
title="Frage entfernen"
|
||||
>
|
||||
<svg className="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Pricing Tab */}
|
||||
{activeTab === 'pricing' && (
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-xl font-semibold text-slate-900">{t('admin_tab_pricing')}</h2>
|
||||
|
||||
{content.pricing.map((plan, index) => (
|
||||
<div key={plan.id} className="border border-slate-200 rounded-lg p-4">
|
||||
<div className="grid gap-4">
|
||||
<div className="grid grid-cols-4 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Name</label>
|
||||
<input
|
||||
type="text"
|
||||
value={plan.name}
|
||||
onChange={(e) => updatePricing(index, 'name', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Preis (EUR)
|
||||
</label>
|
||||
<input
|
||||
type="number"
|
||||
step="0.01"
|
||||
value={plan.price}
|
||||
onChange={(e) => updatePricing(index, 'price', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Intervall
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={plan.interval}
|
||||
onChange={(e) => updatePricing(index, 'interval', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end">
|
||||
<label className="flex items-center gap-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={plan.popular || false}
|
||||
onChange={(e) => updatePricing(index, 'popular', e.target.checked)}
|
||||
className="w-4 h-4 text-primary-600 rounded"
|
||||
/>
|
||||
<span className="text-sm text-slate-700">{t('pricing_popular')}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Beschreibung
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={plan.description}
|
||||
onChange={(e) => updatePricing(index, 'description', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
{t('pricing_tasks')}
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={plan.features.tasks}
|
||||
onChange={(e) => updatePricing(index, 'features.tasks', e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Aufgaben-Beschreibung
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={plan.features.taskDescription}
|
||||
onChange={(e) =>
|
||||
updatePricing(index, 'features.taskDescription', e.target.value)
|
||||
}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Features (eine pro Zeile)
|
||||
</label>
|
||||
<textarea
|
||||
value={plan.features.included.join('\n')}
|
||||
onChange={(e) => updatePricing(index, 'features.included', e.target.value)}
|
||||
rows={4}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500 font-mono text-sm"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Other Tab */}
|
||||
{activeTab === 'other' && (
|
||||
<div className="space-y-8">
|
||||
{/* Trust Indicators */}
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold text-slate-900 mb-4">Trust Indicators</h2>
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
{(['item1', 'item2', 'item3'] as const).map((key, index) => (
|
||||
<div key={key} className="border border-slate-200 rounded-lg p-4">
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Wert {index + 1}
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.trust[key].value}
|
||||
onChange={(e) =>
|
||||
setContent({
|
||||
...content,
|
||||
trust: {
|
||||
...content.trust,
|
||||
[key]: { ...content.trust[key], value: e.target.value },
|
||||
},
|
||||
})
|
||||
}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">
|
||||
Label {index + 1}
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.trust[key].label}
|
||||
onChange={(e) =>
|
||||
setContent({
|
||||
...content,
|
||||
trust: {
|
||||
...content.trust,
|
||||
[key]: { ...content.trust[key], label: e.target.value },
|
||||
},
|
||||
})
|
||||
}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Testimonial */}
|
||||
<div>
|
||||
<h2 className="text-xl font-semibold text-slate-900 mb-4">Testimonial</h2>
|
||||
<div className="border border-slate-200 rounded-lg p-4 space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Zitat</label>
|
||||
<textarea
|
||||
value={content.testimonial.quote}
|
||||
onChange={(e) =>
|
||||
setContent({
|
||||
...content,
|
||||
testimonial: { ...content.testimonial, quote: e.target.value },
|
||||
})
|
||||
}
|
||||
rows={3}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Autor</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.testimonial.author}
|
||||
onChange={(e) =>
|
||||
setContent({
|
||||
...content,
|
||||
testimonial: { ...content.testimonial, author: e.target.value },
|
||||
})
|
||||
}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">Rolle</label>
|
||||
<input
|
||||
type="text"
|
||||
value={content.testimonial.role}
|
||||
onChange={(e) =>
|
||||
setContent({
|
||||
...content,
|
||||
testimonial: { ...content.testimonial, role: e.target.value },
|
||||
})
|
||||
}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg focus:ring-2 focus:ring-primary-500 focus:border-primary-500"
|
||||
dir={isRTL ? 'rtl' : 'ltr'}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Live Preview Panel */}
|
||||
{showPreview && (
|
||||
<div className="bg-white rounded-xl border border-slate-200 shadow-sm overflow-hidden">
|
||||
{/* Preview Header */}
|
||||
<div className="bg-slate-50 border-b border-slate-200 px-4 py-3 flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex gap-1.5">
|
||||
<div className="w-3 h-3 rounded-full bg-red-400"></div>
|
||||
<div className="w-3 h-3 rounded-full bg-yellow-400"></div>
|
||||
<div className="w-3 h-3 rounded-full bg-green-400"></div>
|
||||
</div>
|
||||
<span className="text-xs text-slate-500 ml-2">localhost:3000</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs font-medium text-slate-600 bg-slate-200 px-2 py-1 rounded">
|
||||
{activeTab === 'hero' && 'Hero Section'}
|
||||
{activeTab === 'features' && 'Features'}
|
||||
{activeTab === 'faq' && 'FAQ'}
|
||||
{activeTab === 'pricing' && 'Pricing'}
|
||||
{activeTab === 'other' && 'Trust & Testimonial'}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => iframeRef.current?.contentWindow?.location.reload()}
|
||||
className="p-1.5 text-slate-500 hover:text-slate-700 hover:bg-slate-200 rounded transition-colors"
|
||||
title="Preview neu laden"
|
||||
>
|
||||
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/* Preview Frame */}
|
||||
<div className="relative h-[calc(100vh-340px)] bg-slate-100">
|
||||
<iframe
|
||||
ref={iframeRef}
|
||||
src={`/?preview=true§ion=${activeTab}#${activeTab}`}
|
||||
className="w-full h-full border-0 scale-75 origin-top-left"
|
||||
style={{
|
||||
width: '133.33%',
|
||||
height: '133.33%',
|
||||
transform: 'scale(0.75)',
|
||||
transformOrigin: 'top left',
|
||||
}}
|
||||
title="Website Preview"
|
||||
sandbox="allow-same-origin allow-scripts"
|
||||
/>
|
||||
{/* Sektion-Indikator */}
|
||||
<div className="absolute bottom-4 left-4 right-4 bg-blue-600 text-white px-4 py-2 rounded-lg shadow-lg flex items-center gap-2 text-sm">
|
||||
<svg className="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
<span>
|
||||
Du bearbeitest: <strong>
|
||||
{activeTab === 'hero' && 'Hero Section (Startbereich)'}
|
||||
{activeTab === 'features' && 'Features (Funktionen)'}
|
||||
{activeTab === 'faq' && 'FAQ (Häufige Fragen)'}
|
||||
{activeTab === 'pricing' && 'Pricing (Preise)'}
|
||||
{activeTab === 'other' && 'Trust & Testimonial'}
|
||||
</strong>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{editor.showPreview && (
|
||||
<LivePreviewPanel activeTab={editor.activeTab} iframeRef={editor.iframeRef} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
import { ScreenDefinition, ConnectionDef, FlowType } from './types'
|
||||
|
||||
/**
|
||||
* Find all connected nodes recursively from a start node.
|
||||
*/
|
||||
export function findConnectedNodes(
|
||||
startNodeId: string,
|
||||
connections: ConnectionDef[],
|
||||
direction: 'children' | 'parents' | 'both' = 'children'
|
||||
): Set<string> {
|
||||
const connected = new Set<string>()
|
||||
connected.add(startNodeId)
|
||||
|
||||
const queue = [startNodeId]
|
||||
while (queue.length > 0) {
|
||||
const current = queue.shift()!
|
||||
|
||||
connections.forEach(conn => {
|
||||
if ((direction === 'children' || direction === 'both') && conn.source === current) {
|
||||
if (!connected.has(conn.target)) {
|
||||
connected.add(conn.target)
|
||||
queue.push(conn.target)
|
||||
}
|
||||
}
|
||||
if ((direction === 'parents' || direction === 'both') && conn.target === current) {
|
||||
if (!connected.has(conn.source)) {
|
||||
connected.add(conn.source)
|
||||
queue.push(conn.source)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return connected
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct an embeddable URL from a base URL and screen URL.
|
||||
*/
|
||||
export function constructEmbedUrl(baseUrl: string, url: string | undefined): string | null {
|
||||
if (!url) return null
|
||||
|
||||
const hashIndex = url.indexOf('#')
|
||||
if (hashIndex !== -1) {
|
||||
const basePart = url.substring(0, hashIndex)
|
||||
const hashPart = url.substring(hashIndex)
|
||||
const separator = basePart.includes('?') ? '&' : '?'
|
||||
return `${baseUrl}${basePart}${separator}embed=true${hashPart}`
|
||||
} else {
|
||||
const separator = url.includes('?') ? '&' : '?'
|
||||
return `${baseUrl}${url}${separator}embed=true`
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate node position based on category and index within that category.
|
||||
*/
|
||||
export function getNodePosition(
|
||||
id: string,
|
||||
category: string,
|
||||
screens: ScreenDefinition[],
|
||||
flowType: FlowType
|
||||
) {
|
||||
const studioPositions: Record<string, { x: number; y: number }> = {
|
||||
navigation: { x: 400, y: 50 },
|
||||
content: { x: 50, y: 250 },
|
||||
communication: { x: 750, y: 250 },
|
||||
school: { x: 50, y: 500 },
|
||||
admin: { x: 750, y: 500 },
|
||||
ai: { x: 400, y: 380 },
|
||||
}
|
||||
|
||||
const adminPositions: Record<string, { x: number; y: number }> = {
|
||||
overview: { x: 400, y: 30 },
|
||||
infrastructure: { x: 50, y: 150 },
|
||||
compliance: { x: 700, y: 150 },
|
||||
ai: { x: 50, y: 350 },
|
||||
communication: { x: 400, y: 350 },
|
||||
security: { x: 700, y: 350 },
|
||||
content: { x: 50, y: 550 },
|
||||
game: { x: 400, y: 550 },
|
||||
misc: { x: 700, y: 550 },
|
||||
}
|
||||
|
||||
const positions = flowType === 'studio' ? studioPositions : adminPositions
|
||||
const base = positions[category] || { x: 400, y: 300 }
|
||||
const categoryScreens = screens.filter(s => s.category === category)
|
||||
const categoryIndex = categoryScreens.findIndex(s => s.id === id)
|
||||
|
||||
const cols = Math.ceil(Math.sqrt(categoryScreens.length + 1))
|
||||
const row = Math.floor(categoryIndex / cols)
|
||||
const col = categoryIndex % cols
|
||||
|
||||
return {
|
||||
x: base.x + col * 160,
|
||||
y: base.y + row * 90,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
import { ScreenDefinition, ConnectionDef } from './types'
|
||||
|
||||
// ============================================
|
||||
// STUDIO SCREENS (Port 8000)
|
||||
// ============================================
|
||||
|
||||
export const STUDIO_SCREENS: ScreenDefinition[] = [
|
||||
{ id: 'lehrer-dashboard', name: 'Mein Dashboard', description: 'Hauptuebersicht mit Widgets', category: 'navigation', icon: '🏠', url: '/app#lehrer-dashboard' },
|
||||
{ id: 'lehrer-onboarding', name: 'Erste Schritte', description: 'Onboarding & Schnellstart', category: 'navigation', icon: '🚀', url: '/app#lehrer-onboarding' },
|
||||
{ id: 'hilfe', name: 'Dokumentation', description: 'Hilfe & Anleitungen', category: 'navigation', icon: '📚', url: '/app#hilfe' },
|
||||
{ id: 'worksheets', name: 'Arbeitsblaetter Studio', description: 'Lernmaterialien erstellen', category: 'content', icon: '📝', url: '/app#worksheets' },
|
||||
{ id: 'content-creator', name: 'Content Creator', description: 'Inhalte erstellen', category: 'content', icon: '✨', url: '/app#content-creator' },
|
||||
{ id: 'content-feed', name: 'Content Feed', description: 'Inhalte durchsuchen', category: 'content', icon: '📰', url: '/app#content-feed' },
|
||||
{ id: 'unit-creator', name: 'Unit Creator', description: 'Lerneinheiten erstellen', category: 'content', icon: '📦', url: '/app#unit-creator' },
|
||||
{ id: 'letters', name: 'Briefe & Vorlagen', description: 'Brief-Generator', category: 'content', icon: '✉️', url: '/app#letters' },
|
||||
{ id: 'correction', name: 'Korrektur', description: 'Arbeiten korrigieren', category: 'content', icon: '✏️', url: '/app#correction' },
|
||||
{ id: 'klausur-korrektur', name: 'Abiturklausuren', description: 'KI-gestuetzte Klausurkorrektur', category: 'content', icon: '📋', url: '/app#klausur-korrektur' },
|
||||
{ id: 'jitsi', name: 'Videokonferenz', description: 'Jitsi Meet Integration', category: 'communication', icon: '🎥', url: '/app#jitsi' },
|
||||
{ id: 'messenger', name: 'Messenger', description: 'Matrix E2EE Chat', category: 'communication', icon: '💬', url: '/app#messenger' },
|
||||
{ id: 'mail', name: 'Unified Inbox', description: 'E-Mail Verwaltung', category: 'communication', icon: '📧', url: '/app#mail' },
|
||||
{ id: 'school-classes', name: 'Klassen', description: 'Klassenverwaltung', category: 'school', icon: '👥', url: '/app#school-classes' },
|
||||
{ id: 'school-exams', name: 'Pruefungen', description: 'Pruefungsverwaltung', category: 'school', icon: '📝', url: '/app#school-exams' },
|
||||
{ id: 'school-grades', name: 'Noten', description: 'Notenverwaltung', category: 'school', icon: '📊', url: '/app#school-grades' },
|
||||
{ id: 'school-gradebook', name: 'Notenbuch', description: 'Digitales Notenbuch', category: 'school', icon: '📖', url: '/app#school-gradebook' },
|
||||
{ id: 'school-certificates', name: 'Zeugnisse', description: 'Zeugniserstellung', category: 'school', icon: '🎓', url: '/app#school-certificates' },
|
||||
{ id: 'companion', name: 'Begleiter & Stunde', description: 'KI-Unterrichtsassistent', category: 'ai', icon: '🤖', url: '/app#companion' },
|
||||
{ id: 'alerts', name: 'Alerts', description: 'News & Benachrichtigungen', category: 'ai', icon: '🔔', url: '/app#alerts' },
|
||||
{ id: 'admin', name: 'Einstellungen', description: 'Systemeinstellungen', category: 'admin', icon: '⚙️', url: '/app#admin' },
|
||||
{ id: 'rbac-admin', name: 'Rollen & Rechte', description: 'Berechtigungsverwaltung', category: 'admin', icon: '🔐', url: '/app#rbac-admin' },
|
||||
{ id: 'abitur-docs-admin', name: 'Abitur Dokumente', description: 'Erwartungshorizonte', category: 'admin', icon: '📄', url: '/app#abitur-docs-admin' },
|
||||
{ id: 'system-info', name: 'System Info', description: 'Systeminformationen', category: 'admin', icon: '💻', url: '/app#system-info' },
|
||||
{ id: 'workflow', name: 'Workflow', description: 'Automatisierungen', category: 'admin', icon: '⚡', url: '/app#workflow' },
|
||||
]
|
||||
|
||||
export const STUDIO_CONNECTIONS: ConnectionDef[] = [
|
||||
{ source: 'lehrer-onboarding', target: 'worksheets', label: 'Arbeitsblaetter' },
|
||||
{ source: 'lehrer-onboarding', target: 'klausur-korrektur', label: 'Abiturklausuren' },
|
||||
{ source: 'lehrer-onboarding', target: 'correction', label: 'Korrektur' },
|
||||
{ source: 'lehrer-onboarding', target: 'letters', label: 'Briefe' },
|
||||
{ source: 'lehrer-onboarding', target: 'school-classes', label: 'Klassen' },
|
||||
{ source: 'lehrer-onboarding', target: 'jitsi', label: 'Meet' },
|
||||
{ source: 'lehrer-onboarding', target: 'hilfe', label: 'Doku' },
|
||||
{ source: 'lehrer-onboarding', target: 'admin', label: 'Settings' },
|
||||
{ source: 'lehrer-dashboard', target: 'worksheets' },
|
||||
{ source: 'lehrer-dashboard', target: 'correction' },
|
||||
{ source: 'lehrer-dashboard', target: 'jitsi' },
|
||||
{ source: 'lehrer-dashboard', target: 'letters' },
|
||||
{ source: 'lehrer-dashboard', target: 'messenger' },
|
||||
{ source: 'lehrer-dashboard', target: 'klausur-korrektur' },
|
||||
{ source: 'lehrer-dashboard', target: 'companion' },
|
||||
{ source: 'lehrer-dashboard', target: 'alerts' },
|
||||
{ source: 'lehrer-dashboard', target: 'mail' },
|
||||
{ source: 'lehrer-dashboard', target: 'school-classes' },
|
||||
{ source: 'lehrer-dashboard', target: 'lehrer-onboarding', label: 'Sidebar' },
|
||||
{ source: 'school-classes', target: 'school-exams' },
|
||||
{ source: 'school-classes', target: 'school-grades' },
|
||||
{ source: 'school-grades', target: 'school-gradebook' },
|
||||
{ source: 'school-gradebook', target: 'school-certificates' },
|
||||
{ source: 'worksheets', target: 'content-creator' },
|
||||
{ source: 'worksheets', target: 'unit-creator' },
|
||||
{ source: 'content-creator', target: 'content-feed' },
|
||||
{ source: 'klausur-korrektur', target: 'abitur-docs-admin' },
|
||||
{ source: 'admin', target: 'rbac-admin' },
|
||||
{ source: 'admin', target: 'system-info' },
|
||||
{ source: 'admin', target: 'workflow' },
|
||||
]
|
||||
|
||||
// ============================================
|
||||
// ADMIN SCREENS (Port 3000)
|
||||
// ============================================
|
||||
|
||||
export const ADMIN_SCREENS: ScreenDefinition[] = [
|
||||
{ id: 'admin-dashboard', name: 'Dashboard', description: 'Uebersicht & Statistiken', category: 'overview', icon: '🏠', url: '/admin' },
|
||||
{ id: 'admin-onboarding', name: 'Onboarding', description: 'Lern-Wizards fuer alle Module', category: 'overview', icon: '📖', url: '/admin/onboarding' },
|
||||
{ id: 'admin-gpu', name: 'GPU Infrastruktur', description: 'vast.ai GPU Management', category: 'infrastructure', icon: '🖥️', url: '/admin/gpu' },
|
||||
{ id: 'admin-middleware', name: 'Middleware', description: 'Middleware Stack & Test', category: 'infrastructure', icon: '🔧', url: '/admin/middleware' },
|
||||
{ id: 'admin-mac-mini', name: 'Mac Mini', description: 'Headless Mac Mini Control', category: 'infrastructure', icon: '🍎', url: '/admin/mac-mini' },
|
||||
{ id: 'admin-consent', name: 'Consent Verwaltung', description: 'Rechtliche Dokumente', category: 'compliance', icon: '📄', url: '/admin/consent' },
|
||||
{ id: 'admin-dsr', name: 'Datenschutzanfragen', description: 'DSGVO Art. 15-21', category: 'compliance', icon: '🔒', url: '/admin/dsr' },
|
||||
{ id: 'admin-dsms', name: 'DSMS', description: 'Datenschutz-Management', category: 'compliance', icon: '🛡️', url: '/admin/dsms' },
|
||||
{ id: 'admin-compliance', name: 'Compliance', description: 'GRC & Audit', category: 'compliance', icon: '✅', url: '/admin/compliance' },
|
||||
{ id: 'admin-docs-audit', name: 'DSGVO-Audit', description: 'Audit-Dokumentation', category: 'compliance', icon: '📋', url: '/admin/docs/audit' },
|
||||
{ id: 'admin-rag', name: 'Daten & RAG', description: 'Training Data & RAG', category: 'ai', icon: '🗄️', url: '/admin/rag' },
|
||||
{ id: 'admin-ocr-labeling', name: 'OCR-Labeling', description: 'Handschrift-Training', category: 'ai', icon: '🏷️', url: '/admin/ocr-labeling' },
|
||||
{ id: 'admin-magic-help', name: 'Magic Help (TrOCR)', description: 'Handschrift-OCR', category: 'ai', icon: '✨', url: '/admin/magic-help' },
|
||||
{ id: 'admin-companion', name: 'Companion Dev', description: 'Lesson-Modus Entwicklung', category: 'ai', icon: '📚', url: '/admin/companion' },
|
||||
{ id: 'admin-communication', name: 'Kommunikation', description: 'Matrix & Jitsi Monitoring', category: 'communication', icon: '💬', url: '/admin/communication' },
|
||||
{ id: 'admin-alerts', name: 'Alerts Monitoring', description: 'Google Alerts & Feeds', category: 'communication', icon: '🔔', url: '/admin/alerts' },
|
||||
{ id: 'admin-mail', name: 'Unified Inbox', description: 'E-Mail & KI-Analyse', category: 'communication', icon: '📧', url: '/admin/mail' },
|
||||
{ id: 'admin-security', name: 'Security', description: 'DevSecOps Dashboard', category: 'security', icon: '🔐', url: '/admin/security' },
|
||||
{ id: 'admin-sbom', name: 'SBOM', description: 'Software Bill of Materials', category: 'security', icon: '📦', url: '/admin/sbom' },
|
||||
{ id: 'admin-screen-flow', name: 'Screen Flow', description: 'UI Verbindungen', category: 'security', icon: '🔀', url: '/admin/screen-flow' },
|
||||
{ id: 'admin-content', name: 'Uebersetzungen', description: 'Website Content', category: 'content', icon: '🌍', url: '/admin/content' },
|
||||
{ id: 'admin-edu-search', name: 'Education Search', description: 'Bildungsquellen & Crawler', category: 'content', icon: '🔍', url: '/admin/edu-search' },
|
||||
{ id: 'admin-staff-search', name: 'Personensuche', description: 'Uni-Mitarbeiter', category: 'content', icon: '👤', url: '/admin/staff-search' },
|
||||
{ id: 'admin-uni-crawler', name: 'Uni-Crawler', description: 'Universitaets-Crawling', category: 'content', icon: '🕷️', url: '/admin/uni-crawler' },
|
||||
{ id: 'admin-game', name: 'Breakpilot Drive', description: 'Lernspiel Klasse 2-6', category: 'game', icon: '🎮', url: '/admin/game' },
|
||||
{ id: 'admin-unity-bridge', name: 'Unity Bridge', description: 'Unity Editor Steuerung', category: 'game', icon: '⚡', url: '/admin/unity-bridge' },
|
||||
{ id: 'admin-backlog', name: 'Production Backlog', description: 'Go-Live Checkliste', category: 'misc', icon: '📝', url: '/admin/backlog' },
|
||||
{ id: 'admin-brandbook', name: 'Brandbook', description: 'Corporate Design', category: 'misc', icon: '🎨', url: '/admin/brandbook' },
|
||||
{ id: 'admin-docs', name: 'Developer Docs', description: 'API & Architektur', category: 'misc', icon: '📖', url: '/admin/docs' },
|
||||
{ id: 'admin-pca-platform', name: 'PCA Platform', description: 'Bot-Erkennung', category: 'misc', icon: '💰', url: '/admin/pca-platform' },
|
||||
]
|
||||
|
||||
export const ADMIN_CONNECTIONS: ConnectionDef[] = [
|
||||
{ source: 'admin-dashboard', target: 'admin-onboarding' },
|
||||
{ source: 'admin-dashboard', target: 'admin-security' },
|
||||
{ source: 'admin-dashboard', target: 'admin-compliance' },
|
||||
{ source: 'admin-onboarding', target: 'admin-gpu' },
|
||||
{ source: 'admin-onboarding', target: 'admin-consent' },
|
||||
{ source: 'admin-consent', target: 'admin-dsr' },
|
||||
{ source: 'admin-dsr', target: 'admin-dsms' },
|
||||
{ source: 'admin-dsms', target: 'admin-compliance' },
|
||||
{ source: 'admin-compliance', target: 'admin-docs-audit' },
|
||||
{ source: 'admin-rag', target: 'admin-ocr-labeling' },
|
||||
{ source: 'admin-ocr-labeling', target: 'admin-magic-help' },
|
||||
{ source: 'admin-magic-help', target: 'admin-companion' },
|
||||
{ source: 'admin-security', target: 'admin-sbom' },
|
||||
{ source: 'admin-sbom', target: 'admin-screen-flow' },
|
||||
{ source: 'admin-communication', target: 'admin-alerts' },
|
||||
{ source: 'admin-alerts', target: 'admin-mail' },
|
||||
{ source: 'admin-gpu', target: 'admin-middleware' },
|
||||
{ source: 'admin-middleware', target: 'admin-mac-mini' },
|
||||
{ source: 'admin-game', target: 'admin-unity-bridge' },
|
||||
{ source: 'admin-edu-search', target: 'admin-staff-search' },
|
||||
{ source: 'admin-staff-search', target: 'admin-uni-crawler' },
|
||||
]
|
||||
|
||||
// ============================================
|
||||
// CATEGORY COLORS & LABELS
|
||||
// ============================================
|
||||
|
||||
export const STUDIO_COLORS: Record<string, { bg: string; border: string; text: string }> = {
|
||||
navigation: { bg: '#dbeafe', border: '#3b82f6', text: '#1e40af' },
|
||||
content: { bg: '#dcfce7', border: '#22c55e', text: '#166534' },
|
||||
communication: { bg: '#fef3c7', border: '#f59e0b', text: '#92400e' },
|
||||
school: { bg: '#fce7f3', border: '#ec4899', text: '#9d174d' },
|
||||
admin: { bg: '#f3e8ff', border: '#a855f7', text: '#6b21a8' },
|
||||
ai: { bg: '#cffafe', border: '#06b6d4', text: '#0e7490' },
|
||||
}
|
||||
|
||||
export const ADMIN_COLORS: Record<string, { bg: string; border: string; text: string }> = {
|
||||
overview: { bg: '#dbeafe', border: '#3b82f6', text: '#1e40af' },
|
||||
infrastructure: { bg: '#fef3c7', border: '#f59e0b', text: '#92400e' },
|
||||
compliance: { bg: '#dcfce7', border: '#22c55e', text: '#166534' },
|
||||
ai: { bg: '#cffafe', border: '#06b6d4', text: '#0e7490' },
|
||||
communication: { bg: '#fce7f3', border: '#ec4899', text: '#9d174d' },
|
||||
security: { bg: '#fee2e2', border: '#ef4444', text: '#991b1b' },
|
||||
content: { bg: '#f3e8ff', border: '#a855f7', text: '#6b21a8' },
|
||||
game: { bg: '#fef9c3', border: '#eab308', text: '#713f12' },
|
||||
misc: { bg: '#f1f5f9', border: '#64748b', text: '#334155' },
|
||||
}
|
||||
|
||||
export const STUDIO_LABELS: Record<string, string> = {
|
||||
navigation: 'Navigation',
|
||||
content: 'Content & Tools',
|
||||
communication: 'Kommunikation',
|
||||
school: 'Schulverwaltung',
|
||||
admin: 'Administration',
|
||||
ai: 'KI & Assistent',
|
||||
}
|
||||
|
||||
export const ADMIN_LABELS: Record<string, string> = {
|
||||
overview: 'Uebersicht',
|
||||
infrastructure: 'Infrastruktur',
|
||||
compliance: 'DSGVO & Compliance',
|
||||
ai: 'KI & LLM',
|
||||
communication: 'Kommunikation',
|
||||
security: 'Security & DevOps',
|
||||
content: 'Content & Suche',
|
||||
game: 'Game & Unity',
|
||||
misc: 'Sonstiges',
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
export interface ScreenDefinition {
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
category: string
|
||||
icon: string
|
||||
url?: string
|
||||
}
|
||||
|
||||
export interface ConnectionDef {
|
||||
source: string
|
||||
target: string
|
||||
label?: string
|
||||
}
|
||||
|
||||
export type FlowType = 'studio' | 'admin'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user