diff --git a/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/_components/WizardWidgets.tsx b/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/_components/WizardWidgets.tsx new file mode 100644 index 0000000..5bf21d7 --- /dev/null +++ b/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/_components/WizardWidgets.tsx @@ -0,0 +1,179 @@ +'use client' + +/** + * SBOM Wizard Sub-Components + * + * WizardStepper, EducationCard, CategoryDemo - extracted from page.tsx. + */ + +import type { WizardStep } from './wizard-content' +import { EDUCATION_CONTENT } from './wizard-content' + +export function WizardStepper({ + steps, + currentStep, + onStepClick +}: { + steps: WizardStep[] + currentStep: number + onStepClick: (index: number) => void +}) { + return ( +
+ {steps.map((step, index) => ( +
+ + {index < steps.length - 1 && ( +
+ )} +
+ ))} +
+ ) +} + +export function EducationCard({ stepId }: { stepId: string }) { + const content = EDUCATION_CONTENT[stepId] + if (!content) return null + + return ( +
+

+ 📖 + {content.title} +

+
+ {content.content.map((line, index) => ( +

$1') + .replace(/→/g, '') + .replace(/← NEU!/g, '← NEU!') + }} + /> + ))} +

+ {content.tips && content.tips.length > 0 && ( +
+

💡 Tipps:

+ {content.tips.map((tip, index) => ( +

• {tip}

+ ))} +
+ )} +
+ ) +} + +export function CategoryDemo({ stepId }: { stepId: string }) { + if (stepId === 'categories') { + const categories = [ + { name: 'infrastructure', color: 'blue', count: 45 }, + { name: 'security-tools', color: 'red', count: 12 }, + { name: 'python', color: 'yellow', count: 35 }, + { name: 'go', color: 'cyan', count: 18 }, + { name: 'nodejs', color: 'green', count: 55 }, + { name: 'unity', color: 'amber', count: 7, isNew: true }, + { name: 'csharp', color: 'fuchsia', count: 3, isNew: true }, + { name: 'game', color: 'rose', count: 1, isNew: true }, + ] + return ( +
+

Live-Vorschau: Kategorien

+
+ {categories.map((cat) => ( +
+

{cat.name}

+

{cat.count} Komponenten

+ {'isNew' in cat && cat.isNew && ( + NEU + )} +
+ ))} +
+
+ ) + } + + if (stepId === 'unity-game') { + const unityComponents = [ + { name: 'Unity Engine', version: '6000.0', license: 'Unity EULA' }, + { name: 'URP', version: '17.x', license: 'Unity Companion' }, + { name: 'TextMeshPro', version: '3.2', license: 'Unity Companion' }, + { name: 'Mathematics', version: '1.3', license: 'Unity Companion' }, + { name: 'Newtonsoft.Json', version: '3.2', license: 'MIT' }, + ] + return ( +
+

Unity Packages (Breakpilot Drive)

+
+ {unityComponents.map((comp) => ( +
+
+ unity + {comp.name} +
+
+ {comp.version} + {comp.license} +
+
+ ))} +
+
+ ) + } + + if (stepId === 'licenses') { + const licenses = [ + { name: 'MIT', count: 85, color: 'green', risk: 'Niedrig' }, + { name: 'Apache 2.0', count: 45, color: 'green', risk: 'Niedrig' }, + { name: 'BSD', count: 12, color: 'green', risk: 'Niedrig' }, + { name: 'Unity EULA', count: 1, color: 'yellow', risk: 'Mittel' }, + { name: 'Unity Companion', count: 6, color: 'yellow', risk: 'Mittel' }, + { name: 'AGPL', count: 2, color: 'orange', risk: 'Hoch' }, + ] + return ( +
+

Lizenz-Uebersicht

+
+ {licenses.map((lic) => ( +
+
+ {lic.name} + ({lic.count}) +
+ + Risiko: {lic.risk} + +
+ ))} +
+
+ ) + } + + return null +} diff --git a/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/_components/wizard-content.ts b/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/_components/wizard-content.ts new file mode 100644 index 0000000..d3fbb30 --- /dev/null +++ b/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/_components/wizard-content.ts @@ -0,0 +1,204 @@ +/** + * SBOM Wizard - Education Content & Constants + * + * Extracted from wizard/page.tsx. + */ + +export type StepStatus = 'pending' | 'active' | 'completed' + +export interface WizardStep { + id: string + name: string + icon: string + status: StepStatus +} + +export const STEPS: WizardStep[] = [ + { id: 'welcome', name: 'Willkommen', icon: '📋', status: 'pending' }, + { id: 'what-is-sbom', name: 'Was ist SBOM?', icon: '❓', status: 'pending' }, + { id: 'why-important', name: 'Warum wichtig?', icon: '⚠️', status: 'pending' }, + { id: 'categories', name: 'Kategorien', icon: '📁', status: 'pending' }, + { id: 'infrastructure', name: 'Infrastruktur', icon: '🏗️', status: 'pending' }, + { id: 'unity-game', name: 'Unity & Game', icon: '🎮', status: 'pending' }, + { id: 'licenses', name: 'Lizenzen', icon: '📜', status: 'pending' }, + { id: 'summary', name: 'Zusammenfassung', icon: '✅', status: 'pending' }, +] + +export const EDUCATION_CONTENT: Record = { + 'welcome': { + title: 'Willkommen zum SBOM-Wizard!', + content: [ + 'Eine **Software Bill of Materials (SBOM)** ist wie ein Zutaten-Etikett fuer Software.', + 'Sie listet alle Komponenten auf, aus denen eine Anwendung besteht:', + '• Open-Source-Bibliotheken', + '• Frameworks und Engines', + '• Infrastruktur-Dienste', + '• Entwicklungs-Tools', + 'In diesem Wizard lernst du, warum SBOMs wichtig sind und welche Komponenten BreakPilot verwendet - inklusive der neuen **Breakpilot Drive** (Unity) Komponenten.', + ], + tips: [ + 'SBOMs sind seit 2021 fuer US-Regierungsauftraege Pflicht', + 'Die EU plant aehnliche Vorschriften im Cyber Resilience Act', + ], + }, + 'what-is-sbom': { + title: 'Was ist eine SBOM?', + content: [ + '**SBOM = Software Bill of Materials**', + 'Eine SBOM ist eine vollstaendige Liste aller Software-Komponenten:', + '**Enthaltene Informationen:**', + '• Name der Komponente', + '• Version', + '• Lizenz (MIT, Apache, GPL, etc.)', + '• Herkunft (Source URL)', + '• Typ (Library, Service, Tool)', + '**Formate:**', + '• SPDX (Linux Foundation Standard)', + '• CycloneDX (OWASP Standard)', + '• SWID Tags (ISO Standard)', + 'BreakPilot verwendet eine eigene Darstellung im Admin-Panel, die alle relevanten Infos zeigt.', + ], + tips: [ + 'Eine SBOM ist wie ein Beipackzettel fuer Medikamente', + 'Sie ermoeglicht schnelle Reaktion bei Sicherheitsluecken', + ], + }, + 'why-important': { + title: 'Warum sind SBOMs wichtig?', + content: [ + '**1. Sicherheit (Security)**', + 'Wenn eine Sicherheitsluecke in einer Bibliothek entdeckt wird (z.B. Log4j), kannst du sofort pruefen ob du betroffen bist.', + '**2. Compliance (Lizenz-Einhaltung)**', + 'Verschiedene Lizenzen haben verschiedene Anforderungen:', + '• MIT: Fast keine Einschraenkungen', + '• GPL: Copyleft - abgeleitete Werke muessen auch GPL sein', + '• Proprietary: Kommerzielle Nutzung eingeschraenkt', + '**3. Supply Chain Security**', + 'Moderne Software besteht aus hunderten Abhaengigkeiten. Eine SBOM macht diese Kette transparent.', + '**4. Regulatorische Anforderungen**', + 'US Executive Order 14028 verlangt SBOMs fuer Regierungssoftware.', + ], + tips: [ + 'Log4Shell (2021) betraf Millionen von Systemen', + 'Mit SBOM: Betroffenheit in Minuten geprueft', + ], + }, + 'categories': { + title: 'SBOM-Kategorien in BreakPilot', + content: [ + 'Die BreakPilot SBOM ist in Kategorien unterteilt:', + '**infrastructure** (Blau)', + '→ Kern-Infrastruktur: PostgreSQL, Valkey, Keycloak, Docker', + '**security-tools** (Rot)', + '→ Sicherheits-Tools: Trivy, Gitleaks, Semgrep', + '**python** (Gelb)', + '→ Python-Backend: FastAPI, Pydantic, httpx', + '**go** (Cyan)', + '→ Go-Services: Gin, GORM, JWT', + '**nodejs** (Gruen)', + '→ Frontend: Next.js, React, Tailwind', + '**unity** (Amber) ← NEU!', + '→ Game Engine: Unity 6, URP, TextMeshPro', + '**csharp** (Fuchsia) ← NEU!', + '→ C#/.NET: .NET Standard, UnityWebRequest', + '**game** (Rose) ← NEU!', + '→ Breakpilot Drive Service', + ], + tips: [ + 'Klicke auf eine Kategorie um zu filtern', + 'Die neuen Unity/Game-Kategorien wurden fuer Breakpilot Drive hinzugefuegt', + ], + }, + 'infrastructure': { + title: 'Infrastruktur-Komponenten', + content: [ + 'BreakPilot basiert auf robuster Infrastruktur:', + '**Datenbanken:**', + '• PostgreSQL 16 - Relationale Datenbank', + '• Valkey 8 - In-Memory Cache (Redis-Fork)', + '• ChromaDB - Vector Store fuer RAG', + '**Auth & Security:**', + '• Keycloak 23 - Identity & Access Management', + '• HashiCorp Vault - Secrets Management', + '**Container & Orchestrierung:**', + '• Docker - Container Runtime', + '• Traefik - Reverse Proxy', + '**Kommunikation:**', + '• Matrix Synapse - Chat/Messaging', + '• Jitsi Meet - Video-Konferenzen', + ], + tips: [ + 'Alle Services laufen in Docker-Containern', + 'Ports sind in docker-compose.yml definiert', + ], + }, + 'unity-game': { + title: 'Unity & Breakpilot Drive', + content: [ + '**Neu hinzugefuegt fuer Breakpilot Drive:**', + '**Unity Engine (6000.0)**', + '→ Die Game Engine fuer das Lernspiel', + '→ Lizenz: Unity EULA (kostenlos bis 100k Revenue)', + '**Universal Render Pipeline (17.x)**', + '→ Optimierte Grafik-Pipeline fuer WebGL', + '→ Lizenz: Unity Companion License', + '**TextMeshPro (3.2)**', + '→ Fortgeschrittenes Text-Rendering', + '**Unity Mathematics (1.3)**', + '→ SIMD-optimierte Mathe-Bibliothek', + '**Newtonsoft.Json (3.2)**', + '→ JSON-Serialisierung fuer API-Kommunikation', + '**C# Abhaengigkeiten:**', + '• .NET Standard 2.1', + '• UnityWebRequest (HTTP Client)', + '• System.Text.Json', + ], + tips: [ + 'Unity 6 ist die neueste LTS-Version', + 'WebGL-Builds sind ~30-50 MB gross', + ], + }, + 'licenses': { + title: 'Lizenz-Compliance', + content: [ + '**Lizenz-Typen in BreakPilot:**', + '**Permissive (Unkompliziert):**', + '• MIT - Die meisten JS/Python Libs', + '• Apache 2.0 - FastAPI, Keycloak', + '• BSD - PostgreSQL', + '**Copyleft (Vorsicht bei Aenderungen):**', + '• GPL - Wenige Komponenten', + '• AGPL - Jitsi (Server-Side OK)', + '**Proprietary:**', + '• Unity EULA - Kostenlos bis 100k Revenue', + '• Unity Companion - Packages an Engine gebunden', + '**Wichtig:**', + 'Alle verwendeten Lizenzen sind mit kommerziellem Einsatz kompatibel. Bei Fragen: Rechtsabteilung konsultieren.', + ], + tips: [ + 'MIT und Apache 2.0 sind am unproblematischsten', + 'AGPL erfordert Source-Code-Freigabe bei Modifikation', + ], + }, + 'summary': { + title: 'Zusammenfassung', + content: [ + 'Du hast die SBOM von BreakPilot kennengelernt:', + '✅ Was eine SBOM ist und warum sie wichtig ist', + '✅ Die verschiedenen Kategorien (8 Stueck)', + '✅ Infrastruktur-Komponenten', + '✅ Die neuen Unity/Game-Komponenten fuer Breakpilot Drive', + '✅ Lizenz-Typen und Compliance', + '**Im SBOM-Dashboard kannst du:**', + '• Nach Kategorie filtern', + '• Nach Namen suchen', + '• Lizenzen pruefen', + '• Komponenten-Details ansehen', + '**180+ Komponenten** sind dokumentiert und nachverfolgbar.', + ], + tips: [ + 'Pruefe regelmaessig auf veraltete Komponenten', + 'Bei neuen Abhaengigkeiten: SBOM aktualisieren', + ], + }, +} diff --git a/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/page.tsx b/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/page.tsx index 25c17b9..870d30e 100644 --- a/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/page.tsx +++ b/admin-lehrer/app/(admin)/infrastructure/sbom/wizard/page.tsx @@ -3,408 +3,13 @@ /** * SBOM (Software Bill of Materials) - Lern-Wizard * - * Migriert von /admin/sbom/wizard (website) nach /infrastructure/sbom/wizard (admin-v2) - * - * Interaktiver Wizard zum Verstehen der SBOM: - * - Was ist eine SBOM? - * - Warum ist sie wichtig? - * - Kategorien erklaert - * - Breakpilot Drive (Unity/C#/Game) Komponenten - * - Lizenzen und Compliance + * Interaktiver Wizard zum Verstehen der SBOM. */ import { useState } from 'react' import Link from 'next/link' - -// ============================================== -// Types -// ============================================== - -type StepStatus = 'pending' | 'active' | 'completed' - -interface WizardStep { - id: string - name: string - icon: string - status: StepStatus -} - -// ============================================== -// Constants -// ============================================== - -const STEPS: WizardStep[] = [ - { id: 'welcome', name: 'Willkommen', icon: '📋', status: 'pending' }, - { id: 'what-is-sbom', name: 'Was ist SBOM?', icon: '❓', status: 'pending' }, - { id: 'why-important', name: 'Warum wichtig?', icon: '⚠️', status: 'pending' }, - { id: 'categories', name: 'Kategorien', icon: '📁', status: 'pending' }, - { id: 'infrastructure', name: 'Infrastruktur', icon: '🏗️', status: 'pending' }, - { id: 'unity-game', name: 'Unity & Game', icon: '🎮', status: 'pending' }, - { id: 'licenses', name: 'Lizenzen', icon: '📜', status: 'pending' }, - { id: 'summary', name: 'Zusammenfassung', icon: '✅', status: 'pending' }, -] - -const EDUCATION_CONTENT: Record = { - 'welcome': { - title: 'Willkommen zum SBOM-Wizard!', - content: [ - 'Eine **Software Bill of Materials (SBOM)** ist wie ein Zutaten-Etikett fuer Software.', - 'Sie listet alle Komponenten auf, aus denen eine Anwendung besteht:', - '• Open-Source-Bibliotheken', - '• Frameworks und Engines', - '• Infrastruktur-Dienste', - '• Entwicklungs-Tools', - 'In diesem Wizard lernst du, warum SBOMs wichtig sind und welche Komponenten BreakPilot verwendet - inklusive der neuen **Breakpilot Drive** (Unity) Komponenten.', - ], - tips: [ - 'SBOMs sind seit 2021 fuer US-Regierungsauftraege Pflicht', - 'Die EU plant aehnliche Vorschriften im Cyber Resilience Act', - ], - }, - 'what-is-sbom': { - title: 'Was ist eine SBOM?', - content: [ - '**SBOM = Software Bill of Materials**', - 'Eine SBOM ist eine vollstaendige Liste aller Software-Komponenten:', - '**Enthaltene Informationen:**', - '• Name der Komponente', - '• Version', - '• Lizenz (MIT, Apache, GPL, etc.)', - '• Herkunft (Source URL)', - '• Typ (Library, Service, Tool)', - '**Formate:**', - '• SPDX (Linux Foundation Standard)', - '• CycloneDX (OWASP Standard)', - '• SWID Tags (ISO Standard)', - 'BreakPilot verwendet eine eigene Darstellung im Admin-Panel, die alle relevanten Infos zeigt.', - ], - tips: [ - 'Eine SBOM ist wie ein Beipackzettel fuer Medikamente', - 'Sie ermoeglicht schnelle Reaktion bei Sicherheitsluecken', - ], - }, - 'why-important': { - title: 'Warum sind SBOMs wichtig?', - content: [ - '**1. Sicherheit (Security)**', - 'Wenn eine Sicherheitsluecke in einer Bibliothek entdeckt wird (z.B. Log4j), kannst du sofort pruefen ob du betroffen bist.', - '**2. Compliance (Lizenz-Einhaltung)**', - 'Verschiedene Lizenzen haben verschiedene Anforderungen:', - '• MIT: Fast keine Einschraenkungen', - '• GPL: Copyleft - abgeleitete Werke muessen auch GPL sein', - '• Proprietary: Kommerzielle Nutzung eingeschraenkt', - '**3. Supply Chain Security**', - 'Moderne Software besteht aus hunderten Abhaengigkeiten. Eine SBOM macht diese Kette transparent.', - '**4. Regulatorische Anforderungen**', - 'US Executive Order 14028 verlangt SBOMs fuer Regierungssoftware.', - ], - tips: [ - 'Log4Shell (2021) betraf Millionen von Systemen', - 'Mit SBOM: Betroffenheit in Minuten geprueft', - ], - }, - 'categories': { - title: 'SBOM-Kategorien in BreakPilot', - content: [ - 'Die BreakPilot SBOM ist in Kategorien unterteilt:', - '**infrastructure** (Blau)', - '→ Kern-Infrastruktur: PostgreSQL, Valkey, Keycloak, Docker', - '**security-tools** (Rot)', - '→ Sicherheits-Tools: Trivy, Gitleaks, Semgrep', - '**python** (Gelb)', - '→ Python-Backend: FastAPI, Pydantic, httpx', - '**go** (Cyan)', - '→ Go-Services: Gin, GORM, JWT', - '**nodejs** (Gruen)', - '→ Frontend: Next.js, React, Tailwind', - '**unity** (Amber) ← NEU!', - '→ Game Engine: Unity 6, URP, TextMeshPro', - '**csharp** (Fuchsia) ← NEU!', - '→ C#/.NET: .NET Standard, UnityWebRequest', - '**game** (Rose) ← NEU!', - '→ Breakpilot Drive Service', - ], - tips: [ - 'Klicke auf eine Kategorie um zu filtern', - 'Die neuen Unity/Game-Kategorien wurden fuer Breakpilot Drive hinzugefuegt', - ], - }, - 'infrastructure': { - title: 'Infrastruktur-Komponenten', - content: [ - 'BreakPilot basiert auf robuster Infrastruktur:', - '**Datenbanken:**', - '• PostgreSQL 16 - Relationale Datenbank', - '• Valkey 8 - In-Memory Cache (Redis-Fork)', - '• ChromaDB - Vector Store fuer RAG', - '**Auth & Security:**', - '• Keycloak 23 - Identity & Access Management', - '• HashiCorp Vault - Secrets Management', - '**Container & Orchestrierung:**', - '• Docker - Container Runtime', - '• Traefik - Reverse Proxy', - '**Kommunikation:**', - '• Matrix Synapse - Chat/Messaging', - '• Jitsi Meet - Video-Konferenzen', - ], - tips: [ - 'Alle Services laufen in Docker-Containern', - 'Ports sind in docker-compose.yml definiert', - ], - }, - 'unity-game': { - title: 'Unity & Breakpilot Drive', - content: [ - '**Neu hinzugefuegt fuer Breakpilot Drive:**', - '**Unity Engine (6000.0)**', - '→ Die Game Engine fuer das Lernspiel', - '→ Lizenz: Unity EULA (kostenlos bis 100k Revenue)', - '**Universal Render Pipeline (17.x)**', - '→ Optimierte Grafik-Pipeline fuer WebGL', - '→ Lizenz: Unity Companion License', - '**TextMeshPro (3.2)**', - '→ Fortgeschrittenes Text-Rendering', - '**Unity Mathematics (1.3)**', - '→ SIMD-optimierte Mathe-Bibliothek', - '**Newtonsoft.Json (3.2)**', - '→ JSON-Serialisierung fuer API-Kommunikation', - '**C# Abhaengigkeiten:**', - '• .NET Standard 2.1', - '• UnityWebRequest (HTTP Client)', - '• System.Text.Json', - ], - tips: [ - 'Unity 6 ist die neueste LTS-Version', - 'WebGL-Builds sind ~30-50 MB gross', - ], - }, - 'licenses': { - title: 'Lizenz-Compliance', - content: [ - '**Lizenz-Typen in BreakPilot:**', - '**Permissive (Unkompliziert):**', - '• MIT - Die meisten JS/Python Libs', - '• Apache 2.0 - FastAPI, Keycloak', - '• BSD - PostgreSQL', - '**Copyleft (Vorsicht bei Aenderungen):**', - '• GPL - Wenige Komponenten', - '• AGPL - Jitsi (Server-Side OK)', - '**Proprietary:**', - '• Unity EULA - Kostenlos bis 100k Revenue', - '• Unity Companion - Packages an Engine gebunden', - '**Wichtig:**', - 'Alle verwendeten Lizenzen sind mit kommerziellem Einsatz kompatibel. Bei Fragen: Rechtsabteilung konsultieren.', - ], - tips: [ - 'MIT und Apache 2.0 sind am unproblematischsten', - 'AGPL erfordert Source-Code-Freigabe bei Modifikation', - ], - }, - 'summary': { - title: 'Zusammenfassung', - content: [ - 'Du hast die SBOM von BreakPilot kennengelernt:', - '✅ Was eine SBOM ist und warum sie wichtig ist', - '✅ Die verschiedenen Kategorien (8 Stueck)', - '✅ Infrastruktur-Komponenten', - '✅ Die neuen Unity/Game-Komponenten fuer Breakpilot Drive', - '✅ Lizenz-Typen und Compliance', - '**Im SBOM-Dashboard kannst du:**', - '• Nach Kategorie filtern', - '• Nach Namen suchen', - '• Lizenzen pruefen', - '• Komponenten-Details ansehen', - '**180+ Komponenten** sind dokumentiert und nachverfolgbar.', - ], - tips: [ - 'Pruefe regelmaessig auf veraltete Komponenten', - 'Bei neuen Abhaengigkeiten: SBOM aktualisieren', - ], - }, -} - -// ============================================== -// Components -// ============================================== - -function WizardStepper({ - steps, - currentStep, - onStepClick -}: { - steps: WizardStep[] - currentStep: number - onStepClick: (index: number) => void -}) { - return ( -
- {steps.map((step, index) => ( -
- - {index < steps.length - 1 && ( -
- )} -
- ))} -
- ) -} - -function EducationCard({ stepId }: { stepId: string }) { - const content = EDUCATION_CONTENT[stepId] - if (!content) return null - - return ( -
-

- 📖 - {content.title} -

-
- {content.content.map((line, index) => ( -

$1') - .replace(/→/g, '') - .replace(/← NEU!/g, '← NEU!') - }} - /> - ))} -

- {content.tips && content.tips.length > 0 && ( -
-

💡 Tipps:

- {content.tips.map((tip, index) => ( -

• {tip}

- ))} -
- )} -
- ) -} - -function CategoryDemo({ stepId }: { stepId: string }) { - if (stepId === 'categories') { - const categories = [ - { name: 'infrastructure', color: 'blue', count: 45 }, - { name: 'security-tools', color: 'red', count: 12 }, - { name: 'python', color: 'yellow', count: 35 }, - { name: 'go', color: 'cyan', count: 18 }, - { name: 'nodejs', color: 'green', count: 55 }, - { name: 'unity', color: 'amber', count: 7, isNew: true }, - { name: 'csharp', color: 'fuchsia', count: 3, isNew: true }, - { name: 'game', color: 'rose', count: 1, isNew: true }, - ] - - return ( -
-

Live-Vorschau: Kategorien

-
- {categories.map((cat) => ( -
-

{cat.name}

-

{cat.count} Komponenten

- {cat.isNew && ( - NEU - )} -
- ))} -
-
- ) - } - - if (stepId === 'unity-game') { - const unityComponents = [ - { name: 'Unity Engine', version: '6000.0', license: 'Unity EULA' }, - { name: 'URP', version: '17.x', license: 'Unity Companion' }, - { name: 'TextMeshPro', version: '3.2', license: 'Unity Companion' }, - { name: 'Mathematics', version: '1.3', license: 'Unity Companion' }, - { name: 'Newtonsoft.Json', version: '3.2', license: 'MIT' }, - ] - - return ( -
-

Unity Packages (Breakpilot Drive)

-
- {unityComponents.map((comp) => ( -
-
- unity - {comp.name} -
-
- {comp.version} - {comp.license} -
-
- ))} -
-
- ) - } - - if (stepId === 'licenses') { - const licenses = [ - { name: 'MIT', count: 85, color: 'green', risk: 'Niedrig' }, - { name: 'Apache 2.0', count: 45, color: 'green', risk: 'Niedrig' }, - { name: 'BSD', count: 12, color: 'green', risk: 'Niedrig' }, - { name: 'Unity EULA', count: 1, color: 'yellow', risk: 'Mittel' }, - { name: 'Unity Companion', count: 6, color: 'yellow', risk: 'Mittel' }, - { name: 'AGPL', count: 2, color: 'orange', risk: 'Hoch' }, - ] - - return ( -
-

Lizenz-Uebersicht

-
- {licenses.map((lic) => ( -
-
- {lic.name} - ({lic.count}) -
- - Risiko: {lic.risk} - -
- ))} -
-
- ) - } - - return null -} - -// ============================================== -// Main Component -// ============================================== +import { STEPS, type WizardStep } from './_components/wizard-content' +import { WizardStepper, EducationCard, CategoryDemo } from './_components/WizardWidgets' export default function SBOMWizardPage() { const [currentStep, setCurrentStep] = useState(0) diff --git a/admin-lehrer/components/common/DataFlowDiagram.tsx b/admin-lehrer/components/common/DataFlowDiagram.tsx index 7ce5643..ef20b65 100644 --- a/admin-lehrer/components/common/DataFlowDiagram.tsx +++ b/admin-lehrer/components/common/DataFlowDiagram.tsx @@ -12,6 +12,7 @@ import { MODULE_REGISTRY, type BackendModule } from '@/lib/module-registry' +import { DataFlowDiagramDetails, SERVICE_COLORS, STATUS_COLORS } from './DataFlowDiagramDetails' interface NodePosition { x: number @@ -26,20 +27,6 @@ interface ServiceGroup { modules: BackendModule[] } -const SERVICE_COLORS: Record = { - 'consent-service': '#8b5cf6', // purple - 'python-backend': '#f59e0b', // amber - 'klausur-service': '#10b981', // emerald - 'voice-service': '#3b82f6', // blue -} - -const STATUS_COLORS = { - connected: '#22c55e', - partial: '#eab308', - 'not-connected': '#ef4444', - deprecated: '#6b7280' -} - export function DataFlowDiagram() { const [selectedModule, setSelectedModule] = useState(null) const [hoveredModule, setHoveredModule] = useState(null) @@ -471,39 +458,10 @@ export function DataFlowDiagram() { {/* Selected Module Details */} {selectedModule && ( -
-

- {MODULE_REGISTRY.find(m => m.id === selectedModule)?.name} -

-
-

ID: {selectedModule}

- {MODULE_REGISTRY.find(m => m.id === selectedModule)?.dependencies && ( -

- Abhaengigkeiten: - {MODULE_REGISTRY.find(m => m.id === selectedModule)?.dependencies?.map(dep => ( - - ))} -

- )} - {MODULE_REGISTRY.find(m => m.id === selectedModule)?.frontend.adminV2Page && ( -

- Frontend: - m.id === selectedModule)?.frontend.adminV2Page} - className="ml-2 text-purple-600 hover:underline" - > - {MODULE_REGISTRY.find(m => m.id === selectedModule)?.frontend.adminV2Page} - -

- )} -
-
+ )}
) diff --git a/admin-lehrer/components/common/DataFlowDiagramDetails.tsx b/admin-lehrer/components/common/DataFlowDiagramDetails.tsx new file mode 100644 index 0000000..91b1f0b --- /dev/null +++ b/admin-lehrer/components/common/DataFlowDiagramDetails.tsx @@ -0,0 +1,64 @@ +'use client' + +/** + * DataFlowDiagram Module Details Panel + * + * Extracted from DataFlowDiagram.tsx for the selected module details panel. + */ + +import { MODULE_REGISTRY, type BackendModule } from '@/lib/module-registry' + +interface DataFlowDiagramDetailsProps { + selectedModule: string + onSelectModule: (id: string | null) => void +} + +export function DataFlowDiagramDetails({ selectedModule, onSelectModule }: DataFlowDiagramDetailsProps) { + const module = MODULE_REGISTRY.find(m => m.id === selectedModule) + if (!module) return null + + return ( +
+

{module.name}

+
+

ID: {selectedModule}

+ {module.dependencies && ( +

+ Abhaengigkeiten: + {module.dependencies.map(dep => ( + + ))} +

+ )} + {module.frontend.adminV2Page && ( +

+ Frontend: + + {module.frontend.adminV2Page} + +

+ )} +
+
+ ) +} + +export const SERVICE_COLORS: Record = { + 'consent-service': '#8b5cf6', + 'python-backend': '#f59e0b', + 'klausur-service': '#10b981', + 'voice-service': '#3b82f6', +} + +export const STATUS_COLORS = { + connected: '#22c55e', + partial: '#eab308', + 'not-connected': '#ef4444', + deprecated: '#6b7280', +} diff --git a/admin-lehrer/components/infrastructure/DevOpsPipelineSidebar.tsx b/admin-lehrer/components/infrastructure/DevOpsPipelineSidebar.tsx index 14d2908..2d93e61 100644 --- a/admin-lehrer/components/infrastructure/DevOpsPipelineSidebar.tsx +++ b/admin-lehrer/components/infrastructure/DevOpsPipelineSidebar.tsx @@ -279,255 +279,7 @@ export function DevOpsPipelineSidebar({ ) } -// ============================================================================= -// Responsive Version with Mobile FAB + Drawer -// ============================================================================= - -/** - * Responsive DevOps Sidebar mit Mobile FAB + Drawer - * - * Desktop (xl+): Fixierte Sidebar rechts - * Mobile/Tablet: Floating Action Button unten rechts, oeffnet Drawer - */ -export function DevOpsPipelineSidebarResponsive({ - currentTool, - compact = false, - className = '', - fabPosition = 'bottom-right', -}: DevOpsPipelineSidebarResponsiveProps) { - const [isMobileOpen, setIsMobileOpen] = useState(false) - const liveStatus = usePipelineLiveStatus() - - // Close drawer on escape key - useEffect(() => { - const handleEscape = (e: KeyboardEvent) => { - if (e.key === 'Escape') setIsMobileOpen(false) - } - window.addEventListener('keydown', handleEscape) - return () => window.removeEventListener('keydown', handleEscape) - }, []) - - // Prevent body scroll when drawer is open - useEffect(() => { - if (isMobileOpen) { - document.body.style.overflow = 'hidden' - } else { - document.body.style.overflow = '' - } - return () => { - document.body.style.overflow = '' - } - }, [isMobileOpen]) - - const fabPositionClasses = fabPosition === 'bottom-right' - ? 'right-4 bottom-20' - : 'left-4 bottom-20' - - // Calculate total badge count for FAB - const totalBadgeCount = liveStatus - ? liveStatus.backlogCount + liveStatus.securityFindingsCount - : 0 - - return ( - <> - {/* Desktop: Fixed Sidebar */} -
- -
- - {/* Mobile/Tablet: FAB */} - - - {/* Mobile/Tablet: Drawer Overlay */} - {isMobileOpen && ( -
- {/* Backdrop */} -
setIsMobileOpen(false)} - /> - - {/* Drawer */} -
- {/* Drawer Header */} -
-
- - - - - DevOps Pipeline - - {liveStatus?.isRunning && ( - - )} -
- -
- - {/* Drawer Content */} -
- {/* Tool Links */} -
- {DEVOPS_PIPELINE_MODULES.map((tool) => ( - setIsMobileOpen(false)} - className={`flex items-center gap-3 px-4 py-3 rounded-xl transition-all ${ - currentTool === tool.id - ? 'bg-orange-100 dark:bg-orange-900/30 text-orange-700 dark:text-orange-300 font-medium shadow-sm' - : 'text-slate-600 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-gray-800' - }`} - > - -
-
{tool.name}
-
- {tool.description} -
-
- {/* Status badges */} - {tool.id === 'tests' && liveStatus && ( - - )} - {tool.id === 'security' && liveStatus && ( - - )} - {currentTool === tool.id && ( - - )} - - ))} -
- - {/* Pipeline Flow Visualization */} -
-
- Pipeline Flow -
-
-
- 📝 - Code -
- -
- 🏗️ - Build -
- -
- - Test -
- -
- 🚀 - Deploy -
-
-
- - {/* Quick Info */} -
-
- {currentTool === 'ci-cd' && ( - <> - Aktuell: Woodpecker Pipelines und Deployments verwalten - - )} - {currentTool === 'tests' && ( - <> - Aktuell: 280+ Tests ueber alle Services ueberwachen - - )} - {currentTool === 'sbom' && ( - <> - Aktuell: Abhaengigkeiten und Lizenzen pruefen - - )} - {currentTool === 'security' && ( - <> - Aktuell: Vulnerabilities und Security-Scans analysieren - - )} -
-
- - {/* Quick Action: Pipeline triggern */} -
- -
- - {/* Link to Infrastructure Overview */} -
- setIsMobileOpen(false)} - className="flex items-center gap-2 px-3 py-2 text-sm text-orange-600 dark:text-orange-400 hover:bg-orange-50 dark:hover:bg-orange-900/20 rounded-lg transition-colors" - > - - - - Zur Infrastructure-Uebersicht - -
-
-
-
- )} - - {/* CSS for slide-in animation */} - - - ) -} +// Re-export responsive version for backwards compatibility +export { DevOpsPipelineSidebarResponsive } from './DevOpsPipelineSidebarResponsive' export default DevOpsPipelineSidebar diff --git a/admin-lehrer/components/infrastructure/DevOpsPipelineSidebarResponsive.tsx b/admin-lehrer/components/infrastructure/DevOpsPipelineSidebarResponsive.tsx new file mode 100644 index 0000000..f170820 --- /dev/null +++ b/admin-lehrer/components/infrastructure/DevOpsPipelineSidebarResponsive.tsx @@ -0,0 +1,190 @@ +'use client' + +/** + * Responsive DevOps Pipeline Sidebar (Mobile FAB + Drawer) + * + * Extracted from DevOpsPipelineSidebar.tsx. + * Desktop (xl+): Fixierte Sidebar rechts + * Mobile/Tablet: Floating Action Button unten rechts, oeffnet Drawer + */ + +import Link from 'next/link' +import { useState, useEffect } from 'react' +import type { + DevOpsPipelineSidebarResponsiveProps, + PipelineLiveStatus, +} from '@/types/infrastructure-modules' +import { DEVOPS_PIPELINE_MODULES } from '@/types/infrastructure-modules' +import { DevOpsPipelineSidebar } from './DevOpsPipelineSidebar' + +// Server/Pipeline Icon fuer Header +const ServerIcon = () => ( + + + +) + +// Play Icon fuer Quick Action +const PlayIcon = () => ( + + + + +) + +// Inline ToolIcon (duplicated to avoid circular imports) +const ToolIcon = ({ id }: { id: string }) => { + const icons: Record = { + 'ci-cd': , + 'tests': , + 'sbom': , + 'security': , + } + return icons[id] || null +} + +interface StatusBadgeProps { + count: number + type: 'backlog' | 'security' | 'running' +} + +function StatusBadge({ count, type }: StatusBadgeProps) { + if (count === 0) return null + const colors = { + backlog: 'bg-amber-500', + security: 'bg-red-500', + running: 'bg-green-500 animate-pulse', + } + return ( + + {count} + + ) +} + +function usePipelineLiveStatus(): PipelineLiveStatus | null { + const [status, setStatus] = useState(null) + useEffect(() => { /* placeholder for live status fetch */ }, []) + return status +} + +export function DevOpsPipelineSidebarResponsive({ + currentTool, + compact = false, + className = '', + fabPosition = 'bottom-right', +}: DevOpsPipelineSidebarResponsiveProps) { + const [isMobileOpen, setIsMobileOpen] = useState(false) + const liveStatus = usePipelineLiveStatus() + + useEffect(() => { + const handleEscape = (e: KeyboardEvent) => { + if (e.key === 'Escape') setIsMobileOpen(false) + } + window.addEventListener('keydown', handleEscape) + return () => window.removeEventListener('keydown', handleEscape) + }, []) + + useEffect(() => { + if (isMobileOpen) { + document.body.style.overflow = 'hidden' + } else { + document.body.style.overflow = '' + } + return () => { document.body.style.overflow = '' } + }, [isMobileOpen]) + + const fabPositionClasses = fabPosition === 'bottom-right' ? 'right-4 bottom-20' : 'left-4 bottom-20' + const totalBadgeCount = liveStatus ? liveStatus.backlogCount + liveStatus.securityFindingsCount : 0 + + return ( + <> + {/* Desktop: Fixed Sidebar */} +
+ +
+ + {/* Mobile/Tablet: FAB */} + + + {/* Mobile/Tablet: Drawer Overlay */} + {isMobileOpen && ( +
+
setIsMobileOpen(false)} /> +
+
+
+ + DevOps Pipeline + {liveStatus?.isRunning && } +
+ +
+
+
+ {DEVOPS_PIPELINE_MODULES.map((tool) => ( + setIsMobileOpen(false)} className={`flex items-center gap-3 px-4 py-3 rounded-xl transition-all ${currentTool === tool.id ? 'bg-orange-100 dark:bg-orange-900/30 text-orange-700 dark:text-orange-300 font-medium shadow-sm' : 'text-slate-600 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-gray-800'}`}> + +
+
{tool.name}
+
{tool.description}
+
+ {tool.id === 'tests' && liveStatus && } + {tool.id === 'security' && liveStatus && } + {currentTool === tool.id && } + + ))} +
+
+
Pipeline Flow
+
+ {[{e:'📝',l:'Code'},{e:'🏗️',l:'Build'},{e:'✅',l:'Test'},{e:'🚀',l:'Deploy'}].map((s,i,a)=>( + + {s.e}{s.l} + {i} + + ))} +
+
+
+ +
+
+ setIsMobileOpen(false)} className="flex items-center gap-2 px-3 py-2 text-sm text-orange-600 dark:text-orange-400 hover:bg-orange-50 dark:hover:bg-orange-900/20 rounded-lg transition-colors"> + + Zur Infrastructure-Uebersicht + +
+
+
+
+ )} + + + + ) +} diff --git a/admin-lehrer/components/ocr-pipeline/StepGridReview.tsx b/admin-lehrer/components/ocr-pipeline/StepGridReview.tsx index 8829961..4411fa8 100644 --- a/admin-lehrer/components/ocr-pipeline/StepGridReview.tsx +++ b/admin-lehrer/components/ocr-pipeline/StepGridReview.tsx @@ -14,6 +14,7 @@ import type { GridZone, LayoutDividers } from '@/components/grid-editor/types' import { GridToolbar } from '@/components/grid-editor/GridToolbar' import { GridTable } from '@/components/grid-editor/GridTable' import { ImageLayoutEditor } from '@/components/grid-editor/ImageLayoutEditor' +import { ReviewStatsBar } from './StepGridReviewStats' const KLAUSUR_API = '/klausur-api' @@ -236,108 +237,29 @@ export function StepGridReview({ sessionId, onNext, saveRef }: StepGridReviewPro return (
{/* Review Stats Bar */} -
- - {grid.summary.total_zones} Zone(n), {grid.summary.total_columns} Spalten,{' '} - {grid.summary.total_rows} Zeilen, {grid.summary.total_cells} Zellen - - {grid.dictionary_detection?.is_dictionary && ( - - Woerterbuch ({Math.round(grid.dictionary_detection.confidence * 100)}%) - - )} - {grid.page_number?.text && ( - - S. {grid.page_number.number ?? grid.page_number.text} - - )} - {lowConfCells.length > 0 && ( - - {lowConfCells.length} niedrige Konfidenz - - )} - - {acceptedRows.size}/{totalRows} Zeilen akzeptiert - - {acceptedRows.size < totalRows && ( - - )} - {/* OCR Quality Steps (A/B Testing) */} - | - - - - - | - - - -
- - - - {grid.duration_seconds.toFixed(1)}s - -
-
+ setShowImage(!showImage)} + /> {/* Toolbar */}
diff --git a/admin-lehrer/components/ocr-pipeline/StepGridReviewStats.tsx b/admin-lehrer/components/ocr-pipeline/StepGridReviewStats.tsx new file mode 100644 index 0000000..a711827 --- /dev/null +++ b/admin-lehrer/components/ocr-pipeline/StepGridReviewStats.tsx @@ -0,0 +1,180 @@ +'use client' + +/** + * StepGridReview Stats Bar & OCR Quality Controls + * + * Extracted from StepGridReview.tsx to stay under 500 LOC. + */ + +import type { GridZone } from '@/components/grid-editor/types' + +interface GridSummary { + total_zones: number + total_columns: number + total_rows: number + total_cells: number +} + +interface DictionaryDetection { + is_dictionary: boolean + confidence: number +} + +interface PageNumber { + text?: string + number?: number | null +} + +interface ReviewStatsBarProps { + summary: GridSummary + dictionaryDetection?: DictionaryDetection | null + pageNumber?: PageNumber | null + lowConfCount: number + acceptedCount: number + totalRows: number + ocrEnhance: boolean + ocrMaxCols: number + ocrMinConf: number + visionFusion: boolean + documentCategory: string + durationSeconds: number + showImage: boolean + onOcrEnhanceChange: (v: boolean) => void + onOcrMaxColsChange: (v: number) => void + onOcrMinConfChange: (v: number) => void + onVisionFusionChange: (v: boolean) => void + onDocumentCategoryChange: (v: string) => void + onAcceptAll: () => void + onAutoCorrect: () => number + onToggleImage: () => void +} + +export function ReviewStatsBar({ + summary, + dictionaryDetection, + pageNumber, + lowConfCount, + acceptedCount, + totalRows, + ocrEnhance, + ocrMaxCols, + ocrMinConf, + visionFusion, + documentCategory, + durationSeconds, + showImage, + onOcrEnhanceChange, + onOcrMaxColsChange, + onOcrMinConfChange, + onVisionFusionChange, + onDocumentCategoryChange, + onAcceptAll, + onAutoCorrect, + onToggleImage, +}: ReviewStatsBarProps) { + return ( +
+ + {summary.total_zones} Zone(n), {summary.total_columns} Spalten,{' '} + {summary.total_rows} Zeilen, {summary.total_cells} Zellen + + {dictionaryDetection?.is_dictionary && ( + + Woerterbuch ({Math.round(dictionaryDetection.confidence * 100)}%) + + )} + {pageNumber?.text && ( + + S. {pageNumber.number ?? pageNumber.text} + + )} + {lowConfCount > 0 && ( + + {lowConfCount} niedrige Konfidenz + + )} + + {acceptedCount}/{totalRows} Zeilen akzeptiert + + {acceptedCount < totalRows && ( + + )} + + {/* OCR Quality Steps */} + | + + + + + | + + + +
+ + + + {durationSeconds.toFixed(1)}s + +
+
+ ) +} diff --git a/admin-lehrer/components/ocr/GridOverlay.tsx b/admin-lehrer/components/ocr/GridOverlay.tsx index 0014bc9..23fbe61 100644 --- a/admin-lehrer/components/ocr/GridOverlay.tsx +++ b/admin-lehrer/components/ocr/GridOverlay.tsx @@ -474,76 +474,5 @@ export function GridOverlay({ ) } -/** - * GridStats Component - */ -interface GridStatsProps { - stats: GridData['stats'] - deskewAngle?: number - source?: string - className?: string -} - -export function GridStats({ stats, deskewAngle, source, className }: GridStatsProps) { - const coveragePercent = Math.round(stats.coverage * 100) - - return ( -
-
- Erkannt: {stats.recognized} -
- {(stats.manual ?? 0) > 0 && ( -
- Manuell: {stats.manual} -
- )} - {stats.problematic > 0 && ( -
- Problematisch: {stats.problematic} -
- )} -
- Leer: {stats.empty} -
-
- Abdeckung: {coveragePercent}% -
- {deskewAngle !== undefined && deskewAngle !== 0 && ( -
- Begradigt: {deskewAngle.toFixed(1)} -
- )} - {source && ( -
- Quelle: {source === 'tesseract+grid_service' ? 'Tesseract' : 'Vision LLM'} -
- )} -
- ) -} - -/** - * Legend Component for GridOverlay - */ -export function GridLegend({ className }: { className?: string }) { - return ( -
-
-
- Erkannt -
-
-
- Problematisch -
-
-
- Manuell korrigiert -
-
-
- Leer -
-
- ) -} +// Re-export widgets from sibling file for backwards compatibility +export { GridStats, GridLegend } from './GridOverlayWidgets' diff --git a/admin-lehrer/components/ocr/GridOverlayWidgets.tsx b/admin-lehrer/components/ocr/GridOverlayWidgets.tsx new file mode 100644 index 0000000..2c20856 --- /dev/null +++ b/admin-lehrer/components/ocr/GridOverlayWidgets.tsx @@ -0,0 +1,84 @@ +'use client' + +/** + * GridOverlay Widgets - GridStats and GridLegend + * + * Extracted from GridOverlay.tsx to keep each file under 500 LOC. + */ + +import { cn } from '@/lib/utils' +import type { GridData } from './GridOverlay' + +/** + * GridStats Component + */ +interface GridStatsProps { + stats: GridData['stats'] + deskewAngle?: number + source?: string + className?: string +} + +export function GridStats({ stats, deskewAngle, source, className }: GridStatsProps) { + const coveragePercent = Math.round(stats.coverage * 100) + + return ( +
+
+ Erkannt: {stats.recognized} +
+ {(stats.manual ?? 0) > 0 && ( +
+ Manuell: {stats.manual} +
+ )} + {stats.problematic > 0 && ( +
+ Problematisch: {stats.problematic} +
+ )} +
+ Leer: {stats.empty} +
+
+ Abdeckung: {coveragePercent}% +
+ {deskewAngle !== undefined && deskewAngle !== 0 && ( +
+ Begradigt: {deskewAngle.toFixed(1)} +
+ )} + {source && ( +
+ Quelle: {source === 'tesseract+grid_service' ? 'Tesseract' : 'Vision LLM'} +
+ )} +
+ ) +} + +/** + * Legend Component for GridOverlay + */ +export function GridLegend({ className }: { className?: string }) { + return ( +
+
+
+ Erkannt +
+
+
+ Problematisch +
+
+
+ Manuell korrigiert +
+
+
+ Leer +
+
+ ) +} diff --git a/agent-core/brain/__init__.py b/agent-core/brain/__init__.py index 789ecff..4bb6e7a 100644 --- a/agent-core/brain/__init__.py +++ b/agent-core/brain/__init__.py @@ -7,15 +7,31 @@ Provides: - KnowledgeGraph: Entity relationships and semantic connections """ -from agent_core.brain.memory_store import MemoryStore, Memory -from agent_core.brain.context_manager import ConversationContext, ContextManager -from agent_core.brain.knowledge_graph import KnowledgeGraph, Entity, Relationship +from agent_core.brain.memory_models import Memory +from agent_core.brain.memory_store import MemoryStore +from agent_core.brain.context_models import ( + MessageRole, + Message, + ConversationContext, +) +from agent_core.brain.context_manager import ContextManager +from agent_core.brain.knowledge_models import ( + EntityType, + RelationshipType, + Entity, + Relationship, +) +from agent_core.brain.knowledge_graph import KnowledgeGraph __all__ = [ "MemoryStore", "Memory", + "MessageRole", + "Message", "ConversationContext", "ContextManager", + "EntityType", + "RelationshipType", "KnowledgeGraph", "Entity", "Relationship", diff --git a/agent-core/brain/context_manager.py b/agent-core/brain/context_manager.py index e142d33..6d7dfca 100644 --- a/agent-core/brain/context_manager.py +++ b/agent-core/brain/context_manager.py @@ -1,317 +1,22 @@ """ Context Management for Breakpilot Agents -Provides conversation context with: -- Message history with compression -- Entity extraction and tracking -- Intent history -- Context summarization +Manages conversation contexts for multiple sessions with persistence. """ from typing import Dict, Any, List, Optional, Callable, Awaitable -from dataclasses import dataclass, field -from datetime import datetime, timezone -from enum import Enum import json import logging +from agent_core.brain.context_models import ( + MessageRole, + Message, + ConversationContext, +) + logger = logging.getLogger(__name__) -class MessageRole(Enum): - """Message roles in a conversation""" - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" - - -@dataclass -class Message: - """Represents a message in a conversation""" - role: MessageRole - content: str - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: - return { - "role": self.role.value, - "content": self.content, - "timestamp": self.timestamp.isoformat(), - "metadata": self.metadata - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Message": - return cls( - role=MessageRole(data["role"]), - content=data["content"], - timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc), - metadata=data.get("metadata", {}) - ) - - -@dataclass -class ConversationContext: - """ - Context for a running conversation. - - Maintains: - - Message history with automatic compression - - Extracted entities - - Intent history - - Conversation summary - """ - messages: List[Message] = field(default_factory=list) - entities: Dict[str, Any] = field(default_factory=dict) - intent_history: List[str] = field(default_factory=list) - summary: Optional[str] = None - max_messages: int = 50 - system_prompt: Optional[str] = None - metadata: Dict[str, Any] = field(default_factory=dict) - - def add_message( - self, - role: MessageRole, - content: str, - metadata: Optional[Dict[str, Any]] = None - ) -> Message: - """ - Adds a message to the conversation. - - Args: - role: Message role - content: Message content - metadata: Optional message metadata - - Returns: - The created Message - """ - message = Message( - role=role, - content=content, - metadata=metadata or {} - ) - self.messages.append(message) - - # Compress if needed - if len(self.messages) > self.max_messages: - self._compress_history() - - return message - - def add_user_message( - self, - content: str, - metadata: Optional[Dict[str, Any]] = None - ) -> Message: - """Convenience method to add a user message""" - return self.add_message(MessageRole.USER, content, metadata) - - def add_assistant_message( - self, - content: str, - metadata: Optional[Dict[str, Any]] = None - ) -> Message: - """Convenience method to add an assistant message""" - return self.add_message(MessageRole.ASSISTANT, content, metadata) - - def add_system_message( - self, - content: str, - metadata: Optional[Dict[str, Any]] = None - ) -> Message: - """Convenience method to add a system message""" - return self.add_message(MessageRole.SYSTEM, content, metadata) - - def add_intent(self, intent: str) -> None: - """ - Records an intent in the history. - - Args: - intent: The detected intent - """ - self.intent_history.append(intent) - # Keep last 20 intents - if len(self.intent_history) > 20: - self.intent_history = self.intent_history[-20:] - - def set_entity(self, name: str, value: Any) -> None: - """ - Sets an entity value. - - Args: - name: Entity name - value: Entity value - """ - self.entities[name] = value - - def get_entity(self, name: str, default: Any = None) -> Any: - """ - Gets an entity value. - - Args: - name: Entity name - default: Default value if not found - - Returns: - Entity value or default - """ - return self.entities.get(name, default) - - def get_last_message(self, role: Optional[MessageRole] = None) -> Optional[Message]: - """ - Gets the last message, optionally filtered by role. - - Args: - role: Optional role filter - - Returns: - The last matching message or None - """ - if not self.messages: - return None - - if role is None: - return self.messages[-1] - - for msg in reversed(self.messages): - if msg.role == role: - return msg - - return None - - def get_messages_for_llm(self) -> List[Dict[str, str]]: - """ - Gets messages formatted for LLM API calls. - - Returns: - List of message dicts with role and content - """ - result = [] - - # Add system prompt first - if self.system_prompt: - result.append({ - "role": "system", - "content": self.system_prompt - }) - - # Add summary if we have one and history was compressed - if self.summary: - result.append({ - "role": "system", - "content": f"Previous conversation summary: {self.summary}" - }) - - # Add recent messages - for msg in self.messages: - result.append({ - "role": msg.role.value, - "content": msg.content - }) - - return result - - def _compress_history(self) -> None: - """ - Compresses older messages to save context window space. - - Keeps: - - System messages - - Last 20 messages - - Creates summary of compressed middle messages - """ - # Keep system messages - system_msgs = [m for m in self.messages if m.role == MessageRole.SYSTEM] - - # Keep last 20 messages - recent_msgs = self.messages[-20:] - - # Middle messages to summarize - middle_start = len(system_msgs) - middle_end = len(self.messages) - 20 - middle_msgs = self.messages[middle_start:middle_end] - - if middle_msgs: - # Create a basic summary (can be enhanced with LLM-based summarization) - self.summary = self._create_summary(middle_msgs) - - # Combine - self.messages = system_msgs + recent_msgs - - logger.debug( - f"Compressed conversation: {middle_end - middle_start} messages summarized" - ) - - def _create_summary(self, messages: List[Message]) -> str: - """ - Creates a summary of messages. - - This is a basic implementation - can be enhanced with LLM-based summarization. - - Args: - messages: Messages to summarize - - Returns: - Summary string - """ - # Count message types - user_count = sum(1 for m in messages if m.role == MessageRole.USER) - assistant_count = sum(1 for m in messages if m.role == MessageRole.ASSISTANT) - - # Extract key topics (simplified - could use NLP) - topics = set() - for msg in messages: - # Simple keyword extraction - words = msg.content.lower().split() - # Filter common words - keywords = [w for w in words if len(w) > 5][:3] - topics.update(keywords) - - topics_str = ", ".join(list(topics)[:5]) - - return ( - f"Earlier conversation: {user_count} user messages, " - f"{assistant_count} assistant responses. " - f"Topics discussed: {topics_str}" - ) - - def clear(self) -> None: - """Clears all context""" - self.messages.clear() - self.entities.clear() - self.intent_history.clear() - self.summary = None - - def to_dict(self) -> Dict[str, Any]: - """Serializes context to dict""" - return { - "messages": [m.to_dict() for m in self.messages], - "entities": self.entities, - "intent_history": self.intent_history, - "summary": self.summary, - "max_messages": self.max_messages, - "system_prompt": self.system_prompt, - "metadata": self.metadata - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ConversationContext": - """Deserializes context from dict""" - ctx = cls( - messages=[Message.from_dict(m) for m in data.get("messages", [])], - entities=data.get("entities", {}), - intent_history=data.get("intent_history", []), - summary=data.get("summary"), - max_messages=data.get("max_messages", 50), - system_prompt=data.get("system_prompt"), - metadata=data.get("metadata", {}) - ) - return ctx - - class ContextManager: """ Manages conversation contexts for multiple sessions. diff --git a/agent-core/brain/context_models.py b/agent-core/brain/context_models.py new file mode 100644 index 0000000..1a38a86 --- /dev/null +++ b/agent-core/brain/context_models.py @@ -0,0 +1,307 @@ +""" +Context Models for Breakpilot Agents + +Data classes for conversation messages and context management. +""" + +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +import logging + +logger = logging.getLogger(__name__) + + +class MessageRole(Enum): + """Message roles in a conversation""" + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +@dataclass +class Message: + """Represents a message in a conversation""" + role: MessageRole + content: str + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "role": self.role.value, + "content": self.content, + "timestamp": self.timestamp.isoformat(), + "metadata": self.metadata + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Message": + return cls( + role=MessageRole(data["role"]), + content=data["content"], + timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc), + metadata=data.get("metadata", {}) + ) + + +@dataclass +class ConversationContext: + """ + Context for a running conversation. + + Maintains: + - Message history with automatic compression + - Extracted entities + - Intent history + - Conversation summary + """ + messages: List[Message] = field(default_factory=list) + entities: Dict[str, Any] = field(default_factory=dict) + intent_history: List[str] = field(default_factory=list) + summary: Optional[str] = None + max_messages: int = 50 + system_prompt: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def add_message( + self, + role: MessageRole, + content: str, + metadata: Optional[Dict[str, Any]] = None + ) -> Message: + """ + Adds a message to the conversation. + + Args: + role: Message role + content: Message content + metadata: Optional message metadata + + Returns: + The created Message + """ + message = Message( + role=role, + content=content, + metadata=metadata or {} + ) + self.messages.append(message) + + # Compress if needed + if len(self.messages) > self.max_messages: + self._compress_history() + + return message + + def add_user_message( + self, + content: str, + metadata: Optional[Dict[str, Any]] = None + ) -> Message: + """Convenience method to add a user message""" + return self.add_message(MessageRole.USER, content, metadata) + + def add_assistant_message( + self, + content: str, + metadata: Optional[Dict[str, Any]] = None + ) -> Message: + """Convenience method to add an assistant message""" + return self.add_message(MessageRole.ASSISTANT, content, metadata) + + def add_system_message( + self, + content: str, + metadata: Optional[Dict[str, Any]] = None + ) -> Message: + """Convenience method to add a system message""" + return self.add_message(MessageRole.SYSTEM, content, metadata) + + def add_intent(self, intent: str) -> None: + """ + Records an intent in the history. + + Args: + intent: The detected intent + """ + self.intent_history.append(intent) + # Keep last 20 intents + if len(self.intent_history) > 20: + self.intent_history = self.intent_history[-20:] + + def set_entity(self, name: str, value: Any) -> None: + """ + Sets an entity value. + + Args: + name: Entity name + value: Entity value + """ + self.entities[name] = value + + def get_entity(self, name: str, default: Any = None) -> Any: + """ + Gets an entity value. + + Args: + name: Entity name + default: Default value if not found + + Returns: + Entity value or default + """ + return self.entities.get(name, default) + + def get_last_message(self, role: Optional[MessageRole] = None) -> Optional[Message]: + """ + Gets the last message, optionally filtered by role. + + Args: + role: Optional role filter + + Returns: + The last matching message or None + """ + if not self.messages: + return None + + if role is None: + return self.messages[-1] + + for msg in reversed(self.messages): + if msg.role == role: + return msg + + return None + + def get_messages_for_llm(self) -> List[Dict[str, str]]: + """ + Gets messages formatted for LLM API calls. + + Returns: + List of message dicts with role and content + """ + result = [] + + # Add system prompt first + if self.system_prompt: + result.append({ + "role": "system", + "content": self.system_prompt + }) + + # Add summary if we have one and history was compressed + if self.summary: + result.append({ + "role": "system", + "content": f"Previous conversation summary: {self.summary}" + }) + + # Add recent messages + for msg in self.messages: + result.append({ + "role": msg.role.value, + "content": msg.content + }) + + return result + + def _compress_history(self) -> None: + """ + Compresses older messages to save context window space. + + Keeps: + - System messages + - Last 20 messages + - Creates summary of compressed middle messages + """ + # Keep system messages + system_msgs = [m for m in self.messages if m.role == MessageRole.SYSTEM] + + # Keep last 20 messages + recent_msgs = self.messages[-20:] + + # Middle messages to summarize + middle_start = len(system_msgs) + middle_end = len(self.messages) - 20 + middle_msgs = self.messages[middle_start:middle_end] + + if middle_msgs: + # Create a basic summary (can be enhanced with LLM-based summarization) + self.summary = self._create_summary(middle_msgs) + + # Combine + self.messages = system_msgs + recent_msgs + + logger.debug( + f"Compressed conversation: {middle_end - middle_start} messages summarized" + ) + + def _create_summary(self, messages: List[Message]) -> str: + """ + Creates a summary of messages. + + This is a basic implementation - can be enhanced with LLM-based summarization. + + Args: + messages: Messages to summarize + + Returns: + Summary string + """ + # Count message types + user_count = sum(1 for m in messages if m.role == MessageRole.USER) + assistant_count = sum(1 for m in messages if m.role == MessageRole.ASSISTANT) + + # Extract key topics (simplified - could use NLP) + topics = set() + for msg in messages: + # Simple keyword extraction + words = msg.content.lower().split() + # Filter common words + keywords = [w for w in words if len(w) > 5][:3] + topics.update(keywords) + + topics_str = ", ".join(list(topics)[:5]) + + return ( + f"Earlier conversation: {user_count} user messages, " + f"{assistant_count} assistant responses. " + f"Topics discussed: {topics_str}" + ) + + def clear(self) -> None: + """Clears all context""" + self.messages.clear() + self.entities.clear() + self.intent_history.clear() + self.summary = None + + def to_dict(self) -> Dict[str, Any]: + """Serializes context to dict""" + return { + "messages": [m.to_dict() for m in self.messages], + "entities": self.entities, + "intent_history": self.intent_history, + "summary": self.summary, + "max_messages": self.max_messages, + "system_prompt": self.system_prompt, + "metadata": self.metadata + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConversationContext": + """Deserializes context from dict""" + ctx = cls( + messages=[Message.from_dict(m) for m in data.get("messages", [])], + entities=data.get("entities", {}), + intent_history=data.get("intent_history", []), + summary=data.get("summary"), + max_messages=data.get("max_messages", 50), + system_prompt=data.get("system_prompt"), + metadata=data.get("metadata", {}) + ) + return ctx diff --git a/agent-core/brain/knowledge_graph.py b/agent-core/brain/knowledge_graph.py index 6696b5a..6376b14 100644 --- a/agent-core/brain/knowledge_graph.py +++ b/agent-core/brain/knowledge_graph.py @@ -9,109 +9,20 @@ Provides entity and relationship management: """ from typing import Dict, Any, List, Optional, Set, Tuple -from dataclasses import dataclass, field from datetime import datetime, timezone -from enum import Enum import json import logging +from agent_core.brain.knowledge_models import ( + EntityType, + RelationshipType, + Entity, + Relationship, +) + logger = logging.getLogger(__name__) -class EntityType(Enum): - """Types of entities in the knowledge graph""" - STUDENT = "student" - TEACHER = "teacher" - CLASS = "class" - SUBJECT = "subject" - ASSIGNMENT = "assignment" - EXAM = "exam" - TOPIC = "topic" - CONCEPT = "concept" - RESOURCE = "resource" - CUSTOM = "custom" - - -class RelationshipType(Enum): - """Types of relationships between entities""" - BELONGS_TO = "belongs_to" # Student belongs to class - TEACHES = "teaches" # Teacher teaches subject - ASSIGNED_TO = "assigned_to" # Assignment assigned to student - COVERS = "covers" # Exam covers topic - REQUIRES = "requires" # Topic requires concept - RELATED_TO = "related_to" # General relationship - PARENT_OF = "parent_of" # Hierarchical relationship - CREATED_BY = "created_by" # Creator relationship - GRADED_BY = "graded_by" # Grading relationship - - -@dataclass -class Entity: - """Represents an entity in the knowledge graph""" - id: str - entity_type: EntityType - name: str - properties: Dict[str, Any] = field(default_factory=dict) - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - def to_dict(self) -> Dict[str, Any]: - return { - "id": self.id, - "entity_type": self.entity_type.value, - "name": self.name, - "properties": self.properties, - "created_at": self.created_at.isoformat(), - "updated_at": self.updated_at.isoformat() - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Entity": - return cls( - id=data["id"], - entity_type=EntityType(data["entity_type"]), - name=data["name"], - properties=data.get("properties", {}), - created_at=datetime.fromisoformat(data["created_at"]), - updated_at=datetime.fromisoformat(data["updated_at"]) - ) - - -@dataclass -class Relationship: - """Represents a relationship between two entities""" - id: str - source_id: str - target_id: str - relationship_type: RelationshipType - properties: Dict[str, Any] = field(default_factory=dict) - weight: float = 1.0 - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - def to_dict(self) -> Dict[str, Any]: - return { - "id": self.id, - "source_id": self.source_id, - "target_id": self.target_id, - "relationship_type": self.relationship_type.value, - "properties": self.properties, - "weight": self.weight, - "created_at": self.created_at.isoformat() - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Relationship": - return cls( - id=data["id"], - source_id=data["source_id"], - target_id=data["target_id"], - relationship_type=RelationshipType(data["relationship_type"]), - properties=data.get("properties", {}), - weight=data.get("weight", 1.0), - created_at=datetime.fromisoformat(data["created_at"]) - ) - - class KnowledgeGraph: """ Knowledge graph for managing entity relationships. diff --git a/agent-core/brain/knowledge_models.py b/agent-core/brain/knowledge_models.py new file mode 100644 index 0000000..063ef7f --- /dev/null +++ b/agent-core/brain/knowledge_models.py @@ -0,0 +1,104 @@ +""" +Knowledge Graph Models for Breakpilot Agents + +Entity and relationship data classes, plus type enumerations. +""" + +from typing import Dict, Any +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum + + +class EntityType(Enum): + """Types of entities in the knowledge graph""" + STUDENT = "student" + TEACHER = "teacher" + CLASS = "class" + SUBJECT = "subject" + ASSIGNMENT = "assignment" + EXAM = "exam" + TOPIC = "topic" + CONCEPT = "concept" + RESOURCE = "resource" + CUSTOM = "custom" + + +class RelationshipType(Enum): + """Types of relationships between entities""" + BELONGS_TO = "belongs_to" # Student belongs to class + TEACHES = "teaches" # Teacher teaches subject + ASSIGNED_TO = "assigned_to" # Assignment assigned to student + COVERS = "covers" # Exam covers topic + REQUIRES = "requires" # Topic requires concept + RELATED_TO = "related_to" # General relationship + PARENT_OF = "parent_of" # Hierarchical relationship + CREATED_BY = "created_by" # Creator relationship + GRADED_BY = "graded_by" # Grading relationship + + +@dataclass +class Entity: + """Represents an entity in the knowledge graph""" + id: str + entity_type: EntityType + name: str + properties: Dict[str, Any] = field(default_factory=dict) + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "entity_type": self.entity_type.value, + "name": self.name, + "properties": self.properties, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat() + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Entity": + return cls( + id=data["id"], + entity_type=EntityType(data["entity_type"]), + name=data["name"], + properties=data.get("properties", {}), + created_at=datetime.fromisoformat(data["created_at"]), + updated_at=datetime.fromisoformat(data["updated_at"]) + ) + + +@dataclass +class Relationship: + """Represents a relationship between two entities""" + id: str + source_id: str + target_id: str + relationship_type: RelationshipType + properties: Dict[str, Any] = field(default_factory=dict) + weight: float = 1.0 + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "source_id": self.source_id, + "target_id": self.target_id, + "relationship_type": self.relationship_type.value, + "properties": self.properties, + "weight": self.weight, + "created_at": self.created_at.isoformat() + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Relationship": + return cls( + id=data["id"], + source_id=data["source_id"], + target_id=data["target_id"], + relationship_type=RelationshipType(data["relationship_type"]), + properties=data.get("properties", {}), + weight=data.get("weight", 1.0), + created_at=datetime.fromisoformat(data["created_at"]) + ) diff --git a/agent-core/brain/memory_models.py b/agent-core/brain/memory_models.py new file mode 100644 index 0000000..6283df7 --- /dev/null +++ b/agent-core/brain/memory_models.py @@ -0,0 +1,53 @@ +""" +Memory Models for Breakpilot Agents + +Data classes for memory items used by MemoryStore. +""" + +from typing import Dict, Any, Optional +from datetime import datetime, timezone +from dataclasses import dataclass, field + + +@dataclass +class Memory: + """Represents a stored memory item""" + key: str + value: Any + agent_id: str + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + expires_at: Optional[datetime] = None + access_count: int = 0 + last_accessed: Optional[datetime] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "key": self.key, + "value": self.value, + "agent_id": self.agent_id, + "created_at": self.created_at.isoformat(), + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "access_count": self.access_count, + "last_accessed": self.last_accessed.isoformat() if self.last_accessed else None, + "metadata": self.metadata + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Memory": + return cls( + key=data["key"], + value=data["value"], + agent_id=data["agent_id"], + created_at=datetime.fromisoformat(data["created_at"]), + expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None, + access_count=data.get("access_count", 0), + last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None, + metadata=data.get("metadata", {}) + ) + + def is_expired(self) -> bool: + """Check if the memory has expired""" + if not self.expires_at: + return False + return datetime.now(timezone.utc) > self.expires_at diff --git a/agent-core/brain/memory_store.py b/agent-core/brain/memory_store.py index f41afbf..ec351de 100644 --- a/agent-core/brain/memory_store.py +++ b/agent-core/brain/memory_store.py @@ -1,92 +1,24 @@ """ Memory Store for Breakpilot Agents -Provides long-term memory with: -- TTL-based expiration -- Access count tracking -- Pattern-based search -- Hybrid Valkey + PostgreSQL persistence +Hybrid Valkey + PostgreSQL persistence with TTL, access tracking, and pattern search. """ from typing import List, Dict, Any, Optional from datetime import datetime, timezone, timedelta -from dataclasses import dataclass, field import json import logging import hashlib +from agent_core.brain.memory_models import Memory + logger = logging.getLogger(__name__) -@dataclass -class Memory: - """Represents a stored memory item""" - key: str - value: Any - agent_id: str - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - expires_at: Optional[datetime] = None - access_count: int = 0 - last_accessed: Optional[datetime] = None - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: - return { - "key": self.key, - "value": self.value, - "agent_id": self.agent_id, - "created_at": self.created_at.isoformat(), - "expires_at": self.expires_at.isoformat() if self.expires_at else None, - "access_count": self.access_count, - "last_accessed": self.last_accessed.isoformat() if self.last_accessed else None, - "metadata": self.metadata - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Memory": - return cls( - key=data["key"], - value=data["value"], - agent_id=data["agent_id"], - created_at=datetime.fromisoformat(data["created_at"]), - expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None, - access_count=data.get("access_count", 0), - last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None, - metadata=data.get("metadata", {}) - ) - - def is_expired(self) -> bool: - """Check if the memory has expired""" - if not self.expires_at: - return False - return datetime.now(timezone.utc) > self.expires_at - - class MemoryStore: - """ - Long-term memory store for agents. + """Long-term memory store with TTL, access tracking, and hybrid persistence.""" - Stores facts, decisions, and learning progress with: - - TTL-based expiration - - Access tracking for importance scoring - - Pattern-based retrieval - - Hybrid persistence (Valkey for fast access, PostgreSQL for durability) - """ - - def __init__( - self, - redis_client=None, - db_pool=None, - namespace: str = "breakpilot" - ): - """ - Initialize the memory store. - - Args: - redis_client: Async Redis/Valkey client - db_pool: Async PostgreSQL connection pool - namespace: Key namespace for isolation - """ + def __init__(self, redis_client=None, db_pool=None, namespace: str = "breakpilot"): self.redis = redis_client self.db_pool = db_pool self.namespace = namespace @@ -103,26 +35,10 @@ class MemoryStore: return key async def remember( - self, - key: str, - value: Any, - agent_id: str, - ttl_days: int = 30, - metadata: Optional[Dict[str, Any]] = None + self, key: str, value: Any, agent_id: str, + ttl_days: int = 30, metadata: Optional[Dict[str, Any]] = None ) -> Memory: - """ - Stores a memory. - - Args: - key: Unique key for the memory - value: Value to store (must be JSON-serializable) - agent_id: ID of the agent storing the memory - ttl_days: Time to live in days (0 = no expiration) - metadata: Optional additional metadata - - Returns: - The created Memory object - """ + """Stores a memory with optional TTL and metadata.""" expires_at = None if ttl_days > 0: expires_at = datetime.now(timezone.utc) + timedelta(days=ttl_days) @@ -143,32 +59,14 @@ class MemoryStore: return memory async def recall(self, key: str) -> Optional[Any]: - """ - Retrieves a memory value by key. - - Args: - key: The memory key - - Returns: - The stored value or None if not found/expired - """ + """Retrieves a memory value by key, or None if not found/expired.""" memory = await self.get_memory(key) if memory: return memory.value return None async def get_memory(self, key: str) -> Optional[Memory]: - """ - Retrieves a full Memory object by key. - - Updates access count and last_accessed timestamp. - - Args: - key: The memory key - - Returns: - Memory object or None if not found/expired - """ + """Retrieves a full Memory object by key, updating access count.""" # Check local cache if key in self._local_cache: memory = self._local_cache[key] diff --git a/agent-core/orchestrator/__init__.py b/agent-core/orchestrator/__init__.py index 8a7784a..b6289c3 100644 --- a/agent-core/orchestrator/__init__.py +++ b/agent-core/orchestrator/__init__.py @@ -12,11 +12,13 @@ from agent_core.orchestrator.message_bus import ( AgentMessage, MessagePriority, ) -from agent_core.orchestrator.supervisor import ( - AgentSupervisor, - AgentInfo, +from agent_core.orchestrator.supervisor_models import ( AgentStatus, + RestartPolicy, + AgentInfo, + AgentFactory, ) +from agent_core.orchestrator.supervisor import AgentSupervisor from agent_core.orchestrator.task_router import ( TaskRouter, RoutingResult, @@ -30,6 +32,8 @@ __all__ = [ "AgentSupervisor", "AgentInfo", "AgentStatus", + "RestartPolicy", + "AgentFactory", "TaskRouter", "RoutingResult", "RoutingStrategy", diff --git a/agent-core/orchestrator/supervisor.py b/agent-core/orchestrator/supervisor.py index 72ecef4..2d98008 100644 --- a/agent-core/orchestrator/supervisor.py +++ b/agent-core/orchestrator/supervisor.py @@ -1,17 +1,11 @@ """ Agent Supervisor for Breakpilot -Provides: -- Agent lifecycle management -- Health monitoring -- Restart policies -- Load balancing +Agent lifecycle management, health monitoring, restart policies, load balancing. """ -from typing import Dict, Optional, Callable, Awaitable, List, Any -from dataclasses import dataclass, field +from typing import Dict, Optional, List, Any from datetime import datetime, timezone, timedelta -from enum import Enum import asyncio import logging @@ -21,91 +15,24 @@ from agent_core.orchestrator.message_bus import ( AgentMessage, MessagePriority, ) +from agent_core.orchestrator.supervisor_models import ( + AgentStatus, + RestartPolicy, + AgentInfo, + AgentFactory, +) logger = logging.getLogger(__name__) -class AgentStatus(Enum): - """Agent lifecycle states""" - INITIALIZING = "initializing" - STARTING = "starting" - RUNNING = "running" - PAUSED = "paused" - STOPPING = "stopping" - STOPPED = "stopped" - ERROR = "error" - RESTARTING = "restarting" - - -class RestartPolicy(Enum): - """Agent restart policies""" - NEVER = "never" - ON_FAILURE = "on_failure" - ALWAYS = "always" - - -@dataclass -class AgentInfo: - """Information about a registered agent""" - agent_id: str - agent_type: str - status: AgentStatus = AgentStatus.INITIALIZING - current_task: Optional[str] = None - started_at: Optional[datetime] = None - last_activity: Optional[datetime] = None - error_count: int = 0 - restart_count: int = 0 - max_restarts: int = 3 - restart_policy: RestartPolicy = RestartPolicy.ON_FAILURE - metadata: Dict[str, Any] = field(default_factory=dict) - capacity: int = 10 # Max concurrent tasks - current_load: int = 0 - - def is_healthy(self) -> bool: - """Check if agent is healthy""" - return self.status == AgentStatus.RUNNING and self.error_count < 3 - - def is_available(self) -> bool: - """Check if agent can accept new tasks""" - return ( - self.status == AgentStatus.RUNNING and - self.current_load < self.capacity - ) - - def utilization(self) -> float: - """Returns agent utilization (0-1)""" - return self.current_load / max(self.capacity, 1) - - -AgentFactory = Callable[[str], Awaitable[Any]] - - class AgentSupervisor: - """ - Supervises and coordinates all agents. - - Responsibilities: - - Agent registration and lifecycle - - Health monitoring via heartbeat - - Restart policies - - Load balancing - - Alert escalation - """ + """Supervises agents: lifecycle, health monitoring, restart policies, load balancing.""" def __init__( - self, - message_bus: MessageBus, + self, message_bus: MessageBus, heartbeat_monitor: Optional[HeartbeatMonitor] = None, check_interval_seconds: int = 10 ): - """ - Initialize the supervisor. - - Args: - message_bus: Message bus for inter-agent communication - heartbeat_monitor: Heartbeat monitor for liveness checks - check_interval_seconds: How often to run health checks - """ self.bus = message_bus self.heartbeat = heartbeat_monitor or HeartbeatMonitor() self.check_interval = check_interval_seconds diff --git a/agent-core/orchestrator/supervisor_models.py b/agent-core/orchestrator/supervisor_models.py new file mode 100644 index 0000000..68048f1 --- /dev/null +++ b/agent-core/orchestrator/supervisor_models.py @@ -0,0 +1,65 @@ +""" +Supervisor Models for Breakpilot Agents + +Data classes and enumerations for agent lifecycle management. +""" + +from typing import Dict, Optional, Any, Callable, Awaitable +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum + + +class AgentStatus(Enum): + """Agent lifecycle states""" + INITIALIZING = "initializing" + STARTING = "starting" + RUNNING = "running" + PAUSED = "paused" + STOPPING = "stopping" + STOPPED = "stopped" + ERROR = "error" + RESTARTING = "restarting" + + +class RestartPolicy(Enum): + """Agent restart policies""" + NEVER = "never" + ON_FAILURE = "on_failure" + ALWAYS = "always" + + +@dataclass +class AgentInfo: + """Information about a registered agent""" + agent_id: str + agent_type: str + status: AgentStatus = AgentStatus.INITIALIZING + current_task: Optional[str] = None + started_at: Optional[datetime] = None + last_activity: Optional[datetime] = None + error_count: int = 0 + restart_count: int = 0 + max_restarts: int = 3 + restart_policy: RestartPolicy = RestartPolicy.ON_FAILURE + metadata: Dict[str, Any] = field(default_factory=dict) + capacity: int = 10 # Max concurrent tasks + current_load: int = 0 + + def is_healthy(self) -> bool: + """Check if agent is healthy""" + return self.status == AgentStatus.RUNNING and self.error_count < 3 + + def is_available(self) -> bool: + """Check if agent can accept new tasks""" + return ( + self.status == AgentStatus.RUNNING and + self.current_load < self.capacity + ) + + def utilization(self) -> float: + """Returns agent utilization (0-1)""" + return self.current_load / max(self.capacity, 1) + + +AgentFactory = Callable[[str], Awaitable[Any]] diff --git a/agent-core/sessions/__init__.py b/agent-core/sessions/__init__.py index 3d26c3b..bbe1132 100644 --- a/agent-core/sessions/__init__.py +++ b/agent-core/sessions/__init__.py @@ -8,11 +8,12 @@ Provides: - SessionState: Session state enumeration """ -from agent_core.sessions.session_manager import ( +from agent_core.sessions.session_models import ( AgentSession, - SessionManager, SessionState, + SessionCheckpoint, ) +from agent_core.sessions.session_manager import SessionManager from agent_core.sessions.heartbeat import HeartbeatMonitor from agent_core.sessions.checkpoint import CheckpointManager @@ -20,6 +21,7 @@ __all__ = [ "AgentSession", "SessionManager", "SessionState", + "SessionCheckpoint", "HeartbeatMonitor", "CheckpointManager", ] diff --git a/agent-core/sessions/session_manager.py b/agent-core/sessions/session_manager.py index 0e12e74..2bc6b74 100644 --- a/agent-core/sessions/session_manager.py +++ b/agent-core/sessions/session_manager.py @@ -2,189 +2,25 @@ Session Management for Breakpilot Agents Provides session lifecycle management with: -- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED) -- Checkpoint-based recovery -- Heartbeat integration - Hybrid Valkey + PostgreSQL persistence +- Session CRUD operations +- Stale session cleanup """ -from dataclasses import dataclass, field from datetime import datetime, timezone, timedelta from typing import Dict, Any, Optional, List -from enum import Enum -import uuid import json import logging +from agent_core.sessions.session_models import ( + SessionState, + SessionCheckpoint, + AgentSession, +) + logger = logging.getLogger(__name__) -class SessionState(Enum): - """Agent session states""" - ACTIVE = "active" - PAUSED = "paused" - COMPLETED = "completed" - FAILED = "failed" - - -@dataclass -class SessionCheckpoint: - """Represents a checkpoint in an agent session""" - name: str - timestamp: datetime - data: Dict[str, Any] - - def to_dict(self) -> Dict[str, Any]: - return { - "name": self.name, - "timestamp": self.timestamp.isoformat(), - "data": self.data - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint": - return cls( - name=data["name"], - timestamp=datetime.fromisoformat(data["timestamp"]), - data=data["data"] - ) - - -@dataclass -class AgentSession: - """ - Represents an active agent session. - - Attributes: - session_id: Unique session identifier - agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator) - user_id: Associated user ID - state: Current session state - created_at: Session creation timestamp - last_heartbeat: Last heartbeat timestamp - context: Session context data - checkpoints: List of session checkpoints for recovery - """ - session_id: str = field(default_factory=lambda: str(uuid.uuid4())) - agent_type: str = "" - user_id: str = "" - state: SessionState = SessionState.ACTIVE - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - context: Dict[str, Any] = field(default_factory=dict) - checkpoints: List[SessionCheckpoint] = field(default_factory=list) - metadata: Dict[str, Any] = field(default_factory=dict) - - def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint: - """ - Creates a checkpoint for recovery. - - Args: - name: Checkpoint name (e.g., "task_received", "processing_complete") - data: Checkpoint data to store - - Returns: - The created checkpoint - """ - checkpoint = SessionCheckpoint( - name=name, - timestamp=datetime.now(timezone.utc), - data=data - ) - self.checkpoints.append(checkpoint) - logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created") - return checkpoint - - def heartbeat(self) -> None: - """Updates the heartbeat timestamp""" - self.last_heartbeat = datetime.now(timezone.utc) - - def pause(self) -> None: - """Pauses the session""" - self.state = SessionState.PAUSED - self.checkpoint("session_paused", {"previous_state": "active"}) - - def resume(self) -> None: - """Resumes a paused session""" - if self.state == SessionState.PAUSED: - self.state = SessionState.ACTIVE - self.heartbeat() - self.checkpoint("session_resumed", {}) - - def complete(self, result: Optional[Dict[str, Any]] = None) -> None: - """Marks the session as completed""" - self.state = SessionState.COMPLETED - self.checkpoint("session_completed", {"result": result or {}}) - - def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None: - """Marks the session as failed""" - self.state = SessionState.FAILED - self.checkpoint("session_failed", { - "error": error, - "details": error_details or {} - }) - - def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]: - """ - Gets the last checkpoint, optionally filtered by name. - - Args: - name: Optional checkpoint name to filter by - - Returns: - The last matching checkpoint or None - """ - if not self.checkpoints: - return None - - if name: - matching = [cp for cp in self.checkpoints if cp.name == name] - return matching[-1] if matching else None - - return self.checkpoints[-1] - - def get_duration(self) -> timedelta: - """Returns the session duration""" - end_time = datetime.now(timezone.utc) - if self.state in (SessionState.COMPLETED, SessionState.FAILED): - last_cp = self.get_last_checkpoint() - if last_cp: - end_time = last_cp.timestamp - return end_time - self.created_at - - def to_dict(self) -> Dict[str, Any]: - """Serializes the session to a dictionary""" - return { - "session_id": self.session_id, - "agent_type": self.agent_type, - "user_id": self.user_id, - "state": self.state.value, - "created_at": self.created_at.isoformat(), - "last_heartbeat": self.last_heartbeat.isoformat(), - "context": self.context, - "checkpoints": [cp.to_dict() for cp in self.checkpoints], - "metadata": self.metadata - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "AgentSession": - """Deserializes a session from a dictionary""" - return cls( - session_id=data["session_id"], - agent_type=data["agent_type"], - user_id=data["user_id"], - state=SessionState(data["state"]), - created_at=datetime.fromisoformat(data["created_at"]), - last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]), - context=data.get("context", {}), - checkpoints=[ - SessionCheckpoint.from_dict(cp) - for cp in data.get("checkpoints", []) - ], - metadata=data.get("metadata", {}) - ) - - class SessionManager: """ Manages agent sessions with hybrid Valkey + PostgreSQL persistence. @@ -303,7 +139,6 @@ class SessionManager: """ session.heartbeat() self._local_cache[session.session_id] = session - self._local_cache[session.session_id] = session await self._persist_session(session) async def delete_session(self, session_id: str) -> bool: diff --git a/agent-core/sessions/session_models.py b/agent-core/sessions/session_models.py new file mode 100644 index 0000000..514b4c9 --- /dev/null +++ b/agent-core/sessions/session_models.py @@ -0,0 +1,180 @@ +""" +Session Models for Breakpilot Agents + +Data classes for agent sessions, checkpoints, and state tracking. +""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone, timedelta +from typing import Dict, Any, Optional, List +from enum import Enum +import uuid +import logging + +logger = logging.getLogger(__name__) + + +class SessionState(Enum): + """Agent session states""" + ACTIVE = "active" + PAUSED = "paused" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class SessionCheckpoint: + """Represents a checkpoint in an agent session""" + name: str + timestamp: datetime + data: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "timestamp": self.timestamp.isoformat(), + "data": self.data + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint": + return cls( + name=data["name"], + timestamp=datetime.fromisoformat(data["timestamp"]), + data=data["data"] + ) + + +@dataclass +class AgentSession: + """ + Represents an active agent session. + + Attributes: + session_id: Unique session identifier + agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator) + user_id: Associated user ID + state: Current session state + created_at: Session creation timestamp + last_heartbeat: Last heartbeat timestamp + context: Session context data + checkpoints: List of session checkpoints for recovery + """ + session_id: str = field(default_factory=lambda: str(uuid.uuid4())) + agent_type: str = "" + user_id: str = "" + state: SessionState = SessionState.ACTIVE + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + context: Dict[str, Any] = field(default_factory=dict) + checkpoints: List[SessionCheckpoint] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint: + """ + Creates a checkpoint for recovery. + + Args: + name: Checkpoint name (e.g., "task_received", "processing_complete") + data: Checkpoint data to store + + Returns: + The created checkpoint + """ + checkpoint = SessionCheckpoint( + name=name, + timestamp=datetime.now(timezone.utc), + data=data + ) + self.checkpoints.append(checkpoint) + logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created") + return checkpoint + + def heartbeat(self) -> None: + """Updates the heartbeat timestamp""" + self.last_heartbeat = datetime.now(timezone.utc) + + def pause(self) -> None: + """Pauses the session""" + self.state = SessionState.PAUSED + self.checkpoint("session_paused", {"previous_state": "active"}) + + def resume(self) -> None: + """Resumes a paused session""" + if self.state == SessionState.PAUSED: + self.state = SessionState.ACTIVE + self.heartbeat() + self.checkpoint("session_resumed", {}) + + def complete(self, result: Optional[Dict[str, Any]] = None) -> None: + """Marks the session as completed""" + self.state = SessionState.COMPLETED + self.checkpoint("session_completed", {"result": result or {}}) + + def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None: + """Marks the session as failed""" + self.state = SessionState.FAILED + self.checkpoint("session_failed", { + "error": error, + "details": error_details or {} + }) + + def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]: + """ + Gets the last checkpoint, optionally filtered by name. + + Args: + name: Optional checkpoint name to filter by + + Returns: + The last matching checkpoint or None + """ + if not self.checkpoints: + return None + + if name: + matching = [cp for cp in self.checkpoints if cp.name == name] + return matching[-1] if matching else None + + return self.checkpoints[-1] + + def get_duration(self) -> timedelta: + """Returns the session duration""" + end_time = datetime.now(timezone.utc) + if self.state in (SessionState.COMPLETED, SessionState.FAILED): + last_cp = self.get_last_checkpoint() + if last_cp: + end_time = last_cp.timestamp + return end_time - self.created_at + + def to_dict(self) -> Dict[str, Any]: + """Serializes the session to a dictionary""" + return { + "session_id": self.session_id, + "agent_type": self.agent_type, + "user_id": self.user_id, + "state": self.state.value, + "created_at": self.created_at.isoformat(), + "last_heartbeat": self.last_heartbeat.isoformat(), + "context": self.context, + "checkpoints": [cp.to_dict() for cp in self.checkpoints], + "metadata": self.metadata + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AgentSession": + """Deserializes a session from a dictionary""" + return cls( + session_id=data["session_id"], + agent_type=data["agent_type"], + user_id=data["user_id"], + state=SessionState(data["state"]), + created_at=datetime.fromisoformat(data["created_at"]), + last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]), + context=data.get("context", {}), + checkpoints=[ + SessionCheckpoint.from_dict(cp) + for cp in data.get("checkpoints", []) + ], + metadata=data.get("metadata", {}) + ) diff --git a/backend-lehrer/ai_processor/export/print_templates.py b/backend-lehrer/ai_processor/export/print_templates.py new file mode 100644 index 0000000..b3df69a --- /dev/null +++ b/backend-lehrer/ai_processor/export/print_templates.py @@ -0,0 +1,244 @@ +""" +AI Processor - HTML Templates for Print Versions + +Contains HTML/CSS header templates for Q&A, Cloze, and Multiple Choice print output. +""" + + +def get_qa_html_header(title: str) -> str: + """Get HTML header for Q&A print version.""" + return f""" + + + +{title} - Fragen + + + +""" + + +def get_cloze_html_header(title: str) -> str: + """Get HTML header for cloze print version.""" + return f""" + + + +{title} - Lueckentext + + + +""" + + +def get_mc_html_header(title: str) -> str: + """Get HTML header for MC print version.""" + return f""" + + + +{title} - Multiple Choice + + + +""" diff --git a/backend-lehrer/ai_processor/export/print_versions.py b/backend-lehrer/ai_processor/export/print_versions.py index 1619a87..5f06009 100644 --- a/backend-lehrer/ai_processor/export/print_versions.py +++ b/backend-lehrer/ai_processor/export/print_versions.py @@ -10,6 +10,7 @@ import logging import random from ..config import BEREINIGT_DIR +from .print_templates import get_qa_html_header, get_cloze_html_header, get_mc_html_header logger = logging.getLogger(__name__) @@ -37,7 +38,7 @@ def generate_print_version_qa(qa_path: Path, include_answers: bool = False) -> P grade = metadata.get("grade_level", "") html_parts = [] - html_parts.append(_get_qa_html_header(title)) + html_parts.append(get_qa_html_header(title)) # Header version_text = "Loesungsblatt" if include_answers else "Fragenblatt" @@ -106,7 +107,7 @@ def generate_print_version_cloze(cloze_path: Path, include_answers: bool = False total_gaps = metadata.get("total_gaps", 0) html_parts = [] - html_parts.append(_get_cloze_html_header(title)) + html_parts.append(get_cloze_html_header(title)) # Header version_text = "Loesungsblatt" if include_answers else "Lueckentext" @@ -200,7 +201,7 @@ def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> s grade = metadata.get("grade_level", "") html_parts = [] - html_parts.append(_get_mc_html_header(title)) + html_parts.append(get_mc_html_header(title)) # Header version_text = "Loesungsblatt" if include_answers else "Multiple Choice Test" @@ -267,242 +268,3 @@ def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> s html_parts.append("") return "\n".join(html_parts) - - -def _get_qa_html_header(title: str) -> str: - """Get HTML header for Q&A print version.""" - return f""" - - - -{title} - Fragen - - - -""" - - -def _get_cloze_html_header(title: str) -> str: - """Get HTML header for cloze print version.""" - return f""" - - - -{title} - Lueckentext - - - -""" - - -def _get_mc_html_header(title: str) -> str: - """Get HTML header for MC print version.""" - return f""" - - - -{title} - Multiple Choice - - - -""" diff --git a/backend-lehrer/alerts_agent/api/digests.py b/backend-lehrer/alerts_agent/api/digests.py index 5d1d762..31d7ca7 100644 --- a/backend-lehrer/alerts_agent/api/digests.py +++ b/backend-lehrer/alerts_agent/api/digests.py @@ -9,13 +9,10 @@ Endpoints: - POST /digests/{id}/send-email - Digest per E-Mail versenden """ -import uuid import io -from typing import Optional, List from datetime import datetime, timedelta -from fastapi import APIRouter, Depends, HTTPException, Query, Response +from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse -from pydantic import BaseModel, Field from sqlalchemy.orm import Session as DBSession from ..db.database import get_db @@ -23,126 +20,27 @@ from ..db.models import ( AlertDigestDB, UserAlertSubscriptionDB, DigestStatusEnum ) from ..processing.digest_generator import DigestGenerator +from .digests_models import ( + DigestDetail, + DigestListResponse, + GenerateDigestRequest, + GenerateDigestResponse, + SendEmailRequest, + SendEmailResponse, + digest_to_list_item, + digest_to_detail, +) +from .digests_email import generate_pdf_from_html, send_digest_by_email router = APIRouter(prefix="/digests", tags=["digests"]) -# ============================================================================ -# Request/Response Models -# ============================================================================ - -class DigestListItem(BaseModel): - """Kurze Digest-Info fuer Liste.""" - id: str - period_start: datetime - period_end: datetime - total_alerts: int - critical_count: int - urgent_count: int - status: str - created_at: datetime - - -class DigestDetail(BaseModel): - """Vollstaendige Digest-Details.""" - id: str - subscription_id: Optional[str] - user_id: str - period_start: datetime - period_end: datetime - summary_html: str - summary_pdf_url: Optional[str] - total_alerts: int - critical_count: int - urgent_count: int - important_count: int - review_count: int - info_count: int - status: str - sent_at: Optional[datetime] - created_at: datetime - - -class DigestListResponse(BaseModel): - """Response fuer Digest-Liste.""" - digests: List[DigestListItem] - total: int - - -class GenerateDigestRequest(BaseModel): - """Request fuer manuelle Digest-Generierung.""" - weeks_back: int = Field(default=1, ge=1, le=4, description="Wochen zurueck") - force_regenerate: bool = Field(default=False, description="Vorhandenen Digest ueberschreiben") - - -class GenerateDigestResponse(BaseModel): - """Response fuer Digest-Generierung.""" - status: str - digest_id: Optional[str] - message: str - - -class SendEmailRequest(BaseModel): - """Request fuer E-Mail-Versand.""" - email: Optional[str] = Field(default=None, description="E-Mail-Adresse (optional, sonst aus Subscription)") - - -class SendEmailResponse(BaseModel): - """Response fuer E-Mail-Versand.""" - status: str - sent_to: str - message: str - - -# ============================================================================ -# Helper Functions -# ============================================================================ - def get_user_id_from_request() -> str: - """ - Extrahiert User-ID aus Request. - TODO: JWT-Token auswerten, aktuell Dummy. - """ + """Extrahiert User-ID aus Request. TODO: JWT-Token auswerten.""" return "demo-user" -def _digest_to_list_item(digest: AlertDigestDB) -> DigestListItem: - """Konvertiere DB-Model zu List-Item.""" - return DigestListItem( - id=digest.id, - period_start=digest.period_start, - period_end=digest.period_end, - total_alerts=digest.total_alerts or 0, - critical_count=digest.critical_count or 0, - urgent_count=digest.urgent_count or 0, - status=digest.status.value if digest.status else "pending", - created_at=digest.created_at - ) - - -def _digest_to_detail(digest: AlertDigestDB) -> DigestDetail: - """Konvertiere DB-Model zu Detail.""" - return DigestDetail( - id=digest.id, - subscription_id=digest.subscription_id, - user_id=digest.user_id, - period_start=digest.period_start, - period_end=digest.period_end, - summary_html=digest.summary_html or "", - summary_pdf_url=digest.summary_pdf_url, - total_alerts=digest.total_alerts or 0, - critical_count=digest.critical_count or 0, - urgent_count=digest.urgent_count or 0, - important_count=digest.important_count or 0, - review_count=digest.review_count or 0, - info_count=digest.info_count or 0, - status=digest.status.value if digest.status else "pending", - sent_at=digest.sent_at, - created_at=digest.created_at - ) - - # ============================================================================ # Endpoints # ============================================================================ @@ -153,11 +51,7 @@ async def list_digests( offset: int = Query(0, ge=0), db: DBSession = Depends(get_db) ): - """ - Liste alle Digests des aktuellen Users. - - Sortiert nach Erstellungsdatum (neueste zuerst). - """ + """Liste alle Digests des aktuellen Users.""" user_id = get_user_id_from_request() query = db.query(AlertDigestDB).filter( @@ -168,18 +62,14 @@ async def list_digests( digests = query.offset(offset).limit(limit).all() return DigestListResponse( - digests=[_digest_to_list_item(d) for d in digests], + digests=[digest_to_list_item(d) for d in digests], total=total ) @router.get("/latest", response_model=DigestDetail) -async def get_latest_digest( - db: DBSession = Depends(get_db) -): - """ - Hole den neuesten Digest des Users. - """ +async def get_latest_digest(db: DBSession = Depends(get_db)): + """Hole den neuesten Digest des Users.""" user_id = get_user_id_from_request() digest = db.query(AlertDigestDB).filter( @@ -189,17 +79,12 @@ async def get_latest_digest( if not digest: raise HTTPException(status_code=404, detail="Kein Digest vorhanden") - return _digest_to_detail(digest) + return digest_to_detail(digest) @router.get("/{digest_id}", response_model=DigestDetail) -async def get_digest( - digest_id: str, - db: DBSession = Depends(get_db) -): - """ - Hole Details eines spezifischen Digests. - """ +async def get_digest(digest_id: str, db: DBSession = Depends(get_db)): + """Hole Details eines spezifischen Digests.""" user_id = get_user_id_from_request() digest = db.query(AlertDigestDB).filter( @@ -210,17 +95,12 @@ async def get_digest( if not digest: raise HTTPException(status_code=404, detail="Digest nicht gefunden") - return _digest_to_detail(digest) + return digest_to_detail(digest) @router.get("/{digest_id}/pdf") -async def get_digest_pdf( - digest_id: str, - db: DBSession = Depends(get_db) -): - """ - Generiere und lade PDF-Version des Digests herunter. - """ +async def get_digest_pdf(digest_id: str, db: DBSession = Depends(get_db)): + """Generiere und lade PDF-Version des Digests herunter.""" user_id = get_user_id_from_request() digest = db.query(AlertDigestDB).filter( @@ -230,35 +110,26 @@ async def get_digest_pdf( if not digest: raise HTTPException(status_code=404, detail="Digest nicht gefunden") - if not digest.summary_html: raise HTTPException(status_code=400, detail="Digest hat keinen Inhalt") - # PDF generieren try: pdf_bytes = await generate_pdf_from_html(digest.summary_html) except Exception as e: raise HTTPException(status_code=500, detail=f"PDF-Generierung fehlgeschlagen: {str(e)}") - # Dateiname filename = f"wochenbericht_{digest.period_start.strftime('%Y%m%d')}_{digest.period_end.strftime('%Y%m%d')}.pdf" return StreamingResponse( io.BytesIO(pdf_bytes), media_type="application/pdf", - headers={ - "Content-Disposition": f"attachment; filename={filename}" - } + headers={"Content-Disposition": f"attachment; filename={filename}"} ) @router.get("/latest/pdf") -async def get_latest_digest_pdf( - db: DBSession = Depends(get_db) -): - """ - PDF des neuesten Digests herunterladen. - """ +async def get_latest_digest_pdf(db: DBSession = Depends(get_db)): + """PDF des neuesten Digests herunterladen.""" user_id = get_user_id_from_request() digest = db.query(AlertDigestDB).filter( @@ -267,11 +138,9 @@ async def get_latest_digest_pdf( if not digest: raise HTTPException(status_code=404, detail="Kein Digest vorhanden") - if not digest.summary_html: raise HTTPException(status_code=400, detail="Digest hat keinen Inhalt") - # PDF generieren try: pdf_bytes = await generate_pdf_from_html(digest.summary_html) except Exception as e: @@ -282,9 +151,7 @@ async def get_latest_digest_pdf( return StreamingResponse( io.BytesIO(pdf_bytes), media_type="application/pdf", - headers={ - "Content-Disposition": f"attachment; filename={filename}" - } + headers={"Content-Disposition": f"attachment; filename={filename}"} ) @@ -293,16 +160,10 @@ async def generate_digest( request: GenerateDigestRequest = None, db: DBSession = Depends(get_db) ): - """ - Generiere einen neuen Digest manuell. - - Normalerweise werden Digests automatisch woechentlich generiert. - Diese Route erlaubt manuelle Generierung fuer Tests oder On-Demand. - """ + """Generiere einen neuen Digest manuell.""" user_id = get_user_id_from_request() weeks_back = request.weeks_back if request else 1 - # Pruefe ob bereits ein Digest fuer diesen Zeitraum existiert now = datetime.utcnow() period_end = now - timedelta(days=now.weekday()) period_start = period_end - timedelta(weeks=weeks_back) @@ -315,12 +176,10 @@ async def generate_digest( if existing and not (request and request.force_regenerate): return GenerateDigestResponse( - status="exists", - digest_id=existing.id, + status="exists", digest_id=existing.id, message="Digest fuer diesen Zeitraum existiert bereits" ) - # Generiere neuen Digest generator = DigestGenerator(db) try: @@ -328,14 +187,12 @@ async def generate_digest( if digest: return GenerateDigestResponse( - status="success", - digest_id=digest.id, + status="success", digest_id=digest.id, message="Digest erfolgreich generiert" ) else: return GenerateDigestResponse( - status="empty", - digest_id=None, + status="empty", digest_id=None, message="Keine Alerts fuer diesen Zeitraum vorhanden" ) except Exception as e: @@ -348,9 +205,7 @@ async def send_digest_email( request: SendEmailRequest = None, db: DBSession = Depends(get_db) ): - """ - Versende Digest per E-Mail. - """ + """Versende Digest per E-Mail.""" user_id = get_user_id_from_request() digest = db.query(AlertDigestDB).filter( @@ -361,12 +216,10 @@ async def send_digest_email( if not digest: raise HTTPException(status_code=404, detail="Digest nicht gefunden") - # E-Mail-Adresse ermitteln email = None if request and request.email: email = request.email else: - # Aus Subscription holen subscription = db.query(UserAlertSubscriptionDB).filter( UserAlertSubscriptionDB.id == digest.subscription_id ).first() @@ -376,176 +229,18 @@ async def send_digest_email( if not email: raise HTTPException(status_code=400, detail="Keine E-Mail-Adresse angegeben") - # E-Mail versenden try: await send_digest_by_email(digest, email) - # Status aktualisieren digest.status = DigestStatusEnum.SENT digest.sent_at = datetime.utcnow() db.commit() return SendEmailResponse( - status="success", - sent_to=email, + status="success", sent_to=email, message="E-Mail erfolgreich versendet" ) except Exception as e: digest.status = DigestStatusEnum.FAILED db.commit() raise HTTPException(status_code=500, detail=f"E-Mail-Versand fehlgeschlagen: {str(e)}") - - -# ============================================================================ -# PDF Generation -# ============================================================================ - -async def generate_pdf_from_html(html_content: str) -> bytes: - """ - Generiere PDF aus HTML. - - Verwendet WeasyPrint oder wkhtmltopdf als Fallback. - """ - try: - # Versuche WeasyPrint (bevorzugt) - from weasyprint import HTML - pdf_bytes = HTML(string=html_content).write_pdf() - return pdf_bytes - except ImportError: - pass - - try: - # Fallback: wkhtmltopdf via pdfkit - import pdfkit - pdf_bytes = pdfkit.from_string(html_content, False) - return pdf_bytes - except ImportError: - pass - - try: - # Fallback: xhtml2pdf - from xhtml2pdf import pisa - result = io.BytesIO() - pisa.CreatePDF(io.StringIO(html_content), dest=result) - return result.getvalue() - except ImportError: - pass - - # Letzter Fallback: Einfache Text-Konvertierung - raise ImportError( - "Keine PDF-Bibliothek verfuegbar. " - "Installieren Sie: pip install weasyprint oder pip install pdfkit oder pip install xhtml2pdf" - ) - - -# ============================================================================ -# Email Sending -# ============================================================================ - -async def send_digest_by_email(digest: AlertDigestDB, recipient_email: str): - """ - Versende Digest per E-Mail. - - Verwendet: - - Lokalen SMTP-Server (Postfix/Sendmail) - - SMTP-Relay (z.B. SES, Mailgun) - - SendGrid API - """ - import os - import smtplib - from email.mime.text import MIMEText - from email.mime.multipart import MIMEMultipart - from email.mime.application import MIMEApplication - - # E-Mail zusammenstellen - msg = MIMEMultipart('alternative') - msg['Subject'] = f"Wochenbericht: {digest.period_start.strftime('%d.%m.%Y')} - {digest.period_end.strftime('%d.%m.%Y')}" - msg['From'] = os.getenv('SMTP_FROM', 'alerts@breakpilot.app') - msg['To'] = recipient_email - - # Text-Version - text_content = f""" -BreakPilot Alerts - Wochenbericht - -Zeitraum: {digest.period_start.strftime('%d.%m.%Y')} - {digest.period_end.strftime('%d.%m.%Y')} -Gesamt: {digest.total_alerts} Meldungen -Kritisch: {digest.critical_count} -Dringend: {digest.urgent_count} - -Oeffnen Sie die HTML-Version fuer die vollstaendige Uebersicht. - ---- -Diese E-Mail wurde automatisch von BreakPilot Alerts generiert. - """ - msg.attach(MIMEText(text_content, 'plain', 'utf-8')) - - # HTML-Version - if digest.summary_html: - msg.attach(MIMEText(digest.summary_html, 'html', 'utf-8')) - - # PDF-Anhang (optional) - try: - pdf_bytes = await generate_pdf_from_html(digest.summary_html) - pdf_attachment = MIMEApplication(pdf_bytes, _subtype='pdf') - pdf_attachment.add_header( - 'Content-Disposition', 'attachment', - filename=f"wochenbericht_{digest.period_start.strftime('%Y%m%d')}.pdf" - ) - msg.attach(pdf_attachment) - except Exception: - pass # PDF-Anhang ist optional - - # Senden - smtp_host = os.getenv('SMTP_HOST', 'localhost') - smtp_port = int(os.getenv('SMTP_PORT', '25')) - smtp_user = os.getenv('SMTP_USER', '') - smtp_pass = os.getenv('SMTP_PASS', '') - - try: - if smtp_port == 465: - # SSL - server = smtplib.SMTP_SSL(smtp_host, smtp_port) - else: - server = smtplib.SMTP(smtp_host, smtp_port) - if smtp_port == 587: - server.starttls() - - if smtp_user and smtp_pass: - server.login(smtp_user, smtp_pass) - - server.send_message(msg) - server.quit() - - except Exception as e: - # Fallback: SendGrid API - sendgrid_key = os.getenv('SENDGRID_API_KEY') - if sendgrid_key: - await send_via_sendgrid(msg, sendgrid_key) - else: - raise e - - -async def send_via_sendgrid(msg, api_key: str): - """Fallback: SendGrid API.""" - import httpx - - async with httpx.AsyncClient() as client: - response = await client.post( - "https://api.sendgrid.com/v3/mail/send", - headers={ - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" - }, - json={ - "personalizations": [{"to": [{"email": msg['To']}]}], - "from": {"email": msg['From']}, - "subject": msg['Subject'], - "content": [ - {"type": "text/plain", "value": msg.get_payload(0).get_payload()}, - {"type": "text/html", "value": msg.get_payload(1).get_payload() if len(msg.get_payload()) > 1 else ""} - ] - } - ) - - if response.status_code >= 400: - raise Exception(f"SendGrid error: {response.status_code}") diff --git a/backend-lehrer/alerts_agent/api/digests_email.py b/backend-lehrer/alerts_agent/api/digests_email.py new file mode 100644 index 0000000..b4ed763 --- /dev/null +++ b/backend-lehrer/alerts_agent/api/digests_email.py @@ -0,0 +1,146 @@ +""" +Alert Digests - PDF-Generierung und E-Mail-Versand. +""" + +import io +import logging + +from ..db.models import AlertDigestDB + +logger = logging.getLogger(__name__) + + +async def generate_pdf_from_html(html_content: str) -> bytes: + """ + Generiere PDF aus HTML. + + Verwendet WeasyPrint oder wkhtmltopdf als Fallback. + """ + try: + from weasyprint import HTML + pdf_bytes = HTML(string=html_content).write_pdf() + return pdf_bytes + except ImportError: + pass + + try: + import pdfkit + pdf_bytes = pdfkit.from_string(html_content, False) + return pdf_bytes + except ImportError: + pass + + try: + from xhtml2pdf import pisa + result = io.BytesIO() + pisa.CreatePDF(io.StringIO(html_content), dest=result) + return result.getvalue() + except ImportError: + pass + + raise ImportError( + "Keine PDF-Bibliothek verfuegbar. " + "Installieren Sie: pip install weasyprint oder pip install pdfkit oder pip install xhtml2pdf" + ) + + +async def send_digest_by_email(digest: AlertDigestDB, recipient_email: str): + """ + Versende Digest per E-Mail. + + Verwendet: + - Lokalen SMTP-Server (Postfix/Sendmail) + - SMTP-Relay (z.B. SES, Mailgun) + - SendGrid API + """ + import os + import smtplib + from email.mime.text import MIMEText + from email.mime.multipart import MIMEMultipart + from email.mime.application import MIMEApplication + + msg = MIMEMultipart('alternative') + msg['Subject'] = f"Wochenbericht: {digest.period_start.strftime('%d.%m.%Y')} - {digest.period_end.strftime('%d.%m.%Y')}" + msg['From'] = os.getenv('SMTP_FROM', 'alerts@breakpilot.app') + msg['To'] = recipient_email + + text_content = f""" +BreakPilot Alerts - Wochenbericht + +Zeitraum: {digest.period_start.strftime('%d.%m.%Y')} - {digest.period_end.strftime('%d.%m.%Y')} +Gesamt: {digest.total_alerts} Meldungen +Kritisch: {digest.critical_count} +Dringend: {digest.urgent_count} + +Oeffnen Sie die HTML-Version fuer die vollstaendige Uebersicht. + +--- +Diese E-Mail wurde automatisch von BreakPilot Alerts generiert. + """ + msg.attach(MIMEText(text_content, 'plain', 'utf-8')) + + if digest.summary_html: + msg.attach(MIMEText(digest.summary_html, 'html', 'utf-8')) + + try: + pdf_bytes = await generate_pdf_from_html(digest.summary_html) + pdf_attachment = MIMEApplication(pdf_bytes, _subtype='pdf') + pdf_attachment.add_header( + 'Content-Disposition', 'attachment', + filename=f"wochenbericht_{digest.period_start.strftime('%Y%m%d')}.pdf" + ) + msg.attach(pdf_attachment) + except Exception: + pass # PDF-Anhang ist optional + + smtp_host = os.getenv('SMTP_HOST', 'localhost') + smtp_port = int(os.getenv('SMTP_PORT', '25')) + smtp_user = os.getenv('SMTP_USER', '') + smtp_pass = os.getenv('SMTP_PASS', '') + + try: + if smtp_port == 465: + server = smtplib.SMTP_SSL(smtp_host, smtp_port) + else: + server = smtplib.SMTP(smtp_host, smtp_port) + if smtp_port == 587: + server.starttls() + + if smtp_user and smtp_pass: + server.login(smtp_user, smtp_pass) + + server.send_message(msg) + server.quit() + + except Exception as e: + sendgrid_key = os.getenv('SENDGRID_API_KEY') + if sendgrid_key: + await send_via_sendgrid(msg, sendgrid_key) + else: + raise e + + +async def send_via_sendgrid(msg, api_key: str): + """Fallback: SendGrid API.""" + import httpx + + async with httpx.AsyncClient() as client: + response = await client.post( + "https://api.sendgrid.com/v3/mail/send", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + }, + json={ + "personalizations": [{"to": [{"email": msg['To']}]}], + "from": {"email": msg['From']}, + "subject": msg['Subject'], + "content": [ + {"type": "text/plain", "value": msg.get_payload(0).get_payload()}, + {"type": "text/html", "value": msg.get_payload(1).get_payload() if len(msg.get_payload()) > 1 else ""} + ] + } + ) + + if response.status_code >= 400: + raise Exception(f"SendGrid error: {response.status_code}") diff --git a/backend-lehrer/alerts_agent/api/digests_models.py b/backend-lehrer/alerts_agent/api/digests_models.py new file mode 100644 index 0000000..9c94d5b --- /dev/null +++ b/backend-lehrer/alerts_agent/api/digests_models.py @@ -0,0 +1,116 @@ +""" +Alert Digests - Request/Response Models und Konverter. +""" + +from typing import Optional, List +from datetime import datetime +from pydantic import BaseModel, Field + +from ..db.models import AlertDigestDB + + +# ============================================================================ +# Request/Response Models +# ============================================================================ + +class DigestListItem(BaseModel): + """Kurze Digest-Info fuer Liste.""" + id: str + period_start: datetime + period_end: datetime + total_alerts: int + critical_count: int + urgent_count: int + status: str + created_at: datetime + + +class DigestDetail(BaseModel): + """Vollstaendige Digest-Details.""" + id: str + subscription_id: Optional[str] + user_id: str + period_start: datetime + period_end: datetime + summary_html: str + summary_pdf_url: Optional[str] + total_alerts: int + critical_count: int + urgent_count: int + important_count: int + review_count: int + info_count: int + status: str + sent_at: Optional[datetime] + created_at: datetime + + +class DigestListResponse(BaseModel): + """Response fuer Digest-Liste.""" + digests: List[DigestListItem] + total: int + + +class GenerateDigestRequest(BaseModel): + """Request fuer manuelle Digest-Generierung.""" + weeks_back: int = Field(default=1, ge=1, le=4, description="Wochen zurueck") + force_regenerate: bool = Field(default=False, description="Vorhandenen Digest ueberschreiben") + + +class GenerateDigestResponse(BaseModel): + """Response fuer Digest-Generierung.""" + status: str + digest_id: Optional[str] + message: str + + +class SendEmailRequest(BaseModel): + """Request fuer E-Mail-Versand.""" + email: Optional[str] = Field(default=None, description="E-Mail-Adresse (optional)") + + +class SendEmailResponse(BaseModel): + """Response fuer E-Mail-Versand.""" + status: str + sent_to: str + message: str + + +# ============================================================================ +# Converter Functions +# ============================================================================ + +def digest_to_list_item(digest: AlertDigestDB) -> DigestListItem: + """Konvertiere DB-Model zu List-Item.""" + return DigestListItem( + id=digest.id, + period_start=digest.period_start, + period_end=digest.period_end, + total_alerts=digest.total_alerts or 0, + critical_count=digest.critical_count or 0, + urgent_count=digest.urgent_count or 0, + status=digest.status.value if digest.status else "pending", + created_at=digest.created_at + ) + + +def digest_to_detail(digest: AlertDigestDB) -> DigestDetail: + """Konvertiere DB-Model zu Detail.""" + return DigestDetail( + id=digest.id, + subscription_id=digest.subscription_id, + user_id=digest.user_id, + period_start=digest.period_start, + period_end=digest.period_end, + summary_html=digest.summary_html or "", + summary_pdf_url=digest.summary_pdf_url, + total_alerts=digest.total_alerts or 0, + critical_count=digest.critical_count or 0, + urgent_count=digest.urgent_count or 0, + important_count=digest.important_count or 0, + review_count=digest.review_count or 0, + info_count=digest.info_count or 0, + status=digest.status.value if digest.status else "pending", + sent_at=digest.sent_at, + created_at=digest.created_at + ) diff --git a/backend-lehrer/alerts_agent/api/routes.py b/backend-lehrer/alerts_agent/api/routes.py index 8f52761..630bba4 100644 --- a/backend-lehrer/alerts_agent/api/routes.py +++ b/backend-lehrer/alerts_agent/api/routes.py @@ -1,5 +1,5 @@ """ -API Routes für Alerts Agent. +API Routes fuer Alerts Agent. Endpoints: - POST /alerts/ingest - Manuell Alerts importieren @@ -13,12 +13,18 @@ Endpoints: import os from datetime import datetime from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query -from pydantic import BaseModel, Field +from fastapi import APIRouter, HTTPException, Query from ..models.alert_item import AlertItem, AlertStatus from ..models.relevance_profile import RelevanceProfile, PriorityItem from ..processing.relevance_scorer import RelevanceDecision, RelevanceScorer +from .schemas import ( + AlertIngestRequest, AlertIngestResponse, + AlertRunRequest, AlertRunResponse, + InboxItem, InboxResponse, + FeedbackRequest, FeedbackResponse, + ProfilePriorityRequest, ProfileUpdateRequest, ProfileResponse, +) router = APIRouter(prefix="/alerts", tags=["alerts"]) @@ -30,113 +36,13 @@ ALERTS_USE_LLM = os.getenv("ALERTS_USE_LLM", "false").lower() == "true" # ============================================================================ -# In-Memory Storage (später durch DB ersetzen) +# In-Memory Storage (spaeter durch DB ersetzen) # ============================================================================ _alerts_store: dict[str, AlertItem] = {} _profile_store: dict[str, RelevanceProfile] = {} -# ============================================================================ -# Request/Response Models -# ============================================================================ - -class AlertIngestRequest(BaseModel): - """Request für manuelles Alert-Import.""" - title: str = Field(..., min_length=1, max_length=500) - url: str = Field(..., min_length=1) - snippet: Optional[str] = Field(default=None, max_length=2000) - topic_label: str = Field(default="Manual Import") - published_at: Optional[datetime] = None - - -class AlertIngestResponse(BaseModel): - """Response für Alert-Import.""" - id: str - status: str - message: str - - -class AlertRunRequest(BaseModel): - """Request für Scoring-Pipeline.""" - limit: int = Field(default=50, ge=1, le=200) - skip_scored: bool = Field(default=True) - - -class AlertRunResponse(BaseModel): - """Response für Scoring-Pipeline.""" - processed: int - keep: int - drop: int - review: int - errors: int - duration_ms: int - - -class InboxItem(BaseModel): - """Ein Item in der Inbox.""" - id: str - title: str - url: str - snippet: Optional[str] - topic_label: str - published_at: Optional[datetime] - relevance_score: Optional[float] - relevance_decision: Optional[str] - relevance_summary: Optional[str] - status: str - - -class InboxResponse(BaseModel): - """Response für Inbox-Abfrage.""" - items: list[InboxItem] - total: int - page: int - page_size: int - - -class FeedbackRequest(BaseModel): - """Request für Relevanz-Feedback.""" - alert_id: str - is_relevant: bool - reason: Optional[str] = None - tags: list[str] = Field(default_factory=list) - - -class FeedbackResponse(BaseModel): - """Response für Feedback.""" - success: bool - message: str - profile_updated: bool - - -class ProfilePriorityRequest(BaseModel): - """Priority für Profile-Update.""" - label: str - weight: float = Field(default=0.5, ge=0.0, le=1.0) - keywords: list[str] = Field(default_factory=list) - description: Optional[str] = None - - -class ProfileUpdateRequest(BaseModel): - """Request für Profile-Update.""" - priorities: Optional[list[ProfilePriorityRequest]] = None - exclusions: Optional[list[str]] = None - policies: Optional[dict] = None - - -class ProfileResponse(BaseModel): - """Response für Profile.""" - id: str - priorities: list[dict] - exclusions: list[str] - policies: dict - total_scored: int - total_kept: int - total_dropped: int - accuracy_estimate: Optional[float] - - # ============================================================================ # Endpoints # ============================================================================ @@ -146,7 +52,7 @@ async def ingest_alert(request: AlertIngestRequest): """ Manuell einen Alert importieren. - Nützlich für Tests oder manuelles Hinzufügen von Artikeln. + Nuetzlich fuer Tests oder manuelles Hinzufuegen von Artikeln. """ alert = AlertItem( title=request.title, @@ -168,13 +74,13 @@ async def ingest_alert(request: AlertIngestRequest): @router.post("/run", response_model=AlertRunResponse) async def run_scoring_pipeline(request: AlertRunRequest): """ - Scoring-Pipeline für neue Alerts starten. + Scoring-Pipeline fuer neue Alerts starten. Bewertet alle unbewerteten Alerts und klassifiziert sie in KEEP, DROP oder REVIEW. - Wenn ALERTS_USE_LLM=true, wird das LLM Gateway für Scoring verwendet. - Sonst wird ein schnelles Keyword-basiertes Scoring durchgeführt. + Wenn ALERTS_USE_LLM=true, wird das LLM Gateway fuer Scoring verwendet. + Sonst wird ein schnelles Keyword-basiertes Scoring durchgefuehrt. """ import time start = time.time() @@ -193,7 +99,7 @@ async def run_scoring_pipeline(request: AlertRunRequest): keep = drop = review = errors = 0 - # Profil für Scoring laden + # Profil fuer Scoring laden profile = _profile_store.get("default") if not profile: profile = RelevanceProfile.create_default_education_profile() @@ -201,7 +107,7 @@ async def run_scoring_pipeline(request: AlertRunRequest): _profile_store["default"] = profile if ALERTS_USE_LLM and LLM_API_KEY: - # LLM-basiertes Scoring über Gateway + # LLM-basiertes Scoring ueber Gateway scorer = RelevanceScorer( gateway_url=LLM_GATEWAY_URL, api_key=LLM_API_KEY, @@ -227,12 +133,12 @@ async def run_scoring_pipeline(request: AlertRunRequest): snippet_lower = (alert.snippet or "").lower() combined = title_lower + " " + snippet_lower - # Ausschlüsse aus Profil prüfen + # Ausschluesse aus Profil pruefen if any(excl.lower() in combined for excl in profile.exclusions): alert.relevance_score = 0.15 alert.relevance_decision = RelevanceDecision.DROP.value drop += 1 - # Prioritäten aus Profil prüfen + # Prioritaeten aus Profil pruefen elif any( p.label.lower() in combined or any(kw.lower() in combined for kw in (p.keywords if hasattr(p, 'keywords') else [])) @@ -285,9 +191,9 @@ async def get_inbox( # Pagination total = len(alerts) - start = (page - 1) * page_size - end = start + page_size - page_alerts = alerts[start:end] + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + page_alerts = alerts[start_idx:end_idx] items = [ InboxItem( @@ -327,7 +233,7 @@ async def submit_feedback(request: FeedbackRequest): # Alert Status aktualisieren alert.status = AlertStatus.REVIEWED - # Profile aktualisieren (Default-Profile für Demo) + # Profile aktualisieren (Default-Profile fuer Demo) profile = _profile_store.get("default") if not profile: profile = RelevanceProfile.create_default_education_profile() @@ -353,7 +259,7 @@ async def get_profile(user_id: Optional[str] = Query(default=None)): """ Relevanz-Profil abrufen. - Ohne user_id wird das Default-Profil zurückgegeben. + Ohne user_id wird das Default-Profil zurueckgegeben. """ profile_id = user_id or "default" profile = _profile_store.get(profile_id) @@ -385,7 +291,7 @@ async def update_profile( """ Relevanz-Profil aktualisieren. - Erlaubt Anpassung von Prioritäten, Ausschlüssen und Policies. + Erlaubt Anpassung von Prioritaeten, Ausschluessen und Policies. """ profile_id = user_id or "default" profile = _profile_store.get(profile_id) @@ -431,34 +337,24 @@ async def update_profile( @router.get("/stats") async def get_stats(): """ - Statistiken über Alerts und Scoring. - - Gibt Statistiken im Format zurück, das das Frontend erwartet: - - total_alerts, new_alerts, kept_alerts, review_alerts, dropped_alerts - - total_topics, active_topics, total_rules + Statistiken ueber Alerts und Scoring. """ alerts = list(_alerts_store.values()) total = len(alerts) - # Zähle nach Status und Decision new_alerts = sum(1 for a in alerts if a.status == AlertStatus.NEW) kept_alerts = sum(1 for a in alerts if a.relevance_decision == "KEEP") review_alerts = sum(1 for a in alerts if a.relevance_decision == "REVIEW") dropped_alerts = sum(1 for a in alerts if a.relevance_decision == "DROP") - # Topics und Rules (In-Memory hat diese nicht, aber wir geben 0 zurück) - # Bei DB-Implementierung würden wir hier die Repositories nutzen total_topics = 0 active_topics = 0 total_rules = 0 - # Versuche DB-Statistiken zu laden wenn verfügbar try: from alerts_agent.db import get_db from alerts_agent.db.repository import TopicRepository, RuleRepository - from contextlib import contextmanager - # Versuche eine DB-Session zu bekommen db_gen = get_db() db = next(db_gen, None) if db: @@ -478,15 +374,12 @@ async def get_stats(): except StopIteration: pass except Exception: - # DB nicht verfügbar, nutze In-Memory Defaults pass - # Berechne Durchschnittsscore scored_alerts = [a for a in alerts if a.relevance_score is not None] avg_score = sum(a.relevance_score for a in scored_alerts) / len(scored_alerts) if scored_alerts else 0.0 return { - # Frontend-kompatibles Format "total_alerts": total, "new_alerts": new_alerts, "kept_alerts": kept_alerts, @@ -496,7 +389,6 @@ async def get_stats(): "active_topics": active_topics, "total_rules": total_rules, "avg_score": avg_score, - # Zusätzliche Details (Abwärtskompatibilität) "by_status": { "new": new_alerts, "scored": sum(1 for a in alerts if a.status == AlertStatus.SCORED), diff --git a/backend-lehrer/alerts_agent/api/schemas.py b/backend-lehrer/alerts_agent/api/schemas.py new file mode 100644 index 0000000..3f0b302 --- /dev/null +++ b/backend-lehrer/alerts_agent/api/schemas.py @@ -0,0 +1,111 @@ +""" +Request/Response Schemas fuer Alerts Agent API. +""" + +from datetime import datetime +from typing import Optional +from pydantic import BaseModel, Field + + +# ============================================================================ +# Request Models +# ============================================================================ + +class AlertIngestRequest(BaseModel): + """Request fuer manuelles Alert-Import.""" + title: str = Field(..., min_length=1, max_length=500) + url: str = Field(..., min_length=1) + snippet: Optional[str] = Field(default=None, max_length=2000) + topic_label: str = Field(default="Manual Import") + published_at: Optional[datetime] = None + + +class AlertRunRequest(BaseModel): + """Request fuer Scoring-Pipeline.""" + limit: int = Field(default=50, ge=1, le=200) + skip_scored: bool = Field(default=True) + + +class FeedbackRequest(BaseModel): + """Request fuer Relevanz-Feedback.""" + alert_id: str + is_relevant: bool + reason: Optional[str] = None + tags: list[str] = Field(default_factory=list) + + +class ProfilePriorityRequest(BaseModel): + """Priority fuer Profile-Update.""" + label: str + weight: float = Field(default=0.5, ge=0.0, le=1.0) + keywords: list[str] = Field(default_factory=list) + description: Optional[str] = None + + +class ProfileUpdateRequest(BaseModel): + """Request fuer Profile-Update.""" + priorities: Optional[list[ProfilePriorityRequest]] = None + exclusions: Optional[list[str]] = None + policies: Optional[dict] = None + + +# ============================================================================ +# Response Models +# ============================================================================ + +class AlertIngestResponse(BaseModel): + """Response fuer Alert-Import.""" + id: str + status: str + message: str + + +class AlertRunResponse(BaseModel): + """Response fuer Scoring-Pipeline.""" + processed: int + keep: int + drop: int + review: int + errors: int + duration_ms: int + + +class InboxItem(BaseModel): + """Ein Item in der Inbox.""" + id: str + title: str + url: str + snippet: Optional[str] + topic_label: str + published_at: Optional[datetime] + relevance_score: Optional[float] + relevance_decision: Optional[str] + relevance_summary: Optional[str] + status: str + + +class InboxResponse(BaseModel): + """Response fuer Inbox-Abfrage.""" + items: list[InboxItem] + total: int + page: int + page_size: int + + +class FeedbackResponse(BaseModel): + """Response fuer Feedback.""" + success: bool + message: str + profile_updated: bool + + +class ProfileResponse(BaseModel): + """Response fuer Profile.""" + id: str + priorities: list[dict] + exclusions: list[str] + policies: dict + total_scored: int + total_kept: int + total_dropped: int + accuracy_estimate: Optional[float] diff --git a/backend-lehrer/alerts_agent/api/wizard.py b/backend-lehrer/alerts_agent/api/wizard.py index c4010ca..ea4a059 100644 --- a/backend-lehrer/alerts_agent/api/wizard.py +++ b/backend-lehrer/alerts_agent/api/wizard.py @@ -7,21 +7,12 @@ Verwaltet den 3-Schritt Setup-Wizard: 3. Bestätigung und Aktivierung Zusätzlich: Migration-Wizard für bestehende Google Alerts. - -Endpoints: -- GET /wizard/state - Aktuellen Wizard-Status abrufen -- PUT /wizard/step/{step} - Schritt speichern -- POST /wizard/complete - Wizard abschließen -- POST /wizard/reset - Wizard zurücksetzen -- POST /wizard/migrate/email - E-Mail-Migration starten -- POST /wizard/migrate/rss - RSS-Import """ import uuid -from typing import Optional, List, Dict, Any +from typing import List, Dict, Any from datetime import datetime -from fastapi import APIRouter, Depends, HTTPException, Query -from pydantic import BaseModel, Field +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session as DBSession from ..db.database import get_db @@ -29,77 +20,22 @@ from ..db.models import ( UserAlertSubscriptionDB, AlertTemplateDB, AlertSourceDB, AlertModeEnum, UserRoleEnum, MigrationModeEnum, FeedTypeEnum ) +from .wizard_models import ( + WizardState, + Step1Data, + Step2Data, + Step3Data, + StepResponse, + MigrateEmailRequest, + MigrateEmailResponse, + MigrateRssRequest, + MigrateRssResponse, +) router = APIRouter(prefix="/wizard", tags=["wizard"]) -# ============================================================================ -# Request/Response Models -# ============================================================================ - -class WizardState(BaseModel): - """Aktueller Wizard-Status.""" - subscription_id: Optional[str] = None - current_step: int = 0 # 0=nicht gestartet, 1-3=Schritte, 4=abgeschlossen - is_completed: bool = False - step_data: Dict[str, Any] = {} - recommended_templates: List[Dict[str, Any]] = [] - - -class Step1Data(BaseModel): - """Daten für Schritt 1: Rollenwahl.""" - role: str = Field(..., description="lehrkraft, schulleitung, it_beauftragte") - - -class Step2Data(BaseModel): - """Daten für Schritt 2: Template-Auswahl.""" - template_ids: List[str] = Field(..., min_length=1, max_length=3) - - -class Step3Data(BaseModel): - """Daten für Schritt 3: Bestätigung.""" - notification_email: Optional[str] = None - digest_enabled: bool = True - digest_frequency: str = "weekly" - - -class StepResponse(BaseModel): - """Response für Schritt-Update.""" - status: str - current_step: int - next_step: int - message: str - recommended_templates: List[Dict[str, Any]] = [] - - -class MigrateEmailRequest(BaseModel): - """Request für E-Mail-Migration.""" - original_label: Optional[str] = Field(default=None, description="Beschreibung des Alerts") - - -class MigrateEmailResponse(BaseModel): - """Response für E-Mail-Migration.""" - status: str - inbound_address: str - instructions: List[str] - source_id: str - - -class MigrateRssRequest(BaseModel): - """Request für RSS-Import.""" - rss_urls: List[str] = Field(..., min_length=1, max_length=20) - labels: Optional[List[str]] = None - - -class MigrateRssResponse(BaseModel): - """Response für RSS-Import.""" - status: str - sources_created: int - topics_created: int - message: str - - # ============================================================================ # Helper Functions # ============================================================================ @@ -144,13 +80,9 @@ def _get_recommended_templates(db: DBSession, role: str) -> List[Dict[str, Any]] for t in templates: if role in (t.target_roles or []): result.append({ - "id": t.id, - "slug": t.slug, - "name": t.name, - "description": t.description, - "icon": t.icon, - "category": t.category, - "recommended": True, + "id": t.id, "slug": t.slug, "name": t.name, + "description": t.description, "icon": t.icon, + "category": t.category, "recommended": True, }) return result @@ -167,14 +99,8 @@ def _generate_inbound_address(user_id: str, source_id: str) -> str: # ============================================================================ @router.get("/state", response_model=WizardState) -async def get_wizard_state( - db: DBSession = Depends(get_db) -): - """ - Hole aktuellen Wizard-Status. - - Gibt Schritt, gespeicherte Daten und empfohlene Templates zurück. - """ +async def get_wizard_state(db: DBSession = Depends(get_db)): + """Hole aktuellen Wizard-Status.""" user_id = get_user_id_from_request() subscription = db.query(UserAlertSubscriptionDB).filter( @@ -182,15 +108,8 @@ async def get_wizard_state( ).order_by(UserAlertSubscriptionDB.created_at.desc()).first() if not subscription: - return WizardState( - subscription_id=None, - current_step=0, - is_completed=False, - step_data={}, - recommended_templates=[], - ) + return WizardState() - # Empfohlene Templates basierend auf Rolle role = subscription.user_role.value if subscription.user_role else None recommended = _get_recommended_templates(db, role) if role else [] @@ -204,61 +123,37 @@ async def get_wizard_state( @router.put("/step/1", response_model=StepResponse) -async def save_step_1( - data: Step1Data, - db: DBSession = Depends(get_db) -): - """ - Schritt 1: Rolle speichern. - - Wählt die Rolle des Nutzers und gibt passende Template-Empfehlungen. - """ +async def save_step_1(data: Step1Data, db: DBSession = Depends(get_db)): + """Schritt 1: Rolle speichern.""" user_id = get_user_id_from_request() - # Validiere Rolle try: role = UserRoleEnum(data.role) except ValueError: - raise HTTPException( - status_code=400, - detail="Ungültige Rolle. Erlaubt: 'lehrkraft', 'schulleitung', 'it_beauftragte'" - ) + raise HTTPException(status_code=400, detail="Ungültige Rolle. Erlaubt: 'lehrkraft', 'schulleitung', 'it_beauftragte'") subscription = _get_or_create_subscription(db, user_id) - - # Update subscription.user_role = role subscription.wizard_step = 1 wizard_state = subscription.wizard_state or {} wizard_state["step1"] = {"role": data.role} subscription.wizard_state = wizard_state subscription.updated_at = datetime.utcnow() - db.commit() db.refresh(subscription) - # Empfohlene Templates recommended = _get_recommended_templates(db, data.role) return StepResponse( - status="success", - current_step=1, - next_step=2, + status="success", current_step=1, next_step=2, message=f"Rolle '{data.role}' gespeichert. Bitte wählen Sie jetzt Ihre Themen.", recommended_templates=recommended, ) @router.put("/step/2", response_model=StepResponse) -async def save_step_2( - data: Step2Data, - db: DBSession = Depends(get_db) -): - """ - Schritt 2: Templates auswählen. - - Speichert die ausgewählten Templates (1-3). - """ +async def save_step_2(data: Step2Data, db: DBSession = Depends(get_db)): + """Schritt 2: Templates auswählen.""" user_id = get_user_id_from_request() subscription = db.query(UserAlertSubscriptionDB).filter( @@ -269,46 +164,28 @@ async def save_step_2( if not subscription: raise HTTPException(status_code=400, detail="Bitte zuerst Schritt 1 abschließen") - # Validiere Template-IDs - templates = db.query(AlertTemplateDB).filter( - AlertTemplateDB.id.in_(data.template_ids) - ).all() + templates = db.query(AlertTemplateDB).filter(AlertTemplateDB.id.in_(data.template_ids)).all() if len(templates) != len(data.template_ids): raise HTTPException(status_code=400, detail="Eine oder mehrere Template-IDs sind ungültig") - # Update subscription.selected_template_ids = data.template_ids subscription.wizard_step = 2 wizard_state = subscription.wizard_state or {} - wizard_state["step2"] = { - "template_ids": data.template_ids, - "template_names": [t.name for t in templates], - } + wizard_state["step2"] = {"template_ids": data.template_ids, "template_names": [t.name for t in templates]} subscription.wizard_state = wizard_state subscription.updated_at = datetime.utcnow() - db.commit() return StepResponse( - status="success", - current_step=2, - next_step=3, + status="success", current_step=2, next_step=3, message=f"{len(templates)} Themen ausgewählt. Bitte bestätigen Sie Ihre Auswahl.", - recommended_templates=[], ) @router.put("/step/3", response_model=StepResponse) -async def save_step_3( - data: Step3Data, - db: DBSession = Depends(get_db) -): - """ - Schritt 3: Digest-Einstellungen und Bestätigung. - - Speichert E-Mail und Digest-Präferenzen. - """ +async def save_step_3(data: Step3Data, db: DBSession = Depends(get_db)): + """Schritt 3: Digest-Einstellungen und Bestätigung.""" user_id = get_user_id_from_request() subscription = db.query(UserAlertSubscriptionDB).filter( @@ -318,16 +195,13 @@ async def save_step_3( if not subscription: raise HTTPException(status_code=400, detail="Bitte zuerst Schritte 1 und 2 abschließen") - if not subscription.selected_template_ids: raise HTTPException(status_code=400, detail="Bitte zuerst Templates auswählen (Schritt 2)") - # Update subscription.notification_email = data.notification_email subscription.digest_enabled = data.digest_enabled subscription.digest_frequency = data.digest_frequency subscription.wizard_step = 3 - wizard_state = subscription.wizard_state or {} wizard_state["step3"] = { "notification_email": data.notification_email, @@ -336,27 +210,17 @@ async def save_step_3( } subscription.wizard_state = wizard_state subscription.updated_at = datetime.utcnow() - db.commit() return StepResponse( - status="success", - current_step=3, - next_step=4, + status="success", current_step=3, next_step=4, message="Einstellungen gespeichert. Klicken Sie auf 'Jetzt starten' um den Wizard abzuschließen.", - recommended_templates=[], ) @router.post("/complete") -async def complete_wizard( - db: DBSession = Depends(get_db) -): - """ - Wizard abschließen und Templates aktivieren. - - Erstellt Topics, Rules und Profile basierend auf den gewählten Templates. - """ +async def complete_wizard(db: DBSession = Depends(get_db)): + """Wizard abschließen und Templates aktivieren.""" user_id = get_user_id_from_request() subscription = db.query(UserAlertSubscriptionDB).filter( @@ -366,18 +230,14 @@ async def complete_wizard( if not subscription: raise HTTPException(status_code=400, detail="Kein aktiver Wizard gefunden") - if not subscription.selected_template_ids: raise HTTPException(status_code=400, detail="Bitte zuerst Templates auswählen") - # Aktiviere Templates (über Subscription-Endpoint) from .subscriptions import activate_template, ActivateTemplateRequest - # Markiere als abgeschlossen subscription.wizard_completed = True subscription.wizard_step = 4 subscription.updated_at = datetime.utcnow() - db.commit() return { @@ -390,9 +250,7 @@ async def complete_wizard( @router.post("/reset") -async def reset_wizard( - db: DBSession = Depends(get_db) -): +async def reset_wizard(db: DBSession = Depends(get_db)): """Wizard zurücksetzen (für Neustart).""" user_id = get_user_id_from_request() @@ -405,10 +263,7 @@ async def reset_wizard( db.delete(subscription) db.commit() - return { - "status": "success", - "message": "Wizard zurückgesetzt. Sie können neu beginnen.", - } + return {"status": "success", "message": "Wizard zurückgesetzt. Sie können neu beginnen."} # ============================================================================ @@ -416,29 +271,16 @@ async def reset_wizard( # ============================================================================ @router.post("/migrate/email", response_model=MigrateEmailResponse) -async def start_email_migration( - request: MigrateEmailRequest = None, - db: DBSession = Depends(get_db) -): - """ - Starte E-Mail-Migration für bestehende Google Alerts. - - Generiert eine eindeutige Inbound-E-Mail-Adresse, an die der Nutzer - seine Google Alerts weiterleiten kann. - """ +async def start_email_migration(request: MigrateEmailRequest = None, db: DBSession = Depends(get_db)): + """Starte E-Mail-Migration für bestehende Google Alerts.""" user_id = get_user_id_from_request() - # Erstelle AlertSource source = AlertSourceDB( - id=str(uuid.uuid4()), - user_id=user_id, + id=str(uuid.uuid4()), user_id=user_id, source_type=FeedTypeEnum.EMAIL, original_label=request.original_label if request else "Google Alert Migration", - migration_mode=MigrationModeEnum.FORWARD, - is_active=True, + migration_mode=MigrationModeEnum.FORWARD, is_active=True, ) - - # Generiere Inbound-Adresse source.inbound_address = _generate_inbound_address(user_id, source.id) db.add(source) @@ -446,9 +288,7 @@ async def start_email_migration( db.refresh(source) return MigrateEmailResponse( - status="success", - inbound_address=source.inbound_address, - source_id=source.id, + status="success", inbound_address=source.inbound_address, source_id=source.id, instructions=[ "1. Öffnen Sie Google Alerts (google.com/alerts)", "2. Klicken Sie auf das Bearbeiten-Symbol bei Ihrem Alert", @@ -460,74 +300,49 @@ async def start_email_migration( @router.post("/migrate/rss", response_model=MigrateRssResponse) -async def import_rss_feeds( - request: MigrateRssRequest, - db: DBSession = Depends(get_db) -): - """ - Importiere bestehende Google Alert RSS-Feeds. - - Erstellt für jede RSS-URL einen AlertSource und Topic. - """ +async def import_rss_feeds(request: MigrateRssRequest, db: DBSession = Depends(get_db)): + """Importiere bestehende Google Alert RSS-Feeds.""" user_id = get_user_id_from_request() from ..db.models import AlertTopicDB - sources_created = 0 - topics_created = 0 + sources_created, topics_created = 0, 0 for i, url in enumerate(request.rss_urls): - # Label aus Request oder generieren label = None if request.labels and i < len(request.labels): label = request.labels[i] if not label: label = f"RSS Feed {i + 1}" - # Erstelle AlertSource source = AlertSourceDB( - id=str(uuid.uuid4()), - user_id=user_id, - source_type=FeedTypeEnum.RSS, - original_label=label, - rss_url=url, - migration_mode=MigrationModeEnum.IMPORT, - is_active=True, + id=str(uuid.uuid4()), user_id=user_id, + source_type=FeedTypeEnum.RSS, original_label=label, + rss_url=url, migration_mode=MigrationModeEnum.IMPORT, is_active=True, ) db.add(source) sources_created += 1 - # Erstelle Topic topic = AlertTopicDB( - id=str(uuid.uuid4()), - user_id=user_id, - name=label, - description=f"Importiert aus RSS: {url[:50]}...", - feed_url=url, - feed_type=FeedTypeEnum.RSS, - is_active=True, - fetch_interval_minutes=60, + id=str(uuid.uuid4()), user_id=user_id, + name=label, description=f"Importiert aus RSS: {url[:50]}...", + feed_url=url, feed_type=FeedTypeEnum.RSS, + is_active=True, fetch_interval_minutes=60, ) db.add(topic) - - # Verknüpfe Source mit Topic source.topic_id = topic.id topics_created += 1 db.commit() return MigrateRssResponse( - status="success", - sources_created=sources_created, - topics_created=topics_created, + status="success", sources_created=sources_created, topics_created=topics_created, message=f"{sources_created} RSS-Feeds importiert. Die Alerts werden automatisch abgerufen.", ) @router.get("/migrate/sources") -async def list_migration_sources( - db: DBSession = Depends(get_db) -): +async def list_migration_sources(db: DBSession = Depends(get_db)): """Liste alle Migration-Quellen des Users.""" user_id = get_user_id_from_request() diff --git a/backend-lehrer/alerts_agent/api/wizard_models.py b/backend-lehrer/alerts_agent/api/wizard_models.py new file mode 100644 index 0000000..6cfea14 --- /dev/null +++ b/backend-lehrer/alerts_agent/api/wizard_models.py @@ -0,0 +1,68 @@ +""" +Wizard API - Request/Response Models. +""" + +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + + +class WizardState(BaseModel): + """Aktueller Wizard-Status.""" + subscription_id: Optional[str] = None + current_step: int = 0 # 0=nicht gestartet, 1-3=Schritte, 4=abgeschlossen + is_completed: bool = False + step_data: Dict[str, Any] = {} + recommended_templates: List[Dict[str, Any]] = [] + + +class Step1Data(BaseModel): + """Daten für Schritt 1: Rollenwahl.""" + role: str = Field(..., description="lehrkraft, schulleitung, it_beauftragte") + + +class Step2Data(BaseModel): + """Daten für Schritt 2: Template-Auswahl.""" + template_ids: List[str] = Field(..., min_length=1, max_length=3) + + +class Step3Data(BaseModel): + """Daten für Schritt 3: Bestätigung.""" + notification_email: Optional[str] = None + digest_enabled: bool = True + digest_frequency: str = "weekly" + + +class StepResponse(BaseModel): + """Response für Schritt-Update.""" + status: str + current_step: int + next_step: int + message: str + recommended_templates: List[Dict[str, Any]] = [] + + +class MigrateEmailRequest(BaseModel): + """Request für E-Mail-Migration.""" + original_label: Optional[str] = Field(default=None, description="Beschreibung des Alerts") + + +class MigrateEmailResponse(BaseModel): + """Response für E-Mail-Migration.""" + status: str + inbound_address: str + instructions: List[str] + source_id: str + + +class MigrateRssRequest(BaseModel): + """Request für RSS-Import.""" + rss_urls: List[str] = Field(..., min_length=1, max_length=20) + labels: Optional[List[str]] = None + + +class MigrateRssResponse(BaseModel): + """Response für RSS-Import.""" + status: str + sources_created: int + topics_created: int + message: str diff --git a/backend-lehrer/alerts_agent/processing/rule_engine.py b/backend-lehrer/alerts_agent/processing/rule_engine.py index 1eede0e..01dc3a4 100644 --- a/backend-lehrer/alerts_agent/processing/rule_engine.py +++ b/backend-lehrer/alerts_agent/processing/rule_engine.py @@ -2,277 +2,49 @@ Rule Engine für Alerts Agent. Evaluiert Regeln gegen Alert-Items und führt Aktionen aus. - -Regel-Struktur: -- Bedingungen: [{field, operator, value}, ...] (AND-verknüpft) -- Aktion: keep, drop, tag, email, webhook, slack -- Priorität: Höhere Priorität wird zuerst evaluiert +Batch-Verarbeitung und Action-Anwendung. """ -import re + import logging -from dataclasses import dataclass -from typing import List, Dict, Any, Optional, Callable -from enum import Enum +from typing import List, Dict, Any, Optional from alerts_agent.db.models import AlertItemDB, AlertRuleDB, RuleActionEnum +from .rule_models import ( + ConditionOperator, + RuleCondition, + RuleMatch, + get_field_value, + evaluate_condition, + evaluate_rule, + evaluate_rules_for_alert, + create_keyword_rule, + create_exclusion_rule, + create_score_threshold_rule, +) + logger = logging.getLogger(__name__) - -class ConditionOperator(str, Enum): - """Operatoren für Regel-Bedingungen.""" - CONTAINS = "contains" - NOT_CONTAINS = "not_contains" - EQUALS = "equals" - NOT_EQUALS = "not_equals" - STARTS_WITH = "starts_with" - ENDS_WITH = "ends_with" - REGEX = "regex" - GREATER_THAN = "gt" - LESS_THAN = "lt" - GREATER_EQUAL = "gte" - LESS_EQUAL = "lte" - IN_LIST = "in" - NOT_IN_LIST = "not_in" - - -@dataclass -class RuleCondition: - """Eine einzelne Regel-Bedingung.""" - field: str # "title", "snippet", "url", "source", "relevance_score" - operator: ConditionOperator - value: Any # str, float, list - - @classmethod - def from_dict(cls, data: Dict) -> "RuleCondition": - """Erstellt eine Bedingung aus einem Dict.""" - return cls( - field=data.get("field", ""), - operator=ConditionOperator(data.get("operator", data.get("op", "contains"))), - value=data.get("value", ""), - ) - - -@dataclass -class RuleMatch: - """Ergebnis einer Regel-Evaluierung.""" - rule_id: str - rule_name: str - matched: bool - action: RuleActionEnum - action_config: Dict[str, Any] - conditions_met: List[str] # Welche Bedingungen haben gematched - - -def get_field_value(alert: AlertItemDB, field: str) -> Any: - """ - Extrahiert einen Feldwert aus einem Alert. - - Args: - alert: Alert-Item - field: Feldname - - Returns: - Feldwert oder None - """ - field_map = { - "title": alert.title, - "snippet": alert.snippet, - "url": alert.url, - "source": alert.source.value if alert.source else "", - "status": alert.status.value if alert.status else "", - "relevance_score": alert.relevance_score, - "relevance_decision": alert.relevance_decision.value if alert.relevance_decision else "", - "lang": alert.lang, - "topic_id": alert.topic_id, - } - - return field_map.get(field) - - -def evaluate_condition( - alert: AlertItemDB, - condition: RuleCondition, -) -> bool: - """ - Evaluiert eine einzelne Bedingung gegen einen Alert. - - Args: - alert: Alert-Item - condition: Zu evaluierende Bedingung - - Returns: - True wenn Bedingung erfüllt - """ - field_value = get_field_value(alert, condition.field) - - if field_value is None: - return False - - op = condition.operator - target = condition.value - - try: - # String-Operationen (case-insensitive) - if isinstance(field_value, str): - field_lower = field_value.lower() - target_lower = str(target).lower() if isinstance(target, str) else target - - if op == ConditionOperator.CONTAINS: - return target_lower in field_lower - - elif op == ConditionOperator.NOT_CONTAINS: - return target_lower not in field_lower - - elif op == ConditionOperator.EQUALS: - return field_lower == target_lower - - elif op == ConditionOperator.NOT_EQUALS: - return field_lower != target_lower - - elif op == ConditionOperator.STARTS_WITH: - return field_lower.startswith(target_lower) - - elif op == ConditionOperator.ENDS_WITH: - return field_lower.endswith(target_lower) - - elif op == ConditionOperator.REGEX: - try: - return bool(re.search(str(target), field_value, re.IGNORECASE)) - except re.error: - logger.warning(f"Invalid regex pattern: {target}") - return False - - elif op == ConditionOperator.IN_LIST: - if isinstance(target, list): - return any(t.lower() in field_lower for t in target if isinstance(t, str)) - return False - - elif op == ConditionOperator.NOT_IN_LIST: - if isinstance(target, list): - return not any(t.lower() in field_lower for t in target if isinstance(t, str)) - return True - - # Numerische Operationen - elif isinstance(field_value, (int, float)): - target_num = float(target) if target else 0 - - if op == ConditionOperator.EQUALS: - return field_value == target_num - - elif op == ConditionOperator.NOT_EQUALS: - return field_value != target_num - - elif op == ConditionOperator.GREATER_THAN: - return field_value > target_num - - elif op == ConditionOperator.LESS_THAN: - return field_value < target_num - - elif op == ConditionOperator.GREATER_EQUAL: - return field_value >= target_num - - elif op == ConditionOperator.LESS_EQUAL: - return field_value <= target_num - - except Exception as e: - logger.error(f"Error evaluating condition: {e}") - return False - - return False - - -def evaluate_rule( - alert: AlertItemDB, - rule: AlertRuleDB, -) -> RuleMatch: - """ - Evaluiert eine Regel gegen einen Alert. - - Alle Bedingungen müssen erfüllt sein (AND-Verknüpfung). - - Args: - alert: Alert-Item - rule: Zu evaluierende Regel - - Returns: - RuleMatch-Ergebnis - """ - conditions = rule.conditions or [] - conditions_met = [] - all_matched = True - - for cond_dict in conditions: - condition = RuleCondition.from_dict(cond_dict) - if evaluate_condition(alert, condition): - conditions_met.append(f"{condition.field} {condition.operator.value} {condition.value}") - else: - all_matched = False - - # Wenn keine Bedingungen definiert sind, matcht die Regel immer - if not conditions: - all_matched = True - - return RuleMatch( - rule_id=rule.id, - rule_name=rule.name, - matched=all_matched, - action=rule.action_type, - action_config=rule.action_config or {}, - conditions_met=conditions_met, - ) - - -def evaluate_rules_for_alert( - alert: AlertItemDB, - rules: List[AlertRuleDB], -) -> Optional[RuleMatch]: - """ - Evaluiert alle Regeln gegen einen Alert und gibt den ersten Match zurück. - - Regeln werden nach Priorität (absteigend) evaluiert. - - Args: - alert: Alert-Item - rules: Liste von Regeln (sollte bereits nach Priorität sortiert sein) - - Returns: - Erster RuleMatch oder None - """ - for rule in rules: - if not rule.is_active: - continue - - # Topic-Filter: Regel gilt nur für bestimmtes Topic - if rule.topic_id and rule.topic_id != alert.topic_id: - continue - - match = evaluate_rule(alert, rule) - - if match.matched: - logger.debug( - f"Rule '{rule.name}' matched alert '{alert.id[:8]}': " - f"{match.conditions_met}" - ) - return match - - return None +# Re-export for backward compatibility +__all__ = [ + "ConditionOperator", + "RuleCondition", + "RuleMatch", + "get_field_value", + "evaluate_condition", + "evaluate_rule", + "evaluate_rules_for_alert", + "RuleEngine", + "create_keyword_rule", + "create_exclusion_rule", + "create_score_threshold_rule", +] class RuleEngine: - """ - Rule Engine für Batch-Verarbeitung von Alerts. - - Verwendet für das Scoring von mehreren Alerts gleichzeitig. - """ + """Rule Engine für Batch-Verarbeitung von Alerts.""" def __init__(self, db_session): - """ - Initialisiert die Rule Engine. - - Args: - db_session: SQLAlchemy Session - """ self.db = db_session self._rules_cache: Optional[List[AlertRuleDB]] = None @@ -282,42 +54,19 @@ class RuleEngine: from alerts_agent.db.repository import RuleRepository repo = RuleRepository(self.db) self._rules_cache = repo.get_active() - return self._rules_cache def clear_cache(self) -> None: """Leert den Regel-Cache.""" self._rules_cache = None - def process_alert( - self, - alert: AlertItemDB, - ) -> Optional[RuleMatch]: - """ - Verarbeitet einen Alert mit allen aktiven Regeln. - - Args: - alert: Alert-Item - - Returns: - RuleMatch wenn eine Regel matcht, sonst None - """ + def process_alert(self, alert: AlertItemDB) -> Optional[RuleMatch]: + """Verarbeitet einen Alert mit allen aktiven Regeln.""" rules = self._get_active_rules() return evaluate_rules_for_alert(alert, rules) - def process_alerts( - self, - alerts: List[AlertItemDB], - ) -> Dict[str, RuleMatch]: - """ - Verarbeitet mehrere Alerts mit allen aktiven Regeln. - - Args: - alerts: Liste von Alert-Items - - Returns: - Dict von alert_id -> RuleMatch (nur für gematschte Alerts) - """ + def process_alerts(self, alerts: List[AlertItemDB]) -> Dict[str, RuleMatch]: + """Verarbeitet mehrere Alerts mit allen aktiven Regeln.""" rules = self._get_active_rules() results = {} @@ -328,21 +77,8 @@ class RuleEngine: return results - def apply_rule_actions( - self, - alert: AlertItemDB, - match: RuleMatch, - ) -> Dict[str, Any]: - """ - Wendet die Regel-Aktion auf einen Alert an. - - Args: - alert: Alert-Item - match: RuleMatch mit Aktionsinformationen - - Returns: - Dict mit Ergebnis der Aktion - """ + def apply_rule_actions(self, alert: AlertItemDB, match: RuleMatch) -> Dict[str, Any]: + """Wendet die Regel-Aktion auf einen Alert an.""" from alerts_agent.db.repository import AlertItemRepository, RuleRepository alert_repo = AlertItemRepository(self.db) @@ -350,36 +86,26 @@ class RuleEngine: action = match.action config = match.action_config - result = {"action": action.value, "success": False} try: if action == RuleActionEnum.KEEP: - # Alert als KEEP markieren alert_repo.update_scoring( - alert_id=alert.id, - score=1.0, - decision="KEEP", - reasons=["rule_match"], - summary=f"Matched rule: {match.rule_name}", + alert_id=alert.id, score=1.0, decision="KEEP", + reasons=["rule_match"], summary=f"Matched rule: {match.rule_name}", model="rule_engine", ) result["success"] = True elif action == RuleActionEnum.DROP: - # Alert als DROP markieren alert_repo.update_scoring( - alert_id=alert.id, - score=0.0, - decision="DROP", - reasons=["rule_match"], - summary=f"Dropped by rule: {match.rule_name}", + alert_id=alert.id, score=0.0, decision="DROP", + reasons=["rule_match"], summary=f"Dropped by rule: {match.rule_name}", model="rule_engine", ) result["success"] = True elif action == RuleActionEnum.TAG: - # Tags hinzufügen tags = config.get("tags", []) if tags: existing_tags = alert.user_tags or [] @@ -389,27 +115,20 @@ class RuleEngine: result["success"] = True elif action == RuleActionEnum.EMAIL: - # E-Mail-Benachrichtigung senden - # Wird von Actions-Modul behandelt result["email_config"] = config result["success"] = True - result["deferred"] = True # Wird später gesendet + result["deferred"] = True elif action == RuleActionEnum.WEBHOOK: - # Webhook aufrufen - # Wird von Actions-Modul behandelt result["webhook_config"] = config result["success"] = True result["deferred"] = True elif action == RuleActionEnum.SLACK: - # Slack-Nachricht senden - # Wird von Actions-Modul behandelt result["slack_config"] = config result["success"] = True result["deferred"] = True - # Match-Count erhöhen rule_repo.increment_match_count(match.rule_id) except Exception as e: @@ -417,96 +136,3 @@ class RuleEngine: result["error"] = str(e) return result - - -# Convenience-Funktionen für einfache Nutzung -def create_keyword_rule( - name: str, - keywords: List[str], - action: str = "keep", - field: str = "title", -) -> Dict: - """ - Erstellt eine Keyword-basierte Regel. - - Args: - name: Regelname - keywords: Liste von Keywords (OR-verknüpft über IN_LIST) - action: Aktion (keep, drop, tag) - field: Feld zum Prüfen (title, snippet, url) - - Returns: - Regel-Definition als Dict - """ - return { - "name": name, - "conditions": [ - { - "field": field, - "operator": "in", - "value": keywords, - } - ], - "action_type": action, - "action_config": {}, - } - - -def create_exclusion_rule( - name: str, - excluded_terms: List[str], - field: str = "title", -) -> Dict: - """ - Erstellt eine Ausschluss-Regel. - - Args: - name: Regelname - excluded_terms: Liste von auszuschließenden Begriffen - field: Feld zum Prüfen - - Returns: - Regel-Definition als Dict - """ - return { - "name": name, - "conditions": [ - { - "field": field, - "operator": "in", - "value": excluded_terms, - } - ], - "action_type": "drop", - "action_config": {}, - } - - -def create_score_threshold_rule( - name: str, - min_score: float, - action: str = "keep", -) -> Dict: - """ - Erstellt eine Score-basierte Regel. - - Args: - name: Regelname - min_score: Mindest-Score - action: Aktion bei Erreichen des Scores - - Returns: - Regel-Definition als Dict - """ - return { - "name": name, - "conditions": [ - { - "field": "relevance_score", - "operator": "gte", - "value": min_score, - } - ], - "action_type": action, - "action_config": {}, - } diff --git a/backend-lehrer/alerts_agent/processing/rule_models.py b/backend-lehrer/alerts_agent/processing/rule_models.py new file mode 100644 index 0000000..974af96 --- /dev/null +++ b/backend-lehrer/alerts_agent/processing/rule_models.py @@ -0,0 +1,206 @@ +""" +Rule Engine - Models, Condition Evaluation, and Convenience Functions. + +Datenmodelle und Evaluierungs-Logik fuer Alert-Regeln. +""" + +import re +import logging +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +from enum import Enum + +from alerts_agent.db.models import AlertItemDB, AlertRuleDB, RuleActionEnum + +logger = logging.getLogger(__name__) + + +class ConditionOperator(str, Enum): + """Operatoren für Regel-Bedingungen.""" + CONTAINS = "contains" + NOT_CONTAINS = "not_contains" + EQUALS = "equals" + NOT_EQUALS = "not_equals" + STARTS_WITH = "starts_with" + ENDS_WITH = "ends_with" + REGEX = "regex" + GREATER_THAN = "gt" + LESS_THAN = "lt" + GREATER_EQUAL = "gte" + LESS_EQUAL = "lte" + IN_LIST = "in" + NOT_IN_LIST = "not_in" + + +@dataclass +class RuleCondition: + """Eine einzelne Regel-Bedingung.""" + field: str + operator: ConditionOperator + value: Any + + @classmethod + def from_dict(cls, data: Dict) -> "RuleCondition": + return cls( + field=data.get("field", ""), + operator=ConditionOperator(data.get("operator", data.get("op", "contains"))), + value=data.get("value", ""), + ) + + +@dataclass +class RuleMatch: + """Ergebnis einer Regel-Evaluierung.""" + rule_id: str + rule_name: str + matched: bool + action: RuleActionEnum + action_config: Dict[str, Any] + conditions_met: List[str] + + +def get_field_value(alert: AlertItemDB, field: str) -> Any: + """Extrahiert einen Feldwert aus einem Alert.""" + field_map = { + "title": alert.title, + "snippet": alert.snippet, + "url": alert.url, + "source": alert.source.value if alert.source else "", + "status": alert.status.value if alert.status else "", + "relevance_score": alert.relevance_score, + "relevance_decision": alert.relevance_decision.value if alert.relevance_decision else "", + "lang": alert.lang, + "topic_id": alert.topic_id, + } + return field_map.get(field) + + +def evaluate_condition(alert: AlertItemDB, condition: RuleCondition) -> bool: + """Evaluiert eine einzelne Bedingung gegen einen Alert.""" + field_value = get_field_value(alert, condition.field) + if field_value is None: + return False + + op = condition.operator + target = condition.value + + try: + if isinstance(field_value, str): + field_lower = field_value.lower() + target_lower = str(target).lower() if isinstance(target, str) else target + + if op == ConditionOperator.CONTAINS: + return target_lower in field_lower + elif op == ConditionOperator.NOT_CONTAINS: + return target_lower not in field_lower + elif op == ConditionOperator.EQUALS: + return field_lower == target_lower + elif op == ConditionOperator.NOT_EQUALS: + return field_lower != target_lower + elif op == ConditionOperator.STARTS_WITH: + return field_lower.startswith(target_lower) + elif op == ConditionOperator.ENDS_WITH: + return field_lower.endswith(target_lower) + elif op == ConditionOperator.REGEX: + try: + return bool(re.search(str(target), field_value, re.IGNORECASE)) + except re.error: + logger.warning(f"Invalid regex pattern: {target}") + return False + elif op == ConditionOperator.IN_LIST: + if isinstance(target, list): + return any(t.lower() in field_lower for t in target if isinstance(t, str)) + return False + elif op == ConditionOperator.NOT_IN_LIST: + if isinstance(target, list): + return not any(t.lower() in field_lower for t in target if isinstance(t, str)) + return True + + elif isinstance(field_value, (int, float)): + target_num = float(target) if target else 0 + if op == ConditionOperator.EQUALS: + return field_value == target_num + elif op == ConditionOperator.NOT_EQUALS: + return field_value != target_num + elif op == ConditionOperator.GREATER_THAN: + return field_value > target_num + elif op == ConditionOperator.LESS_THAN: + return field_value < target_num + elif op == ConditionOperator.GREATER_EQUAL: + return field_value >= target_num + elif op == ConditionOperator.LESS_EQUAL: + return field_value <= target_num + + except Exception as e: + logger.error(f"Error evaluating condition: {e}") + return False + + return False + + +def evaluate_rule(alert: AlertItemDB, rule: AlertRuleDB) -> RuleMatch: + """Evaluiert eine Regel gegen einen Alert (AND-Verknüpfung).""" + conditions = rule.conditions or [] + conditions_met = [] + all_matched = True + + for cond_dict in conditions: + condition = RuleCondition.from_dict(cond_dict) + if evaluate_condition(alert, condition): + conditions_met.append(f"{condition.field} {condition.operator.value} {condition.value}") + else: + all_matched = False + + if not conditions: + all_matched = True + + return RuleMatch( + rule_id=rule.id, rule_name=rule.name, matched=all_matched, + action=rule.action_type, action_config=rule.action_config or {}, + conditions_met=conditions_met, + ) + + +def evaluate_rules_for_alert(alert: AlertItemDB, rules: List[AlertRuleDB]) -> Optional[RuleMatch]: + """Evaluiert alle Regeln gegen einen Alert und gibt den ersten Match zurück.""" + for rule in rules: + if not rule.is_active: + continue + if rule.topic_id and rule.topic_id != alert.topic_id: + continue + + match = evaluate_rule(alert, rule) + if match.matched: + logger.debug(f"Rule '{rule.name}' matched alert '{alert.id[:8]}': {match.conditions_met}") + return match + + return None + + +# Convenience-Funktionen + +def create_keyword_rule(name: str, keywords: List[str], action: str = "keep", field: str = "title") -> Dict: + """Erstellt eine Keyword-basierte Regel.""" + return { + "name": name, + "conditions": [{"field": field, "operator": "in", "value": keywords}], + "action_type": action, "action_config": {}, + } + + +def create_exclusion_rule(name: str, excluded_terms: List[str], field: str = "title") -> Dict: + """Erstellt eine Ausschluss-Regel.""" + return { + "name": name, + "conditions": [{"field": field, "operator": "in", "value": excluded_terms}], + "action_type": "drop", "action_config": {}, + } + + +def create_score_threshold_rule(name: str, min_score: float, action: str = "keep") -> Dict: + """Erstellt eine Score-basierte Regel.""" + return { + "name": name, + "conditions": [{"field": "relevance_score", "operator": "gte", "value": min_score}], + "action_type": action, "action_config": {}, + } diff --git a/backend-lehrer/auth/__init__.py b/backend-lehrer/auth/__init__.py index b56b38b..a3778a4 100644 --- a/backend-lehrer/auth/__init__.py +++ b/backend-lehrer/auth/__init__.py @@ -4,15 +4,11 @@ BreakPilot Authentication Module Hybrid authentication supporting both Keycloak and local JWT tokens. """ -from .keycloak_auth import ( +from .keycloak_models import ( # Config KeycloakConfig, KeycloakUser, - # Authenticators - KeycloakAuthenticator, - HybridAuthenticator, - # Exceptions KeycloakAuthError, TokenExpiredError, @@ -21,6 +17,14 @@ from .keycloak_auth import ( # Factory functions get_keycloak_config_from_env, +) + +from .keycloak_auth import ( + # Authenticators + KeycloakAuthenticator, + HybridAuthenticator, + + # Factory functions get_authenticator, get_auth, diff --git a/backend-lehrer/auth/keycloak_auth.py b/backend-lehrer/auth/keycloak_auth.py index 3449169..a8d8e71 100644 --- a/backend-lehrer/auth/keycloak_auth.py +++ b/backend-lehrer/auth/keycloak_auth.py @@ -14,110 +14,24 @@ import os import httpx import jwt from jwt import PyJWKClient -from datetime import datetime, timezone -from typing import Optional, Dict, Any, List -from dataclasses import dataclass -from functools import lru_cache import logging +from typing import Optional, Dict, Any + +from .keycloak_models import ( + KeycloakConfig, + KeycloakUser, + KeycloakAuthError, + TokenExpiredError, + TokenInvalidError, + KeycloakConfigError, + get_keycloak_config_from_env, +) logger = logging.getLogger(__name__) -@dataclass -class KeycloakConfig: - """Keycloak connection configuration.""" - server_url: str - realm: str - client_id: str - client_secret: Optional[str] = None - verify_ssl: bool = True - - @property - def issuer_url(self) -> str: - return f"{self.server_url}/realms/{self.realm}" - - @property - def jwks_url(self) -> str: - return f"{self.issuer_url}/protocol/openid-connect/certs" - - @property - def token_url(self) -> str: - return f"{self.issuer_url}/protocol/openid-connect/token" - - @property - def userinfo_url(self) -> str: - return f"{self.issuer_url}/protocol/openid-connect/userinfo" - - -@dataclass -class KeycloakUser: - """User information extracted from Keycloak token.""" - user_id: str # Keycloak subject (sub) - email: str - email_verified: bool - name: Optional[str] - given_name: Optional[str] - family_name: Optional[str] - realm_roles: List[str] # Keycloak realm roles - client_roles: Dict[str, List[str]] # Client-specific roles - groups: List[str] # Keycloak groups - tenant_id: Optional[str] # Custom claim for school/tenant - raw_claims: Dict[str, Any] # All claims for debugging - - def has_realm_role(self, role: str) -> bool: - """Check if user has a specific realm role.""" - return role in self.realm_roles - - def has_client_role(self, client_id: str, role: str) -> bool: - """Check if user has a specific client role.""" - client_roles = self.client_roles.get(client_id, []) - return role in client_roles - - def is_admin(self) -> bool: - """Check if user has admin role.""" - return self.has_realm_role("admin") or self.has_realm_role("schul_admin") - - def is_teacher(self) -> bool: - """Check if user is a teacher.""" - return self.has_realm_role("teacher") or self.has_realm_role("lehrer") - - -class KeycloakAuthError(Exception): - """Base exception for Keycloak authentication errors.""" - pass - - -class TokenExpiredError(KeycloakAuthError): - """Token has expired.""" - pass - - -class TokenInvalidError(KeycloakAuthError): - """Token is invalid.""" - pass - - -class KeycloakConfigError(KeycloakAuthError): - """Keycloak configuration error.""" - pass - - class KeycloakAuthenticator: - """ - Validates JWT tokens against Keycloak. - - Usage: - config = KeycloakConfig( - server_url="https://keycloak.example.com", - realm="breakpilot", - client_id="breakpilot-backend" - ) - auth = KeycloakAuthenticator(config) - - user = await auth.validate_token(token) - if user.is_teacher(): - # Grant access - """ + """Validates JWT tokens against Keycloak.""" def __init__(self, config: KeycloakConfig): self.config = config @@ -126,64 +40,29 @@ class KeycloakAuthenticator: @property def jwks_client(self) -> PyJWKClient: - """Lazy-load JWKS client.""" if self._jwks_client is None: - self._jwks_client = PyJWKClient( - self.config.jwks_url, - cache_keys=True, - lifespan=3600 # Cache keys for 1 hour - ) + self._jwks_client = PyJWKClient(self.config.jwks_url, cache_keys=True, lifespan=3600) return self._jwks_client async def get_http_client(self) -> httpx.AsyncClient: - """Get or create async HTTP client.""" if self._http_client is None or self._http_client.is_closed: - self._http_client = httpx.AsyncClient( - verify=self.config.verify_ssl, - timeout=30.0 - ) + self._http_client = httpx.AsyncClient(verify=self.config.verify_ssl, timeout=30.0) return self._http_client async def close(self): - """Close HTTP client.""" if self._http_client and not self._http_client.is_closed: await self._http_client.aclose() def validate_token_sync(self, token: str) -> KeycloakUser: - """ - Synchronously validate a JWT token against Keycloak JWKS. - - Args: - token: The JWT access token - - Returns: - KeycloakUser with extracted claims - - Raises: - TokenExpiredError: If token has expired - TokenInvalidError: If token signature is invalid - """ + """Synchronously validate a JWT token against Keycloak JWKS.""" try: - # Get signing key from JWKS signing_key = self.jwks_client.get_signing_key_from_jwt(token) - - # Decode and validate token payload = jwt.decode( - token, - signing_key.key, - algorithms=["RS256"], - audience=self.config.client_id, - issuer=self.config.issuer_url, - options={ - "verify_exp": True, - "verify_iat": True, - "verify_aud": True, - "verify_iss": True - } + token, signing_key.key, algorithms=["RS256"], + audience=self.config.client_id, issuer=self.config.issuer_url, + options={"verify_exp": True, "verify_iat": True, "verify_aud": True, "verify_iss": True} ) - return self._extract_user(payload) - except jwt.ExpiredSignatureError: raise TokenExpiredError("Token has expired") except jwt.InvalidAudienceError: @@ -197,27 +76,14 @@ class KeycloakAuthenticator: raise TokenInvalidError(f"Token validation failed: {e}") async def validate_token(self, token: str) -> KeycloakUser: - """ - Asynchronously validate a JWT token. - - Note: JWKS fetching is synchronous due to PyJWKClient limitations, - but this wrapper allows async context usage. - """ + """Asynchronously validate a JWT token.""" return self.validate_token_sync(token) async def get_userinfo(self, token: str) -> Dict[str, Any]: - """ - Fetch user info from Keycloak userinfo endpoint. - - This provides additional user claims not in the access token. - """ + """Fetch user info from Keycloak userinfo endpoint.""" client = await self.get_http_client() - try: - response = await client.get( - self.config.userinfo_url, - headers={"Authorization": f"Bearer {token}"} - ) + response = await client.get(self.config.userinfo_url, headers={"Authorization": f"Bearer {token}"}) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: @@ -227,94 +93,51 @@ class KeycloakAuthenticator: def _extract_user(self, payload: Dict[str, Any]) -> KeycloakUser: """Extract KeycloakUser from JWT payload.""" - - # Extract realm roles realm_access = payload.get("realm_access", {}) realm_roles = realm_access.get("roles", []) - # Extract client roles resource_access = payload.get("resource_access", {}) client_roles = {} for client_id, access in resource_access.items(): client_roles[client_id] = access.get("roles", []) - # Extract groups groups = payload.get("groups", []) - - # Extract custom tenant claim (if configured in Keycloak) tenant_id = payload.get("tenant_id") or payload.get("school_id") return KeycloakUser( - user_id=payload.get("sub", ""), - email=payload.get("email", ""), + user_id=payload.get("sub", ""), email=payload.get("email", ""), email_verified=payload.get("email_verified", False), - name=payload.get("name"), - given_name=payload.get("given_name"), + name=payload.get("name"), given_name=payload.get("given_name"), family_name=payload.get("family_name"), - realm_roles=realm_roles, - client_roles=client_roles, - groups=groups, - tenant_id=tenant_id, - raw_claims=payload + realm_roles=realm_roles, client_roles=client_roles, + groups=groups, tenant_id=tenant_id, raw_claims=payload ) -# ============================================= -# HYBRID AUTH: Keycloak + Local JWT -# ============================================= - class HybridAuthenticator: - """ - Hybrid authenticator supporting both Keycloak and local JWT tokens. + """Hybrid authenticator supporting both Keycloak and local JWT tokens.""" - This allows gradual migration from local JWT to Keycloak: - 1. Development: Use local JWT (fast, no external dependencies) - 2. Production: Use Keycloak for full IAM capabilities - - Token type detection: - - Keycloak tokens: Have 'iss' claim matching Keycloak URL - - Local tokens: Have 'iss' claim as 'breakpilot' or no 'iss' - """ - - def __init__( - self, - keycloak_config: Optional[KeycloakConfig] = None, - local_jwt_secret: Optional[str] = None, - environment: str = "development" - ): + def __init__(self, keycloak_config=None, local_jwt_secret=None, environment="development"): self.environment = environment self.keycloak_enabled = keycloak_config is not None self.local_jwt_secret = local_jwt_secret - - if keycloak_config: - self.keycloak_auth = KeycloakAuthenticator(keycloak_config) - else: - self.keycloak_auth = None + self.keycloak_auth = KeycloakAuthenticator(keycloak_config) if keycloak_config else None async def validate_token(self, token: str) -> Dict[str, Any]: - """ - Validate token using appropriate method. - - Returns a unified user dict compatible with existing code. - """ + """Validate token using appropriate method.""" if not token: raise TokenInvalidError("No token provided") - # Try to peek at the token to determine type try: - # Decode without verification to check issuer unverified = jwt.decode(token, options={"verify_signature": False}) issuer = unverified.get("iss", "") except jwt.InvalidTokenError: raise TokenInvalidError("Cannot decode token") - # Check if it's a Keycloak token if self.keycloak_auth and self.keycloak_auth.config.issuer_url in issuer: - # Validate with Keycloak kc_user = await self.keycloak_auth.validate_token(token) return self._keycloak_user_to_dict(kc_user) - # Fall back to local JWT validation if self.local_jwt_secret: return self._validate_local_token(token) @@ -326,13 +149,7 @@ class HybridAuthenticator: raise KeycloakConfigError("Local JWT secret not configured") try: - payload = jwt.decode( - token, - self.local_jwt_secret, - algorithms=["HS256"] - ) - - # Map local token claims to unified format + payload = jwt.decode(token, self.local_jwt_secret, algorithms=["HS256"]) return { "user_id": payload.get("user_id", payload.get("sub", "")), "email": payload.get("email", ""), @@ -349,7 +166,6 @@ class HybridAuthenticator: def _keycloak_user_to_dict(self, user: KeycloakUser) -> Dict[str, Any]: """Convert KeycloakUser to dict compatible with existing code.""" - # Map Keycloak roles to our role system role = "user" if user.is_admin(): role = "admin" @@ -357,20 +173,15 @@ class HybridAuthenticator: role = "teacher" return { - "user_id": user.user_id, - "email": user.email, + "user_id": user.user_id, "email": user.email, "name": user.name or f"{user.given_name or ''} {user.family_name or ''}".strip(), - "role": role, - "realm_roles": user.realm_roles, - "client_roles": user.client_roles, - "groups": user.groups, - "tenant_id": user.tenant_id, - "email_verified": user.email_verified, + "role": role, "realm_roles": user.realm_roles, + "client_roles": user.client_roles, "groups": user.groups, + "tenant_id": user.tenant_id, "email_verified": user.email_verified, "auth_method": "keycloak" } async def close(self): - """Cleanup resources.""" if self.keycloak_auth: await self.keycloak_auth.close() @@ -379,57 +190,17 @@ class HybridAuthenticator: # FACTORY FUNCTIONS # ============================================= -def get_keycloak_config_from_env() -> Optional[KeycloakConfig]: - """ - Create KeycloakConfig from environment variables. - - Required env vars: - - KEYCLOAK_SERVER_URL: e.g., https://keycloak.breakpilot.app - - KEYCLOAK_REALM: e.g., breakpilot - - KEYCLOAK_CLIENT_ID: e.g., breakpilot-backend - - Optional: - - KEYCLOAK_CLIENT_SECRET: For confidential clients - - KEYCLOAK_VERIFY_SSL: Default true - """ - server_url = os.environ.get("KEYCLOAK_SERVER_URL") - realm = os.environ.get("KEYCLOAK_REALM") - client_id = os.environ.get("KEYCLOAK_CLIENT_ID") - - if not all([server_url, realm, client_id]): - logger.info("Keycloak not configured, using local JWT only") - return None - - return KeycloakConfig( - server_url=server_url, - realm=realm, - client_id=client_id, - client_secret=os.environ.get("KEYCLOAK_CLIENT_SECRET"), - verify_ssl=os.environ.get("KEYCLOAK_VERIFY_SSL", "true").lower() == "true" - ) - - def get_authenticator() -> HybridAuthenticator: - """ - Get configured authenticator instance. - - Uses environment variables to determine configuration. - """ + """Get configured authenticator instance.""" keycloak_config = get_keycloak_config_from_env() - - # JWT_SECRET is required - no default fallback in production jwt_secret = os.environ.get("JWT_SECRET") environment = os.environ.get("ENVIRONMENT", "development") if not jwt_secret and environment == "production": - raise KeycloakConfigError( - "JWT_SECRET environment variable is required in production" - ) + raise KeycloakConfigError("JWT_SECRET environment variable is required in production") return HybridAuthenticator( - keycloak_config=keycloak_config, - local_jwt_secret=jwt_secret, - environment=environment + keycloak_config=keycloak_config, local_jwt_secret=jwt_secret, environment=environment ) @@ -439,7 +210,6 @@ def get_authenticator() -> HybridAuthenticator: from fastapi import Request, HTTPException, Depends -# Global authenticator instance (lazy-initialized) _authenticator: Optional[HybridAuthenticator] = None @@ -452,26 +222,16 @@ def get_auth() -> HybridAuthenticator: async def get_current_user(request: Request) -> Dict[str, Any]: - """ - FastAPI dependency to get current authenticated user. - - Usage: - @app.get("/api/protected") - async def protected_endpoint(user: dict = Depends(get_current_user)): - return {"user_id": user["user_id"]} - """ + """FastAPI dependency to get current authenticated user.""" auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): - # Check for development mode environment = os.environ.get("ENVIRONMENT", "development") if environment == "development": - # Return demo user in development without token return { "user_id": "10000000-0000-0000-0000-000000000024", "email": "demo@breakpilot.app", - "role": "admin", - "realm_roles": ["admin"], + "role": "admin", "realm_roles": ["admin"], "tenant_id": "a0000000-0000-0000-0000-000000000001", "auth_method": "development_bypass" } @@ -492,24 +252,11 @@ async def get_current_user(request: Request) -> Dict[str, Any]: async def require_role(required_role: str): - """ - FastAPI dependency factory for role-based access. - - Usage: - @app.get("/api/admin-only") - async def admin_endpoint(user: dict = Depends(require_role("admin"))): - return {"message": "Admin access granted"} - """ + """FastAPI dependency factory for role-based access.""" async def role_checker(user: dict = Depends(get_current_user)) -> dict: user_role = user.get("role", "user") realm_roles = user.get("realm_roles", []) - if user_role == required_role or required_role in realm_roles: return user - - raise HTTPException( - status_code=403, - detail=f"Role '{required_role}' required" - ) - + raise HTTPException(status_code=403, detail=f"Role '{required_role}' required") return role_checker diff --git a/backend-lehrer/auth/keycloak_models.py b/backend-lehrer/auth/keycloak_models.py new file mode 100644 index 0000000..31efe9d --- /dev/null +++ b/backend-lehrer/auth/keycloak_models.py @@ -0,0 +1,104 @@ +""" +Keycloak Authentication - Models, Config, and Exceptions. +""" + +import os +import logging +from typing import Optional, Dict, Any, List +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class KeycloakConfig: + """Keycloak connection configuration.""" + server_url: str + realm: str + client_id: str + client_secret: Optional[str] = None + verify_ssl: bool = True + + @property + def issuer_url(self) -> str: + return f"{self.server_url}/realms/{self.realm}" + + @property + def jwks_url(self) -> str: + return f"{self.issuer_url}/protocol/openid-connect/certs" + + @property + def token_url(self) -> str: + return f"{self.issuer_url}/protocol/openid-connect/token" + + @property + def userinfo_url(self) -> str: + return f"{self.issuer_url}/protocol/openid-connect/userinfo" + + +@dataclass +class KeycloakUser: + """User information extracted from Keycloak token.""" + user_id: str + email: str + email_verified: bool + name: Optional[str] + given_name: Optional[str] + family_name: Optional[str] + realm_roles: List[str] + client_roles: Dict[str, List[str]] + groups: List[str] + tenant_id: Optional[str] + raw_claims: Dict[str, Any] + + def has_realm_role(self, role: str) -> bool: + return role in self.realm_roles + + def has_client_role(self, client_id: str, role: str) -> bool: + client_roles = self.client_roles.get(client_id, []) + return role in client_roles + + def is_admin(self) -> bool: + return self.has_realm_role("admin") or self.has_realm_role("schul_admin") + + def is_teacher(self) -> bool: + return self.has_realm_role("teacher") or self.has_realm_role("lehrer") + + +class KeycloakAuthError(Exception): + """Base exception for Keycloak authentication errors.""" + pass + + +class TokenExpiredError(KeycloakAuthError): + """Token has expired.""" + pass + + +class TokenInvalidError(KeycloakAuthError): + """Token is invalid.""" + pass + + +class KeycloakConfigError(KeycloakAuthError): + """Keycloak configuration error.""" + pass + + +def get_keycloak_config_from_env() -> Optional[KeycloakConfig]: + """Create KeycloakConfig from environment variables.""" + server_url = os.environ.get("KEYCLOAK_SERVER_URL") + realm = os.environ.get("KEYCLOAK_REALM") + client_id = os.environ.get("KEYCLOAK_CLIENT_ID") + + if not all([server_url, realm, client_id]): + logger.info("Keycloak not configured, using local JWT only") + return None + + return KeycloakConfig( + server_url=server_url, + realm=realm, + client_id=client_id, + client_secret=os.environ.get("KEYCLOAK_CLIENT_SECRET"), + verify_ssl=os.environ.get("KEYCLOAK_VERIFY_SSL", "true").lower() == "true" + ) diff --git a/backend-lehrer/classroom/models.py b/backend-lehrer/classroom/models.py index d8332bf..9dd024e 100644 --- a/backend-lehrer/classroom/models.py +++ b/backend-lehrer/classroom/models.py @@ -2,567 +2,68 @@ Classroom API - Pydantic Models Alle Request- und Response-Models fuer die Classroom API. +Barrel re-export aus aufgeteilten Modulen. """ -from typing import Dict, List, Optional, Any -from pydantic import BaseModel, Field - - -# === Session Models === - -class CreateSessionRequest(BaseModel): - """Request zum Erstellen einer neuen Session.""" - teacher_id: str = Field(..., description="ID des Lehrers") - class_id: str = Field(..., description="ID der Klasse") - subject: str = Field(..., description="Unterrichtsfach") - topic: Optional[str] = Field(None, description="Thema der Stunde") - phase_durations: Optional[Dict[str, int]] = Field( - None, - description="Optionale individuelle Phasendauern in Minuten" - ) - - -class NotesRequest(BaseModel): - """Request zum Aktualisieren von Notizen.""" - notes: str = Field("", description="Stundennotizen") - homework: str = Field("", description="Hausaufgaben") - - -class ExtendTimeRequest(BaseModel): - """Request zum Verlaengern der aktuellen Phase (Feature f28).""" - minutes: int = Field(5, ge=1, le=30, description="Zusaetzliche Minuten (1-30)") - - -class PhaseInfo(BaseModel): - """Informationen zu einer Phase.""" - phase: str - display_name: str - icon: str - duration_minutes: int - is_completed: bool - is_current: bool - is_future: bool - - -class TimerStatus(BaseModel): - """Timer-Status einer Phase.""" - remaining_seconds: int - remaining_formatted: str - total_seconds: int - total_formatted: str - elapsed_seconds: int - elapsed_formatted: str - percentage_remaining: int - percentage_elapsed: int - percentage: int = Field(description="Alias fuer percentage_remaining (Visual Timer)") - warning: bool - overtime: bool - overtime_seconds: int - overtime_formatted: Optional[str] - is_paused: bool = Field(False, description="Ist der Timer pausiert?") - - -class SuggestionItem(BaseModel): - """Ein Aktivitaets-Vorschlag.""" - id: str - title: str - description: str - activity_type: str - estimated_minutes: int - icon: str - content_url: Optional[str] - - -class SessionResponse(BaseModel): - """Vollstaendige Session-Response.""" - session_id: str - teacher_id: str - class_id: str - subject: str - topic: Optional[str] - current_phase: str - phase_display_name: str - phase_started_at: Optional[str] - lesson_started_at: Optional[str] - lesson_ended_at: Optional[str] - timer: TimerStatus - phases: List[PhaseInfo] - phase_history: List[Dict[str, Any]] - notes: str - homework: str - is_active: bool - is_ended: bool - is_paused: bool = Field(False, description="Ist die Stunde pausiert?") - - -class SuggestionsResponse(BaseModel): - """Response fuer Vorschlaege.""" - suggestions: List[SuggestionItem] - current_phase: str - phase_display_name: str - total_available: int - - -class PhasesListResponse(BaseModel): - """Liste aller verfuegbaren Phasen.""" - phases: List[Dict[str, Any]] - - -class ActiveSessionsResponse(BaseModel): - """Liste aktiver Sessions.""" - sessions: List[Dict[str, Any]] - count: int - - -# === Session History Models === - -class SessionHistoryItem(BaseModel): - """Einzelner Eintrag in der Session-History.""" - session_id: str - teacher_id: str - class_id: str - subject: str - topic: Optional[str] - lesson_started_at: Optional[str] - lesson_ended_at: Optional[str] - total_duration_minutes: Optional[int] - phases_completed: int - notes: str - homework: str - - -class SessionHistoryResponse(BaseModel): - """Response fuer Session-History.""" - sessions: List[SessionHistoryItem] - total_count: int - limit: int - offset: int - - -# === Template Models === - -class TemplateCreate(BaseModel): - """Request zum Erstellen einer Vorlage.""" - name: str = Field(..., min_length=1, max_length=200, description="Name der Vorlage") - description: str = Field("", max_length=1000, description="Beschreibung") - subject: str = Field("", max_length=100, description="Fach") - grade_level: str = Field("", max_length=50, description="Klassenstufe (z.B. '7', '10')") - phase_durations: Optional[Dict[str, int]] = Field( - None, - description="Phasendauern in Minuten" - ) - default_topic: str = Field("", max_length=500, description="Vorausgefuelltes Thema") - default_notes: str = Field("", description="Vorausgefuellte Notizen") - is_public: bool = Field(False, description="Vorlage fuer alle sichtbar?") - - -class TemplateUpdate(BaseModel): - """Request zum Aktualisieren einer Vorlage.""" - name: Optional[str] = Field(None, min_length=1, max_length=200) - description: Optional[str] = Field(None, max_length=1000) - subject: Optional[str] = Field(None, max_length=100) - grade_level: Optional[str] = Field(None, max_length=50) - phase_durations: Optional[Dict[str, int]] = None - default_topic: Optional[str] = Field(None, max_length=500) - default_notes: Optional[str] = None - is_public: Optional[bool] = None - - -class TemplateResponse(BaseModel): - """Response fuer eine einzelne Vorlage.""" - template_id: str - teacher_id: str - name: str - description: str - subject: str - grade_level: str - phase_durations: Dict[str, int] - default_topic: str - default_notes: str - is_public: bool - usage_count: int - total_duration_minutes: int - created_at: Optional[str] - updated_at: Optional[str] - is_system_template: bool = False - - -class TemplateListResponse(BaseModel): - """Response fuer Template-Liste.""" - templates: List[TemplateResponse] - total_count: int - - -# === Homework Models === - -class CreateHomeworkRequest(BaseModel): - """Request zum Erstellen einer Hausaufgabe.""" - teacher_id: str - class_id: str - subject: str - title: str = Field(..., max_length=300) - description: str = "" - session_id: Optional[str] = None - due_date: Optional[str] = Field(None, description="ISO-Format Datum") - - -class UpdateHomeworkRequest(BaseModel): - """Request zum Aktualisieren einer Hausaufgabe.""" - title: Optional[str] = Field(None, max_length=300) - description: Optional[str] = None - due_date: Optional[str] = Field(None, description="ISO-Format Datum") - status: Optional[str] = Field(None, description="assigned, in_progress, completed") - - -class HomeworkResponse(BaseModel): - """Response fuer eine Hausaufgabe.""" - homework_id: str - teacher_id: str - class_id: str - subject: str - title: str - description: str - session_id: Optional[str] - due_date: Optional[str] - status: str - is_overdue: bool - created_at: Optional[str] - updated_at: Optional[str] - - -class HomeworkListResponse(BaseModel): - """Response fuer Liste von Hausaufgaben.""" - homework: List[HomeworkResponse] - total: int - - -# === Material Models === - -class CreateMaterialRequest(BaseModel): - """Request zum Erstellen eines Materials.""" - teacher_id: str - title: str = Field(..., max_length=300) - material_type: str = Field("document", description="document, link, video, image, worksheet, presentation, other") - url: Optional[str] = Field(None, max_length=2000) - description: str = "" - phase: Optional[str] = Field(None, description="einstieg, erarbeitung, sicherung, transfer, reflexion") - subject: str = "" - grade_level: str = "" - tags: List[str] = [] - is_public: bool = False - session_id: Optional[str] = None - - -class UpdateMaterialRequest(BaseModel): - """Request zum Aktualisieren eines Materials.""" - title: Optional[str] = Field(None, max_length=300) - material_type: Optional[str] = None - url: Optional[str] = Field(None, max_length=2000) - description: Optional[str] = None - phase: Optional[str] = None - subject: Optional[str] = None - grade_level: Optional[str] = None - tags: Optional[List[str]] = None - is_public: Optional[bool] = None - - -class MaterialResponse(BaseModel): - """Response fuer ein Material.""" - material_id: str - teacher_id: str - title: str - material_type: str - url: Optional[str] - description: str - phase: Optional[str] - subject: str - grade_level: str - tags: List[str] - is_public: bool - usage_count: int - session_id: Optional[str] - created_at: Optional[str] - updated_at: Optional[str] - - -class MaterialListResponse(BaseModel): - """Response fuer Liste von Materialien.""" - materials: List[MaterialResponse] - total: int - - -# === Analytics Models === - -class SessionSummaryResponse(BaseModel): - """Response fuer Session-Summary.""" - session_id: str - teacher_id: str - class_id: str - subject: str - topic: Optional[str] - date: Optional[str] - date_formatted: str - total_duration_seconds: int - total_duration_formatted: str - planned_duration_seconds: int - planned_duration_formatted: str - phases_completed: int - total_phases: int - completion_percentage: int - phase_statistics: List[Dict[str, Any]] - total_overtime_seconds: int - total_overtime_formatted: str - phases_with_overtime: int - total_pause_count: int - total_pause_seconds: int - reflection_notes: str = "" - reflection_rating: Optional[int] = None - key_learnings: List[str] = [] - - -class TeacherAnalyticsResponse(BaseModel): - """Response fuer Lehrer-Analytics.""" - teacher_id: str - period_start: Optional[str] - period_end: Optional[str] - total_sessions: int - completed_sessions: int - total_teaching_minutes: int - total_teaching_hours: float - avg_phase_durations: Dict[str, int] - sessions_with_overtime: int - overtime_percentage: int - avg_overtime_seconds: int - avg_overtime_formatted: str - most_overtime_phase: Optional[str] - avg_pause_count: float - avg_pause_duration_seconds: int - subjects_taught: Dict[str, int] - classes_taught: Dict[str, int] - - -class ReflectionCreate(BaseModel): - """Request-Body fuer Reflection-Erstellung.""" - session_id: str - teacher_id: str - notes: str = "" - overall_rating: Optional[int] = Field(None, ge=1, le=5) - what_worked: List[str] = [] - improvements: List[str] = [] - notes_for_next_lesson: str = "" - - -class ReflectionUpdate(BaseModel): - """Request-Body fuer Reflection-Update.""" - notes: Optional[str] = None - overall_rating: Optional[int] = Field(None, ge=1, le=5) - what_worked: Optional[List[str]] = None - improvements: Optional[List[str]] = None - notes_for_next_lesson: Optional[str] = None - - -class ReflectionResponse(BaseModel): - """Response fuer eine einzelne Reflection.""" - reflection_id: str - session_id: str - teacher_id: str - notes: str - overall_rating: Optional[int] - what_worked: List[str] - improvements: List[str] - notes_for_next_lesson: str - created_at: Optional[str] - updated_at: Optional[str] - - -# === Feedback Models === - -class FeedbackCreate(BaseModel): - """Request zum Erstellen von Feedback.""" - title: str = Field(..., min_length=3, max_length=500, description="Kurzer Titel") - description: str = Field(..., min_length=10, description="Beschreibung") - feedback_type: str = Field("improvement", description="bug, feature_request, improvement, praise, question") - priority: str = Field("medium", description="critical, high, medium, low") - teacher_name: str = Field("", description="Name des Lehrers") - teacher_email: str = Field("", description="E-Mail fuer Rueckfragen") - context_url: str = Field("", description="URL wo Feedback gegeben wurde") - context_phase: str = Field("", description="Aktuelle Phase") - context_session_id: Optional[str] = Field(None, description="Session-ID falls aktiv") - related_feature: Optional[str] = Field(None, description="Verwandtes Feature") - - -class FeedbackResponse(BaseModel): - """Response fuer Feedback.""" - id: str - teacher_id: str - teacher_name: str - title: str - description: str - feedback_type: str - priority: str - status: str - created_at: str - response: Optional[str] = None - - -class FeedbackListResponse(BaseModel): - """Liste von Feedbacks.""" - feedbacks: List[Dict[str, Any]] - total: int - - -class FeedbackStatsResponse(BaseModel): - """Feedback-Statistiken.""" - total: int - by_status: Dict[str, int] - by_type: Dict[str, int] - by_priority: Dict[str, int] - - -# === Settings Models === - -class TeacherSettingsResponse(BaseModel): - """Response fuer Lehrer-Einstellungen.""" - teacher_id: str - default_phase_durations: Dict[str, int] - audio_enabled: bool = True - high_contrast: bool = False - show_statistics: bool = True - - -class UpdatePhaseDurationsRequest(BaseModel): - """Request zum Aktualisieren der Phasen-Dauern.""" - durations: Dict[str, int] = Field( - ..., - description="Phasen-Dauern in Minuten, z.B. {'einstieg': 10, 'erarbeitung': 25}", - examples=[{"einstieg": 10, "erarbeitung": 25, "sicherung": 10, "transfer": 8, "reflexion": 5}] - ) - - -class UpdatePreferencesRequest(BaseModel): - """Request zum Aktualisieren der UI-Praeferenzen.""" - audio_enabled: Optional[bool] = None - high_contrast: Optional[bool] = None - show_statistics: Optional[bool] = None - - -# === Context Models === - -class SchoolInfo(BaseModel): - """Schul-Informationen.""" - federal_state: str - federal_state_name: str = "" - school_type: str - school_type_name: str = "" - - -class SchoolYearInfo(BaseModel): - """Schuljahr-Informationen.""" - id: str - start: Optional[str] = None - current_week: int = 1 - - -class MacroPhaseInfo(BaseModel): - """Makro-Phase Informationen.""" - id: str - label: str - confidence: float = 1.0 - - -class CoreCounts(BaseModel): - """Kern-Zaehler fuer den Kontext.""" - classes: int = 0 - exams_scheduled: int = 0 - corrections_pending: int = 0 - - -class ContextFlags(BaseModel): - """Status-Flags des Kontexts.""" - onboarding_completed: bool = False - has_classes: bool = False - has_schedule: bool = False - is_exam_period: bool = False - is_before_holidays: bool = False - - -class TeacherContextResponse(BaseModel): - """Response fuer GET /v1/context.""" - schema_version: str = "1.0" - teacher_id: str - school: SchoolInfo - school_year: SchoolYearInfo - macro_phase: MacroPhaseInfo - core_counts: CoreCounts - flags: ContextFlags - - -class UpdateContextRequest(BaseModel): - """Request zum Aktualisieren des Kontexts.""" - federal_state: Optional[str] = None - school_type: Optional[str] = None - schoolyear: Optional[str] = None - schoolyear_start: Optional[str] = None - macro_phase: Optional[str] = None - current_week: Optional[int] = None - - -# === Event Models === - -class CreateEventRequest(BaseModel): - """Request zum Erstellen eines Events.""" - title: str - event_type: str = "other" - start_date: str - end_date: Optional[str] = None - class_id: Optional[str] = None - subject: Optional[str] = None - description: str = "" - needs_preparation: bool = True - reminder_days_before: int = 7 - - -class EventResponse(BaseModel): - """Response fuer ein Event.""" - id: str - teacher_id: str - event_type: str - title: str - description: str - start_date: str - end_date: Optional[str] - class_id: Optional[str] - subject: Optional[str] - status: str - needs_preparation: bool - preparation_done: bool - reminder_days_before: int - - -# === Routine Models === - -class CreateRoutineRequest(BaseModel): - """Request zum Erstellen einer Routine.""" - title: str - routine_type: str = "other" - recurrence_pattern: str = "weekly" - day_of_week: Optional[int] = None - day_of_month: Optional[int] = None - time_of_day: Optional[str] = None - duration_minutes: int = 60 - description: str = "" - - -class RoutineResponse(BaseModel): - """Response fuer eine Routine.""" - id: str - teacher_id: str - routine_type: str - title: str - description: str - recurrence_pattern: str - day_of_week: Optional[int] - day_of_month: Optional[int] - time_of_day: Optional[str] - duration_minutes: int - is_active: bool +# Session & Phase Models +from .models_session import ( + CreateSessionRequest, + NotesRequest, + ExtendTimeRequest, + PhaseInfo, + TimerStatus, + SuggestionItem, + SessionResponse, + SuggestionsResponse, + PhasesListResponse, + ActiveSessionsResponse, + SessionHistoryItem, + SessionHistoryResponse, +) + +# Template, Homework, Material Models +from .models_templates import ( + TemplateCreate, + TemplateUpdate, + TemplateResponse, + TemplateListResponse, + CreateHomeworkRequest, + UpdateHomeworkRequest, + HomeworkResponse, + HomeworkListResponse, + CreateMaterialRequest, + UpdateMaterialRequest, + MaterialResponse, + MaterialListResponse, +) + +# Analytics, Reflection, Feedback, Settings Models +from .models_analytics import ( + SessionSummaryResponse, + TeacherAnalyticsResponse, + ReflectionCreate, + ReflectionUpdate, + ReflectionResponse, + FeedbackCreate, + FeedbackResponse, + FeedbackListResponse, + FeedbackStatsResponse, + TeacherSettingsResponse, + UpdatePhaseDurationsRequest, + UpdatePreferencesRequest, +) + +# Context, Event, Routine Models +from .models_context import ( + SchoolInfo, + SchoolYearInfo, + MacroPhaseInfo, + CoreCounts, + ContextFlags, + TeacherContextResponse, + UpdateContextRequest, + CreateEventRequest, + EventResponse, + CreateRoutineRequest, + RoutineResponse, +) diff --git a/backend-lehrer/classroom/models_analytics.py b/backend-lehrer/classroom/models_analytics.py new file mode 100644 index 0000000..6c1673d --- /dev/null +++ b/backend-lehrer/classroom/models_analytics.py @@ -0,0 +1,161 @@ +""" +Classroom API - Analytics, Reflection, Feedback, Settings Pydantic Models. +""" + +from typing import Dict, List, Optional, Any +from pydantic import BaseModel, Field + + +# === Analytics Models === + +class SessionSummaryResponse(BaseModel): + """Response fuer Session-Summary.""" + session_id: str + teacher_id: str + class_id: str + subject: str + topic: Optional[str] + date: Optional[str] + date_formatted: str + total_duration_seconds: int + total_duration_formatted: str + planned_duration_seconds: int + planned_duration_formatted: str + phases_completed: int + total_phases: int + completion_percentage: int + phase_statistics: List[Dict[str, Any]] + total_overtime_seconds: int + total_overtime_formatted: str + phases_with_overtime: int + total_pause_count: int + total_pause_seconds: int + reflection_notes: str = "" + reflection_rating: Optional[int] = None + key_learnings: List[str] = [] + + +class TeacherAnalyticsResponse(BaseModel): + """Response fuer Lehrer-Analytics.""" + teacher_id: str + period_start: Optional[str] + period_end: Optional[str] + total_sessions: int + completed_sessions: int + total_teaching_minutes: int + total_teaching_hours: float + avg_phase_durations: Dict[str, int] + sessions_with_overtime: int + overtime_percentage: int + avg_overtime_seconds: int + avg_overtime_formatted: str + most_overtime_phase: Optional[str] + avg_pause_count: float + avg_pause_duration_seconds: int + subjects_taught: Dict[str, int] + classes_taught: Dict[str, int] + + +class ReflectionCreate(BaseModel): + """Request-Body fuer Reflection-Erstellung.""" + session_id: str + teacher_id: str + notes: str = "" + overall_rating: Optional[int] = Field(None, ge=1, le=5) + what_worked: List[str] = [] + improvements: List[str] = [] + notes_for_next_lesson: str = "" + + +class ReflectionUpdate(BaseModel): + """Request-Body fuer Reflection-Update.""" + notes: Optional[str] = None + overall_rating: Optional[int] = Field(None, ge=1, le=5) + what_worked: Optional[List[str]] = None + improvements: Optional[List[str]] = None + notes_for_next_lesson: Optional[str] = None + + +class ReflectionResponse(BaseModel): + """Response fuer eine einzelne Reflection.""" + reflection_id: str + session_id: str + teacher_id: str + notes: str + overall_rating: Optional[int] + what_worked: List[str] + improvements: List[str] + notes_for_next_lesson: str + created_at: Optional[str] + updated_at: Optional[str] + + +# === Feedback Models === + +class FeedbackCreate(BaseModel): + """Request zum Erstellen von Feedback.""" + title: str = Field(..., min_length=3, max_length=500, description="Kurzer Titel") + description: str = Field(..., min_length=10, description="Beschreibung") + feedback_type: str = Field("improvement", description="bug, feature_request, improvement, praise, question") + priority: str = Field("medium", description="critical, high, medium, low") + teacher_name: str = Field("", description="Name des Lehrers") + teacher_email: str = Field("", description="E-Mail fuer Rueckfragen") + context_url: str = Field("", description="URL wo Feedback gegeben wurde") + context_phase: str = Field("", description="Aktuelle Phase") + context_session_id: Optional[str] = Field(None, description="Session-ID falls aktiv") + related_feature: Optional[str] = Field(None, description="Verwandtes Feature") + + +class FeedbackResponse(BaseModel): + """Response fuer Feedback.""" + id: str + teacher_id: str + teacher_name: str + title: str + description: str + feedback_type: str + priority: str + status: str + created_at: str + response: Optional[str] = None + + +class FeedbackListResponse(BaseModel): + """Liste von Feedbacks.""" + feedbacks: List[Dict[str, Any]] + total: int + + +class FeedbackStatsResponse(BaseModel): + """Feedback-Statistiken.""" + total: int + by_status: Dict[str, int] + by_type: Dict[str, int] + by_priority: Dict[str, int] + + +# === Settings Models === + +class TeacherSettingsResponse(BaseModel): + """Response fuer Lehrer-Einstellungen.""" + teacher_id: str + default_phase_durations: Dict[str, int] + audio_enabled: bool = True + high_contrast: bool = False + show_statistics: bool = True + + +class UpdatePhaseDurationsRequest(BaseModel): + """Request zum Aktualisieren der Phasen-Dauern.""" + durations: Dict[str, int] = Field( + ..., + description="Phasen-Dauern in Minuten, z.B. {'einstieg': 10, 'erarbeitung': 25}", + examples=[{"einstieg": 10, "erarbeitung": 25, "sicherung": 10, "transfer": 8, "reflexion": 5}] + ) + + +class UpdatePreferencesRequest(BaseModel): + """Request zum Aktualisieren der UI-Praeferenzen.""" + audio_enabled: Optional[bool] = None + high_contrast: Optional[bool] = None + show_statistics: Optional[bool] = None diff --git a/backend-lehrer/classroom/models_context.py b/backend-lehrer/classroom/models_context.py new file mode 100644 index 0000000..4f743fc --- /dev/null +++ b/backend-lehrer/classroom/models_context.py @@ -0,0 +1,128 @@ +""" +Classroom API - Context, Event, Routine Pydantic Models. +""" + +from typing import Optional +from pydantic import BaseModel, Field + + +# === Context Models === + +class SchoolInfo(BaseModel): + """Schul-Informationen.""" + federal_state: str + federal_state_name: str = "" + school_type: str + school_type_name: str = "" + + +class SchoolYearInfo(BaseModel): + """Schuljahr-Informationen.""" + id: str + start: Optional[str] = None + current_week: int = 1 + + +class MacroPhaseInfo(BaseModel): + """Makro-Phase Informationen.""" + id: str + label: str + confidence: float = 1.0 + + +class CoreCounts(BaseModel): + """Kern-Zaehler fuer den Kontext.""" + classes: int = 0 + exams_scheduled: int = 0 + corrections_pending: int = 0 + + +class ContextFlags(BaseModel): + """Status-Flags des Kontexts.""" + onboarding_completed: bool = False + has_classes: bool = False + has_schedule: bool = False + is_exam_period: bool = False + is_before_holidays: bool = False + + +class TeacherContextResponse(BaseModel): + """Response fuer GET /v1/context.""" + schema_version: str = "1.0" + teacher_id: str + school: SchoolInfo + school_year: SchoolYearInfo + macro_phase: MacroPhaseInfo + core_counts: CoreCounts + flags: ContextFlags + + +class UpdateContextRequest(BaseModel): + """Request zum Aktualisieren des Kontexts.""" + federal_state: Optional[str] = None + school_type: Optional[str] = None + schoolyear: Optional[str] = None + schoolyear_start: Optional[str] = None + macro_phase: Optional[str] = None + current_week: Optional[int] = None + + +# === Event Models === + +class CreateEventRequest(BaseModel): + """Request zum Erstellen eines Events.""" + title: str + event_type: str = "other" + start_date: str + end_date: Optional[str] = None + class_id: Optional[str] = None + subject: Optional[str] = None + description: str = "" + needs_preparation: bool = True + reminder_days_before: int = 7 + + +class EventResponse(BaseModel): + """Response fuer ein Event.""" + id: str + teacher_id: str + event_type: str + title: str + description: str + start_date: str + end_date: Optional[str] + class_id: Optional[str] + subject: Optional[str] + status: str + needs_preparation: bool + preparation_done: bool + reminder_days_before: int + + +# === Routine Models === + +class CreateRoutineRequest(BaseModel): + """Request zum Erstellen einer Routine.""" + title: str + routine_type: str = "other" + recurrence_pattern: str = "weekly" + day_of_week: Optional[int] = None + day_of_month: Optional[int] = None + time_of_day: Optional[str] = None + duration_minutes: int = 60 + description: str = "" + + +class RoutineResponse(BaseModel): + """Response fuer eine Routine.""" + id: str + teacher_id: str + routine_type: str + title: str + description: str + recurrence_pattern: str + day_of_week: Optional[int] + day_of_month: Optional[int] + time_of_day: Optional[str] + duration_minutes: int + is_active: bool diff --git a/backend-lehrer/classroom/models_session.py b/backend-lehrer/classroom/models_session.py new file mode 100644 index 0000000..938256b --- /dev/null +++ b/backend-lehrer/classroom/models_session.py @@ -0,0 +1,137 @@ +""" +Classroom API - Session & Phase Pydantic Models. +""" + +from typing import Dict, List, Optional, Any +from pydantic import BaseModel, Field + + +# === Session Models === + +class CreateSessionRequest(BaseModel): + """Request zum Erstellen einer neuen Session.""" + teacher_id: str = Field(..., description="ID des Lehrers") + class_id: str = Field(..., description="ID der Klasse") + subject: str = Field(..., description="Unterrichtsfach") + topic: Optional[str] = Field(None, description="Thema der Stunde") + phase_durations: Optional[Dict[str, int]] = Field( + None, + description="Optionale individuelle Phasendauern in Minuten" + ) + + +class NotesRequest(BaseModel): + """Request zum Aktualisieren von Notizen.""" + notes: str = Field("", description="Stundennotizen") + homework: str = Field("", description="Hausaufgaben") + + +class ExtendTimeRequest(BaseModel): + """Request zum Verlaengern der aktuellen Phase (Feature f28).""" + minutes: int = Field(5, ge=1, le=30, description="Zusaetzliche Minuten (1-30)") + + +class PhaseInfo(BaseModel): + """Informationen zu einer Phase.""" + phase: str + display_name: str + icon: str + duration_minutes: int + is_completed: bool + is_current: bool + is_future: bool + + +class TimerStatus(BaseModel): + """Timer-Status einer Phase.""" + remaining_seconds: int + remaining_formatted: str + total_seconds: int + total_formatted: str + elapsed_seconds: int + elapsed_formatted: str + percentage_remaining: int + percentage_elapsed: int + percentage: int = Field(description="Alias fuer percentage_remaining (Visual Timer)") + warning: bool + overtime: bool + overtime_seconds: int + overtime_formatted: Optional[str] + is_paused: bool = Field(False, description="Ist der Timer pausiert?") + + +class SuggestionItem(BaseModel): + """Ein Aktivitaets-Vorschlag.""" + id: str + title: str + description: str + activity_type: str + estimated_minutes: int + icon: str + content_url: Optional[str] + + +class SessionResponse(BaseModel): + """Vollstaendige Session-Response.""" + session_id: str + teacher_id: str + class_id: str + subject: str + topic: Optional[str] + current_phase: str + phase_display_name: str + phase_started_at: Optional[str] + lesson_started_at: Optional[str] + lesson_ended_at: Optional[str] + timer: TimerStatus + phases: List[PhaseInfo] + phase_history: List[Dict[str, Any]] + notes: str + homework: str + is_active: bool + is_ended: bool + is_paused: bool = Field(False, description="Ist die Stunde pausiert?") + + +class SuggestionsResponse(BaseModel): + """Response fuer Vorschlaege.""" + suggestions: List[SuggestionItem] + current_phase: str + phase_display_name: str + total_available: int + + +class PhasesListResponse(BaseModel): + """Liste aller verfuegbaren Phasen.""" + phases: List[Dict[str, Any]] + + +class ActiveSessionsResponse(BaseModel): + """Liste aktiver Sessions.""" + sessions: List[Dict[str, Any]] + count: int + + +# === Session History Models === + +class SessionHistoryItem(BaseModel): + """Einzelner Eintrag in der Session-History.""" + session_id: str + teacher_id: str + class_id: str + subject: str + topic: Optional[str] + lesson_started_at: Optional[str] + lesson_ended_at: Optional[str] + total_duration_minutes: Optional[int] + phases_completed: int + notes: str + homework: str + + +class SessionHistoryResponse(BaseModel): + """Response fuer Session-History.""" + sessions: List[SessionHistoryItem] + total_count: int + limit: int + offset: int diff --git a/backend-lehrer/classroom/models_templates.py b/backend-lehrer/classroom/models_templates.py new file mode 100644 index 0000000..64f188b --- /dev/null +++ b/backend-lehrer/classroom/models_templates.py @@ -0,0 +1,158 @@ +""" +Classroom API - Template, Homework, Material Pydantic Models. +""" + +from typing import Dict, List, Optional +from pydantic import BaseModel, Field + + +# === Template Models === + +class TemplateCreate(BaseModel): + """Request zum Erstellen einer Vorlage.""" + name: str = Field(..., min_length=1, max_length=200, description="Name der Vorlage") + description: str = Field("", max_length=1000, description="Beschreibung") + subject: str = Field("", max_length=100, description="Fach") + grade_level: str = Field("", max_length=50, description="Klassenstufe (z.B. '7', '10')") + phase_durations: Optional[Dict[str, int]] = Field( + None, + description="Phasendauern in Minuten" + ) + default_topic: str = Field("", max_length=500, description="Vorausgefuelltes Thema") + default_notes: str = Field("", description="Vorausgefuellte Notizen") + is_public: bool = Field(False, description="Vorlage fuer alle sichtbar?") + + +class TemplateUpdate(BaseModel): + """Request zum Aktualisieren einer Vorlage.""" + name: Optional[str] = Field(None, min_length=1, max_length=200) + description: Optional[str] = Field(None, max_length=1000) + subject: Optional[str] = Field(None, max_length=100) + grade_level: Optional[str] = Field(None, max_length=50) + phase_durations: Optional[Dict[str, int]] = None + default_topic: Optional[str] = Field(None, max_length=500) + default_notes: Optional[str] = None + is_public: Optional[bool] = None + + +class TemplateResponse(BaseModel): + """Response fuer eine einzelne Vorlage.""" + template_id: str + teacher_id: str + name: str + description: str + subject: str + grade_level: str + phase_durations: Dict[str, int] + default_topic: str + default_notes: str + is_public: bool + usage_count: int + total_duration_minutes: int + created_at: Optional[str] + updated_at: Optional[str] + is_system_template: bool = False + + +class TemplateListResponse(BaseModel): + """Response fuer Template-Liste.""" + templates: List[TemplateResponse] + total_count: int + + +# === Homework Models === + +class CreateHomeworkRequest(BaseModel): + """Request zum Erstellen einer Hausaufgabe.""" + teacher_id: str + class_id: str + subject: str + title: str = Field(..., max_length=300) + description: str = "" + session_id: Optional[str] = None + due_date: Optional[str] = Field(None, description="ISO-Format Datum") + + +class UpdateHomeworkRequest(BaseModel): + """Request zum Aktualisieren einer Hausaufgabe.""" + title: Optional[str] = Field(None, max_length=300) + description: Optional[str] = None + due_date: Optional[str] = Field(None, description="ISO-Format Datum") + status: Optional[str] = Field(None, description="assigned, in_progress, completed") + + +class HomeworkResponse(BaseModel): + """Response fuer eine Hausaufgabe.""" + homework_id: str + teacher_id: str + class_id: str + subject: str + title: str + description: str + session_id: Optional[str] + due_date: Optional[str] + status: str + is_overdue: bool + created_at: Optional[str] + updated_at: Optional[str] + + +class HomeworkListResponse(BaseModel): + """Response fuer Liste von Hausaufgaben.""" + homework: List[HomeworkResponse] + total: int + + +# === Material Models === + +class CreateMaterialRequest(BaseModel): + """Request zum Erstellen eines Materials.""" + teacher_id: str + title: str = Field(..., max_length=300) + material_type: str = Field("document", description="document, link, video, image, worksheet, presentation, other") + url: Optional[str] = Field(None, max_length=2000) + description: str = "" + phase: Optional[str] = Field(None, description="einstieg, erarbeitung, sicherung, transfer, reflexion") + subject: str = "" + grade_level: str = "" + tags: List[str] = [] + is_public: bool = False + session_id: Optional[str] = None + + +class UpdateMaterialRequest(BaseModel): + """Request zum Aktualisieren eines Materials.""" + title: Optional[str] = Field(None, max_length=300) + material_type: Optional[str] = None + url: Optional[str] = Field(None, max_length=2000) + description: Optional[str] = None + phase: Optional[str] = None + subject: Optional[str] = None + grade_level: Optional[str] = None + tags: Optional[List[str]] = None + is_public: Optional[bool] = None + + +class MaterialResponse(BaseModel): + """Response fuer ein Material.""" + material_id: str + teacher_id: str + title: str + material_type: str + url: Optional[str] + description: str + phase: Optional[str] + subject: str + grade_level: str + tags: List[str] + is_public: bool + usage_count: int + session_id: Optional[str] + created_at: Optional[str] + updated_at: Optional[str] + + +class MaterialListResponse(BaseModel): + """Response fuer Liste von Materialien.""" + materials: List[MaterialResponse] + total: int diff --git a/backend-lehrer/classroom/routes/sessions.py b/backend-lehrer/classroom/routes/sessions.py index ef3286a..55e26c6 100644 --- a/backend-lehrer/classroom/routes/sessions.py +++ b/backend-lehrer/classroom/routes/sessions.py @@ -1,525 +1,17 @@ """ -Classroom API - Session Routes +Classroom API - Session Routes (barrel re-export) -Session management endpoints: create, get, start, next-phase, end, etc. +Combines core session routes and action routes into a single router. """ -from uuid import uuid4 -from typing import Dict, Optional, Any -from datetime import datetime -import logging +from fastapi import APIRouter -from fastapi import APIRouter, HTTPException, Query +from .sessions_core import router as core_router, build_session_response +from .sessions_actions import router as actions_router -from classroom_engine import ( - LessonPhase, - LessonSession, - LessonStateMachine, - PhaseTimer, - SuggestionEngine, - LESSON_PHASES, -) +router = APIRouter() +router.include_router(core_router) +router.include_router(actions_router) -from ..models import ( - CreateSessionRequest, - NotesRequest, - ExtendTimeRequest, - PhaseInfo, - TimerStatus, - SuggestionItem, - SessionResponse, - SuggestionsResponse, - PhasesListResponse, - ActiveSessionsResponse, - SessionHistoryItem, - SessionHistoryResponse, -) -from ..services.persistence import ( - sessions, - init_db_if_needed, - persist_session, - get_session_or_404, - DB_ENABLED, - SessionLocal, -) -from ..websocket_manager import notify_phase_change, notify_session_ended - -logger = logging.getLogger(__name__) - -router = APIRouter(tags=["Sessions"]) - - -def build_session_response(session: LessonSession) -> SessionResponse: - """Baut die vollstaendige Session-Response.""" - fsm = LessonStateMachine() - timer = PhaseTimer() - - timer_status = timer.get_phase_status(session) - phases_info = fsm.get_phases_info(session) - - return SessionResponse( - session_id=session.session_id, - teacher_id=session.teacher_id, - class_id=session.class_id, - subject=session.subject, - topic=session.topic, - current_phase=session.current_phase.value, - phase_display_name=session.get_phase_display_name(), - phase_started_at=session.phase_started_at.isoformat() if session.phase_started_at else None, - lesson_started_at=session.lesson_started_at.isoformat() if session.lesson_started_at else None, - lesson_ended_at=session.lesson_ended_at.isoformat() if session.lesson_ended_at else None, - timer=TimerStatus(**timer_status), - phases=[PhaseInfo(**p) for p in phases_info], - phase_history=session.phase_history, - notes=session.notes, - homework=session.homework, - is_active=fsm.is_lesson_active(session), - is_ended=fsm.is_lesson_ended(session), - is_paused=session.is_paused, - ) - - -# === Session CRUD Endpoints === - -@router.post("/sessions", response_model=SessionResponse) -async def create_session(request: CreateSessionRequest) -> SessionResponse: - """ - Erstellt eine neue Unterrichtsstunde (Session). - - Die Stunde ist nach Erstellung im Status NOT_STARTED. - Zum Starten muss /sessions/{id}/start aufgerufen werden. - """ - init_db_if_needed() - - # Default-Dauern mit uebergebenen Werten mergen - phase_durations = { - "einstieg": 8, - "erarbeitung": 20, - "sicherung": 10, - "transfer": 7, - "reflexion": 5, - } - if request.phase_durations: - phase_durations.update(request.phase_durations) - - session = LessonSession( - session_id=str(uuid4()), - teacher_id=request.teacher_id, - class_id=request.class_id, - subject=request.subject, - topic=request.topic, - phase_durations=phase_durations, - ) - - sessions[session.session_id] = session - persist_session(session) - return build_session_response(session) - - -@router.get("/sessions/{session_id}", response_model=SessionResponse) -async def get_session(session_id: str) -> SessionResponse: - """ - Ruft den aktuellen Status einer Session ab. - - Enthaelt alle Informationen inkl. Timer-Status und Phasen-Timeline. - """ - session = get_session_or_404(session_id) - return build_session_response(session) - - -@router.post("/sessions/{session_id}/start", response_model=SessionResponse) -async def start_lesson(session_id: str) -> SessionResponse: - """ - Startet die Unterrichtsstunde. - - Wechselt von NOT_STARTED zur ersten Phase (EINSTIEG). - """ - session = get_session_or_404(session_id) - - if session.current_phase != LessonPhase.NOT_STARTED: - raise HTTPException( - status_code=400, - detail=f"Stunde bereits gestartet (aktuelle Phase: {session.current_phase.value})" - ) - - fsm = LessonStateMachine() - session = fsm.transition(session, LessonPhase.EINSTIEG) - - persist_session(session) - return build_session_response(session) - - -@router.post("/sessions/{session_id}/next-phase", response_model=SessionResponse) -async def next_phase(session_id: str) -> SessionResponse: - """ - Wechselt zur naechsten Phase. - - Wirft 400 wenn keine naechste Phase verfuegbar (z.B. bei ENDED). - """ - session = get_session_or_404(session_id) - - fsm = LessonStateMachine() - next_p = fsm.next_phase(session.current_phase) - - if not next_p: - raise HTTPException( - status_code=400, - detail=f"Keine naechste Phase verfuegbar (aktuelle Phase: {session.current_phase.value})" - ) - - session = fsm.transition(session, next_p) - persist_session(session) - - # WebSocket-Benachrichtigung - response = build_session_response(session) - await notify_phase_change(session_id, session.current_phase.value, { - "phase_display_name": session.get_phase_display_name(), - "is_ended": session.current_phase == LessonPhase.ENDED - }) - return response - - -@router.post("/sessions/{session_id}/end", response_model=SessionResponse) -async def end_lesson(session_id: str) -> SessionResponse: - """ - Beendet die Unterrichtsstunde sofort. - - Kann von jeder aktiven Phase aus aufgerufen werden. - """ - session = get_session_or_404(session_id) - - if session.current_phase == LessonPhase.ENDED: - raise HTTPException(status_code=400, detail="Stunde bereits beendet") - - if session.current_phase == LessonPhase.NOT_STARTED: - raise HTTPException(status_code=400, detail="Stunde noch nicht gestartet") - - # Direkt zur Endphase springen (ueberspringt evtl. Phasen) - fsm = LessonStateMachine() - - # Phasen bis zum Ende durchlaufen - while session.current_phase != LessonPhase.ENDED: - next_p = fsm.next_phase(session.current_phase) - if next_p: - session = fsm.transition(session, next_p) - else: - break - - persist_session(session) - - # WebSocket-Benachrichtigung - await notify_session_ended(session_id) - return build_session_response(session) - - -# === Quick Actions (Feature f26/f27/f28) === - -@router.post("/sessions/{session_id}/pause", response_model=SessionResponse) -async def toggle_pause(session_id: str) -> SessionResponse: - """ - Pausiert oder setzt die laufende Stunde fort (Feature f27). - - Toggle-Funktion: Wenn pausiert -> fortsetzen, wenn laufend -> pausieren. - Die Pause-Zeit wird nicht auf die Phasendauer angerechnet. - """ - session = get_session_or_404(session_id) - - # Nur aktive Phasen koennen pausiert werden - if session.current_phase in [LessonPhase.NOT_STARTED, LessonPhase.ENDED]: - raise HTTPException( - status_code=400, - detail="Stunde ist nicht aktiv" - ) - - if session.is_paused: - # Fortsetzen: Pause-Zeit zur Gesamt-Pause addieren - if session.pause_started_at: - pause_duration = (datetime.utcnow() - session.pause_started_at).total_seconds() - session.total_paused_seconds += int(pause_duration) - - session.is_paused = False - session.pause_started_at = None - else: - # Pausieren - session.is_paused = True - session.pause_started_at = datetime.utcnow() - - persist_session(session) - return build_session_response(session) - - -@router.post("/sessions/{session_id}/extend", response_model=SessionResponse) -async def extend_phase(session_id: str, request: ExtendTimeRequest) -> SessionResponse: - """ - Verlaengert die aktuelle Phase um zusaetzliche Minuten (Feature f28). - - Nuetzlich wenn mehr Zeit benoetigt wird, z.B. fuer vertiefte Diskussionen. - """ - session = get_session_or_404(session_id) - - # Nur aktive Phasen koennen verlaengert werden - if session.current_phase in [LessonPhase.NOT_STARTED, LessonPhase.ENDED]: - raise HTTPException( - status_code=400, - detail="Stunde ist nicht aktiv" - ) - - # Aktuelle Phasendauer erhoehen - phase_id = session.current_phase.value - current_duration = session.phase_durations.get(phase_id, 10) - session.phase_durations[phase_id] = current_duration + request.minutes - - persist_session(session) - return build_session_response(session) - - -@router.get("/sessions/{session_id}/timer", response_model=TimerStatus) -async def get_timer(session_id: str) -> TimerStatus: - """ - Ruft den Timer-Status der aktuellen Phase ab. - - Enthaelt verbleibende Zeit, Warnung und Overtime-Status. - Sollte alle 5 Sekunden gepollt werden. - """ - session = get_session_or_404(session_id) - timer = PhaseTimer() - status = timer.get_phase_status(session) - return TimerStatus(**status) - - -@router.get("/sessions/{session_id}/suggestions", response_model=SuggestionsResponse) -async def get_suggestions( - session_id: str, - limit: int = Query(3, ge=1, le=10, description="Anzahl Vorschlaege") -) -> SuggestionsResponse: - """ - Ruft phasenspezifische Aktivitaets-Vorschlaege ab. - - Die Vorschlaege aendern sich je nach aktueller Phase. - """ - session = get_session_or_404(session_id) - engine = SuggestionEngine() - response = engine.get_suggestions_response(session, limit) - - return SuggestionsResponse( - suggestions=[SuggestionItem(**s) for s in response["suggestions"]], - current_phase=response["current_phase"], - phase_display_name=response["phase_display_name"], - total_available=response["total_available"], - ) - - -@router.put("/sessions/{session_id}/notes", response_model=SessionResponse) -async def update_notes(session_id: str, request: NotesRequest) -> SessionResponse: - """ - Aktualisiert Notizen und Hausaufgaben der Stunde. - """ - session = get_session_or_404(session_id) - session.notes = request.notes - session.homework = request.homework - persist_session(session) - return build_session_response(session) - - -@router.delete("/sessions/{session_id}") -async def delete_session(session_id: str) -> Dict[str, str]: - """ - Loescht eine Session. - """ - if session_id not in sessions: - raise HTTPException(status_code=404, detail="Session nicht gefunden") - - del sessions[session_id] - - # Auch aus DB loeschen - if DB_ENABLED: - try: - from ..services.persistence import delete_session_from_db - delete_session_from_db(session_id) - except Exception as e: - logger.error(f"Failed to delete session {session_id} from DB: {e}") - - return {"status": "deleted", "session_id": session_id} - - -# === Session History (Feature f17) === - -@router.get("/history/{teacher_id}", response_model=SessionHistoryResponse) -async def get_session_history( - teacher_id: str, - limit: int = Query(20, ge=1, le=100, description="Max. Anzahl Eintraege"), - offset: int = Query(0, ge=0, description="Offset fuer Pagination") -) -> SessionHistoryResponse: - """ - Ruft die Session-History eines Lehrers ab (Feature f17). - - Zeigt abgeschlossene Unterrichtsstunden mit Statistiken. - Nur verfuegbar wenn DB aktiviert ist. - """ - init_db_if_needed() - - if not DB_ENABLED: - # Fallback: In-Memory Sessions filtern - ended_sessions = [ - s for s in sessions.values() - if s.teacher_id == teacher_id and s.current_phase == LessonPhase.ENDED - ] - ended_sessions.sort( - key=lambda x: x.lesson_ended_at or datetime.min, - reverse=True - ) - paginated = ended_sessions[offset:offset + limit] - - items = [] - for s in paginated: - duration = None - if s.lesson_started_at and s.lesson_ended_at: - duration = int((s.lesson_ended_at - s.lesson_started_at).total_seconds() / 60) - - items.append(SessionHistoryItem( - session_id=s.session_id, - teacher_id=s.teacher_id, - class_id=s.class_id, - subject=s.subject, - topic=s.topic, - lesson_started_at=s.lesson_started_at.isoformat() if s.lesson_started_at else None, - lesson_ended_at=s.lesson_ended_at.isoformat() if s.lesson_ended_at else None, - total_duration_minutes=duration, - phases_completed=len(s.phase_history), - notes=s.notes, - homework=s.homework, - )) - - return SessionHistoryResponse( - sessions=items, - total_count=len(ended_sessions), - limit=limit, - offset=offset, - ) - - # DB-basierte History - try: - from classroom_engine.repository import SessionRepository - db = SessionLocal() - repo = SessionRepository(db) - - # Beendete Sessions abrufen - db_sessions = repo.get_history_by_teacher(teacher_id, limit, offset) - - # Gesamtanzahl ermitteln - from classroom_engine.db_models import LessonSessionDB, LessonPhaseEnum - total_count = db.query(LessonSessionDB).filter( - LessonSessionDB.teacher_id == teacher_id, - LessonSessionDB.current_phase == LessonPhaseEnum.ENDED - ).count() - - items = [] - for db_session in db_sessions: - duration = None - if db_session.lesson_started_at and db_session.lesson_ended_at: - duration = int((db_session.lesson_ended_at - db_session.lesson_started_at).total_seconds() / 60) - - phase_history = db_session.phase_history or [] - - items.append(SessionHistoryItem( - session_id=db_session.id, - teacher_id=db_session.teacher_id, - class_id=db_session.class_id, - subject=db_session.subject, - topic=db_session.topic, - lesson_started_at=db_session.lesson_started_at.isoformat() if db_session.lesson_started_at else None, - lesson_ended_at=db_session.lesson_ended_at.isoformat() if db_session.lesson_ended_at else None, - total_duration_minutes=duration, - phases_completed=len(phase_history), - notes=db_session.notes or "", - homework=db_session.homework or "", - )) - - db.close() - - return SessionHistoryResponse( - sessions=items, - total_count=total_count, - limit=limit, - offset=offset, - ) - - except Exception as e: - logger.error(f"Failed to get session history: {e}") - raise HTTPException(status_code=500, detail="Fehler beim Laden der History") - - -# === Utility Endpoints === - -@router.get("/phases", response_model=PhasesListResponse) -async def list_phases() -> PhasesListResponse: - """ - Listet alle verfuegbaren Unterrichtsphasen mit Metadaten. - """ - phases = [] - for phase_id, config in LESSON_PHASES.items(): - phases.append({ - "phase": phase_id, - "display_name": config["display_name"], - "default_duration_minutes": config["default_duration_minutes"], - "activities": config["activities"], - "icon": config["icon"], - "description": config.get("description", ""), - }) - return PhasesListResponse(phases=phases) - - -@router.get("/sessions", response_model=ActiveSessionsResponse) -async def list_active_sessions( - teacher_id: Optional[str] = Query(None, description="Filter nach Lehrer") -) -> ActiveSessionsResponse: - """ - Listet alle (optionally gefilterten) Sessions. - """ - sessions_list = [] - for session in sessions.values(): - if teacher_id and session.teacher_id != teacher_id: - continue - - fsm = LessonStateMachine() - sessions_list.append({ - "session_id": session.session_id, - "teacher_id": session.teacher_id, - "class_id": session.class_id, - "subject": session.subject, - "current_phase": session.current_phase.value, - "is_active": fsm.is_lesson_active(session), - "lesson_started_at": session.lesson_started_at.isoformat() if session.lesson_started_at else None, - }) - - return ActiveSessionsResponse( - sessions=sessions_list, - count=len(sessions_list) - ) - - -@router.get("/health") -async def health_check() -> Dict[str, Any]: - """ - Health-Check fuer den Classroom Service. - """ - from sqlalchemy import text - - db_status = "disabled" - if DB_ENABLED: - try: - db = SessionLocal() - db.execute(text("SELECT 1")) - db.close() - db_status = "connected" - except Exception as e: - db_status = f"error: {str(e)}" - - return { - "status": "healthy", - "service": "classroom-engine", - "active_sessions": len(sessions), - "db_enabled": DB_ENABLED, - "db_status": db_status, - "timestamp": datetime.utcnow().isoformat(), - } +# Re-export for backward compatibility +__all__ = ["router", "build_session_response"] diff --git a/backend-lehrer/classroom/routes/sessions_actions.py b/backend-lehrer/classroom/routes/sessions_actions.py new file mode 100644 index 0000000..3f23693 --- /dev/null +++ b/backend-lehrer/classroom/routes/sessions_actions.py @@ -0,0 +1,173 @@ +""" +Classroom API - Session Actions Routes + +Quick actions (pause, extend, timer), suggestions, utility endpoints. +""" + +from typing import Dict, Optional, Any +from datetime import datetime +import logging + +from fastapi import APIRouter, HTTPException, Query +from sqlalchemy import text + +from classroom_engine import ( + LessonPhase, + LessonStateMachine, + PhaseTimer, + SuggestionEngine, + LESSON_PHASES, +) + +from ..models import ( + ExtendTimeRequest, + TimerStatus, + SuggestionItem, + SuggestionsResponse, + PhasesListResponse, + ActiveSessionsResponse, +) +from ..services.persistence import ( + sessions, + persist_session, + get_session_or_404, + DB_ENABLED, + SessionLocal, +) +from .sessions_core import build_session_response, SessionResponse + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Sessions"]) + + +# === Quick Actions (Feature f26/f27/f28) === + +@router.post("/sessions/{session_id}/pause", response_model=SessionResponse) +async def toggle_pause(session_id: str) -> SessionResponse: + """Pausiert oder setzt die laufende Stunde fort (Feature f27).""" + session = get_session_or_404(session_id) + + if session.current_phase in [LessonPhase.NOT_STARTED, LessonPhase.ENDED]: + raise HTTPException(status_code=400, detail="Stunde ist nicht aktiv") + + if session.is_paused: + if session.pause_started_at: + pause_duration = (datetime.utcnow() - session.pause_started_at).total_seconds() + session.total_paused_seconds += int(pause_duration) + session.is_paused = False + session.pause_started_at = None + else: + session.is_paused = True + session.pause_started_at = datetime.utcnow() + + persist_session(session) + return build_session_response(session) + + +@router.post("/sessions/{session_id}/extend", response_model=SessionResponse) +async def extend_phase(session_id: str, request: ExtendTimeRequest) -> SessionResponse: + """Verlaengert die aktuelle Phase um zusaetzliche Minuten (Feature f28).""" + session = get_session_or_404(session_id) + + if session.current_phase in [LessonPhase.NOT_STARTED, LessonPhase.ENDED]: + raise HTTPException(status_code=400, detail="Stunde ist nicht aktiv") + + phase_id = session.current_phase.value + current_duration = session.phase_durations.get(phase_id, 10) + session.phase_durations[phase_id] = current_duration + request.minutes + + persist_session(session) + return build_session_response(session) + + +@router.get("/sessions/{session_id}/timer", response_model=TimerStatus) +async def get_timer(session_id: str) -> TimerStatus: + """Ruft den Timer-Status der aktuellen Phase ab.""" + session = get_session_or_404(session_id) + timer = PhaseTimer() + status = timer.get_phase_status(session) + return TimerStatus(**status) + + +@router.get("/sessions/{session_id}/suggestions", response_model=SuggestionsResponse) +async def get_suggestions( + session_id: str, + limit: int = Query(3, ge=1, le=10) +) -> SuggestionsResponse: + """Ruft phasenspezifische Aktivitaets-Vorschlaege ab.""" + session = get_session_or_404(session_id) + engine = SuggestionEngine() + response = engine.get_suggestions_response(session, limit) + + return SuggestionsResponse( + suggestions=[SuggestionItem(**s) for s in response["suggestions"]], + current_phase=response["current_phase"], + phase_display_name=response["phase_display_name"], + total_available=response["total_available"], + ) + + +# === Utility Endpoints === + +@router.get("/phases", response_model=PhasesListResponse) +async def list_phases() -> PhasesListResponse: + """Listet alle verfuegbaren Unterrichtsphasen mit Metadaten.""" + phases = [] + for phase_id, config in LESSON_PHASES.items(): + phases.append({ + "phase": phase_id, + "display_name": config["display_name"], + "default_duration_minutes": config["default_duration_minutes"], + "activities": config["activities"], + "icon": config["icon"], + "description": config.get("description", ""), + }) + return PhasesListResponse(phases=phases) + + +@router.get("/sessions", response_model=ActiveSessionsResponse) +async def list_active_sessions( + teacher_id: Optional[str] = Query(None) +) -> ActiveSessionsResponse: + """Listet alle (optionally gefilterten) Sessions.""" + sessions_list = [] + for session in sessions.values(): + if teacher_id and session.teacher_id != teacher_id: + continue + + fsm = LessonStateMachine() + sessions_list.append({ + "session_id": session.session_id, + "teacher_id": session.teacher_id, + "class_id": session.class_id, + "subject": session.subject, + "current_phase": session.current_phase.value, + "is_active": fsm.is_lesson_active(session), + "lesson_started_at": session.lesson_started_at.isoformat() if session.lesson_started_at else None, + }) + + return ActiveSessionsResponse(sessions=sessions_list, count=len(sessions_list)) + + +@router.get("/health") +async def health_check() -> Dict[str, Any]: + """Health-Check fuer den Classroom Service.""" + db_status = "disabled" + if DB_ENABLED: + try: + db = SessionLocal() + db.execute(text("SELECT 1")) + db.close() + db_status = "connected" + except Exception as e: + db_status = f"error: {str(e)}" + + return { + "status": "healthy", + "service": "classroom-engine", + "active_sessions": len(sessions), + "db_enabled": DB_ENABLED, + "db_status": db_status, + "timestamp": datetime.utcnow().isoformat(), + } diff --git a/backend-lehrer/classroom/routes/sessions_core.py b/backend-lehrer/classroom/routes/sessions_core.py new file mode 100644 index 0000000..1f4e529 --- /dev/null +++ b/backend-lehrer/classroom/routes/sessions_core.py @@ -0,0 +1,283 @@ +""" +Classroom API - Session Core Routes + +Session CRUD, lifecycle, and history endpoints. +""" + +from uuid import uuid4 +from typing import Dict, Optional, Any +from datetime import datetime +import logging + +from fastapi import APIRouter, HTTPException, Query + +from classroom_engine import ( + LessonPhase, + LessonSession, + LessonStateMachine, + PhaseTimer, + LESSON_PHASES, +) + +from ..models import ( + CreateSessionRequest, + NotesRequest, + PhaseInfo, + TimerStatus, + SessionResponse, + PhasesListResponse, + ActiveSessionsResponse, + SessionHistoryItem, + SessionHistoryResponse, +) +from ..services.persistence import ( + sessions, + init_db_if_needed, + persist_session, + get_session_or_404, + DB_ENABLED, + SessionLocal, +) +from ..websocket_manager import notify_phase_change, notify_session_ended + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Sessions"]) + + +def build_session_response(session: LessonSession) -> SessionResponse: + """Baut die vollstaendige Session-Response.""" + fsm = LessonStateMachine() + timer = PhaseTimer() + + timer_status = timer.get_phase_status(session) + phases_info = fsm.get_phases_info(session) + + return SessionResponse( + session_id=session.session_id, + teacher_id=session.teacher_id, + class_id=session.class_id, + subject=session.subject, + topic=session.topic, + current_phase=session.current_phase.value, + phase_display_name=session.get_phase_display_name(), + phase_started_at=session.phase_started_at.isoformat() if session.phase_started_at else None, + lesson_started_at=session.lesson_started_at.isoformat() if session.lesson_started_at else None, + lesson_ended_at=session.lesson_ended_at.isoformat() if session.lesson_ended_at else None, + timer=TimerStatus(**timer_status), + phases=[PhaseInfo(**p) for p in phases_info], + phase_history=session.phase_history, + notes=session.notes, + homework=session.homework, + is_active=fsm.is_lesson_active(session), + is_ended=fsm.is_lesson_ended(session), + is_paused=session.is_paused, + ) + + +# === Session CRUD Endpoints === + +@router.post("/sessions", response_model=SessionResponse) +async def create_session(request: CreateSessionRequest) -> SessionResponse: + """Erstellt eine neue Unterrichtsstunde (Session).""" + init_db_if_needed() + + phase_durations = { + "einstieg": 8, "erarbeitung": 20, "sicherung": 10, + "transfer": 7, "reflexion": 5, + } + if request.phase_durations: + phase_durations.update(request.phase_durations) + + session = LessonSession( + session_id=str(uuid4()), + teacher_id=request.teacher_id, + class_id=request.class_id, + subject=request.subject, + topic=request.topic, + phase_durations=phase_durations, + ) + + sessions[session.session_id] = session + persist_session(session) + return build_session_response(session) + + +@router.get("/sessions/{session_id}", response_model=SessionResponse) +async def get_session(session_id: str) -> SessionResponse: + """Ruft den aktuellen Status einer Session ab.""" + session = get_session_or_404(session_id) + return build_session_response(session) + + +@router.post("/sessions/{session_id}/start", response_model=SessionResponse) +async def start_lesson(session_id: str) -> SessionResponse: + """Startet die Unterrichtsstunde.""" + session = get_session_or_404(session_id) + + if session.current_phase != LessonPhase.NOT_STARTED: + raise HTTPException( + status_code=400, + detail=f"Stunde bereits gestartet (aktuelle Phase: {session.current_phase.value})" + ) + + fsm = LessonStateMachine() + session = fsm.transition(session, LessonPhase.EINSTIEG) + persist_session(session) + return build_session_response(session) + + +@router.post("/sessions/{session_id}/next-phase", response_model=SessionResponse) +async def next_phase(session_id: str) -> SessionResponse: + """Wechselt zur naechsten Phase.""" + session = get_session_or_404(session_id) + + fsm = LessonStateMachine() + next_p = fsm.next_phase(session.current_phase) + + if not next_p: + raise HTTPException( + status_code=400, + detail=f"Keine naechste Phase verfuegbar (aktuelle Phase: {session.current_phase.value})" + ) + + session = fsm.transition(session, next_p) + persist_session(session) + + response = build_session_response(session) + await notify_phase_change(session_id, session.current_phase.value, { + "phase_display_name": session.get_phase_display_name(), + "is_ended": session.current_phase == LessonPhase.ENDED + }) + return response + + +@router.post("/sessions/{session_id}/end", response_model=SessionResponse) +async def end_lesson(session_id: str) -> SessionResponse: + """Beendet die Unterrichtsstunde sofort.""" + session = get_session_or_404(session_id) + + if session.current_phase == LessonPhase.ENDED: + raise HTTPException(status_code=400, detail="Stunde bereits beendet") + if session.current_phase == LessonPhase.NOT_STARTED: + raise HTTPException(status_code=400, detail="Stunde noch nicht gestartet") + + fsm = LessonStateMachine() + while session.current_phase != LessonPhase.ENDED: + next_p = fsm.next_phase(session.current_phase) + if next_p: + session = fsm.transition(session, next_p) + else: + break + + persist_session(session) + await notify_session_ended(session_id) + return build_session_response(session) + + +@router.put("/sessions/{session_id}/notes", response_model=SessionResponse) +async def update_notes(session_id: str, request: NotesRequest) -> SessionResponse: + """Aktualisiert Notizen und Hausaufgaben der Stunde.""" + session = get_session_or_404(session_id) + session.notes = request.notes + session.homework = request.homework + persist_session(session) + return build_session_response(session) + + +@router.delete("/sessions/{session_id}") +async def delete_session(session_id: str) -> Dict[str, str]: + """Loescht eine Session.""" + if session_id not in sessions: + raise HTTPException(status_code=404, detail="Session nicht gefunden") + + del sessions[session_id] + + if DB_ENABLED: + try: + from ..services.persistence import delete_session_from_db + delete_session_from_db(session_id) + except Exception as e: + logger.error(f"Failed to delete session {session_id} from DB: {e}") + + return {"status": "deleted", "session_id": session_id} + + +# === Session History (Feature f17) === + +@router.get("/history/{teacher_id}", response_model=SessionHistoryResponse) +async def get_session_history( + teacher_id: str, + limit: int = Query(20, ge=1, le=100), + offset: int = Query(0, ge=0) +) -> SessionHistoryResponse: + """Ruft die Session-History eines Lehrers ab (Feature f17).""" + init_db_if_needed() + + if not DB_ENABLED: + ended_sessions = [ + s for s in sessions.values() + if s.teacher_id == teacher_id and s.current_phase == LessonPhase.ENDED + ] + ended_sessions.sort(key=lambda x: x.lesson_ended_at or datetime.min, reverse=True) + paginated = ended_sessions[offset:offset + limit] + + items = [] + for s in paginated: + duration = None + if s.lesson_started_at and s.lesson_ended_at: + duration = int((s.lesson_ended_at - s.lesson_started_at).total_seconds() / 60) + + items.append(SessionHistoryItem( + session_id=s.session_id, teacher_id=s.teacher_id, + class_id=s.class_id, subject=s.subject, topic=s.topic, + lesson_started_at=s.lesson_started_at.isoformat() if s.lesson_started_at else None, + lesson_ended_at=s.lesson_ended_at.isoformat() if s.lesson_ended_at else None, + total_duration_minutes=duration, + phases_completed=len(s.phase_history), + notes=s.notes, homework=s.homework, + )) + + return SessionHistoryResponse( + sessions=items, total_count=len(ended_sessions), limit=limit, offset=offset, + ) + + try: + from classroom_engine.repository import SessionRepository + db = SessionLocal() + repo = SessionRepository(db) + db_sessions = repo.get_history_by_teacher(teacher_id, limit, offset) + + from classroom_engine.db_models import LessonSessionDB, LessonPhaseEnum + total_count = db.query(LessonSessionDB).filter( + LessonSessionDB.teacher_id == teacher_id, + LessonSessionDB.current_phase == LessonPhaseEnum.ENDED + ).count() + + items = [] + for db_session in db_sessions: + duration = None + if db_session.lesson_started_at and db_session.lesson_ended_at: + duration = int((db_session.lesson_ended_at - db_session.lesson_started_at).total_seconds() / 60) + + phase_history = db_session.phase_history or [] + + items.append(SessionHistoryItem( + session_id=db_session.id, teacher_id=db_session.teacher_id, + class_id=db_session.class_id, subject=db_session.subject, topic=db_session.topic, + lesson_started_at=db_session.lesson_started_at.isoformat() if db_session.lesson_started_at else None, + lesson_ended_at=db_session.lesson_ended_at.isoformat() if db_session.lesson_ended_at else None, + total_duration_minutes=duration, + phases_completed=len(phase_history), + notes=db_session.notes or "", homework=db_session.homework or "", + )) + + db.close() + + return SessionHistoryResponse( + sessions=items, total_count=total_count, limit=limit, offset=offset, + ) + + except Exception as e: + logger.error(f"Failed to get session history: {e}") + raise HTTPException(status_code=500, detail="Fehler beim Laden der History") diff --git a/backend-lehrer/classroom_engine/__init__.py b/backend-lehrer/classroom_engine/__init__.py index 23284bb..4362c38 100644 --- a/backend-lehrer/classroom_engine/__init__.py +++ b/backend-lehrer/classroom_engine/__init__.py @@ -32,7 +32,8 @@ from .models import ( ) from .fsm import LessonStateMachine from .timer import PhaseTimer -from .suggestions import SuggestionEngine, PHASE_SUGGESTIONS, SUBJECT_SUGGESTIONS +from .suggestions import SuggestionEngine +from .suggestion_data import PHASE_SUGGESTIONS, SUBJECT_SUGGESTIONS from .context_models import ( MacroPhaseEnum, EventTypeEnum, diff --git a/backend-lehrer/classroom_engine/analytics.py b/backend-lehrer/classroom_engine/analytics.py index f0d4fa1..ca4dad1 100644 --- a/backend-lehrer/classroom_engine/analytics.py +++ b/backend-lehrer/classroom_engine/analytics.py @@ -11,256 +11,28 @@ WICHTIG: Keine wertenden Metriken (z.B. "Sie haben 70% geredet"). Fokus auf neutrale, hilfreiche Statistiken. """ -from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import datetime from typing import Optional, List, Dict, Any -from enum import Enum +from .analytics_models import ( + PhaseStatistics, + SessionSummary, + TeacherAnalytics, + LessonReflection, +) -# ==================== Analytics Models ==================== +# Re-export models for backward compatibility +__all__ = [ + "PhaseStatistics", + "SessionSummary", + "TeacherAnalytics", + "LessonReflection", + "AnalyticsCalculator", +] -@dataclass -class PhaseStatistics: - """Statistik fuer eine einzelne Phase.""" - phase: str - display_name: str - - # Dauer-Metriken - planned_duration_seconds: int - actual_duration_seconds: int - difference_seconds: int # positiv = laenger als geplant - - # Overtime - had_overtime: bool - overtime_seconds: int = 0 - - # Erweiterungen - was_extended: bool = False - extension_minutes: int = 0 - - # Pausen - pause_count: int = 0 - total_pause_seconds: int = 0 - - def to_dict(self) -> Dict[str, Any]: - return { - "phase": self.phase, - "display_name": self.display_name, - "planned_duration_seconds": self.planned_duration_seconds, - "actual_duration_seconds": self.actual_duration_seconds, - "difference_seconds": self.difference_seconds, - "difference_formatted": self._format_difference(), - "had_overtime": self.had_overtime, - "overtime_seconds": self.overtime_seconds, - "overtime_formatted": self._format_seconds(self.overtime_seconds), - "was_extended": self.was_extended, - "extension_minutes": self.extension_minutes, - "pause_count": self.pause_count, - "total_pause_seconds": self.total_pause_seconds, - } - - def _format_difference(self) -> str: - """Formatiert die Differenz als +/-MM:SS.""" - prefix = "+" if self.difference_seconds >= 0 else "" - return f"{prefix}{self._format_seconds(abs(self.difference_seconds))}" - - def _format_seconds(self, seconds: int) -> str: - """Formatiert Sekunden als MM:SS.""" - mins = seconds // 60 - secs = seconds % 60 - return f"{mins:02d}:{secs:02d}" - - -@dataclass -class SessionSummary: - """ - Zusammenfassung einer Unterrichtsstunde. - - Wird nach Stundenende generiert und fuer das Lehrer-Dashboard verwendet. - """ - session_id: str - teacher_id: str - class_id: str - subject: str - topic: Optional[str] - date: datetime - - # Dauer - total_duration_seconds: int - planned_duration_seconds: int - - # Phasen-Statistiken - phases_completed: int - total_phases: int = 5 - phase_statistics: List[PhaseStatistics] = field(default_factory=list) - - # Overtime-Zusammenfassung - total_overtime_seconds: int = 0 - phases_with_overtime: int = 0 - - # Pausen-Zusammenfassung - total_pause_count: int = 0 - total_pause_seconds: int = 0 - - # Post-Lesson Reflection - reflection_notes: str = "" - reflection_rating: Optional[int] = None # 1-5 Sterne (optional) - key_learnings: List[str] = field(default_factory=list) - - def to_dict(self) -> Dict[str, Any]: - return { - "session_id": self.session_id, - "teacher_id": self.teacher_id, - "class_id": self.class_id, - "subject": self.subject, - "topic": self.topic, - "date": self.date.isoformat() if self.date else None, - "date_formatted": self._format_date(), - "total_duration_seconds": self.total_duration_seconds, - "total_duration_formatted": self._format_seconds(self.total_duration_seconds), - "planned_duration_seconds": self.planned_duration_seconds, - "planned_duration_formatted": self._format_seconds(self.planned_duration_seconds), - "phases_completed": self.phases_completed, - "total_phases": self.total_phases, - "completion_percentage": round(self.phases_completed / self.total_phases * 100), - "phase_statistics": [p.to_dict() for p in self.phase_statistics], - "total_overtime_seconds": self.total_overtime_seconds, - "total_overtime_formatted": self._format_seconds(self.total_overtime_seconds), - "phases_with_overtime": self.phases_with_overtime, - "total_pause_count": self.total_pause_count, - "total_pause_seconds": self.total_pause_seconds, - "reflection_notes": self.reflection_notes, - "reflection_rating": self.reflection_rating, - "key_learnings": self.key_learnings, - } - - def _format_seconds(self, seconds: int) -> str: - mins = seconds // 60 - secs = seconds % 60 - return f"{mins:02d}:{secs:02d}" - - def _format_date(self) -> str: - if not self.date: - return "" - return self.date.strftime("%d.%m.%Y %H:%M") - - -@dataclass -class TeacherAnalytics: - """ - Aggregierte Statistiken fuer einen Lehrer. - - Zeigt Trends und Muster ueber mehrere Stunden. - """ - teacher_id: str - period_start: datetime - period_end: datetime - - # Stunden-Uebersicht - total_sessions: int = 0 - completed_sessions: int = 0 - total_teaching_minutes: int = 0 - - # Durchschnittliche Phasendauern - avg_phase_durations: Dict[str, float] = field(default_factory=dict) - - # Overtime-Trends - sessions_with_overtime: int = 0 - avg_overtime_seconds: float = 0 - most_overtime_phase: Optional[str] = None - - # Pausen-Statistik - avg_pause_count: float = 0 - avg_pause_duration_seconds: float = 0 - - # Faecher-Verteilung - subjects_taught: Dict[str, int] = field(default_factory=dict) - - # Klassen-Verteilung - classes_taught: Dict[str, int] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: - return { - "teacher_id": self.teacher_id, - "period_start": self.period_start.isoformat() if self.period_start else None, - "period_end": self.period_end.isoformat() if self.period_end else None, - "total_sessions": self.total_sessions, - "completed_sessions": self.completed_sessions, - "total_teaching_minutes": self.total_teaching_minutes, - "total_teaching_hours": round(self.total_teaching_minutes / 60, 1), - "avg_phase_durations": self.avg_phase_durations, - "sessions_with_overtime": self.sessions_with_overtime, - "overtime_percentage": round(self.sessions_with_overtime / max(self.total_sessions, 1) * 100), - "avg_overtime_seconds": round(self.avg_overtime_seconds), - "avg_overtime_formatted": self._format_seconds(int(self.avg_overtime_seconds)), - "most_overtime_phase": self.most_overtime_phase, - "avg_pause_count": round(self.avg_pause_count, 1), - "avg_pause_duration_seconds": round(self.avg_pause_duration_seconds), - "subjects_taught": self.subjects_taught, - "classes_taught": self.classes_taught, - } - - def _format_seconds(self, seconds: int) -> str: - mins = seconds // 60 - secs = seconds % 60 - return f"{mins:02d}:{secs:02d}" - - -# ==================== Reflection Model ==================== - -@dataclass -class LessonReflection: - """ - Post-Lesson Reflection (Feature). - - Ermoeglicht Lehrern, nach der Stunde Notizen zu machen. - Keine Bewertung, nur Reflexion. - """ - reflection_id: str - session_id: str - teacher_id: str - - # Reflexionsnotizen - notes: str = "" - - # Optional: Sterne-Bewertung (selbst-eingeschaetzt) - overall_rating: Optional[int] = None # 1-5 - - # Was hat gut funktioniert? - what_worked: List[str] = field(default_factory=list) - - # Was wuerde ich naechstes Mal anders machen? - improvements: List[str] = field(default_factory=list) - - # Notizen fuer naechste Stunde - notes_for_next_lesson: str = "" - - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - - def to_dict(self) -> Dict[str, Any]: - return { - "reflection_id": self.reflection_id, - "session_id": self.session_id, - "teacher_id": self.teacher_id, - "notes": self.notes, - "overall_rating": self.overall_rating, - "what_worked": self.what_worked, - "improvements": self.improvements, - "notes_for_next_lesson": self.notes_for_next_lesson, - "created_at": self.created_at.isoformat() if self.created_at else None, - "updated_at": self.updated_at.isoformat() if self.updated_at else None, - } - - -# ==================== Analytics Calculator ==================== class AnalyticsCalculator: - """ - Berechnet Analytics aus Session-Daten. - - Verwendet In-Memory-Daten oder DB-Daten. - """ + """Berechnet Analytics aus Session-Daten.""" PHASE_DISPLAY_NAMES = { "einstieg": "Einstieg", @@ -276,24 +48,13 @@ class AnalyticsCalculator: session_data: Dict[str, Any], phase_history: List[Dict[str, Any]] ) -> SessionSummary: - """ - Berechnet die Zusammenfassung einer Session. - - Args: - session_data: Session-Dictionary (aus LessonSession.to_dict()) - phase_history: Liste der Phasen-History-Eintraege - - Returns: - SessionSummary mit allen berechneten Statistiken - """ - # Basis-Daten + """Berechnet die Zusammenfassung einer Session.""" session_id = session_data.get("session_id", "") teacher_id = session_data.get("teacher_id", "") class_id = session_data.get("class_id", "") subject = session_data.get("subject", "") topic = session_data.get("topic") - # Timestamps lesson_started = session_data.get("lesson_started_at") lesson_ended = session_data.get("lesson_ended_at") @@ -302,16 +63,13 @@ class AnalyticsCalculator: if isinstance(lesson_ended, str): lesson_ended = datetime.fromisoformat(lesson_ended.replace("Z", "+00:00")) - # Dauer berechnen total_duration = 0 if lesson_started and lesson_ended: total_duration = int((lesson_ended - lesson_started).total_seconds()) - # Geplante Dauer phase_durations = session_data.get("phase_durations", {}) - planned_duration = sum(phase_durations.values()) * 60 # Minuten zu Sekunden + planned_duration = sum(phase_durations.values()) * 60 - # Phasen-Statistiken berechnen phase_stats = [] total_overtime = 0 phases_with_overtime = 0 @@ -324,18 +82,10 @@ class AnalyticsCalculator: if phase in ["not_started", "ended"]: continue - # Geplante Dauer fuer diese Phase planned_seconds = phase_durations.get(phase, 0) * 60 - - # Tatsaechliche Dauer - actual_seconds = entry.get("duration_seconds", 0) - if actual_seconds is None: - actual_seconds = 0 - - # Differenz + actual_seconds = entry.get("duration_seconds", 0) or 0 difference = actual_seconds - planned_seconds - # Overtime (nur positive Differenz zaehlt) had_overtime = difference > 0 overtime_seconds = max(0, difference) @@ -343,13 +93,11 @@ class AnalyticsCalculator: total_overtime += overtime_seconds phases_with_overtime += 1 - # Pausen pause_count = entry.get("pause_count", 0) or 0 pause_seconds = entry.get("total_pause_seconds", 0) or 0 total_pause_count += pause_count total_pause_seconds += pause_seconds - # Phase als abgeschlossen zaehlen if entry.get("ended_at"): phases_completed += 1 @@ -368,16 +116,12 @@ class AnalyticsCalculator: )) return SessionSummary( - session_id=session_id, - teacher_id=teacher_id, - class_id=class_id, - subject=subject, - topic=topic, + session_id=session_id, teacher_id=teacher_id, + class_id=class_id, subject=subject, topic=topic, date=lesson_started or datetime.now(), total_duration_seconds=total_duration, planned_duration_seconds=planned_duration, - phases_completed=phases_completed, - total_phases=5, + phases_completed=phases_completed, total_phases=5, phase_statistics=phase_stats, total_overtime_seconds=total_overtime, phases_with_overtime=phases_with_overtime, @@ -392,31 +136,15 @@ class AnalyticsCalculator: period_start: datetime, period_end: datetime ) -> TeacherAnalytics: - """ - Berechnet aggregierte Statistiken fuer einen Lehrer. - - Args: - sessions: Liste von Session-Dictionaries - period_start: Beginn des Zeitraums - period_end: Ende des Zeitraums - - Returns: - TeacherAnalytics mit aggregierten Statistiken - """ + """Berechnet aggregierte Statistiken fuer einen Lehrer.""" if not sessions: - return TeacherAnalytics( - teacher_id="", - period_start=period_start, - period_end=period_end, - ) + return TeacherAnalytics(teacher_id="", period_start=period_start, period_end=period_end) teacher_id = sessions[0].get("teacher_id", "") - # Basis-Zaehler total_sessions = len(sessions) completed_sessions = sum(1 for s in sessions if s.get("lesson_ended_at")) - # Gesamtdauer berechnen total_minutes = 0 for session in sessions: started = session.get("lesson_started_at") @@ -428,41 +156,29 @@ class AnalyticsCalculator: ended = datetime.fromisoformat(ended.replace("Z", "+00:00")) total_minutes += (ended - started).total_seconds() / 60 - # Durchschnittliche Phasendauern phase_durations_sum: Dict[str, List[int]] = { - "einstieg": [], - "erarbeitung": [], - "sicherung": [], - "transfer": [], - "reflexion": [], + "einstieg": [], "erarbeitung": [], "sicherung": [], + "transfer": [], "reflexion": [], } - # Overtime-Tracking overtime_count = 0 overtime_seconds_total = 0 phase_overtime: Dict[str, int] = {} - - # Pausen-Tracking pause_counts = [] pause_durations = [] - - # Faecher und Klassen subjects: Dict[str, int] = {} classes: Dict[str, int] = {} for session in sessions: - # Fach und Klasse zaehlen subject = session.get("subject", "") class_id = session.get("class_id", "") subjects[subject] = subjects.get(subject, 0) + 1 classes[class_id] = classes.get(class_id, 0) + 1 - # Phase History analysieren history = session.get("phase_history", []) session_has_overtime = False session_pause_count = 0 session_pause_duration = 0 - phase_durations_dict = session.get("phase_durations", {}) for entry in history: @@ -471,7 +187,6 @@ class AnalyticsCalculator: duration = entry.get("duration_seconds", 0) or 0 phase_durations_sum[phase].append(duration) - # Overtime berechnen planned = phase_durations_dict.get(phase, 0) * 60 if duration > planned: overtime = duration - planned @@ -479,35 +194,25 @@ class AnalyticsCalculator: session_has_overtime = True phase_overtime[phase] = phase_overtime.get(phase, 0) + overtime - # Pausen zaehlen session_pause_count += entry.get("pause_count", 0) or 0 session_pause_duration += entry.get("total_pause_seconds", 0) or 0 if session_has_overtime: overtime_count += 1 - pause_counts.append(session_pause_count) pause_durations.append(session_pause_duration) - # Durchschnitte berechnen avg_durations = {} for phase, durations in phase_durations_sum.items(): - if durations: - avg_durations[phase] = round(sum(durations) / len(durations)) - else: - avg_durations[phase] = 0 + avg_durations[phase] = round(sum(durations) / len(durations)) if durations else 0 - # Phase mit meistem Overtime finden most_overtime_phase = None if phase_overtime: most_overtime_phase = max(phase_overtime, key=phase_overtime.get) return TeacherAnalytics( - teacher_id=teacher_id, - period_start=period_start, - period_end=period_end, - total_sessions=total_sessions, - completed_sessions=completed_sessions, + teacher_id=teacher_id, period_start=period_start, period_end=period_end, + total_sessions=total_sessions, completed_sessions=completed_sessions, total_teaching_minutes=int(total_minutes), avg_phase_durations=avg_durations, sessions_with_overtime=overtime_count, @@ -515,6 +220,5 @@ class AnalyticsCalculator: most_overtime_phase=most_overtime_phase, avg_pause_count=sum(pause_counts) / max(len(pause_counts), 1), avg_pause_duration_seconds=sum(pause_durations) / max(len(pause_durations), 1), - subjects_taught=subjects, - classes_taught=classes, + subjects_taught=subjects, classes_taught=classes, ) diff --git a/backend-lehrer/classroom_engine/analytics_models.py b/backend-lehrer/classroom_engine/analytics_models.py new file mode 100644 index 0000000..c18008e --- /dev/null +++ b/backend-lehrer/classroom_engine/analytics_models.py @@ -0,0 +1,205 @@ +""" +Analytics Models - Datenstrukturen fuer Classroom Analytics. + +Enthaelt PhaseStatistics, SessionSummary, TeacherAnalytics, LessonReflection. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional, List, Dict, Any + + +@dataclass +class PhaseStatistics: + """Statistik fuer eine einzelne Phase.""" + phase: str + display_name: str + + # Dauer-Metriken + planned_duration_seconds: int + actual_duration_seconds: int + difference_seconds: int # positiv = laenger als geplant + + # Overtime + had_overtime: bool + overtime_seconds: int = 0 + + # Erweiterungen + was_extended: bool = False + extension_minutes: int = 0 + + # Pausen + pause_count: int = 0 + total_pause_seconds: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "phase": self.phase, + "display_name": self.display_name, + "planned_duration_seconds": self.planned_duration_seconds, + "actual_duration_seconds": self.actual_duration_seconds, + "difference_seconds": self.difference_seconds, + "difference_formatted": self._format_difference(), + "had_overtime": self.had_overtime, + "overtime_seconds": self.overtime_seconds, + "overtime_formatted": self._format_seconds(self.overtime_seconds), + "was_extended": self.was_extended, + "extension_minutes": self.extension_minutes, + "pause_count": self.pause_count, + "total_pause_seconds": self.total_pause_seconds, + } + + def _format_difference(self) -> str: + prefix = "+" if self.difference_seconds >= 0 else "" + return f"{prefix}{self._format_seconds(abs(self.difference_seconds))}" + + def _format_seconds(self, seconds: int) -> str: + mins = seconds // 60 + secs = seconds % 60 + return f"{mins:02d}:{secs:02d}" + + +@dataclass +class SessionSummary: + """Zusammenfassung einer Unterrichtsstunde.""" + session_id: str + teacher_id: str + class_id: str + subject: str + topic: Optional[str] + date: datetime + + total_duration_seconds: int + planned_duration_seconds: int + + phases_completed: int + total_phases: int = 5 + phase_statistics: List[PhaseStatistics] = field(default_factory=list) + + total_overtime_seconds: int = 0 + phases_with_overtime: int = 0 + + total_pause_count: int = 0 + total_pause_seconds: int = 0 + + reflection_notes: str = "" + reflection_rating: Optional[int] = None + key_learnings: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "session_id": self.session_id, + "teacher_id": self.teacher_id, + "class_id": self.class_id, + "subject": self.subject, + "topic": self.topic, + "date": self.date.isoformat() if self.date else None, + "date_formatted": self._format_date(), + "total_duration_seconds": self.total_duration_seconds, + "total_duration_formatted": self._format_seconds(self.total_duration_seconds), + "planned_duration_seconds": self.planned_duration_seconds, + "planned_duration_formatted": self._format_seconds(self.planned_duration_seconds), + "phases_completed": self.phases_completed, + "total_phases": self.total_phases, + "completion_percentage": round(self.phases_completed / self.total_phases * 100), + "phase_statistics": [p.to_dict() for p in self.phase_statistics], + "total_overtime_seconds": self.total_overtime_seconds, + "total_overtime_formatted": self._format_seconds(self.total_overtime_seconds), + "phases_with_overtime": self.phases_with_overtime, + "total_pause_count": self.total_pause_count, + "total_pause_seconds": self.total_pause_seconds, + "reflection_notes": self.reflection_notes, + "reflection_rating": self.reflection_rating, + "key_learnings": self.key_learnings, + } + + def _format_seconds(self, seconds: int) -> str: + mins = seconds // 60 + secs = seconds % 60 + return f"{mins:02d}:{secs:02d}" + + def _format_date(self) -> str: + if not self.date: + return "" + return self.date.strftime("%d.%m.%Y %H:%M") + + +@dataclass +class TeacherAnalytics: + """Aggregierte Statistiken fuer einen Lehrer.""" + teacher_id: str + period_start: datetime + period_end: datetime + + total_sessions: int = 0 + completed_sessions: int = 0 + total_teaching_minutes: int = 0 + + avg_phase_durations: Dict[str, float] = field(default_factory=dict) + + sessions_with_overtime: int = 0 + avg_overtime_seconds: float = 0 + most_overtime_phase: Optional[str] = None + + avg_pause_count: float = 0 + avg_pause_duration_seconds: float = 0 + + subjects_taught: Dict[str, int] = field(default_factory=dict) + classes_taught: Dict[str, int] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "teacher_id": self.teacher_id, + "period_start": self.period_start.isoformat() if self.period_start else None, + "period_end": self.period_end.isoformat() if self.period_end else None, + "total_sessions": self.total_sessions, + "completed_sessions": self.completed_sessions, + "total_teaching_minutes": self.total_teaching_minutes, + "total_teaching_hours": round(self.total_teaching_minutes / 60, 1), + "avg_phase_durations": self.avg_phase_durations, + "sessions_with_overtime": self.sessions_with_overtime, + "overtime_percentage": round(self.sessions_with_overtime / max(self.total_sessions, 1) * 100), + "avg_overtime_seconds": round(self.avg_overtime_seconds), + "avg_overtime_formatted": self._format_seconds(int(self.avg_overtime_seconds)), + "most_overtime_phase": self.most_overtime_phase, + "avg_pause_count": round(self.avg_pause_count, 1), + "avg_pause_duration_seconds": round(self.avg_pause_duration_seconds), + "subjects_taught": self.subjects_taught, + "classes_taught": self.classes_taught, + } + + def _format_seconds(self, seconds: int) -> str: + mins = seconds // 60 + secs = seconds % 60 + return f"{mins:02d}:{secs:02d}" + + +@dataclass +class LessonReflection: + """Post-Lesson Reflection (Feature).""" + reflection_id: str + session_id: str + teacher_id: str + + notes: str = "" + overall_rating: Optional[int] = None + what_worked: List[str] = field(default_factory=list) + improvements: List[str] = field(default_factory=list) + notes_for_next_lesson: str = "" + + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "reflection_id": self.reflection_id, + "session_id": self.session_id, + "teacher_id": self.teacher_id, + "notes": self.notes, + "overall_rating": self.overall_rating, + "what_worked": self.what_worked, + "improvements": self.improvements, + "notes_for_next_lesson": self.notes_for_next_lesson, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None, + } diff --git a/backend-lehrer/classroom_engine/suggestion_data.py b/backend-lehrer/classroom_engine/suggestion_data.py new file mode 100644 index 0000000..7a9e245 --- /dev/null +++ b/backend-lehrer/classroom_engine/suggestion_data.py @@ -0,0 +1,494 @@ +""" +Phasenspezifische und fachspezifische Vorschlags-Daten (Feature f18). + +Enthaelt die vordefinierten Vorschlaege fuer allgemeine Phasen +und fachspezifische Aktivitaeten. +""" + +from typing import List, Dict, Any + +from .models import LessonPhase + + +# Unterstuetzte Faecher fuer fachspezifische Vorschlaege +SUPPORTED_SUBJECTS = [ + "mathematik", "mathe", "math", + "deutsch", + "englisch", "english", + "biologie", "bio", + "physik", + "chemie", + "geschichte", + "geografie", "erdkunde", + "kunst", + "musik", + "sport", + "informatik", +] + + +# Fachspezifische Vorschlaege (Feature f18) +SUBJECT_SUGGESTIONS: Dict[str, Dict[LessonPhase, List[Dict[str, Any]]]] = { + "mathematik": { + LessonPhase.EINSTIEG: [ + { + "id": "math_warm_up", + "title": "Kopfrechnen-Challenge", + "description": "5 schnelle Kopfrechenaufgaben zum Aufwaermen", + "activity_type": "warmup", + "estimated_minutes": 3, + "icon": "calculate", + "subjects": ["mathematik", "mathe"], + }, + { + "id": "math_puzzle", + "title": "Mathematisches Raetsel", + "description": "Ein kniffliges Zahlenraetsel als Einstieg", + "activity_type": "motivation", + "estimated_minutes": 5, + "icon": "extension", + "subjects": ["mathematik", "mathe"], + }, + ], + LessonPhase.ERARBEITUNG: [ + { + "id": "math_geogebra", + "title": "GeoGebra-Exploration", + "description": "Interaktive Visualisierung mit GeoGebra", + "activity_type": "individual_work", + "estimated_minutes": 15, + "icon": "functions", + "subjects": ["mathematik", "mathe"], + }, + { + "id": "math_peer_explain", + "title": "Rechenweg erklaeren", + "description": "Schueler erklaeren sich gegenseitig ihre Loesungswege", + "activity_type": "partner_work", + "estimated_minutes": 10, + "icon": "groups", + "subjects": ["mathematik", "mathe"], + }, + ], + LessonPhase.SICHERUNG: [ + { + "id": "math_formula_card", + "title": "Formelkarte erstellen", + "description": "Wichtigste Formeln auf einer Karte festhalten", + "activity_type": "documentation", + "estimated_minutes": 5, + "icon": "note_alt", + "subjects": ["mathematik", "mathe"], + }, + ], + }, + "deutsch": { + LessonPhase.EINSTIEG: [ + { + "id": "deutsch_wordle", + "title": "Wordle-Variante", + "description": "Wort des Tages erraten", + "activity_type": "warmup", + "estimated_minutes": 4, + "icon": "abc", + "subjects": ["deutsch"], + }, + { + "id": "deutsch_zitat", + "title": "Zitat-Interpretation", + "description": "Ein literarisches Zitat gemeinsam deuten", + "activity_type": "motivation", + "estimated_minutes": 5, + "icon": "format_quote", + "subjects": ["deutsch"], + }, + ], + LessonPhase.ERARBEITUNG: [ + { + "id": "deutsch_textarbeit", + "title": "Textanalyse in Gruppen", + "description": "Gruppenarbeit zu verschiedenen Textabschnitten", + "activity_type": "group_work", + "estimated_minutes": 15, + "icon": "menu_book", + "subjects": ["deutsch"], + }, + { + "id": "deutsch_schreibworkshop", + "title": "Schreibwerkstatt", + "description": "Kreatives Schreiben mit Peer-Feedback", + "activity_type": "individual_work", + "estimated_minutes": 20, + "icon": "edit_note", + "subjects": ["deutsch"], + }, + ], + LessonPhase.SICHERUNG: [ + { + "id": "deutsch_zusammenfassung", + "title": "Text-Zusammenfassung", + "description": "Die wichtigsten Punkte in 3 Saetzen formulieren", + "activity_type": "summary", + "estimated_minutes": 5, + "icon": "summarize", + "subjects": ["deutsch"], + }, + ], + }, + "englisch": { + LessonPhase.EINSTIEG: [ + { + "id": "english_smalltalk", + "title": "Small Talk Warm-Up", + "description": "2-Minuten Gespraeche zu einem Alltagsthema", + "activity_type": "warmup", + "estimated_minutes": 4, + "icon": "chat", + "subjects": ["englisch", "english"], + }, + { + "id": "english_video", + "title": "Authentic Video Clip", + "description": "Kurzer Clip aus einer englischen Serie oder Nachricht", + "activity_type": "motivation", + "estimated_minutes": 5, + "icon": "movie", + "subjects": ["englisch", "english"], + }, + ], + LessonPhase.ERARBEITUNG: [ + { + "id": "english_role_play", + "title": "Role Play Activity", + "description": "Dialoguebung in authentischen Situationen", + "activity_type": "partner_work", + "estimated_minutes": 12, + "icon": "theater_comedy", + "subjects": ["englisch", "english"], + }, + { + "id": "english_reading_circle", + "title": "Reading Circle", + "description": "Gemeinsames Lesen mit verteilten Rollen", + "activity_type": "group_work", + "estimated_minutes": 15, + "icon": "auto_stories", + "subjects": ["englisch", "english"], + }, + ], + }, + "biologie": { + LessonPhase.EINSTIEG: [ + { + "id": "bio_nature_question", + "title": "Naturfrage", + "description": "Eine spannende Frage aus der Natur diskutieren", + "activity_type": "motivation", + "estimated_minutes": 5, + "icon": "eco", + "subjects": ["biologie", "bio"], + }, + ], + LessonPhase.ERARBEITUNG: [ + { + "id": "bio_experiment", + "title": "Mini-Experiment", + "description": "Einfaches Experiment zum Thema durchfuehren", + "activity_type": "group_work", + "estimated_minutes": 20, + "icon": "science", + "subjects": ["biologie", "bio"], + }, + { + "id": "bio_diagram", + "title": "Biologische Zeichnung", + "description": "Beschriftete Zeichnung eines Organismus", + "activity_type": "individual_work", + "estimated_minutes": 15, + "icon": "draw", + "subjects": ["biologie", "bio"], + }, + ], + }, + "physik": { + LessonPhase.EINSTIEG: [ + { + "id": "physik_demo", + "title": "Phaenomen-Demo", + "description": "Ein physikalisches Phaenomen vorfuehren", + "activity_type": "motivation", + "estimated_minutes": 5, + "icon": "bolt", + "subjects": ["physik"], + }, + ], + LessonPhase.ERARBEITUNG: [ + { + "id": "physik_simulation", + "title": "PhET-Simulation", + "description": "Interaktive Simulation von phet.colorado.edu", + "activity_type": "individual_work", + "estimated_minutes": 15, + "icon": "smart_toy", + "subjects": ["physik"], + }, + { + "id": "physik_rechnung", + "title": "Physikalische Rechnung", + "description": "Rechenaufgabe mit physikalischem Kontext", + "activity_type": "partner_work", + "estimated_minutes": 12, + "icon": "calculate", + "subjects": ["physik"], + }, + ], + }, + "informatik": { + LessonPhase.EINSTIEG: [ + { + "id": "info_code_puzzle", + "title": "Code-Puzzle", + "description": "Kurzen Code-Schnipsel analysieren - was macht er?", + "activity_type": "warmup", + "estimated_minutes": 4, + "icon": "code", + "subjects": ["informatik"], + }, + ], + LessonPhase.ERARBEITUNG: [ + { + "id": "info_live_coding", + "title": "Live Coding", + "description": "Gemeinsam Code entwickeln mit Erklaerungen", + "activity_type": "instruction", + "estimated_minutes": 15, + "icon": "terminal", + "subjects": ["informatik"], + }, + { + "id": "info_pair_programming", + "title": "Pair Programming", + "description": "Zu zweit programmieren - Driver und Navigator", + "activity_type": "partner_work", + "estimated_minutes": 20, + "icon": "computer", + "subjects": ["informatik"], + }, + ], + }, +} + + +# Vordefinierte allgemeine Vorschlaege pro Phase +PHASE_SUGGESTIONS: Dict[LessonPhase, List[Dict[str, Any]]] = { + LessonPhase.EINSTIEG: [ + { + "id": "warmup_quiz", + "title": "Kurzes Quiz zum Einstieg", + "description": "Aktivieren Sie das Vorwissen der Schueler mit 3-5 Fragen zum Thema", + "activity_type": "warmup", + "estimated_minutes": 3, + "icon": "quiz" + }, + { + "id": "problem_story", + "title": "Problemgeschichte erzaehlen", + "description": "Stellen Sie ein alltagsnahes Problem vor, das zum Thema fuehrt", + "activity_type": "motivation", + "estimated_minutes": 5, + "icon": "auto_stories" + }, + { + "id": "video_intro", + "title": "Kurzes Erklaervideo", + "description": "Zeigen Sie ein 2-3 Minuten Video zur Einfuehrung ins Thema", + "activity_type": "motivation", + "estimated_minutes": 4, + "icon": "play_circle" + }, + { + "id": "brainstorming", + "title": "Brainstorming", + "description": "Sammeln Sie Ideen und Vorkenntnisse der Schueler an der Tafel", + "activity_type": "warmup", + "estimated_minutes": 5, + "icon": "psychology" + }, + { + "id": "daily_challenge", + "title": "Tagesaufgabe vorstellen", + "description": "Praesentieren Sie die zentrale Frage oder Aufgabe der Stunde", + "activity_type": "problem_introduction", + "estimated_minutes": 3, + "icon": "flag" + } + ], + LessonPhase.ERARBEITUNG: [ + { + "id": "think_pair_share", + "title": "Think-Pair-Share", + "description": "Schueler denken erst einzeln nach, tauschen sich dann zu zweit aus und praesentieren im Plenum", + "activity_type": "partner_work", + "estimated_minutes": 10, + "icon": "groups" + }, + { + "id": "worksheet_digital", + "title": "Digitales Arbeitsblatt", + "description": "Schueler bearbeiten ein interaktives Arbeitsblatt am Tablet oder Computer", + "activity_type": "individual_work", + "estimated_minutes": 15, + "icon": "description" + }, + { + "id": "station_learning", + "title": "Stationenlernen", + "description": "Verschiedene Stationen mit unterschiedlichen Aufgaben und Materialien", + "activity_type": "group_work", + "estimated_minutes": 20, + "icon": "hub" + }, + { + "id": "expert_puzzle", + "title": "Expertenrunde (Jigsaw)", + "description": "Schueler werden Experten fuer ein Teilthema und lehren es anderen", + "activity_type": "group_work", + "estimated_minutes": 15, + "icon": "extension" + }, + { + "id": "guided_instruction", + "title": "Geleitete Instruktion", + "description": "Schrittweise Erklaerung mit Uebungsphasen zwischendurch", + "activity_type": "instruction", + "estimated_minutes": 12, + "icon": "school" + }, + { + "id": "pair_programming", + "title": "Partnerarbeit", + "description": "Zwei Schueler loesen gemeinsam eine Aufgabe", + "activity_type": "partner_work", + "estimated_minutes": 10, + "icon": "people" + } + ], + LessonPhase.SICHERUNG: [ + { + "id": "mindmap_class", + "title": "Gemeinsame Mindmap", + "description": "Ergebnisse als Mindmap an der Tafel oder digital sammeln und strukturieren", + "activity_type": "visualization", + "estimated_minutes": 8, + "icon": "account_tree" + }, + { + "id": "exit_ticket", + "title": "Exit Ticket", + "description": "Schueler notieren 3 Dinge die sie gelernt haben und 1 offene Frage", + "activity_type": "summary", + "estimated_minutes": 5, + "icon": "sticky_note_2" + }, + { + "id": "gallery_walk", + "title": "Galerie-Rundgang", + "description": "Schueler praesentieren ihre Ergebnisse und geben sich Feedback", + "activity_type": "presentation", + "estimated_minutes": 10, + "icon": "photo_library" + }, + { + "id": "key_points", + "title": "Kernpunkte zusammenfassen", + "description": "Gemeinsam die wichtigsten Erkenntnisse der Stunde formulieren", + "activity_type": "summary", + "estimated_minutes": 5, + "icon": "format_list_bulleted" + }, + { + "id": "quick_check", + "title": "Schneller Wissenscheck", + "description": "5 kurze Fragen zur Ueberpruefung des Verstaendnisses", + "activity_type": "documentation", + "estimated_minutes": 5, + "icon": "fact_check" + } + ], + LessonPhase.TRANSFER: [ + { + "id": "real_world_example", + "title": "Alltagsbeispiele finden", + "description": "Schueler suchen Beispiele aus ihrem Alltag, wo das Gelernte vorkommt", + "activity_type": "application", + "estimated_minutes": 5, + "icon": "public" + }, + { + "id": "challenge_task", + "title": "Knobelaufgabe", + "description": "Eine anspruchsvollere Aufgabe fuer schnelle Schueler oder als Bonus", + "activity_type": "differentiation", + "estimated_minutes": 7, + "icon": "psychology" + }, + { + "id": "creative_application", + "title": "Kreative Anwendung", + "description": "Schueler wenden das Gelernte in einem kreativen Projekt an", + "activity_type": "application", + "estimated_minutes": 10, + "icon": "palette" + }, + { + "id": "peer_teaching", + "title": "Peer-Teaching", + "description": "Schueler erklaeren sich gegenseitig das Gelernte", + "activity_type": "real_world_connection", + "estimated_minutes": 5, + "icon": "supervisor_account" + } + ], + LessonPhase.REFLEXION: [ + { + "id": "thumbs_feedback", + "title": "Daumen-Feedback", + "description": "Schnelle Stimmungsabfrage: Daumen hoch/mitte/runter", + "activity_type": "feedback", + "estimated_minutes": 2, + "icon": "thumb_up" + }, + { + "id": "homework_assign", + "title": "Hausaufgabe vergeben", + "description": "Passende Hausaufgabe zur Vertiefung des Gelernten", + "activity_type": "homework", + "estimated_minutes": 3, + "icon": "home_work" + }, + { + "id": "one_word", + "title": "Ein-Wort-Reflexion", + "description": "Jeder Schueler nennt ein Wort, das die Stunde beschreibt", + "activity_type": "feedback", + "estimated_minutes": 3, + "icon": "chat" + }, + { + "id": "preview_next", + "title": "Ausblick naechste Stunde", + "description": "Kurzer Ausblick auf das Thema der naechsten Stunde", + "activity_type": "preview", + "estimated_minutes": 2, + "icon": "event" + }, + { + "id": "learning_log", + "title": "Lerntagebuch", + "description": "Schueler notieren ihre wichtigsten Erkenntnisse im Lerntagebuch", + "activity_type": "feedback", + "estimated_minutes": 4, + "icon": "menu_book" + } + ] +} diff --git a/backend-lehrer/classroom_engine/suggestions.py b/backend-lehrer/classroom_engine/suggestions.py index 3aa148b..b8e36d2 100644 --- a/backend-lehrer/classroom_engine/suggestions.py +++ b/backend-lehrer/classroom_engine/suggestions.py @@ -8,490 +8,11 @@ und optional dem Fach. from typing import List, Dict, Any, Optional from .models import LessonPhase, LessonSession, PhaseSuggestion - - -# Unterstuetzte Faecher fuer fachspezifische Vorschlaege -SUPPORTED_SUBJECTS = [ - "mathematik", "mathe", "math", - "deutsch", - "englisch", "english", - "biologie", "bio", - "physik", - "chemie", - "geschichte", - "geografie", "erdkunde", - "kunst", - "musik", - "sport", - "informatik", -] - - -# Fachspezifische Vorschlaege (Feature f18) -SUBJECT_SUGGESTIONS: Dict[str, Dict[LessonPhase, List[Dict[str, Any]]]] = { - "mathematik": { - LessonPhase.EINSTIEG: [ - { - "id": "math_warm_up", - "title": "Kopfrechnen-Challenge", - "description": "5 schnelle Kopfrechenaufgaben zum Aufwaermen", - "activity_type": "warmup", - "estimated_minutes": 3, - "icon": "calculate", - "subjects": ["mathematik", "mathe"], - }, - { - "id": "math_puzzle", - "title": "Mathematisches Raetsel", - "description": "Ein kniffliges Zahlenraetsel als Einstieg", - "activity_type": "motivation", - "estimated_minutes": 5, - "icon": "extension", - "subjects": ["mathematik", "mathe"], - }, - ], - LessonPhase.ERARBEITUNG: [ - { - "id": "math_geogebra", - "title": "GeoGebra-Exploration", - "description": "Interaktive Visualisierung mit GeoGebra", - "activity_type": "individual_work", - "estimated_minutes": 15, - "icon": "functions", - "subjects": ["mathematik", "mathe"], - }, - { - "id": "math_peer_explain", - "title": "Rechenweg erklaeren", - "description": "Schueler erklaeren sich gegenseitig ihre Loesungswege", - "activity_type": "partner_work", - "estimated_minutes": 10, - "icon": "groups", - "subjects": ["mathematik", "mathe"], - }, - ], - LessonPhase.SICHERUNG: [ - { - "id": "math_formula_card", - "title": "Formelkarte erstellen", - "description": "Wichtigste Formeln auf einer Karte festhalten", - "activity_type": "documentation", - "estimated_minutes": 5, - "icon": "note_alt", - "subjects": ["mathematik", "mathe"], - }, - ], - }, - "deutsch": { - LessonPhase.EINSTIEG: [ - { - "id": "deutsch_wordle", - "title": "Wordle-Variante", - "description": "Wort des Tages erraten", - "activity_type": "warmup", - "estimated_minutes": 4, - "icon": "abc", - "subjects": ["deutsch"], - }, - { - "id": "deutsch_zitat", - "title": "Zitat-Interpretation", - "description": "Ein literarisches Zitat gemeinsam deuten", - "activity_type": "motivation", - "estimated_minutes": 5, - "icon": "format_quote", - "subjects": ["deutsch"], - }, - ], - LessonPhase.ERARBEITUNG: [ - { - "id": "deutsch_textarbeit", - "title": "Textanalyse in Gruppen", - "description": "Gruppenarbeit zu verschiedenen Textabschnitten", - "activity_type": "group_work", - "estimated_minutes": 15, - "icon": "menu_book", - "subjects": ["deutsch"], - }, - { - "id": "deutsch_schreibworkshop", - "title": "Schreibwerkstatt", - "description": "Kreatives Schreiben mit Peer-Feedback", - "activity_type": "individual_work", - "estimated_minutes": 20, - "icon": "edit_note", - "subjects": ["deutsch"], - }, - ], - LessonPhase.SICHERUNG: [ - { - "id": "deutsch_zusammenfassung", - "title": "Text-Zusammenfassung", - "description": "Die wichtigsten Punkte in 3 Saetzen formulieren", - "activity_type": "summary", - "estimated_minutes": 5, - "icon": "summarize", - "subjects": ["deutsch"], - }, - ], - }, - "englisch": { - LessonPhase.EINSTIEG: [ - { - "id": "english_smalltalk", - "title": "Small Talk Warm-Up", - "description": "2-Minuten Gespraeche zu einem Alltagsthema", - "activity_type": "warmup", - "estimated_minutes": 4, - "icon": "chat", - "subjects": ["englisch", "english"], - }, - { - "id": "english_video", - "title": "Authentic Video Clip", - "description": "Kurzer Clip aus einer englischen Serie oder Nachricht", - "activity_type": "motivation", - "estimated_minutes": 5, - "icon": "movie", - "subjects": ["englisch", "english"], - }, - ], - LessonPhase.ERARBEITUNG: [ - { - "id": "english_role_play", - "title": "Role Play Activity", - "description": "Dialoguebung in authentischen Situationen", - "activity_type": "partner_work", - "estimated_minutes": 12, - "icon": "theater_comedy", - "subjects": ["englisch", "english"], - }, - { - "id": "english_reading_circle", - "title": "Reading Circle", - "description": "Gemeinsames Lesen mit verteilten Rollen", - "activity_type": "group_work", - "estimated_minutes": 15, - "icon": "auto_stories", - "subjects": ["englisch", "english"], - }, - ], - }, - "biologie": { - LessonPhase.EINSTIEG: [ - { - "id": "bio_nature_question", - "title": "Naturfrage", - "description": "Eine spannende Frage aus der Natur diskutieren", - "activity_type": "motivation", - "estimated_minutes": 5, - "icon": "eco", - "subjects": ["biologie", "bio"], - }, - ], - LessonPhase.ERARBEITUNG: [ - { - "id": "bio_experiment", - "title": "Mini-Experiment", - "description": "Einfaches Experiment zum Thema durchfuehren", - "activity_type": "group_work", - "estimated_minutes": 20, - "icon": "science", - "subjects": ["biologie", "bio"], - }, - { - "id": "bio_diagram", - "title": "Biologische Zeichnung", - "description": "Beschriftete Zeichnung eines Organismus", - "activity_type": "individual_work", - "estimated_minutes": 15, - "icon": "draw", - "subjects": ["biologie", "bio"], - }, - ], - }, - "physik": { - LessonPhase.EINSTIEG: [ - { - "id": "physik_demo", - "title": "Phaenomen-Demo", - "description": "Ein physikalisches Phaenomen vorfuehren", - "activity_type": "motivation", - "estimated_minutes": 5, - "icon": "bolt", - "subjects": ["physik"], - }, - ], - LessonPhase.ERARBEITUNG: [ - { - "id": "physik_simulation", - "title": "PhET-Simulation", - "description": "Interaktive Simulation von phet.colorado.edu", - "activity_type": "individual_work", - "estimated_minutes": 15, - "icon": "smart_toy", - "subjects": ["physik"], - }, - { - "id": "physik_rechnung", - "title": "Physikalische Rechnung", - "description": "Rechenaufgabe mit physikalischem Kontext", - "activity_type": "partner_work", - "estimated_minutes": 12, - "icon": "calculate", - "subjects": ["physik"], - }, - ], - }, - "informatik": { - LessonPhase.EINSTIEG: [ - { - "id": "info_code_puzzle", - "title": "Code-Puzzle", - "description": "Kurzen Code-Schnipsel analysieren - was macht er?", - "activity_type": "warmup", - "estimated_minutes": 4, - "icon": "code", - "subjects": ["informatik"], - }, - ], - LessonPhase.ERARBEITUNG: [ - { - "id": "info_live_coding", - "title": "Live Coding", - "description": "Gemeinsam Code entwickeln mit Erklaerungen", - "activity_type": "instruction", - "estimated_minutes": 15, - "icon": "terminal", - "subjects": ["informatik"], - }, - { - "id": "info_pair_programming", - "title": "Pair Programming", - "description": "Zu zweit programmieren - Driver und Navigator", - "activity_type": "partner_work", - "estimated_minutes": 20, - "icon": "computer", - "subjects": ["informatik"], - }, - ], - }, -} - - -# Vordefinierte allgemeine Vorschlaege pro Phase -PHASE_SUGGESTIONS: Dict[LessonPhase, List[Dict[str, Any]]] = { - LessonPhase.EINSTIEG: [ - { - "id": "warmup_quiz", - "title": "Kurzes Quiz zum Einstieg", - "description": "Aktivieren Sie das Vorwissen der Schueler mit 3-5 Fragen zum Thema", - "activity_type": "warmup", - "estimated_minutes": 3, - "icon": "quiz" - }, - { - "id": "problem_story", - "title": "Problemgeschichte erzaehlen", - "description": "Stellen Sie ein alltagsnahes Problem vor, das zum Thema fuehrt", - "activity_type": "motivation", - "estimated_minutes": 5, - "icon": "auto_stories" - }, - { - "id": "video_intro", - "title": "Kurzes Erklaervideo", - "description": "Zeigen Sie ein 2-3 Minuten Video zur Einfuehrung ins Thema", - "activity_type": "motivation", - "estimated_minutes": 4, - "icon": "play_circle" - }, - { - "id": "brainstorming", - "title": "Brainstorming", - "description": "Sammeln Sie Ideen und Vorkenntnisse der Schueler an der Tafel", - "activity_type": "warmup", - "estimated_minutes": 5, - "icon": "psychology" - }, - { - "id": "daily_challenge", - "title": "Tagesaufgabe vorstellen", - "description": "Praesentieren Sie die zentrale Frage oder Aufgabe der Stunde", - "activity_type": "problem_introduction", - "estimated_minutes": 3, - "icon": "flag" - } - ], - LessonPhase.ERARBEITUNG: [ - { - "id": "think_pair_share", - "title": "Think-Pair-Share", - "description": "Schueler denken erst einzeln nach, tauschen sich dann zu zweit aus und praesentieren im Plenum", - "activity_type": "partner_work", - "estimated_minutes": 10, - "icon": "groups" - }, - { - "id": "worksheet_digital", - "title": "Digitales Arbeitsblatt", - "description": "Schueler bearbeiten ein interaktives Arbeitsblatt am Tablet oder Computer", - "activity_type": "individual_work", - "estimated_minutes": 15, - "icon": "description" - }, - { - "id": "station_learning", - "title": "Stationenlernen", - "description": "Verschiedene Stationen mit unterschiedlichen Aufgaben und Materialien", - "activity_type": "group_work", - "estimated_minutes": 20, - "icon": "hub" - }, - { - "id": "expert_puzzle", - "title": "Expertenrunde (Jigsaw)", - "description": "Schueler werden Experten fuer ein Teilthema und lehren es anderen", - "activity_type": "group_work", - "estimated_minutes": 15, - "icon": "extension" - }, - { - "id": "guided_instruction", - "title": "Geleitete Instruktion", - "description": "Schrittweise Erklaerung mit Uebungsphasen zwischendurch", - "activity_type": "instruction", - "estimated_minutes": 12, - "icon": "school" - }, - { - "id": "pair_programming", - "title": "Partnerarbeit", - "description": "Zwei Schueler loesen gemeinsam eine Aufgabe", - "activity_type": "partner_work", - "estimated_minutes": 10, - "icon": "people" - } - ], - LessonPhase.SICHERUNG: [ - { - "id": "mindmap_class", - "title": "Gemeinsame Mindmap", - "description": "Ergebnisse als Mindmap an der Tafel oder digital sammeln und strukturieren", - "activity_type": "visualization", - "estimated_minutes": 8, - "icon": "account_tree" - }, - { - "id": "exit_ticket", - "title": "Exit Ticket", - "description": "Schueler notieren 3 Dinge die sie gelernt haben und 1 offene Frage", - "activity_type": "summary", - "estimated_minutes": 5, - "icon": "sticky_note_2" - }, - { - "id": "gallery_walk", - "title": "Galerie-Rundgang", - "description": "Schueler praesentieren ihre Ergebnisse und geben sich Feedback", - "activity_type": "presentation", - "estimated_minutes": 10, - "icon": "photo_library" - }, - { - "id": "key_points", - "title": "Kernpunkte zusammenfassen", - "description": "Gemeinsam die wichtigsten Erkenntnisse der Stunde formulieren", - "activity_type": "summary", - "estimated_minutes": 5, - "icon": "format_list_bulleted" - }, - { - "id": "quick_check", - "title": "Schneller Wissenscheck", - "description": "5 kurze Fragen zur Ueberpruefung des Verstaendnisses", - "activity_type": "documentation", - "estimated_minutes": 5, - "icon": "fact_check" - } - ], - LessonPhase.TRANSFER: [ - { - "id": "real_world_example", - "title": "Alltagsbeispiele finden", - "description": "Schueler suchen Beispiele aus ihrem Alltag, wo das Gelernte vorkommt", - "activity_type": "application", - "estimated_minutes": 5, - "icon": "public" - }, - { - "id": "challenge_task", - "title": "Knobelaufgabe", - "description": "Eine anspruchsvollere Aufgabe fuer schnelle Schueler oder als Bonus", - "activity_type": "differentiation", - "estimated_minutes": 7, - "icon": "psychology" - }, - { - "id": "creative_application", - "title": "Kreative Anwendung", - "description": "Schueler wenden das Gelernte in einem kreativen Projekt an", - "activity_type": "application", - "estimated_minutes": 10, - "icon": "palette" - }, - { - "id": "peer_teaching", - "title": "Peer-Teaching", - "description": "Schueler erklaeren sich gegenseitig das Gelernte", - "activity_type": "real_world_connection", - "estimated_minutes": 5, - "icon": "supervisor_account" - } - ], - LessonPhase.REFLEXION: [ - { - "id": "thumbs_feedback", - "title": "Daumen-Feedback", - "description": "Schnelle Stimmungsabfrage: Daumen hoch/mitte/runter", - "activity_type": "feedback", - "estimated_minutes": 2, - "icon": "thumb_up" - }, - { - "id": "homework_assign", - "title": "Hausaufgabe vergeben", - "description": "Passende Hausaufgabe zur Vertiefung des Gelernten", - "activity_type": "homework", - "estimated_minutes": 3, - "icon": "home_work" - }, - { - "id": "one_word", - "title": "Ein-Wort-Reflexion", - "description": "Jeder Schueler nennt ein Wort, das die Stunde beschreibt", - "activity_type": "feedback", - "estimated_minutes": 3, - "icon": "chat" - }, - { - "id": "preview_next", - "title": "Ausblick naechste Stunde", - "description": "Kurzer Ausblick auf das Thema der naechsten Stunde", - "activity_type": "preview", - "estimated_minutes": 2, - "icon": "event" - }, - { - "id": "learning_log", - "title": "Lerntagebuch", - "description": "Schueler notieren ihre wichtigsten Erkenntnisse im Lerntagebuch", - "activity_type": "feedback", - "estimated_minutes": 4, - "icon": "menu_book" - } - ] -} +from .suggestion_data import ( + SUPPORTED_SUBJECTS, + SUBJECT_SUGGESTIONS, + PHASE_SUGGESTIONS, +) class SuggestionEngine: diff --git a/backend-lehrer/content_generators/__init__.py b/backend-lehrer/content_generators/__init__.py index 2313c04..2292f89 100644 --- a/backend-lehrer/content_generators/__init__.py +++ b/backend-lehrer/content_generators/__init__.py @@ -11,10 +11,12 @@ from .h5p_generator import ( generate_h5p_manifest, ) -from .pdf_generator import ( - PDFGenerator, +from .worksheet_models import ( Worksheet, WorksheetSection, +) +from .pdf_generator import ( + PDFGenerator, generate_worksheet_html, generate_worksheet_pdf, ) diff --git a/backend-lehrer/content_generators/pdf_generator.py b/backend-lehrer/content_generators/pdf_generator.py index d41c442..c036732 100644 --- a/backend-lehrer/content_generators/pdf_generator.py +++ b/backend-lehrer/content_generators/pdf_generator.py @@ -12,252 +12,9 @@ Structure: 6. Reflection Questions """ -import io -from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Optional, Union -# Note: In production, use reportlab or weasyprint for actual PDF generation -# This module generates an intermediate format that can be converted to PDF - - -@dataclass -class WorksheetSection: - """A section of the worksheet""" - title: str - content_type: str # "text", "table", "exercises", "blanks" - content: Any - difficulty: int = 1 # 1-4 - - -@dataclass -class Worksheet: - """Complete worksheet structure""" - title: str - subtitle: str - unit_id: str - locale: str - sections: list[WorksheetSection] - footer: str = "" - - def to_html(self) -> str: - """Convert worksheet to HTML (for PDF conversion via weasyprint)""" - html_parts = [ - "", - "", - "", - "", - "", - "", - "", - f"

{self.title}

", - f"

{self.subtitle}

", - ] - - for section in self.sections: - html_parts.append(self._render_section(section)) - - html_parts.extend([ - f"
{self.footer}
", - "", - "" - ]) - - return "\n".join(html_parts) - - def _get_styles(self) -> str: - return """ - @page { - size: A4; - margin: 2cm; - } - body { - font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; - font-size: 11pt; - line-height: 1.5; - color: #333; - } - header { - text-align: center; - margin-bottom: 1.5em; - border-bottom: 2px solid #2c5282; - padding-bottom: 1em; - } - h1 { - color: #2c5282; - margin-bottom: 0.25em; - font-size: 20pt; - } - .subtitle { - color: #666; - font-style: italic; - } - h2 { - color: #2c5282; - border-bottom: 1px solid #e2e8f0; - padding-bottom: 0.25em; - margin-top: 1.5em; - font-size: 14pt; - } - h3 { - color: #4a5568; - font-size: 12pt; - } - table { - width: 100%; - border-collapse: collapse; - margin: 1em 0; - } - th, td { - border: 1px solid #e2e8f0; - padding: 0.5em; - text-align: left; - } - th { - background-color: #edf2f7; - font-weight: bold; - } - .exercise { - margin: 1em 0; - padding: 1em; - background-color: #f7fafc; - border-left: 4px solid #4299e1; - } - .exercise-number { - font-weight: bold; - color: #2c5282; - } - .blank { - display: inline-block; - min-width: 100px; - border-bottom: 1px solid #333; - margin: 0 0.25em; - } - .difficulty { - font-size: 9pt; - color: #718096; - } - .difficulty-1 { color: #48bb78; } - .difficulty-2 { color: #4299e1; } - .difficulty-3 { color: #ed8936; } - .difficulty-4 { color: #f56565; } - .reflection { - margin-top: 2em; - padding: 1em; - background-color: #fffaf0; - border: 1px dashed #ed8936; - } - .write-area { - min-height: 80px; - border: 1px solid #e2e8f0; - margin: 0.5em 0; - background-color: #fff; - } - footer { - margin-top: 2em; - padding-top: 1em; - border-top: 1px solid #e2e8f0; - font-size: 9pt; - color: #718096; - text-align: center; - } - ul, ol { - margin: 0.5em 0; - padding-left: 1.5em; - } - .objectives { - background-color: #ebf8ff; - padding: 1em; - border-radius: 4px; - } - """ - - def _render_section(self, section: WorksheetSection) -> str: - parts = [f"

{section.title}

"] - - if section.content_type == "text": - parts.append(f"

{section.content}

") - - elif section.content_type == "objectives": - parts.append("
    ") - for obj in section.content: - parts.append(f"
  • {obj}
  • ") - parts.append("
") - - elif section.content_type == "table": - parts.append("") - for header in section.content.get("headers", []): - parts.append(f"") - parts.append("") - for row in section.content.get("rows", []): - parts.append("") - for cell in row: - parts.append(f"") - parts.append("") - parts.append("
{header}
{cell}
") - - elif section.content_type == "exercises": - for i, ex in enumerate(section.content, 1): - diff_class = f"difficulty-{ex.get('difficulty', 1)}" - diff_stars = "*" * ex.get("difficulty", 1) - parts.append(f""" -
- Aufgabe {i} - ({diff_stars}) -

{ex.get('question', '')}

- {self._render_exercise_input(ex)} -
- """) - - elif section.content_type == "blanks": - text = section.content - # Replace *word* with blank - import re - text = re.sub(r'\*([^*]+)\*', r"", text) - parts.append(f"

{text}

") - - elif section.content_type == "reflection": - parts.append("
") - parts.append(f"

{section.content.get('prompt', '')}

") - parts.append("
") - parts.append("
") - - parts.append("
") - return "\n".join(parts) - - def _render_exercise_input(self, exercise: dict) -> str: - ex_type = exercise.get("type", "text") - - if ex_type == "multiple_choice": - options = exercise.get("options", []) - parts = ["
    "] - for opt in options: - parts.append(f"
  • □ {opt}
  • ") - parts.append("
") - return "\n".join(parts) - - elif ex_type == "matching": - left = exercise.get("left", []) - right = exercise.get("right", []) - parts = [""] - for i, item in enumerate(left): - right_item = right[i] if i < len(right) else "" - parts.append(f"") - parts.append("
BegriffZuordnung
{item}
") - return "\n".join(parts) - - elif ex_type == "sequence": - items = exercise.get("items", []) - parts = ["

Bringe in die richtige Reihenfolge:

    "] - for item in items: - parts.append(f"
  1. ") - parts.append("
") - parts.append(f"

Begriffe: {', '.join(items)}

") - return "\n".join(parts) - - else: - return "
" +from .worksheet_models import Worksheet, WorksheetSection class PDFGenerator: @@ -267,15 +24,7 @@ class PDFGenerator: self.locale = locale def generate_from_unit(self, unit: dict) -> Worksheet: - """ - Generate a worksheet from a unit definition. - - Args: - unit: Unit definition dictionary - - Returns: - Worksheet object - """ + """Generate a worksheet from a unit definition.""" unit_id = unit.get("unit_id", "unknown") title = self._get_localized(unit.get("title"), "Arbeitsblatt") objectives = unit.get("learning_objectives", []) @@ -283,51 +32,36 @@ class PDFGenerator: sections = [] - # Learning Objectives if objectives: sections.append(WorksheetSection( - title="Lernziele", - content_type="objectives", - content=objectives + title="Lernziele", content_type="objectives", content=objectives )) - # Vocabulary Table vocab_section = self._create_vocabulary_section(stops) if vocab_section: sections.append(vocab_section) - # Key Concepts Summary concepts_section = self._create_concepts_section(stops) if concepts_section: sections.append(concepts_section) - # Basic Exercises basic_exercises = self._create_basic_exercises(stops) if basic_exercises: sections.append(WorksheetSection( - title="Ubungen - Basis", - content_type="exercises", - content=basic_exercises, - difficulty=1 + title="Ubungen - Basis", content_type="exercises", + content=basic_exercises, difficulty=1 )) - # Challenge Exercises challenge_exercises = self._create_challenge_exercises(stops, unit) if challenge_exercises: sections.append(WorksheetSection( - title="Ubungen - Herausforderung", - content_type="exercises", - content=challenge_exercises, - difficulty=3 + title="Ubungen - Herausforderung", content_type="exercises", + content=challenge_exercises, difficulty=3 )) - # Reflection sections.append(WorksheetSection( - title="Reflexion", - content_type="reflection", - content={ - "prompt": "Erklaere in eigenen Worten, was du heute gelernt hast:" - } + title="Reflexion", content_type="reflection", + content={"prompt": "Erklaere in eigenen Worten, was du heute gelernt hast:"} )) return Worksheet( @@ -370,12 +104,8 @@ class PDFGenerator: return None return WorksheetSection( - title="Wichtige Begriffe", - content_type="table", - content={ - "headers": ["Begriff", "Erklarung"], - "rows": rows - } + title="Wichtige Begriffe", content_type="table", + content={"headers": ["Begriff", "Erklarung"], "rows": rows} ) def _create_concepts_section(self, stops: list) -> Optional[WorksheetSection]: @@ -392,19 +122,14 @@ class PDFGenerator: return None return WorksheetSection( - title="Zusammenfassung", - content_type="table", - content={ - "headers": ["Station", "Was hast du gelernt?"], - "rows": rows - } + title="Zusammenfassung", content_type="table", + content={"headers": ["Station", "Was hast du gelernt?"], "rows": rows} ) def _create_basic_exercises(self, stops: list) -> list[dict]: """Create basic difficulty exercises""" exercises = [] - # Vocabulary matching vocab_items = [] for stop in stops: for v in stop.get("vocab", []): @@ -422,7 +147,6 @@ class PDFGenerator: "difficulty": 1 }) - # True/False from concepts for stop in stops[:3]: concept = stop.get("concept", {}) why = self._get_localized(concept.get("why")) @@ -435,7 +159,6 @@ class PDFGenerator: }) break - # Sequence ordering (for FlightPath) if len(stops) >= 4: labels = [self._get_localized(s.get("label")) for s in stops[:6] if self._get_localized(s.get("label"))] if len(labels) >= 4: @@ -455,7 +178,6 @@ class PDFGenerator: """Create challenging exercises""" exercises = [] - # Misconception identification for stop in stops: concept = stop.get("concept", {}) misconception = self._get_localized(concept.get("common_misconception")) @@ -472,14 +194,12 @@ class PDFGenerator: if len(exercises) >= 2: break - # Transfer/Application question exercises.append({ "type": "text", "question": "Erklaere einem Freund in 2-3 Satzen, was du gelernt hast:", "difficulty": 3 }) - # Critical thinking exercises.append({ "type": "text", "question": "Was moechtest du noch mehr uber dieses Thema erfahren?", @@ -490,35 +210,14 @@ class PDFGenerator: def generate_worksheet_html(unit_definition: dict, locale: str = "de-DE") -> str: - """ - Generate HTML worksheet from unit definition. - - Args: - unit_definition: The unit JSON definition - locale: Target locale for content - - Returns: - HTML string ready for PDF conversion - """ + """Generate HTML worksheet from unit definition.""" generator = PDFGenerator(locale=locale) worksheet = generator.generate_from_unit(unit_definition) return worksheet.to_html() def generate_worksheet_pdf(unit_definition: dict, locale: str = "de-DE") -> bytes: - """ - Generate PDF worksheet from unit definition. - - Requires weasyprint to be installed: - pip install weasyprint - - Args: - unit_definition: The unit JSON definition - locale: Target locale for content - - Returns: - PDF bytes - """ + """Generate PDF worksheet from unit definition.""" try: from weasyprint import HTML except ImportError: diff --git a/backend-lehrer/content_generators/worksheet_models.py b/backend-lehrer/content_generators/worksheet_models.py new file mode 100644 index 0000000..e245f5b --- /dev/null +++ b/backend-lehrer/content_generators/worksheet_models.py @@ -0,0 +1,247 @@ +""" +Worksheet Models - Datenstrukturen und HTML-Rendering fuer Arbeitsblaetter. +""" + +import re +from dataclasses import dataclass +from typing import Any + + +@dataclass +class WorksheetSection: + """A section of the worksheet""" + title: str + content_type: str # "text", "table", "exercises", "blanks" + content: Any + difficulty: int = 1 # 1-4 + + +@dataclass +class Worksheet: + """Complete worksheet structure""" + title: str + subtitle: str + unit_id: str + locale: str + sections: list[WorksheetSection] + footer: str = "" + + def to_html(self) -> str: + """Convert worksheet to HTML (for PDF conversion via weasyprint)""" + html_parts = [ + "", + "", + "", + "", + "", + "", + "", + f"

{self.title}

", + f"

{self.subtitle}

", + ] + + for section in self.sections: + html_parts.append(_render_section(section)) + + html_parts.extend([ + f"
{self.footer}
", + "", + "" + ]) + + return "\n".join(html_parts) + + +def _get_styles() -> str: + return """ + @page { + size: A4; + margin: 2cm; + } + body { + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; + font-size: 11pt; + line-height: 1.5; + color: #333; + } + header { + text-align: center; + margin-bottom: 1.5em; + border-bottom: 2px solid #2c5282; + padding-bottom: 1em; + } + h1 { + color: #2c5282; + margin-bottom: 0.25em; + font-size: 20pt; + } + .subtitle { + color: #666; + font-style: italic; + } + h2 { + color: #2c5282; + border-bottom: 1px solid #e2e8f0; + padding-bottom: 0.25em; + margin-top: 1.5em; + font-size: 14pt; + } + h3 { + color: #4a5568; + font-size: 12pt; + } + table { + width: 100%; + border-collapse: collapse; + margin: 1em 0; + } + th, td { + border: 1px solid #e2e8f0; + padding: 0.5em; + text-align: left; + } + th { + background-color: #edf2f7; + font-weight: bold; + } + .exercise { + margin: 1em 0; + padding: 1em; + background-color: #f7fafc; + border-left: 4px solid #4299e1; + } + .exercise-number { + font-weight: bold; + color: #2c5282; + } + .blank { + display: inline-block; + min-width: 100px; + border-bottom: 1px solid #333; + margin: 0 0.25em; + } + .difficulty { + font-size: 9pt; + color: #718096; + } + .difficulty-1 { color: #48bb78; } + .difficulty-2 { color: #4299e1; } + .difficulty-3 { color: #ed8936; } + .difficulty-4 { color: #f56565; } + .reflection { + margin-top: 2em; + padding: 1em; + background-color: #fffaf0; + border: 1px dashed #ed8936; + } + .write-area { + min-height: 80px; + border: 1px solid #e2e8f0; + margin: 0.5em 0; + background-color: #fff; + } + footer { + margin-top: 2em; + padding-top: 1em; + border-top: 1px solid #e2e8f0; + font-size: 9pt; + color: #718096; + text-align: center; + } + ul, ol { + margin: 0.5em 0; + padding-left: 1.5em; + } + .objectives { + background-color: #ebf8ff; + padding: 1em; + border-radius: 4px; + } + """ + + +def _render_section(section: WorksheetSection) -> str: + parts = [f"

{section.title}

"] + + if section.content_type == "text": + parts.append(f"

{section.content}

") + + elif section.content_type == "objectives": + parts.append("
    ") + for obj in section.content: + parts.append(f"
  • {obj}
  • ") + parts.append("
") + + elif section.content_type == "table": + parts.append("") + for header in section.content.get("headers", []): + parts.append(f"") + parts.append("") + for row in section.content.get("rows", []): + parts.append("") + for cell in row: + parts.append(f"") + parts.append("") + parts.append("
{header}
{cell}
") + + elif section.content_type == "exercises": + for i, ex in enumerate(section.content, 1): + diff_class = f"difficulty-{ex.get('difficulty', 1)}" + diff_stars = "*" * ex.get("difficulty", 1) + parts.append(f""" +
+ Aufgabe {i} + ({diff_stars}) +

{ex.get('question', '')}

+ {_render_exercise_input(ex)} +
+ """) + + elif section.content_type == "blanks": + text = section.content + text = re.sub(r'\*([^*]+)\*', r"", text) + parts.append(f"

{text}

") + + elif section.content_type == "reflection": + parts.append("
") + parts.append(f"

{section.content.get('prompt', '')}

") + parts.append("
") + parts.append("
") + + parts.append("
") + return "\n".join(parts) + + +def _render_exercise_input(exercise: dict) -> str: + ex_type = exercise.get("type", "text") + + if ex_type == "multiple_choice": + options = exercise.get("options", []) + parts = ["
    "] + for opt in options: + parts.append(f"
  • □ {opt}
  • ") + parts.append("
") + return "\n".join(parts) + + elif ex_type == "matching": + left = exercise.get("left", []) + right = exercise.get("right", []) + parts = [""] + for i, item in enumerate(left): + parts.append(f"") + parts.append("
BegriffZuordnung
{item}
") + return "\n".join(parts) + + elif ex_type == "sequence": + items = exercise.get("items", []) + parts = ["

Bringe in die richtige Reihenfolge:

    "] + for item in items: + parts.append(f"
  1. ") + parts.append("
") + parts.append(f"

Begriffe: {', '.join(items)}

") + return "\n".join(parts) + + else: + return "
" diff --git a/backend-lehrer/generators/quiz_generator.py b/backend-lehrer/generators/quiz_generator.py index 3f4b6e5..4732cd1 100644 --- a/backend-lehrer/generators/quiz_generator.py +++ b/backend-lehrer/generators/quiz_generator.py @@ -10,66 +10,27 @@ Generiert: import logging import json -import re -from typing import List, Dict, Any, Optional, Tuple -from dataclasses import dataclass -from enum import Enum +from typing import List, Dict, Any, Optional + +from .quiz_models import ( + QuizType, + TrueFalseQuestion, + MatchingPair, + SortingItem, + OpenQuestion, + Quiz, +) +from .quiz_helpers import ( + extract_factual_sentences, + negate_sentence, + extract_definitions, + extract_sequence, + extract_keywords, +) logger = logging.getLogger(__name__) -class QuizType(str, Enum): - """Typen von Quiz-Aufgaben.""" - TRUE_FALSE = "true_false" - MATCHING = "matching" - SORTING = "sorting" - OPEN_ENDED = "open_ended" - - -@dataclass -class TrueFalseQuestion: - """Eine Wahr/Falsch-Frage.""" - statement: str - is_true: bool - explanation: str - source_reference: Optional[str] = None - - -@dataclass -class MatchingPair: - """Ein Zuordnungspaar.""" - left: str - right: str - hint: Optional[str] = None - - -@dataclass -class SortingItem: - """Ein Element zum Sortieren.""" - text: str - correct_position: int - category: Optional[str] = None - - -@dataclass -class OpenQuestion: - """Eine offene Frage.""" - question: str - model_answer: str - keywords: List[str] - points: int = 1 - - -@dataclass -class Quiz: - """Ein komplettes Quiz.""" - quiz_type: QuizType - title: str - questions: List[Any] # Je nach Typ unterschiedlich - topic: Optional[str] = None - difficulty: str = "medium" - - class QuizGenerator: """ Generiert verschiedene Quiz-Typen aus Quelltexten. @@ -146,13 +107,12 @@ class QuizGenerator: return self._generate_true_false_llm(source_text, num_questions, difficulty) # Automatische Generierung - sentences = self._extract_factual_sentences(source_text) + sentences = extract_factual_sentences(source_text) questions = [] for i, sentence in enumerate(sentences[:num_questions]): # Abwechselnd wahre und falsche Aussagen if i % 2 == 0: - # Wahre Aussage questions.append(TrueFalseQuestion( statement=sentence, is_true=True, @@ -160,8 +120,7 @@ class QuizGenerator: source_reference=sentence[:50] )) else: - # Falsche Aussage (Negation) - false_statement = self._negate_sentence(sentence) + false_statement = negate_sentence(sentence) questions.append(TrueFalseQuestion( statement=false_statement, is_true=False, @@ -222,9 +181,8 @@ Antworte im JSON-Format: if self.llm_client: return self._generate_matching_llm(source_text, num_pairs, difficulty) - # Automatische Generierung: Begriff -> Definition pairs = [] - definitions = self._extract_definitions(source_text) + definitions = extract_definitions(source_text) for term, definition in definitions[:num_pairs]: pairs.append(MatchingPair( @@ -286,9 +244,8 @@ Antworte im JSON-Format: if self.llm_client: return self._generate_sorting_llm(source_text, num_items, difficulty) - # Automatische Generierung: Chronologische Reihenfolge items = [] - steps = self._extract_sequence(source_text) + steps = extract_sequence(source_text) for i, step in enumerate(steps[:num_items]): items.append(SortingItem( @@ -349,9 +306,8 @@ Antworte im JSON-Format: if self.llm_client: return self._generate_open_ended_llm(source_text, num_questions, difficulty) - # Automatische Generierung questions = [] - sentences = self._extract_factual_sentences(source_text) + sentences = extract_factual_sentences(source_text) question_starters = [ "Was bedeutet", @@ -362,8 +318,7 @@ Antworte im JSON-Format: ] for i, sentence in enumerate(sentences[:num_questions]): - # Extrahiere Schlüsselwort - keywords = self._extract_keywords(sentence) + keywords = extract_keywords(sentence) if keywords: keyword = keywords[0] starter = question_starters[i % len(question_starters)] @@ -421,76 +376,6 @@ Antworte im JSON-Format: logger.error(f"LLM error: {e}") return self._generate_open_ended(source_text, num_questions, difficulty) - # Hilfsmethoden - - def _extract_factual_sentences(self, text: str) -> List[str]: - """Extrahiert Fakten-Sätze aus dem Text.""" - sentences = re.split(r'[.!?]+', text) - factual = [] - - for sentence in sentences: - sentence = sentence.strip() - # Filtere zu kurze oder fragende Sätze - if len(sentence) > 20 and '?' not in sentence: - factual.append(sentence) - - return factual - - def _negate_sentence(self, sentence: str) -> str: - """Negiert eine Aussage einfach.""" - # Einfache Negation durch Einfügen von "nicht" - words = sentence.split() - if len(words) > 2: - # Nach erstem Verb "nicht" einfügen - for i, word in enumerate(words): - if word.endswith(('t', 'en', 'st')) and i > 0: - words.insert(i + 1, 'nicht') - break - return ' '.join(words) - - def _extract_definitions(self, text: str) -> List[Tuple[str, str]]: - """Extrahiert Begriff-Definition-Paare.""" - definitions = [] - - # Suche nach Mustern wie "X ist Y" oder "X bezeichnet Y" - patterns = [ - r'(\w+)\s+ist\s+(.+?)[.]', - r'(\w+)\s+bezeichnet\s+(.+?)[.]', - r'(\w+)\s+bedeutet\s+(.+?)[.]', - r'(\w+):\s+(.+?)[.]', - ] - - for pattern in patterns: - matches = re.findall(pattern, text) - for term, definition in matches: - if len(definition) > 10: - definitions.append((term, definition.strip())) - - return definitions - - def _extract_sequence(self, text: str) -> List[str]: - """Extrahiert eine Sequenz von Schritten.""" - steps = [] - - # Suche nach nummerierten Schritten - numbered = re.findall(r'\d+[.)]\s*([^.]+)', text) - steps.extend(numbered) - - # Suche nach Signalwörtern - signal_words = ['zuerst', 'dann', 'danach', 'anschließend', 'schließlich'] - for word in signal_words: - pattern = rf'{word}\s+([^.]+)' - matches = re.findall(pattern, text, re.IGNORECASE) - steps.extend(matches) - - return steps - - def _extract_keywords(self, text: str) -> List[str]: - """Extrahiert Schlüsselwörter.""" - # Längere Wörter mit Großbuchstaben (meist Substantive) - words = re.findall(r'\b[A-ZÄÖÜ][a-zäöüß]+\b', text) - return list(set(words))[:5] - def _empty_quiz(self, quiz_type: QuizType, title: str) -> Quiz: """Erstellt leeres Quiz bei Fehler.""" return Quiz( @@ -549,7 +434,6 @@ Antworte im JSON-Format: return self._true_false_to_h5p(quiz) elif quiz.quiz_type == QuizType.MATCHING: return self._matching_to_h5p(quiz) - # Weitere Typen... return {} def _true_false_to_h5p(self, quiz: Quiz) -> Dict[str, Any]: diff --git a/backend-lehrer/generators/quiz_helpers.py b/backend-lehrer/generators/quiz_helpers.py new file mode 100644 index 0000000..650dc76 --- /dev/null +++ b/backend-lehrer/generators/quiz_helpers.py @@ -0,0 +1,70 @@ +""" +Quiz Helpers - Text-Verarbeitungs-Hilfsfunktionen fuer Quiz-Generierung. +""" + +import re +from typing import List, Tuple + + +def extract_factual_sentences(text: str) -> List[str]: + """Extrahiert Fakten-Sätze aus dem Text.""" + sentences = re.split(r'[.!?]+', text) + factual = [] + + for sentence in sentences: + sentence = sentence.strip() + if len(sentence) > 20 and '?' not in sentence: + factual.append(sentence) + + return factual + + +def negate_sentence(sentence: str) -> str: + """Negiert eine Aussage einfach.""" + words = sentence.split() + if len(words) > 2: + for i, word in enumerate(words): + if word.endswith(('t', 'en', 'st')) and i > 0: + words.insert(i + 1, 'nicht') + break + return ' '.join(words) + + +def extract_definitions(text: str) -> List[Tuple[str, str]]: + """Extrahiert Begriff-Definition-Paare.""" + definitions = [] + patterns = [ + r'(\w+)\s+ist\s+(.+?)[.]', + r'(\w+)\s+bezeichnet\s+(.+?)[.]', + r'(\w+)\s+bedeutet\s+(.+?)[.]', + r'(\w+):\s+(.+?)[.]', + ] + + for pattern in patterns: + matches = re.findall(pattern, text) + for term, definition in matches: + if len(definition) > 10: + definitions.append((term, definition.strip())) + + return definitions + + +def extract_sequence(text: str) -> List[str]: + """Extrahiert eine Sequenz von Schritten.""" + steps = [] + numbered = re.findall(r'\d+[.)]\s*([^.]+)', text) + steps.extend(numbered) + + signal_words = ['zuerst', 'dann', 'danach', 'anschließend', 'schließlich'] + for word in signal_words: + pattern = rf'{word}\s+([^.]+)' + matches = re.findall(pattern, text, re.IGNORECASE) + steps.extend(matches) + + return steps + + +def extract_keywords(text: str) -> List[str]: + """Extrahiert Schlüsselwörter.""" + words = re.findall(r'\b[A-ZÄÖÜ][a-zäöüß]+\b', text) + return list(set(words))[:5] diff --git a/backend-lehrer/generators/quiz_models.py b/backend-lehrer/generators/quiz_models.py new file mode 100644 index 0000000..d466811 --- /dev/null +++ b/backend-lehrer/generators/quiz_models.py @@ -0,0 +1,65 @@ +""" +Quiz Models - Datenmodelle fuer Quiz-Generierung. + +Enthaelt alle Dataclasses und Enums fuer Quiz-Typen: +- True/False Fragen +- Zuordnungsaufgaben (Matching) +- Sortieraufgaben +- Offene Fragen +""" + +from typing import List, Any, Optional +from dataclasses import dataclass +from enum import Enum + + +class QuizType(str, Enum): + """Typen von Quiz-Aufgaben.""" + TRUE_FALSE = "true_false" + MATCHING = "matching" + SORTING = "sorting" + OPEN_ENDED = "open_ended" + + +@dataclass +class TrueFalseQuestion: + """Eine Wahr/Falsch-Frage.""" + statement: str + is_true: bool + explanation: str + source_reference: Optional[str] = None + + +@dataclass +class MatchingPair: + """Ein Zuordnungspaar.""" + left: str + right: str + hint: Optional[str] = None + + +@dataclass +class SortingItem: + """Ein Element zum Sortieren.""" + text: str + correct_position: int + category: Optional[str] = None + + +@dataclass +class OpenQuestion: + """Eine offene Frage.""" + question: str + model_answer: str + keywords: List[str] + points: int = 1 + + +@dataclass +class Quiz: + """Ein komplettes Quiz.""" + quiz_type: QuizType + title: str + questions: List[Any] # Je nach Typ unterschiedlich + topic: Optional[str] = None + difficulty: str = "medium" diff --git a/backend-lehrer/llm_gateway/routes/comparison.py b/backend-lehrer/llm_gateway/routes/comparison.py index b662d40..b4ab076 100644 --- a/backend-lehrer/llm_gateway/routes/comparison.py +++ b/backend-lehrer/llm_gateway/routes/comparison.py @@ -9,378 +9,33 @@ Dieses Modul ermoeglicht: import asyncio import logging -import time import uuid from datetime import datetime, timezone from typing import Optional -from pydantic import BaseModel, Field from fastapi import APIRouter, HTTPException, Depends -from ..models.chat import ChatMessage from ..middleware.auth import verify_api_key +from .comparison_models import ( + ComparisonRequest, + LLMResponse, + ComparisonResponse, + SavedComparison, + _comparisons_store, + _system_prompts_store, +) +from .comparison_providers import ( + call_openai, + call_claude, + search_tavily, + search_edusearch, + call_selfhosted_with_search, +) logger = logging.getLogger(__name__) router = APIRouter(prefix="/comparison", tags=["LLM Comparison"]) -class ComparisonRequest(BaseModel): - """Request fuer LLM-Vergleich.""" - prompt: str = Field(..., description="User prompt (z.B. Lehrer-Frage)") - system_prompt: Optional[str] = Field(None, description="Optionaler System Prompt") - enable_openai: bool = Field(True, description="OpenAI/ChatGPT aktivieren") - enable_claude: bool = Field(True, description="Claude aktivieren") - enable_selfhosted_tavily: bool = Field(True, description="Self-hosted + Tavily aktivieren") - enable_selfhosted_edusearch: bool = Field(True, description="Self-hosted + EduSearch aktivieren") - - # Parameter fuer Self-hosted Modelle - selfhosted_model: str = Field("llama3.2:3b", description="Self-hosted Modell") - temperature: float = Field(0.7, ge=0.0, le=2.0, description="Temperature") - top_p: float = Field(0.9, ge=0.0, le=1.0, description="Top-p Sampling") - max_tokens: int = Field(2048, ge=1, le=8192, description="Max Tokens") - - # Search Parameter - search_results_count: int = Field(5, ge=1, le=20, description="Anzahl Suchergebnisse") - edu_search_filters: Optional[dict] = Field(None, description="Filter fuer EduSearch") - - -class LLMResponse(BaseModel): - """Antwort eines einzelnen LLM.""" - provider: str - model: str - response: str - latency_ms: int - tokens_used: Optional[int] = None - search_results: Optional[list] = None - error: Optional[str] = None - timestamp: datetime = Field(default_factory=datetime.utcnow) - - -class ComparisonResponse(BaseModel): - """Gesamt-Antwort des Vergleichs.""" - comparison_id: str - prompt: str - system_prompt: Optional[str] - responses: list[LLMResponse] - created_at: datetime = Field(default_factory=datetime.utcnow) - - -class SavedComparison(BaseModel): - """Gespeicherter Vergleich fuer QA.""" - comparison_id: str - prompt: str - system_prompt: Optional[str] - responses: list[LLMResponse] - notes: Optional[str] = None - rating: Optional[dict] = None # {"openai": 4, "claude": 5, ...} - created_at: datetime - created_by: Optional[str] = None - - -# In-Memory Storage (in Production: Database) -_comparisons_store: dict[str, SavedComparison] = {} -_system_prompts_store: dict[str, dict] = { - "default": { - "id": "default", - "name": "Standard Lehrer-Assistent", - "prompt": """Du bist ein hilfreicher Assistent fuer Lehrkraefte in Deutschland. -Deine Aufgaben: -- Hilfe bei der Unterrichtsplanung -- Erklaerung von Fachinhalten -- Erstellung von Arbeitsblaettern und Pruefungen -- Beratung zu paedagogischen Methoden - -Antworte immer auf Deutsch und beachte den deutschen Lehrplankontext.""", - "created_at": datetime.now(timezone.utc).isoformat(), - }, - "curriculum": { - "id": "curriculum", - "name": "Lehrplan-Experte", - "prompt": """Du bist ein Experte fuer deutsche Lehrplaene und Bildungsstandards. -Du kennst: -- Lehrplaene aller 16 Bundeslaender -- KMK Bildungsstandards -- Kompetenzorientierung im deutschen Bildungssystem - -Beziehe dich immer auf konkrete Lehrplanvorgaben wenn moeglich.""", - "created_at": datetime.now(timezone.utc).isoformat(), - }, - "worksheet": { - "id": "worksheet", - "name": "Arbeitsblatt-Generator", - "prompt": """Du bist ein spezialisierter Assistent fuer die Erstellung von Arbeitsblaettern. -Erstelle didaktisch sinnvolle Aufgaben mit: -- Klaren Arbeitsanweisungen -- Differenzierungsmoeglichkeiten -- Loesungshinweisen - -Format: Markdown mit klarer Struktur.""", - "created_at": datetime.now(timezone.utc).isoformat(), - }, -} - - -async def _call_openai(prompt: str, system_prompt: Optional[str]) -> LLMResponse: - """Ruft OpenAI ChatGPT auf.""" - import os - import httpx - - start_time = time.time() - api_key = os.getenv("OPENAI_API_KEY") - - if not api_key: - return LLMResponse( - provider="openai", - model="gpt-4o-mini", - response="", - latency_ms=0, - error="OPENAI_API_KEY nicht konfiguriert" - ) - - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": prompt}) - - try: - async with httpx.AsyncClient(timeout=60.0) as client: - response = await client.post( - "https://api.openai.com/v1/chat/completions", - headers={ - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - }, - json={ - "model": "gpt-4o-mini", - "messages": messages, - "temperature": 0.7, - "max_tokens": 2048, - }, - ) - response.raise_for_status() - data = response.json() - - latency_ms = int((time.time() - start_time) * 1000) - content = data["choices"][0]["message"]["content"] - tokens = data.get("usage", {}).get("total_tokens") - - return LLMResponse( - provider="openai", - model="gpt-4o-mini", - response=content, - latency_ms=latency_ms, - tokens_used=tokens, - ) - except Exception as e: - return LLMResponse( - provider="openai", - model="gpt-4o-mini", - response="", - latency_ms=int((time.time() - start_time) * 1000), - error=str(e), - ) - - -async def _call_claude(prompt: str, system_prompt: Optional[str]) -> LLMResponse: - """Ruft Anthropic Claude auf.""" - import os - - start_time = time.time() - api_key = os.getenv("ANTHROPIC_API_KEY") - - if not api_key: - return LLMResponse( - provider="claude", - model="claude-3-5-sonnet-20241022", - response="", - latency_ms=0, - error="ANTHROPIC_API_KEY nicht konfiguriert" - ) - - try: - import anthropic - client = anthropic.AsyncAnthropic(api_key=api_key) - - response = await client.messages.create( - model="claude-3-5-sonnet-20241022", - max_tokens=2048, - system=system_prompt or "", - messages=[{"role": "user", "content": prompt}], - ) - - latency_ms = int((time.time() - start_time) * 1000) - content = response.content[0].text if response.content else "" - tokens = response.usage.input_tokens + response.usage.output_tokens - - return LLMResponse( - provider="claude", - model="claude-3-5-sonnet-20241022", - response=content, - latency_ms=latency_ms, - tokens_used=tokens, - ) - except Exception as e: - return LLMResponse( - provider="claude", - model="claude-3-5-sonnet-20241022", - response="", - latency_ms=int((time.time() - start_time) * 1000), - error=str(e), - ) - - -async def _search_tavily(query: str, count: int = 5) -> list[dict]: - """Sucht mit Tavily API.""" - import os - import httpx - - api_key = os.getenv("TAVILY_API_KEY") - if not api_key: - return [] - - try: - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post( - "https://api.tavily.com/search", - json={ - "api_key": api_key, - "query": query, - "max_results": count, - "include_domains": [ - "kmk.org", "bildungsserver.de", "bpb.de", - "bayern.de", "nrw.de", "berlin.de", - ], - }, - ) - response.raise_for_status() - data = response.json() - return data.get("results", []) - except Exception as e: - logger.error(f"Tavily search error: {e}") - return [] - - -async def _search_edusearch(query: str, count: int = 5, filters: Optional[dict] = None) -> list[dict]: - """Sucht mit EduSearch API.""" - import os - import httpx - - edu_search_url = os.getenv("EDU_SEARCH_URL", "http://edu-search-service:8084") - - try: - async with httpx.AsyncClient(timeout=30.0) as client: - payload = { - "q": query, - "limit": count, - "mode": "keyword", - } - if filters: - payload["filters"] = filters - - response = await client.post( - f"{edu_search_url}/v1/search", - json=payload, - ) - response.raise_for_status() - data = response.json() - - # Formatiere Ergebnisse - results = [] - for r in data.get("results", []): - results.append({ - "title": r.get("title", ""), - "url": r.get("url", ""), - "content": r.get("snippet", ""), - "score": r.get("scores", {}).get("final", 0), - }) - return results - except Exception as e: - logger.error(f"EduSearch error: {e}") - return [] - - -async def _call_selfhosted_with_search( - prompt: str, - system_prompt: Optional[str], - search_provider: str, - search_results: list[dict], - model: str, - temperature: float, - top_p: float, - max_tokens: int, -) -> LLMResponse: - """Ruft Self-hosted LLM mit Suchergebnissen auf.""" - import os - import httpx - - start_time = time.time() - ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434") - - # Baue Kontext aus Suchergebnissen - context_parts = [] - for i, result in enumerate(search_results, 1): - context_parts.append(f"[{i}] {result.get('title', 'Untitled')}") - context_parts.append(f" URL: {result.get('url', '')}") - context_parts.append(f" {result.get('content', '')[:500]}") - context_parts.append("") - - search_context = "\n".join(context_parts) - - # Erweitere System Prompt mit Suchergebnissen - augmented_system = f"""{system_prompt or ''} - -Du hast Zugriff auf folgende Suchergebnisse aus {"Tavily" if search_provider == "tavily" else "EduSearch (deutsche Bildungsquellen)"}: - -{search_context} - -Nutze diese Quellen um deine Antwort zu unterstuetzen. Zitiere relevante Quellen mit [Nummer].""" - - messages = [ - {"role": "system", "content": augmented_system}, - {"role": "user", "content": prompt}, - ] - - try: - async with httpx.AsyncClient(timeout=120.0) as client: - response = await client.post( - f"{ollama_url}/api/chat", - json={ - "model": model, - "messages": messages, - "stream": False, - "options": { - "temperature": temperature, - "top_p": top_p, - "num_predict": max_tokens, - }, - }, - ) - response.raise_for_status() - data = response.json() - - latency_ms = int((time.time() - start_time) * 1000) - content = data.get("message", {}).get("content", "") - tokens = data.get("prompt_eval_count", 0) + data.get("eval_count", 0) - - return LLMResponse( - provider=f"selfhosted_{search_provider}", - model=model, - response=content, - latency_ms=latency_ms, - tokens_used=tokens, - search_results=search_results, - ) - except Exception as e: - return LLMResponse( - provider=f"selfhosted_{search_provider}", - model=model, - response="", - latency_ms=int((time.time() - start_time) * 1000), - error=str(e), - search_results=search_results, - ) - - @router.post("/run", response_model=ComparisonResponse) async def run_comparison( request: ComparisonRequest, @@ -395,23 +50,19 @@ async def run_comparison( comparison_id = f"cmp-{uuid.uuid4().hex[:12]}" tasks = [] - # System Prompt vorbereiten system_prompt = request.system_prompt - # OpenAI if request.enable_openai: - tasks.append(("openai", _call_openai(request.prompt, system_prompt))) + tasks.append(("openai", call_openai(request.prompt, system_prompt))) - # Claude if request.enable_claude: - tasks.append(("claude", _call_claude(request.prompt, system_prompt))) + tasks.append(("claude", call_claude(request.prompt, system_prompt))) - # Self-hosted + Tavily if request.enable_selfhosted_tavily: - tavily_results = await _search_tavily(request.prompt, request.search_results_count) + tavily_results = await search_tavily(request.prompt, request.search_results_count) tasks.append(( "selfhosted_tavily", - _call_selfhosted_with_search( + call_selfhosted_with_search( request.prompt, system_prompt, "tavily", @@ -423,16 +74,15 @@ async def run_comparison( ) )) - # Self-hosted + EduSearch if request.enable_selfhosted_edusearch: - edu_results = await _search_edusearch( + edu_results = await search_edusearch( request.prompt, request.search_results_count, request.edu_search_filters, ) tasks.append(( "selfhosted_edusearch", - _call_selfhosted_with_search( + call_selfhosted_with_search( request.prompt, system_prompt, "edusearch", @@ -444,7 +94,6 @@ async def run_comparison( ) )) - # Parallele Ausfuehrung responses = [] if tasks: results = await asyncio.gather(*[t[1] for t in tasks], return_exceptions=True) diff --git a/backend-lehrer/llm_gateway/routes/comparison_models.py b/backend-lehrer/llm_gateway/routes/comparison_models.py new file mode 100644 index 0000000..3652a57 --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/comparison_models.py @@ -0,0 +1,103 @@ +""" +LLM Comparison - Pydantic Models und In-Memory Storage. +""" + +from datetime import datetime, timezone +from typing import Optional +from pydantic import BaseModel, Field + + +class ComparisonRequest(BaseModel): + """Request fuer LLM-Vergleich.""" + prompt: str = Field(..., description="User prompt (z.B. Lehrer-Frage)") + system_prompt: Optional[str] = Field(None, description="Optionaler System Prompt") + enable_openai: bool = Field(True, description="OpenAI/ChatGPT aktivieren") + enable_claude: bool = Field(True, description="Claude aktivieren") + enable_selfhosted_tavily: bool = Field(True, description="Self-hosted + Tavily aktivieren") + enable_selfhosted_edusearch: bool = Field(True, description="Self-hosted + EduSearch aktivieren") + + # Parameter fuer Self-hosted Modelle + selfhosted_model: str = Field("llama3.2:3b", description="Self-hosted Modell") + temperature: float = Field(0.7, ge=0.0, le=2.0, description="Temperature") + top_p: float = Field(0.9, ge=0.0, le=1.0, description="Top-p Sampling") + max_tokens: int = Field(2048, ge=1, le=8192, description="Max Tokens") + + # Search Parameter + search_results_count: int = Field(5, ge=1, le=20, description="Anzahl Suchergebnisse") + edu_search_filters: Optional[dict] = Field(None, description="Filter fuer EduSearch") + + +class LLMResponse(BaseModel): + """Antwort eines einzelnen LLM.""" + provider: str + model: str + response: str + latency_ms: int + tokens_used: Optional[int] = None + search_results: Optional[list] = None + error: Optional[str] = None + timestamp: datetime = Field(default_factory=datetime.utcnow) + + +class ComparisonResponse(BaseModel): + """Gesamt-Antwort des Vergleichs.""" + comparison_id: str + prompt: str + system_prompt: Optional[str] + responses: list[LLMResponse] + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class SavedComparison(BaseModel): + """Gespeicherter Vergleich fuer QA.""" + comparison_id: str + prompt: str + system_prompt: Optional[str] + responses: list[LLMResponse] + notes: Optional[str] = None + rating: Optional[dict] = None # {"openai": 4, "claude": 5, ...} + created_at: datetime + created_by: Optional[str] = None + + +# In-Memory Storage (in Production: Database) +_comparisons_store: dict[str, SavedComparison] = {} +_system_prompts_store: dict[str, dict] = { + "default": { + "id": "default", + "name": "Standard Lehrer-Assistent", + "prompt": """Du bist ein hilfreicher Assistent fuer Lehrkraefte in Deutschland. +Deine Aufgaben: +- Hilfe bei der Unterrichtsplanung +- Erklaerung von Fachinhalten +- Erstellung von Arbeitsblaettern und Pruefungen +- Beratung zu paedagogischen Methoden + +Antworte immer auf Deutsch und beachte den deutschen Lehrplankontext.""", + "created_at": datetime.now(timezone.utc).isoformat(), + }, + "curriculum": { + "id": "curriculum", + "name": "Lehrplan-Experte", + "prompt": """Du bist ein Experte fuer deutsche Lehrplaene und Bildungsstandards. +Du kennst: +- Lehrplaene aller 16 Bundeslaender +- KMK Bildungsstandards +- Kompetenzorientierung im deutschen Bildungssystem + +Beziehe dich immer auf konkrete Lehrplanvorgaben wenn moeglich.""", + "created_at": datetime.now(timezone.utc).isoformat(), + }, + "worksheet": { + "id": "worksheet", + "name": "Arbeitsblatt-Generator", + "prompt": """Du bist ein spezialisierter Assistent fuer die Erstellung von Arbeitsblaettern. +Erstelle didaktisch sinnvolle Aufgaben mit: +- Klaren Arbeitsanweisungen +- Differenzierungsmoeglichkeiten +- Loesungshinweisen + +Format: Markdown mit klarer Struktur.""", + "created_at": datetime.now(timezone.utc).isoformat(), + }, +} diff --git a/backend-lehrer/llm_gateway/routes/comparison_providers.py b/backend-lehrer/llm_gateway/routes/comparison_providers.py new file mode 100644 index 0000000..36237c2 --- /dev/null +++ b/backend-lehrer/llm_gateway/routes/comparison_providers.py @@ -0,0 +1,270 @@ +""" +LLM Comparison - Provider-Aufrufe (OpenAI, Claude, Self-hosted, Search). +""" + +import logging +import time +from typing import Optional + +from .comparison_models import LLMResponse + +logger = logging.getLogger(__name__) + + +async def call_openai(prompt: str, system_prompt: Optional[str]) -> LLMResponse: + """Ruft OpenAI ChatGPT auf.""" + import os + import httpx + + start_time = time.time() + api_key = os.getenv("OPENAI_API_KEY") + + if not api_key: + return LLMResponse( + provider="openai", + model="gpt-4o-mini", + response="", + latency_ms=0, + error="OPENAI_API_KEY nicht konfiguriert" + ) + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + "https://api.openai.com/v1/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json={ + "model": "gpt-4o-mini", + "messages": messages, + "temperature": 0.7, + "max_tokens": 2048, + }, + ) + response.raise_for_status() + data = response.json() + + latency_ms = int((time.time() - start_time) * 1000) + content = data["choices"][0]["message"]["content"] + tokens = data.get("usage", {}).get("total_tokens") + + return LLMResponse( + provider="openai", + model="gpt-4o-mini", + response=content, + latency_ms=latency_ms, + tokens_used=tokens, + ) + except Exception as e: + return LLMResponse( + provider="openai", + model="gpt-4o-mini", + response="", + latency_ms=int((time.time() - start_time) * 1000), + error=str(e), + ) + + +async def call_claude(prompt: str, system_prompt: Optional[str]) -> LLMResponse: + """Ruft Anthropic Claude auf.""" + import os + + start_time = time.time() + api_key = os.getenv("ANTHROPIC_API_KEY") + + if not api_key: + return LLMResponse( + provider="claude", + model="claude-3-5-sonnet-20241022", + response="", + latency_ms=0, + error="ANTHROPIC_API_KEY nicht konfiguriert" + ) + + try: + import anthropic + client = anthropic.AsyncAnthropic(api_key=api_key) + + response = await client.messages.create( + model="claude-3-5-sonnet-20241022", + max_tokens=2048, + system=system_prompt or "", + messages=[{"role": "user", "content": prompt}], + ) + + latency_ms = int((time.time() - start_time) * 1000) + content = response.content[0].text if response.content else "" + tokens = response.usage.input_tokens + response.usage.output_tokens + + return LLMResponse( + provider="claude", + model="claude-3-5-sonnet-20241022", + response=content, + latency_ms=latency_ms, + tokens_used=tokens, + ) + except Exception as e: + return LLMResponse( + provider="claude", + model="claude-3-5-sonnet-20241022", + response="", + latency_ms=int((time.time() - start_time) * 1000), + error=str(e), + ) + + +async def search_tavily(query: str, count: int = 5) -> list[dict]: + """Sucht mit Tavily API.""" + import os + import httpx + + api_key = os.getenv("TAVILY_API_KEY") + if not api_key: + return [] + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + "https://api.tavily.com/search", + json={ + "api_key": api_key, + "query": query, + "max_results": count, + "include_domains": [ + "kmk.org", "bildungsserver.de", "bpb.de", + "bayern.de", "nrw.de", "berlin.de", + ], + }, + ) + response.raise_for_status() + data = response.json() + return data.get("results", []) + except Exception as e: + logger.error(f"Tavily search error: {e}") + return [] + + +async def search_edusearch(query: str, count: int = 5, filters: Optional[dict] = None) -> list[dict]: + """Sucht mit EduSearch API.""" + import os + import httpx + + edu_search_url = os.getenv("EDU_SEARCH_URL", "http://edu-search-service:8084") + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + payload = { + "q": query, + "limit": count, + "mode": "keyword", + } + if filters: + payload["filters"] = filters + + response = await client.post( + f"{edu_search_url}/v1/search", + json=payload, + ) + response.raise_for_status() + data = response.json() + + results = [] + for r in data.get("results", []): + results.append({ + "title": r.get("title", ""), + "url": r.get("url", ""), + "content": r.get("snippet", ""), + "score": r.get("scores", {}).get("final", 0), + }) + return results + except Exception as e: + logger.error(f"EduSearch error: {e}") + return [] + + +async def call_selfhosted_with_search( + prompt: str, + system_prompt: Optional[str], + search_provider: str, + search_results: list[dict], + model: str, + temperature: float, + top_p: float, + max_tokens: int, +) -> LLMResponse: + """Ruft Self-hosted LLM mit Suchergebnissen auf.""" + import os + import httpx + + start_time = time.time() + ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434") + + # Baue Kontext aus Suchergebnissen + context_parts = [] + for i, result in enumerate(search_results, 1): + context_parts.append(f"[{i}] {result.get('title', 'Untitled')}") + context_parts.append(f" URL: {result.get('url', '')}") + context_parts.append(f" {result.get('content', '')[:500]}") + context_parts.append("") + + search_context = "\n".join(context_parts) + + augmented_system = f"""{system_prompt or ''} + +Du hast Zugriff auf folgende Suchergebnisse aus {"Tavily" if search_provider == "tavily" else "EduSearch (deutsche Bildungsquellen)"}: + +{search_context} + +Nutze diese Quellen um deine Antwort zu unterstuetzen. Zitiere relevante Quellen mit [Nummer].""" + + messages = [ + {"role": "system", "content": augmented_system}, + {"role": "user", "content": prompt}, + ] + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + f"{ollama_url}/api/chat", + json={ + "model": model, + "messages": messages, + "stream": False, + "options": { + "temperature": temperature, + "top_p": top_p, + "num_predict": max_tokens, + }, + }, + ) + response.raise_for_status() + data = response.json() + + latency_ms = int((time.time() - start_time) * 1000) + content = data.get("message", {}).get("content", "") + tokens = data.get("prompt_eval_count", 0) + data.get("eval_count", 0) + + return LLMResponse( + provider=f"selfhosted_{search_provider}", + model=model, + response=content, + latency_ms=latency_ms, + tokens_used=tokens, + search_results=search_results, + ) + except Exception as e: + return LLMResponse( + provider=f"selfhosted_{search_provider}", + model=model, + response="", + latency_ms=int((time.time() - start_time) * 1000), + error=str(e), + search_results=search_results, + ) diff --git a/backend-lehrer/llm_gateway/services/inference.py b/backend-lehrer/llm_gateway/services/inference.py index 756afc5..e39f68e 100644 --- a/backend-lehrer/llm_gateway/services/inference.py +++ b/backend-lehrer/llm_gateway/services/inference.py @@ -8,10 +8,8 @@ Unterstützt: """ import httpx -import json import logging from typing import AsyncIterator, Optional -from dataclasses import dataclass from ..config import get_config, LLMBackendConfig from ..models.chat import ( @@ -20,26 +18,23 @@ from ..models.chat import ( ChatCompletionChunk, ChatMessage, ChatChoice, - StreamChoice, - ChatChoiceDelta, Usage, ModelInfo, ModelListResponse, ) +from .inference_backends import ( + InferenceResult, + call_ollama, + stream_ollama, + call_openai_compatible, + stream_openai_compatible, + call_anthropic, + stream_anthropic, +) logger = logging.getLogger(__name__) -@dataclass -class InferenceResult: - """Ergebnis einer Inference-Anfrage.""" - content: str - model: str - backend: str - usage: Optional[Usage] = None - finish_reason: str = "stop" - - class InferenceService: """Service für LLM Inference über verschiedene Backends.""" @@ -68,26 +63,17 @@ class InferenceService: return None def _map_model_to_backend(self, model: str) -> tuple[str, LLMBackendConfig]: - """ - Mapped ein Modell-Name zum entsprechenden Backend. - - Beispiele: - - "breakpilot-teacher-8b" → Ollama/vLLM mit llama3.1:8b - - "claude-3-5-sonnet" → Anthropic - """ + """Mapped ein Modell-Name zum entsprechenden Backend.""" model_lower = model.lower() - # Explizite Claude-Modelle → Anthropic if "claude" in model_lower: if self.config.anthropic and self.config.anthropic.enabled: return self.config.anthropic.default_model, self.config.anthropic raise ValueError("Anthropic backend not configured") - # BreakPilot Modelle → primäres Backend if "breakpilot" in model_lower or "teacher" in model_lower: backend = self._get_available_backend() if backend: - # Map zu tatsächlichem Modell-Namen if "70b" in model_lower: actual_model = "llama3.1:70b" if backend.name == "ollama" else "meta-llama/Meta-Llama-3.1-70B-Instruct" else: @@ -95,7 +81,6 @@ class InferenceService: return actual_model, backend raise ValueError("No LLM backend available") - # Mistral Modelle if "mistral" in model_lower: backend = self._get_available_backend() if backend: @@ -103,409 +88,64 @@ class InferenceService: return actual_model, backend raise ValueError("No LLM backend available") - # Fallback: verwende Modell-Name direkt backend = self._get_available_backend() if backend: return model, backend raise ValueError("No LLM backend available") - async def _call_ollama( - self, - backend: LLMBackendConfig, - model: str, - request: ChatCompletionRequest, - ) -> InferenceResult: - """Ruft Ollama API auf (nicht OpenAI-kompatibel).""" - client = await self.get_client() - - # Ollama verwendet eigenes Format - messages = [{"role": m.role, "content": m.content or ""} for m in request.messages] - - payload = { - "model": model, - "messages": messages, - "stream": False, - "options": { - "temperature": request.temperature, - "top_p": request.top_p, - }, - } - - if request.max_tokens: - payload["options"]["num_predict"] = request.max_tokens - - response = await client.post( - f"{backend.base_url}/api/chat", - json=payload, - timeout=backend.timeout, - ) - response.raise_for_status() - data = response.json() - - return InferenceResult( - content=data.get("message", {}).get("content", ""), - model=model, - backend="ollama", - usage=Usage( - prompt_tokens=data.get("prompt_eval_count", 0), - completion_tokens=data.get("eval_count", 0), - total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0), - ), - finish_reason="stop" if data.get("done") else "length", - ) - - async def _stream_ollama( - self, - backend: LLMBackendConfig, - model: str, - request: ChatCompletionRequest, - response_id: str, - ) -> AsyncIterator[ChatCompletionChunk]: - """Streamt von Ollama.""" - client = await self.get_client() - - messages = [{"role": m.role, "content": m.content or ""} for m in request.messages] - - payload = { - "model": model, - "messages": messages, - "stream": True, - "options": { - "temperature": request.temperature, - "top_p": request.top_p, - }, - } - - if request.max_tokens: - payload["options"]["num_predict"] = request.max_tokens - - async with client.stream( - "POST", - f"{backend.base_url}/api/chat", - json=payload, - timeout=backend.timeout, - ) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if not line: - continue - try: - data = json.loads(line) - content = data.get("message", {}).get("content", "") - done = data.get("done", False) - - yield ChatCompletionChunk( - id=response_id, - model=model, - choices=[ - StreamChoice( - index=0, - delta=ChatChoiceDelta(content=content), - finish_reason="stop" if done else None, - ) - ], - ) - except json.JSONDecodeError: - continue - - async def _call_openai_compatible( - self, - backend: LLMBackendConfig, - model: str, - request: ChatCompletionRequest, - ) -> InferenceResult: - """Ruft OpenAI-kompatible API auf (vLLM, etc.).""" - client = await self.get_client() - - headers = {"Content-Type": "application/json"} - if backend.api_key: - headers["Authorization"] = f"Bearer {backend.api_key}" - - payload = { - "model": model, - "messages": [m.model_dump(exclude_none=True) for m in request.messages], - "stream": False, - "temperature": request.temperature, - "top_p": request.top_p, - } - - if request.max_tokens: - payload["max_tokens"] = request.max_tokens - if request.stop: - payload["stop"] = request.stop - - response = await client.post( - f"{backend.base_url}/v1/chat/completions", - json=payload, - headers=headers, - timeout=backend.timeout, - ) - response.raise_for_status() - data = response.json() - - choice = data.get("choices", [{}])[0] - usage_data = data.get("usage", {}) - - return InferenceResult( - content=choice.get("message", {}).get("content", ""), - model=model, - backend=backend.name, - usage=Usage( - prompt_tokens=usage_data.get("prompt_tokens", 0), - completion_tokens=usage_data.get("completion_tokens", 0), - total_tokens=usage_data.get("total_tokens", 0), - ), - finish_reason=choice.get("finish_reason", "stop"), - ) - - async def _stream_openai_compatible( - self, - backend: LLMBackendConfig, - model: str, - request: ChatCompletionRequest, - response_id: str, - ) -> AsyncIterator[ChatCompletionChunk]: - """Streamt von OpenAI-kompatibler API.""" - client = await self.get_client() - - headers = {"Content-Type": "application/json"} - if backend.api_key: - headers["Authorization"] = f"Bearer {backend.api_key}" - - payload = { - "model": model, - "messages": [m.model_dump(exclude_none=True) for m in request.messages], - "stream": True, - "temperature": request.temperature, - "top_p": request.top_p, - } - - if request.max_tokens: - payload["max_tokens"] = request.max_tokens - - async with client.stream( - "POST", - f"{backend.base_url}/v1/chat/completions", - json=payload, - headers=headers, - timeout=backend.timeout, - ) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if not line or not line.startswith("data: "): - continue - data_str = line[6:] # Remove "data: " prefix - if data_str == "[DONE]": - break - try: - data = json.loads(data_str) - choice = data.get("choices", [{}])[0] - delta = choice.get("delta", {}) - - yield ChatCompletionChunk( - id=response_id, - model=model, - choices=[ - StreamChoice( - index=0, - delta=ChatChoiceDelta( - role=delta.get("role"), - content=delta.get("content"), - ), - finish_reason=choice.get("finish_reason"), - ) - ], - ) - except json.JSONDecodeError: - continue - - async def _call_anthropic( - self, - backend: LLMBackendConfig, - model: str, - request: ChatCompletionRequest, - ) -> InferenceResult: - """Ruft Anthropic Claude API auf.""" - # Anthropic SDK verwenden (bereits installiert) - try: - import anthropic - except ImportError: - raise ImportError("anthropic package required for Claude API") - - client = anthropic.AsyncAnthropic(api_key=backend.api_key) - - # System message extrahieren - system_content = "" - messages = [] - for msg in request.messages: - if msg.role == "system": - system_content += (msg.content or "") + "\n" - else: - messages.append({"role": msg.role, "content": msg.content or ""}) - - response = await client.messages.create( - model=model, - max_tokens=request.max_tokens or 4096, - system=system_content.strip() if system_content else None, - messages=messages, - temperature=request.temperature, - top_p=request.top_p, - ) - - content = "" - if response.content: - content = response.content[0].text if response.content[0].type == "text" else "" - - return InferenceResult( - content=content, - model=model, - backend="anthropic", - usage=Usage( - prompt_tokens=response.usage.input_tokens, - completion_tokens=response.usage.output_tokens, - total_tokens=response.usage.input_tokens + response.usage.output_tokens, - ), - finish_reason="stop" if response.stop_reason == "end_turn" else response.stop_reason or "stop", - ) - - async def _stream_anthropic( - self, - backend: LLMBackendConfig, - model: str, - request: ChatCompletionRequest, - response_id: str, - ) -> AsyncIterator[ChatCompletionChunk]: - """Streamt von Anthropic Claude API.""" - try: - import anthropic - except ImportError: - raise ImportError("anthropic package required for Claude API") - - client = anthropic.AsyncAnthropic(api_key=backend.api_key) - - # System message extrahieren - system_content = "" - messages = [] - for msg in request.messages: - if msg.role == "system": - system_content += (msg.content or "") + "\n" - else: - messages.append({"role": msg.role, "content": msg.content or ""}) - - async with client.messages.stream( - model=model, - max_tokens=request.max_tokens or 4096, - system=system_content.strip() if system_content else None, - messages=messages, - temperature=request.temperature, - top_p=request.top_p, - ) as stream: - async for text in stream.text_stream: - yield ChatCompletionChunk( - id=response_id, - model=model, - choices=[ - StreamChoice( - index=0, - delta=ChatChoiceDelta(content=text), - finish_reason=None, - ) - ], - ) - - # Final chunk with finish_reason - yield ChatCompletionChunk( - id=response_id, - model=model, - choices=[ - StreamChoice( - index=0, - delta=ChatChoiceDelta(), - finish_reason="stop", - ) - ], - ) - async def complete(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - """ - Führt Chat Completion durch (non-streaming). - """ + """Führt Chat Completion durch (non-streaming).""" actual_model, backend = self._map_model_to_backend(request.model) + logger.info(f"Inference request: model={request.model} -> {actual_model} via {backend.name}") - logger.info(f"Inference request: model={request.model} → {actual_model} via {backend.name}") + client = await self.get_client() if backend.name == "ollama": - result = await self._call_ollama(backend, actual_model, request) + result = await call_ollama(client, backend, actual_model, request) elif backend.name == "anthropic": - result = await self._call_anthropic(backend, actual_model, request) + result = await call_anthropic(backend, actual_model, request) else: - result = await self._call_openai_compatible(backend, actual_model, request) + result = await call_openai_compatible(client, backend, actual_model, request) return ChatCompletionResponse( - model=request.model, # Original requested model name - choices=[ - ChatChoice( - index=0, - message=ChatMessage(role="assistant", content=result.content), - finish_reason=result.finish_reason, - ) - ], + model=request.model, + choices=[ChatChoice(index=0, message=ChatMessage(role="assistant", content=result.content), finish_reason=result.finish_reason)], usage=result.usage, ) async def stream(self, request: ChatCompletionRequest) -> AsyncIterator[ChatCompletionChunk]: - """ - Führt Chat Completion mit Streaming durch. - """ + """Führt Chat Completion mit Streaming durch.""" import uuid response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" actual_model, backend = self._map_model_to_backend(request.model) + logger.info(f"Streaming request: model={request.model} -> {actual_model} via {backend.name}") - logger.info(f"Streaming request: model={request.model} → {actual_model} via {backend.name}") + client = await self.get_client() if backend.name == "ollama": - async for chunk in self._stream_ollama(backend, actual_model, request, response_id): + async for chunk in stream_ollama(client, backend, actual_model, request, response_id): yield chunk elif backend.name == "anthropic": - async for chunk in self._stream_anthropic(backend, actual_model, request, response_id): + async for chunk in stream_anthropic(backend, actual_model, request, response_id): yield chunk else: - async for chunk in self._stream_openai_compatible(backend, actual_model, request, response_id): + async for chunk in stream_openai_compatible(client, backend, actual_model, request, response_id): yield chunk async def list_models(self) -> ModelListResponse: """Listet verfügbare Modelle.""" models = [] - # BreakPilot Modelle (mapped zu verfügbaren Backends) backend = self._get_available_backend() if backend: models.extend([ - ModelInfo( - id="breakpilot-teacher-8b", - owned_by="breakpilot", - description="Llama 3.1 8B optimiert für Schulkontext", - context_length=8192, - ), - ModelInfo( - id="breakpilot-teacher-70b", - owned_by="breakpilot", - description="Llama 3.1 70B für komplexe Aufgaben", - context_length=8192, - ), + ModelInfo(id="breakpilot-teacher-8b", owned_by="breakpilot", description="Llama 3.1 8B optimiert für Schulkontext", context_length=8192), + ModelInfo(id="breakpilot-teacher-70b", owned_by="breakpilot", description="Llama 3.1 70B für komplexe Aufgaben", context_length=8192), ]) - # Claude Modelle (wenn Anthropic konfiguriert) if self.config.anthropic and self.config.anthropic.enabled: - models.append( - ModelInfo( - id="claude-3-5-sonnet", - owned_by="anthropic", - description="Claude 3.5 Sonnet - Fallback für höchste Qualität", - context_length=200000, - ) - ) + models.append(ModelInfo(id="claude-3-5-sonnet", owned_by="anthropic", description="Claude 3.5 Sonnet - Fallback für höchste Qualität", context_length=200000)) return ModelListResponse(data=models) diff --git a/backend-lehrer/llm_gateway/services/inference_backends.py b/backend-lehrer/llm_gateway/services/inference_backends.py new file mode 100644 index 0000000..90de01c --- /dev/null +++ b/backend-lehrer/llm_gateway/services/inference_backends.py @@ -0,0 +1,230 @@ +""" +Inference Backends - Kommunikation mit einzelnen LLM-Providern. + +Unterstützt Ollama, OpenAI-kompatible APIs und Anthropic Claude. +""" + +import json +import logging +from typing import AsyncIterator, Optional +from dataclasses import dataclass + +from ..config import LLMBackendConfig +from ..models.chat import ( + ChatCompletionRequest, + ChatCompletionChunk, + ChatMessage, + StreamChoice, + ChatChoiceDelta, + Usage, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class InferenceResult: + """Ergebnis einer Inference-Anfrage.""" + content: str + model: str + backend: str + usage: Optional[Usage] = None + finish_reason: str = "stop" + + +async def call_ollama(client, backend: LLMBackendConfig, model: str, request: ChatCompletionRequest) -> InferenceResult: + """Ruft Ollama API auf (nicht OpenAI-kompatibel).""" + messages = [{"role": m.role, "content": m.content or ""} for m in request.messages] + + payload = { + "model": model, + "messages": messages, + "stream": False, + "options": {"temperature": request.temperature, "top_p": request.top_p}, + } + if request.max_tokens: + payload["options"]["num_predict"] = request.max_tokens + + response = await client.post(f"{backend.base_url}/api/chat", json=payload, timeout=backend.timeout) + response.raise_for_status() + data = response.json() + + return InferenceResult( + content=data.get("message", {}).get("content", ""), + model=model, backend="ollama", + usage=Usage( + prompt_tokens=data.get("prompt_eval_count", 0), + completion_tokens=data.get("eval_count", 0), + total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0), + ), + finish_reason="stop" if data.get("done") else "length", + ) + + +async def stream_ollama(client, backend, model, request, response_id) -> AsyncIterator[ChatCompletionChunk]: + """Streamt von Ollama.""" + messages = [{"role": m.role, "content": m.content or ""} for m in request.messages] + + payload = { + "model": model, "messages": messages, "stream": True, + "options": {"temperature": request.temperature, "top_p": request.top_p}, + } + if request.max_tokens: + payload["options"]["num_predict"] = request.max_tokens + + async with client.stream("POST", f"{backend.base_url}/api/chat", json=payload, timeout=backend.timeout) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if not line: + continue + try: + data = json.loads(line) + content = data.get("message", {}).get("content", "") + done = data.get("done", False) + yield ChatCompletionChunk( + id=response_id, model=model, + choices=[StreamChoice(index=0, delta=ChatChoiceDelta(content=content), finish_reason="stop" if done else None)], + ) + except json.JSONDecodeError: + continue + + +async def call_openai_compatible(client, backend, model, request) -> InferenceResult: + """Ruft OpenAI-kompatible API auf (vLLM, etc.).""" + headers = {"Content-Type": "application/json"} + if backend.api_key: + headers["Authorization"] = f"Bearer {backend.api_key}" + + payload = { + "model": model, + "messages": [m.model_dump(exclude_none=True) for m in request.messages], + "stream": False, "temperature": request.temperature, "top_p": request.top_p, + } + if request.max_tokens: + payload["max_tokens"] = request.max_tokens + if request.stop: + payload["stop"] = request.stop + + response = await client.post(f"{backend.base_url}/v1/chat/completions", json=payload, headers=headers, timeout=backend.timeout) + response.raise_for_status() + data = response.json() + + choice = data.get("choices", [{}])[0] + usage_data = data.get("usage", {}) + + return InferenceResult( + content=choice.get("message", {}).get("content", ""), + model=model, backend=backend.name, + usage=Usage( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + total_tokens=usage_data.get("total_tokens", 0), + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + + +async def stream_openai_compatible(client, backend, model, request, response_id) -> AsyncIterator[ChatCompletionChunk]: + """Streamt von OpenAI-kompatibler API.""" + headers = {"Content-Type": "application/json"} + if backend.api_key: + headers["Authorization"] = f"Bearer {backend.api_key}" + + payload = { + "model": model, + "messages": [m.model_dump(exclude_none=True) for m in request.messages], + "stream": True, "temperature": request.temperature, "top_p": request.top_p, + } + if request.max_tokens: + payload["max_tokens"] = request.max_tokens + + async with client.stream("POST", f"{backend.base_url}/v1/chat/completions", json=payload, headers=headers, timeout=backend.timeout) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if not line or not line.startswith("data: "): + continue + data_str = line[6:] + if data_str == "[DONE]": + break + try: + data = json.loads(data_str) + choice = data.get("choices", [{}])[0] + delta = choice.get("delta", {}) + yield ChatCompletionChunk( + id=response_id, model=model, + choices=[StreamChoice(index=0, delta=ChatChoiceDelta(role=delta.get("role"), content=delta.get("content")), finish_reason=choice.get("finish_reason"))], + ) + except json.JSONDecodeError: + continue + + +async def call_anthropic(backend, model, request) -> InferenceResult: + """Ruft Anthropic Claude API auf.""" + try: + import anthropic + except ImportError: + raise ImportError("anthropic package required for Claude API") + + client = anthropic.AsyncAnthropic(api_key=backend.api_key) + + system_content = "" + messages = [] + for msg in request.messages: + if msg.role == "system": + system_content += (msg.content or "") + "\n" + else: + messages.append({"role": msg.role, "content": msg.content or ""}) + + response = await client.messages.create( + model=model, max_tokens=request.max_tokens or 4096, + system=system_content.strip() if system_content else None, + messages=messages, temperature=request.temperature, top_p=request.top_p, + ) + + content = "" + if response.content: + content = response.content[0].text if response.content[0].type == "text" else "" + + return InferenceResult( + content=content, model=model, backend="anthropic", + usage=Usage( + prompt_tokens=response.usage.input_tokens, + completion_tokens=response.usage.output_tokens, + total_tokens=response.usage.input_tokens + response.usage.output_tokens, + ), + finish_reason="stop" if response.stop_reason == "end_turn" else response.stop_reason or "stop", + ) + + +async def stream_anthropic(backend, model, request, response_id) -> AsyncIterator[ChatCompletionChunk]: + """Streamt von Anthropic Claude API.""" + try: + import anthropic + except ImportError: + raise ImportError("anthropic package required for Claude API") + + client = anthropic.AsyncAnthropic(api_key=backend.api_key) + + system_content = "" + messages = [] + for msg in request.messages: + if msg.role == "system": + system_content += (msg.content or "") + "\n" + else: + messages.append({"role": msg.role, "content": msg.content or ""}) + + async with client.messages.stream( + model=model, max_tokens=request.max_tokens or 4096, + system=system_content.strip() if system_content else None, + messages=messages, temperature=request.temperature, top_p=request.top_p, + ) as stream: + async for text in stream.text_stream: + yield ChatCompletionChunk( + id=response_id, model=model, + choices=[StreamChoice(index=0, delta=ChatChoiceDelta(content=text), finish_reason=None)], + ) + + yield ChatCompletionChunk( + id=response_id, model=model, + choices=[StreamChoice(index=0, delta=ChatChoiceDelta(), finish_reason="stop")], + ) diff --git a/backend-lehrer/services/file_processor.py b/backend-lehrer/services/file_processor.py index 438c220..17ad3dd 100644 --- a/backend-lehrer/services/file_processor.py +++ b/backend-lehrer/services/file_processor.py @@ -15,60 +15,24 @@ Verwendet: """ import logging -import os import io -import base64 from pathlib import Path -from typing import Optional, List, Dict, Any, Tuple, Union -from dataclasses import dataclass -from enum import Enum +from typing import Optional, List, Dict, Any import cv2 import numpy as np from PIL import Image +from .file_processor_models import ( + FileType, + ProcessingMode, + ProcessedRegion, + ProcessingResult, +) + logger = logging.getLogger(__name__) -class FileType(str, Enum): - """Unterstützte Dateitypen.""" - PDF = "pdf" - IMAGE = "image" - DOCX = "docx" - DOC = "doc" - TXT = "txt" - UNKNOWN = "unknown" - - -class ProcessingMode(str, Enum): - """Verarbeitungsmodi.""" - OCR_HANDWRITING = "ocr_handwriting" # Handschrifterkennung - OCR_PRINTED = "ocr_printed" # Gedruckter Text - TEXT_EXTRACT = "text_extract" # Textextraktion (PDF/DOCX) - MIXED = "mixed" # Kombiniert OCR + Textextraktion - - -@dataclass -class ProcessedRegion: - """Ein erkannter Textbereich.""" - text: str - confidence: float - bbox: Tuple[int, int, int, int] # x1, y1, x2, y2 - page: int = 1 - - -@dataclass -class ProcessingResult: - """Ergebnis der Dokumentenverarbeitung.""" - text: str - confidence: float - regions: List[ProcessedRegion] - page_count: int - file_type: FileType - processing_mode: ProcessingMode - metadata: Dict[str, Any] - - class FileProcessor: """ Zentrale Dokumentenverarbeitung für BreakPilot. @@ -81,17 +45,9 @@ class FileProcessor: """ def __init__(self, ocr_lang: str = "de", use_gpu: bool = False): - """ - Initialisiert den File Processor. - - Args: - ocr_lang: Sprache für OCR (default: "de" für Deutsch) - use_gpu: GPU für OCR nutzen (beschleunigt Verarbeitung) - """ self.ocr_lang = ocr_lang self.use_gpu = use_gpu self._ocr_engine = None - logger.info(f"FileProcessor initialized (lang={ocr_lang}, gpu={use_gpu})") @property @@ -107,7 +63,7 @@ class FileProcessor: from paddleocr import PaddleOCR return PaddleOCR( use_angle_cls=True, - lang='german', # Deutsch + lang='german', use_gpu=self.use_gpu, show_log=False ) @@ -116,16 +72,7 @@ class FileProcessor: return None def detect_file_type(self, file_path: str = None, file_bytes: bytes = None) -> FileType: - """ - Erkennt den Dateityp. - - Args: - file_path: Pfad zur Datei - file_bytes: Dateiinhalt als Bytes - - Returns: - FileType enum - """ + """Erkennt den Dateityp.""" if file_path: ext = Path(file_path).suffix.lower() if ext == ".pdf": @@ -140,14 +87,13 @@ class FileProcessor: return FileType.TXT if file_bytes: - # Magic number detection if file_bytes[:4] == b'%PDF': return FileType.PDF elif file_bytes[:8] == b'\x89PNG\r\n\x1a\n': return FileType.IMAGE - elif file_bytes[:2] in [b'\xff\xd8', b'BM']: # JPEG, BMP + elif file_bytes[:2] in [b'\xff\xd8', b'BM']: return FileType.IMAGE - elif file_bytes[:4] == b'PK\x03\x04': # ZIP (DOCX) + elif file_bytes[:4] == b'PK\x03\x04': return FileType.DOCX return FileType.UNKNOWN @@ -158,17 +104,7 @@ class FileProcessor: file_bytes: bytes = None, mode: ProcessingMode = ProcessingMode.MIXED ) -> ProcessingResult: - """ - Verarbeitet ein Dokument. - - Args: - file_path: Pfad zur Datei - file_bytes: Dateiinhalt als Bytes - mode: Verarbeitungsmodus - - Returns: - ProcessingResult mit extrahiertem Text und Metadaten - """ + """Verarbeitet ein Dokument.""" if not file_path and not file_bytes: raise ValueError("Entweder file_path oder file_bytes muss angegeben werden") @@ -186,18 +122,12 @@ class FileProcessor: else: raise ValueError(f"Nicht unterstützter Dateityp: {file_type}") - def _process_pdf( - self, - file_path: str = None, - file_bytes: bytes = None, - mode: ProcessingMode = ProcessingMode.MIXED - ) -> ProcessingResult: + def _process_pdf(self, file_path=None, file_bytes=None, mode=ProcessingMode.MIXED): """Verarbeitet PDF-Dateien.""" try: - import fitz # PyMuPDF + import fitz except ImportError: logger.warning("PyMuPDF nicht installiert - versuche Fallback") - # Fallback: PDF als Bild behandeln return self._process_image(file_path, file_bytes, mode) if file_bytes: @@ -205,35 +135,27 @@ class FileProcessor: else: doc = fitz.open(file_path) - all_text = [] - all_regions = [] - total_confidence = 0.0 - region_count = 0 + all_text, all_regions = [], [] + total_confidence, region_count = 0.0, 0 for page_num, page in enumerate(doc, start=1): - # Erst versuchen Text direkt zu extrahieren page_text = page.get_text() if page_text.strip() and mode != ProcessingMode.OCR_HANDWRITING: - # PDF enthält Text (nicht nur Bilder) all_text.append(page_text) all_regions.append(ProcessedRegion( - text=page_text, - confidence=1.0, + text=page_text, confidence=1.0, bbox=(0, 0, int(page.rect.width), int(page.rect.height)), page=page_num )) total_confidence += 1.0 region_count += 1 else: - # Seite als Bild rendern und OCR anwenden - pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x Auflösung + pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) img_bytes = pix.tobytes("png") img = Image.open(io.BytesIO(img_bytes)) - ocr_result = self._ocr_image(img) all_text.append(ocr_result["text"]) - for region in ocr_result["regions"]: region.page = page_num all_regions.append(region) @@ -241,55 +163,34 @@ class FileProcessor: region_count += 1 doc.close() - avg_confidence = total_confidence / region_count if region_count > 0 else 0.0 return ProcessingResult( - text="\n\n".join(all_text), - confidence=avg_confidence, + text="\n\n".join(all_text), confidence=avg_confidence, regions=all_regions, page_count=len(doc) if hasattr(doc, '__len__') else 1, - file_type=FileType.PDF, - processing_mode=mode, + file_type=FileType.PDF, processing_mode=mode, metadata={"source": file_path or "bytes"} ) - def _process_image( - self, - file_path: str = None, - file_bytes: bytes = None, - mode: ProcessingMode = ProcessingMode.MIXED - ) -> ProcessingResult: + def _process_image(self, file_path=None, file_bytes=None, mode=ProcessingMode.MIXED): """Verarbeitet Bilddateien.""" if file_bytes: img = Image.open(io.BytesIO(file_bytes)) else: img = Image.open(file_path) - # Bildvorverarbeitung processed_img = self._preprocess_image(img) - - # OCR ocr_result = self._ocr_image(processed_img) return ProcessingResult( - text=ocr_result["text"], - confidence=ocr_result["confidence"], - regions=ocr_result["regions"], - page_count=1, - file_type=FileType.IMAGE, - processing_mode=mode, - metadata={ - "source": file_path or "bytes", - "image_size": img.size - } + text=ocr_result["text"], confidence=ocr_result["confidence"], + regions=ocr_result["regions"], page_count=1, + file_type=FileType.IMAGE, processing_mode=mode, + metadata={"source": file_path or "bytes", "image_size": img.size} ) - def _process_docx( - self, - file_path: str = None, - file_bytes: bytes = None - ) -> ProcessingResult: + def _process_docx(self, file_path=None, file_bytes=None): """Verarbeitet DOCX-Dateien.""" try: from docx import Document @@ -306,7 +207,6 @@ class FileProcessor: if para.text.strip(): paragraphs.append(para.text) - # Auch Tabellen extrahieren for table in doc.tables: for row in table.rows: row_text = " | ".join(cell.text for cell in row.cells) @@ -316,25 +216,14 @@ class FileProcessor: text = "\n\n".join(paragraphs) return ProcessingResult( - text=text, - confidence=1.0, # Direkte Textextraktion - regions=[ProcessedRegion( - text=text, - confidence=1.0, - bbox=(0, 0, 0, 0), - page=1 - )], - page_count=1, - file_type=FileType.DOCX, + text=text, confidence=1.0, + regions=[ProcessedRegion(text=text, confidence=1.0, bbox=(0, 0, 0, 0), page=1)], + page_count=1, file_type=FileType.DOCX, processing_mode=ProcessingMode.TEXT_EXTRACT, metadata={"source": file_path or "bytes"} ) - def _process_txt( - self, - file_path: str = None, - file_bytes: bytes = None - ) -> ProcessingResult: + def _process_txt(self, file_path=None, file_bytes=None): """Verarbeitet Textdateien.""" if file_bytes: text = file_bytes.decode('utf-8', errors='ignore') @@ -343,146 +232,65 @@ class FileProcessor: text = f.read() return ProcessingResult( - text=text, - confidence=1.0, - regions=[ProcessedRegion( - text=text, - confidence=1.0, - bbox=(0, 0, 0, 0), - page=1 - )], - page_count=1, - file_type=FileType.TXT, + text=text, confidence=1.0, + regions=[ProcessedRegion(text=text, confidence=1.0, bbox=(0, 0, 0, 0), page=1)], + page_count=1, file_type=FileType.TXT, processing_mode=ProcessingMode.TEXT_EXTRACT, metadata={"source": file_path or "bytes"} ) def _preprocess_image(self, img: Image.Image) -> Image.Image: - """ - Vorverarbeitung des Bildes für bessere OCR-Ergebnisse. - - - Konvertierung zu Graustufen - - Kontrastverstärkung - - Rauschunterdrückung - - Binarisierung - """ - # PIL zu OpenCV + """Vorverarbeitung des Bildes für bessere OCR-Ergebnisse.""" cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) - - # Zu Graustufen konvertieren gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) - - # Rauschunterdrückung denoised = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21) - - # Kontrastverstärkung (CLAHE) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) enhanced = clahe.apply(denoised) - - # Adaptive Binarisierung binary = cv2.adaptiveThreshold( - enhanced, - 255, - cv2.ADAPTIVE_THRESH_GAUSSIAN_C, - cv2.THRESH_BINARY, - 11, - 2 + enhanced, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY, 11, 2 ) - - # Zurück zu PIL return Image.fromarray(binary) def _ocr_image(self, img: Image.Image) -> Dict[str, Any]: - """ - Führt OCR auf einem Bild aus. - - Returns: - Dict mit text, confidence und regions - """ + """Führt OCR auf einem Bild aus.""" if self.ocr_engine is None: - # Fallback wenn kein OCR-Engine verfügbar - return { - "text": "[OCR nicht verfügbar - bitte PaddleOCR installieren]", - "confidence": 0.0, - "regions": [] - } + return {"text": "[OCR nicht verfügbar - bitte PaddleOCR installieren]", + "confidence": 0.0, "regions": []} - # PIL zu numpy array img_array = np.array(img) - - # Wenn Graustufen, zu RGB konvertieren (PaddleOCR erwartet RGB) if len(img_array.shape) == 2: img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB) - # OCR ausführen result = self.ocr_engine.ocr(img_array, cls=True) if not result or not result[0]: return {"text": "", "confidence": 0.0, "regions": []} - all_text = [] - all_regions = [] + all_text, all_regions = [], [] total_confidence = 0.0 for line in result[0]: - bbox_points = line[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + bbox_points = line[0] text, confidence = line[1] - - # Bounding Box zu x1, y1, x2, y2 konvertieren x_coords = [p[0] for p in bbox_points] y_coords = [p[1] for p in bbox_points] - bbox = ( - int(min(x_coords)), - int(min(y_coords)), - int(max(x_coords)), - int(max(y_coords)) - ) - + bbox = (int(min(x_coords)), int(min(y_coords)), + int(max(x_coords)), int(max(y_coords))) all_text.append(text) - all_regions.append(ProcessedRegion( - text=text, - confidence=confidence, - bbox=bbox - )) + all_regions.append(ProcessedRegion(text=text, confidence=confidence, bbox=bbox)) total_confidence += confidence avg_confidence = total_confidence / len(all_regions) if all_regions else 0.0 + return {"text": "\n".join(all_text), "confidence": avg_confidence, "regions": all_regions} - return { - "text": "\n".join(all_text), - "confidence": avg_confidence, - "regions": all_regions - } - - def extract_handwriting_regions( - self, - img: Image.Image, - min_area: int = 500 - ) -> List[Dict[str, Any]]: - """ - Erkennt und extrahiert handschriftliche Bereiche aus einem Bild. - - Nützlich für Klausuren mit gedruckten Fragen und handschriftlichen Antworten. - - Args: - img: Eingabebild - min_area: Minimale Fläche für erkannte Regionen - - Returns: - Liste von Regionen mit Koordinaten und erkanntem Text - """ - # Bildvorverarbeitung + def extract_handwriting_regions(self, img: Image.Image, min_area: int = 500) -> List[Dict[str, Any]]: + """Erkennt und extrahiert handschriftliche Bereiche aus einem Bild.""" cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) - - # Kanten erkennen edges = cv2.Canny(gray, 50, 150) - - # Morphologische Operationen zum Verbinden kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 5)) dilated = cv2.dilate(edges, kernel, iterations=2) - - # Konturen finden contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) regions = [] @@ -490,25 +298,15 @@ class FileProcessor: area = cv2.contourArea(contour) if area < min_area: continue - x, y, w, h = cv2.boundingRect(contour) - - # Region ausschneiden region_img = img.crop((x, y, x + w, y + h)) - - # OCR auf Region anwenden ocr_result = self._ocr_image(region_img) - regions.append({ - "bbox": (x, y, x + w, y + h), - "area": area, - "text": ocr_result["text"], - "confidence": ocr_result["confidence"] + "bbox": (x, y, x + w, y + h), "area": area, + "text": ocr_result["text"], "confidence": ocr_result["confidence"] }) - # Nach Y-Position sortieren (oben nach unten) regions.sort(key=lambda r: r["bbox"][1]) - return regions @@ -525,39 +323,25 @@ def get_file_processor() -> FileProcessor: # Convenience functions -def process_file( - file_path: str = None, - file_bytes: bytes = None, - mode: ProcessingMode = ProcessingMode.MIXED -) -> ProcessingResult: - """ - Convenience function zum Verarbeiten einer Datei. - - Args: - file_path: Pfad zur Datei - file_bytes: Dateiinhalt als Bytes - mode: Verarbeitungsmodus - - Returns: - ProcessingResult - """ +def process_file(file_path=None, file_bytes=None, mode=ProcessingMode.MIXED) -> ProcessingResult: + """Convenience function zum Verarbeiten einer Datei.""" processor = get_file_processor() return processor.process(file_path, file_bytes, mode) -def extract_text_from_pdf(file_path: str = None, file_bytes: bytes = None) -> str: +def extract_text_from_pdf(file_path=None, file_bytes=None) -> str: """Extrahiert Text aus einer PDF-Datei.""" result = process_file(file_path, file_bytes, ProcessingMode.TEXT_EXTRACT) return result.text -def ocr_image(file_path: str = None, file_bytes: bytes = None) -> str: +def ocr_image(file_path=None, file_bytes=None) -> str: """Führt OCR auf einem Bild aus.""" result = process_file(file_path, file_bytes, ProcessingMode.OCR_PRINTED) return result.text -def ocr_handwriting(file_path: str = None, file_bytes: bytes = None) -> str: +def ocr_handwriting(file_path=None, file_bytes=None) -> str: """Führt Handschrift-OCR auf einem Bild aus.""" result = process_file(file_path, file_bytes, ProcessingMode.OCR_HANDWRITING) return result.text diff --git a/backend-lehrer/services/file_processor_models.py b/backend-lehrer/services/file_processor_models.py new file mode 100644 index 0000000..dc5f084 --- /dev/null +++ b/backend-lehrer/services/file_processor_models.py @@ -0,0 +1,48 @@ +""" +File Processor - Datenmodelle und Enums. + +Typen fuer Dokumentenverarbeitung: Dateitypen, Modi, Ergebnisse. +""" + +from typing import List, Dict, Any, Tuple +from dataclasses import dataclass +from enum import Enum + + +class FileType(str, Enum): + """Unterstützte Dateitypen.""" + PDF = "pdf" + IMAGE = "image" + DOCX = "docx" + DOC = "doc" + TXT = "txt" + UNKNOWN = "unknown" + + +class ProcessingMode(str, Enum): + """Verarbeitungsmodi.""" + OCR_HANDWRITING = "ocr_handwriting" # Handschrifterkennung + OCR_PRINTED = "ocr_printed" # Gedruckter Text + TEXT_EXTRACT = "text_extract" # Textextraktion (PDF/DOCX) + MIXED = "mixed" # Kombiniert OCR + Textextraktion + + +@dataclass +class ProcessedRegion: + """Ein erkannter Textbereich.""" + text: str + confidence: float + bbox: Tuple[int, int, int, int] # x1, y1, x2, y2 + page: int = 1 + + +@dataclass +class ProcessingResult: + """Ergebnis der Dokumentenverarbeitung.""" + text: str + confidence: float + regions: List[ProcessedRegion] + page_count: int + file_type: FileType + processing_mode: ProcessingMode + metadata: Dict[str, Any] diff --git a/backend-lehrer/state_engine_api.py b/backend-lehrer/state_engine_api.py index fe669d6..3f43ca8 100644 --- a/backend-lehrer/state_engine_api.py +++ b/backend-lehrer/state_engine_api.py @@ -12,21 +12,29 @@ Endpoints: import logging import uuid from datetime import datetime, timedelta -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List from fastapi import APIRouter, HTTPException, Query -from pydantic import BaseModel, Field from state_engine import ( AnticipationEngine, PhaseService, - TeacherContext, SchoolYearPhase, ClassSummary, Event, - TeacherStats, get_phase_info, - PHASE_INFO +) +from state_engine_models import ( + MilestoneRequest, + TransitionRequest, + ContextResponse, + SuggestionsResponse, + DashboardResponse, + _teacher_contexts, + _milestones, + get_or_create_context, + update_context_from_services, + get_phase_display_name, ) logger = logging.getLogger(__name__) @@ -41,157 +49,15 @@ _engine = AnticipationEngine() _phase_service = PhaseService() -# ============================================================================ -# In-Memory Storage (später durch DB ersetzen) -# ============================================================================ - -# Simulierter Lehrer-Kontext (in Produktion aus DB) -_teacher_contexts: Dict[str, TeacherContext] = {} -_milestones: Dict[str, List[str]] = {} # teacher_id -> milestones - - -# ============================================================================ -# Pydantic Models -# ============================================================================ - -class MilestoneRequest(BaseModel): - """Request zum Abschließen eines Meilensteins.""" - milestone: str = Field(..., description="Name des Meilensteins") - - -class TransitionRequest(BaseModel): - """Request für Phasen-Übergang.""" - target_phase: str = Field(..., description="Zielphase") - - -class ContextResponse(BaseModel): - """Response mit TeacherContext.""" - context: Dict[str, Any] - phase_info: Dict[str, Any] - - -class SuggestionsResponse(BaseModel): - """Response mit Vorschlägen.""" - suggestions: List[Dict[str, Any]] - current_phase: str - phase_display_name: str - priority_counts: Dict[str, int] - - -class DashboardResponse(BaseModel): - """Response mit Dashboard-Daten.""" - context: Dict[str, Any] - suggestions: List[Dict[str, Any]] - stats: Dict[str, Any] - upcoming_events: List[Dict[str, Any]] - progress: Dict[str, Any] - phases: List[Dict[str, Any]] - - -# ============================================================================ -# Helper Functions -# ============================================================================ - -def _get_or_create_context(teacher_id: str) -> TeacherContext: - """ - Holt oder erstellt TeacherContext. - - In Produktion würde dies aus der Datenbank geladen. - """ - if teacher_id not in _teacher_contexts: - # Erstelle Demo-Kontext - now = datetime.now() - school_year_start = datetime(now.year if now.month >= 8 else now.year - 1, 8, 1) - weeks_since_start = (now - school_year_start).days // 7 - - # Bestimme Phase basierend auf Monat - month = now.month - if month in [8, 9]: - phase = SchoolYearPhase.SCHOOL_YEAR_START - elif month in [10, 11]: - phase = SchoolYearPhase.TEACHING_SETUP - elif month == 12: - phase = SchoolYearPhase.PERFORMANCE_1 - elif month in [1, 2]: - phase = SchoolYearPhase.SEMESTER_END - elif month in [3, 4]: - phase = SchoolYearPhase.TEACHING_2 - elif month in [5, 6]: - phase = SchoolYearPhase.PERFORMANCE_2 - else: - phase = SchoolYearPhase.YEAR_END - - _teacher_contexts[teacher_id] = TeacherContext( - teacher_id=teacher_id, - school_id=str(uuid.uuid4()), - school_year_id=str(uuid.uuid4()), - federal_state="niedersachsen", - school_type="gymnasium", - school_year_start=school_year_start, - current_phase=phase, - phase_entered_at=now - timedelta(days=7), - weeks_since_start=weeks_since_start, - days_in_phase=7, - classes=[], - total_students=0, - upcoming_events=[], - completed_milestones=_milestones.get(teacher_id, []), - pending_milestones=[], - stats=TeacherStats(), - ) - - return _teacher_contexts[teacher_id] - - -def _update_context_from_services(ctx: TeacherContext) -> TeacherContext: - """ - Aktualisiert Kontext mit Daten aus anderen Services. - - In Produktion würde dies von school-service, gradebook etc. laden. - """ - # Simulierte Daten - in Produktion API-Calls - # Hier könnten wir den Kontext mit echten Daten anreichern - - # Berechne days_in_phase - ctx.days_in_phase = (datetime.now() - ctx.phase_entered_at).days - - # Lade abgeschlossene Meilensteine - ctx.completed_milestones = _milestones.get(ctx.teacher_id, []) - - # Berechne pending milestones - phase_info = get_phase_info(ctx.current_phase) - ctx.pending_milestones = [ - m for m in phase_info.required_actions - if m not in ctx.completed_milestones - ] - - return ctx - - -def _get_phase_display_name(phase: str) -> str: - """Gibt Display-Name für Phase zurück.""" - try: - return get_phase_info(SchoolYearPhase(phase)).display_name - except (ValueError, KeyError): - return phase - - # ============================================================================ # API Endpoints # ============================================================================ @router.get("/context", response_model=ContextResponse) async def get_teacher_context(teacher_id: str = Query("demo-teacher")): - """ - Gibt den aggregierten TeacherContext zurück. - - Enthält alle relevanten Informationen für: - - Phasen-Anzeige - - Antizipations-Engine - - Dashboard - """ - ctx = _get_or_create_context(teacher_id) - ctx = _update_context_from_services(ctx) + """Gibt den aggregierten TeacherContext zurück.""" + ctx = get_or_create_context(teacher_id) + ctx = update_context_from_services(ctx) phase_info = get_phase_info(ctx.current_phase) @@ -210,10 +76,8 @@ async def get_teacher_context(teacher_id: str = Query("demo-teacher")): @router.get("/phase") async def get_current_phase(teacher_id: str = Query("demo-teacher")): - """ - Gibt die aktuelle Phase mit Details zurück. - """ - ctx = _get_or_create_context(teacher_id) + """Gibt die aktuelle Phase mit Details zurück.""" + ctx = get_or_create_context(teacher_id) phase_info = get_phase_info(ctx.current_phase) return { @@ -230,11 +94,7 @@ async def get_current_phase(teacher_id: str = Query("demo-teacher")): @router.get("/phases") async def get_all_phases(): - """ - Gibt alle Phasen mit Metadaten zurück. - - Nützlich für die Phasen-Anzeige im Dashboard. - """ + """Gibt alle Phasen mit Metadaten zurück.""" return { "phases": _phase_service.get_all_phases() } @@ -242,13 +102,9 @@ async def get_all_phases(): @router.get("/suggestions", response_model=SuggestionsResponse) async def get_suggestions(teacher_id: str = Query("demo-teacher")): - """ - Gibt Vorschläge basierend auf dem aktuellen Kontext zurück. - - Die Vorschläge sind priorisiert und auf max. 5 limitiert. - """ - ctx = _get_or_create_context(teacher_id) - ctx = _update_context_from_services(ctx) + """Gibt Vorschläge basierend auf dem aktuellen Kontext zurück.""" + ctx = get_or_create_context(teacher_id) + ctx = update_context_from_services(ctx) suggestions = _engine.get_suggestions(ctx) priority_counts = _engine.count_by_priority(ctx) @@ -256,18 +112,16 @@ async def get_suggestions(teacher_id: str = Query("demo-teacher")): return SuggestionsResponse( suggestions=[s.to_dict() for s in suggestions], current_phase=ctx.current_phase.value, - phase_display_name=_get_phase_display_name(ctx.current_phase.value), + phase_display_name=get_phase_display_name(ctx.current_phase.value), priority_counts=priority_counts, ) @router.get("/suggestions/top") async def get_top_suggestion(teacher_id: str = Query("demo-teacher")): - """ - Gibt den wichtigsten einzelnen Vorschlag zurück. - """ - ctx = _get_or_create_context(teacher_id) - ctx = _update_context_from_services(ctx) + """Gibt den wichtigsten einzelnen Vorschlag zurück.""" + ctx = get_or_create_context(teacher_id) + ctx = update_context_from_services(ctx) suggestion = _engine.get_top_suggestion(ctx) @@ -284,28 +138,17 @@ async def get_top_suggestion(teacher_id: str = Query("demo-teacher")): @router.get("/dashboard", response_model=DashboardResponse) async def get_dashboard_data(teacher_id: str = Query("demo-teacher")): - """ - Gibt alle Daten für das Begleiter-Dashboard zurück. - - Kombiniert: - - TeacherContext - - Vorschläge - - Statistiken - - Termine - - Fortschritt - """ - ctx = _get_or_create_context(teacher_id) - ctx = _update_context_from_services(ctx) + """Gibt alle Daten für das Begleiter-Dashboard zurück.""" + ctx = get_or_create_context(teacher_id) + ctx = update_context_from_services(ctx) suggestions = _engine.get_suggestions(ctx) phase_info = get_phase_info(ctx.current_phase) - # Berechne Fortschritt required = set(phase_info.required_actions) completed = set(ctx.completed_milestones) completed_in_phase = len(required.intersection(completed)) - # Alle Phasen für Anzeige all_phases = [] phase_order = [ SchoolYearPhase.ONBOARDING, @@ -376,14 +219,9 @@ async def complete_milestone( request: MilestoneRequest, teacher_id: str = Query("demo-teacher") ): - """ - Markiert einen Meilenstein als erledigt. - - Prüft automatisch ob ein Phasen-Übergang möglich ist. - """ + """Markiert einen Meilenstein als erledigt.""" milestone = request.milestone - # Speichere Meilenstein if teacher_id not in _milestones: _milestones[teacher_id] = [] @@ -391,12 +229,10 @@ async def complete_milestone( _milestones[teacher_id].append(milestone) logger.info(f"Milestone '{milestone}' completed for teacher {teacher_id}") - # Aktualisiere Kontext - ctx = _get_or_create_context(teacher_id) + ctx = get_or_create_context(teacher_id) ctx.completed_milestones = _milestones[teacher_id] _teacher_contexts[teacher_id] = ctx - # Prüfe automatischen Phasen-Übergang new_phase = _phase_service.check_and_transition(ctx) if new_phase: @@ -420,9 +256,7 @@ async def transition_phase( request: TransitionRequest, teacher_id: str = Query("demo-teacher") ): - """ - Führt einen manuellen Phasen-Übergang durch. - """ + """Führt einen manuellen Phasen-Übergang durch.""" try: target_phase = SchoolYearPhase(request.target_phase) except ValueError: @@ -431,16 +265,14 @@ async def transition_phase( detail=f"Ungültige Phase: {request.target_phase}" ) - ctx = _get_or_create_context(teacher_id) + ctx = get_or_create_context(teacher_id) - # Prüfe ob Übergang erlaubt if not _phase_service.can_transition_to(ctx, target_phase): raise HTTPException( status_code=400, detail=f"Übergang von {ctx.current_phase.value} zu {target_phase.value} nicht erlaubt" ) - # Führe Übergang durch old_phase = ctx.current_phase ctx.current_phase = target_phase ctx.phase_entered_at = datetime.now() @@ -459,10 +291,8 @@ async def transition_phase( @router.get("/next-phase") async def get_next_phase(teacher_id: str = Query("demo-teacher")): - """ - Gibt die nächste Phase und Anforderungen zurück. - """ - ctx = _get_or_create_context(teacher_id) + """Gibt die nächste Phase und Anforderungen zurück.""" + ctx = get_or_create_context(teacher_id) next_phase = _phase_service.get_next_phase(ctx.current_phase) if not next_phase: @@ -475,7 +305,6 @@ async def get_next_phase(teacher_id: str = Query("demo-teacher")): next_info = get_phase_info(next_phase) current_info = get_phase_info(ctx.current_phase) - # Fehlende Anforderungen missing = [ m for m in current_info.required_actions if m not in ctx.completed_milestones @@ -505,7 +334,7 @@ async def demo_add_class( teacher_id: str = Query("demo-teacher") ): """Demo: Fügt eine Klasse zum Kontext hinzu.""" - ctx = _get_or_create_context(teacher_id) + ctx = get_or_create_context(teacher_id) ctx.classes.append(ClassSummary( class_id=str(uuid.uuid4()), @@ -515,7 +344,6 @@ async def demo_add_class( subject="Deutsch" )) ctx.total_students += student_count - _teacher_contexts[teacher_id] = ctx return {"success": True, "classes": len(ctx.classes)} @@ -529,7 +357,7 @@ async def demo_add_event( teacher_id: str = Query("demo-teacher") ): """Demo: Fügt ein Event zum Kontext hinzu.""" - ctx = _get_or_create_context(teacher_id) + ctx = get_or_create_context(teacher_id) ctx.upcoming_events.append(Event( type=event_type, @@ -538,7 +366,6 @@ async def demo_add_event( in_days=in_days, priority="high" if in_days <= 3 else "medium" )) - _teacher_contexts[teacher_id] = ctx return {"success": True, "events": len(ctx.upcoming_events)} @@ -554,7 +381,7 @@ async def demo_update_stats( teacher_id: str = Query("demo-teacher") ): """Demo: Aktualisiert Statistiken.""" - ctx = _get_or_create_context(teacher_id) + ctx = get_or_create_context(teacher_id) if learning_units: ctx.stats.learning_units_created = learning_units diff --git a/backend-lehrer/state_engine_models.py b/backend-lehrer/state_engine_models.py new file mode 100644 index 0000000..778b2f7 --- /dev/null +++ b/backend-lehrer/state_engine_models.py @@ -0,0 +1,143 @@ +""" +State Engine API - Pydantic Models und Helper Functions. +""" + +import uuid +from datetime import datetime, timedelta +from typing import Dict, Any, List, Optional + +from pydantic import BaseModel, Field + +from state_engine import ( + SchoolYearPhase, + ClassSummary, + Event, + TeacherContext, + TeacherStats, + get_phase_info, +) + + +# ============================================================================ +# In-Memory Storage (später durch DB ersetzen) +# ============================================================================ + +_teacher_contexts: Dict[str, TeacherContext] = {} +_milestones: Dict[str, List[str]] = {} # teacher_id -> milestones + + +# ============================================================================ +# Pydantic Models +# ============================================================================ + +class MilestoneRequest(BaseModel): + """Request zum Abschließen eines Meilensteins.""" + milestone: str = Field(..., description="Name des Meilensteins") + + +class TransitionRequest(BaseModel): + """Request für Phasen-Übergang.""" + target_phase: str = Field(..., description="Zielphase") + + +class ContextResponse(BaseModel): + """Response mit TeacherContext.""" + context: Dict[str, Any] + phase_info: Dict[str, Any] + + +class SuggestionsResponse(BaseModel): + """Response mit Vorschlägen.""" + suggestions: List[Dict[str, Any]] + current_phase: str + phase_display_name: str + priority_counts: Dict[str, int] + + +class DashboardResponse(BaseModel): + """Response mit Dashboard-Daten.""" + context: Dict[str, Any] + suggestions: List[Dict[str, Any]] + stats: Dict[str, Any] + upcoming_events: List[Dict[str, Any]] + progress: Dict[str, Any] + phases: List[Dict[str, Any]] + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +def get_or_create_context(teacher_id: str) -> TeacherContext: + """ + Holt oder erstellt TeacherContext. + + In Produktion würde dies aus der Datenbank geladen. + """ + if teacher_id not in _teacher_contexts: + now = datetime.now() + school_year_start = datetime(now.year if now.month >= 8 else now.year - 1, 8, 1) + weeks_since_start = (now - school_year_start).days // 7 + + month = now.month + if month in [8, 9]: + phase = SchoolYearPhase.SCHOOL_YEAR_START + elif month in [10, 11]: + phase = SchoolYearPhase.TEACHING_SETUP + elif month == 12: + phase = SchoolYearPhase.PERFORMANCE_1 + elif month in [1, 2]: + phase = SchoolYearPhase.SEMESTER_END + elif month in [3, 4]: + phase = SchoolYearPhase.TEACHING_2 + elif month in [5, 6]: + phase = SchoolYearPhase.PERFORMANCE_2 + else: + phase = SchoolYearPhase.YEAR_END + + _teacher_contexts[teacher_id] = TeacherContext( + teacher_id=teacher_id, + school_id=str(uuid.uuid4()), + school_year_id=str(uuid.uuid4()), + federal_state="niedersachsen", + school_type="gymnasium", + school_year_start=school_year_start, + current_phase=phase, + phase_entered_at=now - timedelta(days=7), + weeks_since_start=weeks_since_start, + days_in_phase=7, + classes=[], + total_students=0, + upcoming_events=[], + completed_milestones=_milestones.get(teacher_id, []), + pending_milestones=[], + stats=TeacherStats(), + ) + + return _teacher_contexts[teacher_id] + + +def update_context_from_services(ctx: TeacherContext) -> TeacherContext: + """ + Aktualisiert Kontext mit Daten aus anderen Services. + + In Produktion würde dies von school-service, gradebook etc. laden. + """ + ctx.days_in_phase = (datetime.now() - ctx.phase_entered_at).days + ctx.completed_milestones = _milestones.get(ctx.teacher_id, []) + + phase_info = get_phase_info(ctx.current_phase) + ctx.pending_milestones = [ + m for m in phase_info.required_actions + if m not in ctx.completed_milestones + ] + + return ctx + + +def get_phase_display_name(phase: str) -> str: + """Gibt Display-Name für Phase zurück.""" + try: + return get_phase_info(SchoolYearPhase(phase)).display_name + except (ValueError, KeyError): + return phase diff --git a/backend-lehrer/worksheets_api.py b/backend-lehrer/worksheets_api.py index 527bf95..d6f9f8d 100644 --- a/backend-lehrer/worksheets_api.py +++ b/backend-lehrer/worksheets_api.py @@ -16,11 +16,9 @@ Unterstützt: import logging import uuid from datetime import datetime -from typing import List, Dict, Any, Optional -from enum import Enum +from typing import Dict -from fastapi import APIRouter, HTTPException, UploadFile, File, Form -from pydantic import BaseModel, Field +from fastapi import APIRouter, HTTPException from generators import ( MultipleChoiceGenerator, @@ -28,9 +26,22 @@ from generators import ( MindmapGenerator, QuizGenerator ) -from generators.mc_generator import Difficulty -from generators.cloze_generator import ClozeType -from generators.quiz_generator import QuizType + +from worksheets_models import ( + ContentType, + GenerateRequest, + MCGenerateRequest, + ClozeGenerateRequest, + MindmapGenerateRequest, + QuizGenerateRequest, + BatchGenerateRequest, + WorksheetContent, + GenerateResponse, + BatchGenerateResponse, + parse_difficulty, + parse_cloze_type, + parse_quiz_types, +) logger = logging.getLogger(__name__) @@ -40,89 +51,6 @@ router = APIRouter( ) -# ============================================================================ -# Pydantic Models -# ============================================================================ - -class ContentType(str, Enum): - """Verfügbare Content-Typen.""" - MULTIPLE_CHOICE = "multiple_choice" - CLOZE = "cloze" - MINDMAP = "mindmap" - QUIZ = "quiz" - - -class GenerateRequest(BaseModel): - """Basis-Request für Generierung.""" - source_text: str = Field(..., min_length=50, description="Quelltext für Generierung") - topic: Optional[str] = Field(None, description="Thema/Titel") - subject: Optional[str] = Field(None, description="Fach") - grade_level: Optional[str] = Field(None, description="Klassenstufe") - - -class MCGenerateRequest(GenerateRequest): - """Request für Multiple-Choice-Generierung.""" - num_questions: int = Field(5, ge=1, le=20, description="Anzahl Fragen") - difficulty: str = Field("medium", description="easy, medium, hard") - - -class ClozeGenerateRequest(GenerateRequest): - """Request für Lückentext-Generierung.""" - num_gaps: int = Field(5, ge=1, le=15, description="Anzahl Lücken") - difficulty: str = Field("medium", description="easy, medium, hard") - cloze_type: str = Field("fill_in", description="fill_in, drag_drop, dropdown") - - -class MindmapGenerateRequest(GenerateRequest): - """Request für Mindmap-Generierung.""" - max_depth: int = Field(3, ge=2, le=5, description="Maximale Tiefe") - - -class QuizGenerateRequest(GenerateRequest): - """Request für Quiz-Generierung.""" - quiz_types: List[str] = Field( - ["true_false", "matching"], - description="Typen: true_false, matching, sorting, open_ended" - ) - num_items: int = Field(5, ge=1, le=10, description="Items pro Typ") - difficulty: str = Field("medium", description="easy, medium, hard") - - -class BatchGenerateRequest(BaseModel): - """Request für Batch-Generierung mehrerer Content-Typen.""" - source_text: str = Field(..., min_length=50) - content_types: List[str] = Field(..., description="Liste von Content-Typen") - topic: Optional[str] = None - subject: Optional[str] = None - grade_level: Optional[str] = None - difficulty: str = "medium" - - -class WorksheetContent(BaseModel): - """Generierter Content.""" - id: str - content_type: str - data: Dict[str, Any] - h5p_format: Optional[Dict[str, Any]] = None - created_at: datetime - topic: Optional[str] = None - difficulty: Optional[str] = None - - -class GenerateResponse(BaseModel): - """Response mit generiertem Content.""" - success: bool - content: Optional[WorksheetContent] = None - error: Optional[str] = None - - -class BatchGenerateResponse(BaseModel): - """Response für Batch-Generierung.""" - success: bool - contents: List[WorksheetContent] = [] - errors: List[str] = [] - - # ============================================================================ # In-Memory Storage (später durch DB ersetzen) # ============================================================================ @@ -134,49 +62,12 @@ _generated_content: Dict[str, WorksheetContent] = {} # Generator Instances # ============================================================================ -# Generatoren ohne LLM-Client (automatische Generierung) -# In Produktion würde hier der LLM-Client injiziert mc_generator = MultipleChoiceGenerator() cloze_generator = ClozeGenerator() mindmap_generator = MindmapGenerator() quiz_generator = QuizGenerator() -# ============================================================================ -# Helper Functions -# ============================================================================ - -def _parse_difficulty(difficulty_str: str) -> Difficulty: - """Konvertiert String zu Difficulty Enum.""" - mapping = { - "easy": Difficulty.EASY, - "medium": Difficulty.MEDIUM, - "hard": Difficulty.HARD - } - return mapping.get(difficulty_str.lower(), Difficulty.MEDIUM) - - -def _parse_cloze_type(type_str: str) -> ClozeType: - """Konvertiert String zu ClozeType Enum.""" - mapping = { - "fill_in": ClozeType.FILL_IN, - "drag_drop": ClozeType.DRAG_DROP, - "dropdown": ClozeType.DROPDOWN - } - return mapping.get(type_str.lower(), ClozeType.FILL_IN) - - -def _parse_quiz_types(type_strs: List[str]) -> List[QuizType]: - """Konvertiert String-Liste zu QuizType Enums.""" - mapping = { - "true_false": QuizType.TRUE_FALSE, - "matching": QuizType.MATCHING, - "sorting": QuizType.SORTING, - "open_ended": QuizType.OPEN_ENDED - } - return [mapping.get(t.lower(), QuizType.TRUE_FALSE) for t in type_strs] - - def _store_content(content: WorksheetContent) -> None: """Speichert generierten Content.""" _generated_content[content.id] = content @@ -188,15 +79,9 @@ def _store_content(content: WorksheetContent) -> None: @router.post("/generate/multiple-choice", response_model=GenerateResponse) async def generate_multiple_choice(request: MCGenerateRequest): - """ - Generiert Multiple-Choice-Fragen aus Quelltext. - - - **source_text**: Text mit mind. 50 Zeichen - - **num_questions**: Anzahl Fragen (1-20) - - **difficulty**: easy, medium, hard - """ + """Generiert Multiple-Choice-Fragen aus Quelltext.""" try: - difficulty = _parse_difficulty(request.difficulty) + difficulty = parse_difficulty(request.difficulty) questions = mc_generator.generate( source_text=request.source_text, @@ -212,7 +97,6 @@ async def generate_multiple_choice(request: MCGenerateRequest): error="Keine Fragen generiert. Text möglicherweise zu kurz." ) - # Konvertiere zu Dict questions_dict = mc_generator.to_dict(questions) h5p_format = mc_generator.to_h5p_format(questions) @@ -227,7 +111,6 @@ async def generate_multiple_choice(request: MCGenerateRequest): ) _store_content(content) - return GenerateResponse(success=True, content=content) except Exception as e: @@ -237,15 +120,9 @@ async def generate_multiple_choice(request: MCGenerateRequest): @router.post("/generate/cloze", response_model=GenerateResponse) async def generate_cloze(request: ClozeGenerateRequest): - """ - Generiert Lückentext aus Quelltext. - - - **source_text**: Text mit mind. 50 Zeichen - - **num_gaps**: Anzahl Lücken (1-15) - - **cloze_type**: fill_in, drag_drop, dropdown - """ + """Generiert Lückentext aus Quelltext.""" try: - cloze_type = _parse_cloze_type(request.cloze_type) + cloze_type = parse_cloze_type(request.cloze_type) cloze = cloze_generator.generate( source_text=request.source_text, @@ -275,7 +152,6 @@ async def generate_cloze(request: ClozeGenerateRequest): ) _store_content(content) - return GenerateResponse(success=True, content=content) except Exception as e: @@ -285,12 +161,7 @@ async def generate_cloze(request: ClozeGenerateRequest): @router.post("/generate/mindmap", response_model=GenerateResponse) async def generate_mindmap(request: MindmapGenerateRequest): - """ - Generiert Mindmap aus Quelltext. - - - **source_text**: Text mit mind. 50 Zeichen - - **max_depth**: Maximale Tiefe (2-5) - """ + """Generiert Mindmap aus Quelltext.""" try: mindmap = mindmap_generator.generate( source_text=request.source_text, @@ -317,14 +188,13 @@ async def generate_mindmap(request: MindmapGenerateRequest): "mermaid": mermaid, "json_tree": json_tree }, - h5p_format=None, # Mindmaps haben kein H5P-Format + h5p_format=None, created_at=datetime.utcnow(), topic=request.topic, difficulty=None ) _store_content(content) - return GenerateResponse(success=True, content=content) except Exception as e: @@ -334,17 +204,10 @@ async def generate_mindmap(request: MindmapGenerateRequest): @router.post("/generate/quiz", response_model=GenerateResponse) async def generate_quiz(request: QuizGenerateRequest): - """ - Generiert Quiz mit verschiedenen Fragetypen. - - - **source_text**: Text mit mind. 50 Zeichen - - **quiz_types**: Liste von true_false, matching, sorting, open_ended - - **num_items**: Items pro Typ (1-10) - """ + """Generiert Quiz mit verschiedenen Fragetypen.""" try: - quiz_types = _parse_quiz_types(request.quiz_types) + quiz_types = parse_quiz_types(request.quiz_types) - # Generate quiz for each type and combine results all_questions = [] quizzes = [] @@ -365,7 +228,6 @@ async def generate_quiz(request: QuizGenerateRequest): error="Quiz konnte nicht generiert werden. Text möglicherweise zu kurz." ) - # Combine all quizzes into a single dict combined_quiz_dict = { "quiz_types": [qt.value for qt in quiz_types], "title": f"Combined Quiz - {request.topic or 'Various Topics'}", @@ -374,12 +236,10 @@ async def generate_quiz(request: QuizGenerateRequest): "questions": [] } - # Add questions from each quiz for quiz in quizzes: quiz_dict = quiz_generator.to_dict(quiz) combined_quiz_dict["questions"].extend(quiz_dict.get("questions", [])) - # Use first quiz's H5P format as base (or empty if none) h5p_format = quiz_generator.to_h5p_format(quizzes[0]) if quizzes else {} content = WorksheetContent( @@ -393,7 +253,6 @@ async def generate_quiz(request: QuizGenerateRequest): ) _store_content(content) - return GenerateResponse(success=True, content=content) except Exception as e: @@ -403,22 +262,10 @@ async def generate_quiz(request: QuizGenerateRequest): @router.post("/generate/batch", response_model=BatchGenerateResponse) async def generate_batch(request: BatchGenerateRequest): - """ - Generiert mehrere Content-Typen aus einem Quelltext. - - Ideal für die Erstellung kompletter Arbeitsblätter mit - verschiedenen Übungstypen. - """ + """Generiert mehrere Content-Typen aus einem Quelltext.""" contents = [] errors = [] - type_mapping = { - "multiple_choice": MCGenerateRequest, - "cloze": ClozeGenerateRequest, - "mindmap": MindmapGenerateRequest, - "quiz": QuizGenerateRequest - } - for content_type in request.content_types: try: if content_type == "multiple_choice": diff --git a/backend-lehrer/worksheets_models.py b/backend-lehrer/worksheets_models.py new file mode 100644 index 0000000..87c6b25 --- /dev/null +++ b/backend-lehrer/worksheets_models.py @@ -0,0 +1,135 @@ +""" +Worksheets API - Pydantic Models und Helpers. + +Request-/Response-Models und Hilfsfunktionen fuer die +Arbeitsblatt-Generierungs-API. +""" + +import uuid +from datetime import datetime +from typing import List, Dict, Any, Optional +from enum import Enum + +from pydantic import BaseModel, Field + +from generators.mc_generator import Difficulty +from generators.cloze_generator import ClozeType +from generators.quiz_generator import QuizType + + +# ============================================================================ +# Pydantic Models +# ============================================================================ + +class ContentType(str, Enum): + """Verfügbare Content-Typen.""" + MULTIPLE_CHOICE = "multiple_choice" + CLOZE = "cloze" + MINDMAP = "mindmap" + QUIZ = "quiz" + + +class GenerateRequest(BaseModel): + """Basis-Request für Generierung.""" + source_text: str = Field(..., min_length=50, description="Quelltext für Generierung") + topic: Optional[str] = Field(None, description="Thema/Titel") + subject: Optional[str] = Field(None, description="Fach") + grade_level: Optional[str] = Field(None, description="Klassenstufe") + + +class MCGenerateRequest(GenerateRequest): + """Request für Multiple-Choice-Generierung.""" + num_questions: int = Field(5, ge=1, le=20, description="Anzahl Fragen") + difficulty: str = Field("medium", description="easy, medium, hard") + + +class ClozeGenerateRequest(GenerateRequest): + """Request für Lückentext-Generierung.""" + num_gaps: int = Field(5, ge=1, le=15, description="Anzahl Lücken") + difficulty: str = Field("medium", description="easy, medium, hard") + cloze_type: str = Field("fill_in", description="fill_in, drag_drop, dropdown") + + +class MindmapGenerateRequest(GenerateRequest): + """Request für Mindmap-Generierung.""" + max_depth: int = Field(3, ge=2, le=5, description="Maximale Tiefe") + + +class QuizGenerateRequest(GenerateRequest): + """Request für Quiz-Generierung.""" + quiz_types: List[str] = Field( + ["true_false", "matching"], + description="Typen: true_false, matching, sorting, open_ended" + ) + num_items: int = Field(5, ge=1, le=10, description="Items pro Typ") + difficulty: str = Field("medium", description="easy, medium, hard") + + +class BatchGenerateRequest(BaseModel): + """Request für Batch-Generierung mehrerer Content-Typen.""" + source_text: str = Field(..., min_length=50) + content_types: List[str] = Field(..., description="Liste von Content-Typen") + topic: Optional[str] = None + subject: Optional[str] = None + grade_level: Optional[str] = None + difficulty: str = "medium" + + +class WorksheetContent(BaseModel): + """Generierter Content.""" + id: str + content_type: str + data: Dict[str, Any] + h5p_format: Optional[Dict[str, Any]] = None + created_at: datetime + topic: Optional[str] = None + difficulty: Optional[str] = None + + +class GenerateResponse(BaseModel): + """Response mit generiertem Content.""" + success: bool + content: Optional[WorksheetContent] = None + error: Optional[str] = None + + +class BatchGenerateResponse(BaseModel): + """Response für Batch-Generierung.""" + success: bool + contents: List[WorksheetContent] = [] + errors: List[str] = [] + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +def parse_difficulty(difficulty_str: str) -> Difficulty: + """Konvertiert String zu Difficulty Enum.""" + mapping = { + "easy": Difficulty.EASY, + "medium": Difficulty.MEDIUM, + "hard": Difficulty.HARD + } + return mapping.get(difficulty_str.lower(), Difficulty.MEDIUM) + + +def parse_cloze_type(type_str: str) -> ClozeType: + """Konvertiert String zu ClozeType Enum.""" + mapping = { + "fill_in": ClozeType.FILL_IN, + "drag_drop": ClozeType.DRAG_DROP, + "dropdown": ClozeType.DROPDOWN + } + return mapping.get(type_str.lower(), ClozeType.FILL_IN) + + +def parse_quiz_types(type_strs: List[str]) -> List[QuizType]: + """Konvertiert String-Liste zu QuizType Enums.""" + mapping = { + "true_false": QuizType.TRUE_FALSE, + "matching": QuizType.MATCHING, + "sorting": QuizType.SORTING, + "open_ended": QuizType.OPEN_ENDED + } + return [mapping.get(t.lower(), QuizType.TRUE_FALSE) for t in type_strs] diff --git a/klausur-service/backend/cv_gutter_repair.py b/klausur-service/backend/cv_gutter_repair.py index 03c7bd1..fc6fc6c 100644 --- a/klausur-service/backend/cv_gutter_repair.py +++ b/klausur-service/backend/cv_gutter_repair.py @@ -1,610 +1,35 @@ """ -Gutter Repair — detects and fixes words truncated or blurred at the book gutter. +Gutter Repair — barrel re-export. -When scanning double-page spreads, the binding area (gutter) causes: - 1. Blurry/garbled trailing characters ("stammeli" → "stammeln") - 2. Words split across lines with a hyphen lost in the gutter - ("ve" + "künden" → "verkünden") - -This module analyses grid cells, identifies gutter-edge candidates, and -proposes corrections using pyspellchecker (DE + EN). +All implementation split into: + cv_gutter_repair_core — spellchecker setup, data types, single-word repair + cv_gutter_repair_grid — grid analysis, suggestion application Lizenz: Apache 2.0 (kommerziell nutzbar) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import itertools -import logging -import re -import time -import uuid -from dataclasses import dataclass, field, asdict -from typing import Any, Dict, List, Optional, Tuple - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Spellchecker setup (lazy, cached) -# --------------------------------------------------------------------------- - -_spell_de = None -_spell_en = None -_SPELL_AVAILABLE = False - -def _init_spellcheckers(): - """Lazy-load DE + EN spellcheckers (cached across calls).""" - global _spell_de, _spell_en, _SPELL_AVAILABLE - if _spell_de is not None: - return - try: - from spellchecker import SpellChecker - _spell_de = SpellChecker(language='de', distance=1) - _spell_en = SpellChecker(language='en', distance=1) - _SPELL_AVAILABLE = True - logger.info("Gutter repair: spellcheckers loaded (DE + EN)") - except ImportError: - logger.warning("pyspellchecker not installed — gutter repair unavailable") - - -def _is_known(word: str) -> bool: - """Check if a word is known in DE or EN dictionary.""" - _init_spellcheckers() - if not _SPELL_AVAILABLE: - return False - w = word.lower() - return bool(_spell_de.known([w])) or bool(_spell_en.known([w])) - - -def _spell_candidates(word: str, lang: str = "both") -> List[str]: - """Get all plausible spellchecker candidates for a word (deduplicated).""" - _init_spellcheckers() - if not _SPELL_AVAILABLE: - return [] - w = word.lower() - seen: set = set() - results: List[str] = [] - - for checker in ([_spell_de, _spell_en] if lang == "both" - else [_spell_de] if lang == "de" - else [_spell_en]): - if checker is None: - continue - cands = checker.candidates(w) - if cands: - for c in cands: - if c and c != w and c not in seen: - seen.add(c) - results.append(c) - - return results - - -# --------------------------------------------------------------------------- -# Gutter position detection -# --------------------------------------------------------------------------- - -# Minimum word length for spell-fix (very short words are often legitimate) -_MIN_WORD_LEN_SPELL = 3 - -# Minimum word length for hyphen-join candidates (fragments at the gutter -# can be as short as 1-2 chars, e.g. "ve" from "ver-künden") -_MIN_WORD_LEN_HYPHEN = 2 - -# How close to the right column edge a word must be to count as "gutter-adjacent". -# Expressed as fraction of column width (e.g. 0.75 = rightmost 25%). -_GUTTER_EDGE_THRESHOLD = 0.70 - -# Small common words / abbreviations that should NOT be repaired -_STOPWORDS = frozenset([ - # German - "ab", "an", "am", "da", "er", "es", "im", "in", "ja", "ob", "so", "um", - "zu", "wo", "du", "eh", "ei", "je", "na", "nu", "oh", - # English - "a", "am", "an", "as", "at", "be", "by", "do", "go", "he", "if", "in", - "is", "it", "me", "my", "no", "of", "on", "or", "so", "to", "up", "us", - "we", -]) - -# IPA / phonetic patterns — skip these cells -_IPA_RE = re.compile(r'[\[\]/ˈˌːʃʒθðŋɑɒæɔəɛɪʊʌ]') - - -def _is_ipa_text(text: str) -> bool: - """True if text looks like IPA transcription.""" - return bool(_IPA_RE.search(text)) - - -def _word_is_at_gutter_edge(word_bbox: Dict, col_x: float, col_width: float) -> bool: - """Check if a word's right edge is near the right boundary of its column.""" - if col_width <= 0: - return False - word_right = word_bbox.get("left", 0) + word_bbox.get("width", 0) - col_right = col_x + col_width - # Word's right edge within the rightmost portion of the column - relative_pos = (word_right - col_x) / col_width - return relative_pos >= _GUTTER_EDGE_THRESHOLD - - -# --------------------------------------------------------------------------- -# Suggestion types -# --------------------------------------------------------------------------- - -@dataclass -class GutterSuggestion: - """A single correction suggestion.""" - id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) - type: str = "" # "hyphen_join" | "spell_fix" - zone_index: int = 0 - row_index: int = 0 - col_index: int = 0 - col_type: str = "" - cell_id: str = "" - original_text: str = "" - suggested_text: str = "" - # For hyphen_join: - next_row_index: int = -1 - next_row_cell_id: str = "" - next_row_text: str = "" - missing_chars: str = "" - display_parts: List[str] = field(default_factory=list) - # Alternatives (other plausible corrections the user can pick from) - alternatives: List[str] = field(default_factory=list) - # Meta: - confidence: float = 0.0 - reason: str = "" # "gutter_truncation" | "gutter_blur" | "hyphen_continuation" - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - -# --------------------------------------------------------------------------- -# Core repair logic -# --------------------------------------------------------------------------- - -_TRAILING_PUNCT_RE = re.compile(r'[.,;:!?\)\]]+$') - - -def _try_hyphen_join( - word_text: str, - next_word_text: str, - max_missing: int = 3, -) -> Optional[Tuple[str, str, float]]: - """Try joining two fragments with 0..max_missing interpolated chars. - - Strips trailing punctuation from the continuation word before testing - (e.g. "künden," → "künden") so dictionary lookup succeeds. - - Returns (joined_word, missing_chars, confidence) or None. - """ - base = word_text.rstrip("-").rstrip() - # Strip trailing punctuation from continuation (commas, periods, etc.) - raw_continuation = next_word_text.lstrip() - continuation = _TRAILING_PUNCT_RE.sub('', raw_continuation) - - if not base or not continuation: - return None - - # 1. Direct join (no missing chars) - direct = base + continuation - if _is_known(direct): - return (direct, "", 0.95) - - # 2. Try with 1..max_missing missing characters - # Use common letters, weighted by frequency in German/English - _COMMON_CHARS = "enristaldhgcmobwfkzpvjyxqu" - - for n_missing in range(1, max_missing + 1): - for chars in itertools.product(_COMMON_CHARS[:15], repeat=n_missing): - candidate = base + "".join(chars) + continuation - if _is_known(candidate): - missing = "".join(chars) - # Confidence decreases with more missing chars - conf = 0.90 - (n_missing - 1) * 0.10 - return (candidate, missing, conf) - - return None - - -def _try_spell_fix( - word_text: str, col_type: str = "", -) -> Optional[Tuple[str, float, List[str]]]: - """Try to fix a single garbled gutter word via spellchecker. - - Returns (best_correction, confidence, alternatives_list) or None. - The alternatives list contains other plausible corrections the user - can choose from (e.g. "stammelt" vs "stammeln"). - """ - if len(word_text) < _MIN_WORD_LEN_SPELL: - return None - - # Strip trailing/leading parentheses and check if the bare word is valid. - # Words like "probieren)" or "(Englisch" are valid words with punctuation, - # not OCR errors. Don't suggest corrections for them. - stripped = word_text.strip("()") - if stripped and _is_known(stripped): - return None - - # Determine language priority from column type - if "en" in col_type: - lang = "en" - elif "de" in col_type: - lang = "de" - else: - lang = "both" - - candidates = _spell_candidates(word_text, lang=lang) - if not candidates and lang != "both": - candidates = _spell_candidates(word_text, lang="both") - - if not candidates: - return None - - # Preserve original casing - is_upper = word_text[0].isupper() - - def _preserve_case(w: str) -> str: - if is_upper and w: - return w[0].upper() + w[1:] - return w - - # Sort candidates by edit distance (closest first) - scored = [] - for c in candidates: - dist = _edit_distance(word_text.lower(), c.lower()) - scored.append((dist, c)) - scored.sort(key=lambda x: x[0]) - - best_dist, best = scored[0] - best = _preserve_case(best) - conf = max(0.5, 1.0 - best_dist * 0.15) - - # Build alternatives (all other candidates, also case-preserved) - alts = [_preserve_case(c) for _, c in scored[1:] if c.lower() != best.lower()] - # Limit to top 5 alternatives - alts = alts[:5] - - return (best, conf, alts) - - -def _edit_distance(a: str, b: str) -> int: - """Simple Levenshtein distance.""" - if len(a) < len(b): - return _edit_distance(b, a) - if len(b) == 0: - return len(a) - prev = list(range(len(b) + 1)) - for i, ca in enumerate(a): - curr = [i + 1] - for j, cb in enumerate(b): - cost = 0 if ca == cb else 1 - curr.append(min(curr[j] + 1, prev[j + 1] + 1, prev[j] + cost)) - prev = curr - return prev[len(b)] - - -# --------------------------------------------------------------------------- -# Grid analysis -# --------------------------------------------------------------------------- - -def analyse_grid_for_gutter_repair( - grid_data: Dict[str, Any], - image_width: int = 0, -) -> Dict[str, Any]: - """Analyse a structured grid and return gutter repair suggestions. - - Args: - grid_data: The grid_editor_result from the session (zones→cells structure). - image_width: Image width in pixels (for determining gutter side). - - Returns: - Dict with "suggestions" list and "stats". - """ - t0 = time.time() - _init_spellcheckers() - - if not _SPELL_AVAILABLE: - return { - "suggestions": [], - "stats": {"error": "pyspellchecker not installed"}, - "duration_seconds": 0, - } - - zones = grid_data.get("zones", []) - suggestions: List[GutterSuggestion] = [] - words_checked = 0 - gutter_candidates = 0 - - for zi, zone in enumerate(zones): - columns = zone.get("columns", []) - cells = zone.get("cells", []) - if not columns or not cells: - continue - - # Build column lookup: col_index → {x, width, type} - col_info: Dict[int, Dict] = {} - for col in columns: - ci = col.get("index", col.get("col_index", -1)) - col_info[ci] = { - "x": col.get("x_min_px", col.get("x", 0)), - "width": col.get("x_max_px", col.get("width", 0)) - col.get("x_min_px", col.get("x", 0)), - "type": col.get("type", col.get("col_type", "")), - } - - # Build row→col→cell lookup - cell_map: Dict[Tuple[int, int], Dict] = {} - max_row = 0 - for cell in cells: - ri = cell.get("row_index", 0) - ci = cell.get("col_index", 0) - cell_map[(ri, ci)] = cell - if ri > max_row: - max_row = ri - - # Determine which columns are at the gutter edge. - # For a left page: rightmost content columns. - # For now, check ALL columns — a word is a candidate if it's at the - # right edge of its column AND not a known word. - for (ri, ci), cell in cell_map.items(): - text = (cell.get("text") or "").strip() - if not text: - continue - if _is_ipa_text(text): - continue - - words_checked += 1 - col = col_info.get(ci, {}) - col_type = col.get("type", "") - - # Get word boxes to check position - word_boxes = cell.get("word_boxes", []) - - # Check the LAST word in the cell (rightmost, closest to gutter) - cell_words = text.split() - if not cell_words: - continue - - last_word = cell_words[-1] - - # Skip stopwords - if last_word.lower().rstrip(".,;:!?-") in _STOPWORDS: - continue - - last_word_clean = last_word.rstrip(".,;:!?)(") - if len(last_word_clean) < _MIN_WORD_LEN_HYPHEN: - continue - - # Check if the last word is at the gutter edge - is_at_edge = False - if word_boxes: - last_wb = word_boxes[-1] - is_at_edge = _word_is_at_gutter_edge( - last_wb, col.get("x", 0), col.get("width", 1) - ) - else: - # No word boxes — use cell bbox - bbox = cell.get("bbox_px", {}) - is_at_edge = _word_is_at_gutter_edge( - {"left": bbox.get("x", 0), "width": bbox.get("w", 0)}, - col.get("x", 0), col.get("width", 1) - ) - - if not is_at_edge: - continue - - # Word is at gutter edge — check if it's a known word - if _is_known(last_word_clean): - continue - - # Check if the word ends with "-" (explicit hyphen break) - ends_with_hyphen = last_word.endswith("-") - - # If the word already ends with "-" and the stem (without - # the hyphen) is a known word, this is a VALID line-break - # hyphenation — not a gutter error. Gutter problems cause - # the hyphen to be LOST ("ve" instead of "ver-"), so a - # visible hyphen + known stem = intentional word-wrap. - # Example: "wunder-" → "wunder" is known → skip. - if ends_with_hyphen: - stem = last_word_clean.rstrip("-") - if stem and _is_known(stem): - continue - - gutter_candidates += 1 - - # --- Strategy 1: Hyphen join with next row --- - next_cell = cell_map.get((ri + 1, ci)) - if next_cell: - next_text = (next_cell.get("text") or "").strip() - next_words = next_text.split() - if next_words: - first_next = next_words[0] - first_next_clean = _TRAILING_PUNCT_RE.sub('', first_next) - first_alpha = next((c for c in first_next if c.isalpha()), "") - - # Also skip if the joined word is known (covers compound - # words where the stem alone might not be in the dictionary) - if ends_with_hyphen and first_next_clean: - direct = last_word_clean.rstrip("-") + first_next_clean - if _is_known(direct): - continue - - # Continuation likely if: - # - explicit hyphen, OR - # - next row starts lowercase (= not a new entry) - if ends_with_hyphen or (first_alpha and first_alpha.islower()): - result = _try_hyphen_join(last_word_clean, first_next) - if result: - joined, missing, conf = result - # Build display parts: show hyphenation for original layout - if ends_with_hyphen: - display_p1 = last_word_clean.rstrip("-") - if missing: - display_p1 += missing - display_p1 += "-" - else: - display_p1 = last_word_clean - if missing: - display_p1 += missing + "-" - else: - display_p1 += "-" - - suggestion = GutterSuggestion( - type="hyphen_join", - zone_index=zi, - row_index=ri, - col_index=ci, - col_type=col_type, - cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"), - original_text=last_word, - suggested_text=joined, - next_row_index=ri + 1, - next_row_cell_id=next_cell.get("cell_id", f"R{ri+1:02d}_C{ci}"), - next_row_text=next_text, - missing_chars=missing, - display_parts=[display_p1, first_next], - confidence=conf, - reason="gutter_truncation" if missing else "hyphen_continuation", - ) - suggestions.append(suggestion) - continue # skip spell_fix if hyphen_join found - - # --- Strategy 2: Single-word spell fix (only for longer words) --- - fix_result = _try_spell_fix(last_word_clean, col_type) - if fix_result: - corrected, conf, alts = fix_result - suggestion = GutterSuggestion( - type="spell_fix", - zone_index=zi, - row_index=ri, - col_index=ci, - col_type=col_type, - cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"), - original_text=last_word, - suggested_text=corrected, - alternatives=alts, - confidence=conf, - reason="gutter_blur", - ) - suggestions.append(suggestion) - - duration = round(time.time() - t0, 3) - - logger.info( - "Gutter repair: checked %d words, %d gutter candidates, %d suggestions (%.2fs)", - words_checked, gutter_candidates, len(suggestions), duration, - ) - - return { - "suggestions": [s.to_dict() for s in suggestions], - "stats": { - "words_checked": words_checked, - "gutter_candidates": gutter_candidates, - "suggestions_found": len(suggestions), - }, - "duration_seconds": duration, - } - - -def apply_gutter_suggestions( - grid_data: Dict[str, Any], - accepted_ids: List[str], - suggestions: List[Dict[str, Any]], -) -> Dict[str, Any]: - """Apply accepted gutter repair suggestions to the grid data. - - Modifies cells in-place and returns summary of changes. - - Args: - grid_data: The grid_editor_result (zones→cells). - accepted_ids: List of suggestion IDs the user accepted. - suggestions: The full suggestions list (from analyse_grid_for_gutter_repair). - - Returns: - Dict with "applied_count" and "changes" list. - """ - accepted_set = set(accepted_ids) - accepted_suggestions = [s for s in suggestions if s.get("id") in accepted_set] - - zones = grid_data.get("zones", []) - changes: List[Dict[str, Any]] = [] - - for s in accepted_suggestions: - zi = s.get("zone_index", 0) - ri = s.get("row_index", 0) - ci = s.get("col_index", 0) - stype = s.get("type", "") - - if zi >= len(zones): - continue - zone_cells = zones[zi].get("cells", []) - - # Find the target cell - target_cell = None - for cell in zone_cells: - if cell.get("row_index") == ri and cell.get("col_index") == ci: - target_cell = cell - break - - if not target_cell: - continue - - old_text = target_cell.get("text", "") - - if stype == "spell_fix": - # Replace the last word in the cell text - original_word = s.get("original_text", "") - corrected = s.get("suggested_text", "") - if original_word and corrected: - # Replace from the right (last occurrence) - idx = old_text.rfind(original_word) - if idx >= 0: - new_text = old_text[:idx] + corrected + old_text[idx + len(original_word):] - target_cell["text"] = new_text - changes.append({ - "type": "spell_fix", - "zone_index": zi, - "row_index": ri, - "col_index": ci, - "cell_id": target_cell.get("cell_id", ""), - "old_text": old_text, - "new_text": new_text, - }) - - elif stype == "hyphen_join": - # Current cell: replace last word with the hyphenated first part - original_word = s.get("original_text", "") - joined = s.get("suggested_text", "") - display_parts = s.get("display_parts", []) - next_ri = s.get("next_row_index", -1) - - if not original_word or not joined or not display_parts: - continue - - # The first display part is what goes in the current row - first_part = display_parts[0] if display_parts else "" - - # Replace the last word in current cell with the restored form. - # The next row is NOT modified — "künden" stays in its row - # because the original book layout has it there. We only fix - # the truncated word in the current row (e.g. "ve" → "ver-"). - idx = old_text.rfind(original_word) - if idx >= 0: - new_text = old_text[:idx] + first_part + old_text[idx + len(original_word):] - target_cell["text"] = new_text - changes.append({ - "type": "hyphen_join", - "zone_index": zi, - "row_index": ri, - "col_index": ci, - "cell_id": target_cell.get("cell_id", ""), - "old_text": old_text, - "new_text": new_text, - "joined_word": joined, - }) - - logger.info("Gutter repair applied: %d/%d suggestions", len(changes), len(accepted_suggestions)) - - return { - "applied_count": len(accepted_suggestions), - "changes": changes, - } +# Core: spellchecker, data types, repair helpers +from cv_gutter_repair_core import ( # noqa: F401 + _init_spellcheckers, + _is_known, + _spell_candidates, + _MIN_WORD_LEN_SPELL, + _MIN_WORD_LEN_HYPHEN, + _GUTTER_EDGE_THRESHOLD, + _STOPWORDS, + _IPA_RE, + _is_ipa_text, + _word_is_at_gutter_edge, + GutterSuggestion, + _TRAILING_PUNCT_RE, + _try_hyphen_join, + _try_spell_fix, + _edit_distance, +) + +# Grid: analysis and application +from cv_gutter_repair_grid import ( # noqa: F401 + analyse_grid_for_gutter_repair, + apply_gutter_suggestions, +) diff --git a/klausur-service/backend/cv_gutter_repair_core.py b/klausur-service/backend/cv_gutter_repair_core.py new file mode 100644 index 0000000..4387e88 --- /dev/null +++ b/klausur-service/backend/cv_gutter_repair_core.py @@ -0,0 +1,275 @@ +""" +Gutter Repair Core — spellchecker setup, data types, and single-word repair logic. + +Extracted from cv_gutter_repair.py for modularity. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import itertools +import logging +import re +import uuid +from dataclasses import dataclass, field, asdict +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Spellchecker setup (lazy, cached) +# --------------------------------------------------------------------------- + +_spell_de = None +_spell_en = None +_SPELL_AVAILABLE = False + +def _init_spellcheckers(): + """Lazy-load DE + EN spellcheckers (cached across calls).""" + global _spell_de, _spell_en, _SPELL_AVAILABLE + if _spell_de is not None: + return + try: + from spellchecker import SpellChecker + _spell_de = SpellChecker(language='de', distance=1) + _spell_en = SpellChecker(language='en', distance=1) + _SPELL_AVAILABLE = True + logger.info("Gutter repair: spellcheckers loaded (DE + EN)") + except ImportError: + logger.warning("pyspellchecker not installed — gutter repair unavailable") + + +def _is_known(word: str) -> bool: + """Check if a word is known in DE or EN dictionary.""" + _init_spellcheckers() + if not _SPELL_AVAILABLE: + return False + w = word.lower() + return bool(_spell_de.known([w])) or bool(_spell_en.known([w])) + + +def _spell_candidates(word: str, lang: str = "both") -> List[str]: + """Get all plausible spellchecker candidates for a word (deduplicated).""" + _init_spellcheckers() + if not _SPELL_AVAILABLE: + return [] + w = word.lower() + seen: set = set() + results: List[str] = [] + + for checker in ([_spell_de, _spell_en] if lang == "both" + else [_spell_de] if lang == "de" + else [_spell_en]): + if checker is None: + continue + cands = checker.candidates(w) + if cands: + for c in cands: + if c and c != w and c not in seen: + seen.add(c) + results.append(c) + + return results + + +# --------------------------------------------------------------------------- +# Gutter position detection +# --------------------------------------------------------------------------- + +# Minimum word length for spell-fix (very short words are often legitimate) +_MIN_WORD_LEN_SPELL = 3 + +# Minimum word length for hyphen-join candidates (fragments at the gutter +# can be as short as 1-2 chars, e.g. "ve" from "ver-künden") +_MIN_WORD_LEN_HYPHEN = 2 + +# How close to the right column edge a word must be to count as "gutter-adjacent". +# Expressed as fraction of column width (e.g. 0.75 = rightmost 25%). +_GUTTER_EDGE_THRESHOLD = 0.70 + +# Small common words / abbreviations that should NOT be repaired +_STOPWORDS = frozenset([ + # German + "ab", "an", "am", "da", "er", "es", "im", "in", "ja", "ob", "so", "um", + "zu", "wo", "du", "eh", "ei", "je", "na", "nu", "oh", + # English + "a", "am", "an", "as", "at", "be", "by", "do", "go", "he", "if", "in", + "is", "it", "me", "my", "no", "of", "on", "or", "so", "to", "up", "us", + "we", +]) + +# IPA / phonetic patterns — skip these cells +_IPA_RE = re.compile(r'[\[\]/ˈˌːʃʒθðŋɑɒæɔəɛɪʊʌ]') + + +def _is_ipa_text(text: str) -> bool: + """True if text looks like IPA transcription.""" + return bool(_IPA_RE.search(text)) + + +def _word_is_at_gutter_edge(word_bbox: Dict, col_x: float, col_width: float) -> bool: + """Check if a word's right edge is near the right boundary of its column.""" + if col_width <= 0: + return False + word_right = word_bbox.get("left", 0) + word_bbox.get("width", 0) + col_right = col_x + col_width + # Word's right edge within the rightmost portion of the column + relative_pos = (word_right - col_x) / col_width + return relative_pos >= _GUTTER_EDGE_THRESHOLD + + +# --------------------------------------------------------------------------- +# Suggestion types +# --------------------------------------------------------------------------- + +@dataclass +class GutterSuggestion: + """A single correction suggestion.""" + id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) + type: str = "" # "hyphen_join" | "spell_fix" + zone_index: int = 0 + row_index: int = 0 + col_index: int = 0 + col_type: str = "" + cell_id: str = "" + original_text: str = "" + suggested_text: str = "" + # For hyphen_join: + next_row_index: int = -1 + next_row_cell_id: str = "" + next_row_text: str = "" + missing_chars: str = "" + display_parts: List[str] = field(default_factory=list) + # Alternatives (other plausible corrections the user can pick from) + alternatives: List[str] = field(default_factory=list) + # Meta: + confidence: float = 0.0 + reason: str = "" # "gutter_truncation" | "gutter_blur" | "hyphen_continuation" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +# --------------------------------------------------------------------------- +# Core repair logic +# --------------------------------------------------------------------------- + +_TRAILING_PUNCT_RE = re.compile(r'[.,;:!?\)\]]+$') + + +def _try_hyphen_join( + word_text: str, + next_word_text: str, + max_missing: int = 3, +) -> Optional[Tuple[str, str, float]]: + """Try joining two fragments with 0..max_missing interpolated chars. + + Strips trailing punctuation from the continuation word before testing + (e.g. "künden," → "künden") so dictionary lookup succeeds. + + Returns (joined_word, missing_chars, confidence) or None. + """ + base = word_text.rstrip("-").rstrip() + # Strip trailing punctuation from continuation (commas, periods, etc.) + raw_continuation = next_word_text.lstrip() + continuation = _TRAILING_PUNCT_RE.sub('', raw_continuation) + + if not base or not continuation: + return None + + # 1. Direct join (no missing chars) + direct = base + continuation + if _is_known(direct): + return (direct, "", 0.95) + + # 2. Try with 1..max_missing missing characters + # Use common letters, weighted by frequency in German/English + _COMMON_CHARS = "enristaldhgcmobwfkzpvjyxqu" + + for n_missing in range(1, max_missing + 1): + for chars in itertools.product(_COMMON_CHARS[:15], repeat=n_missing): + candidate = base + "".join(chars) + continuation + if _is_known(candidate): + missing = "".join(chars) + # Confidence decreases with more missing chars + conf = 0.90 - (n_missing - 1) * 0.10 + return (candidate, missing, conf) + + return None + + +def _try_spell_fix( + word_text: str, col_type: str = "", +) -> Optional[Tuple[str, float, List[str]]]: + """Try to fix a single garbled gutter word via spellchecker. + + Returns (best_correction, confidence, alternatives_list) or None. + The alternatives list contains other plausible corrections the user + can choose from (e.g. "stammelt" vs "stammeln"). + """ + if len(word_text) < _MIN_WORD_LEN_SPELL: + return None + + # Strip trailing/leading parentheses and check if the bare word is valid. + # Words like "probieren)" or "(Englisch" are valid words with punctuation, + # not OCR errors. Don't suggest corrections for them. + stripped = word_text.strip("()") + if stripped and _is_known(stripped): + return None + + # Determine language priority from column type + if "en" in col_type: + lang = "en" + elif "de" in col_type: + lang = "de" + else: + lang = "both" + + candidates = _spell_candidates(word_text, lang=lang) + if not candidates and lang != "both": + candidates = _spell_candidates(word_text, lang="both") + + if not candidates: + return None + + # Preserve original casing + is_upper = word_text[0].isupper() + + def _preserve_case(w: str) -> str: + if is_upper and w: + return w[0].upper() + w[1:] + return w + + # Sort candidates by edit distance (closest first) + scored = [] + for c in candidates: + dist = _edit_distance(word_text.lower(), c.lower()) + scored.append((dist, c)) + scored.sort(key=lambda x: x[0]) + + best_dist, best = scored[0] + best = _preserve_case(best) + conf = max(0.5, 1.0 - best_dist * 0.15) + + # Build alternatives (all other candidates, also case-preserved) + alts = [_preserve_case(c) for _, c in scored[1:] if c.lower() != best.lower()] + # Limit to top 5 alternatives + alts = alts[:5] + + return (best, conf, alts) + + +def _edit_distance(a: str, b: str) -> int: + """Simple Levenshtein distance.""" + if len(a) < len(b): + return _edit_distance(b, a) + if len(b) == 0: + return len(a) + prev = list(range(len(b) + 1)) + for i, ca in enumerate(a): + curr = [i + 1] + for j, cb in enumerate(b): + cost = 0 if ca == cb else 1 + curr.append(min(curr[j] + 1, prev[j + 1] + 1, prev[j] + cost)) + prev = curr + return prev[len(b)] diff --git a/klausur-service/backend/cv_gutter_repair_grid.py b/klausur-service/backend/cv_gutter_repair_grid.py new file mode 100644 index 0000000..caf7c0f --- /dev/null +++ b/klausur-service/backend/cv_gutter_repair_grid.py @@ -0,0 +1,356 @@ +""" +Gutter Repair Grid — grid analysis and suggestion application. + +Extracted from cv_gutter_repair.py for modularity. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from typing import Any, Dict, List, Tuple + +from cv_gutter_repair_core import ( + _init_spellcheckers, + _is_ipa_text, + _is_known, + _MIN_WORD_LEN_HYPHEN, + _SPELL_AVAILABLE, + _STOPWORDS, + _TRAILING_PUNCT_RE, + _try_hyphen_join, + _try_spell_fix, + _word_is_at_gutter_edge, + GutterSuggestion, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Grid analysis +# --------------------------------------------------------------------------- + +def analyse_grid_for_gutter_repair( + grid_data: Dict[str, Any], + image_width: int = 0, +) -> Dict[str, Any]: + """Analyse a structured grid and return gutter repair suggestions. + + Args: + grid_data: The grid_editor_result from the session (zones→cells structure). + image_width: Image width in pixels (for determining gutter side). + + Returns: + Dict with "suggestions" list and "stats". + """ + t0 = time.time() + _init_spellcheckers() + + if not _SPELL_AVAILABLE: + return { + "suggestions": [], + "stats": {"error": "pyspellchecker not installed"}, + "duration_seconds": 0, + } + + zones = grid_data.get("zones", []) + suggestions: List[GutterSuggestion] = [] + words_checked = 0 + gutter_candidates = 0 + + for zi, zone in enumerate(zones): + columns = zone.get("columns", []) + cells = zone.get("cells", []) + if not columns or not cells: + continue + + # Build column lookup: col_index → {x, width, type} + col_info: Dict[int, Dict] = {} + for col in columns: + ci = col.get("index", col.get("col_index", -1)) + col_info[ci] = { + "x": col.get("x_min_px", col.get("x", 0)), + "width": col.get("x_max_px", col.get("width", 0)) - col.get("x_min_px", col.get("x", 0)), + "type": col.get("type", col.get("col_type", "")), + } + + # Build row→col→cell lookup + cell_map: Dict[Tuple[int, int], Dict] = {} + max_row = 0 + for cell in cells: + ri = cell.get("row_index", 0) + ci = cell.get("col_index", 0) + cell_map[(ri, ci)] = cell + if ri > max_row: + max_row = ri + + # Determine which columns are at the gutter edge. + # For a left page: rightmost content columns. + # For now, check ALL columns — a word is a candidate if it's at the + # right edge of its column AND not a known word. + for (ri, ci), cell in cell_map.items(): + text = (cell.get("text") or "").strip() + if not text: + continue + if _is_ipa_text(text): + continue + + words_checked += 1 + col = col_info.get(ci, {}) + col_type = col.get("type", "") + + # Get word boxes to check position + word_boxes = cell.get("word_boxes", []) + + # Check the LAST word in the cell (rightmost, closest to gutter) + cell_words = text.split() + if not cell_words: + continue + + last_word = cell_words[-1] + + # Skip stopwords + if last_word.lower().rstrip(".,;:!?-") in _STOPWORDS: + continue + + last_word_clean = last_word.rstrip(".,;:!?)(") + if len(last_word_clean) < _MIN_WORD_LEN_HYPHEN: + continue + + # Check if the last word is at the gutter edge + is_at_edge = False + if word_boxes: + last_wb = word_boxes[-1] + is_at_edge = _word_is_at_gutter_edge( + last_wb, col.get("x", 0), col.get("width", 1) + ) + else: + # No word boxes — use cell bbox + bbox = cell.get("bbox_px", {}) + is_at_edge = _word_is_at_gutter_edge( + {"left": bbox.get("x", 0), "width": bbox.get("w", 0)}, + col.get("x", 0), col.get("width", 1) + ) + + if not is_at_edge: + continue + + # Word is at gutter edge — check if it's a known word + if _is_known(last_word_clean): + continue + + # Check if the word ends with "-" (explicit hyphen break) + ends_with_hyphen = last_word.endswith("-") + + # If the word already ends with "-" and the stem (without + # the hyphen) is a known word, this is a VALID line-break + # hyphenation — not a gutter error. Gutter problems cause + # the hyphen to be LOST ("ve" instead of "ver-"), so a + # visible hyphen + known stem = intentional word-wrap. + # Example: "wunder-" → "wunder" is known → skip. + if ends_with_hyphen: + stem = last_word_clean.rstrip("-") + if stem and _is_known(stem): + continue + + gutter_candidates += 1 + + # --- Strategy 1: Hyphen join with next row --- + next_cell = cell_map.get((ri + 1, ci)) + if next_cell: + next_text = (next_cell.get("text") or "").strip() + next_words = next_text.split() + if next_words: + first_next = next_words[0] + first_next_clean = _TRAILING_PUNCT_RE.sub('', first_next) + first_alpha = next((c for c in first_next if c.isalpha()), "") + + # Also skip if the joined word is known (covers compound + # words where the stem alone might not be in the dictionary) + if ends_with_hyphen and first_next_clean: + direct = last_word_clean.rstrip("-") + first_next_clean + if _is_known(direct): + continue + + # Continuation likely if: + # - explicit hyphen, OR + # - next row starts lowercase (= not a new entry) + if ends_with_hyphen or (first_alpha and first_alpha.islower()): + result = _try_hyphen_join(last_word_clean, first_next) + if result: + joined, missing, conf = result + # Build display parts: show hyphenation for original layout + if ends_with_hyphen: + display_p1 = last_word_clean.rstrip("-") + if missing: + display_p1 += missing + display_p1 += "-" + else: + display_p1 = last_word_clean + if missing: + display_p1 += missing + "-" + else: + display_p1 += "-" + + suggestion = GutterSuggestion( + type="hyphen_join", + zone_index=zi, + row_index=ri, + col_index=ci, + col_type=col_type, + cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"), + original_text=last_word, + suggested_text=joined, + next_row_index=ri + 1, + next_row_cell_id=next_cell.get("cell_id", f"R{ri+1:02d}_C{ci}"), + next_row_text=next_text, + missing_chars=missing, + display_parts=[display_p1, first_next], + confidence=conf, + reason="gutter_truncation" if missing else "hyphen_continuation", + ) + suggestions.append(suggestion) + continue # skip spell_fix if hyphen_join found + + # --- Strategy 2: Single-word spell fix (only for longer words) --- + fix_result = _try_spell_fix(last_word_clean, col_type) + if fix_result: + corrected, conf, alts = fix_result + suggestion = GutterSuggestion( + type="spell_fix", + zone_index=zi, + row_index=ri, + col_index=ci, + col_type=col_type, + cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"), + original_text=last_word, + suggested_text=corrected, + alternatives=alts, + confidence=conf, + reason="gutter_blur", + ) + suggestions.append(suggestion) + + duration = round(time.time() - t0, 3) + + logger.info( + "Gutter repair: checked %d words, %d gutter candidates, %d suggestions (%.2fs)", + words_checked, gutter_candidates, len(suggestions), duration, + ) + + return { + "suggestions": [s.to_dict() for s in suggestions], + "stats": { + "words_checked": words_checked, + "gutter_candidates": gutter_candidates, + "suggestions_found": len(suggestions), + }, + "duration_seconds": duration, + } + + +def apply_gutter_suggestions( + grid_data: Dict[str, Any], + accepted_ids: List[str], + suggestions: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Apply accepted gutter repair suggestions to the grid data. + + Modifies cells in-place and returns summary of changes. + + Args: + grid_data: The grid_editor_result (zones→cells). + accepted_ids: List of suggestion IDs the user accepted. + suggestions: The full suggestions list (from analyse_grid_for_gutter_repair). + + Returns: + Dict with "applied_count" and "changes" list. + """ + accepted_set = set(accepted_ids) + accepted_suggestions = [s for s in suggestions if s.get("id") in accepted_set] + + zones = grid_data.get("zones", []) + changes: List[Dict[str, Any]] = [] + + for s in accepted_suggestions: + zi = s.get("zone_index", 0) + ri = s.get("row_index", 0) + ci = s.get("col_index", 0) + stype = s.get("type", "") + + if zi >= len(zones): + continue + zone_cells = zones[zi].get("cells", []) + + # Find the target cell + target_cell = None + for cell in zone_cells: + if cell.get("row_index") == ri and cell.get("col_index") == ci: + target_cell = cell + break + + if not target_cell: + continue + + old_text = target_cell.get("text", "") + + if stype == "spell_fix": + # Replace the last word in the cell text + original_word = s.get("original_text", "") + corrected = s.get("suggested_text", "") + if original_word and corrected: + # Replace from the right (last occurrence) + idx = old_text.rfind(original_word) + if idx >= 0: + new_text = old_text[:idx] + corrected + old_text[idx + len(original_word):] + target_cell["text"] = new_text + changes.append({ + "type": "spell_fix", + "zone_index": zi, + "row_index": ri, + "col_index": ci, + "cell_id": target_cell.get("cell_id", ""), + "old_text": old_text, + "new_text": new_text, + }) + + elif stype == "hyphen_join": + # Current cell: replace last word with the hyphenated first part + original_word = s.get("original_text", "") + joined = s.get("suggested_text", "") + display_parts = s.get("display_parts", []) + next_ri = s.get("next_row_index", -1) + + if not original_word or not joined or not display_parts: + continue + + # The first display part is what goes in the current row + first_part = display_parts[0] if display_parts else "" + + # Replace the last word in current cell with the restored form. + # The next row is NOT modified — "künden" stays in its row + # because the original book layout has it there. We only fix + # the truncated word in the current row (e.g. "ve" → "ver-"). + idx = old_text.rfind(original_word) + if idx >= 0: + new_text = old_text[:idx] + first_part + old_text[idx + len(original_word):] + target_cell["text"] = new_text + changes.append({ + "type": "hyphen_join", + "zone_index": zi, + "row_index": ri, + "col_index": ci, + "cell_id": target_cell.get("cell_id", ""), + "old_text": old_text, + "new_text": new_text, + "joined_word": joined, + }) + + logger.info("Gutter repair applied: %d/%d suggestions", len(changes), len(accepted_suggestions)) + + return { + "applied_count": len(accepted_suggestions), + "changes": changes, + } diff --git a/klausur-service/backend/cv_syllable_core.py b/klausur-service/backend/cv_syllable_core.py new file mode 100644 index 0000000..4a4dca8 --- /dev/null +++ b/klausur-service/backend/cv_syllable_core.py @@ -0,0 +1,231 @@ +""" +Syllable Core — hyphenator init, word validation, pipe autocorrect. + +Extracted from cv_syllable_detect.py for modularity. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +# IPA/phonetic characters -- skip cells containing these +_IPA_RE = re.compile(r'[\[\]\u02c8\u02cc\u02d0\u0283\u0292\u03b8\u00f0\u014b\u0251\u0252\u00e6\u0254\u0259\u025b\u025c\u026a\u028a\u028c]') + +# Common German words that should NOT be merged with adjacent tokens. +_STOP_WORDS = frozenset([ + # Articles + 'der', 'die', 'das', 'dem', 'den', 'des', + 'ein', 'eine', 'einem', 'einen', 'einer', + # Pronouns + 'du', 'er', 'es', 'sie', 'wir', 'ihr', 'ich', 'man', 'sich', + 'dich', 'dir', 'mich', 'mir', 'uns', 'euch', 'ihm', 'ihn', + # Prepositions + 'mit', 'von', 'zu', 'f\u00fcr', 'auf', 'in', 'an', 'um', 'am', 'im', + 'aus', 'bei', 'nach', 'vor', 'bis', 'durch', '\u00fcber', 'unter', + 'zwischen', 'ohne', 'gegen', + # Conjunctions + 'und', 'oder', 'als', 'wie', 'wenn', 'dass', 'weil', 'aber', + # Adverbs + 'auch', 'noch', 'nur', 'schon', 'sehr', 'nicht', + # Verbs + 'ist', 'hat', 'wird', 'kann', 'soll', 'muss', 'darf', + 'sein', 'haben', + # Other + 'kein', 'keine', 'keinem', 'keinen', 'keiner', +]) + +# Cached hyphenators +_hyph_de = None +_hyph_en = None + +# Cached spellchecker (for autocorrect_pipe_artifacts) +_spell_de = None + + +def _get_hyphenators(): + """Lazy-load pyphen hyphenators (cached across calls).""" + global _hyph_de, _hyph_en + if _hyph_de is not None: + return _hyph_de, _hyph_en + try: + import pyphen + except ImportError: + return None, None + _hyph_de = pyphen.Pyphen(lang='de_DE') + _hyph_en = pyphen.Pyphen(lang='en_US') + return _hyph_de, _hyph_en + + +def _get_spellchecker(): + """Lazy-load German spellchecker (cached across calls).""" + global _spell_de + if _spell_de is not None: + return _spell_de + try: + from spellchecker import SpellChecker + except ImportError: + return None + _spell_de = SpellChecker(language='de') + return _spell_de + + +def _is_known_word(word: str, hyph_de, hyph_en) -> bool: + """Check whether pyphen recognises a word (DE or EN).""" + if len(word) < 2: + return False + return ('|' in hyph_de.inserted(word, hyphen='|') + or '|' in hyph_en.inserted(word, hyphen='|')) + + +def _is_real_word(word: str) -> bool: + """Check whether spellchecker knows this word (case-insensitive).""" + spell = _get_spellchecker() + if spell is None: + return False + return word.lower() in spell + + +def _hyphenate_word(word: str, hyph_de, hyph_en) -> Optional[str]: + """Try to hyphenate a word using DE then EN dictionary. + + Returns word with | separators, or None if not recognized. + """ + hyph = hyph_de.inserted(word, hyphen='|') + if '|' in hyph: + return hyph + hyph = hyph_en.inserted(word, hyphen='|') + if '|' in hyph: + return hyph + return None + + +def _autocorrect_piped_word(word_with_pipes: str) -> Optional[str]: + """Try to correct a word that has OCR pipe artifacts. + + Printed syllable divider lines on dictionary pages confuse OCR: + the vertical stroke is often read as an extra character (commonly + ``l``, ``I``, ``1``, ``i``) adjacent to where the pipe appears. + + Uses ``spellchecker`` (frequency-based word list) for validation. + + Strategy: + 1. Strip ``|`` -- if spellchecker knows the result, done. + 2. Try deleting each pipe-like character (l, I, 1, i, t). + 3. Fall back to spellchecker's own ``correction()`` method. + 4. Preserve the original casing of the first letter. + """ + stripped = word_with_pipes.replace('|', '') + if not stripped or len(stripped) < 3: + return stripped # too short to validate + + # Step 1: if the stripped word is already a real word, done + if _is_real_word(stripped): + return stripped + + # Step 2: try deleting pipe-like characters (most likely artifacts) + _PIPE_LIKE = frozenset('lI1it') + for idx in range(len(stripped)): + if stripped[idx] not in _PIPE_LIKE: + continue + candidate = stripped[:idx] + stripped[idx + 1:] + if len(candidate) >= 3 and _is_real_word(candidate): + return candidate + + # Step 3: use spellchecker's built-in correction + spell = _get_spellchecker() + if spell is not None: + suggestion = spell.correction(stripped.lower()) + if suggestion and suggestion != stripped.lower(): + # Preserve original first-letter case + if stripped[0].isupper(): + suggestion = suggestion[0].upper() + suggestion[1:] + return suggestion + + return None # could not fix + + +def autocorrect_pipe_artifacts( + zones_data: List[Dict], session_id: str, +) -> int: + """Strip OCR pipe artifacts and correct garbled words in-place. + + Printed syllable divider lines on dictionary scans are read by OCR + as ``|`` characters embedded in words (e.g. ``Zel|le``, ``Ze|plpe|lin``). + This function: + + 1. Strips ``|`` from every word in content cells. + 2. Validates with spellchecker (real dictionary lookup). + 3. If not recognised, tries deleting pipe-like characters or uses + spellchecker's correction (e.g. ``Zeplpelin`` -> ``Zeppelin``). + 4. Updates both word-box texts and cell text. + + Returns the number of cells modified. + """ + spell = _get_spellchecker() + if spell is None: + logger.warning("spellchecker not available -- pipe autocorrect limited") + # Fall back: still strip pipes even without spellchecker + pass + + modified = 0 + for z in zones_data: + for cell in z.get("cells", []): + ct = cell.get("col_type", "") + if not ct.startswith("column_"): + continue + + cell_changed = False + + # --- Fix word boxes --- + for wb in cell.get("word_boxes", []): + wb_text = wb.get("text", "") + if "|" not in wb_text: + continue + + # Separate trailing punctuation + m = re.match( + r'^([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)' + r'(.*?)' + r'([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)$', + wb_text, + ) + if not m: + continue + lead, core, trail = m.group(1), m.group(2), m.group(3) + if "|" not in core: + continue + + corrected = _autocorrect_piped_word(core) + if corrected is not None and corrected != core: + wb["text"] = lead + corrected + trail + cell_changed = True + + # --- Rebuild cell text from word boxes --- + if cell_changed: + wbs = cell.get("word_boxes", []) + if wbs: + cell["text"] = " ".join( + (wb.get("text") or "") for wb in wbs + ) + modified += 1 + + # --- Fallback: strip residual | from cell text --- + text = cell.get("text", "") + if "|" in text: + clean = text.replace("|", "") + if clean != text: + cell["text"] = clean + if not cell_changed: + modified += 1 + + if modified: + logger.info( + "build-grid session %s: autocorrected pipe artifacts in %d cells", + session_id, modified, + ) + return modified diff --git a/klausur-service/backend/cv_syllable_detect.py b/klausur-service/backend/cv_syllable_detect.py index 65e0ae9..fe2b003 100644 --- a/klausur-service/backend/cv_syllable_detect.py +++ b/klausur-service/backend/cv_syllable_detect.py @@ -1,532 +1,32 @@ """ -Syllable divider insertion for dictionary pages. +Syllable divider insertion for dictionary pages — barrel re-export. -For confirmed dictionary pages (is_dictionary=True), processes all content -column cells: - 1. Strips existing | dividers for clean normalization - 2. Merges pipe-gap spaces (where OCR split a word at a divider position) - 3. Applies pyphen syllabification to each word >= 3 alpha chars (DE then EN) - 4. Only modifies words that pyphen recognizes — garbled OCR stays as-is - -No CV gate needed — the dictionary detection confidence is sufficient. -pyphen uses Hunspell/TeX hyphenation dictionaries and is very reliable. +All implementation split into: + cv_syllable_core — hyphenator init, word validation, pipe autocorrect + cv_syllable_merge — word gap merging, syllabification, divider insertion Lizenz: Apache 2.0 (kommerziell nutzbar) DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import logging -import re -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np - -logger = logging.getLogger(__name__) - -# IPA/phonetic characters — skip cells containing these -_IPA_RE = re.compile(r'[\[\]ˈˌːʃʒθðŋɑɒæɔəɛɜɪʊʌ]') - -# Common German words that should NOT be merged with adjacent tokens. -# These are function words that appear as standalone words between -# headwords/definitions on dictionary pages. -_STOP_WORDS = frozenset([ - # Articles - 'der', 'die', 'das', 'dem', 'den', 'des', - 'ein', 'eine', 'einem', 'einen', 'einer', - # Pronouns - 'du', 'er', 'es', 'sie', 'wir', 'ihr', 'ich', 'man', 'sich', - 'dich', 'dir', 'mich', 'mir', 'uns', 'euch', 'ihm', 'ihn', - # Prepositions - 'mit', 'von', 'zu', 'für', 'auf', 'in', 'an', 'um', 'am', 'im', - 'aus', 'bei', 'nach', 'vor', 'bis', 'durch', 'über', 'unter', - 'zwischen', 'ohne', 'gegen', - # Conjunctions - 'und', 'oder', 'als', 'wie', 'wenn', 'dass', 'weil', 'aber', - # Adverbs - 'auch', 'noch', 'nur', 'schon', 'sehr', 'nicht', - # Verbs - 'ist', 'hat', 'wird', 'kann', 'soll', 'muss', 'darf', - 'sein', 'haben', - # Other - 'kein', 'keine', 'keinem', 'keinen', 'keiner', -]) - -# Cached hyphenators -_hyph_de = None -_hyph_en = None - -# Cached spellchecker (for autocorrect_pipe_artifacts) -_spell_de = None - - -def _get_hyphenators(): - """Lazy-load pyphen hyphenators (cached across calls).""" - global _hyph_de, _hyph_en - if _hyph_de is not None: - return _hyph_de, _hyph_en - try: - import pyphen - except ImportError: - return None, None - _hyph_de = pyphen.Pyphen(lang='de_DE') - _hyph_en = pyphen.Pyphen(lang='en_US') - return _hyph_de, _hyph_en - - -def _get_spellchecker(): - """Lazy-load German spellchecker (cached across calls).""" - global _spell_de - if _spell_de is not None: - return _spell_de - try: - from spellchecker import SpellChecker - except ImportError: - return None - _spell_de = SpellChecker(language='de') - return _spell_de - - -def _is_known_word(word: str, hyph_de, hyph_en) -> bool: - """Check whether pyphen recognises a word (DE or EN).""" - if len(word) < 2: - return False - return ('|' in hyph_de.inserted(word, hyphen='|') - or '|' in hyph_en.inserted(word, hyphen='|')) - - -def _is_real_word(word: str) -> bool: - """Check whether spellchecker knows this word (case-insensitive).""" - spell = _get_spellchecker() - if spell is None: - return False - return word.lower() in spell - - -def _hyphenate_word(word: str, hyph_de, hyph_en) -> Optional[str]: - """Try to hyphenate a word using DE then EN dictionary. - - Returns word with | separators, or None if not recognized. - """ - hyph = hyph_de.inserted(word, hyphen='|') - if '|' in hyph: - return hyph - hyph = hyph_en.inserted(word, hyphen='|') - if '|' in hyph: - return hyph - return None - - -def _autocorrect_piped_word(word_with_pipes: str) -> Optional[str]: - """Try to correct a word that has OCR pipe artifacts. - - Printed syllable divider lines on dictionary pages confuse OCR: - the vertical stroke is often read as an extra character (commonly - ``l``, ``I``, ``1``, ``i``) adjacent to where the pipe appears. - Sometimes OCR reads one divider as ``|`` and another as a letter, - so the garbled character may be far from any detected pipe. - - Uses ``spellchecker`` (frequency-based word list) for validation — - unlike pyphen which is a pattern-based hyphenator and accepts - nonsense strings like "Zeplpelin". - - Strategy: - 1. Strip ``|`` — if spellchecker knows the result, done. - 2. Try deleting each pipe-like character (l, I, 1, i, t). - OCR inserts extra chars that resemble vertical strokes. - 3. Fall back to spellchecker's own ``correction()`` method. - 4. Preserve the original casing of the first letter. - """ - stripped = word_with_pipes.replace('|', '') - if not stripped or len(stripped) < 3: - return stripped # too short to validate - - # Step 1: if the stripped word is already a real word, done - if _is_real_word(stripped): - return stripped - - # Step 2: try deleting pipe-like characters (most likely artifacts) - _PIPE_LIKE = frozenset('lI1it') - for idx in range(len(stripped)): - if stripped[idx] not in _PIPE_LIKE: - continue - candidate = stripped[:idx] + stripped[idx + 1:] - if len(candidate) >= 3 and _is_real_word(candidate): - return candidate - - # Step 3: use spellchecker's built-in correction - spell = _get_spellchecker() - if spell is not None: - suggestion = spell.correction(stripped.lower()) - if suggestion and suggestion != stripped.lower(): - # Preserve original first-letter case - if stripped[0].isupper(): - suggestion = suggestion[0].upper() + suggestion[1:] - return suggestion - - return None # could not fix - - -def autocorrect_pipe_artifacts( - zones_data: List[Dict], session_id: str, -) -> int: - """Strip OCR pipe artifacts and correct garbled words in-place. - - Printed syllable divider lines on dictionary scans are read by OCR - as ``|`` characters embedded in words (e.g. ``Zel|le``, ``Ze|plpe|lin``). - This function: - - 1. Strips ``|`` from every word in content cells. - 2. Validates with spellchecker (real dictionary lookup). - 3. If not recognised, tries deleting pipe-like characters or uses - spellchecker's correction (e.g. ``Zeplpelin`` → ``Zeppelin``). - 4. Updates both word-box texts and cell text. - - Returns the number of cells modified. - """ - spell = _get_spellchecker() - if spell is None: - logger.warning("spellchecker not available — pipe autocorrect limited") - # Fall back: still strip pipes even without spellchecker - pass - - modified = 0 - for z in zones_data: - for cell in z.get("cells", []): - ct = cell.get("col_type", "") - if not ct.startswith("column_"): - continue - - cell_changed = False - - # --- Fix word boxes --- - for wb in cell.get("word_boxes", []): - wb_text = wb.get("text", "") - if "|" not in wb_text: - continue - - # Separate trailing punctuation - m = re.match( - r'^([^a-zA-ZäöüÄÖÜßẞ]*)' - r'(.*?)' - r'([^a-zA-ZäöüÄÖÜßẞ]*)$', - wb_text, - ) - if not m: - continue - lead, core, trail = m.group(1), m.group(2), m.group(3) - if "|" not in core: - continue - - corrected = _autocorrect_piped_word(core) - if corrected is not None and corrected != core: - wb["text"] = lead + corrected + trail - cell_changed = True - - # --- Rebuild cell text from word boxes --- - if cell_changed: - wbs = cell.get("word_boxes", []) - if wbs: - cell["text"] = " ".join( - (wb.get("text") or "") for wb in wbs - ) - modified += 1 - - # --- Fallback: strip residual | from cell text --- - # (covers cases where word_boxes don't exist or weren't fixed) - text = cell.get("text", "") - if "|" in text: - clean = text.replace("|", "") - if clean != text: - cell["text"] = clean - if not cell_changed: - modified += 1 - - if modified: - logger.info( - "build-grid session %s: autocorrected pipe artifacts in %d cells", - session_id, modified, - ) - return modified - - -def _try_merge_pipe_gaps(text: str, hyph_de) -> str: - """Merge fragments separated by single spaces where OCR split at a pipe. - - Example: "Kaf fee" -> "Kaffee" (pyphen recognizes the merged word). - Multi-step: "Ka bel jau" -> "Kabel jau" -> "Kabeljau". - - Guards against false merges: - - The FIRST token must be pure alpha (word start — no attached punctuation) - - The second token may have trailing punctuation (comma, period) which - stays attached to the merged word: "Kä" + "fer," -> "Käfer," - - Common German function words (der, die, das, ...) are never merged - - At least one fragment must be very short (<=3 alpha chars) - """ - parts = text.split(' ') - if len(parts) < 2: - return text - - result = [parts[0]] - i = 1 - while i < len(parts): - prev = result[-1] - curr = parts[i] - - # Extract alpha-only core for lookup - prev_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', prev) - curr_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', curr) - - # Guard 1: first token must be pure alpha (word-start fragment) - # second token may have trailing punctuation - # Guard 2: neither alpha core can be a common German function word - # Guard 3: the shorter fragment must be <= 3 chars (pipe-gap signal) - # Guard 4: combined length must be >= 4 - should_try = ( - prev == prev_alpha # first token: pure alpha (word start) - and prev_alpha and curr_alpha - and prev_alpha.lower() not in _STOP_WORDS - and curr_alpha.lower() not in _STOP_WORDS - and min(len(prev_alpha), len(curr_alpha)) <= 3 - and len(prev_alpha) + len(curr_alpha) >= 4 - ) - - if should_try: - merged_alpha = prev_alpha + curr_alpha - hyph = hyph_de.inserted(merged_alpha, hyphen='-') - if '-' in hyph: - # pyphen recognizes merged word — collapse the space - result[-1] = prev + curr - i += 1 - continue - - result.append(curr) - i += 1 - - return ' '.join(result) - - -def merge_word_gaps_in_zones(zones_data: List[Dict], session_id: str) -> int: - """Merge OCR word-gap fragments in cell texts using pyphen validation. - - OCR often splits words at syllable boundaries into separate word_boxes, - producing text like "zerknit tert" instead of "zerknittert". This - function tries to merge adjacent fragments in every content cell. - - More permissive than ``_try_merge_pipe_gaps`` (threshold 5 instead of 3) - but still guarded by pyphen dictionary lookup and stop-word exclusion. - - Returns the number of cells modified. - """ - hyph_de, _ = _get_hyphenators() - if hyph_de is None: - return 0 - - modified = 0 - for z in zones_data: - for cell in z.get("cells", []): - ct = cell.get("col_type", "") - if not ct.startswith("column_"): - continue - text = cell.get("text", "") - if not text or " " not in text: - continue - - # Skip IPA cells - text_no_brackets = re.sub(r'\[[^\]]*\]', '', text) - if _IPA_RE.search(text_no_brackets): - continue - - new_text = _try_merge_word_gaps(text, hyph_de) - if new_text != text: - cell["text"] = new_text - modified += 1 - - if modified: - logger.info( - "build-grid session %s: merged word gaps in %d cells", - session_id, modified, - ) - return modified - - -def _try_merge_word_gaps(text: str, hyph_de) -> str: - """Merge OCR word fragments with relaxed threshold (max_short=5). - - Similar to ``_try_merge_pipe_gaps`` but allows slightly longer fragments - (max_short=5 instead of 3). Still requires pyphen to recognize the - merged word. - """ - parts = text.split(' ') - if len(parts) < 2: - return text - - result = [parts[0]] - i = 1 - while i < len(parts): - prev = result[-1] - curr = parts[i] - - prev_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', prev) - curr_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', curr) - - should_try = ( - prev == prev_alpha - and prev_alpha and curr_alpha - and prev_alpha.lower() not in _STOP_WORDS - and curr_alpha.lower() not in _STOP_WORDS - and min(len(prev_alpha), len(curr_alpha)) <= 5 - and len(prev_alpha) + len(curr_alpha) >= 4 - ) - - if should_try: - merged_alpha = prev_alpha + curr_alpha - hyph = hyph_de.inserted(merged_alpha, hyphen='-') - if '-' in hyph: - result[-1] = prev + curr - i += 1 - continue - - result.append(curr) - i += 1 - - return ' '.join(result) - - -def _syllabify_text(text: str, hyph_de, hyph_en) -> str: - """Syllabify all significant words in a text string. - - 1. Strip existing | dividers - 2. Merge pipe-gap spaces where possible - 3. Apply pyphen to each word >= 3 alphabetic chars - 4. Words pyphen doesn't recognize stay as-is (no bad guesses) - """ - if not text: - return text - - # Skip cells that contain IPA transcription characters outside brackets. - # Bracket content like [bɪltʃøn] is programmatically inserted and should - # not block syllabification of the surrounding text. - text_no_brackets = re.sub(r'\[[^\]]*\]', '', text) - if _IPA_RE.search(text_no_brackets): - return text - - # Phase 1: strip existing pipe dividers for clean normalization - clean = text.replace('|', '') - - # Phase 2: merge pipe-gap spaces (OCR fragments from pipe splitting) - clean = _try_merge_pipe_gaps(clean, hyph_de) - - # Phase 3: tokenize and syllabify each word - # Split on whitespace and comma/semicolon sequences, keeping separators - tokens = re.split(r'(\s+|[,;:]+\s*)', clean) - - result = [] - for tok in tokens: - if not tok or re.match(r'^[\s,;:]+$', tok): - result.append(tok) - continue - - # Strip trailing/leading punctuation for pyphen lookup - m = re.match(r'^([^a-zA-ZäöüÄÖÜßẞ]*)(.*?)([^a-zA-ZäöüÄÖÜßẞ]*)$', tok) - if not m: - result.append(tok) - continue - lead, word, trail = m.group(1), m.group(2), m.group(3) - - if len(word) < 3 or not re.search(r'[a-zA-ZäöüÄÖÜß]', word): - result.append(tok) - continue - - hyph = _hyphenate_word(word, hyph_de, hyph_en) - if hyph: - result.append(lead + hyph + trail) - else: - result.append(tok) - - return ''.join(result) - - -def insert_syllable_dividers( - zones_data: List[Dict], - img_bgr: np.ndarray, - session_id: str, - *, - force: bool = False, - col_filter: Optional[set] = None, -) -> int: - """Insert pipe syllable dividers into dictionary cells. - - For dictionary pages: process all content column cells, strip existing - pipes, merge pipe-gap spaces, and re-syllabify using pyphen. - - Pre-check: at least 1% of content cells must already contain ``|`` from - OCR. This guards against pages with zero pipe characters (the primary - guard — article_col_index — is checked at the call site). - - Args: - force: If True, skip the pipe-ratio pre-check and syllabify all - content words regardless of whether the original has pipe dividers. - col_filter: If set, only process cells whose col_type is in this set. - None means process all content columns. - - Returns the number of cells modified. - """ - hyph_de, hyph_en = _get_hyphenators() - if hyph_de is None: - logger.warning("pyphen not installed — skipping syllable insertion") - return 0 - - # Pre-check: count cells that already have | from OCR. - # Real dictionary pages with printed syllable dividers will have OCR- - # detected pipes in many cells. Pages without syllable dividers will - # have zero — skip those to avoid false syllabification. - if not force: - total_col_cells = 0 - cells_with_pipes = 0 - for z in zones_data: - for cell in z.get("cells", []): - if cell.get("col_type", "").startswith("column_"): - total_col_cells += 1 - if "|" in cell.get("text", ""): - cells_with_pipes += 1 - - if total_col_cells > 0: - pipe_ratio = cells_with_pipes / total_col_cells - if pipe_ratio < 0.01: - logger.info( - "build-grid session %s: skipping syllable insertion — " - "only %.1f%% of cells have existing pipes (need >=1%%)", - session_id, pipe_ratio * 100, - ) - return 0 - - insertions = 0 - for z in zones_data: - for cell in z.get("cells", []): - ct = cell.get("col_type", "") - if not ct.startswith("column_"): - continue - if col_filter is not None and ct not in col_filter: - continue - text = cell.get("text", "") - if not text: - continue - - # In auto mode (force=False), only normalize cells that already - # have | from OCR (i.e. printed syllable dividers on the original - # scan). Don't add new syllable marks to other words. - if not force and "|" not in text: - continue - - new_text = _syllabify_text(text, hyph_de, hyph_en) - if new_text != text: - cell["text"] = new_text - insertions += 1 - - if insertions: - logger.info( - "build-grid session %s: syllable dividers inserted/normalized " - "in %d cells (pyphen)", - session_id, insertions, - ) - return insertions +# Core: init, validation, autocorrect +from cv_syllable_core import ( # noqa: F401 + _IPA_RE, + _STOP_WORDS, + _get_hyphenators, + _get_spellchecker, + _is_known_word, + _is_real_word, + _hyphenate_word, + _autocorrect_piped_word, + autocorrect_pipe_artifacts, +) + +# Merge: gap merging, syllabify, insert +from cv_syllable_merge import ( # noqa: F401 + _try_merge_pipe_gaps, + merge_word_gaps_in_zones, + _try_merge_word_gaps, + _syllabify_text, + insert_syllable_dividers, +) diff --git a/klausur-service/backend/cv_syllable_merge.py b/klausur-service/backend/cv_syllable_merge.py new file mode 100644 index 0000000..3684210 --- /dev/null +++ b/klausur-service/backend/cv_syllable_merge.py @@ -0,0 +1,300 @@ +""" +Syllable Merge — word gap merging, syllabification, divider insertion. + +Extracted from cv_syllable_detect.py for modularity. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import re +from typing import Any, Dict, List, Optional + +import numpy as np + +from cv_syllable_core import ( + _get_hyphenators, + _hyphenate_word, + _IPA_RE, + _STOP_WORDS, +) + +logger = logging.getLogger(__name__) + + +def _try_merge_pipe_gaps(text: str, hyph_de) -> str: + """Merge fragments separated by single spaces where OCR split at a pipe. + + Example: "Kaf fee" -> "Kaffee" (pyphen recognizes the merged word). + Multi-step: "Ka bel jau" -> "Kabel jau" -> "Kabeljau". + + Guards against false merges: + - The FIRST token must be pure alpha (word start -- no attached punctuation) + - The second token may have trailing punctuation (comma, period) which + stays attached to the merged word: "Ka" + "fer," -> "Kafer," + - Common German function words (der, die, das, ...) are never merged + - At least one fragment must be very short (<=3 alpha chars) + """ + parts = text.split(' ') + if len(parts) < 2: + return text + + result = [parts[0]] + i = 1 + while i < len(parts): + prev = result[-1] + curr = parts[i] + + # Extract alpha-only core for lookup + prev_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', prev) + curr_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', curr) + + # Guard 1: first token must be pure alpha (word-start fragment) + # second token may have trailing punctuation + # Guard 2: neither alpha core can be a common German function word + # Guard 3: the shorter fragment must be <= 3 chars (pipe-gap signal) + # Guard 4: combined length must be >= 4 + should_try = ( + prev == prev_alpha # first token: pure alpha (word start) + and prev_alpha and curr_alpha + and prev_alpha.lower() not in _STOP_WORDS + and curr_alpha.lower() not in _STOP_WORDS + and min(len(prev_alpha), len(curr_alpha)) <= 3 + and len(prev_alpha) + len(curr_alpha) >= 4 + ) + + if should_try: + merged_alpha = prev_alpha + curr_alpha + hyph = hyph_de.inserted(merged_alpha, hyphen='-') + if '-' in hyph: + # pyphen recognizes merged word -- collapse the space + result[-1] = prev + curr + i += 1 + continue + + result.append(curr) + i += 1 + + return ' '.join(result) + + +def merge_word_gaps_in_zones(zones_data: List[Dict], session_id: str) -> int: + """Merge OCR word-gap fragments in cell texts using pyphen validation. + + OCR often splits words at syllable boundaries into separate word_boxes, + producing text like "zerknit tert" instead of "zerknittert". This + function tries to merge adjacent fragments in every content cell. + + More permissive than ``_try_merge_pipe_gaps`` (threshold 5 instead of 3) + but still guarded by pyphen dictionary lookup and stop-word exclusion. + + Returns the number of cells modified. + """ + hyph_de, _ = _get_hyphenators() + if hyph_de is None: + return 0 + + modified = 0 + for z in zones_data: + for cell in z.get("cells", []): + ct = cell.get("col_type", "") + if not ct.startswith("column_"): + continue + text = cell.get("text", "") + if not text or " " not in text: + continue + + # Skip IPA cells + text_no_brackets = re.sub(r'\[[^\]]*\]', '', text) + if _IPA_RE.search(text_no_brackets): + continue + + new_text = _try_merge_word_gaps(text, hyph_de) + if new_text != text: + cell["text"] = new_text + modified += 1 + + if modified: + logger.info( + "build-grid session %s: merged word gaps in %d cells", + session_id, modified, + ) + return modified + + +def _try_merge_word_gaps(text: str, hyph_de) -> str: + """Merge OCR word fragments with relaxed threshold (max_short=5). + + Similar to ``_try_merge_pipe_gaps`` but allows slightly longer fragments + (max_short=5 instead of 3). Still requires pyphen to recognize the + merged word. + """ + parts = text.split(' ') + if len(parts) < 2: + return text + + result = [parts[0]] + i = 1 + while i < len(parts): + prev = result[-1] + curr = parts[i] + + prev_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', prev) + curr_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', curr) + + should_try = ( + prev == prev_alpha + and prev_alpha and curr_alpha + and prev_alpha.lower() not in _STOP_WORDS + and curr_alpha.lower() not in _STOP_WORDS + and min(len(prev_alpha), len(curr_alpha)) <= 5 + and len(prev_alpha) + len(curr_alpha) >= 4 + ) + + if should_try: + merged_alpha = prev_alpha + curr_alpha + hyph = hyph_de.inserted(merged_alpha, hyphen='-') + if '-' in hyph: + result[-1] = prev + curr + i += 1 + continue + + result.append(curr) + i += 1 + + return ' '.join(result) + + +def _syllabify_text(text: str, hyph_de, hyph_en) -> str: + """Syllabify all significant words in a text string. + + 1. Strip existing | dividers + 2. Merge pipe-gap spaces where possible + 3. Apply pyphen to each word >= 3 alphabetic chars + 4. Words pyphen doesn't recognize stay as-is (no bad guesses) + """ + if not text: + return text + + # Skip cells that contain IPA transcription characters outside brackets. + text_no_brackets = re.sub(r'\[[^\]]*\]', '', text) + if _IPA_RE.search(text_no_brackets): + return text + + # Phase 1: strip existing pipe dividers for clean normalization + clean = text.replace('|', '') + + # Phase 2: merge pipe-gap spaces (OCR fragments from pipe splitting) + clean = _try_merge_pipe_gaps(clean, hyph_de) + + # Phase 3: tokenize and syllabify each word + # Split on whitespace and comma/semicolon sequences, keeping separators + tokens = re.split(r'(\s+|[,;:]+\s*)', clean) + + result = [] + for tok in tokens: + if not tok or re.match(r'^[\s,;:]+$', tok): + result.append(tok) + continue + + # Strip trailing/leading punctuation for pyphen lookup + m = re.match(r'^([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)(.*?)([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)$', tok) + if not m: + result.append(tok) + continue + lead, word, trail = m.group(1), m.group(2), m.group(3) + + if len(word) < 3 or not re.search(r'[a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df]', word): + result.append(tok) + continue + + hyph = _hyphenate_word(word, hyph_de, hyph_en) + if hyph: + result.append(lead + hyph + trail) + else: + result.append(tok) + + return ''.join(result) + + +def insert_syllable_dividers( + zones_data: List[Dict], + img_bgr: np.ndarray, + session_id: str, + *, + force: bool = False, + col_filter: Optional[set] = None, +) -> int: + """Insert pipe syllable dividers into dictionary cells. + + For dictionary pages: process all content column cells, strip existing + pipes, merge pipe-gap spaces, and re-syllabify using pyphen. + + Pre-check: at least 1% of content cells must already contain ``|`` from + OCR. This guards against pages with zero pipe characters. + + Args: + force: If True, skip the pipe-ratio pre-check and syllabify all + content words regardless of whether the original has pipe dividers. + col_filter: If set, only process cells whose col_type is in this set. + None means process all content columns. + + Returns the number of cells modified. + """ + hyph_de, hyph_en = _get_hyphenators() + if hyph_de is None: + logger.warning("pyphen not installed -- skipping syllable insertion") + return 0 + + # Pre-check: count cells that already have | from OCR. + if not force: + total_col_cells = 0 + cells_with_pipes = 0 + for z in zones_data: + for cell in z.get("cells", []): + if cell.get("col_type", "").startswith("column_"): + total_col_cells += 1 + if "|" in cell.get("text", ""): + cells_with_pipes += 1 + + if total_col_cells > 0: + pipe_ratio = cells_with_pipes / total_col_cells + if pipe_ratio < 0.01: + logger.info( + "build-grid session %s: skipping syllable insertion -- " + "only %.1f%% of cells have existing pipes (need >=1%%)", + session_id, pipe_ratio * 100, + ) + return 0 + + insertions = 0 + for z in zones_data: + for cell in z.get("cells", []): + ct = cell.get("col_type", "") + if not ct.startswith("column_"): + continue + if col_filter is not None and ct not in col_filter: + continue + text = cell.get("text", "") + if not text: + continue + + # In auto mode (force=False), only normalize cells that already + # have | from OCR (i.e. printed syllable dividers on the original + # scan). Don't add new syllable marks to other words. + if not force and "|" not in text: + continue + + new_text = _syllabify_text(text, hyph_de, hyph_en) + if new_text != text: + cell["text"] = new_text + insertions += 1 + + if insertions: + logger.info( + "build-grid session %s: syllable dividers inserted/normalized " + "in %d cells (pyphen)", + session_id, insertions, + ) + return insertions diff --git a/klausur-service/backend/mail/aggregator.py b/klausur-service/backend/mail/aggregator.py index 081a61c..01756b7 100644 --- a/klausur-service/backend/mail/aggregator.py +++ b/klausur-service/backend/mail/aggregator.py @@ -1,52 +1,27 @@ """ -Mail Aggregator Service +Mail Aggregator Service — barrel re-export. + +All implementation split into: + aggregator_imap — IMAP connection, sync, email parsing + aggregator_smtp — SMTP connection, email sending Multi-account IMAP aggregation with async support. """ -import os -import ssl -import email import asyncio import logging -import smtplib -from typing import Optional, List, Dict, Any, Tuple -from datetime import datetime, timezone -from email.mime.text import MIMEText -from email.mime.multipart import MIMEMultipart -from email.header import decode_header, make_header -from email.utils import parsedate_to_datetime, parseaddr +from typing import Optional, List, Dict, Any -from .credentials import get_credentials_service, MailCredentials -from .mail_db import ( - get_email_accounts, - get_email_account, - update_account_status, - upsert_email, - get_unified_inbox, -) -from .models import ( - AccountStatus, - AccountTestResult, - AggregatedEmail, - EmailComposeRequest, - EmailSendResult, -) +from .credentials import get_credentials_service +from .mail_db import get_email_accounts, get_unified_inbox +from .models import AccountTestResult +from .aggregator_imap import IMAPMixin, IMAPConnectionError +from .aggregator_smtp import SMTPMixin, SMTPConnectionError logger = logging.getLogger(__name__) -class IMAPConnectionError(Exception): - """Raised when IMAP connection fails.""" - pass - - -class SMTPConnectionError(Exception): - """Raised when SMTP connection fails.""" - pass - - -class MailAggregator: +class MailAggregator(IMAPMixin, SMTPMixin): """ Aggregates emails from multiple IMAP accounts into a unified inbox. @@ -86,390 +61,29 @@ class MailAggregator: ) # Test IMAP - try: - import imaplib - - if imap_ssl: - imap = imaplib.IMAP4_SSL(imap_host, imap_port) - else: - imap = imaplib.IMAP4(imap_host, imap_port) - - imap.login(email_address, password) - result.imap_connected = True - - # List folders - status, folders = imap.list() - if status == "OK": - result.folders_found = [ - self._parse_folder_name(f) for f in folders if f - ] - - imap.logout() - - except Exception as e: - result.error_message = f"IMAP Error: {str(e)}" - logger.warning(f"IMAP test failed for {email_address}: {e}") + imap_ok, imap_err, folders = await self.test_imap_connection( + imap_host, imap_port, imap_ssl, email_address, password + ) + result.imap_connected = imap_ok + if folders: + result.folders_found = folders + if imap_err: + result.error_message = imap_err # Test SMTP - try: - if smtp_ssl: - smtp = smtplib.SMTP_SSL(smtp_host, smtp_port) - else: - smtp = smtplib.SMTP(smtp_host, smtp_port) - smtp.starttls() - - smtp.login(email_address, password) - result.smtp_connected = True - smtp.quit() - - except Exception as e: - smtp_error = f"SMTP Error: {str(e)}" + smtp_ok, smtp_err = await self.test_smtp_connection( + smtp_host, smtp_port, smtp_ssl, email_address, password + ) + result.smtp_connected = smtp_ok + if smtp_err: if result.error_message: - result.error_message += f"; {smtp_error}" + result.error_message += f"; {smtp_err}" else: - result.error_message = smtp_error - logger.warning(f"SMTP test failed for {email_address}: {e}") + result.error_message = smtp_err result.success = result.imap_connected and result.smtp_connected return result - def _parse_folder_name(self, folder_response: bytes) -> str: - """Parse folder name from IMAP LIST response.""" - try: - # Format: '(\\HasNoChildren) "/" "INBOX"' - decoded = folder_response.decode("utf-8") if isinstance(folder_response, bytes) else folder_response - parts = decoded.rsplit('" "', 1) - if len(parts) == 2: - return parts[1].rstrip('"') - return decoded - except Exception: - return str(folder_response) - - async def sync_account( - self, - account_id: str, - user_id: str, - max_emails: int = 100, - folders: Optional[List[str]] = None, - ) -> Tuple[int, int]: - """ - Sync emails from an IMAP account. - - Args: - account_id: The account ID - user_id: The user ID - max_emails: Maximum emails to fetch - folders: Specific folders to sync (default: INBOX) - - Returns: - Tuple of (new_emails, total_emails) - """ - import imaplib - - account = await get_email_account(account_id, user_id) - if not account: - raise ValueError(f"Account not found: {account_id}") - - # Get credentials - vault_path = account.get("vault_path", "") - creds = await self._credentials_service.get_credentials(account_id, vault_path) - if not creds: - await update_account_status(account_id, "error", "Credentials not found") - raise IMAPConnectionError("Credentials not found") - - new_count = 0 - total_count = 0 - - try: - # Connect to IMAP - if account["imap_ssl"]: - imap = imaplib.IMAP4_SSL(account["imap_host"], account["imap_port"]) - else: - imap = imaplib.IMAP4(account["imap_host"], account["imap_port"]) - - imap.login(creds.email, creds.password) - - # Sync specified folders or just INBOX - sync_folders = folders or ["INBOX"] - - for folder in sync_folders: - try: - status, _ = imap.select(folder) - if status != "OK": - continue - - # Search for recent emails - status, messages = imap.search(None, "ALL") - if status != "OK": - continue - - message_ids = messages[0].split() - total_count += len(message_ids) - - # Fetch most recent emails - recent_ids = message_ids[-max_emails:] if len(message_ids) > max_emails else message_ids - - for msg_id in recent_ids: - try: - email_data = await self._fetch_and_store_email( - imap, msg_id, account_id, user_id, account["tenant_id"], folder - ) - if email_data: - new_count += 1 - except Exception as e: - logger.warning(f"Failed to fetch email {msg_id}: {e}") - - except Exception as e: - logger.warning(f"Failed to sync folder {folder}: {e}") - - imap.logout() - - # Update account status - await update_account_status( - account_id, - "active", - email_count=total_count, - unread_count=new_count, # Will be recalculated - ) - - return new_count, total_count - - except Exception as e: - logger.error(f"Account sync failed: {e}") - await update_account_status(account_id, "error", str(e)) - raise IMAPConnectionError(str(e)) - - async def _fetch_and_store_email( - self, - imap, - msg_id: bytes, - account_id: str, - user_id: str, - tenant_id: str, - folder: str, - ) -> Optional[str]: - """Fetch a single email and store it in the database.""" - try: - status, msg_data = imap.fetch(msg_id, "(RFC822)") - if status != "OK" or not msg_data or not msg_data[0]: - return None - - raw_email = msg_data[0][1] - msg = email.message_from_bytes(raw_email) - - # Parse headers - message_id = msg.get("Message-ID", str(msg_id)) - subject = self._decode_header(msg.get("Subject", "")) - from_header = msg.get("From", "") - sender_name, sender_email = parseaddr(from_header) - sender_name = self._decode_header(sender_name) - - # Parse recipients - to_header = msg.get("To", "") - recipients = [addr[1] for addr in email.utils.getaddresses([to_header])] - - cc_header = msg.get("Cc", "") - cc = [addr[1] for addr in email.utils.getaddresses([cc_header])] - - # Parse dates - date_str = msg.get("Date") - try: - date_sent = parsedate_to_datetime(date_str) if date_str else datetime.now(timezone.utc) - except Exception: - date_sent = datetime.now(timezone.utc) - - date_received = datetime.now(timezone.utc) - - # Parse body - body_text, body_html, attachments = self._parse_body(msg) - - # Create preview - body_preview = (body_text[:200] + "...") if body_text and len(body_text) > 200 else body_text - - # Get headers dict - headers = {k: self._decode_header(v) for k, v in msg.items() if k not in ["Body"]} - - # Store in database - email_id = await upsert_email( - account_id=account_id, - user_id=user_id, - tenant_id=tenant_id, - message_id=message_id, - subject=subject, - sender_email=sender_email, - sender_name=sender_name, - recipients=recipients, - cc=cc, - body_preview=body_preview, - body_text=body_text, - body_html=body_html, - has_attachments=len(attachments) > 0, - attachments=attachments, - headers=headers, - folder=folder, - date_sent=date_sent, - date_received=date_received, - ) - - return email_id - - except Exception as e: - logger.error(f"Failed to parse email: {e}") - return None - - def _decode_header(self, header_value: str) -> str: - """Decode email header value.""" - if not header_value: - return "" - try: - decoded = decode_header(header_value) - return str(make_header(decoded)) - except Exception: - return str(header_value) - - def _parse_body(self, msg) -> Tuple[Optional[str], Optional[str], List[Dict]]: - """ - Parse email body and attachments. - - Returns: - Tuple of (body_text, body_html, attachments) - """ - body_text = None - body_html = None - attachments = [] - - if msg.is_multipart(): - for part in msg.walk(): - content_type = part.get_content_type() - content_disposition = str(part.get("Content-Disposition", "")) - - # Skip multipart containers - if content_type.startswith("multipart/"): - continue - - # Check for attachments - if "attachment" in content_disposition: - filename = part.get_filename() - if filename: - attachments.append({ - "filename": self._decode_header(filename), - "content_type": content_type, - "size": len(part.get_payload(decode=True) or b""), - }) - continue - - # Get body content - try: - payload = part.get_payload(decode=True) - charset = part.get_content_charset() or "utf-8" - - if payload: - text = payload.decode(charset, errors="replace") - - if content_type == "text/plain" and not body_text: - body_text = text - elif content_type == "text/html" and not body_html: - body_html = text - except Exception as e: - logger.debug(f"Failed to decode body part: {e}") - - else: - # Single part message - content_type = msg.get_content_type() - try: - payload = msg.get_payload(decode=True) - charset = msg.get_content_charset() or "utf-8" - - if payload: - text = payload.decode(charset, errors="replace") - - if content_type == "text/plain": - body_text = text - elif content_type == "text/html": - body_html = text - except Exception as e: - logger.debug(f"Failed to decode body: {e}") - - return body_text, body_html, attachments - - async def send_email( - self, - account_id: str, - user_id: str, - request: EmailComposeRequest, - ) -> EmailSendResult: - """ - Send an email via SMTP. - - Args: - account_id: The account to send from - user_id: The user ID - request: The compose request with recipients and content - - Returns: - EmailSendResult with success status - """ - account = await get_email_account(account_id, user_id) - if not account: - return EmailSendResult(success=False, error="Account not found") - - # Verify the account_id matches - if request.account_id != account_id: - return EmailSendResult(success=False, error="Account mismatch") - - # Get credentials - vault_path = account.get("vault_path", "") - creds = await self._credentials_service.get_credentials(account_id, vault_path) - if not creds: - return EmailSendResult(success=False, error="Credentials not found") - - try: - # Create message - if request.is_html: - msg = MIMEMultipart("alternative") - msg.attach(MIMEText(request.body, "html")) - else: - msg = MIMEText(request.body, "plain") - - msg["Subject"] = request.subject - msg["From"] = account["email"] - msg["To"] = ", ".join(request.to) - - if request.cc: - msg["Cc"] = ", ".join(request.cc) - - if request.reply_to_message_id: - msg["In-Reply-To"] = request.reply_to_message_id - msg["References"] = request.reply_to_message_id - - # Send via SMTP - if account["smtp_ssl"]: - smtp = smtplib.SMTP_SSL(account["smtp_host"], account["smtp_port"]) - else: - smtp = smtplib.SMTP(account["smtp_host"], account["smtp_port"]) - smtp.starttls() - - smtp.login(creds.email, creds.password) - - # All recipients - all_recipients = list(request.to) - if request.cc: - all_recipients.extend(request.cc) - if request.bcc: - all_recipients.extend(request.bcc) - - smtp.sendmail(account["email"], all_recipients, msg.as_string()) - smtp.quit() - - return EmailSendResult( - success=True, - message_id=msg.get("Message-ID"), - ) - - except Exception as e: - logger.error(f"Failed to send email: {e}") - return EmailSendResult(success=False, error=str(e)) - async def sync_all_accounts(self, user_id: str, tenant_id: Optional[str] = None) -> Dict[str, Any]: """ Sync all accounts for a user. diff --git a/klausur-service/backend/mail/aggregator_imap.py b/klausur-service/backend/mail/aggregator_imap.py new file mode 100644 index 0000000..9b5e259 --- /dev/null +++ b/klausur-service/backend/mail/aggregator_imap.py @@ -0,0 +1,322 @@ +""" +Mail Aggregator IMAP — IMAP connection, sync, email parsing. + +Extracted from aggregator.py for modularity. +""" + +import email +import logging +from typing import Optional, List, Dict, Any, Tuple +from datetime import datetime, timezone +from email.header import decode_header, make_header +from email.utils import parsedate_to_datetime, parseaddr + +from .mail_db import upsert_email, update_account_status, get_email_account + +logger = logging.getLogger(__name__) + + +class IMAPConnectionError(Exception): + """Raised when IMAP connection fails.""" + pass + + +class IMAPMixin: + """IMAP-related methods for MailAggregator. + + Provides connection testing, syncing, and email parsing. + Must be mixed into a class that has ``_credentials_service``. + """ + + def _parse_folder_name(self, folder_response: bytes) -> str: + """Parse folder name from IMAP LIST response.""" + try: + # Format: '(\\HasNoChildren) "/" "INBOX"' + decoded = folder_response.decode("utf-8") if isinstance(folder_response, bytes) else folder_response + parts = decoded.rsplit('" "', 1) + if len(parts) == 2: + return parts[1].rstrip('"') + return decoded + except Exception: + return str(folder_response) + + async def test_imap_connection( + self, + imap_host: str, + imap_port: int, + imap_ssl: bool, + email_address: str, + password: str, + ) -> Tuple[bool, Optional[str], Optional[List[str]]]: + """Test IMAP connection. Returns (success, error, folders).""" + try: + import imaplib + + if imap_ssl: + imap = imaplib.IMAP4_SSL(imap_host, imap_port) + else: + imap = imaplib.IMAP4(imap_host, imap_port) + + imap.login(email_address, password) + + # List folders + folders_found = None + status, folders = imap.list() + if status == "OK": + folders_found = [ + self._parse_folder_name(f) for f in folders if f + ] + + imap.logout() + return True, None, folders_found + + except Exception as e: + logger.warning(f"IMAP test failed for {email_address}: {e}") + return False, f"IMAP Error: {str(e)}", None + + async def sync_account( + self, + account_id: str, + user_id: str, + max_emails: int = 100, + folders: Optional[List[str]] = None, + ) -> Tuple[int, int]: + """ + Sync emails from an IMAP account. + + Args: + account_id: The account ID + user_id: The user ID + max_emails: Maximum emails to fetch + folders: Specific folders to sync (default: INBOX) + + Returns: + Tuple of (new_emails, total_emails) + """ + import imaplib + + account = await get_email_account(account_id, user_id) + if not account: + raise ValueError(f"Account not found: {account_id}") + + # Get credentials + vault_path = account.get("vault_path", "") + creds = await self._credentials_service.get_credentials(account_id, vault_path) + if not creds: + await update_account_status(account_id, "error", "Credentials not found") + raise IMAPConnectionError("Credentials not found") + + new_count = 0 + total_count = 0 + + try: + # Connect to IMAP + if account["imap_ssl"]: + imap = imaplib.IMAP4_SSL(account["imap_host"], account["imap_port"]) + else: + imap = imaplib.IMAP4(account["imap_host"], account["imap_port"]) + + imap.login(creds.email, creds.password) + + # Sync specified folders or just INBOX + sync_folders = folders or ["INBOX"] + + for folder in sync_folders: + try: + status, _ = imap.select(folder) + if status != "OK": + continue + + # Search for recent emails + status, messages = imap.search(None, "ALL") + if status != "OK": + continue + + message_ids = messages[0].split() + total_count += len(message_ids) + + # Fetch most recent emails + recent_ids = message_ids[-max_emails:] if len(message_ids) > max_emails else message_ids + + for msg_id in recent_ids: + try: + email_data = await self._fetch_and_store_email( + imap, msg_id, account_id, user_id, account["tenant_id"], folder + ) + if email_data: + new_count += 1 + except Exception as e: + logger.warning(f"Failed to fetch email {msg_id}: {e}") + + except Exception as e: + logger.warning(f"Failed to sync folder {folder}: {e}") + + imap.logout() + + # Update account status + await update_account_status( + account_id, + "active", + email_count=total_count, + unread_count=new_count, # Will be recalculated + ) + + return new_count, total_count + + except Exception as e: + logger.error(f"Account sync failed: {e}") + await update_account_status(account_id, "error", str(e)) + raise IMAPConnectionError(str(e)) + + async def _fetch_and_store_email( + self, + imap, + msg_id: bytes, + account_id: str, + user_id: str, + tenant_id: str, + folder: str, + ) -> Optional[str]: + """Fetch a single email and store it in the database.""" + try: + status, msg_data = imap.fetch(msg_id, "(RFC822)") + if status != "OK" or not msg_data or not msg_data[0]: + return None + + raw_email = msg_data[0][1] + msg = email.message_from_bytes(raw_email) + + # Parse headers + message_id = msg.get("Message-ID", str(msg_id)) + subject = self._decode_header(msg.get("Subject", "")) + from_header = msg.get("From", "") + sender_name, sender_email = parseaddr(from_header) + sender_name = self._decode_header(sender_name) + + # Parse recipients + to_header = msg.get("To", "") + recipients = [addr[1] for addr in email.utils.getaddresses([to_header])] + + cc_header = msg.get("Cc", "") + cc = [addr[1] for addr in email.utils.getaddresses([cc_header])] + + # Parse dates + date_str = msg.get("Date") + try: + date_sent = parsedate_to_datetime(date_str) if date_str else datetime.now(timezone.utc) + except Exception: + date_sent = datetime.now(timezone.utc) + + date_received = datetime.now(timezone.utc) + + # Parse body + body_text, body_html, attachments = self._parse_body(msg) + + # Create preview + body_preview = (body_text[:200] + "...") if body_text and len(body_text) > 200 else body_text + + # Get headers dict + headers = {k: self._decode_header(v) for k, v in msg.items() if k not in ["Body"]} + + # Store in database + email_id = await upsert_email( + account_id=account_id, + user_id=user_id, + tenant_id=tenant_id, + message_id=message_id, + subject=subject, + sender_email=sender_email, + sender_name=sender_name, + recipients=recipients, + cc=cc, + body_preview=body_preview, + body_text=body_text, + body_html=body_html, + has_attachments=len(attachments) > 0, + attachments=attachments, + headers=headers, + folder=folder, + date_sent=date_sent, + date_received=date_received, + ) + + return email_id + + except Exception as e: + logger.error(f"Failed to parse email: {e}") + return None + + def _decode_header(self, header_value: str) -> str: + """Decode email header value.""" + if not header_value: + return "" + try: + decoded = decode_header(header_value) + return str(make_header(decoded)) + except Exception: + return str(header_value) + + def _parse_body(self, msg) -> Tuple[Optional[str], Optional[str], List[Dict]]: + """ + Parse email body and attachments. + + Returns: + Tuple of (body_text, body_html, attachments) + """ + body_text = None + body_html = None + attachments = [] + + if msg.is_multipart(): + for part in msg.walk(): + content_type = part.get_content_type() + content_disposition = str(part.get("Content-Disposition", "")) + + # Skip multipart containers + if content_type.startswith("multipart/"): + continue + + # Check for attachments + if "attachment" in content_disposition: + filename = part.get_filename() + if filename: + attachments.append({ + "filename": self._decode_header(filename), + "content_type": content_type, + "size": len(part.get_payload(decode=True) or b""), + }) + continue + + # Get body content + try: + payload = part.get_payload(decode=True) + charset = part.get_content_charset() or "utf-8" + + if payload: + text = payload.decode(charset, errors="replace") + + if content_type == "text/plain" and not body_text: + body_text = text + elif content_type == "text/html" and not body_html: + body_html = text + except Exception as e: + logger.debug(f"Failed to decode body part: {e}") + + else: + # Single part message + content_type = msg.get_content_type() + try: + payload = msg.get_payload(decode=True) + charset = msg.get_content_charset() or "utf-8" + + if payload: + text = payload.decode(charset, errors="replace") + + if content_type == "text/plain": + body_text = text + elif content_type == "text/html": + body_html = text + except Exception as e: + logger.debug(f"Failed to decode body: {e}") + + return body_text, body_html, attachments diff --git a/klausur-service/backend/mail/aggregator_smtp.py b/klausur-service/backend/mail/aggregator_smtp.py new file mode 100644 index 0000000..2f67a3e --- /dev/null +++ b/klausur-service/backend/mail/aggregator_smtp.py @@ -0,0 +1,131 @@ +""" +Mail Aggregator SMTP — email sending via SMTP. + +Extracted from aggregator.py for modularity. +""" + +import logging +import smtplib +from typing import Optional, List, Dict, Any +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart + +from .mail_db import get_email_account +from .models import EmailComposeRequest, EmailSendResult + +logger = logging.getLogger(__name__) + + +class SMTPConnectionError(Exception): + """Raised when SMTP connection fails.""" + pass + + +class SMTPMixin: + """SMTP-related methods for MailAggregator. + + Provides SMTP connection testing and email sending. + Must be mixed into a class that has ``_credentials_service``. + """ + + async def test_smtp_connection( + self, + smtp_host: str, + smtp_port: int, + smtp_ssl: bool, + email_address: str, + password: str, + ) -> tuple: + """Test SMTP connection. Returns (success, error).""" + try: + if smtp_ssl: + smtp = smtplib.SMTP_SSL(smtp_host, smtp_port) + else: + smtp = smtplib.SMTP(smtp_host, smtp_port) + smtp.starttls() + + smtp.login(email_address, password) + smtp.quit() + return True, None + + except Exception as e: + logger.warning(f"SMTP test failed for {email_address}: {e}") + return False, f"SMTP Error: {str(e)}" + + async def send_email( + self, + account_id: str, + user_id: str, + request: EmailComposeRequest, + ) -> EmailSendResult: + """ + Send an email via SMTP. + + Args: + account_id: The account to send from + user_id: The user ID + request: The compose request with recipients and content + + Returns: + EmailSendResult with success status + """ + account = await get_email_account(account_id, user_id) + if not account: + return EmailSendResult(success=False, error="Account not found") + + # Verify the account_id matches + if request.account_id != account_id: + return EmailSendResult(success=False, error="Account mismatch") + + # Get credentials + vault_path = account.get("vault_path", "") + creds = await self._credentials_service.get_credentials(account_id, vault_path) + if not creds: + return EmailSendResult(success=False, error="Credentials not found") + + try: + # Create message + if request.is_html: + msg = MIMEMultipart("alternative") + msg.attach(MIMEText(request.body, "html")) + else: + msg = MIMEText(request.body, "plain") + + msg["Subject"] = request.subject + msg["From"] = account["email"] + msg["To"] = ", ".join(request.to) + + if request.cc: + msg["Cc"] = ", ".join(request.cc) + + if request.reply_to_message_id: + msg["In-Reply-To"] = request.reply_to_message_id + msg["References"] = request.reply_to_message_id + + # Send via SMTP + if account["smtp_ssl"]: + smtp = smtplib.SMTP_SSL(account["smtp_host"], account["smtp_port"]) + else: + smtp = smtplib.SMTP(account["smtp_host"], account["smtp_port"]) + smtp.starttls() + + smtp.login(creds.email, creds.password) + + # All recipients + all_recipients = list(request.to) + if request.cc: + all_recipients.extend(request.cc) + if request.bcc: + all_recipients.extend(request.bcc) + + smtp.sendmail(account["email"], all_recipients, msg.as_string()) + smtp.quit() + + return EmailSendResult( + success=True, + message_id=msg.get("Message-ID"), + ) + + except Exception as e: + logger.error(f"Failed to send email: {e}") + return EmailSendResult(success=False, error=str(e)) diff --git a/klausur-service/backend/nibis_ingestion.py b/klausur-service/backend/nibis_ingestion.py index 3fe22e0..63b2f23 100644 --- a/klausur-service/backend/nibis_ingestion.py +++ b/klausur-service/backend/nibis_ingestion.py @@ -10,12 +10,11 @@ Unterstützt: """ import os -import re import zipfile import hashlib import json from pathlib import Path -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional from dataclasses import dataclass, asdict from datetime import datetime import asyncio @@ -23,6 +22,7 @@ import asyncio # Local imports from eh_pipeline import chunk_text, generate_embeddings, extract_text_from_pdf, get_vector_size, EMBEDDING_BACKEND from qdrant_service import QdrantService +from nibis_parsers import parse_filename_old_format, parse_filename_new_format # Configuration DOCS_BASE_PATH = Path("/Users/benjaminadmin/projekte/breakpilot-pwa/docs") @@ -87,15 +87,6 @@ SUBJECT_MAPPING = { "gespfl": "Gesundheit-Pflege", } -# Niveau-Mapping -NIVEAU_MAPPING = { - "ea": "eA", # erhöhtes Anforderungsniveau - "ga": "gA", # grundlegendes Anforderungsniveau - "neuga": "gA (neu einsetzend)", - "neuea": "eA (neu einsetzend)", -} - - def compute_file_hash(file_path: Path) -> str: """Berechnet SHA-256 Hash einer Datei.""" sha256 = hashlib.sha256() @@ -135,103 +126,6 @@ def extract_zip_files(base_path: Path) -> List[Path]: return extracted -def parse_filename_old_format(filename: str, file_path: Path) -> Optional[Dict]: - """ - Parst alte Namenskonvention (2016, 2017): - - {Jahr}{Fach}{Niveau}Lehrer/{Jahr}{Fach}{Niveau}A{Nr}L.pdf - - Beispiel: 2016DeutschEALehrer/2016DeutschEAA1L.pdf - """ - # Pattern für Lehrer-Dateien - pattern = r"(\d{4})([A-Za-zäöüÄÖÜ]+)(EA|GA|NeuGA|NeuEA)(?:Lehrer)?.*?(?:A(\d+)|Aufg(\d+))?L?\.pdf$" - - match = re.search(pattern, filename, re.IGNORECASE) - if not match: - return None - - year = int(match.group(1)) - subject_raw = match.group(2).lower() - niveau = match.group(3).upper() - task_num = match.group(4) or match.group(5) - - # Prüfe ob es ein Lehrer-Dokument ist (EWH) - is_ewh = "lehrer" in str(file_path).lower() or filename.endswith("L.pdf") - - # Extrahiere Variante (Tech, Wirt, CAS, GTR, etc.) - variant = None - variant_patterns = ["Tech", "Wirt", "CAS", "GTR", "Pflicht", "BG", "mitExp", "ohneExp"] - for v in variant_patterns: - if v.lower() in str(file_path).lower(): - variant = v - break - - return { - "year": year, - "subject": subject_raw, - "niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau), - "task_number": int(task_num) if task_num else None, - "doc_type": "EWH" if is_ewh else "Aufgabe", - "variant": variant, - } - - -def parse_filename_new_format(filename: str, file_path: Path) -> Optional[Dict]: - """ - Parst neue Namenskonvention (2024, 2025): - - {Jahr}_{Fach}_{niveau}_{Nr}_EWH.pdf - - Beispiel: 2025_Deutsch_eA_I_EWH.pdf - """ - # Pattern für neue Dateien - pattern = r"(\d{4})_([A-Za-zäöüÄÖÜ]+)(?:BG)?_(eA|gA)(?:_([IVX\d]+))?(?:_(.+))?\.pdf$" - - match = re.search(pattern, filename, re.IGNORECASE) - if not match: - return None - - year = int(match.group(1)) - subject_raw = match.group(2).lower() - niveau = match.group(3) - task_id = match.group(4) - suffix = match.group(5) or "" - - # Task-Nummer aus römischen Zahlen - task_num = None - if task_id: - roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5} - task_num = roman_map.get(task_id) or (int(task_id) if task_id.isdigit() else None) - - # Dokumenttyp - is_ewh = "EWH" in filename or "ewh" in filename.lower() - - # Spezielle Dokumenttypen - doc_type = "EWH" if is_ewh else "Aufgabe" - if "Material" in suffix: - doc_type = "Material" - elif "GBU" in suffix: - doc_type = "GBU" - elif "Ergebnis" in suffix: - doc_type = "Ergebnis" - elif "Bewertungsbogen" in suffix: - doc_type = "Bewertungsbogen" - elif "HV" in suffix: - doc_type = "Hörverstehen" - elif "ME" in suffix: - doc_type = "Mediation" - - # BG Variante - variant = "BG" if "BG" in filename else None - if "mitExp" in str(file_path): - variant = "mitExp" - - return { - "year": year, - "subject": subject_raw, - "niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau), - "task_number": task_num, - "doc_type": doc_type, - "variant": variant, - } - - def discover_documents(base_path: Path, ewh_only: bool = True) -> List[NiBiSDocument]: """ Findet alle relevanten Dokumente in den za-download Verzeichnissen. diff --git a/klausur-service/backend/nibis_parsers.py b/klausur-service/backend/nibis_parsers.py new file mode 100644 index 0000000..f65adff --- /dev/null +++ b/klausur-service/backend/nibis_parsers.py @@ -0,0 +1,113 @@ +""" +NiBiS Filename Parsers + +Parses old and new naming conventions for NiBiS Abitur documents. +""" + +import re +from typing import Dict, Optional + +# Niveau-Mapping +NIVEAU_MAPPING = { + "ea": "eA", # erhoehtes Anforderungsniveau + "ga": "gA", # grundlegendes Anforderungsniveau + "neuga": "gA (neu einsetzend)", + "neuea": "eA (neu einsetzend)", +} + + +def parse_filename_old_format(filename: str, file_path) -> Optional[Dict]: + """ + Parst alte Namenskonvention (2016, 2017): + - {Jahr}{Fach}{Niveau}Lehrer/{Jahr}{Fach}{Niveau}A{Nr}L.pdf + - Beispiel: 2016DeutschEALehrer/2016DeutschEAA1L.pdf + """ + # Pattern fuer Lehrer-Dateien + pattern = r"(\d{4})([A-Za-z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc]+)(EA|GA|NeuGA|NeuEA)(?:Lehrer)?.*?(?:A(\d+)|Aufg(\d+))?L?\.pdf$" + + match = re.search(pattern, filename, re.IGNORECASE) + if not match: + return None + + year = int(match.group(1)) + subject_raw = match.group(2).lower() + niveau = match.group(3).upper() + task_num = match.group(4) or match.group(5) + + # Pruefe ob es ein Lehrer-Dokument ist (EWH) + is_ewh = "lehrer" in str(file_path).lower() or filename.endswith("L.pdf") + + # Extrahiere Variante (Tech, Wirt, CAS, GTR, etc.) + variant = None + variant_patterns = ["Tech", "Wirt", "CAS", "GTR", "Pflicht", "BG", "mitExp", "ohneExp"] + for v in variant_patterns: + if v.lower() in str(file_path).lower(): + variant = v + break + + return { + "year": year, + "subject": subject_raw, + "niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau), + "task_number": int(task_num) if task_num else None, + "doc_type": "EWH" if is_ewh else "Aufgabe", + "variant": variant, + } + + +def parse_filename_new_format(filename: str, file_path) -> Optional[Dict]: + """ + Parst neue Namenskonvention (2024, 2025): + - {Jahr}_{Fach}_{niveau}_{Nr}_EWH.pdf + - Beispiel: 2025_Deutsch_eA_I_EWH.pdf + """ + # Pattern fuer neue Dateien + pattern = r"(\d{4})_([A-Za-z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc]+)(?:BG)?_(eA|gA)(?:_([IVX\d]+))?(?:_(.+))?\.pdf$" + + match = re.search(pattern, filename, re.IGNORECASE) + if not match: + return None + + year = int(match.group(1)) + subject_raw = match.group(2).lower() + niveau = match.group(3) + task_id = match.group(4) + suffix = match.group(5) or "" + + # Task-Nummer aus roemischen Zahlen + task_num = None + if task_id: + roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5} + task_num = roman_map.get(task_id) or (int(task_id) if task_id.isdigit() else None) + + # Dokumenttyp + is_ewh = "EWH" in filename or "ewh" in filename.lower() + + # Spezielle Dokumenttypen + doc_type = "EWH" if is_ewh else "Aufgabe" + if "Material" in suffix: + doc_type = "Material" + elif "GBU" in suffix: + doc_type = "GBU" + elif "Ergebnis" in suffix: + doc_type = "Ergebnis" + elif "Bewertungsbogen" in suffix: + doc_type = "Bewertungsbogen" + elif "HV" in suffix: + doc_type = "Hoerverstehen" + elif "ME" in suffix: + doc_type = "Mediation" + + # BG Variante + variant = "BG" if "BG" in filename else None + if "mitExp" in str(file_path): + variant = "mitExp" + + return { + "year": year, + "subject": subject_raw, + "niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau), + "task_number": task_num, + "doc_type": doc_type, + "variant": variant, + } diff --git a/klausur-service/backend/nru_worksheet_generator.py b/klausur-service/backend/nru_worksheet_generator.py index d75715b..e3a79ef 100644 --- a/klausur-service/backend/nru_worksheet_generator.py +++ b/klausur-service/backend/nru_worksheet_generator.py @@ -1,557 +1,26 @@ """ -NRU Worksheet Generator - Generate vocabulary worksheets in NRU format. +NRU Worksheet Generator — barrel re-export. -Format: -- Page 1 (Vokabeln): 3-column table - - Column 1: English vocabulary - - Column 2: Empty (child writes German translation) - - Column 3: Empty (child writes corrected English after parent review) - -- Page 2 (Lernsätze): Full-width table - - Row 1: German sentence (pre-filled) - - Row 2-3: Empty lines (child writes English translation) +All implementation split into: + nru_worksheet_models — data classes, entry separation + nru_worksheet_html — HTML generation + nru_worksheet_pdf — PDF generation Per scanned page, we generate 2 worksheet pages. """ -import io -import logging -from typing import List, Dict, Tuple -from dataclasses import dataclass +# Models +from nru_worksheet_models import ( # noqa: F401 + VocabEntry, + SentenceEntry, + separate_vocab_and_sentences, +) -logger = logging.getLogger(__name__) +# HTML generation +from nru_worksheet_html import ( # noqa: F401 + generate_nru_html, + generate_nru_worksheet_html, +) - -@dataclass -class VocabEntry: - english: str - german: str - source_page: int = 1 - - -@dataclass -class SentenceEntry: - german: str - english: str # For solution sheet - source_page: int = 1 - - -def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]: - """ - Separate vocabulary entries into single words/phrases and full sentences. - - Sentences are identified by: - - Ending with punctuation (. ! ?) - - Being longer than 40 characters - - Containing multiple words with capital letters mid-sentence - """ - vocab_list = [] - sentence_list = [] - - for entry in entries: - english = entry.get("english", "").strip() - german = entry.get("german", "").strip() - source_page = entry.get("source_page", 1) - - if not english or not german: - continue - - # Detect if this is a sentence - is_sentence = ( - english.endswith('.') or - english.endswith('!') or - english.endswith('?') or - len(english) > 50 or - (len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w)) - ) - - if is_sentence: - sentence_list.append(SentenceEntry( - german=german, - english=english, - source_page=source_page - )) - else: - vocab_list.append(VocabEntry( - english=english, - german=german, - source_page=source_page - )) - - return vocab_list, sentence_list - - -def generate_nru_html( - vocab_list: List[VocabEntry], - sentence_list: List[SentenceEntry], - page_number: int, - title: str = "Vokabeltest", - show_solutions: bool = False, - line_height_px: int = 28 -) -> str: - """ - Generate HTML for NRU-format worksheet. - - Returns HTML for 2 pages: - - Page 1: Vocabulary table (3 columns) - - Page 2: Sentence practice (full width) - """ - - # Filter by page - page_vocab = [v for v in vocab_list if v.source_page == page_number] - page_sentences = [s for s in sentence_list if s.source_page == page_number] - - html = f""" - - - - - - -""" - - # ========== PAGE 1: VOCABULARY TABLE ========== - if page_vocab: - html += f""" -
-
-

{title} - Vokabeln (Seite {page_number})

-
Name: _________________________ Datum: _____________
-
- - - - - - - - - - -""" - for v in page_vocab: - if show_solutions: - html += f""" - - - - - -""" - else: - html += f""" - - - - - -""" - - html += """ - -
EnglischDeutschKorrektur
{v.english}{v.german}
{v.english}
-
Vokabeln aus Unit
-
-""" - - # ========== PAGE 2: SENTENCE PRACTICE ========== - if page_sentences: - html += f""" -
-
-

{title} - Lernsaetze (Seite {page_number})

-
Name: _________________________ Datum: _____________
-
-""" - for s in page_sentences: - html += f""" - - - - -""" - if show_solutions: - html += f""" - - - - - - -""" - else: - html += """ - - - - - - -""" - html += """ -
{s.german}
{s.english}
-""" - - html += """ -
Lernsaetze aus Unit
-
-""" - - html += """ - - -""" - return html - - -def generate_nru_worksheet_html( - entries: List[Dict], - title: str = "Vokabeltest", - show_solutions: bool = False, - specific_pages: List[int] = None -) -> str: - """ - Generate complete NRU worksheet HTML for all pages. - - Args: - entries: List of vocabulary entries with source_page - title: Worksheet title - show_solutions: Whether to show answers - specific_pages: List of specific page numbers to include (1-indexed) - - Returns: - Complete HTML document - """ - # Separate into vocab and sentences - vocab_list, sentence_list = separate_vocab_and_sentences(entries) - - # Get unique page numbers - all_pages = set() - for v in vocab_list: - all_pages.add(v.source_page) - for s in sentence_list: - all_pages.add(s.source_page) - - # Filter to specific pages if requested - if specific_pages: - all_pages = all_pages.intersection(set(specific_pages)) - - pages_sorted = sorted(all_pages) - - logger.info(f"Generating NRU worksheet for pages {pages_sorted}") - logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}") - - # Generate HTML for each page - combined_html = """ - - - - - - -""" - - for page_num in pages_sorted: - page_vocab = [v for v in vocab_list if v.source_page == page_num] - page_sentences = [s for s in sentence_list if s.source_page == page_num] - - # PAGE 1: VOCABULARY TABLE - if page_vocab: - combined_html += f""" -
-
-

{title} - Vokabeln (Seite {page_num})

-
Name: _________________________ Datum: _____________
-
- - - - - - - - - - -""" - for v in page_vocab: - if show_solutions: - combined_html += f""" - - - - - -""" - else: - combined_html += f""" - - - - - -""" - - combined_html += f""" - -
EnglischDeutschKorrektur
{v.english}{v.german}
{v.english}
-
{title} - Seite {page_num}
-
-""" - - # PAGE 2: SENTENCE PRACTICE - if page_sentences: - combined_html += f""" -
-
-

{title} - Lernsaetze (Seite {page_num})

-
Name: _________________________ Datum: _____________
-
-""" - for s in page_sentences: - combined_html += f""" - - - - -""" - if show_solutions: - combined_html += f""" - - - - - - -""" - else: - combined_html += """ - - - - - - -""" - combined_html += """ -
{s.german}
{s.english}
-""" - - combined_html += f""" -
{title} - Seite {page_num}
-
-""" - - combined_html += """ - - -""" - return combined_html - - -async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]: - """ - Generate NRU worksheet PDFs. - - Returns: - Tuple of (worksheet_pdf_bytes, solution_pdf_bytes) - """ - from weasyprint import HTML - - # Generate worksheet HTML - worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False) - worksheet_pdf = HTML(string=worksheet_html).write_pdf() - - # Generate solution HTML - solution_pdf = None - if include_solutions: - solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True) - solution_pdf = HTML(string=solution_html).write_pdf() - - return worksheet_pdf, solution_pdf +# PDF generation +from nru_worksheet_pdf import generate_nru_pdf # noqa: F401 diff --git a/klausur-service/backend/nru_worksheet_html.py b/klausur-service/backend/nru_worksheet_html.py new file mode 100644 index 0000000..8d881de --- /dev/null +++ b/klausur-service/backend/nru_worksheet_html.py @@ -0,0 +1,466 @@ +""" +NRU Worksheet HTML — HTML generation for vocabulary worksheets. + +Extracted from nru_worksheet_generator.py for modularity. +""" + +import logging +from typing import List, Dict + +from nru_worksheet_models import VocabEntry, SentenceEntry, separate_vocab_and_sentences + +logger = logging.getLogger(__name__) + + +def generate_nru_html( + vocab_list: List[VocabEntry], + sentence_list: List[SentenceEntry], + page_number: int, + title: str = "Vokabeltest", + show_solutions: bool = False, + line_height_px: int = 28 +) -> str: + """ + Generate HTML for NRU-format worksheet. + + Returns HTML for 2 pages: + - Page 1: Vocabulary table (3 columns) + - Page 2: Sentence practice (full width) + """ + + # Filter by page + page_vocab = [v for v in vocab_list if v.source_page == page_number] + page_sentences = [s for s in sentence_list if s.source_page == page_number] + + html = f""" + + + + + + +""" + + # ========== PAGE 1: VOCABULARY TABLE ========== + if page_vocab: + html += f""" +
+
+

{title} - Vokabeln (Seite {page_number})

+
Name: _________________________ Datum: _____________
+
+ + + + + + + + + + +""" + for v in page_vocab: + if show_solutions: + html += f""" + + + + + +""" + else: + html += f""" + + + + + +""" + + html += """ + +
EnglischDeutschKorrektur
{v.english}{v.german}
{v.english}
+
Vokabeln aus Unit
+
+""" + + # ========== PAGE 2: SENTENCE PRACTICE ========== + if page_sentences: + html += f""" +
+
+

{title} - Lernsaetze (Seite {page_number})

+
Name: _________________________ Datum: _____________
+
+""" + for s in page_sentences: + html += f""" + + + + +""" + if show_solutions: + html += f""" + + + + + + +""" + else: + html += """ + + + + + + +""" + html += """ +
{s.german}
{s.english}
+""" + + html += """ +
Lernsaetze aus Unit
+
+""" + + html += """ + + +""" + return html + + +def generate_nru_worksheet_html( + entries: List[Dict], + title: str = "Vokabeltest", + show_solutions: bool = False, + specific_pages: List[int] = None +) -> str: + """ + Generate complete NRU worksheet HTML for all pages. + + Args: + entries: List of vocabulary entries with source_page + title: Worksheet title + show_solutions: Whether to show answers + specific_pages: List of specific page numbers to include (1-indexed) + + Returns: + Complete HTML document + """ + # Separate into vocab and sentences + vocab_list, sentence_list = separate_vocab_and_sentences(entries) + + # Get unique page numbers + all_pages = set() + for v in vocab_list: + all_pages.add(v.source_page) + for s in sentence_list: + all_pages.add(s.source_page) + + # Filter to specific pages if requested + if specific_pages: + all_pages = all_pages.intersection(set(specific_pages)) + + pages_sorted = sorted(all_pages) + + logger.info(f"Generating NRU worksheet for pages {pages_sorted}") + logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}") + + # Generate HTML for each page + combined_html = """ + + + + + + +""" + + for page_num in pages_sorted: + page_vocab = [v for v in vocab_list if v.source_page == page_num] + page_sentences = [s for s in sentence_list if s.source_page == page_num] + + # PAGE 1: VOCABULARY TABLE + if page_vocab: + combined_html += f""" +
+
+

{title} - Vokabeln (Seite {page_num})

+
Name: _________________________ Datum: _____________
+
+ + + + + + + + + + +""" + for v in page_vocab: + if show_solutions: + combined_html += f""" + + + + + +""" + else: + combined_html += f""" + + + + + +""" + + combined_html += f""" + +
EnglischDeutschKorrektur
{v.english}{v.german}
{v.english}
+
{title} - Seite {page_num}
+
+""" + + # PAGE 2: SENTENCE PRACTICE + if page_sentences: + combined_html += f""" +
+
+

{title} - Lernsaetze (Seite {page_num})

+
Name: _________________________ Datum: _____________
+
+""" + for s in page_sentences: + combined_html += f""" + + + + +""" + if show_solutions: + combined_html += f""" + + + + + + +""" + else: + combined_html += """ + + + + + + +""" + combined_html += """ +
{s.german}
{s.english}
+""" + + combined_html += f""" +
{title} - Seite {page_num}
+
+""" + + combined_html += """ + + +""" + return combined_html diff --git a/klausur-service/backend/nru_worksheet_models.py b/klausur-service/backend/nru_worksheet_models.py new file mode 100644 index 0000000..1276bfe --- /dev/null +++ b/klausur-service/backend/nru_worksheet_models.py @@ -0,0 +1,70 @@ +""" +NRU Worksheet Models — data classes and entry separation logic. + +Extracted from nru_worksheet_generator.py for modularity. +""" + +import logging +from typing import List, Dict, Tuple +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class VocabEntry: + english: str + german: str + source_page: int = 1 + + +@dataclass +class SentenceEntry: + german: str + english: str # For solution sheet + source_page: int = 1 + + +def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]: + """ + Separate vocabulary entries into single words/phrases and full sentences. + + Sentences are identified by: + - Ending with punctuation (. ! ?) + - Being longer than 40 characters + - Containing multiple words with capital letters mid-sentence + """ + vocab_list = [] + sentence_list = [] + + for entry in entries: + english = entry.get("english", "").strip() + german = entry.get("german", "").strip() + source_page = entry.get("source_page", 1) + + if not english or not german: + continue + + # Detect if this is a sentence + is_sentence = ( + english.endswith('.') or + english.endswith('!') or + english.endswith('?') or + len(english) > 50 or + (len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w)) + ) + + if is_sentence: + sentence_list.append(SentenceEntry( + german=german, + english=english, + source_page=source_page + )) + else: + vocab_list.append(VocabEntry( + english=english, + german=german, + source_page=source_page + )) + + return vocab_list, sentence_list diff --git a/klausur-service/backend/nru_worksheet_pdf.py b/klausur-service/backend/nru_worksheet_pdf.py new file mode 100644 index 0000000..ceebc1a --- /dev/null +++ b/klausur-service/backend/nru_worksheet_pdf.py @@ -0,0 +1,31 @@ +""" +NRU Worksheet PDF — PDF generation using weasyprint. + +Extracted from nru_worksheet_generator.py for modularity. +""" + +from typing import List, Dict, Tuple + +from nru_worksheet_html import generate_nru_worksheet_html + + +async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]: + """ + Generate NRU worksheet PDFs. + + Returns: + Tuple of (worksheet_pdf_bytes, solution_pdf_bytes) + """ + from weasyprint import HTML + + # Generate worksheet HTML + worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False) + worksheet_pdf = HTML(string=worksheet_html).write_pdf() + + # Generate solution HTML + solution_pdf = None + if include_solutions: + solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True) + solution_pdf = HTML(string=solution_html).write_pdf() + + return worksheet_pdf, solution_pdf diff --git a/klausur-service/backend/ocr_pipeline_overlay_grid.py b/klausur-service/backend/ocr_pipeline_overlay_grid.py new file mode 100644 index 0000000..769ef0f --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_overlay_grid.py @@ -0,0 +1,333 @@ +""" +Overlay rendering for columns, rows, and words (grid-based overlays). + +Extracted from ocr_pipeline_overlays.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict, List + +import cv2 +import numpy as np +from fastapi import HTTPException +from fastapi.responses import Response + +from ocr_pipeline_common import _get_base_image_png +from ocr_pipeline_session_store import get_session_db +from ocr_pipeline_rows import _draw_box_exclusion_overlay + +logger = logging.getLogger(__name__) + + +async def _get_columns_overlay(session_id: str) -> Response: + """Generate cropped (or dewarped) image with column borders drawn on it.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + if not column_result or not column_result.get("columns"): + raise HTTPException(status_code=404, detail="No column data available") + + # Load best available base image (cropped > dewarped > original) + base_png = await _get_base_image_png(session_id) + if not base_png: + raise HTTPException(status_code=404, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + # Color map for region types (BGR) + colors = { + "column_en": (255, 180, 0), # Blue + "column_de": (0, 200, 0), # Green + "column_example": (0, 140, 255), # Orange + "column_text": (200, 200, 0), # Cyan/Turquoise + "page_ref": (200, 0, 200), # Purple + "column_marker": (0, 0, 220), # Red + "column_ignore": (180, 180, 180), # Light Gray + "header": (128, 128, 128), # Gray + "footer": (128, 128, 128), # Gray + "margin_top": (100, 100, 100), # Dark Gray + "margin_bottom": (100, 100, 100), # Dark Gray + } + + overlay = img.copy() + for col in column_result["columns"]: + x, y = col["x"], col["y"] + w, h = col["width"], col["height"] + color = colors.get(col.get("type", ""), (200, 200, 200)) + + # Semi-transparent fill + cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) + + # Solid border + cv2.rectangle(img, (x, y), (x + w, y + h), color, 3) + + # Label with confidence + label = col.get("type", "unknown").replace("column_", "").upper() + conf = col.get("classification_confidence") + if conf is not None and conf < 1.0: + label = f"{label} {int(conf * 100)}%" + cv2.putText(img, label, (x + 10, y + 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2) + + # Blend overlay at 20% opacity + cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img) + + # Draw detected box boundaries as dashed rectangles + zones = column_result.get("zones") or [] + for zone in zones: + if zone.get("zone_type") == "box" and zone.get("box"): + box = zone["box"] + bx, by = box["x"], box["y"] + bw, bh = box["width"], box["height"] + box_color = (0, 200, 255) # Yellow (BGR) + # Draw dashed rectangle by drawing short line segments + dash_len = 15 + for edge_x in range(bx, bx + bw, dash_len * 2): + end_x = min(edge_x + dash_len, bx + bw) + cv2.line(img, (edge_x, by), (end_x, by), box_color, 2) + cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2) + for edge_y in range(by, by + bh, dash_len * 2): + end_y = min(edge_y + dash_len, by + bh) + cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2) + cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2) + cv2.putText(img, "BOX", (bx + 10, by + bh - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2) + + # Red semi-transparent overlay for box zones + _draw_box_exclusion_overlay(img, zones) + + success, result_png = cv2.imencode(".png", img) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode overlay image") + + return Response(content=result_png.tobytes(), media_type="image/png") + + +async def _get_rows_overlay(session_id: str) -> Response: + """Generate cropped (or dewarped) image with row bands drawn on it.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + row_result = session.get("row_result") + if not row_result or not row_result.get("rows"): + raise HTTPException(status_code=404, detail="No row data available") + + # Load best available base image (cropped > dewarped > original) + base_png = await _get_base_image_png(session_id) + if not base_png: + raise HTTPException(status_code=404, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + # Color map for row types (BGR) + row_colors = { + "content": (255, 180, 0), # Blue + "header": (128, 128, 128), # Gray + "footer": (128, 128, 128), # Gray + "margin_top": (100, 100, 100), # Dark Gray + "margin_bottom": (100, 100, 100), # Dark Gray + } + + overlay = img.copy() + for row in row_result["rows"]: + x, y = row["x"], row["y"] + w, h = row["width"], row["height"] + row_type = row.get("row_type", "content") + color = row_colors.get(row_type, (200, 200, 200)) + + # Semi-transparent fill + cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) + + # Solid border + cv2.rectangle(img, (x, y), (x + w, y + h), color, 2) + + # Label + idx = row.get("index", 0) + label = f"R{idx} {row_type.upper()}" + wc = row.get("word_count", 0) + if wc: + label = f"{label} ({wc}w)" + cv2.putText(img, label, (x + 5, y + 18), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) + + # Blend overlay at 15% opacity + cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) + + # Draw zone separator lines if zones exist + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + if zones: + img_w_px = img.shape[1] + zone_color = (0, 200, 255) # Yellow (BGR) + dash_len = 20 + for zone in zones: + if zone.get("zone_type") == "box": + zy = zone["y"] + zh = zone["height"] + for line_y in [zy, zy + zh]: + for sx in range(0, img_w_px, dash_len * 2): + ex = min(sx + dash_len, img_w_px) + cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2) + + # Red semi-transparent overlay for box zones + _draw_box_exclusion_overlay(img, zones) + + success, result_png = cv2.imencode(".png", img) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode overlay image") + + return Response(content=result_png.tobytes(), media_type="image/png") + + +async def _get_words_overlay(session_id: str) -> Response: + """Generate cropped (or dewarped) image with cell grid drawn on it.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + word_result = session.get("word_result") + if not word_result: + raise HTTPException(status_code=404, detail="No word data available") + + # Support both new cell-based and legacy entry-based formats + cells = word_result.get("cells") + if not cells and not word_result.get("entries"): + raise HTTPException(status_code=404, detail="No word data available") + + # Load best available base image (cropped > dewarped > original) + base_png = await _get_base_image_png(session_id) + if not base_png: + raise HTTPException(status_code=404, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + img_h, img_w = img.shape[:2] + + overlay = img.copy() + + if cells: + # New cell-based overlay: color by column index + col_palette = [ + (255, 180, 0), # Blue (BGR) + (0, 200, 0), # Green + (0, 140, 255), # Orange + (200, 100, 200), # Purple + (200, 200, 0), # Cyan + (100, 200, 200), # Yellow-ish + ] + + for cell in cells: + bbox = cell.get("bbox_px", {}) + cx = bbox.get("x", 0) + cy = bbox.get("y", 0) + cw = bbox.get("w", 0) + ch = bbox.get("h", 0) + if cw <= 0 or ch <= 0: + continue + + col_idx = cell.get("col_index", 0) + color = col_palette[col_idx % len(col_palette)] + + # Cell rectangle border + cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1) + # Semi-transparent fill + cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1) + + # Cell-ID label (top-left corner) + cell_id = cell.get("cell_id", "") + cv2.putText(img, cell_id, (cx + 2, cy + 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1) + + # Text label (bottom of cell) + text = cell.get("text", "") + if text: + conf = cell.get("confidence", 0) + if conf >= 70: + text_color = (0, 180, 0) + elif conf >= 50: + text_color = (0, 180, 220) + else: + text_color = (0, 0, 220) + + label = text.replace('\n', ' ')[:30] + cv2.putText(img, label, (cx + 3, cy + ch - 4), + cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) + else: + # Legacy fallback: entry-based overlay (for old sessions) + column_result = session.get("column_result") + row_result = session.get("row_result") + col_colors = { + "column_en": (255, 180, 0), + "column_de": (0, 200, 0), + "column_example": (0, 140, 255), + } + + columns = [] + if column_result and column_result.get("columns"): + columns = [c for c in column_result["columns"] + if c.get("type", "").startswith("column_")] + + content_rows_data = [] + if row_result and row_result.get("rows"): + content_rows_data = [r for r in row_result["rows"] + if r.get("row_type") == "content"] + + for col in columns: + col_type = col.get("type", "") + color = col_colors.get(col_type, (200, 200, 200)) + cx, cw = col["x"], col["width"] + for row in content_rows_data: + ry, rh = row["y"], row["height"] + cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1) + cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1) + + entries = word_result["entries"] + entry_by_row: Dict[int, Dict] = {} + for entry in entries: + entry_by_row[entry.get("row_index", -1)] = entry + + for row_idx, row in enumerate(content_rows_data): + entry = entry_by_row.get(row_idx) + if not entry: + continue + conf = entry.get("confidence", 0) + text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220) + ry, rh = row["y"], row["height"] + for col in columns: + col_type = col.get("type", "") + cx, cw = col["x"], col["width"] + field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "") + text = entry.get(field, "") if field else "" + if text: + label = text.replace('\n', ' ')[:30] + cv2.putText(img, label, (cx + 3, ry + rh - 4), + cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) + + # Blend overlay at 10% opacity + cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img) + + # Red semi-transparent overlay for box zones + column_result = session.get("column_result") or {} + zones = column_result.get("zones") or [] + _draw_box_exclusion_overlay(img, zones) + + success, result_png = cv2.imencode(".png", img) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode overlay image") + + return Response(content=result_png.tobytes(), media_type="image/png") diff --git a/klausur-service/backend/ocr_pipeline_overlay_structure.py b/klausur-service/backend/ocr_pipeline_overlay_structure.py new file mode 100644 index 0000000..ad48382 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_overlay_structure.py @@ -0,0 +1,205 @@ +""" +Overlay rendering for structure detection (boxes, zones, colors, graphics). + +Extracted from ocr_pipeline_overlays.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from typing import Any, Dict, List + +import cv2 +import numpy as np +from fastapi import HTTPException +from fastapi.responses import Response + +from ocr_pipeline_common import _get_base_image_png +from ocr_pipeline_session_store import get_session_db +from cv_color_detect import _COLOR_HEX, _COLOR_RANGES +from cv_box_detect import detect_boxes, split_page_into_zones + +logger = logging.getLogger(__name__) + + +async def _get_structure_overlay(session_id: str) -> Response: + """Generate overlay image showing detected boxes, zones, and color regions.""" + base_png = await _get_base_image_png(session_id) + if not base_png: + raise HTTPException(status_code=404, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + h, w = img.shape[:2] + + # Get structure result (run detection if not cached) + session = await get_session_db(session_id) + structure = (session or {}).get("structure_result") + + if not structure: + # Run detection on-the-fly + margin = int(min(w, h) * 0.03) + content_x, content_y = margin, margin + content_w_px = w - 2 * margin + content_h_px = h - 2 * margin + boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px) + zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes) + structure = { + "boxes": [ + {"x": b.x, "y": b.y, "w": b.width, "h": b.height, + "confidence": b.confidence, "border_thickness": b.border_thickness} + for b in boxes + ], + "zones": [ + {"index": z.index, "zone_type": z.zone_type, + "y": z.y, "h": z.height, "x": z.x, "w": z.width} + for z in zones + ], + } + + overlay = img.copy() + + # --- Draw zone boundaries --- + zone_colors = { + "content": (200, 200, 200), # light gray + "box": (255, 180, 0), # blue-ish (BGR) + } + for zone in structure.get("zones", []): + zx = zone["x"] + zy = zone["y"] + zw = zone["w"] + zh = zone["h"] + color = zone_colors.get(zone["zone_type"], (200, 200, 200)) + + # Draw zone boundary as dashed line + dash_len = 12 + for edge_x in range(zx, zx + zw, dash_len * 2): + end_x = min(edge_x + dash_len, zx + zw) + cv2.line(img, (edge_x, zy), (end_x, zy), color, 1) + cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1) + + # Zone label + zone_label = f"Zone {zone['index']} ({zone['zone_type']})" + cv2.putText(img, zone_label, (zx + 5, zy + 15), + cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1) + + # --- Draw detected boxes --- + # Color map for box backgrounds (BGR) + bg_hex_to_bgr = { + "#dc2626": (38, 38, 220), # red + "#2563eb": (235, 99, 37), # blue + "#16a34a": (74, 163, 22), # green + "#ea580c": (12, 88, 234), # orange + "#9333ea": (234, 51, 147), # purple + "#ca8a04": (4, 138, 202), # yellow + "#6b7280": (128, 114, 107), # gray + } + + for box_data in structure.get("boxes", []): + bx = box_data["x"] + by = box_data["y"] + bw = box_data["w"] + bh = box_data["h"] + conf = box_data.get("confidence", 0) + thickness = box_data.get("border_thickness", 0) + bg_hex = box_data.get("bg_color_hex", "#6b7280") + bg_name = box_data.get("bg_color_name", "") + + # Box fill color + fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107)) + + # Semi-transparent fill + cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1) + + # Solid border + border_color = fill_bgr + cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3) + + # Label + label = f"BOX" + if bg_name and bg_name not in ("unknown", "white"): + label += f" ({bg_name})" + if thickness > 0: + label += f" border={thickness}px" + label += f" {int(conf * 100)}%" + cv2.putText(img, label, (bx + 8, by + 22), + cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2) + cv2.putText(img, label, (bx + 8, by + 22), + cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1) + + # Blend overlay at 15% opacity + cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) + + # --- Draw color regions (HSV masks) --- + hsv = cv2.cvtColor( + cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR), + cv2.COLOR_BGR2HSV, + ) + color_bgr_map = { + "red": (0, 0, 255), + "orange": (0, 140, 255), + "yellow": (0, 200, 255), + "green": (0, 200, 0), + "blue": (255, 150, 0), + "purple": (200, 0, 200), + } + for color_name, ranges in _COLOR_RANGES.items(): + mask = np.zeros((h, w), dtype=np.uint8) + for lower, upper in ranges: + mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) + # Only draw if there are significant colored pixels + if np.sum(mask > 0) < 100: + continue + # Draw colored contours + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + draw_color = color_bgr_map.get(color_name, (200, 200, 200)) + for cnt in contours: + area = cv2.contourArea(cnt) + if area < 20: + continue + cv2.drawContours(img, [cnt], -1, draw_color, 2) + + # --- Draw graphic elements --- + graphics_data = structure.get("graphics", []) + shape_icons = { + "image": "IMAGE", + "illustration": "ILLUST", + } + for gfx in graphics_data: + gx, gy = gfx["x"], gfx["y"] + gw, gh = gfx["w"], gfx["h"] + shape = gfx.get("shape", "icon") + color_hex = gfx.get("color_hex", "#6b7280") + conf = gfx.get("confidence", 0) + + # Pick draw color based on element color (BGR) + gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107)) + + # Draw bounding box (dashed style via short segments) + dash = 6 + for seg_x in range(gx, gx + gw, dash * 2): + end_x = min(seg_x + dash, gx + gw) + cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2) + cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2) + for seg_y in range(gy, gy + gh, dash * 2): + end_y = min(seg_y + dash, gy + gh) + cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2) + cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2) + + # Label + icon = shape_icons.get(shape, shape.upper()[:5]) + label = f"{icon} {int(conf * 100)}%" + # White background for readability + (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1) + lx = gx + 2 + ly = max(gy - 4, th + 4) + cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1) + cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1) + + # Encode result + _, png_buf = cv2.imencode(".png", img) + return Response(content=png_buf.tobytes(), media_type="image/png") diff --git a/klausur-service/backend/ocr_pipeline_overlays.py b/klausur-service/backend/ocr_pipeline_overlays.py index 2789557..7a30f9b 100644 --- a/klausur-service/backend/ocr_pipeline_overlays.py +++ b/klausur-service/backend/ocr_pipeline_overlays.py @@ -1,34 +1,23 @@ """ -Overlay image rendering for OCR pipeline. +Overlay image rendering for OCR pipeline — barrel re-export. -Generates visual overlays for structure, columns, rows, and words -detection results. +All implementation split into: + ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics) + ocr_pipeline_overlay_grid — columns, rows, words overlays Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import logging -from dataclasses import asdict -from typing import Any, Dict, List, Optional - -import cv2 -import numpy as np from fastapi import HTTPException from fastapi.responses import Response -from ocr_pipeline_common import ( - _cache, - _get_base_image_png, - _load_session_to_cache, - _get_cached, +from ocr_pipeline_overlay_structure import _get_structure_overlay # noqa: F401 +from ocr_pipeline_overlay_grid import ( # noqa: F401 + _get_columns_overlay, + _get_rows_overlay, + _get_words_overlay, ) -from ocr_pipeline_session_store import get_session_db, get_session_image -from cv_color_detect import _COLOR_HEX, _COLOR_RANGES -from cv_box_detect import detect_boxes, split_page_into_zones -from ocr_pipeline_rows import _draw_box_exclusion_overlay - -logger = logging.getLogger(__name__) async def render_overlay(overlay_type: str, session_id: str) -> Response: @@ -43,505 +32,3 @@ async def render_overlay(overlay_type: str, session_id: str) -> Response: return await _get_words_overlay(session_id) else: raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}") - - -async def _get_structure_overlay(session_id: str) -> Response: - """Generate overlay image showing detected boxes, zones, and color regions.""" - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - h, w = img.shape[:2] - - # Get structure result (run detection if not cached) - session = await get_session_db(session_id) - structure = (session or {}).get("structure_result") - - if not structure: - # Run detection on-the-fly - margin = int(min(w, h) * 0.03) - content_x, content_y = margin, margin - content_w_px = w - 2 * margin - content_h_px = h - 2 * margin - boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px) - zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes) - structure = { - "boxes": [ - {"x": b.x, "y": b.y, "w": b.width, "h": b.height, - "confidence": b.confidence, "border_thickness": b.border_thickness} - for b in boxes - ], - "zones": [ - {"index": z.index, "zone_type": z.zone_type, - "y": z.y, "h": z.height, "x": z.x, "w": z.width} - for z in zones - ], - } - - overlay = img.copy() - - # --- Draw zone boundaries --- - zone_colors = { - "content": (200, 200, 200), # light gray - "box": (255, 180, 0), # blue-ish (BGR) - } - for zone in structure.get("zones", []): - zx = zone["x"] - zy = zone["y"] - zw = zone["w"] - zh = zone["h"] - color = zone_colors.get(zone["zone_type"], (200, 200, 200)) - - # Draw zone boundary as dashed line - dash_len = 12 - for edge_x in range(zx, zx + zw, dash_len * 2): - end_x = min(edge_x + dash_len, zx + zw) - cv2.line(img, (edge_x, zy), (end_x, zy), color, 1) - cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1) - - # Zone label - zone_label = f"Zone {zone['index']} ({zone['zone_type']})" - cv2.putText(img, zone_label, (zx + 5, zy + 15), - cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1) - - # --- Draw detected boxes --- - # Color map for box backgrounds (BGR) - bg_hex_to_bgr = { - "#dc2626": (38, 38, 220), # red - "#2563eb": (235, 99, 37), # blue - "#16a34a": (74, 163, 22), # green - "#ea580c": (12, 88, 234), # orange - "#9333ea": (234, 51, 147), # purple - "#ca8a04": (4, 138, 202), # yellow - "#6b7280": (128, 114, 107), # gray - } - - for box_data in structure.get("boxes", []): - bx = box_data["x"] - by = box_data["y"] - bw = box_data["w"] - bh = box_data["h"] - conf = box_data.get("confidence", 0) - thickness = box_data.get("border_thickness", 0) - bg_hex = box_data.get("bg_color_hex", "#6b7280") - bg_name = box_data.get("bg_color_name", "") - - # Box fill color - fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107)) - - # Semi-transparent fill - cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1) - - # Solid border - border_color = fill_bgr - cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3) - - # Label - label = f"BOX" - if bg_name and bg_name not in ("unknown", "white"): - label += f" ({bg_name})" - if thickness > 0: - label += f" border={thickness}px" - label += f" {int(conf * 100)}%" - cv2.putText(img, label, (bx + 8, by + 22), - cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2) - cv2.putText(img, label, (bx + 8, by + 22), - cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1) - - # Blend overlay at 15% opacity - cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) - - # --- Draw color regions (HSV masks) --- - hsv = cv2.cvtColor( - cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR), - cv2.COLOR_BGR2HSV, - ) - color_bgr_map = { - "red": (0, 0, 255), - "orange": (0, 140, 255), - "yellow": (0, 200, 255), - "green": (0, 200, 0), - "blue": (255, 150, 0), - "purple": (200, 0, 200), - } - for color_name, ranges in _COLOR_RANGES.items(): - mask = np.zeros((h, w), dtype=np.uint8) - for lower, upper in ranges: - mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper)) - # Only draw if there are significant colored pixels - if np.sum(mask > 0) < 100: - continue - # Draw colored contours - contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - draw_color = color_bgr_map.get(color_name, (200, 200, 200)) - for cnt in contours: - area = cv2.contourArea(cnt) - if area < 20: - continue - cv2.drawContours(img, [cnt], -1, draw_color, 2) - - # --- Draw graphic elements --- - graphics_data = structure.get("graphics", []) - shape_icons = { - "image": "IMAGE", - "illustration": "ILLUST", - } - for gfx in graphics_data: - gx, gy = gfx["x"], gfx["y"] - gw, gh = gfx["w"], gfx["h"] - shape = gfx.get("shape", "icon") - color_hex = gfx.get("color_hex", "#6b7280") - conf = gfx.get("confidence", 0) - - # Pick draw color based on element color (BGR) - gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107)) - - # Draw bounding box (dashed style via short segments) - dash = 6 - for seg_x in range(gx, gx + gw, dash * 2): - end_x = min(seg_x + dash, gx + gw) - cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2) - cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2) - for seg_y in range(gy, gy + gh, dash * 2): - end_y = min(seg_y + dash, gy + gh) - cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2) - cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2) - - # Label - icon = shape_icons.get(shape, shape.upper()[:5]) - label = f"{icon} {int(conf * 100)}%" - # White background for readability - (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1) - lx = gx + 2 - ly = max(gy - 4, th + 4) - cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1) - cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1) - - # Encode result - _, png_buf = cv2.imencode(".png", img) - return Response(content=png_buf.tobytes(), media_type="image/png") - - - -async def _get_columns_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with column borders drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - if not column_result or not column_result.get("columns"): - raise HTTPException(status_code=404, detail="No column data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - # Color map for region types (BGR) - colors = { - "column_en": (255, 180, 0), # Blue - "column_de": (0, 200, 0), # Green - "column_example": (0, 140, 255), # Orange - "column_text": (200, 200, 0), # Cyan/Turquoise - "page_ref": (200, 0, 200), # Purple - "column_marker": (0, 0, 220), # Red - "column_ignore": (180, 180, 180), # Light Gray - "header": (128, 128, 128), # Gray - "footer": (128, 128, 128), # Gray - "margin_top": (100, 100, 100), # Dark Gray - "margin_bottom": (100, 100, 100), # Dark Gray - } - - overlay = img.copy() - for col in column_result["columns"]: - x, y = col["x"], col["y"] - w, h = col["width"], col["height"] - color = colors.get(col.get("type", ""), (200, 200, 200)) - - # Semi-transparent fill - cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) - - # Solid border - cv2.rectangle(img, (x, y), (x + w, y + h), color, 3) - - # Label with confidence - label = col.get("type", "unknown").replace("column_", "").upper() - conf = col.get("classification_confidence") - if conf is not None and conf < 1.0: - label = f"{label} {int(conf * 100)}%" - cv2.putText(img, label, (x + 10, y + 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2) - - # Blend overlay at 20% opacity - cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img) - - # Draw detected box boundaries as dashed rectangles - zones = column_result.get("zones") or [] - for zone in zones: - if zone.get("zone_type") == "box" and zone.get("box"): - box = zone["box"] - bx, by = box["x"], box["y"] - bw, bh = box["width"], box["height"] - box_color = (0, 200, 255) # Yellow (BGR) - # Draw dashed rectangle by drawing short line segments - dash_len = 15 - for edge_x in range(bx, bx + bw, dash_len * 2): - end_x = min(edge_x + dash_len, bx + bw) - cv2.line(img, (edge_x, by), (end_x, by), box_color, 2) - cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2) - for edge_y in range(by, by + bh, dash_len * 2): - end_y = min(edge_y + dash_len, by + bh) - cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2) - cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2) - cv2.putText(img, "BOX", (bx + 10, by + bh - 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2) - - # Red semi-transparent overlay for box zones - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") - - -# --------------------------------------------------------------------------- -# Row Detection Endpoints -# --------------------------------------------------------------------------- - - - -async def _get_rows_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with row bands drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - row_result = session.get("row_result") - if not row_result or not row_result.get("rows"): - raise HTTPException(status_code=404, detail="No row data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - # Color map for row types (BGR) - row_colors = { - "content": (255, 180, 0), # Blue - "header": (128, 128, 128), # Gray - "footer": (128, 128, 128), # Gray - "margin_top": (100, 100, 100), # Dark Gray - "margin_bottom": (100, 100, 100), # Dark Gray - } - - overlay = img.copy() - for row in row_result["rows"]: - x, y = row["x"], row["y"] - w, h = row["width"], row["height"] - row_type = row.get("row_type", "content") - color = row_colors.get(row_type, (200, 200, 200)) - - # Semi-transparent fill - cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1) - - # Solid border - cv2.rectangle(img, (x, y), (x + w, y + h), color, 2) - - # Label - idx = row.get("index", 0) - label = f"R{idx} {row_type.upper()}" - wc = row.get("word_count", 0) - if wc: - label = f"{label} ({wc}w)" - cv2.putText(img, label, (x + 5, y + 18), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) - - # Blend overlay at 15% opacity - cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img) - - # Draw zone separator lines if zones exist - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - if zones: - img_w_px = img.shape[1] - zone_color = (0, 200, 255) # Yellow (BGR) - dash_len = 20 - for zone in zones: - if zone.get("zone_type") == "box": - zy = zone["y"] - zh = zone["height"] - for line_y in [zy, zy + zh]: - for sx in range(0, img_w_px, dash_len * 2): - ex = min(sx + dash_len, img_w_px) - cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2) - - # Red semi-transparent overlay for box zones - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") - - - -async def _get_words_overlay(session_id: str) -> Response: - """Generate cropped (or dewarped) image with cell grid drawn on it.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - word_result = session.get("word_result") - if not word_result: - raise HTTPException(status_code=404, detail="No word data available") - - # Support both new cell-based and legacy entry-based formats - cells = word_result.get("cells") - if not cells and not word_result.get("entries"): - raise HTTPException(status_code=404, detail="No word data available") - - # Load best available base image (cropped > dewarped > original) - base_png = await _get_base_image_png(session_id) - if not base_png: - raise HTTPException(status_code=404, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - img_h, img_w = img.shape[:2] - - overlay = img.copy() - - if cells: - # New cell-based overlay: color by column index - col_palette = [ - (255, 180, 0), # Blue (BGR) - (0, 200, 0), # Green - (0, 140, 255), # Orange - (200, 100, 200), # Purple - (200, 200, 0), # Cyan - (100, 200, 200), # Yellow-ish - ] - - for cell in cells: - bbox = cell.get("bbox_px", {}) - cx = bbox.get("x", 0) - cy = bbox.get("y", 0) - cw = bbox.get("w", 0) - ch = bbox.get("h", 0) - if cw <= 0 or ch <= 0: - continue - - col_idx = cell.get("col_index", 0) - color = col_palette[col_idx % len(col_palette)] - - # Cell rectangle border - cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1) - # Semi-transparent fill - cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1) - - # Cell-ID label (top-left corner) - cell_id = cell.get("cell_id", "") - cv2.putText(img, cell_id, (cx + 2, cy + 10), - cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1) - - # Text label (bottom of cell) - text = cell.get("text", "") - if text: - conf = cell.get("confidence", 0) - if conf >= 70: - text_color = (0, 180, 0) - elif conf >= 50: - text_color = (0, 180, 220) - else: - text_color = (0, 0, 220) - - label = text.replace('\n', ' ')[:30] - cv2.putText(img, label, (cx + 3, cy + ch - 4), - cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) - else: - # Legacy fallback: entry-based overlay (for old sessions) - column_result = session.get("column_result") - row_result = session.get("row_result") - col_colors = { - "column_en": (255, 180, 0), - "column_de": (0, 200, 0), - "column_example": (0, 140, 255), - } - - columns = [] - if column_result and column_result.get("columns"): - columns = [c for c in column_result["columns"] - if c.get("type", "").startswith("column_")] - - content_rows_data = [] - if row_result and row_result.get("rows"): - content_rows_data = [r for r in row_result["rows"] - if r.get("row_type") == "content"] - - for col in columns: - col_type = col.get("type", "") - color = col_colors.get(col_type, (200, 200, 200)) - cx, cw = col["x"], col["width"] - for row in content_rows_data: - ry, rh = row["y"], row["height"] - cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1) - cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1) - - entries = word_result["entries"] - entry_by_row: Dict[int, Dict] = {} - for entry in entries: - entry_by_row[entry.get("row_index", -1)] = entry - - for row_idx, row in enumerate(content_rows_data): - entry = entry_by_row.get(row_idx) - if not entry: - continue - conf = entry.get("confidence", 0) - text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220) - ry, rh = row["y"], row["height"] - for col in columns: - col_type = col.get("type", "") - cx, cw = col["x"], col["width"] - field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "") - text = entry.get(field, "") if field else "" - if text: - label = text.replace('\n', ' ')[:30] - cv2.putText(img, label, (cx + 3, ry + rh - 4), - cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1) - - # Blend overlay at 10% opacity - cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img) - - # Red semi-transparent overlay for box zones - column_result = session.get("column_result") or {} - zones = column_result.get("zones") or [] - _draw_box_exclusion_overlay(img, zones) - - success, result_png = cv2.imencode(".png", img) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode overlay image") - - return Response(content=result_png.tobytes(), media_type="image/png") - diff --git a/klausur-service/backend/ocr_pipeline_regression.py b/klausur-service/backend/ocr_pipeline_regression.py index b6e09a0..5c8ff89 100644 --- a/klausur-service/backend/ocr_pipeline_regression.py +++ b/klausur-service/backend/ocr_pipeline_regression.py @@ -1,607 +1,22 @@ """ -OCR Pipeline Regression Tests — Ground Truth comparison system. +OCR Pipeline Regression Tests — barrel re-export. -Allows marking sessions as "ground truth" and re-running build_grid() -to detect regressions after code changes. +All implementation split into: + ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison + ocr_pipeline_regression_endpoints — FastAPI routes Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import json -import logging -import os -import time -import uuid -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -from fastapi import APIRouter, HTTPException, Query - -from grid_editor_api import _build_grid_core -from ocr_pipeline_session_store import ( - get_pool, - get_session_db, - list_ground_truth_sessions_db, - update_session_db, +# Helpers (used by grid_editor_api_grid.py) +from ocr_pipeline_regression_helpers import ( # noqa: F401 + _init_regression_table, + _persist_regression_run, + _extract_cells_for_comparison, + _build_reference_snapshot, + compare_grids, ) -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"]) - - -# --------------------------------------------------------------------------- -# DB persistence for regression runs -# --------------------------------------------------------------------------- - -async def _init_regression_table(): - """Ensure regression_runs table exists (idempotent).""" - pool = await get_pool() - async with pool.acquire() as conn: - migration_path = os.path.join( - os.path.dirname(__file__), - "migrations/008_regression_runs.sql", - ) - if os.path.exists(migration_path): - with open(migration_path, "r") as f: - sql = f.read() - await conn.execute(sql) - - -async def _persist_regression_run( - status: str, - summary: dict, - results: list, - duration_ms: int, - triggered_by: str = "manual", -) -> str: - """Save a regression run to the database. Returns the run ID.""" - try: - await _init_regression_table() - pool = await get_pool() - run_id = str(uuid.uuid4()) - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO regression_runs - (id, status, total, passed, failed, errors, duration_ms, results, triggered_by) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9) - """, - run_id, - status, - summary.get("total", 0), - summary.get("passed", 0), - summary.get("failed", 0), - summary.get("errors", 0), - duration_ms, - json.dumps(results), - triggered_by, - ) - logger.info("Regression run %s persisted: %s", run_id, status) - return run_id - except Exception as e: - logger.warning("Failed to persist regression run: %s", e) - return "" - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]: - """Extract a flat list of cells from a grid_editor_result for comparison. - - Only keeps fields relevant for comparison: cell_id, row_index, col_index, - col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold. - """ - cells = [] - for zone in grid_result.get("zones", []): - for cell in zone.get("cells", []): - cells.append({ - "cell_id": cell.get("cell_id", ""), - "row_index": cell.get("row_index"), - "col_index": cell.get("col_index"), - "col_type": cell.get("col_type", ""), - "text": cell.get("text", ""), - }) - return cells - - -def _build_reference_snapshot( - grid_result: dict, - pipeline: Optional[str] = None, -) -> dict: - """Build a ground-truth reference snapshot from a grid_editor_result.""" - cells = _extract_cells_for_comparison(grid_result) - - total_zones = len(grid_result.get("zones", [])) - total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", [])) - total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", [])) - - snapshot = { - "saved_at": datetime.now(timezone.utc).isoformat(), - "version": 1, - "pipeline": pipeline, - "summary": { - "total_zones": total_zones, - "total_columns": total_columns, - "total_rows": total_rows, - "total_cells": len(cells), - }, - "cells": cells, - } - return snapshot - - -def compare_grids(reference: dict, current: dict) -> dict: - """Compare a reference grid snapshot with a newly computed one. - - Returns a diff report with: - - status: "pass" or "fail" - - structural_diffs: changes in zone/row/column counts - - cell_diffs: list of individual cell changes - """ - ref_summary = reference.get("summary", {}) - cur_summary = current.get("summary", {}) - - structural_diffs = [] - for key in ("total_zones", "total_columns", "total_rows", "total_cells"): - ref_val = ref_summary.get(key, 0) - cur_val = cur_summary.get(key, 0) - if ref_val != cur_val: - structural_diffs.append({ - "field": key, - "reference": ref_val, - "current": cur_val, - }) - - # Build cell lookup by cell_id - ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])} - cur_cells = {c["cell_id"]: c for c in current.get("cells", [])} - - cell_diffs: List[Dict[str, Any]] = [] - - # Check for missing cells (in reference but not in current) - for cell_id in ref_cells: - if cell_id not in cur_cells: - cell_diffs.append({ - "type": "cell_missing", - "cell_id": cell_id, - "reference_text": ref_cells[cell_id].get("text", ""), - }) - - # Check for added cells (in current but not in reference) - for cell_id in cur_cells: - if cell_id not in ref_cells: - cell_diffs.append({ - "type": "cell_added", - "cell_id": cell_id, - "current_text": cur_cells[cell_id].get("text", ""), - }) - - # Check for changes in shared cells - for cell_id in ref_cells: - if cell_id not in cur_cells: - continue - ref_cell = ref_cells[cell_id] - cur_cell = cur_cells[cell_id] - - if ref_cell.get("text", "") != cur_cell.get("text", ""): - cell_diffs.append({ - "type": "text_change", - "cell_id": cell_id, - "reference": ref_cell.get("text", ""), - "current": cur_cell.get("text", ""), - }) - - if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""): - cell_diffs.append({ - "type": "col_type_change", - "cell_id": cell_id, - "reference": ref_cell.get("col_type", ""), - "current": cur_cell.get("col_type", ""), - }) - - status = "pass" if not structural_diffs and not cell_diffs else "fail" - - return { - "status": status, - "structural_diffs": structural_diffs, - "cell_diffs": cell_diffs, - "summary": { - "structural_changes": len(structural_diffs), - "cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"), - "cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"), - "text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"), - "col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"), - }, - } - - -# --------------------------------------------------------------------------- -# Endpoints -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/mark-ground-truth") -async def mark_ground_truth( - session_id: str, - pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"), -): - """Save the current build-grid result as ground-truth reference.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - grid_result = session.get("grid_editor_result") - if not grid_result or not grid_result.get("zones"): - raise HTTPException( - status_code=400, - detail="No grid_editor_result found. Run build-grid first.", - ) - - # Auto-detect pipeline from word_result if not provided - if not pipeline: - wr = session.get("word_result") or {} - engine = wr.get("ocr_engine", "") - if engine in ("kombi", "rapid_kombi"): - pipeline = "kombi" - elif engine == "paddle_direct": - pipeline = "paddle-direct" - else: - pipeline = "pipeline" - - reference = _build_reference_snapshot(grid_result, pipeline=pipeline) - - # Merge into existing ground_truth JSONB - gt = session.get("ground_truth") or {} - gt["build_grid_reference"] = reference - await update_session_db(session_id, ground_truth=gt, current_step=11) - - # Compare with auto-snapshot if available (shows what the user corrected) - auto_snapshot = gt.get("auto_grid_snapshot") - correction_diff = None - if auto_snapshot: - correction_diff = compare_grids(auto_snapshot, reference) - - logger.info( - "Ground truth marked for session %s: %d cells (corrections: %s)", - session_id, - len(reference["cells"]), - correction_diff["summary"] if correction_diff else "no auto-snapshot", - ) - - return { - "status": "ok", - "session_id": session_id, - "cells_saved": len(reference["cells"]), - "summary": reference["summary"], - "correction_diff": correction_diff, - } - - -@router.delete("/sessions/{session_id}/mark-ground-truth") -async def unmark_ground_truth(session_id: str): - """Remove the ground-truth reference from a session.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - gt = session.get("ground_truth") or {} - if "build_grid_reference" not in gt: - raise HTTPException(status_code=404, detail="No ground truth reference found") - - del gt["build_grid_reference"] - await update_session_db(session_id, ground_truth=gt) - - logger.info("Ground truth removed for session %s", session_id) - return {"status": "ok", "session_id": session_id} - - -@router.get("/sessions/{session_id}/correction-diff") -async def get_correction_diff(session_id: str): - """Compare automatic OCR grid with manually corrected ground truth. - - Returns a diff showing exactly which cells the user corrected, - broken down by col_type (english, german, ipa, etc.). - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - gt = session.get("ground_truth") or {} - auto_snapshot = gt.get("auto_grid_snapshot") - reference = gt.get("build_grid_reference") - - if not auto_snapshot: - raise HTTPException( - status_code=404, - detail="No auto_grid_snapshot found. Re-run build-grid to create one.", - ) - if not reference: - raise HTTPException( - status_code=404, - detail="No ground truth reference found. Mark as ground truth first.", - ) - - diff = compare_grids(auto_snapshot, reference) - - # Enrich with per-col_type breakdown - col_type_stats: Dict[str, Dict[str, int]] = {} - for cell_diff in diff.get("cell_diffs", []): - if cell_diff["type"] != "text_change": - continue - # Find col_type from reference cells - cell_id = cell_diff["cell_id"] - ref_cell = next( - (c for c in reference.get("cells", []) if c["cell_id"] == cell_id), - None, - ) - ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown" - if ct not in col_type_stats: - col_type_stats[ct] = {"total": 0, "corrected": 0} - col_type_stats[ct]["corrected"] += 1 - - # Count total cells per col_type from reference - for cell in reference.get("cells", []): - ct = cell.get("col_type", "unknown") - if ct not in col_type_stats: - col_type_stats[ct] = {"total": 0, "corrected": 0} - col_type_stats[ct]["total"] += 1 - - # Calculate accuracy per col_type - for ct, stats in col_type_stats.items(): - total = stats["total"] - corrected = stats["corrected"] - stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0 - - diff["col_type_breakdown"] = col_type_stats - - return diff - - -@router.get("/ground-truth-sessions") -async def list_ground_truth_sessions(): - """List all sessions that have a ground-truth reference.""" - sessions = await list_ground_truth_sessions_db() - - result = [] - for s in sessions: - gt = s.get("ground_truth") or {} - ref = gt.get("build_grid_reference", {}) - result.append({ - "session_id": s["id"], - "name": s.get("name", ""), - "filename": s.get("filename", ""), - "document_category": s.get("document_category"), - "pipeline": ref.get("pipeline"), - "saved_at": ref.get("saved_at"), - "summary": ref.get("summary", {}), - }) - - return {"sessions": result, "count": len(result)} - - -@router.post("/sessions/{session_id}/regression/run") -async def run_single_regression(session_id: str): - """Re-run build_grid for a single session and compare to ground truth.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - gt = session.get("ground_truth") or {} - reference = gt.get("build_grid_reference") - if not reference: - raise HTTPException( - status_code=400, - detail="No ground truth reference found for this session", - ) - - # Re-compute grid without persisting - try: - new_result = await _build_grid_core(session_id, session) - except (ValueError, Exception) as e: - return { - "session_id": session_id, - "name": session.get("name", ""), - "status": "error", - "error": str(e), - } - - new_snapshot = _build_reference_snapshot(new_result) - diff = compare_grids(reference, new_snapshot) - - logger.info( - "Regression test session %s: %s (%d structural, %d cell diffs)", - session_id, diff["status"], - diff["summary"]["structural_changes"], - sum(v for k, v in diff["summary"].items() if k != "structural_changes"), - ) - - return { - "session_id": session_id, - "name": session.get("name", ""), - "status": diff["status"], - "diff": diff, - "reference_summary": reference.get("summary", {}), - "current_summary": new_snapshot.get("summary", {}), - } - - -@router.post("/regression/run") -async def run_all_regressions( - triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"), -): - """Re-run build_grid for ALL ground-truth sessions and compare.""" - start_time = time.monotonic() - sessions = await list_ground_truth_sessions_db() - - if not sessions: - return { - "status": "pass", - "message": "No ground truth sessions found", - "results": [], - "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0}, - } - - results = [] - passed = 0 - failed = 0 - errors = 0 - - for s in sessions: - session_id = s["id"] - gt = s.get("ground_truth") or {} - reference = gt.get("build_grid_reference") - if not reference: - continue - - # Re-load full session (list query may not include all JSONB fields) - full_session = await get_session_db(session_id) - if not full_session: - results.append({ - "session_id": session_id, - "name": s.get("name", ""), - "status": "error", - "error": "Session not found during re-load", - }) - errors += 1 - continue - - try: - new_result = await _build_grid_core(session_id, full_session) - except (ValueError, Exception) as e: - results.append({ - "session_id": session_id, - "name": s.get("name", ""), - "status": "error", - "error": str(e), - }) - errors += 1 - continue - - new_snapshot = _build_reference_snapshot(new_result) - diff = compare_grids(reference, new_snapshot) - - entry = { - "session_id": session_id, - "name": s.get("name", ""), - "status": diff["status"], - "diff_summary": diff["summary"], - "reference_summary": reference.get("summary", {}), - "current_summary": new_snapshot.get("summary", {}), - } - - # Include full diffs only for failures (keep response compact) - if diff["status"] == "fail": - entry["structural_diffs"] = diff["structural_diffs"] - entry["cell_diffs"] = diff["cell_diffs"] - failed += 1 - else: - passed += 1 - - results.append(entry) - - overall = "pass" if failed == 0 and errors == 0 else "fail" - duration_ms = int((time.monotonic() - start_time) * 1000) - - summary = { - "total": len(results), - "passed": passed, - "failed": failed, - "errors": errors, - } - - logger.info( - "Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms", - overall, passed, failed, errors, len(results), duration_ms, - ) - - # Persist to DB - run_id = await _persist_regression_run( - status=overall, - summary=summary, - results=results, - duration_ms=duration_ms, - triggered_by=triggered_by, - ) - - return { - "status": overall, - "run_id": run_id, - "duration_ms": duration_ms, - "results": results, - "summary": summary, - } - - -@router.get("/regression/history") -async def get_regression_history( - limit: int = Query(20, ge=1, le=100), -): - """Get recent regression run history from the database.""" - try: - await _init_regression_table() - pool = await get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT id, run_at, status, total, passed, failed, errors, - duration_ms, triggered_by - FROM regression_runs - ORDER BY run_at DESC - LIMIT $1 - """, - limit, - ) - return { - "runs": [ - { - "id": str(row["id"]), - "run_at": row["run_at"].isoformat() if row["run_at"] else None, - "status": row["status"], - "total": row["total"], - "passed": row["passed"], - "failed": row["failed"], - "errors": row["errors"], - "duration_ms": row["duration_ms"], - "triggered_by": row["triggered_by"], - } - for row in rows - ], - "count": len(rows), - } - except Exception as e: - logger.warning("Failed to fetch regression history: %s", e) - return {"runs": [], "count": 0, "error": str(e)} - - -@router.get("/regression/history/{run_id}") -async def get_regression_run_detail(run_id: str): - """Get detailed results of a specific regression run.""" - try: - await _init_regression_table() - pool = await get_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT * FROM regression_runs WHERE id = $1", - run_id, - ) - if not row: - raise HTTPException(status_code=404, detail="Run not found") - return { - "id": str(row["id"]), - "run_at": row["run_at"].isoformat() if row["run_at"] else None, - "status": row["status"], - "total": row["total"], - "passed": row["passed"], - "failed": row["failed"], - "errors": row["errors"], - "duration_ms": row["duration_ms"], - "triggered_by": row["triggered_by"], - "results": json.loads(row["results"]) if row["results"] else [], - } - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) +# Endpoints (router used by ocr_pipeline_api.py) +from ocr_pipeline_regression_endpoints import router # noqa: F401 diff --git a/klausur-service/backend/ocr_pipeline_regression_endpoints.py b/klausur-service/backend/ocr_pipeline_regression_endpoints.py new file mode 100644 index 0000000..a91d6d6 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_regression_endpoints.py @@ -0,0 +1,421 @@ +""" +OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression. + +Extracted from ocr_pipeline_regression.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import time +from typing import Any, Dict, Optional + +from fastapi import APIRouter, HTTPException, Query + +from grid_editor_api import _build_grid_core +from ocr_pipeline_session_store import ( + get_session_db, + list_ground_truth_sessions_db, + update_session_db, +) +from ocr_pipeline_regression_helpers import ( + _build_reference_snapshot, + _init_regression_table, + _persist_regression_run, + compare_grids, + get_pool, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"]) + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/mark-ground-truth") +async def mark_ground_truth( + session_id: str, + pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"), +): + """Save the current build-grid result as ground-truth reference.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + grid_result = session.get("grid_editor_result") + if not grid_result or not grid_result.get("zones"): + raise HTTPException( + status_code=400, + detail="No grid_editor_result found. Run build-grid first.", + ) + + # Auto-detect pipeline from word_result if not provided + if not pipeline: + wr = session.get("word_result") or {} + engine = wr.get("ocr_engine", "") + if engine in ("kombi", "rapid_kombi"): + pipeline = "kombi" + elif engine == "paddle_direct": + pipeline = "paddle-direct" + else: + pipeline = "pipeline" + + reference = _build_reference_snapshot(grid_result, pipeline=pipeline) + + # Merge into existing ground_truth JSONB + gt = session.get("ground_truth") or {} + gt["build_grid_reference"] = reference + await update_session_db(session_id, ground_truth=gt, current_step=11) + + # Compare with auto-snapshot if available (shows what the user corrected) + auto_snapshot = gt.get("auto_grid_snapshot") + correction_diff = None + if auto_snapshot: + correction_diff = compare_grids(auto_snapshot, reference) + + logger.info( + "Ground truth marked for session %s: %d cells (corrections: %s)", + session_id, + len(reference["cells"]), + correction_diff["summary"] if correction_diff else "no auto-snapshot", + ) + + return { + "status": "ok", + "session_id": session_id, + "cells_saved": len(reference["cells"]), + "summary": reference["summary"], + "correction_diff": correction_diff, + } + + +@router.delete("/sessions/{session_id}/mark-ground-truth") +async def unmark_ground_truth(session_id: str): + """Remove the ground-truth reference from a session.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + gt = session.get("ground_truth") or {} + if "build_grid_reference" not in gt: + raise HTTPException(status_code=404, detail="No ground truth reference found") + + del gt["build_grid_reference"] + await update_session_db(session_id, ground_truth=gt) + + logger.info("Ground truth removed for session %s", session_id) + return {"status": "ok", "session_id": session_id} + + +@router.get("/sessions/{session_id}/correction-diff") +async def get_correction_diff(session_id: str): + """Compare automatic OCR grid with manually corrected ground truth. + + Returns a diff showing exactly which cells the user corrected, + broken down by col_type (english, german, ipa, etc.). + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + gt = session.get("ground_truth") or {} + auto_snapshot = gt.get("auto_grid_snapshot") + reference = gt.get("build_grid_reference") + + if not auto_snapshot: + raise HTTPException( + status_code=404, + detail="No auto_grid_snapshot found. Re-run build-grid to create one.", + ) + if not reference: + raise HTTPException( + status_code=404, + detail="No ground truth reference found. Mark as ground truth first.", + ) + + diff = compare_grids(auto_snapshot, reference) + + # Enrich with per-col_type breakdown + col_type_stats: Dict[str, Dict[str, int]] = {} + for cell_diff in diff.get("cell_diffs", []): + if cell_diff["type"] != "text_change": + continue + # Find col_type from reference cells + cell_id = cell_diff["cell_id"] + ref_cell = next( + (c for c in reference.get("cells", []) if c["cell_id"] == cell_id), + None, + ) + ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown" + if ct not in col_type_stats: + col_type_stats[ct] = {"total": 0, "corrected": 0} + col_type_stats[ct]["corrected"] += 1 + + # Count total cells per col_type from reference + for cell in reference.get("cells", []): + ct = cell.get("col_type", "unknown") + if ct not in col_type_stats: + col_type_stats[ct] = {"total": 0, "corrected": 0} + col_type_stats[ct]["total"] += 1 + + # Calculate accuracy per col_type + for ct, stats in col_type_stats.items(): + total = stats["total"] + corrected = stats["corrected"] + stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0 + + diff["col_type_breakdown"] = col_type_stats + + return diff + + +@router.get("/ground-truth-sessions") +async def list_ground_truth_sessions(): + """List all sessions that have a ground-truth reference.""" + sessions = await list_ground_truth_sessions_db() + + result = [] + for s in sessions: + gt = s.get("ground_truth") or {} + ref = gt.get("build_grid_reference", {}) + result.append({ + "session_id": s["id"], + "name": s.get("name", ""), + "filename": s.get("filename", ""), + "document_category": s.get("document_category"), + "pipeline": ref.get("pipeline"), + "saved_at": ref.get("saved_at"), + "summary": ref.get("summary", {}), + }) + + return {"sessions": result, "count": len(result)} + + +@router.post("/sessions/{session_id}/regression/run") +async def run_single_regression(session_id: str): + """Re-run build_grid for a single session and compare to ground truth.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + gt = session.get("ground_truth") or {} + reference = gt.get("build_grid_reference") + if not reference: + raise HTTPException( + status_code=400, + detail="No ground truth reference found for this session", + ) + + # Re-compute grid without persisting + try: + new_result = await _build_grid_core(session_id, session) + except (ValueError, Exception) as e: + return { + "session_id": session_id, + "name": session.get("name", ""), + "status": "error", + "error": str(e), + } + + new_snapshot = _build_reference_snapshot(new_result) + diff = compare_grids(reference, new_snapshot) + + logger.info( + "Regression test session %s: %s (%d structural, %d cell diffs)", + session_id, diff["status"], + diff["summary"]["structural_changes"], + sum(v for k, v in diff["summary"].items() if k != "structural_changes"), + ) + + return { + "session_id": session_id, + "name": session.get("name", ""), + "status": diff["status"], + "diff": diff, + "reference_summary": reference.get("summary", {}), + "current_summary": new_snapshot.get("summary", {}), + } + + +@router.post("/regression/run") +async def run_all_regressions( + triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"), +): + """Re-run build_grid for ALL ground-truth sessions and compare.""" + start_time = time.monotonic() + sessions = await list_ground_truth_sessions_db() + + if not sessions: + return { + "status": "pass", + "message": "No ground truth sessions found", + "results": [], + "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0}, + } + + results = [] + passed = 0 + failed = 0 + errors = 0 + + for s in sessions: + session_id = s["id"] + gt = s.get("ground_truth") or {} + reference = gt.get("build_grid_reference") + if not reference: + continue + + # Re-load full session (list query may not include all JSONB fields) + full_session = await get_session_db(session_id) + if not full_session: + results.append({ + "session_id": session_id, + "name": s.get("name", ""), + "status": "error", + "error": "Session not found during re-load", + }) + errors += 1 + continue + + try: + new_result = await _build_grid_core(session_id, full_session) + except (ValueError, Exception) as e: + results.append({ + "session_id": session_id, + "name": s.get("name", ""), + "status": "error", + "error": str(e), + }) + errors += 1 + continue + + new_snapshot = _build_reference_snapshot(new_result) + diff = compare_grids(reference, new_snapshot) + + entry = { + "session_id": session_id, + "name": s.get("name", ""), + "status": diff["status"], + "diff_summary": diff["summary"], + "reference_summary": reference.get("summary", {}), + "current_summary": new_snapshot.get("summary", {}), + } + + # Include full diffs only for failures (keep response compact) + if diff["status"] == "fail": + entry["structural_diffs"] = diff["structural_diffs"] + entry["cell_diffs"] = diff["cell_diffs"] + failed += 1 + else: + passed += 1 + + results.append(entry) + + overall = "pass" if failed == 0 and errors == 0 else "fail" + duration_ms = int((time.monotonic() - start_time) * 1000) + + summary = { + "total": len(results), + "passed": passed, + "failed": failed, + "errors": errors, + } + + logger.info( + "Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms", + overall, passed, failed, errors, len(results), duration_ms, + ) + + # Persist to DB + run_id = await _persist_regression_run( + status=overall, + summary=summary, + results=results, + duration_ms=duration_ms, + triggered_by=triggered_by, + ) + + return { + "status": overall, + "run_id": run_id, + "duration_ms": duration_ms, + "results": results, + "summary": summary, + } + + +@router.get("/regression/history") +async def get_regression_history( + limit: int = Query(20, ge=1, le=100), +): + """Get recent regression run history from the database.""" + try: + await _init_regression_table() + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT id, run_at, status, total, passed, failed, errors, + duration_ms, triggered_by + FROM regression_runs + ORDER BY run_at DESC + LIMIT $1 + """, + limit, + ) + return { + "runs": [ + { + "id": str(row["id"]), + "run_at": row["run_at"].isoformat() if row["run_at"] else None, + "status": row["status"], + "total": row["total"], + "passed": row["passed"], + "failed": row["failed"], + "errors": row["errors"], + "duration_ms": row["duration_ms"], + "triggered_by": row["triggered_by"], + } + for row in rows + ], + "count": len(rows), + } + except Exception as e: + logger.warning("Failed to fetch regression history: %s", e) + return {"runs": [], "count": 0, "error": str(e)} + + +@router.get("/regression/history/{run_id}") +async def get_regression_run_detail(run_id: str): + """Get detailed results of a specific regression run.""" + try: + await _init_regression_table() + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT * FROM regression_runs WHERE id = $1", + run_id, + ) + if not row: + raise HTTPException(status_code=404, detail="Run not found") + return { + "id": str(row["id"]), + "run_at": row["run_at"].isoformat() if row["run_at"] else None, + "status": row["status"], + "total": row["total"], + "passed": row["passed"], + "failed": row["failed"], + "errors": row["errors"], + "duration_ms": row["duration_ms"], + "triggered_by": row["triggered_by"], + "results": json.loads(row["results"]) if row["results"] else [], + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/klausur-service/backend/ocr_pipeline_regression_helpers.py b/klausur-service/backend/ocr_pipeline_regression_helpers.py new file mode 100644 index 0000000..b8e0a57 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_regression_helpers.py @@ -0,0 +1,207 @@ +""" +OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison. + +Extracted from ocr_pipeline_regression.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import json +import logging +import os +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from ocr_pipeline_session_store import get_pool + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# DB persistence for regression runs +# --------------------------------------------------------------------------- + +async def _init_regression_table(): + """Ensure regression_runs table exists (idempotent).""" + pool = await get_pool() + async with pool.acquire() as conn: + migration_path = os.path.join( + os.path.dirname(__file__), + "migrations/008_regression_runs.sql", + ) + if os.path.exists(migration_path): + with open(migration_path, "r") as f: + sql = f.read() + await conn.execute(sql) + + +async def _persist_regression_run( + status: str, + summary: dict, + results: list, + duration_ms: int, + triggered_by: str = "manual", +) -> str: + """Save a regression run to the database. Returns the run ID.""" + try: + await _init_regression_table() + pool = await get_pool() + run_id = str(uuid.uuid4()) + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO regression_runs + (id, status, total, passed, failed, errors, duration_ms, results, triggered_by) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9) + """, + run_id, + status, + summary.get("total", 0), + summary.get("passed", 0), + summary.get("failed", 0), + summary.get("errors", 0), + duration_ms, + json.dumps(results), + triggered_by, + ) + logger.info("Regression run %s persisted: %s", run_id, status) + return run_id + except Exception as e: + logger.warning("Failed to persist regression run: %s", e) + return "" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]: + """Extract a flat list of cells from a grid_editor_result for comparison. + + Only keeps fields relevant for comparison: cell_id, row_index, col_index, + col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold. + """ + cells = [] + for zone in grid_result.get("zones", []): + for cell in zone.get("cells", []): + cells.append({ + "cell_id": cell.get("cell_id", ""), + "row_index": cell.get("row_index"), + "col_index": cell.get("col_index"), + "col_type": cell.get("col_type", ""), + "text": cell.get("text", ""), + }) + return cells + + +def _build_reference_snapshot( + grid_result: dict, + pipeline: Optional[str] = None, +) -> dict: + """Build a ground-truth reference snapshot from a grid_editor_result.""" + cells = _extract_cells_for_comparison(grid_result) + + total_zones = len(grid_result.get("zones", [])) + total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", [])) + total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", [])) + + snapshot = { + "saved_at": datetime.now(timezone.utc).isoformat(), + "version": 1, + "pipeline": pipeline, + "summary": { + "total_zones": total_zones, + "total_columns": total_columns, + "total_rows": total_rows, + "total_cells": len(cells), + }, + "cells": cells, + } + return snapshot + + +def compare_grids(reference: dict, current: dict) -> dict: + """Compare a reference grid snapshot with a newly computed one. + + Returns a diff report with: + - status: "pass" or "fail" + - structural_diffs: changes in zone/row/column counts + - cell_diffs: list of individual cell changes + """ + ref_summary = reference.get("summary", {}) + cur_summary = current.get("summary", {}) + + structural_diffs = [] + for key in ("total_zones", "total_columns", "total_rows", "total_cells"): + ref_val = ref_summary.get(key, 0) + cur_val = cur_summary.get(key, 0) + if ref_val != cur_val: + structural_diffs.append({ + "field": key, + "reference": ref_val, + "current": cur_val, + }) + + # Build cell lookup by cell_id + ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])} + cur_cells = {c["cell_id"]: c for c in current.get("cells", [])} + + cell_diffs: List[Dict[str, Any]] = [] + + # Check for missing cells (in reference but not in current) + for cell_id in ref_cells: + if cell_id not in cur_cells: + cell_diffs.append({ + "type": "cell_missing", + "cell_id": cell_id, + "reference_text": ref_cells[cell_id].get("text", ""), + }) + + # Check for added cells (in current but not in reference) + for cell_id in cur_cells: + if cell_id not in ref_cells: + cell_diffs.append({ + "type": "cell_added", + "cell_id": cell_id, + "current_text": cur_cells[cell_id].get("text", ""), + }) + + # Check for changes in shared cells + for cell_id in ref_cells: + if cell_id not in cur_cells: + continue + ref_cell = ref_cells[cell_id] + cur_cell = cur_cells[cell_id] + + if ref_cell.get("text", "") != cur_cell.get("text", ""): + cell_diffs.append({ + "type": "text_change", + "cell_id": cell_id, + "reference": ref_cell.get("text", ""), + "current": cur_cell.get("text", ""), + }) + + if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""): + cell_diffs.append({ + "type": "col_type_change", + "cell_id": cell_id, + "reference": ref_cell.get("col_type", ""), + "current": cur_cell.get("col_type", ""), + }) + + status = "pass" if not structural_diffs and not cell_diffs else "fail" + + return { + "status": status, + "structural_diffs": structural_diffs, + "cell_diffs": cell_diffs, + "summary": { + "structural_changes": len(structural_diffs), + "cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"), + "cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"), + "text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"), + "col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"), + }, + } diff --git a/klausur-service/backend/ocr_pipeline_sessions.py b/klausur-service/backend/ocr_pipeline_sessions.py index 22e00d2..ae3f771 100644 --- a/klausur-service/backend/ocr_pipeline_sessions.py +++ b/klausur-service/backend/ocr_pipeline_sessions.py @@ -1,597 +1,20 @@ """ -OCR Pipeline Sessions API - Session management and image serving endpoints. +OCR Pipeline Sessions API — barrel re-export. -Extracted from ocr_pipeline_api.py for modularity. -Handles: CRUD for sessions, thumbnails, pipeline logs, categories, -image serving (with overlay dispatch), and document type detection. +All implementation split into: + ocr_pipeline_sessions_crud — session CRUD, box sessions + ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection Lizenz: Apache 2.0 DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. """ -import logging -import time -import uuid -from typing import Any, Dict, Optional +from fastapi import APIRouter -import cv2 -import numpy as np -from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile -from fastapi.responses import Response +from ocr_pipeline_sessions_crud import router as _crud_router # noqa: F401 +from ocr_pipeline_sessions_images import router as _images_router # noqa: F401 -from cv_vocab_pipeline import ( - create_ocr_image, - detect_document_type, - render_image_high_res, - render_pdf_high_res, -) -from ocr_pipeline_common import ( - VALID_DOCUMENT_CATEGORIES, - UpdateSessionRequest, - _append_pipeline_log, - _cache, - _get_base_image_png, - _get_cached, - _load_session_to_cache, -) -from ocr_pipeline_overlays import render_overlay -from ocr_pipeline_session_store import ( - create_session_db, - delete_all_sessions_db, - delete_session_db, - get_session_db, - get_session_image, - get_sub_sessions, - list_sessions_db, - update_session_db, -) - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) - - -# --------------------------------------------------------------------------- -# Session Management Endpoints -# --------------------------------------------------------------------------- - -@router.get("/sessions") -async def list_sessions(include_sub_sessions: bool = False): - """List OCR pipeline sessions. - - By default, sub-sessions (box regions) are hidden. - Pass ?include_sub_sessions=true to show them. - """ - sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions) - return {"sessions": sessions} - - -@router.post("/sessions") -async def create_session( - file: UploadFile = File(...), - name: Optional[str] = Form(None), -): - """Upload a PDF or image file and create a pipeline session. - - For multi-page PDFs (> 1 page), each page becomes its own session - grouped under a ``document_group_id``. The response includes a - ``pages`` array with one entry per page/session. - """ - file_data = await file.read() - filename = file.filename or "upload" - content_type = file.content_type or "" - - is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf") - session_name = name or filename - - # --- Multi-page PDF handling --- - if is_pdf: - try: - import fitz # PyMuPDF - pdf_doc = fitz.open(stream=file_data, filetype="pdf") - page_count = pdf_doc.page_count - pdf_doc.close() - except Exception as e: - raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}") - - if page_count > 1: - return await _create_multi_page_sessions( - file_data, filename, session_name, page_count, - ) - - # --- Single page (image or 1-page PDF) --- - session_id = str(uuid.uuid4()) - - try: - if is_pdf: - img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0) - else: - img_bgr = render_image_high_res(file_data) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Could not process file: {e}") - - # Encode original as PNG bytes - success, png_buf = cv2.imencode(".png", img_bgr) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode image") - - original_png = png_buf.tobytes() - - # Persist to DB - await create_session_db( - session_id=session_id, - name=session_name, - filename=filename, - original_png=original_png, - ) - - # Cache BGR array for immediate processing - _cache[session_id] = { - "id": session_id, - "filename": filename, - "name": session_name, - "original_bgr": img_bgr, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": 1, - } - - logger.info(f"OCR Pipeline: created session {session_id} from {filename} " - f"({img_bgr.shape[1]}x{img_bgr.shape[0]})") - - return { - "session_id": session_id, - "filename": filename, - "name": session_name, - "image_width": img_bgr.shape[1], - "image_height": img_bgr.shape[0], - "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", - } - - -async def _create_multi_page_sessions( - pdf_data: bytes, - filename: str, - base_name: str, - page_count: int, -) -> dict: - """Create one session per PDF page, grouped by document_group_id.""" - document_group_id = str(uuid.uuid4()) - pages = [] - - for page_idx in range(page_count): - session_id = str(uuid.uuid4()) - page_name = f"{base_name} — Seite {page_idx + 1}" - - try: - img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0) - except Exception as e: - logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}") - continue - - ok, png_buf = cv2.imencode(".png", img_bgr) - if not ok: - continue - page_png = png_buf.tobytes() - - await create_session_db( - session_id=session_id, - name=page_name, - filename=filename, - original_png=page_png, - document_group_id=document_group_id, - page_number=page_idx + 1, - ) - - _cache[session_id] = { - "id": session_id, - "filename": filename, - "name": page_name, - "original_bgr": img_bgr, - "oriented_bgr": None, - "cropped_bgr": None, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": 1, - } - - h, w = img_bgr.shape[:2] - pages.append({ - "session_id": session_id, - "name": page_name, - "page_number": page_idx + 1, - "image_width": w, - "image_height": h, - "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", - }) - - logger.info( - f"OCR Pipeline: created page session {session_id} " - f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})" - ) - - # Include session_id pointing to first page for backwards compatibility - # (frontends that expect a single session_id will navigate to page 1) - first_session_id = pages[0]["session_id"] if pages else None - - return { - "session_id": first_session_id, - "document_group_id": document_group_id, - "filename": filename, - "name": base_name, - "page_count": page_count, - "pages": pages, - } - - -@router.get("/sessions/{session_id}") -async def get_session_info(session_id: str): - """Get session info including deskew/dewarp/column results for step navigation.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - # Get image dimensions from original PNG - original_png = await get_session_image(session_id, "original") - if original_png: - arr = np.frombuffer(original_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0) - else: - img_w, img_h = 0, 0 - - result = { - "session_id": session["id"], - "filename": session.get("filename", ""), - "name": session.get("name", ""), - "image_width": img_w, - "image_height": img_h, - "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", - "current_step": session.get("current_step", 1), - "document_category": session.get("document_category"), - "doc_type": session.get("doc_type"), - } - - if session.get("orientation_result"): - result["orientation_result"] = session["orientation_result"] - if session.get("crop_result"): - result["crop_result"] = session["crop_result"] - if session.get("deskew_result"): - result["deskew_result"] = session["deskew_result"] - if session.get("dewarp_result"): - result["dewarp_result"] = session["dewarp_result"] - if session.get("column_result"): - result["column_result"] = session["column_result"] - if session.get("row_result"): - result["row_result"] = session["row_result"] - if session.get("word_result"): - result["word_result"] = session["word_result"] - if session.get("doc_type_result"): - result["doc_type_result"] = session["doc_type_result"] - if session.get("structure_result"): - result["structure_result"] = session["structure_result"] - if session.get("grid_editor_result"): - # Include summary only to keep response small - gr = session["grid_editor_result"] - result["grid_editor_result"] = { - "summary": gr.get("summary", {}), - "zones_count": len(gr.get("zones", [])), - "edited": gr.get("edited", False), - } - if session.get("ground_truth"): - result["ground_truth"] = session["ground_truth"] - - # Box sub-session info (zone_type='box' from column detection — NOT page-split) - if session.get("parent_session_id"): - result["parent_session_id"] = session["parent_session_id"] - result["box_index"] = session.get("box_index") - else: - # Check for box sub-sessions (column detection creates these) - subs = await get_sub_sessions(session_id) - if subs: - result["sub_sessions"] = [ - {"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")} - for s in subs - ] - - return result - - -@router.put("/sessions/{session_id}") -async def update_session(session_id: str, req: UpdateSessionRequest): - """Update session name and/or document category.""" - kwargs: Dict[str, Any] = {} - if req.name is not None: - kwargs["name"] = req.name - if req.document_category is not None: - if req.document_category not in VALID_DOCUMENT_CATEGORIES: - raise HTTPException( - status_code=400, - detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}", - ) - kwargs["document_category"] = req.document_category - if not kwargs: - raise HTTPException(status_code=400, detail="Nothing to update") - updated = await update_session_db(session_id, **kwargs) - if not updated: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, **kwargs} - - -@router.delete("/sessions/{session_id}") -async def delete_session(session_id: str): - """Delete a session.""" - _cache.pop(session_id, None) - deleted = await delete_session_db(session_id) - if not deleted: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, "deleted": True} - - -@router.delete("/sessions") -async def delete_all_sessions(): - """Delete ALL sessions (cleanup).""" - _cache.clear() - count = await delete_all_sessions_db() - return {"deleted_count": count} - - -@router.post("/sessions/{session_id}/create-box-sessions") -async def create_box_sessions(session_id: str): - """Create sub-sessions for each detected box region. - - Crops box regions from the cropped/dewarped image and creates - independent sub-sessions that can be processed through the pipeline. - """ - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - - column_result = session.get("column_result") - if not column_result: - raise HTTPException(status_code=400, detail="Column detection must be completed first") - - zones = column_result.get("zones") or [] - box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] - if not box_zones: - return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"} - - # Check for existing sub-sessions - existing = await get_sub_sessions(session_id) - if existing: - return { - "session_id": session_id, - "sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing], - "message": f"{len(existing)} sub-session(s) already exist", - } - - # Load base image - base_png = await get_session_image(session_id, "cropped") - if not base_png: - base_png = await get_session_image(session_id, "dewarped") - if not base_png: - raise HTTPException(status_code=400, detail="No base image available") - - arr = np.frombuffer(base_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - - parent_name = session.get("name", "Session") - created = [] - - for i, zone in enumerate(box_zones): - box = zone["box"] - bx, by = box["x"], box["y"] - bw, bh = box["width"], box["height"] - - # Crop box region with small padding - pad = 5 - y1 = max(0, by - pad) - y2 = min(img.shape[0], by + bh + pad) - x1 = max(0, bx - pad) - x2 = min(img.shape[1], bx + bw + pad) - crop = img[y1:y2, x1:x2] - - # Encode as PNG - success, png_buf = cv2.imencode(".png", crop) - if not success: - logger.warning(f"Failed to encode box {i} crop for session {session_id}") - continue - - sub_id = str(uuid.uuid4()) - sub_name = f"{parent_name} — Box {i + 1}" - - await create_session_db( - session_id=sub_id, - name=sub_name, - filename=session.get("filename", "box-crop.png"), - original_png=png_buf.tobytes(), - parent_session_id=session_id, - box_index=i, - ) - - # Cache the BGR for immediate processing - # Promote original to cropped so column/row/word detection finds it - box_bgr = crop.copy() - _cache[sub_id] = { - "id": sub_id, - "filename": session.get("filename", "box-crop.png"), - "name": sub_name, - "parent_session_id": session_id, - "original_bgr": box_bgr, - "oriented_bgr": None, - "cropped_bgr": box_bgr, - "deskewed_bgr": None, - "dewarped_bgr": None, - "orientation_result": None, - "crop_result": None, - "deskew_result": None, - "dewarp_result": None, - "ground_truth": {}, - "current_step": 1, - } - - created.append({ - "id": sub_id, - "name": sub_name, - "box_index": i, - "box": box, - "image_width": crop.shape[1], - "image_height": crop.shape[0], - }) - - logger.info(f"Created box sub-session {sub_id} for session {session_id} " - f"(box {i}, {crop.shape[1]}x{crop.shape[0]})") - - return { - "session_id": session_id, - "sub_sessions": created, - "total": len(created), - } - - -@router.get("/sessions/{session_id}/thumbnail") -async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)): - """Return a small thumbnail of the original image.""" - original_png = await get_session_image(session_id, "original") - if not original_png: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image") - arr = np.frombuffer(original_png, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - if img is None: - raise HTTPException(status_code=500, detail="Failed to decode image") - h, w = img.shape[:2] - scale = size / max(h, w) - new_w, new_h = int(w * scale), int(h * scale) - thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) - _, png_bytes = cv2.imencode(".png", thumb) - return Response(content=png_bytes.tobytes(), media_type="image/png", - headers={"Cache-Control": "public, max-age=3600"}) - - -@router.get("/sessions/{session_id}/pipeline-log") -async def get_pipeline_log(session_id: str): - """Get the pipeline execution log for a session.""" - session = await get_session_db(session_id) - if not session: - raise HTTPException(status_code=404, detail=f"Session {session_id} not found") - return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}} - - -@router.get("/categories") -async def list_categories(): - """List valid document categories.""" - return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)} - - -# --------------------------------------------------------------------------- -# Image Endpoints -# --------------------------------------------------------------------------- - -@router.get("/sessions/{session_id}/image/{image_type}") -async def get_image(session_id: str, image_type: str): - """Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay.""" - valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"} - if image_type not in valid_types: - raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") - - if image_type == "structure-overlay": - return await render_overlay("structure", session_id) - - if image_type == "columns-overlay": - return await render_overlay("columns", session_id) - - if image_type == "rows-overlay": - return await render_overlay("rows", session_id) - - if image_type == "words-overlay": - return await render_overlay("words", session_id) - - # Try cache first for fast serving - cached = _cache.get(session_id) - if cached: - png_key = f"{image_type}_png" if image_type != "original" else None - bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None - - # For binarized, check if we have it cached as PNG - if image_type == "binarized" and cached.get("binarized_png"): - return Response(content=cached["binarized_png"], media_type="image/png") - - # Load from DB — for cropped/dewarped, fall back through the chain - if image_type in ("cropped", "dewarped"): - data = await _get_base_image_png(session_id) - else: - data = await get_session_image(session_id, image_type) - if not data: - raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet") - - return Response(content=data, media_type="image/png") - - -# --------------------------------------------------------------------------- -# Document Type Detection (between Dewarp and Columns) -# --------------------------------------------------------------------------- - -@router.post("/sessions/{session_id}/detect-type") -async def detect_type(session_id: str): - """Detect document type (vocab_table, full_text, generic_table). - - Should be called after crop (clean image available). - Falls back to dewarped if crop was skipped. - Stores result in session for frontend to decide pipeline flow. - """ - if session_id not in _cache: - await _load_session_to_cache(session_id) - cached = _get_cached(session_id) - - img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") - if img_bgr is None: - raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") - - t0 = time.time() - ocr_img = create_ocr_image(img_bgr) - result = detect_document_type(ocr_img, img_bgr) - duration = time.time() - t0 - - result_dict = { - "doc_type": result.doc_type, - "confidence": result.confidence, - "pipeline": result.pipeline, - "skip_steps": result.skip_steps, - "features": result.features, - "duration_seconds": round(duration, 2), - } - - # Persist to DB - await update_session_db( - session_id, - doc_type=result.doc_type, - doc_type_result=result_dict, - ) - - cached["doc_type_result"] = result_dict - - logger.info(f"OCR Pipeline: detect-type session {session_id}: " - f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)") - - await _append_pipeline_log(session_id, "detect_type", { - "doc_type": result.doc_type, - "pipeline": result.pipeline, - "confidence": result.confidence, - **{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))}, - }, duration_ms=int(duration * 1000)) - - return {"session_id": session_id, **result_dict} +# Composite router (used by ocr_pipeline_api.py) +router = APIRouter() +router.include_router(_crud_router) +router.include_router(_images_router) diff --git a/klausur-service/backend/ocr_pipeline_sessions_crud.py b/klausur-service/backend/ocr_pipeline_sessions_crud.py new file mode 100644 index 0000000..19343d7 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_sessions_crud.py @@ -0,0 +1,449 @@ +""" +OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions. + +Extracted from ocr_pipeline_sessions.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import uuid +from typing import Any, Dict, Optional + +import cv2 +import numpy as np +from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile + +from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res +from ocr_pipeline_common import ( + VALID_DOCUMENT_CATEGORIES, + UpdateSessionRequest, + _cache, +) +from ocr_pipeline_session_store import ( + create_session_db, + delete_all_sessions_db, + delete_session_db, + get_session_db, + get_session_image, + get_sub_sessions, + list_sessions_db, + update_session_db, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Session Management Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions") +async def list_sessions(include_sub_sessions: bool = False): + """List OCR pipeline sessions. + + By default, sub-sessions (box regions) are hidden. + Pass ?include_sub_sessions=true to show them. + """ + sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions) + return {"sessions": sessions} + + +@router.post("/sessions") +async def create_session( + file: UploadFile = File(...), + name: Optional[str] = Form(None), +): + """Upload a PDF or image file and create a pipeline session. + + For multi-page PDFs (> 1 page), each page becomes its own session + grouped under a ``document_group_id``. The response includes a + ``pages`` array with one entry per page/session. + """ + file_data = await file.read() + filename = file.filename or "upload" + content_type = file.content_type or "" + + is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf") + session_name = name or filename + + # --- Multi-page PDF handling --- + if is_pdf: + try: + import fitz # PyMuPDF + pdf_doc = fitz.open(stream=file_data, filetype="pdf") + page_count = pdf_doc.page_count + pdf_doc.close() + except Exception as e: + raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}") + + if page_count > 1: + return await _create_multi_page_sessions( + file_data, filename, session_name, page_count, + ) + + # --- Single page (image or 1-page PDF) --- + session_id = str(uuid.uuid4()) + + try: + if is_pdf: + img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0) + else: + img_bgr = render_image_high_res(file_data) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Could not process file: {e}") + + # Encode original as PNG bytes + success, png_buf = cv2.imencode(".png", img_bgr) + if not success: + raise HTTPException(status_code=500, detail="Failed to encode image") + + original_png = png_buf.tobytes() + + # Persist to DB + await create_session_db( + session_id=session_id, + name=session_name, + filename=filename, + original_png=original_png, + ) + + # Cache BGR array for immediate processing + _cache[session_id] = { + "id": session_id, + "filename": filename, + "name": session_name, + "original_bgr": img_bgr, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": 1, + } + + logger.info(f"OCR Pipeline: created session {session_id} from {filename} " + f"({img_bgr.shape[1]}x{img_bgr.shape[0]})") + + return { + "session_id": session_id, + "filename": filename, + "name": session_name, + "image_width": img_bgr.shape[1], + "image_height": img_bgr.shape[0], + "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", + } + + +async def _create_multi_page_sessions( + pdf_data: bytes, + filename: str, + base_name: str, + page_count: int, +) -> dict: + """Create one session per PDF page, grouped by document_group_id.""" + document_group_id = str(uuid.uuid4()) + pages = [] + + for page_idx in range(page_count): + session_id = str(uuid.uuid4()) + page_name = f"{base_name} — Seite {page_idx + 1}" + + try: + img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0) + except Exception as e: + logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}") + continue + + ok, png_buf = cv2.imencode(".png", img_bgr) + if not ok: + continue + page_png = png_buf.tobytes() + + await create_session_db( + session_id=session_id, + name=page_name, + filename=filename, + original_png=page_png, + document_group_id=document_group_id, + page_number=page_idx + 1, + ) + + _cache[session_id] = { + "id": session_id, + "filename": filename, + "name": page_name, + "original_bgr": img_bgr, + "oriented_bgr": None, + "cropped_bgr": None, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": 1, + } + + h, w = img_bgr.shape[:2] + pages.append({ + "session_id": session_id, + "name": page_name, + "page_number": page_idx + 1, + "image_width": w, + "image_height": h, + "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", + }) + + logger.info( + f"OCR Pipeline: created page session {session_id} " + f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})" + ) + + # Include session_id pointing to first page for backwards compatibility + # (frontends that expect a single session_id will navigate to page 1) + first_session_id = pages[0]["session_id"] if pages else None + + return { + "session_id": first_session_id, + "document_group_id": document_group_id, + "filename": filename, + "name": base_name, + "page_count": page_count, + "pages": pages, + } + + +@router.get("/sessions/{session_id}") +async def get_session_info(session_id: str): + """Get session info including deskew/dewarp/column results for step navigation.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + # Get image dimensions from original PNG + original_png = await get_session_image(session_id, "original") + if original_png: + arr = np.frombuffer(original_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0) + else: + img_w, img_h = 0, 0 + + result = { + "session_id": session["id"], + "filename": session.get("filename", ""), + "name": session.get("name", ""), + "image_width": img_w, + "image_height": img_h, + "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original", + "current_step": session.get("current_step", 1), + "document_category": session.get("document_category"), + "doc_type": session.get("doc_type"), + } + + if session.get("orientation_result"): + result["orientation_result"] = session["orientation_result"] + if session.get("crop_result"): + result["crop_result"] = session["crop_result"] + if session.get("deskew_result"): + result["deskew_result"] = session["deskew_result"] + if session.get("dewarp_result"): + result["dewarp_result"] = session["dewarp_result"] + if session.get("column_result"): + result["column_result"] = session["column_result"] + if session.get("row_result"): + result["row_result"] = session["row_result"] + if session.get("word_result"): + result["word_result"] = session["word_result"] + if session.get("doc_type_result"): + result["doc_type_result"] = session["doc_type_result"] + if session.get("structure_result"): + result["structure_result"] = session["structure_result"] + if session.get("grid_editor_result"): + # Include summary only to keep response small + gr = session["grid_editor_result"] + result["grid_editor_result"] = { + "summary": gr.get("summary", {}), + "zones_count": len(gr.get("zones", [])), + "edited": gr.get("edited", False), + } + if session.get("ground_truth"): + result["ground_truth"] = session["ground_truth"] + + # Box sub-session info (zone_type='box' from column detection — NOT page-split) + if session.get("parent_session_id"): + result["parent_session_id"] = session["parent_session_id"] + result["box_index"] = session.get("box_index") + else: + # Check for box sub-sessions (column detection creates these) + subs = await get_sub_sessions(session_id) + if subs: + result["sub_sessions"] = [ + {"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")} + for s in subs + ] + + return result + + +@router.put("/sessions/{session_id}") +async def update_session(session_id: str, req: UpdateSessionRequest): + """Update session name and/or document category.""" + kwargs: Dict[str, Any] = {} + if req.name is not None: + kwargs["name"] = req.name + if req.document_category is not None: + if req.document_category not in VALID_DOCUMENT_CATEGORIES: + raise HTTPException( + status_code=400, + detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}", + ) + kwargs["document_category"] = req.document_category + if not kwargs: + raise HTTPException(status_code=400, detail="Nothing to update") + updated = await update_session_db(session_id, **kwargs) + if not updated: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, **kwargs} + + +@router.delete("/sessions/{session_id}") +async def delete_session(session_id: str): + """Delete a session.""" + _cache.pop(session_id, None) + deleted = await delete_session_db(session_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, "deleted": True} + + +@router.delete("/sessions") +async def delete_all_sessions(): + """Delete ALL sessions (cleanup).""" + _cache.clear() + count = await delete_all_sessions_db() + return {"deleted_count": count} + + +@router.post("/sessions/{session_id}/create-box-sessions") +async def create_box_sessions(session_id: str): + """Create sub-sessions for each detected box region. + + Crops box regions from the cropped/dewarped image and creates + independent sub-sessions that can be processed through the pipeline. + """ + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + + column_result = session.get("column_result") + if not column_result: + raise HTTPException(status_code=400, detail="Column detection must be completed first") + + zones = column_result.get("zones") or [] + box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")] + if not box_zones: + return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"} + + # Check for existing sub-sessions + existing = await get_sub_sessions(session_id) + if existing: + return { + "session_id": session_id, + "sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing], + "message": f"{len(existing)} sub-session(s) already exist", + } + + # Load base image + base_png = await get_session_image(session_id, "cropped") + if not base_png: + base_png = await get_session_image(session_id, "dewarped") + if not base_png: + raise HTTPException(status_code=400, detail="No base image available") + + arr = np.frombuffer(base_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + + parent_name = session.get("name", "Session") + created = [] + + for i, zone in enumerate(box_zones): + box = zone["box"] + bx, by = box["x"], box["y"] + bw, bh = box["width"], box["height"] + + # Crop box region with small padding + pad = 5 + y1 = max(0, by - pad) + y2 = min(img.shape[0], by + bh + pad) + x1 = max(0, bx - pad) + x2 = min(img.shape[1], bx + bw + pad) + crop = img[y1:y2, x1:x2] + + # Encode as PNG + success, png_buf = cv2.imencode(".png", crop) + if not success: + logger.warning(f"Failed to encode box {i} crop for session {session_id}") + continue + + sub_id = str(uuid.uuid4()) + sub_name = f"{parent_name} — Box {i + 1}" + + await create_session_db( + session_id=sub_id, + name=sub_name, + filename=session.get("filename", "box-crop.png"), + original_png=png_buf.tobytes(), + parent_session_id=session_id, + box_index=i, + ) + + # Cache the BGR for immediate processing + # Promote original to cropped so column/row/word detection finds it + box_bgr = crop.copy() + _cache[sub_id] = { + "id": sub_id, + "filename": session.get("filename", "box-crop.png"), + "name": sub_name, + "parent_session_id": session_id, + "original_bgr": box_bgr, + "oriented_bgr": None, + "cropped_bgr": box_bgr, + "deskewed_bgr": None, + "dewarped_bgr": None, + "orientation_result": None, + "crop_result": None, + "deskew_result": None, + "dewarp_result": None, + "ground_truth": {}, + "current_step": 1, + } + + created.append({ + "id": sub_id, + "name": sub_name, + "box_index": i, + "box": box, + "image_width": crop.shape[1], + "image_height": crop.shape[0], + }) + + logger.info(f"Created box sub-session {sub_id} for session {session_id} " + f"(box {i}, {crop.shape[1]}x{crop.shape[0]})") + + return { + "session_id": session_id, + "sub_sessions": created, + "total": len(created), + } diff --git a/klausur-service/backend/ocr_pipeline_sessions_images.py b/klausur-service/backend/ocr_pipeline_sessions_images.py new file mode 100644 index 0000000..79da448 --- /dev/null +++ b/klausur-service/backend/ocr_pipeline_sessions_images.py @@ -0,0 +1,176 @@ +""" +OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log, +categories, and document type detection. + +Extracted from ocr_pipeline_sessions.py for modularity. + +Lizenz: Apache 2.0 +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +import time +from typing import Any, Dict + +import cv2 +import numpy as np +from fastapi import APIRouter, HTTPException, Query +from fastapi.responses import Response + +from cv_vocab_pipeline import create_ocr_image, detect_document_type +from ocr_pipeline_common import ( + VALID_DOCUMENT_CATEGORIES, + _append_pipeline_log, + _cache, + _get_base_image_png, + _get_cached, + _load_session_to_cache, +) +from ocr_pipeline_overlays import render_overlay +from ocr_pipeline_session_store import ( + get_session_db, + get_session_image, + update_session_db, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"]) + + +# --------------------------------------------------------------------------- +# Thumbnail & Log Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/thumbnail") +async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)): + """Return a small thumbnail of the original image.""" + original_png = await get_session_image(session_id, "original") + if not original_png: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image") + arr = np.frombuffer(original_png, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise HTTPException(status_code=500, detail="Failed to decode image") + h, w = img.shape[:2] + scale = size / max(h, w) + new_w, new_h = int(w * scale), int(h * scale) + thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) + _, png_bytes = cv2.imencode(".png", thumb) + return Response(content=png_bytes.tobytes(), media_type="image/png", + headers={"Cache-Control": "public, max-age=3600"}) + + +@router.get("/sessions/{session_id}/pipeline-log") +async def get_pipeline_log(session_id: str): + """Get the pipeline execution log for a session.""" + session = await get_session_db(session_id) + if not session: + raise HTTPException(status_code=404, detail=f"Session {session_id} not found") + return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}} + + +@router.get("/categories") +async def list_categories(): + """List valid document categories.""" + return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)} + + +# --------------------------------------------------------------------------- +# Image Endpoints +# --------------------------------------------------------------------------- + +@router.get("/sessions/{session_id}/image/{image_type}") +async def get_image(session_id: str, image_type: str): + """Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay.""" + valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"} + if image_type not in valid_types: + raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}") + + if image_type == "structure-overlay": + return await render_overlay("structure", session_id) + + if image_type == "columns-overlay": + return await render_overlay("columns", session_id) + + if image_type == "rows-overlay": + return await render_overlay("rows", session_id) + + if image_type == "words-overlay": + return await render_overlay("words", session_id) + + # Try cache first for fast serving + cached = _cache.get(session_id) + if cached: + png_key = f"{image_type}_png" if image_type != "original" else None + bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None + + # For binarized, check if we have it cached as PNG + if image_type == "binarized" and cached.get("binarized_png"): + return Response(content=cached["binarized_png"], media_type="image/png") + + # Load from DB — for cropped/dewarped, fall back through the chain + if image_type in ("cropped", "dewarped"): + data = await _get_base_image_png(session_id) + else: + data = await get_session_image(session_id, image_type) + if not data: + raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet") + + return Response(content=data, media_type="image/png") + + +# --------------------------------------------------------------------------- +# Document Type Detection (between Dewarp and Columns) +# --------------------------------------------------------------------------- + +@router.post("/sessions/{session_id}/detect-type") +async def detect_type(session_id: str): + """Detect document type (vocab_table, full_text, generic_table). + + Should be called after crop (clean image available). + Falls back to dewarped if crop was skipped. + Stores result in session for frontend to decide pipeline flow. + """ + if session_id not in _cache: + await _load_session_to_cache(session_id) + cached = _get_cached(session_id) + + img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr") + if img_bgr is None: + raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first") + + t0 = time.time() + ocr_img = create_ocr_image(img_bgr) + result = detect_document_type(ocr_img, img_bgr) + duration = time.time() - t0 + + result_dict = { + "doc_type": result.doc_type, + "confidence": result.confidence, + "pipeline": result.pipeline, + "skip_steps": result.skip_steps, + "features": result.features, + "duration_seconds": round(duration, 2), + } + + # Persist to DB + await update_session_db( + session_id, + doc_type=result.doc_type, + doc_type_result=result_dict, + ) + + cached["doc_type_result"] = result_dict + + logger.info(f"OCR Pipeline: detect-type session {session_id}: " + f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)") + + await _append_pipeline_log(session_id, "detect_type", { + "doc_type": result.doc_type, + "pipeline": result.pipeline, + "confidence": result.confidence, + **{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))}, + }, duration_ms=int(duration * 1000)) + + return {"session_id": session_id, **result_dict} diff --git a/klausur-service/backend/self_rag.py b/klausur-service/backend/self_rag.py index 61de02c..9bcf871 100644 --- a/klausur-service/backend/self_rag.py +++ b/klausur-service/backend/self_rag.py @@ -1,529 +1,38 @@ """ -Self-RAG / Corrective RAG Module +Self-RAG / Corrective RAG Module — barrel re-export. -Implements self-reflective RAG that can: -1. Grade retrieved documents for relevance -2. Decide if more retrieval is needed -3. Reformulate queries if initial retrieval fails -4. Filter irrelevant passages before generation -5. Grade answers for groundedness and hallucination +All implementation split into: + self_rag_grading — document relevance grading, filtering, decisions + self_rag_retrieval — query reformulation, retrieval loop, info + +IMPORTANT: Self-RAG is DISABLED by default for privacy reasons! +When enabled, search queries and retrieved documents are sent to OpenAI API. Based on research: -- Self-RAG (Asai et al., 2023): Learning to retrieve, generate, and critique -- Corrective RAG (Yan et al., 2024): Self-correcting retrieval augmented generation - -This is especially useful for German educational documents where: -- Queries may use informal language -- Documents use formal/technical terminology -- Context must be precisely matched to scoring criteria +- Self-RAG (Asai et al., 2023) +- Corrective RAG (Yan et al., 2024) """ -import os -from typing import List, Dict, Optional, Tuple -from enum import Enum -import httpx - -# Configuration -# IMPORTANT: Self-RAG is DISABLED by default for privacy reasons! -# When enabled, search queries and retrieved documents are sent to OpenAI API -# for relevance grading and query reformulation. This exposes user data to third parties. -# Only enable if you have explicit user consent for data processing. -SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true" -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") -SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini") - -# Thresholds for self-reflection -RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6")) -GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7")) -MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2")) - - -class RetrievalDecision(Enum): - """Decision after grading retrieval.""" - SUFFICIENT = "sufficient" # Context is good, proceed to generation - NEEDS_MORE = "needs_more" # Need to retrieve more documents - REFORMULATE = "reformulate" # Query needs reformulation - FALLBACK = "fallback" # Use fallback (no good context found) - - -class SelfRAGError(Exception): - """Error during Self-RAG processing.""" - pass - - -async def grade_document_relevance( - query: str, - document: str, -) -> Tuple[float, str]: - """ - Grade whether a document is relevant to the query. - - Returns a score between 0 (irrelevant) and 1 (highly relevant) - along with an explanation. - """ - if not OPENAI_API_KEY: - # Fallback: simple keyword overlap - query_words = set(query.lower().split()) - doc_words = set(document.lower().split()) - overlap = len(query_words & doc_words) / max(len(query_words), 1) - return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)" - - prompt = f"""Bewerte, ob das folgende Dokument relevant für die Suchanfrage ist. - -SUCHANFRAGE: {query} - -DOKUMENT: -{document[:2000]} - -Ist dieses Dokument relevant, um die Anfrage zu beantworten? -Berücksichtige: -- Thematische Übereinstimmung -- Enthält das Dokument spezifische Informationen zur Anfrage? -- Würde dieses Dokument bei der Beantwortung helfen? - -Antworte im Format: -SCORE: [0.0-1.0] -BEGRÜNDUNG: [Kurze Erklärung]""" - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - "https://api.openai.com/v1/chat/completions", - headers={ - "Authorization": f"Bearer {OPENAI_API_KEY}", - "Content-Type": "application/json" - }, - json={ - "model": SELF_RAG_MODEL, - "messages": [{"role": "user", "content": prompt}], - "max_tokens": 150, - "temperature": 0.0, - }, - timeout=30.0 - ) - - if response.status_code != 200: - return 0.5, f"API error: {response.status_code}" - - result = response.json()["choices"][0]["message"]["content"] - - import re - score_match = re.search(r'SCORE:\s*([\d.]+)', result) - score = float(score_match.group(1)) if score_match else 0.5 - - reason_match = re.search(r'BEGRÜNDUNG:\s*(.+)', result, re.DOTALL) - reason = reason_match.group(1).strip() if reason_match else result - - return min(max(score, 0.0), 1.0), reason - - except Exception as e: - return 0.5, f"Grading error: {str(e)}" - - -async def grade_documents_batch( - query: str, - documents: List[str], -) -> List[Tuple[float, str]]: - """ - Grade multiple documents for relevance. - - Returns list of (score, reason) tuples. - """ - results = [] - for doc in documents: - score, reason = await grade_document_relevance(query, doc) - results.append((score, reason)) - return results - - -async def filter_relevant_documents( - query: str, - documents: List[Dict], - threshold: float = RELEVANCE_THRESHOLD, -) -> Tuple[List[Dict], List[Dict]]: - """ - Filter documents by relevance, separating relevant from irrelevant. - - Args: - query: The search query - documents: List of document dicts with 'text' field - threshold: Minimum relevance score to keep - - Returns: - Tuple of (relevant_docs, filtered_out_docs) - """ - relevant = [] - filtered = [] - - for doc in documents: - text = doc.get("text", "") - score, reason = await grade_document_relevance(query, text) - - doc_with_grade = doc.copy() - doc_with_grade["relevance_score"] = score - doc_with_grade["relevance_reason"] = reason - - if score >= threshold: - relevant.append(doc_with_grade) - else: - filtered.append(doc_with_grade) - - # Sort relevant by score - relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True) - - return relevant, filtered - - -async def decide_retrieval_strategy( - query: str, - documents: List[Dict], - attempt: int = 1, -) -> Tuple[RetrievalDecision, Dict]: - """ - Decide what to do based on current retrieval results. - - Args: - query: The search query - documents: Retrieved documents with relevance scores - attempt: Current retrieval attempt number - - Returns: - Tuple of (decision, metadata) - """ - if not documents: - if attempt >= MAX_RETRIEVAL_ATTEMPTS: - return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"} - return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"} - - # Check average relevance - scores = [doc.get("relevance_score", 0.5) for doc in documents] - avg_score = sum(scores) / len(scores) - max_score = max(scores) - - if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7: - return RetrievalDecision.SUFFICIENT, { - "avg_relevance": avg_score, - "max_relevance": max_score, - "doc_count": len(documents), - } - - if attempt >= MAX_RETRIEVAL_ATTEMPTS: - if max_score >= RELEVANCE_THRESHOLD * 0.5: - # At least some relevant context, proceed with caution - return RetrievalDecision.SUFFICIENT, { - "avg_relevance": avg_score, - "warning": "Low relevance after max attempts", - } - return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"} - - if avg_score < 0.3: - return RetrievalDecision.REFORMULATE, { - "reason": "Very low relevance, query reformulation needed", - "avg_relevance": avg_score, - } - - return RetrievalDecision.NEEDS_MORE, { - "reason": "Moderate relevance, retrieving more documents", - "avg_relevance": avg_score, - } - - -async def reformulate_query( - original_query: str, - context: Optional[str] = None, - previous_results_summary: Optional[str] = None, -) -> str: - """ - Reformulate a query to improve retrieval. - - Uses LLM to generate a better query based on: - - Original query - - Optional context (subject, niveau, etc.) - - Summary of why previous retrieval failed - """ - if not OPENAI_API_KEY: - # Simple reformulation: expand abbreviations, add synonyms - reformulated = original_query - expansions = { - "EA": "erhöhtes Anforderungsniveau", - "eA": "erhöhtes Anforderungsniveau", - "GA": "grundlegendes Anforderungsniveau", - "gA": "grundlegendes Anforderungsniveau", - "AFB": "Anforderungsbereich", - "Abi": "Abitur", - } - for abbr, expansion in expansions.items(): - if abbr in original_query: - reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})") - return reformulated - - prompt = f"""Du bist ein Experte für deutsche Bildungsstandards und Prüfungsanforderungen. - -Die folgende Suchanfrage hat keine guten Ergebnisse geliefert: -ORIGINAL: {original_query} - -{f"KONTEXT: {context}" if context else ""} -{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""} - -Formuliere die Anfrage so um, dass sie: -1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten) -2. Relevante Synonyme oder verwandte Begriffe einschließt -3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist - -Antworte NUR mit der umformulierten Suchanfrage, ohne Erklärung.""" - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - "https://api.openai.com/v1/chat/completions", - headers={ - "Authorization": f"Bearer {OPENAI_API_KEY}", - "Content-Type": "application/json" - }, - json={ - "model": SELF_RAG_MODEL, - "messages": [{"role": "user", "content": prompt}], - "max_tokens": 100, - "temperature": 0.3, - }, - timeout=30.0 - ) - - if response.status_code != 200: - return original_query - - return response.json()["choices"][0]["message"]["content"].strip() - - except Exception: - return original_query - - -async def grade_answer_groundedness( - answer: str, - contexts: List[str], -) -> Tuple[float, List[str]]: - """ - Grade whether an answer is grounded in the provided contexts. - - Returns: - Tuple of (grounding_score, list of unsupported claims) - """ - if not OPENAI_API_KEY: - return 0.5, ["LLM not configured for grounding check"] - - context_text = "\n---\n".join(contexts[:5]) - - prompt = f"""Analysiere, ob die folgende Antwort vollständig durch die Kontexte gestützt wird. - -KONTEXTE: -{context_text} - -ANTWORT: -{answer} - -Identifiziere: -1. Welche Aussagen sind durch die Kontexte belegt? -2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)? - -Antworte im Format: -SCORE: [0.0-1.0] (1.0 = vollständig belegt) -NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]""" - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - "https://api.openai.com/v1/chat/completions", - headers={ - "Authorization": f"Bearer {OPENAI_API_KEY}", - "Content-Type": "application/json" - }, - json={ - "model": SELF_RAG_MODEL, - "messages": [{"role": "user", "content": prompt}], - "max_tokens": 300, - "temperature": 0.0, - }, - timeout=30.0 - ) - - if response.status_code != 200: - return 0.5, [f"API error: {response.status_code}"] - - result = response.json()["choices"][0]["message"]["content"] - - import re - score_match = re.search(r'SCORE:\s*([\d.]+)', result) - score = float(score_match.group(1)) if score_match else 0.5 - - unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL) - unsupported_text = unsupported_match.group(1).strip() if unsupported_match else "" - - if unsupported_text.lower() == "keine": - unsupported = [] - else: - unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()] - - return min(max(score, 0.0), 1.0), unsupported - - except Exception as e: - return 0.5, [f"Grounding check error: {str(e)}"] - - -async def self_rag_retrieve( - query: str, - search_func, - subject: Optional[str] = None, - niveau: Optional[str] = None, - initial_top_k: int = 10, - final_top_k: int = 5, - **search_kwargs -) -> Dict: - """ - Perform Self-RAG enhanced retrieval with reflection and correction. - - This implements a retrieval loop that: - 1. Retrieves initial documents - 2. Grades them for relevance - 3. Decides if more retrieval is needed - 4. Reformulates query if necessary - 5. Returns filtered, high-quality context - - Args: - query: The search query - search_func: Async function to perform the actual search - subject: Optional subject context - niveau: Optional niveau context - initial_top_k: Number of documents for initial retrieval - final_top_k: Maximum documents to return - **search_kwargs: Additional args for search_func - - Returns: - Dict with results, metadata, and reflection trace - """ - if not SELF_RAG_ENABLED: - # Fall back to simple search - results = await search_func(query=query, limit=final_top_k, **search_kwargs) - return { - "results": results, - "self_rag_enabled": False, - "query_used": query, - } - - trace = [] - current_query = query - attempt = 1 - - while attempt <= MAX_RETRIEVAL_ATTEMPTS: - # Step 1: Retrieve documents - results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs) - - trace.append({ - "attempt": attempt, - "query": current_query, - "retrieved_count": len(results) if results else 0, - }) - - if not results: - attempt += 1 - if attempt <= MAX_RETRIEVAL_ATTEMPTS: - current_query = await reformulate_query( - query, - context=f"Fach: {subject}" if subject else None, - previous_results_summary="Keine Dokumente gefunden" - ) - trace[-1]["action"] = "reformulate" - trace[-1]["new_query"] = current_query - continue - - # Step 2: Grade documents for relevance - relevant, filtered = await filter_relevant_documents(current_query, results) - - trace[-1]["relevant_count"] = len(relevant) - trace[-1]["filtered_count"] = len(filtered) - - # Step 3: Decide what to do - decision, decision_meta = await decide_retrieval_strategy( - current_query, relevant, attempt - ) - - trace[-1]["decision"] = decision.value - trace[-1]["decision_meta"] = decision_meta - - if decision == RetrievalDecision.SUFFICIENT: - # We have good context, return it - return { - "results": relevant[:final_top_k], - "self_rag_enabled": True, - "query_used": current_query, - "original_query": query if current_query != query else None, - "attempts": attempt, - "decision": decision.value, - "trace": trace, - "filtered_out_count": len(filtered), - } - - elif decision == RetrievalDecision.REFORMULATE: - # Reformulate and try again - avg_score = decision_meta.get("avg_relevance", 0) - current_query = await reformulate_query( - query, - context=f"Fach: {subject}" if subject else None, - previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}" - ) - trace[-1]["action"] = "reformulate" - trace[-1]["new_query"] = current_query - - elif decision == RetrievalDecision.NEEDS_MORE: - # Retrieve more with expanded query - current_query = f"{current_query} Bewertungskriterien Anforderungen" - trace[-1]["action"] = "expand_query" - trace[-1]["new_query"] = current_query - - elif decision == RetrievalDecision.FALLBACK: - # Return what we have, even if not ideal - return { - "results": (relevant or results)[:final_top_k], - "self_rag_enabled": True, - "query_used": current_query, - "original_query": query if current_query != query else None, - "attempts": attempt, - "decision": decision.value, - "warning": "Fallback mode - low relevance context", - "trace": trace, - } - - attempt += 1 - - # Max attempts reached - return { - "results": results[:final_top_k] if results else [], - "self_rag_enabled": True, - "query_used": current_query, - "original_query": query if current_query != query else None, - "attempts": attempt - 1, - "decision": "max_attempts", - "warning": "Max retrieval attempts reached", - "trace": trace, - } - - -def get_self_rag_info() -> dict: - """Get information about Self-RAG configuration.""" - return { - "enabled": SELF_RAG_ENABLED, - "llm_configured": bool(OPENAI_API_KEY), - "model": SELF_RAG_MODEL, - "relevance_threshold": RELEVANCE_THRESHOLD, - "grounding_threshold": GROUNDING_THRESHOLD, - "max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS, - "features": [ - "document_grading", - "relevance_filtering", - "query_reformulation", - "answer_grounding_check", - "retrieval_decision", - ], - "sends_data_externally": True, # ALWAYS true when enabled - documents sent to OpenAI - "privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading", - "default_enabled": False, # Disabled by default for privacy - } +# Grading: relevance, filtering, decisions, groundedness +from self_rag_grading import ( # noqa: F401 + SELF_RAG_ENABLED, + OPENAI_API_KEY, + SELF_RAG_MODEL, + RELEVANCE_THRESHOLD, + GROUNDING_THRESHOLD, + MAX_RETRIEVAL_ATTEMPTS, + RetrievalDecision, + SelfRAGError, + grade_document_relevance, + grade_documents_batch, + filter_relevant_documents, + decide_retrieval_strategy, + grade_answer_groundedness, +) + +# Retrieval: reformulation, loop, info +from self_rag_retrieval import ( # noqa: F401 + reformulate_query, + self_rag_retrieve, + get_self_rag_info, +) diff --git a/klausur-service/backend/self_rag_grading.py b/klausur-service/backend/self_rag_grading.py new file mode 100644 index 0000000..be6b096 --- /dev/null +++ b/klausur-service/backend/self_rag_grading.py @@ -0,0 +1,285 @@ +""" +Self-RAG Grading — document relevance grading, filtering, retrieval decisions. + +Extracted from self_rag.py for modularity. + +Based on research: +- Self-RAG (Asai et al., 2023) +- Corrective RAG (Yan et al., 2024) +""" + +import os +from typing import List, Dict, Optional, Tuple +from enum import Enum +import httpx + +# Configuration +SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true" +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") +SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini") + +# Thresholds for self-reflection +RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6")) +GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7")) +MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2")) + + +class RetrievalDecision(Enum): + """Decision after grading retrieval.""" + SUFFICIENT = "sufficient" # Context is good, proceed to generation + NEEDS_MORE = "needs_more" # Need to retrieve more documents + REFORMULATE = "reformulate" # Query needs reformulation + FALLBACK = "fallback" # Use fallback (no good context found) + + +class SelfRAGError(Exception): + """Error during Self-RAG processing.""" + pass + + +async def grade_document_relevance( + query: str, + document: str, +) -> Tuple[float, str]: + """ + Grade whether a document is relevant to the query. + + Returns a score between 0 (irrelevant) and 1 (highly relevant) + along with an explanation. + """ + if not OPENAI_API_KEY: + # Fallback: simple keyword overlap + query_words = set(query.lower().split()) + doc_words = set(document.lower().split()) + overlap = len(query_words & doc_words) / max(len(query_words), 1) + return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)" + + prompt = f"""Bewerte, ob das folgende Dokument relevant fuer die Suchanfrage ist. + +SUCHANFRAGE: {query} + +DOKUMENT: +{document[:2000]} + +Ist dieses Dokument relevant, um die Anfrage zu beantworten? +Beruecksichtige: +- Thematische Uebereinstimmung +- Enthaelt das Dokument spezifische Informationen zur Anfrage? +- Wuerde dieses Dokument bei der Beantwortung helfen? + +Antworte im Format: +SCORE: [0.0-1.0] +BEGRUENDUNG: [Kurze Erklaerung]""" + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + "https://api.openai.com/v1/chat/completions", + headers={ + "Authorization": f"Bearer {OPENAI_API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": SELF_RAG_MODEL, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 150, + "temperature": 0.0, + }, + timeout=30.0 + ) + + if response.status_code != 200: + return 0.5, f"API error: {response.status_code}" + + result = response.json()["choices"][0]["message"]["content"] + + import re + score_match = re.search(r'SCORE:\s*([\d.]+)', result) + score = float(score_match.group(1)) if score_match else 0.5 + + reason_match = re.search(r'BEGRUENDUNG:\s*(.+)', result, re.DOTALL) + reason = reason_match.group(1).strip() if reason_match else result + + return min(max(score, 0.0), 1.0), reason + + except Exception as e: + return 0.5, f"Grading error: {str(e)}" + + +async def grade_documents_batch( + query: str, + documents: List[str], +) -> List[Tuple[float, str]]: + """ + Grade multiple documents for relevance. + + Returns list of (score, reason) tuples. + """ + results = [] + for doc in documents: + score, reason = await grade_document_relevance(query, doc) + results.append((score, reason)) + return results + + +async def filter_relevant_documents( + query: str, + documents: List[Dict], + threshold: float = RELEVANCE_THRESHOLD, +) -> Tuple[List[Dict], List[Dict]]: + """ + Filter documents by relevance, separating relevant from irrelevant. + + Args: + query: The search query + documents: List of document dicts with 'text' field + threshold: Minimum relevance score to keep + + Returns: + Tuple of (relevant_docs, filtered_out_docs) + """ + relevant = [] + filtered = [] + + for doc in documents: + text = doc.get("text", "") + score, reason = await grade_document_relevance(query, text) + + doc_with_grade = doc.copy() + doc_with_grade["relevance_score"] = score + doc_with_grade["relevance_reason"] = reason + + if score >= threshold: + relevant.append(doc_with_grade) + else: + filtered.append(doc_with_grade) + + # Sort relevant by score + relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True) + + return relevant, filtered + + +async def decide_retrieval_strategy( + query: str, + documents: List[Dict], + attempt: int = 1, +) -> Tuple[RetrievalDecision, Dict]: + """ + Decide what to do based on current retrieval results. + + Args: + query: The search query + documents: Retrieved documents with relevance scores + attempt: Current retrieval attempt number + + Returns: + Tuple of (decision, metadata) + """ + if not documents: + if attempt >= MAX_RETRIEVAL_ATTEMPTS: + return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"} + return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"} + + # Check average relevance + scores = [doc.get("relevance_score", 0.5) for doc in documents] + avg_score = sum(scores) / len(scores) + max_score = max(scores) + + if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7: + return RetrievalDecision.SUFFICIENT, { + "avg_relevance": avg_score, + "max_relevance": max_score, + "doc_count": len(documents), + } + + if attempt >= MAX_RETRIEVAL_ATTEMPTS: + if max_score >= RELEVANCE_THRESHOLD * 0.5: + # At least some relevant context, proceed with caution + return RetrievalDecision.SUFFICIENT, { + "avg_relevance": avg_score, + "warning": "Low relevance after max attempts", + } + return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"} + + if avg_score < 0.3: + return RetrievalDecision.REFORMULATE, { + "reason": "Very low relevance, query reformulation needed", + "avg_relevance": avg_score, + } + + return RetrievalDecision.NEEDS_MORE, { + "reason": "Moderate relevance, retrieving more documents", + "avg_relevance": avg_score, + } + + +async def grade_answer_groundedness( + answer: str, + contexts: List[str], +) -> Tuple[float, List[str]]: + """ + Grade whether an answer is grounded in the provided contexts. + + Returns: + Tuple of (grounding_score, list of unsupported claims) + """ + if not OPENAI_API_KEY: + return 0.5, ["LLM not configured for grounding check"] + + context_text = "\n---\n".join(contexts[:5]) + + prompt = f"""Analysiere, ob die folgende Antwort vollstaendig durch die Kontexte gestuetzt wird. + +KONTEXTE: +{context_text} + +ANTWORT: +{answer} + +Identifiziere: +1. Welche Aussagen sind durch die Kontexte belegt? +2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)? + +Antworte im Format: +SCORE: [0.0-1.0] (1.0 = vollstaendig belegt) +NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]""" + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + "https://api.openai.com/v1/chat/completions", + headers={ + "Authorization": f"Bearer {OPENAI_API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": SELF_RAG_MODEL, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 300, + "temperature": 0.0, + }, + timeout=30.0 + ) + + if response.status_code != 200: + return 0.5, [f"API error: {response.status_code}"] + + result = response.json()["choices"][0]["message"]["content"] + + import re + score_match = re.search(r'SCORE:\s*([\d.]+)', result) + score = float(score_match.group(1)) if score_match else 0.5 + + unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL) + unsupported_text = unsupported_match.group(1).strip() if unsupported_match else "" + + if unsupported_text.lower() == "keine": + unsupported = [] + else: + unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()] + + return min(max(score, 0.0), 1.0), unsupported + + except Exception as e: + return 0.5, [f"Grounding check error: {str(e)}"] diff --git a/klausur-service/backend/self_rag_retrieval.py b/klausur-service/backend/self_rag_retrieval.py new file mode 100644 index 0000000..ac989d7 --- /dev/null +++ b/klausur-service/backend/self_rag_retrieval.py @@ -0,0 +1,255 @@ +""" +Self-RAG Retrieval — query reformulation, retrieval loop, info. + +Extracted from self_rag.py for modularity. + +IMPORTANT: Self-RAG is DISABLED by default for privacy reasons! +When enabled, search queries and retrieved documents are sent to OpenAI API +for relevance grading and query reformulation. +""" + +import os +from typing import List, Dict, Optional +import httpx + +from self_rag_grading import ( + SELF_RAG_ENABLED, + OPENAI_API_KEY, + SELF_RAG_MODEL, + RELEVANCE_THRESHOLD, + GROUNDING_THRESHOLD, + MAX_RETRIEVAL_ATTEMPTS, + RetrievalDecision, + filter_relevant_documents, + decide_retrieval_strategy, +) + + +async def reformulate_query( + original_query: str, + context: Optional[str] = None, + previous_results_summary: Optional[str] = None, +) -> str: + """ + Reformulate a query to improve retrieval. + + Uses LLM to generate a better query based on: + - Original query + - Optional context (subject, niveau, etc.) + - Summary of why previous retrieval failed + """ + if not OPENAI_API_KEY: + # Simple reformulation: expand abbreviations, add synonyms + reformulated = original_query + expansions = { + "EA": "erhoehtes Anforderungsniveau", + "eA": "erhoehtes Anforderungsniveau", + "GA": "grundlegendes Anforderungsniveau", + "gA": "grundlegendes Anforderungsniveau", + "AFB": "Anforderungsbereich", + "Abi": "Abitur", + } + for abbr, expansion in expansions.items(): + if abbr in original_query: + reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})") + return reformulated + + prompt = f"""Du bist ein Experte fuer deutsche Bildungsstandards und Pruefungsanforderungen. + +Die folgende Suchanfrage hat keine guten Ergebnisse geliefert: +ORIGINAL: {original_query} + +{f"KONTEXT: {context}" if context else ""} +{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""} + +Formuliere die Anfrage so um, dass sie: +1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten) +2. Relevante Synonyme oder verwandte Begriffe einschliesst +3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist + +Antworte NUR mit der umformulierten Suchanfrage, ohne Erklaerung.""" + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + "https://api.openai.com/v1/chat/completions", + headers={ + "Authorization": f"Bearer {OPENAI_API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": SELF_RAG_MODEL, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 100, + "temperature": 0.3, + }, + timeout=30.0 + ) + + if response.status_code != 200: + return original_query + + return response.json()["choices"][0]["message"]["content"].strip() + + except Exception: + return original_query + + +async def self_rag_retrieve( + query: str, + search_func, + subject: Optional[str] = None, + niveau: Optional[str] = None, + initial_top_k: int = 10, + final_top_k: int = 5, + **search_kwargs +) -> Dict: + """ + Perform Self-RAG enhanced retrieval with reflection and correction. + + This implements a retrieval loop that: + 1. Retrieves initial documents + 2. Grades them for relevance + 3. Decides if more retrieval is needed + 4. Reformulates query if necessary + 5. Returns filtered, high-quality context + + Args: + query: The search query + search_func: Async function to perform the actual search + subject: Optional subject context + niveau: Optional niveau context + initial_top_k: Number of documents for initial retrieval + final_top_k: Maximum documents to return + **search_kwargs: Additional args for search_func + + Returns: + Dict with results, metadata, and reflection trace + """ + if not SELF_RAG_ENABLED: + # Fall back to simple search + results = await search_func(query=query, limit=final_top_k, **search_kwargs) + return { + "results": results, + "self_rag_enabled": False, + "query_used": query, + } + + trace = [] + current_query = query + attempt = 1 + + while attempt <= MAX_RETRIEVAL_ATTEMPTS: + # Step 1: Retrieve documents + results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs) + + trace.append({ + "attempt": attempt, + "query": current_query, + "retrieved_count": len(results) if results else 0, + }) + + if not results: + attempt += 1 + if attempt <= MAX_RETRIEVAL_ATTEMPTS: + current_query = await reformulate_query( + query, + context=f"Fach: {subject}" if subject else None, + previous_results_summary="Keine Dokumente gefunden" + ) + trace[-1]["action"] = "reformulate" + trace[-1]["new_query"] = current_query + continue + + # Step 2: Grade documents for relevance + relevant, filtered = await filter_relevant_documents(current_query, results) + + trace[-1]["relevant_count"] = len(relevant) + trace[-1]["filtered_count"] = len(filtered) + + # Step 3: Decide what to do + decision, decision_meta = await decide_retrieval_strategy( + current_query, relevant, attempt + ) + + trace[-1]["decision"] = decision.value + trace[-1]["decision_meta"] = decision_meta + + if decision == RetrievalDecision.SUFFICIENT: + # We have good context, return it + return { + "results": relevant[:final_top_k], + "self_rag_enabled": True, + "query_used": current_query, + "original_query": query if current_query != query else None, + "attempts": attempt, + "decision": decision.value, + "trace": trace, + "filtered_out_count": len(filtered), + } + + elif decision == RetrievalDecision.REFORMULATE: + # Reformulate and try again + avg_score = decision_meta.get("avg_relevance", 0) + current_query = await reformulate_query( + query, + context=f"Fach: {subject}" if subject else None, + previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}" + ) + trace[-1]["action"] = "reformulate" + trace[-1]["new_query"] = current_query + + elif decision == RetrievalDecision.NEEDS_MORE: + # Retrieve more with expanded query + current_query = f"{current_query} Bewertungskriterien Anforderungen" + trace[-1]["action"] = "expand_query" + trace[-1]["new_query"] = current_query + + elif decision == RetrievalDecision.FALLBACK: + # Return what we have, even if not ideal + return { + "results": (relevant or results)[:final_top_k], + "self_rag_enabled": True, + "query_used": current_query, + "original_query": query if current_query != query else None, + "attempts": attempt, + "decision": decision.value, + "warning": "Fallback mode - low relevance context", + "trace": trace, + } + + attempt += 1 + + # Max attempts reached + return { + "results": results[:final_top_k] if results else [], + "self_rag_enabled": True, + "query_used": current_query, + "original_query": query if current_query != query else None, + "attempts": attempt - 1, + "decision": "max_attempts", + "warning": "Max retrieval attempts reached", + "trace": trace, + } + + +def get_self_rag_info() -> dict: + """Get information about Self-RAG configuration.""" + return { + "enabled": SELF_RAG_ENABLED, + "llm_configured": bool(OPENAI_API_KEY), + "model": SELF_RAG_MODEL, + "relevance_threshold": RELEVANCE_THRESHOLD, + "grounding_threshold": GROUNDING_THRESHOLD, + "max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS, + "features": [ + "document_grading", + "relevance_filtering", + "query_reformulation", + "answer_grounding_check", + "retrieval_decision", + ], + "sends_data_externally": True, # ALWAYS true when enabled + "privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading", + "default_enabled": False, # Disabled by default for privacy + } diff --git a/klausur-service/backend/services/grid_detection_models.py b/klausur-service/backend/services/grid_detection_models.py new file mode 100644 index 0000000..dcd2bf2 --- /dev/null +++ b/klausur-service/backend/services/grid_detection_models.py @@ -0,0 +1,164 @@ +""" +Grid Detection Models v4 + +Data classes for OCR grid detection results. +Coordinates use percentage (0-100) and mm (A4 format). +""" + +from enum import Enum +from dataclasses import dataclass, field +from typing import List, Dict, Any + +# A4 dimensions +A4_WIDTH_MM = 210.0 +A4_HEIGHT_MM = 297.0 + +# Column margin (1mm) +COLUMN_MARGIN_MM = 1.0 +COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100 + + +class CellStatus(str, Enum): + EMPTY = "empty" + RECOGNIZED = "recognized" + PROBLEMATIC = "problematic" + MANUAL = "manual" + + +class ColumnType(str, Enum): + ENGLISH = "english" + GERMAN = "german" + EXAMPLE = "example" + UNKNOWN = "unknown" + + +@dataclass +class OCRRegion: + """A word/phrase detected by OCR with bounding box coordinates in percentage (0-100).""" + text: str + confidence: float + x: float # X position as percentage of page width + y: float # Y position as percentage of page height + width: float # Width as percentage of page width + height: float # Height as percentage of page height + + @property + def x_mm(self) -> float: + return round(self.x / 100 * A4_WIDTH_MM, 1) + + @property + def y_mm(self) -> float: + return round(self.y / 100 * A4_HEIGHT_MM, 1) + + @property + def width_mm(self) -> float: + return round(self.width / 100 * A4_WIDTH_MM, 1) + + @property + def height_mm(self) -> float: + return round(self.height / 100 * A4_HEIGHT_MM, 2) + + @property + def center_x(self) -> float: + return self.x + self.width / 2 + + @property + def center_y(self) -> float: + return self.y + self.height / 2 + + @property + def right(self) -> float: + return self.x + self.width + + @property + def bottom(self) -> float: + return self.y + self.height + + +@dataclass +class GridCell: + """A cell in the detected grid with coordinates in percentage (0-100).""" + row: int + col: int + x: float + y: float + width: float + height: float + text: str = "" + confidence: float = 0.0 + status: CellStatus = CellStatus.EMPTY + column_type: ColumnType = ColumnType.UNKNOWN + logical_row: int = 0 + logical_col: int = 0 + is_continuation: bool = False + + @property + def x_mm(self) -> float: + return round(self.x / 100 * A4_WIDTH_MM, 1) + + @property + def y_mm(self) -> float: + return round(self.y / 100 * A4_HEIGHT_MM, 1) + + @property + def width_mm(self) -> float: + return round(self.width / 100 * A4_WIDTH_MM, 1) + + @property + def height_mm(self) -> float: + return round(self.height / 100 * A4_HEIGHT_MM, 2) + + def to_dict(self) -> dict: + return { + "row": self.row, + "col": self.col, + "x": round(self.x, 2), + "y": round(self.y, 2), + "width": round(self.width, 2), + "height": round(self.height, 2), + "x_mm": self.x_mm, + "y_mm": self.y_mm, + "width_mm": self.width_mm, + "height_mm": self.height_mm, + "text": self.text, + "confidence": self.confidence, + "status": self.status.value, + "column_type": self.column_type.value, + "logical_row": self.logical_row, + "logical_col": self.logical_col, + "is_continuation": self.is_continuation, + } + + +@dataclass +class GridResult: + """Result of grid detection.""" + rows: int = 0 + columns: int = 0 + cells: List[List[GridCell]] = field(default_factory=list) + column_types: List[str] = field(default_factory=list) + column_boundaries: List[float] = field(default_factory=list) + row_boundaries: List[float] = field(default_factory=list) + deskew_angle: float = 0.0 + stats: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict: + cells_dicts = [] + for row_cells in self.cells: + cells_dicts.append([c.to_dict() for c in row_cells]) + + return { + "rows": self.rows, + "columns": self.columns, + "cells": cells_dicts, + "column_types": self.column_types, + "column_boundaries": [round(b, 2) for b in self.column_boundaries], + "row_boundaries": [round(b, 2) for b in self.row_boundaries], + "deskew_angle": round(self.deskew_angle, 2), + "stats": self.stats, + "page_dimensions": { + "width_mm": A4_WIDTH_MM, + "height_mm": A4_HEIGHT_MM, + "format": "A4", + }, + } diff --git a/klausur-service/backend/services/grid_detection_service.py b/klausur-service/backend/services/grid_detection_service.py index 4f6c4c2..c544275 100644 --- a/klausur-service/backend/services/grid_detection_service.py +++ b/klausur-service/backend/services/grid_detection_service.py @@ -10,166 +10,21 @@ Lizenz: Apache 2.0 (kommerziell nutzbar) import math import logging -from enum import Enum -from dataclasses import dataclass, field -from typing import List, Optional, Dict, Any, Tuple +from typing import List + +from .grid_detection_models import ( + A4_WIDTH_MM, + A4_HEIGHT_MM, + COLUMN_MARGIN_MM, + CellStatus, + ColumnType, + OCRRegion, + GridCell, + GridResult, +) logger = logging.getLogger(__name__) -# A4 dimensions -A4_WIDTH_MM = 210.0 -A4_HEIGHT_MM = 297.0 - -# Column margin (1mm) -COLUMN_MARGIN_MM = 1.0 -COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100 - - -class CellStatus(str, Enum): - EMPTY = "empty" - RECOGNIZED = "recognized" - PROBLEMATIC = "problematic" - MANUAL = "manual" - - -class ColumnType(str, Enum): - ENGLISH = "english" - GERMAN = "german" - EXAMPLE = "example" - UNKNOWN = "unknown" - - -@dataclass -class OCRRegion: - """A word/phrase detected by OCR with bounding box coordinates in percentage (0-100).""" - text: str - confidence: float - x: float # X position as percentage of page width - y: float # Y position as percentage of page height - width: float # Width as percentage of page width - height: float # Height as percentage of page height - - @property - def x_mm(self) -> float: - return round(self.x / 100 * A4_WIDTH_MM, 1) - - @property - def y_mm(self) -> float: - return round(self.y / 100 * A4_HEIGHT_MM, 1) - - @property - def width_mm(self) -> float: - return round(self.width / 100 * A4_WIDTH_MM, 1) - - @property - def height_mm(self) -> float: - return round(self.height / 100 * A4_HEIGHT_MM, 2) - - @property - def center_x(self) -> float: - return self.x + self.width / 2 - - @property - def center_y(self) -> float: - return self.y + self.height / 2 - - @property - def right(self) -> float: - return self.x + self.width - - @property - def bottom(self) -> float: - return self.y + self.height - - -@dataclass -class GridCell: - """A cell in the detected grid with coordinates in percentage (0-100).""" - row: int - col: int - x: float - y: float - width: float - height: float - text: str = "" - confidence: float = 0.0 - status: CellStatus = CellStatus.EMPTY - column_type: ColumnType = ColumnType.UNKNOWN - logical_row: int = 0 - logical_col: int = 0 - is_continuation: bool = False - - @property - def x_mm(self) -> float: - return round(self.x / 100 * A4_WIDTH_MM, 1) - - @property - def y_mm(self) -> float: - return round(self.y / 100 * A4_HEIGHT_MM, 1) - - @property - def width_mm(self) -> float: - return round(self.width / 100 * A4_WIDTH_MM, 1) - - @property - def height_mm(self) -> float: - return round(self.height / 100 * A4_HEIGHT_MM, 2) - - def to_dict(self) -> dict: - return { - "row": self.row, - "col": self.col, - "x": round(self.x, 2), - "y": round(self.y, 2), - "width": round(self.width, 2), - "height": round(self.height, 2), - "x_mm": self.x_mm, - "y_mm": self.y_mm, - "width_mm": self.width_mm, - "height_mm": self.height_mm, - "text": self.text, - "confidence": self.confidence, - "status": self.status.value, - "column_type": self.column_type.value, - "logical_row": self.logical_row, - "logical_col": self.logical_col, - "is_continuation": self.is_continuation, - } - - -@dataclass -class GridResult: - """Result of grid detection.""" - rows: int = 0 - columns: int = 0 - cells: List[List[GridCell]] = field(default_factory=list) - column_types: List[str] = field(default_factory=list) - column_boundaries: List[float] = field(default_factory=list) - row_boundaries: List[float] = field(default_factory=list) - deskew_angle: float = 0.0 - stats: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> dict: - cells_dicts = [] - for row_cells in self.cells: - cells_dicts.append([c.to_dict() for c in row_cells]) - - return { - "rows": self.rows, - "columns": self.columns, - "cells": cells_dicts, - "column_types": self.column_types, - "column_boundaries": [round(b, 2) for b in self.column_boundaries], - "row_boundaries": [round(b, 2) for b in self.row_boundaries], - "deskew_angle": round(self.deskew_angle, 2), - "stats": self.stats, - "page_dimensions": { - "width_mm": A4_WIDTH_MM, - "height_mm": A4_HEIGHT_MM, - "format": "A4", - }, - } - class GridDetectionService: """Detect grid/table structure from OCR bounding-box regions.""" @@ -184,7 +39,7 @@ class GridDetectionService: """Calculate page skew angle from OCR region positions. Uses left-edge alignment of regions to detect consistent tilt. - Returns angle in degrees, clamped to ±5°. + Returns angle in degrees, clamped to +/-5 degrees. """ if len(regions) < 3: return 0.0 @@ -229,12 +84,12 @@ class GridDetectionService: slope = (n * sum_xy - sum_y * sum_x) / denom # Convert slope to angle (slope is dx/dy in percent space) - # Adjust for aspect ratio: A4 is 210/297 ≈ 0.707 + # Adjust for aspect ratio: A4 is 210/297 ~ 0.707 aspect = A4_WIDTH_MM / A4_HEIGHT_MM angle_rad = math.atan(slope * aspect) angle_deg = math.degrees(angle_rad) - # Clamp to ±5° + # Clamp to +/-5 degrees return max(-5.0, min(5.0, round(angle_deg, 2))) def apply_deskew_to_regions(self, regions: List[OCRRegion], angle: float) -> List[OCRRegion]: diff --git a/klausur-service/backend/smart_spell.py b/klausur-service/backend/smart_spell.py index e400474..1926500 100644 --- a/klausur-service/backend/smart_spell.py +++ b/klausur-service/backend/smart_spell.py @@ -1,594 +1,25 @@ """ -SmartSpellChecker — Language-aware OCR post-correction without LLMs. +SmartSpellChecker — barrel re-export. -Uses pyspellchecker (MIT) with dual EN+DE dictionaries for: -- Automatic language detection per word (dual-dictionary heuristic) -- OCR error correction (digit↔letter, umlauts, transpositions) -- Context-based disambiguation (a/I, l/I) via bigram lookup -- Mixed-language support for example sentences +All implementation split into: + smart_spell_core — init, data types, language detection, word correction + smart_spell_text — full text correction, boundary repair, context split Lizenz: Apache 2.0 (kommerziell nutzbar) """ -import logging -import re -from dataclasses import dataclass, field -from typing import Dict, List, Literal, Optional, Set, Tuple - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Init -# --------------------------------------------------------------------------- - -try: - from spellchecker import SpellChecker as _SpellChecker - _en_spell = _SpellChecker(language='en', distance=1) - _de_spell = _SpellChecker(language='de', distance=1) - _AVAILABLE = True -except ImportError: - _AVAILABLE = False - logger.warning("pyspellchecker not installed — SmartSpellChecker disabled") - -Lang = Literal["en", "de", "both", "unknown"] - -# --------------------------------------------------------------------------- -# Bigram context for a/I disambiguation -# --------------------------------------------------------------------------- - -# Words that commonly follow "I" (subject pronoun → verb/modal) -_I_FOLLOWERS: frozenset = frozenset({ - "am", "was", "have", "had", "do", "did", "will", "would", "can", - "could", "should", "shall", "may", "might", "must", - "think", "know", "see", "want", "need", "like", "love", "hate", - "go", "went", "come", "came", "say", "said", "get", "got", - "make", "made", "take", "took", "give", "gave", "tell", "told", - "feel", "felt", "find", "found", "believe", "hope", "wish", - "remember", "forget", "understand", "mean", "meant", - "don't", "didn't", "can't", "won't", "couldn't", "wouldn't", - "shouldn't", "haven't", "hadn't", "isn't", "wasn't", - "really", "just", "also", "always", "never", "often", "sometimes", -}) - -# Words that commonly follow "a" (article → noun/adjective) -_A_FOLLOWERS: frozenset = frozenset({ - "lot", "few", "little", "bit", "good", "bad", "great", "new", "old", - "long", "short", "big", "small", "large", "huge", "tiny", - "nice", "beautiful", "wonderful", "terrible", "horrible", - "man", "woman", "boy", "girl", "child", "dog", "cat", "bird", - "book", "car", "house", "room", "school", "teacher", "student", - "day", "week", "month", "year", "time", "place", "way", - "friend", "family", "person", "problem", "question", "story", - "very", "really", "quite", "rather", "pretty", "single", -}) - -# Digit→letter substitutions (OCR confusion) -_DIGIT_SUBS: Dict[str, List[str]] = { - '0': ['o', 'O'], - '1': ['l', 'I'], - '5': ['s', 'S'], - '6': ['g', 'G'], - '8': ['b', 'B'], - '|': ['I', 'l'], - '/': ['l'], # italic 'l' misread as slash (e.g. "p/" → "pl") -} -_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys()) - -# Umlaut confusion: OCR drops dots (ü→u, ä→a, ö→o) -_UMLAUT_MAP = { - 'a': 'ä', 'o': 'ö', 'u': 'ü', 'i': 'ü', - 'A': 'Ä', 'O': 'Ö', 'U': 'Ü', 'I': 'Ü', -} - -# Tokenizer — includes | and / so OCR artifacts like "p/" are treated as words -_TOKEN_RE = re.compile(r"([A-Za-zÄÖÜäöüß'|/]+)([^A-Za-zÄÖÜäöüß'|/]*)") - - -# --------------------------------------------------------------------------- -# Data types -# --------------------------------------------------------------------------- - -@dataclass -class CorrectionResult: - original: str - corrected: str - lang_detected: Lang - changed: bool - changes: List[str] = field(default_factory=list) - - -# --------------------------------------------------------------------------- -# Core class -# --------------------------------------------------------------------------- - -class SmartSpellChecker: - """Language-aware OCR spell checker using pyspellchecker (no LLM).""" - - def __init__(self): - if not _AVAILABLE: - raise RuntimeError("pyspellchecker not installed") - self.en = _en_spell - self.de = _de_spell - - # --- Language detection --- - - def detect_word_lang(self, word: str) -> Lang: - """Detect language of a single word using dual-dict heuristic.""" - w = word.lower().strip(".,;:!?\"'()") - if not w: - return "unknown" - in_en = bool(self.en.known([w])) - in_de = bool(self.de.known([w])) - if in_en and in_de: - return "both" - if in_en: - return "en" - if in_de: - return "de" - return "unknown" - - def detect_text_lang(self, text: str) -> Lang: - """Detect dominant language of a text string (sentence/phrase).""" - words = re.findall(r"[A-Za-zÄÖÜäöüß]+", text) - if not words: - return "unknown" - - en_count = 0 - de_count = 0 - for w in words: - lang = self.detect_word_lang(w) - if lang == "en": - en_count += 1 - elif lang == "de": - de_count += 1 - # "both" doesn't count for either - - if en_count > de_count: - return "en" - if de_count > en_count: - return "de" - if en_count == de_count and en_count > 0: - return "both" - return "unknown" - - # --- Single-word correction --- - - def _known(self, word: str) -> bool: - """True if word is known in EN or DE dictionary, or is a known abbreviation.""" - w = word.lower() - if bool(self.en.known([w])) or bool(self.de.known([w])): - return True - # Also accept known abbreviations (sth, sb, adj, etc.) - try: - from cv_ocr_engines import _KNOWN_ABBREVIATIONS - if w in _KNOWN_ABBREVIATIONS: - return True - except ImportError: - pass - return False - - def _word_freq(self, word: str) -> float: - """Get word frequency (max of EN and DE).""" - w = word.lower() - return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w)) - - def _known_in(self, word: str, lang: str) -> bool: - """True if word is known in a specific language dictionary.""" - w = word.lower() - spell = self.en if lang == "en" else self.de - return bool(spell.known([w])) - - def correct_word(self, word: str, lang: str = "en", - prev_word: str = "", next_word: str = "") -> Optional[str]: - """Correct a single word for the given language. - - Returns None if no correction needed, or the corrected string. - - Args: - word: The word to check/correct - lang: Expected language ("en" or "de") - prev_word: Previous word (for context) - next_word: Next word (for context) - """ - if not word or not word.strip(): - return None - - # Skip numbers, abbreviations with dots, very short tokens - if word.isdigit() or '.' in word: - return None - - # Skip IPA/phonetic content in brackets - if '[' in word or ']' in word: - return None - - has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word) - - # 1. Already known → no fix - if self._known(word): - # But check a/I disambiguation for single-char words - if word.lower() in ('l', '|') and next_word: - return self._disambiguate_a_I(word, next_word) - return None - - # 2. Digit/pipe substitution - if has_suspicious: - if word == '|': - return 'I' - # Try single-char substitutions - for i, ch in enumerate(word): - if ch not in _DIGIT_SUBS: - continue - for replacement in _DIGIT_SUBS[ch]: - candidate = word[:i] + replacement + word[i + 1:] - if self._known(candidate): - return candidate - # Try multi-char substitution (e.g., "sch00l" → "school") - multi = self._try_multi_digit_sub(word) - if multi: - return multi - - # 3. Umlaut correction (German) - if lang == "de" and len(word) >= 3 and word.isalpha(): - umlaut_fix = self._try_umlaut_fix(word) - if umlaut_fix: - return umlaut_fix - - # 4. General spell correction - if not has_suspicious and len(word) >= 3 and word.isalpha(): - # Safety: don't correct if the word is valid in the OTHER language - # (either directly or via umlaut fix) - other_lang = "de" if lang == "en" else "en" - if self._known_in(word, other_lang): - return None - if other_lang == "de" and self._try_umlaut_fix(word): - return None # has a valid DE umlaut variant → don't touch - - spell = self.en if lang == "en" else self.de - correction = spell.correction(word.lower()) - if correction and correction != word.lower(): - if word[0].isupper(): - correction = correction[0].upper() + correction[1:] - if self._known(correction): - return correction - - return None - - # --- Multi-digit substitution --- - - def _try_multi_digit_sub(self, word: str) -> Optional[str]: - """Try replacing multiple digits simultaneously.""" - positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS] - if len(positions) < 1 or len(positions) > 4: - return None - - # Try all combinations (max 2^4 = 16 for 4 positions) - chars = list(word) - best = None - self._multi_sub_recurse(chars, positions, 0, best_result=[None]) - return self._multi_sub_recurse_result - - _multi_sub_recurse_result: Optional[str] = None - - def _try_multi_digit_sub(self, word: str) -> Optional[str]: - """Try replacing multiple digits simultaneously using BFS.""" - positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS] - if not positions or len(positions) > 4: - return None - - # BFS over substitution combinations - queue = [list(word)] - for pos, ch in positions: - next_queue = [] - for current in queue: - # Keep original - next_queue.append(current[:]) - # Try each substitution - for repl in _DIGIT_SUBS[ch]: - variant = current[:] - variant[pos] = repl - next_queue.append(variant) - queue = next_queue - - # Check which combinations produce known words - for combo in queue: - candidate = "".join(combo) - if candidate != word and self._known(candidate): - return candidate - - return None - - # --- Umlaut fix --- - - def _try_umlaut_fix(self, word: str) -> Optional[str]: - """Try single-char umlaut substitutions for German words.""" - for i, ch in enumerate(word): - if ch in _UMLAUT_MAP: - candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:] - if self._known(candidate): - return candidate - return None - - # --- Boundary repair (shifted word boundaries) --- - - def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]: - """Fix shifted word boundaries between adjacent tokens. - - OCR sometimes shifts the boundary: "at sth." → "ats th." - Try moving 1-2 chars from end of word1 to start of word2 and vice versa. - Returns (fixed_word1, fixed_word2) or None. - """ - # Import known abbreviations for vocabulary context - try: - from cv_ocr_engines import _KNOWN_ABBREVIATIONS - except ImportError: - _KNOWN_ABBREVIATIONS = set() - - # Strip trailing punctuation for checking, preserve for result - w2_stripped = word2.rstrip(".,;:!?") - w2_punct = word2[len(w2_stripped):] - - # Try shifting 1-2 chars from word1 → word2 - for shift in (1, 2): - if len(word1) <= shift: - continue - new_w1 = word1[:-shift] - new_w2_base = word1[-shift:] + w2_stripped - - w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS - w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS - - if w1_ok and w2_ok: - return (new_w1, new_w2_base + w2_punct) - - # Try shifting 1-2 chars from word2 → word1 - for shift in (1, 2): - if len(w2_stripped) <= shift: - continue - new_w1 = word1 + w2_stripped[:shift] - new_w2_base = w2_stripped[shift:] - - w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS - w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS - - if w1_ok and w2_ok: - return (new_w1, new_w2_base + w2_punct) - - return None - - # --- Context-based word split for ambiguous merges --- - - # Patterns where a valid word is actually "a" + adjective/noun - _ARTICLE_SPLIT_CANDIDATES = { - # word → (article, remainder) — only when followed by a compatible word - "anew": ("a", "new"), - "areal": ("a", "real"), - "alive": None, # genuinely one word, never split - "alone": None, - "aware": None, - "alike": None, - "apart": None, - "aside": None, - "above": None, - "about": None, - "among": None, - "along": None, - } - - def _try_context_split(self, word: str, next_word: str, - prev_word: str) -> Optional[str]: - """Split words like 'anew' → 'a new' when context indicates a merge. - - Only splits when: - - The word is in the split candidates list - - The following word makes sense as a noun (for "a + adj + noun" pattern) - - OR the word is unknown and can be split into article + known word - """ - w_lower = word.lower() - - # Check explicit candidates - if w_lower in self._ARTICLE_SPLIT_CANDIDATES: - split = self._ARTICLE_SPLIT_CANDIDATES[w_lower] - if split is None: - return None # explicitly marked as "don't split" - article, remainder = split - # Only split if followed by a word (noun pattern) - if next_word and next_word[0].islower(): - return f"{article} {remainder}" - # Also split if remainder + next_word makes a common phrase - if next_word and self._known(next_word): - return f"{article} {remainder}" - - # Generic: if word starts with 'a' and rest is a known adjective/word - if (len(word) >= 4 and word[0].lower() == 'a' - and not self._known(word) # only for UNKNOWN words - and self._known(word[1:])): - return f"a {word[1:]}" - - return None - - # --- a/I disambiguation --- - - def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]: - """Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|').""" - nw = next_word.lower().strip(".,;:!?") - if nw in _I_FOLLOWERS: - return "I" - if nw in _A_FOLLOWERS: - return "a" - # Fallback: check if next word is more commonly a verb (→I) or noun/adj (→a) - # Simple heuristic: if next word starts with uppercase (and isn't first in sentence) - # it's likely a German noun following "I"... but in English context, uppercase - # after "I" is unusual. - return None # uncertain, don't change - - # --- Full text correction --- - - def correct_text(self, text: str, lang: str = "en") -> CorrectionResult: - """Correct a full text string (field value). - - Three passes: - 1. Boundary repair — fix shifted word boundaries between adjacent tokens - 2. Context split — split ambiguous merges (anew → a new) - 3. Per-word correction — spell check individual words - - Args: - text: The text to correct - lang: Expected language ("en" or "de") - """ - if not text or not text.strip(): - return CorrectionResult(text, text, "unknown", False) - - detected = self.detect_text_lang(text) if lang == "auto" else lang - effective_lang = detected if detected in ("en", "de") else "en" - - changes: List[str] = [] - tokens = list(_TOKEN_RE.finditer(text)) - - # Extract token list: [(word, separator), ...] - token_list: List[List[str]] = [] # [[word, sep], ...] - for m in tokens: - token_list.append([m.group(1), m.group(2)]) - - # --- Pass 1: Boundary repair between adjacent unknown words --- - # Import abbreviations for the heuristic below - try: - from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS - except ImportError: - _ABBREVS = set() - - for i in range(len(token_list) - 1): - w1 = token_list[i][0] - w2_raw = token_list[i + 1][0] - - # Skip boundary repair for IPA/bracket content - # Brackets may be in the token OR in the adjacent separators - sep_before_w1 = token_list[i - 1][1] if i > 0 else "" - sep_after_w1 = token_list[i][1] - sep_after_w2 = token_list[i + 1][1] - has_bracket = ( - '[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw - or ']' in sep_after_w1 # w1 text was inside [brackets] - or '[' in sep_after_w1 # w2 starts a bracket - or ']' in sep_after_w2 # w2 text was inside [brackets] - or '[' in sep_before_w1 # w1 starts a bracket - ) - if has_bracket: - continue - - # Include trailing punct from separator in w2 for abbreviation matching - w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ") - - # Try boundary repair — always, even if both words are valid. - # Use word-frequency scoring to decide if repair is better. - repair = self._try_boundary_repair(w1, w2_with_punct) - if not repair and w2_with_punct != w2_raw: - repair = self._try_boundary_repair(w1, w2_raw) - if repair: - new_w1, new_w2_full = repair - new_w2_base = new_w2_full.rstrip(".,;:!?") - - # Frequency-based scoring: product of word frequencies - # Higher product = more common word pair = better - old_freq = self._word_freq(w1) * self._word_freq(w2_raw) - new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base) - - # Abbreviation bonus: if repair produces a known abbreviation - has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS - if has_abbrev: - # Accept abbreviation repair ONLY if at least one of the - # original words is rare/unknown (prevents "Can I" → "Ca nI" - # where both original words are common and correct). - # "Rare" = frequency < 1e-6 (covers "ats", "th" but not "Can", "I") - RARE_THRESHOLD = 1e-6 - orig_both_common = ( - self._word_freq(w1) > RARE_THRESHOLD - and self._word_freq(w2_raw) > RARE_THRESHOLD - ) - if not orig_both_common: - new_freq = max(new_freq, old_freq * 10) - else: - has_abbrev = False # both originals common → don't trust - - # Accept if repair produces a more frequent word pair - # (threshold: at least 5x more frequent to avoid false positives) - if new_freq > old_freq * 5: - new_w2_punct = new_w2_full[len(new_w2_base):] - changes.append(f"{w1} {w2_raw}→{new_w1} {new_w2_base}") - token_list[i][0] = new_w1 - token_list[i + 1][0] = new_w2_base - if new_w2_punct: - token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?") - - # --- Pass 2: Context split (anew → a new) --- - expanded: List[List[str]] = [] - for i, (word, sep) in enumerate(token_list): - next_word = token_list[i + 1][0] if i + 1 < len(token_list) else "" - prev_word = token_list[i - 1][0] if i > 0 else "" - split = self._try_context_split(word, next_word, prev_word) - if split and split != word: - changes.append(f"{word}→{split}") - expanded.append([split, sep]) - else: - expanded.append([word, sep]) - token_list = expanded - - # --- Pass 3: Per-word correction --- - parts: List[str] = [] - - # Preserve any leading text before the first token match - # (e.g., "(= " before "I won and he lost.") - first_start = tokens[0].start() if tokens else 0 - if first_start > 0: - parts.append(text[:first_start]) - - for i, (word, sep) in enumerate(token_list): - # Skip words inside IPA brackets (brackets land in separators) - prev_sep = token_list[i - 1][1] if i > 0 else "" - if '[' in prev_sep or ']' in sep: - parts.append(word) - parts.append(sep) - continue - - next_word = token_list[i + 1][0] if i + 1 < len(token_list) else "" - prev_word = token_list[i - 1][0] if i > 0 else "" - - correction = self.correct_word( - word, lang=effective_lang, - prev_word=prev_word, next_word=next_word, - ) - if correction and correction != word: - changes.append(f"{word}→{correction}") - parts.append(correction) - else: - parts.append(word) - parts.append(sep) - - # Append any trailing text - last_end = tokens[-1].end() if tokens else 0 - if last_end < len(text): - parts.append(text[last_end:]) - - corrected = "".join(parts) - return CorrectionResult( - original=text, - corrected=corrected, - lang_detected=detected, - changed=corrected != text, - changes=changes, - ) - - # --- Vocabulary entry correction --- - - def correct_vocab_entry(self, english: str, german: str, - example: str = "") -> Dict[str, CorrectionResult]: - """Correct a full vocabulary entry (EN + DE + example). - - Uses column position to determine language — the most reliable signal. - """ - results = {} - results["english"] = self.correct_text(english, lang="en") - results["german"] = self.correct_text(german, lang="de") - if example: - # For examples, auto-detect language - results["example"] = self.correct_text(example, lang="auto") - return results +# Core: data types, lang detection (re-exported for tests) +from smart_spell_core import ( # noqa: F401 + _AVAILABLE, + _DIGIT_SUBS, + _SUSPICIOUS_CHARS, + _UMLAUT_MAP, + _TOKEN_RE, + _I_FOLLOWERS, + _A_FOLLOWERS, + CorrectionResult, + Lang, +) + +# Text: SmartSpellChecker class (the main public API) +from smart_spell_text import SmartSpellChecker # noqa: F401 diff --git a/klausur-service/backend/smart_spell_core.py b/klausur-service/backend/smart_spell_core.py new file mode 100644 index 0000000..9f2fa7d --- /dev/null +++ b/klausur-service/backend/smart_spell_core.py @@ -0,0 +1,298 @@ +""" +SmartSpellChecker Core — init, data types, language detection, word correction. + +Extracted from smart_spell.py for modularity. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +""" + +import logging +import re +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Optional, Set, Tuple + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + +try: + from spellchecker import SpellChecker as _SpellChecker + _en_spell = _SpellChecker(language='en', distance=1) + _de_spell = _SpellChecker(language='de', distance=1) + _AVAILABLE = True +except ImportError: + _AVAILABLE = False + logger.warning("pyspellchecker not installed — SmartSpellChecker disabled") + +Lang = Literal["en", "de", "both", "unknown"] + +# --------------------------------------------------------------------------- +# Bigram context for a/I disambiguation +# --------------------------------------------------------------------------- + +# Words that commonly follow "I" (subject pronoun -> verb/modal) +_I_FOLLOWERS: frozenset = frozenset({ + "am", "was", "have", "had", "do", "did", "will", "would", "can", + "could", "should", "shall", "may", "might", "must", + "think", "know", "see", "want", "need", "like", "love", "hate", + "go", "went", "come", "came", "say", "said", "get", "got", + "make", "made", "take", "took", "give", "gave", "tell", "told", + "feel", "felt", "find", "found", "believe", "hope", "wish", + "remember", "forget", "understand", "mean", "meant", + "don't", "didn't", "can't", "won't", "couldn't", "wouldn't", + "shouldn't", "haven't", "hadn't", "isn't", "wasn't", + "really", "just", "also", "always", "never", "often", "sometimes", +}) + +# Words that commonly follow "a" (article -> noun/adjective) +_A_FOLLOWERS: frozenset = frozenset({ + "lot", "few", "little", "bit", "good", "bad", "great", "new", "old", + "long", "short", "big", "small", "large", "huge", "tiny", + "nice", "beautiful", "wonderful", "terrible", "horrible", + "man", "woman", "boy", "girl", "child", "dog", "cat", "bird", + "book", "car", "house", "room", "school", "teacher", "student", + "day", "week", "month", "year", "time", "place", "way", + "friend", "family", "person", "problem", "question", "story", + "very", "really", "quite", "rather", "pretty", "single", +}) + +# Digit->letter substitutions (OCR confusion) +_DIGIT_SUBS: Dict[str, List[str]] = { + '0': ['o', 'O'], + '1': ['l', 'I'], + '5': ['s', 'S'], + '6': ['g', 'G'], + '8': ['b', 'B'], + '|': ['I', 'l'], + '/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl") +} +_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys()) + +# Umlaut confusion: OCR drops dots (u->u, a->a, o->o) +_UMLAUT_MAP = { + 'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc', + 'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc', +} + +# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words +_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)") + + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- + +@dataclass +class CorrectionResult: + original: str + corrected: str + lang_detected: Lang + changed: bool + changes: List[str] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Core class — language detection and word-level correction +# --------------------------------------------------------------------------- + +class _SmartSpellCoreBase: + """Base class with language detection and single-word correction. + + Not intended for direct use — SmartSpellChecker inherits from this. + """ + + def __init__(self): + if not _AVAILABLE: + raise RuntimeError("pyspellchecker not installed") + self.en = _en_spell + self.de = _de_spell + + # --- Language detection --- + + def detect_word_lang(self, word: str) -> Lang: + """Detect language of a single word using dual-dict heuristic.""" + w = word.lower().strip(".,;:!?\"'()") + if not w: + return "unknown" + in_en = bool(self.en.known([w])) + in_de = bool(self.de.known([w])) + if in_en and in_de: + return "both" + if in_en: + return "en" + if in_de: + return "de" + return "unknown" + + def detect_text_lang(self, text: str) -> Lang: + """Detect dominant language of a text string (sentence/phrase).""" + words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text) + if not words: + return "unknown" + + en_count = 0 + de_count = 0 + for w in words: + lang = self.detect_word_lang(w) + if lang == "en": + en_count += 1 + elif lang == "de": + de_count += 1 + # "both" doesn't count for either + + if en_count > de_count: + return "en" + if de_count > en_count: + return "de" + if en_count == de_count and en_count > 0: + return "both" + return "unknown" + + # --- Single-word correction --- + + def _known(self, word: str) -> bool: + """True if word is known in EN or DE dictionary, or is a known abbreviation.""" + w = word.lower() + if bool(self.en.known([w])) or bool(self.de.known([w])): + return True + # Also accept known abbreviations (sth, sb, adj, etc.) + try: + from cv_ocr_engines import _KNOWN_ABBREVIATIONS + if w in _KNOWN_ABBREVIATIONS: + return True + except ImportError: + pass + return False + + def _word_freq(self, word: str) -> float: + """Get word frequency (max of EN and DE).""" + w = word.lower() + return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w)) + + def _known_in(self, word: str, lang: str) -> bool: + """True if word is known in a specific language dictionary.""" + w = word.lower() + spell = self.en if lang == "en" else self.de + return bool(spell.known([w])) + + def correct_word(self, word: str, lang: str = "en", + prev_word: str = "", next_word: str = "") -> Optional[str]: + """Correct a single word for the given language. + + Returns None if no correction needed, or the corrected string. + """ + if not word or not word.strip(): + return None + + # Skip numbers, abbreviations with dots, very short tokens + if word.isdigit() or '.' in word: + return None + + # Skip IPA/phonetic content in brackets + if '[' in word or ']' in word: + return None + + has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word) + + # 1. Already known -> no fix + if self._known(word): + # But check a/I disambiguation for single-char words + if word.lower() in ('l', '|') and next_word: + return self._disambiguate_a_I(word, next_word) + return None + + # 2. Digit/pipe substitution + if has_suspicious: + if word == '|': + return 'I' + # Try single-char substitutions + for i, ch in enumerate(word): + if ch not in _DIGIT_SUBS: + continue + for replacement in _DIGIT_SUBS[ch]: + candidate = word[:i] + replacement + word[i + 1:] + if self._known(candidate): + return candidate + # Try multi-char substitution (e.g., "sch00l" -> "school") + multi = self._try_multi_digit_sub(word) + if multi: + return multi + + # 3. Umlaut correction (German) + if lang == "de" and len(word) >= 3 and word.isalpha(): + umlaut_fix = self._try_umlaut_fix(word) + if umlaut_fix: + return umlaut_fix + + # 4. General spell correction + if not has_suspicious and len(word) >= 3 and word.isalpha(): + # Safety: don't correct if the word is valid in the OTHER language + other_lang = "de" if lang == "en" else "en" + if self._known_in(word, other_lang): + return None + if other_lang == "de" and self._try_umlaut_fix(word): + return None # has a valid DE umlaut variant -> don't touch + + spell = self.en if lang == "en" else self.de + correction = spell.correction(word.lower()) + if correction and correction != word.lower(): + if word[0].isupper(): + correction = correction[0].upper() + correction[1:] + if self._known(correction): + return correction + + return None + + # --- Multi-digit substitution --- + + def _try_multi_digit_sub(self, word: str) -> Optional[str]: + """Try replacing multiple digits simultaneously using BFS.""" + positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS] + if not positions or len(positions) > 4: + return None + + # BFS over substitution combinations + queue = [list(word)] + for pos, ch in positions: + next_queue = [] + for current in queue: + # Keep original + next_queue.append(current[:]) + # Try each substitution + for repl in _DIGIT_SUBS[ch]: + variant = current[:] + variant[pos] = repl + next_queue.append(variant) + queue = next_queue + + # Check which combinations produce known words + for combo in queue: + candidate = "".join(combo) + if candidate != word and self._known(candidate): + return candidate + + return None + + # --- Umlaut fix --- + + def _try_umlaut_fix(self, word: str) -> Optional[str]: + """Try single-char umlaut substitutions for German words.""" + for i, ch in enumerate(word): + if ch in _UMLAUT_MAP: + candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:] + if self._known(candidate): + return candidate + return None + + # --- a/I disambiguation --- + + def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]: + """Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|').""" + nw = next_word.lower().strip(".,;:!?") + if nw in _I_FOLLOWERS: + return "I" + if nw in _A_FOLLOWERS: + return "a" + return None # uncertain, don't change diff --git a/klausur-service/backend/smart_spell_text.py b/klausur-service/backend/smart_spell_text.py new file mode 100644 index 0000000..7628e61 --- /dev/null +++ b/klausur-service/backend/smart_spell_text.py @@ -0,0 +1,289 @@ +""" +SmartSpellChecker Text — full text correction, boundary repair, context split. + +Extracted from smart_spell.py for modularity. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +""" + +import re +from typing import Dict, List, Optional, Tuple + +from smart_spell_core import ( + _SmartSpellCoreBase, + _TOKEN_RE, + CorrectionResult, + Lang, +) + + +class SmartSpellChecker(_SmartSpellCoreBase): + """Language-aware OCR spell checker using pyspellchecker (no LLM). + + Inherits single-word correction from _SmartSpellCoreBase. + Adds text-level passes: boundary repair, context split, full correction. + """ + + # --- Boundary repair (shifted word boundaries) --- + + def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]: + """Fix shifted word boundaries between adjacent tokens. + + OCR sometimes shifts the boundary: "at sth." -> "ats th." + Try moving 1-2 chars from end of word1 to start of word2 and vice versa. + Returns (fixed_word1, fixed_word2) or None. + """ + # Import known abbreviations for vocabulary context + try: + from cv_ocr_engines import _KNOWN_ABBREVIATIONS + except ImportError: + _KNOWN_ABBREVIATIONS = set() + + # Strip trailing punctuation for checking, preserve for result + w2_stripped = word2.rstrip(".,;:!?") + w2_punct = word2[len(w2_stripped):] + + # Try shifting 1-2 chars from word1 -> word2 + for shift in (1, 2): + if len(word1) <= shift: + continue + new_w1 = word1[:-shift] + new_w2_base = word1[-shift:] + w2_stripped + + w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS + w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS + + if w1_ok and w2_ok: + return (new_w1, new_w2_base + w2_punct) + + # Try shifting 1-2 chars from word2 -> word1 + for shift in (1, 2): + if len(w2_stripped) <= shift: + continue + new_w1 = word1 + w2_stripped[:shift] + new_w2_base = w2_stripped[shift:] + + w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS + w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS + + if w1_ok and w2_ok: + return (new_w1, new_w2_base + w2_punct) + + return None + + # --- Context-based word split for ambiguous merges --- + + # Patterns where a valid word is actually "a" + adjective/noun + _ARTICLE_SPLIT_CANDIDATES = { + # word -> (article, remainder) -- only when followed by a compatible word + "anew": ("a", "new"), + "areal": ("a", "real"), + "alive": None, # genuinely one word, never split + "alone": None, + "aware": None, + "alike": None, + "apart": None, + "aside": None, + "above": None, + "about": None, + "among": None, + "along": None, + } + + def _try_context_split(self, word: str, next_word: str, + prev_word: str) -> Optional[str]: + """Split words like 'anew' -> 'a new' when context indicates a merge. + + Only splits when: + - The word is in the split candidates list + - The following word makes sense as a noun (for "a + adj + noun" pattern) + - OR the word is unknown and can be split into article + known word + """ + w_lower = word.lower() + + # Check explicit candidates + if w_lower in self._ARTICLE_SPLIT_CANDIDATES: + split = self._ARTICLE_SPLIT_CANDIDATES[w_lower] + if split is None: + return None # explicitly marked as "don't split" + article, remainder = split + # Only split if followed by a word (noun pattern) + if next_word and next_word[0].islower(): + return f"{article} {remainder}" + # Also split if remainder + next_word makes a common phrase + if next_word and self._known(next_word): + return f"{article} {remainder}" + + # Generic: if word starts with 'a' and rest is a known adjective/word + if (len(word) >= 4 and word[0].lower() == 'a' + and not self._known(word) # only for UNKNOWN words + and self._known(word[1:])): + return f"a {word[1:]}" + + return None + + # --- Full text correction --- + + def correct_text(self, text: str, lang: str = "en") -> CorrectionResult: + """Correct a full text string (field value). + + Three passes: + 1. Boundary repair -- fix shifted word boundaries between adjacent tokens + 2. Context split -- split ambiguous merges (anew -> a new) + 3. Per-word correction -- spell check individual words + """ + if not text or not text.strip(): + return CorrectionResult(text, text, "unknown", False) + + detected = self.detect_text_lang(text) if lang == "auto" else lang + effective_lang = detected if detected in ("en", "de") else "en" + + changes: List[str] = [] + tokens = list(_TOKEN_RE.finditer(text)) + + # Extract token list: [(word, separator), ...] + token_list: List[List[str]] = [] # [[word, sep], ...] + for m in tokens: + token_list.append([m.group(1), m.group(2)]) + + # --- Pass 1: Boundary repair between adjacent unknown words --- + # Import abbreviations for the heuristic below + try: + from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS + except ImportError: + _ABBREVS = set() + + for i in range(len(token_list) - 1): + w1 = token_list[i][0] + w2_raw = token_list[i + 1][0] + + # Skip boundary repair for IPA/bracket content + # Brackets may be in the token OR in the adjacent separators + sep_before_w1 = token_list[i - 1][1] if i > 0 else "" + sep_after_w1 = token_list[i][1] + sep_after_w2 = token_list[i + 1][1] + has_bracket = ( + '[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw + or ']' in sep_after_w1 # w1 text was inside [brackets] + or '[' in sep_after_w1 # w2 starts a bracket + or ']' in sep_after_w2 # w2 text was inside [brackets] + or '[' in sep_before_w1 # w1 starts a bracket + ) + if has_bracket: + continue + + # Include trailing punct from separator in w2 for abbreviation matching + w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ") + + # Try boundary repair -- always, even if both words are valid. + # Use word-frequency scoring to decide if repair is better. + repair = self._try_boundary_repair(w1, w2_with_punct) + if not repair and w2_with_punct != w2_raw: + repair = self._try_boundary_repair(w1, w2_raw) + if repair: + new_w1, new_w2_full = repair + new_w2_base = new_w2_full.rstrip(".,;:!?") + + # Frequency-based scoring: product of word frequencies + # Higher product = more common word pair = better + old_freq = self._word_freq(w1) * self._word_freq(w2_raw) + new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base) + + # Abbreviation bonus: if repair produces a known abbreviation + has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS + if has_abbrev: + # Accept abbreviation repair ONLY if at least one of the + # original words is rare/unknown (prevents "Can I" -> "Ca nI" + # where both original words are common and correct). + RARE_THRESHOLD = 1e-6 + orig_both_common = ( + self._word_freq(w1) > RARE_THRESHOLD + and self._word_freq(w2_raw) > RARE_THRESHOLD + ) + if not orig_both_common: + new_freq = max(new_freq, old_freq * 10) + else: + has_abbrev = False # both originals common -> don't trust + + # Accept if repair produces a more frequent word pair + # (threshold: at least 5x more frequent to avoid false positives) + if new_freq > old_freq * 5: + new_w2_punct = new_w2_full[len(new_w2_base):] + changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}") + token_list[i][0] = new_w1 + token_list[i + 1][0] = new_w2_base + if new_w2_punct: + token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?") + + # --- Pass 2: Context split (anew -> a new) --- + expanded: List[List[str]] = [] + for i, (word, sep) in enumerate(token_list): + next_word = token_list[i + 1][0] if i + 1 < len(token_list) else "" + prev_word = token_list[i - 1][0] if i > 0 else "" + split = self._try_context_split(word, next_word, prev_word) + if split and split != word: + changes.append(f"{word}\u2192{split}") + expanded.append([split, sep]) + else: + expanded.append([word, sep]) + token_list = expanded + + # --- Pass 3: Per-word correction --- + parts: List[str] = [] + + # Preserve any leading text before the first token match + first_start = tokens[0].start() if tokens else 0 + if first_start > 0: + parts.append(text[:first_start]) + + for i, (word, sep) in enumerate(token_list): + # Skip words inside IPA brackets (brackets land in separators) + prev_sep = token_list[i - 1][1] if i > 0 else "" + if '[' in prev_sep or ']' in sep: + parts.append(word) + parts.append(sep) + continue + + next_word = token_list[i + 1][0] if i + 1 < len(token_list) else "" + prev_word = token_list[i - 1][0] if i > 0 else "" + + correction = self.correct_word( + word, lang=effective_lang, + prev_word=prev_word, next_word=next_word, + ) + if correction and correction != word: + changes.append(f"{word}\u2192{correction}") + parts.append(correction) + else: + parts.append(word) + parts.append(sep) + + # Append any trailing text + last_end = tokens[-1].end() if tokens else 0 + if last_end < len(text): + parts.append(text[last_end:]) + + corrected = "".join(parts) + return CorrectionResult( + original=text, + corrected=corrected, + lang_detected=detected, + changed=corrected != text, + changes=changes, + ) + + # --- Vocabulary entry correction --- + + def correct_vocab_entry(self, english: str, german: str, + example: str = "") -> Dict[str, CorrectionResult]: + """Correct a full vocabulary entry (EN + DE + example). + + Uses column position to determine language -- the most reliable signal. + """ + results = {} + results["english"] = self.correct_text(english, lang="en") + results["german"] = self.correct_text(german, lang="de") + if example: + # For examples, auto-detect language + results["example"] = self.correct_text(example, lang="auto") + return results diff --git a/klausur-service/backend/upload_api.py b/klausur-service/backend/upload_api.py index 6846192..98e6f1a 100644 --- a/klausur-service/backend/upload_api.py +++ b/klausur-service/backend/upload_api.py @@ -1,602 +1,29 @@ """ -Mobile Upload API for Klausur-Service +Mobile Upload API — barrel re-export. + +All implementation split into: + upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list) + upload_api_mobile — mobile HTML upload page -Provides chunked upload endpoints for large PDF files (100MB+) from mobile devices. DSGVO-konform: Data stays local in WLAN, no external transmission. """ -import os -import uuid -import shutil -import hashlib -from pathlib import Path -from datetime import datetime, timezone -from typing import Dict, Optional - -from fastapi import APIRouter, HTTPException, UploadFile, File, Form -from fastapi.responses import HTMLResponse -from pydantic import BaseModel - -# Configuration -UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads")) -CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks")) -EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads")) - -# Ensure directories exist -UPLOAD_DIR.mkdir(parents=True, exist_ok=True) -CHUNK_DIR.mkdir(parents=True, exist_ok=True) -EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) - -# In-memory storage for upload sessions (for simplicity) -# In production, use Redis or database -_upload_sessions: Dict[str, dict] = {} - -router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"]) - - -class InitUploadRequest(BaseModel): - filename: str - filesize: int - chunks: int - destination: str = "klausur" # "klausur" or "rag" - - -class InitUploadResponse(BaseModel): - upload_id: str - chunk_size: int - total_chunks: int - message: str - - -class ChunkUploadResponse(BaseModel): - upload_id: str - chunk_index: int - received: bool - chunks_received: int - total_chunks: int - - -class FinalizeResponse(BaseModel): - upload_id: str - filename: str - filepath: str - filesize: int - checksum: str - message: str - - -@router.post("/init", response_model=InitUploadResponse) -async def init_upload(request: InitUploadRequest): - """ - Initialize a chunked upload session. - - Returns an upload_id that must be used for subsequent chunk uploads. - """ - upload_id = str(uuid.uuid4()) - - # Create session directory - session_dir = CHUNK_DIR / upload_id - session_dir.mkdir(parents=True, exist_ok=True) - - # Store session info - _upload_sessions[upload_id] = { - "filename": request.filename, - "filesize": request.filesize, - "total_chunks": request.chunks, - "received_chunks": set(), - "destination": request.destination, - "session_dir": str(session_dir), - "created_at": datetime.now(timezone.utc).isoformat(), - } - - return InitUploadResponse( - upload_id=upload_id, - chunk_size=5 * 1024 * 1024, # 5 MB - total_chunks=request.chunks, - message="Upload-Session erstellt" - ) - - -@router.post("/chunk", response_model=ChunkUploadResponse) -async def upload_chunk( - chunk: UploadFile = File(...), - upload_id: str = Form(...), - chunk_index: int = Form(...) -): - """ - Upload a single chunk of a file. - - Chunks are stored temporarily until finalize is called. - """ - if upload_id not in _upload_sessions: - raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") - - session = _upload_sessions[upload_id] - - if chunk_index < 0 or chunk_index >= session["total_chunks"]: - raise HTTPException( - status_code=400, - detail=f"Ungueltiger Chunk-Index: {chunk_index}" - ) - - # Save chunk - chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}" - - with open(chunk_path, "wb") as f: - content = await chunk.read() - f.write(content) - - # Track received chunks - session["received_chunks"].add(chunk_index) - - return ChunkUploadResponse( - upload_id=upload_id, - chunk_index=chunk_index, - received=True, - chunks_received=len(session["received_chunks"]), - total_chunks=session["total_chunks"] - ) - - -@router.post("/finalize", response_model=FinalizeResponse) -async def finalize_upload(upload_id: str = Form(...)): - """ - Finalize the upload by combining all chunks into a single file. - - Validates that all chunks were received and calculates checksum. - """ - if upload_id not in _upload_sessions: - raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") - - session = _upload_sessions[upload_id] - - # Check if all chunks received - if len(session["received_chunks"]) != session["total_chunks"]: - missing = session["total_chunks"] - len(session["received_chunks"]) - raise HTTPException( - status_code=400, - detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}" - ) - - # Determine destination directory - if session["destination"] == "rag": - dest_dir = EH_UPLOAD_DIR - else: - dest_dir = UPLOAD_DIR - - # Generate unique filename - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - safe_filename = session["filename"].replace(" ", "_") - final_filename = f"{timestamp}_{safe_filename}" - final_path = dest_dir / final_filename - - # Combine chunks - hasher = hashlib.sha256() - total_size = 0 - - with open(final_path, "wb") as outfile: - for i in range(session["total_chunks"]): - chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}" - - if not chunk_path.exists(): - raise HTTPException( - status_code=500, - detail=f"Chunk {i} nicht gefunden" - ) - - with open(chunk_path, "rb") as infile: - data = infile.read() - outfile.write(data) - hasher.update(data) - total_size += len(data) - - # Clean up chunks - shutil.rmtree(session["session_dir"], ignore_errors=True) - del _upload_sessions[upload_id] - - checksum = hasher.hexdigest() - - return FinalizeResponse( - upload_id=upload_id, - filename=final_filename, - filepath=str(final_path), - filesize=total_size, - checksum=checksum, - message="Upload erfolgreich abgeschlossen" - ) - - -@router.post("/simple") -async def simple_upload( - file: UploadFile = File(...), - destination: str = Form("klausur") -): - """ - Simple single-request upload for smaller files (<10MB). - - For larger files, use the chunked upload endpoints. - """ - # Determine destination directory - if destination == "rag": - dest_dir = EH_UPLOAD_DIR - else: - dest_dir = UPLOAD_DIR - - # Generate unique filename - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf" - final_filename = f"{timestamp}_{safe_filename}" - final_path = dest_dir / final_filename - - # Calculate checksum while writing - hasher = hashlib.sha256() - total_size = 0 - - with open(final_path, "wb") as f: - while True: - chunk = await file.read(1024 * 1024) # Read 1MB at a time - if not chunk: - break - f.write(chunk) - hasher.update(chunk) - total_size += len(chunk) - - return { - "filename": final_filename, - "filepath": str(final_path), - "filesize": total_size, - "checksum": hasher.hexdigest(), - "message": "Upload erfolgreich" - } - - -@router.get("/status/{upload_id}") -async def get_upload_status(upload_id: str): - """ - Get the status of an ongoing upload. - """ - if upload_id not in _upload_sessions: - raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") - - session = _upload_sessions[upload_id] - - return { - "upload_id": upload_id, - "filename": session["filename"], - "total_chunks": session["total_chunks"], - "received_chunks": len(session["received_chunks"]), - "progress_percent": round( - len(session["received_chunks"]) / session["total_chunks"] * 100, 1 - ), - "destination": session["destination"], - "created_at": session["created_at"] - } - - -@router.delete("/cancel/{upload_id}") -async def cancel_upload(upload_id: str): - """ - Cancel an ongoing upload and clean up temporary files. - """ - if upload_id not in _upload_sessions: - raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") - - session = _upload_sessions[upload_id] - - # Clean up chunks - shutil.rmtree(session["session_dir"], ignore_errors=True) - del _upload_sessions[upload_id] - - return {"message": "Upload abgebrochen", "upload_id": upload_id} - - -@router.get("/list") -async def list_uploads(destination: str = "klausur"): - """ - List all uploaded files in the specified destination. - """ - if destination == "rag": - dest_dir = EH_UPLOAD_DIR - else: - dest_dir = UPLOAD_DIR - - files = [] - - for f in dest_dir.iterdir(): - if f.is_file() and f.suffix.lower() == ".pdf": - stat = f.stat() - files.append({ - "filename": f.name, - "size": stat.st_size, - "modified": datetime.fromtimestamp(stat.st_mtime).isoformat(), - }) - - files.sort(key=lambda x: x["modified"], reverse=True) - - return { - "destination": destination, - "count": len(files), - "files": files[:50] # Limit to 50 most recent - } - - -@router.get("/mobile", response_class=HTMLResponse) -async def mobile_upload_page(): - """ - Serve the mobile upload page directly from the klausur-service. - This allows mobile devices to upload without needing the Next.js website. - """ - from fastapi.responses import HTMLResponse - - html_content = ''' - - - - - - BreakPilot Upload - - - -
-

BreakPilot Upload

- DSGVO-konform -
- -
- - -
- -
- -
-
PDF-Dateien hochladen
-
Tippen zum Auswaehlen oder hierher ziehen
-
Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen
-
- - - -
- -
-

Hinweise:

-
    -
  • Die Dateien werden lokal im WLAN uebertragen
  • -
  • Keine Daten werden ins Internet gesendet
  • -
  • Unterstuetzte Formate: PDF
  • -
-
- -
Server: wird ermittelt...
- - - -''' - return HTMLResponse(content=html_content) +from fastapi import APIRouter + +from upload_api_chunked import ( # noqa: F401 + router as _chunked_router, + UPLOAD_DIR, + CHUNK_DIR, + EH_UPLOAD_DIR, + _upload_sessions, + InitUploadRequest, + InitUploadResponse, + ChunkUploadResponse, + FinalizeResponse, +) +from upload_api_mobile import router as _mobile_router # noqa: F401 + +# Composite router that includes both sub-routers +router = APIRouter() +router.include_router(_chunked_router) +router.include_router(_mobile_router) diff --git a/klausur-service/backend/upload_api_chunked.py b/klausur-service/backend/upload_api_chunked.py new file mode 100644 index 0000000..13ddfff --- /dev/null +++ b/klausur-service/backend/upload_api_chunked.py @@ -0,0 +1,320 @@ +""" +Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list. + +Extracted from upload_api.py for modularity. + +DSGVO-konform: Data stays local in WLAN, no external transmission. +""" + +import os +import uuid +import shutil +import hashlib +from pathlib import Path +from datetime import datetime, timezone +from typing import Dict, Optional + +from fastapi import APIRouter, HTTPException, UploadFile, File, Form +from pydantic import BaseModel + +# Configuration +UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads")) +CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks")) +EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads")) + +# Ensure directories exist +UPLOAD_DIR.mkdir(parents=True, exist_ok=True) +CHUNK_DIR.mkdir(parents=True, exist_ok=True) +EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True) + +# In-memory storage for upload sessions (for simplicity) +# In production, use Redis or database +_upload_sessions: Dict[str, dict] = {} + +router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"]) + + +class InitUploadRequest(BaseModel): + filename: str + filesize: int + chunks: int + destination: str = "klausur" # "klausur" or "rag" + + +class InitUploadResponse(BaseModel): + upload_id: str + chunk_size: int + total_chunks: int + message: str + + +class ChunkUploadResponse(BaseModel): + upload_id: str + chunk_index: int + received: bool + chunks_received: int + total_chunks: int + + +class FinalizeResponse(BaseModel): + upload_id: str + filename: str + filepath: str + filesize: int + checksum: str + message: str + + +@router.post("/init", response_model=InitUploadResponse) +async def init_upload(request: InitUploadRequest): + """ + Initialize a chunked upload session. + + Returns an upload_id that must be used for subsequent chunk uploads. + """ + upload_id = str(uuid.uuid4()) + + # Create session directory + session_dir = CHUNK_DIR / upload_id + session_dir.mkdir(parents=True, exist_ok=True) + + # Store session info + _upload_sessions[upload_id] = { + "filename": request.filename, + "filesize": request.filesize, + "total_chunks": request.chunks, + "received_chunks": set(), + "destination": request.destination, + "session_dir": str(session_dir), + "created_at": datetime.now(timezone.utc).isoformat(), + } + + return InitUploadResponse( + upload_id=upload_id, + chunk_size=5 * 1024 * 1024, # 5 MB + total_chunks=request.chunks, + message="Upload-Session erstellt" + ) + + +@router.post("/chunk", response_model=ChunkUploadResponse) +async def upload_chunk( + chunk: UploadFile = File(...), + upload_id: str = Form(...), + chunk_index: int = Form(...) +): + """ + Upload a single chunk of a file. + + Chunks are stored temporarily until finalize is called. + """ + if upload_id not in _upload_sessions: + raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") + + session = _upload_sessions[upload_id] + + if chunk_index < 0 or chunk_index >= session["total_chunks"]: + raise HTTPException( + status_code=400, + detail=f"Ungueltiger Chunk-Index: {chunk_index}" + ) + + # Save chunk + chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}" + + with open(chunk_path, "wb") as f: + content = await chunk.read() + f.write(content) + + # Track received chunks + session["received_chunks"].add(chunk_index) + + return ChunkUploadResponse( + upload_id=upload_id, + chunk_index=chunk_index, + received=True, + chunks_received=len(session["received_chunks"]), + total_chunks=session["total_chunks"] + ) + + +@router.post("/finalize", response_model=FinalizeResponse) +async def finalize_upload(upload_id: str = Form(...)): + """ + Finalize the upload by combining all chunks into a single file. + + Validates that all chunks were received and calculates checksum. + """ + if upload_id not in _upload_sessions: + raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") + + session = _upload_sessions[upload_id] + + # Check if all chunks received + if len(session["received_chunks"]) != session["total_chunks"]: + missing = session["total_chunks"] - len(session["received_chunks"]) + raise HTTPException( + status_code=400, + detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}" + ) + + # Determine destination directory + if session["destination"] == "rag": + dest_dir = EH_UPLOAD_DIR + else: + dest_dir = UPLOAD_DIR + + # Generate unique filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + safe_filename = session["filename"].replace(" ", "_") + final_filename = f"{timestamp}_{safe_filename}" + final_path = dest_dir / final_filename + + # Combine chunks + hasher = hashlib.sha256() + total_size = 0 + + with open(final_path, "wb") as outfile: + for i in range(session["total_chunks"]): + chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}" + + if not chunk_path.exists(): + raise HTTPException( + status_code=500, + detail=f"Chunk {i} nicht gefunden" + ) + + with open(chunk_path, "rb") as infile: + data = infile.read() + outfile.write(data) + hasher.update(data) + total_size += len(data) + + # Clean up chunks + shutil.rmtree(session["session_dir"], ignore_errors=True) + del _upload_sessions[upload_id] + + checksum = hasher.hexdigest() + + return FinalizeResponse( + upload_id=upload_id, + filename=final_filename, + filepath=str(final_path), + filesize=total_size, + checksum=checksum, + message="Upload erfolgreich abgeschlossen" + ) + + +@router.post("/simple") +async def simple_upload( + file: UploadFile = File(...), + destination: str = Form("klausur") +): + """ + Simple single-request upload for smaller files (<10MB). + + For larger files, use the chunked upload endpoints. + """ + # Determine destination directory + if destination == "rag": + dest_dir = EH_UPLOAD_DIR + else: + dest_dir = UPLOAD_DIR + + # Generate unique filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf" + final_filename = f"{timestamp}_{safe_filename}" + final_path = dest_dir / final_filename + + # Calculate checksum while writing + hasher = hashlib.sha256() + total_size = 0 + + with open(final_path, "wb") as f: + while True: + chunk = await file.read(1024 * 1024) # Read 1MB at a time + if not chunk: + break + f.write(chunk) + hasher.update(chunk) + total_size += len(chunk) + + return { + "filename": final_filename, + "filepath": str(final_path), + "filesize": total_size, + "checksum": hasher.hexdigest(), + "message": "Upload erfolgreich" + } + + +@router.get("/status/{upload_id}") +async def get_upload_status(upload_id: str): + """ + Get the status of an ongoing upload. + """ + if upload_id not in _upload_sessions: + raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") + + session = _upload_sessions[upload_id] + + return { + "upload_id": upload_id, + "filename": session["filename"], + "total_chunks": session["total_chunks"], + "received_chunks": len(session["received_chunks"]), + "progress_percent": round( + len(session["received_chunks"]) / session["total_chunks"] * 100, 1 + ), + "destination": session["destination"], + "created_at": session["created_at"] + } + + +@router.delete("/cancel/{upload_id}") +async def cancel_upload(upload_id: str): + """ + Cancel an ongoing upload and clean up temporary files. + """ + if upload_id not in _upload_sessions: + raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden") + + session = _upload_sessions[upload_id] + + # Clean up chunks + shutil.rmtree(session["session_dir"], ignore_errors=True) + del _upload_sessions[upload_id] + + return {"message": "Upload abgebrochen", "upload_id": upload_id} + + +@router.get("/list") +async def list_uploads(destination: str = "klausur"): + """ + List all uploaded files in the specified destination. + """ + if destination == "rag": + dest_dir = EH_UPLOAD_DIR + else: + dest_dir = UPLOAD_DIR + + files = [] + + for f in dest_dir.iterdir(): + if f.is_file() and f.suffix.lower() == ".pdf": + stat = f.stat() + files.append({ + "filename": f.name, + "size": stat.st_size, + "modified": datetime.fromtimestamp(stat.st_mtime).isoformat(), + }) + + files.sort(key=lambda x: x["modified"], reverse=True) + + return { + "destination": destination, + "count": len(files), + "files": files[:50] # Limit to 50 most recent + } diff --git a/klausur-service/backend/upload_api_mobile.py b/klausur-service/backend/upload_api_mobile.py new file mode 100644 index 0000000..8ddd423 --- /dev/null +++ b/klausur-service/backend/upload_api_mobile.py @@ -0,0 +1,292 @@ +""" +Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service. + +Extracted from upload_api.py for modularity. + +DSGVO-konform: Data stays local in WLAN, no external transmission. +""" + +from fastapi import APIRouter +from fastapi.responses import HTMLResponse + +router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"]) + + +@router.get("/mobile", response_class=HTMLResponse) +async def mobile_upload_page(): + """ + Serve the mobile upload page directly from the klausur-service. + This allows mobile devices to upload without needing the Next.js website. + """ + html_content = ''' + + + + + + BreakPilot Upload + + + +
+

BreakPilot Upload

+ DSGVO-konform +
+ +
+ + +
+ +
+ +
+
PDF-Dateien hochladen
+
Tippen zum Auswaehlen oder hierher ziehen
+
Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen
+
+ + + +
+ +
+

Hinweise:

+
    +
  • Die Dateien werden lokal im WLAN uebertragen
  • +
  • Keine Daten werden ins Internet gesendet
  • +
  • Unterstuetzte Formate: PDF
  • +
+
+ +
Server: wird ermittelt...
+ + + +''' + return HTMLResponse(content=html_content) diff --git a/klausur-service/backend/zeugnis_api.py b/klausur-service/backend/zeugnis_api.py index 4d1618d..53e2ca2 100644 --- a/klausur-service/backend/zeugnis_api.py +++ b/klausur-service/backend/zeugnis_api.py @@ -1,537 +1,19 @@ """ -Zeugnis Rights-Aware Crawler - API Endpoints +Zeugnis Rights-Aware Crawler — barrel re-export. + +All implementation split into: + zeugnis_api_sources — sources, seed URLs, initialization + zeugnis_api_docs — documents, crawler, statistics, audit FastAPI router for managing zeugnis sources, documents, and crawler operations. """ -from datetime import datetime, timedelta -from typing import Optional, List -from fastapi import APIRouter, HTTPException, BackgroundTasks, Query -from pydantic import BaseModel +from fastapi import APIRouter -from zeugnis_models import ( - ZeugnisSource, ZeugnisSourceCreate, ZeugnisSourceVerify, - SeedUrl, SeedUrlCreate, - ZeugnisDocument, ZeugnisStats, - CrawlerStatus, CrawlRequest, CrawlQueueItem, - UsageEvent, AuditExport, - LicenseType, CrawlStatus, DocType, EventType, - BUNDESLAENDER, TRAINING_PERMISSIONS, - generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland, -) -from zeugnis_crawler import ( - start_crawler, stop_crawler, get_crawler_status, -) -from metrics_db import ( - get_zeugnis_sources, upsert_zeugnis_source, - get_zeugnis_documents, get_zeugnis_stats, - log_zeugnis_event, get_pool, -) +from zeugnis_api_sources import router as _sources_router # noqa: F401 +from zeugnis_api_docs import router as _docs_router # noqa: F401 - -router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"]) - - -# ============================================================================= -# Sources Endpoints -# ============================================================================= - -@router.get("/sources", response_model=List[dict]) -async def list_sources(): - """Get all zeugnis sources (Bundesländer).""" - sources = await get_zeugnis_sources() - if not sources: - # Return default sources if none exist - return [ - { - "id": None, - "bundesland": code, - "name": info["name"], - "base_url": None, - "license_type": str(get_license_for_bundesland(code).value), - "training_allowed": get_training_allowed(code), - "verified_by": None, - "verified_at": None, - "created_at": None, - "updated_at": None, - } - for code, info in BUNDESLAENDER.items() - ] - return sources - - -@router.post("/sources", response_model=dict) -async def create_source(source: ZeugnisSourceCreate): - """Create or update a zeugnis source.""" - source_id = generate_id() - success = await upsert_zeugnis_source( - id=source_id, - bundesland=source.bundesland, - name=source.name, - license_type=source.license_type.value, - training_allowed=source.training_allowed, - base_url=source.base_url, - ) - if not success: - raise HTTPException(status_code=500, detail="Failed to create source") - return {"id": source_id, "success": True} - - -@router.put("/sources/{source_id}/verify", response_model=dict) -async def verify_source(source_id: str, verification: ZeugnisSourceVerify): - """Verify a source's license status.""" - pool = await get_pool() - if not pool: - raise HTTPException(status_code=503, detail="Database not available") - - try: - async with pool.acquire() as conn: - await conn.execute( - """ - UPDATE zeugnis_sources - SET license_type = $2, - training_allowed = $3, - verified_by = $4, - verified_at = NOW(), - updated_at = NOW() - WHERE id = $1 - """, - source_id, verification.license_type.value, - verification.training_allowed, verification.verified_by - ) - return {"success": True, "source_id": source_id} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/sources/{bundesland}", response_model=dict) -async def get_source_by_bundesland(bundesland: str): - """Get source details for a specific Bundesland.""" - pool = await get_pool() - if not pool: - # Return default info - if bundesland not in BUNDESLAENDER: - raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}") - return { - "bundesland": bundesland, - "name": get_bundesland_name(bundesland), - "training_allowed": get_training_allowed(bundesland), - "license_type": get_license_for_bundesland(bundesland).value, - "document_count": 0, - } - - try: - async with pool.acquire() as conn: - source = await conn.fetchrow( - "SELECT * FROM zeugnis_sources WHERE bundesland = $1", - bundesland - ) - if source: - doc_count = await conn.fetchval( - """ - SELECT COUNT(*) FROM zeugnis_documents d - JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id - WHERE u.source_id = $1 - """, - source["id"] - ) - return {**dict(source), "document_count": doc_count or 0} - - # Return default - return { - "bundesland": bundesland, - "name": get_bundesland_name(bundesland), - "training_allowed": get_training_allowed(bundesland), - "license_type": get_license_for_bundesland(bundesland).value, - "document_count": 0, - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -# ============================================================================= -# Seed URLs Endpoints -# ============================================================================= - -@router.get("/sources/{source_id}/urls", response_model=List[dict]) -async def list_seed_urls(source_id: str): - """Get all seed URLs for a source.""" - pool = await get_pool() - if not pool: - return [] - - try: - async with pool.acquire() as conn: - rows = await conn.fetch( - "SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at", - source_id - ) - return [dict(r) for r in rows] - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/sources/{source_id}/urls", response_model=dict) -async def add_seed_url(source_id: str, seed_url: SeedUrlCreate): - """Add a new seed URL to a source.""" - pool = await get_pool() - if not pool: - raise HTTPException(status_code=503, detail="Database not available") - - url_id = generate_id() - try: - async with pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status) - VALUES ($1, $2, $3, $4, 'pending') - """, - url_id, source_id, seed_url.url, seed_url.doc_type.value - ) - return {"id": url_id, "success": True} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.delete("/urls/{url_id}", response_model=dict) -async def delete_seed_url(url_id: str): - """Delete a seed URL.""" - pool = await get_pool() - if not pool: - raise HTTPException(status_code=503, detail="Database not available") - - try: - async with pool.acquire() as conn: - await conn.execute( - "DELETE FROM zeugnis_seed_urls WHERE id = $1", - url_id - ) - return {"success": True} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -# ============================================================================= -# Documents Endpoints -# ============================================================================= - -@router.get("/documents", response_model=List[dict]) -async def list_documents( - bundesland: Optional[str] = None, - limit: int = Query(100, le=500), - offset: int = 0, -): - """Get all zeugnis documents with optional filtering.""" - documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset) - return documents - - -@router.get("/documents/{document_id}", response_model=dict) -async def get_document(document_id: str): - """Get details for a specific document.""" - pool = await get_pool() - if not pool: - raise HTTPException(status_code=503, detail="Database not available") - - try: - async with pool.acquire() as conn: - doc = await conn.fetchrow( - """ - SELECT d.*, s.bundesland, s.name as source_name - FROM zeugnis_documents d - JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id - JOIN zeugnis_sources s ON u.source_id = s.id - WHERE d.id = $1 - """, - document_id - ) - if not doc: - raise HTTPException(status_code=404, detail="Document not found") - - # Log view event - await log_zeugnis_event(document_id, EventType.VIEWED.value) - - return dict(doc) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/documents/{document_id}/versions", response_model=List[dict]) -async def get_document_versions(document_id: str): - """Get version history for a document.""" - pool = await get_pool() - if not pool: - raise HTTPException(status_code=503, detail="Database not available") - - try: - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT * FROM zeugnis_document_versions - WHERE document_id = $1 - ORDER BY version DESC - """, - document_id - ) - return [dict(r) for r in rows] - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -# ============================================================================= -# Crawler Control Endpoints -# ============================================================================= - -@router.get("/crawler/status", response_model=dict) -async def crawler_status(): - """Get current crawler status.""" - return get_crawler_status() - - -@router.post("/crawler/start", response_model=dict) -async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks): - """Start the crawler.""" - success = await start_crawler( - bundesland=request.bundesland, - source_id=request.source_id, - ) - if not success: - raise HTTPException(status_code=409, detail="Crawler already running") - return {"success": True, "message": "Crawler started"} - - -@router.post("/crawler/stop", response_model=dict) -async def stop_crawl(): - """Stop the crawler.""" - success = await stop_crawler() - if not success: - raise HTTPException(status_code=409, detail="Crawler not running") - return {"success": True, "message": "Crawler stopped"} - - -@router.get("/crawler/queue", response_model=List[dict]) -async def get_queue(): - """Get the crawler queue.""" - pool = await get_pool() - if not pool: - return [] - - try: - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT q.*, s.bundesland, s.name as source_name - FROM zeugnis_crawler_queue q - JOIN zeugnis_sources s ON q.source_id = s.id - ORDER BY q.priority DESC, q.created_at - """ - ) - return [dict(r) for r in rows] - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/crawler/queue", response_model=dict) -async def add_to_queue(request: CrawlRequest): - """Add a source to the crawler queue.""" - pool = await get_pool() - if not pool: - raise HTTPException(status_code=503, detail="Database not available") - - queue_id = generate_id() - try: - async with pool.acquire() as conn: - # Get source ID if bundesland provided - source_id = request.source_id - if not source_id and request.bundesland: - source = await conn.fetchrow( - "SELECT id FROM zeugnis_sources WHERE bundesland = $1", - request.bundesland - ) - if source: - source_id = source["id"] - - if not source_id: - raise HTTPException(status_code=400, detail="Source not found") - - await conn.execute( - """ - INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status) - VALUES ($1, $2, $3, 'pending') - """, - queue_id, source_id, request.priority - ) - return {"id": queue_id, "success": True} - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -# ============================================================================= -# Statistics Endpoints -# ============================================================================= - -@router.get("/stats", response_model=dict) -async def get_stats(): - """Get zeugnis crawler statistics.""" - stats = await get_zeugnis_stats() - return stats - - -@router.get("/stats/bundesland", response_model=List[dict]) -async def get_bundesland_stats(): - """Get statistics per Bundesland.""" - pool = await get_pool() - - # Build stats from BUNDESLAENDER with DB data if available - stats = [] - for code, info in BUNDESLAENDER.items(): - stat = { - "bundesland": code, - "name": info["name"], - "training_allowed": get_training_allowed(code), - "document_count": 0, - "indexed_count": 0, - "last_crawled": None, - } - - if pool: - try: - async with pool.acquire() as conn: - row = await conn.fetchrow( - """ - SELECT - COUNT(d.id) as doc_count, - COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count, - MAX(u.last_crawled) as last_crawled - FROM zeugnis_sources s - LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id - LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id - WHERE s.bundesland = $1 - GROUP BY s.id - """, - code - ) - if row: - stat["document_count"] = row["doc_count"] or 0 - stat["indexed_count"] = row["indexed_count"] or 0 - stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None - except Exception: - pass - - stats.append(stat) - - return stats - - -# ============================================================================= -# Audit Endpoints -# ============================================================================= - -@router.get("/audit/events", response_model=List[dict]) -async def get_audit_events( - document_id: Optional[str] = None, - event_type: Optional[str] = None, - limit: int = Query(100, le=1000), - days: int = Query(30, le=365), -): - """Get audit events with optional filtering.""" - pool = await get_pool() - if not pool: - return [] - - try: - since = datetime.now() - timedelta(days=days) - async with pool.acquire() as conn: - query = """ - SELECT * FROM zeugnis_usage_events - WHERE created_at >= $1 - """ - params = [since] - - if document_id: - query += " AND document_id = $2" - params.append(document_id) - if event_type: - query += f" AND event_type = ${len(params) + 1}" - params.append(event_type) - - query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}" - params.append(limit) - - rows = await conn.fetch(query, *params) - return [dict(r) for r in rows] - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/audit/export", response_model=dict) -async def export_audit( - days: int = Query(30, le=365), - requested_by: str = Query(..., description="User requesting the export"), -): - """Export audit data for GDPR compliance.""" - pool = await get_pool() - if not pool: - raise HTTPException(status_code=503, detail="Database not available") - - try: - since = datetime.now() - timedelta(days=days) - async with pool.acquire() as conn: - rows = await conn.fetch( - """ - SELECT * FROM zeugnis_usage_events - WHERE created_at >= $1 - ORDER BY created_at DESC - """, - since - ) - - doc_count = await conn.fetchval( - "SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1", - since - ) - - return { - "export_date": datetime.now().isoformat(), - "requested_by": requested_by, - "events": [dict(r) for r in rows], - "document_count": doc_count or 0, - "date_range_start": since.isoformat(), - "date_range_end": datetime.now().isoformat(), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -# ============================================================================= -# Initialization Endpoint -# ============================================================================= - -@router.post("/init", response_model=dict) -async def initialize_sources(): - """Initialize default sources from BUNDESLAENDER.""" - pool = await get_pool() - if not pool: - raise HTTPException(status_code=503, detail="Database not available") - - created = 0 - try: - for code, info in BUNDESLAENDER.items(): - source_id = generate_id() - success = await upsert_zeugnis_source( - id=source_id, - bundesland=code, - name=info["name"], - license_type=get_license_for_bundesland(code).value, - training_allowed=get_training_allowed(code), - ) - if success: - created += 1 - - return {"success": True, "sources_created": created} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) +# Composite router (used by main.py) +router = APIRouter() +router.include_router(_sources_router) +router.include_router(_docs_router) diff --git a/klausur-service/backend/zeugnis_api_docs.py b/klausur-service/backend/zeugnis_api_docs.py new file mode 100644 index 0000000..0800380 --- /dev/null +++ b/klausur-service/backend/zeugnis_api_docs.py @@ -0,0 +1,321 @@ +""" +Zeugnis API Docs — documents, crawler control, statistics, audit endpoints. + +Extracted from zeugnis_api.py for modularity. +""" + +from datetime import datetime, timedelta +from typing import Optional, List +from fastapi import APIRouter, HTTPException, BackgroundTasks, Query + +from zeugnis_models import ( + CrawlRequest, EventType, + BUNDESLAENDER, + generate_id, get_training_allowed, get_license_for_bundesland, +) +from zeugnis_crawler import ( + start_crawler, stop_crawler, get_crawler_status, +) +from metrics_db import ( + get_zeugnis_documents, get_zeugnis_stats, + log_zeugnis_event, get_pool, +) + + +router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"]) + + +# ============================================================================= +# Documents Endpoints +# ============================================================================= + +@router.get("/documents", response_model=List[dict]) +async def list_documents( + bundesland: Optional[str] = None, + limit: int = Query(100, le=500), + offset: int = 0, +): + """Get all zeugnis documents with optional filtering.""" + documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset) + return documents + + +@router.get("/documents/{document_id}", response_model=dict) +async def get_document(document_id: str): + """Get details for a specific document.""" + pool = await get_pool() + if not pool: + raise HTTPException(status_code=503, detail="Database not available") + + try: + async with pool.acquire() as conn: + doc = await conn.fetchrow( + """ + SELECT d.*, s.bundesland, s.name as source_name + FROM zeugnis_documents d + JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id + JOIN zeugnis_sources s ON u.source_id = s.id + WHERE d.id = $1 + """, + document_id + ) + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + # Log view event + await log_zeugnis_event(document_id, EventType.VIEWED.value) + + return dict(doc) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/documents/{document_id}/versions", response_model=List[dict]) +async def get_document_versions(document_id: str): + """Get version history for a document.""" + pool = await get_pool() + if not pool: + raise HTTPException(status_code=503, detail="Database not available") + + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT * FROM zeugnis_document_versions + WHERE document_id = $1 + ORDER BY version DESC + """, + document_id + ) + return [dict(r) for r in rows] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Crawler Control Endpoints +# ============================================================================= + +@router.get("/crawler/status", response_model=dict) +async def crawler_status(): + """Get current crawler status.""" + return get_crawler_status() + + +@router.post("/crawler/start", response_model=dict) +async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks): + """Start the crawler.""" + success = await start_crawler( + bundesland=request.bundesland, + source_id=request.source_id, + ) + if not success: + raise HTTPException(status_code=409, detail="Crawler already running") + return {"success": True, "message": "Crawler started"} + + +@router.post("/crawler/stop", response_model=dict) +async def stop_crawl(): + """Stop the crawler.""" + success = await stop_crawler() + if not success: + raise HTTPException(status_code=409, detail="Crawler not running") + return {"success": True, "message": "Crawler stopped"} + + +@router.get("/crawler/queue", response_model=List[dict]) +async def get_queue(): + """Get the crawler queue.""" + pool = await get_pool() + if not pool: + return [] + + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT q.*, s.bundesland, s.name as source_name + FROM zeugnis_crawler_queue q + JOIN zeugnis_sources s ON q.source_id = s.id + ORDER BY q.priority DESC, q.created_at + """ + ) + return [dict(r) for r in rows] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/crawler/queue", response_model=dict) +async def add_to_queue(request: CrawlRequest): + """Add a source to the crawler queue.""" + pool = await get_pool() + if not pool: + raise HTTPException(status_code=503, detail="Database not available") + + queue_id = generate_id() + try: + async with pool.acquire() as conn: + # Get source ID if bundesland provided + source_id = request.source_id + if not source_id and request.bundesland: + source = await conn.fetchrow( + "SELECT id FROM zeugnis_sources WHERE bundesland = $1", + request.bundesland + ) + if source: + source_id = source["id"] + + if not source_id: + raise HTTPException(status_code=400, detail="Source not found") + + await conn.execute( + """ + INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status) + VALUES ($1, $2, $3, 'pending') + """, + queue_id, source_id, request.priority + ) + return {"id": queue_id, "success": True} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Statistics Endpoints +# ============================================================================= + +@router.get("/stats", response_model=dict) +async def get_stats(): + """Get zeugnis crawler statistics.""" + stats = await get_zeugnis_stats() + return stats + + +@router.get("/stats/bundesland", response_model=List[dict]) +async def get_bundesland_stats(): + """Get statistics per Bundesland.""" + pool = await get_pool() + + # Build stats from BUNDESLAENDER with DB data if available + stats = [] + for code, info in BUNDESLAENDER.items(): + stat = { + "bundesland": code, + "name": info["name"], + "training_allowed": get_training_allowed(code), + "document_count": 0, + "indexed_count": 0, + "last_crawled": None, + } + + if pool: + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT + COUNT(d.id) as doc_count, + COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count, + MAX(u.last_crawled) as last_crawled + FROM zeugnis_sources s + LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id + LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id + WHERE s.bundesland = $1 + GROUP BY s.id + """, + code + ) + if row: + stat["document_count"] = row["doc_count"] or 0 + stat["indexed_count"] = row["indexed_count"] or 0 + stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None + except Exception: + pass + + stats.append(stat) + + return stats + + +# ============================================================================= +# Audit Endpoints +# ============================================================================= + +@router.get("/audit/events", response_model=List[dict]) +async def get_audit_events( + document_id: Optional[str] = None, + event_type: Optional[str] = None, + limit: int = Query(100, le=1000), + days: int = Query(30, le=365), +): + """Get audit events with optional filtering.""" + pool = await get_pool() + if not pool: + return [] + + try: + since = datetime.now() - timedelta(days=days) + async with pool.acquire() as conn: + query = """ + SELECT * FROM zeugnis_usage_events + WHERE created_at >= $1 + """ + params = [since] + + if document_id: + query += " AND document_id = $2" + params.append(document_id) + if event_type: + query += f" AND event_type = ${len(params) + 1}" + params.append(event_type) + + query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}" + params.append(limit) + + rows = await conn.fetch(query, *params) + return [dict(r) for r in rows] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/audit/export", response_model=dict) +async def export_audit( + days: int = Query(30, le=365), + requested_by: str = Query(..., description="User requesting the export"), +): + """Export audit data for GDPR compliance.""" + pool = await get_pool() + if not pool: + raise HTTPException(status_code=503, detail="Database not available") + + try: + since = datetime.now() - timedelta(days=days) + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT * FROM zeugnis_usage_events + WHERE created_at >= $1 + ORDER BY created_at DESC + """, + since + ) + + doc_count = await conn.fetchval( + "SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1", + since + ) + + return { + "export_date": datetime.now().isoformat(), + "requested_by": requested_by, + "events": [dict(r) for r in rows], + "document_count": doc_count or 0, + "date_range_start": since.isoformat(), + "date_range_end": datetime.now().isoformat(), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/klausur-service/backend/zeugnis_api_sources.py b/klausur-service/backend/zeugnis_api_sources.py new file mode 100644 index 0000000..3eecf28 --- /dev/null +++ b/klausur-service/backend/zeugnis_api_sources.py @@ -0,0 +1,232 @@ +""" +Zeugnis API Sources — source and seed URL management endpoints. + +Extracted from zeugnis_api.py for modularity. +""" + +from typing import Optional, List +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from zeugnis_models import ( + ZeugnisSourceCreate, ZeugnisSourceVerify, + SeedUrlCreate, + LicenseType, DocType, + BUNDESLAENDER, + generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland, +) +from metrics_db import ( + get_zeugnis_sources, upsert_zeugnis_source, get_pool, +) + + +router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"]) + + +# ============================================================================= +# Sources Endpoints +# ============================================================================= + +@router.get("/sources", response_model=List[dict]) +async def list_sources(): + """Get all zeugnis sources (Bundeslaender).""" + sources = await get_zeugnis_sources() + if not sources: + # Return default sources if none exist + return [ + { + "id": None, + "bundesland": code, + "name": info["name"], + "base_url": None, + "license_type": str(get_license_for_bundesland(code).value), + "training_allowed": get_training_allowed(code), + "verified_by": None, + "verified_at": None, + "created_at": None, + "updated_at": None, + } + for code, info in BUNDESLAENDER.items() + ] + return sources + + +@router.post("/sources", response_model=dict) +async def create_source(source: ZeugnisSourceCreate): + """Create or update a zeugnis source.""" + source_id = generate_id() + success = await upsert_zeugnis_source( + id=source_id, + bundesland=source.bundesland, + name=source.name, + license_type=source.license_type.value, + training_allowed=source.training_allowed, + base_url=source.base_url, + ) + if not success: + raise HTTPException(status_code=500, detail="Failed to create source") + return {"id": source_id, "success": True} + + +@router.put("/sources/{source_id}/verify", response_model=dict) +async def verify_source(source_id: str, verification: ZeugnisSourceVerify): + """Verify a source's license status.""" + pool = await get_pool() + if not pool: + raise HTTPException(status_code=503, detail="Database not available") + + try: + async with pool.acquire() as conn: + await conn.execute( + """ + UPDATE zeugnis_sources + SET license_type = $2, + training_allowed = $3, + verified_by = $4, + verified_at = NOW(), + updated_at = NOW() + WHERE id = $1 + """, + source_id, verification.license_type.value, + verification.training_allowed, verification.verified_by + ) + return {"success": True, "source_id": source_id} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/sources/{bundesland}", response_model=dict) +async def get_source_by_bundesland(bundesland: str): + """Get source details for a specific Bundesland.""" + pool = await get_pool() + if not pool: + # Return default info + if bundesland not in BUNDESLAENDER: + raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}") + return { + "bundesland": bundesland, + "name": get_bundesland_name(bundesland), + "training_allowed": get_training_allowed(bundesland), + "license_type": get_license_for_bundesland(bundesland).value, + "document_count": 0, + } + + try: + async with pool.acquire() as conn: + source = await conn.fetchrow( + "SELECT * FROM zeugnis_sources WHERE bundesland = $1", + bundesland + ) + if source: + doc_count = await conn.fetchval( + """ + SELECT COUNT(*) FROM zeugnis_documents d + JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id + WHERE u.source_id = $1 + """, + source["id"] + ) + return {**dict(source), "document_count": doc_count or 0} + + # Return default + return { + "bundesland": bundesland, + "name": get_bundesland_name(bundesland), + "training_allowed": get_training_allowed(bundesland), + "license_type": get_license_for_bundesland(bundesland).value, + "document_count": 0, + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Seed URLs Endpoints +# ============================================================================= + +@router.get("/sources/{source_id}/urls", response_model=List[dict]) +async def list_seed_urls(source_id: str): + """Get all seed URLs for a source.""" + pool = await get_pool() + if not pool: + return [] + + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at", + source_id + ) + return [dict(r) for r in rows] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/sources/{source_id}/urls", response_model=dict) +async def add_seed_url(source_id: str, seed_url: SeedUrlCreate): + """Add a new seed URL to a source.""" + pool = await get_pool() + if not pool: + raise HTTPException(status_code=503, detail="Database not available") + + url_id = generate_id() + try: + async with pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status) + VALUES ($1, $2, $3, $4, 'pending') + """, + url_id, source_id, seed_url.url, seed_url.doc_type.value + ) + return {"id": url_id, "success": True} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/urls/{url_id}", response_model=dict) +async def delete_seed_url(url_id: str): + """Delete a seed URL.""" + pool = await get_pool() + if not pool: + raise HTTPException(status_code=503, detail="Database not available") + + try: + async with pool.acquire() as conn: + await conn.execute( + "DELETE FROM zeugnis_seed_urls WHERE id = $1", + url_id + ) + return {"success": True} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Initialization Endpoint +# ============================================================================= + +@router.post("/init", response_model=dict) +async def initialize_sources(): + """Initialize default sources from BUNDESLAENDER.""" + pool = await get_pool() + if not pool: + raise HTTPException(status_code=503, detail="Database not available") + + created = 0 + try: + for code, info in BUNDESLAENDER.items(): + source_id = generate_id() + success = await upsert_zeugnis_source( + id=source_id, + bundesland=code, + name=info["name"], + license_type=get_license_for_bundesland(code).value, + training_allowed=get_training_allowed(code), + ) + if success: + created += 1 + + return {"success": True, "sources_created": created} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/scripts/doclayout_export_methods.py b/scripts/doclayout_export_methods.py new file mode 100644 index 0000000..2a93224 --- /dev/null +++ b/scripts/doclayout_export_methods.py @@ -0,0 +1,311 @@ +""" +PP-DocLayout ONNX Export Methods + +Download and Docker-based conversion methods for PP-DocLayout model. +Extracted from export-doclayout-onnx.py. +""" + +import hashlib +import json +import logging +import shutil +import subprocess +import tempfile +import urllib.request +from pathlib import Path + +log = logging.getLogger("export-doclayout") + +# Known download sources for pre-exported ONNX models. +DOWNLOAD_SOURCES = [ + { + "name": "PaddleOCR PP-DocLayout (ppyoloe_plus_sod, HuggingFace)", + "url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx", + "filename": "model.onnx", + "sha256": None, + }, + { + "name": "PaddleOCR PP-DocLayout (RapidOCR mirror)", + "url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx", + "filename": "model.onnx", + "sha256": None, + }, +] + +# Paddle inference model URLs (for Docker-based conversion). +PADDLE_MODEL_URL = ( + "https://paddleocr.bj.bcebos.com/PP-DocLayout/PP-DocLayout_plus.tar" +) + +# Docker image name used for conversion. +DOCKER_IMAGE_TAG = "breakpilot/paddle2onnx-converter:latest" + + +def sha256_file(path: Path) -> str: + """Compute SHA-256 hex digest for a file.""" + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + return h.hexdigest() + + +def download_file(url: str, dest: Path, desc: str = "") -> bool: + """Download a file with progress reporting. Returns True on success.""" + label = desc or url.split("/")[-1] + log.info("Downloading %s ...", label) + log.info(" URL: %s", url) + + try: + req = urllib.request.Request(url, headers={"User-Agent": "breakpilot-export/1.0"}) + with urllib.request.urlopen(req, timeout=120) as resp: + total = resp.headers.get("Content-Length") + total = int(total) if total else None + downloaded = 0 + + dest.parent.mkdir(parents=True, exist_ok=True) + with open(dest, "wb") as f: + while True: + chunk = resp.read(1 << 18) # 256 KB + if not chunk: + break + f.write(chunk) + downloaded += len(chunk) + if total: + pct = downloaded * 100 / total + mb = downloaded / (1 << 20) + total_mb = total / (1 << 20) + print( + f"\r {mb:.1f}/{total_mb:.1f} MB ({pct:.0f}%)", + end="", + flush=True, + ) + if total: + print() # newline after progress + + size_mb = dest.stat().st_size / (1 << 20) + log.info(" Downloaded %.1f MB -> %s", size_mb, dest) + return True + + except Exception as exc: + log.warning(" Download failed: %s", exc) + if dest.exists(): + dest.unlink() + return False + + +def try_download(output_dir: Path) -> bool: + """Attempt to download a pre-exported ONNX model. Returns True on success.""" + log.info("=== Method: DOWNLOAD ===") + + output_dir.mkdir(parents=True, exist_ok=True) + model_path = output_dir / "model.onnx" + + for source in DOWNLOAD_SOURCES: + log.info("Trying source: %s", source["name"]) + tmp_path = output_dir / f".{source['filename']}.tmp" + + if not download_file(source["url"], tmp_path, desc=source["name"]): + continue + + # Check SHA-256 if known. + if source["sha256"]: + actual_hash = sha256_file(tmp_path) + if actual_hash != source["sha256"]: + log.warning( + " SHA-256 mismatch: expected %s, got %s", + source["sha256"], + actual_hash, + ) + tmp_path.unlink() + continue + + # Basic sanity: file should be > 1 MB. + size = tmp_path.stat().st_size + if size < 1 << 20: + log.warning(" File too small (%.1f KB) — probably not a valid model.", size / 1024) + tmp_path.unlink() + continue + + # Move into place. + shutil.move(str(tmp_path), str(model_path)) + log.info("Model saved to %s (%.1f MB)", model_path, model_path.stat().st_size / (1 << 20)) + return True + + log.warning("All download sources failed.") + return False + + +DOCKERFILE_CONTENT = r""" +FROM --platform=linux/amd64 python:3.11-slim + +RUN pip install --no-cache-dir \ + paddlepaddle==3.0.0 \ + paddle2onnx==1.3.1 \ + onnx==1.17.0 \ + requests + +WORKDIR /work + +# Download + extract the PP-DocLayout Paddle inference model. +RUN python3 -c " +import urllib.request, tarfile, os +url = 'PADDLE_MODEL_URL_PLACEHOLDER' +print(f'Downloading {url} ...') +dest = '/work/pp_doclayout.tar' +urllib.request.urlretrieve(url, dest) +print('Extracting ...') +with tarfile.open(dest) as t: + t.extractall('/work/paddle_model') +os.remove(dest) +# List what we extracted +for root, dirs, files in os.walk('/work/paddle_model'): + for f in files: + fp = os.path.join(root, f) + sz = os.path.getsize(fp) + print(f' {fp} ({sz} bytes)') +" + +# Convert Paddle model to ONNX. +RUN python3 -c " +import os, glob, subprocess + +# Find the inference model files +model_dir = '/work/paddle_model' +pdmodel_files = glob.glob(os.path.join(model_dir, '**', '*.pdmodel'), recursive=True) +pdiparams_files = glob.glob(os.path.join(model_dir, '**', '*.pdiparams'), recursive=True) + +if not pdmodel_files: + raise FileNotFoundError('No .pdmodel file found in extracted archive') + +pdmodel = pdmodel_files[0] +pdiparams = pdiparams_files[0] if pdiparams_files else None +model_dir_actual = os.path.dirname(pdmodel) +pdmodel_name = os.path.basename(pdmodel).replace('.pdmodel', '') + +print(f'Found model: {pdmodel}') +print(f'Found params: {pdiparams}') +print(f'Model dir: {model_dir_actual}') +print(f'Model name prefix: {pdmodel_name}') + +cmd = [ + 'paddle2onnx', + '--model_dir', model_dir_actual, + '--model_filename', os.path.basename(pdmodel), +] +if pdiparams: + cmd += ['--params_filename', os.path.basename(pdiparams)] +cmd += [ + '--save_file', '/work/output/model.onnx', + '--opset_version', '14', + '--enable_onnx_checker', 'True', +] + +os.makedirs('/work/output', exist_ok=True) +print(f'Running: {\" \".join(cmd)}') +subprocess.run(cmd, check=True) + +out_size = os.path.getsize('/work/output/model.onnx') +print(f'Conversion done: /work/output/model.onnx ({out_size} bytes)') +" + +CMD ["cp", "-v", "/work/output/model.onnx", "/output/model.onnx"] +""".replace( + "PADDLE_MODEL_URL_PLACEHOLDER", PADDLE_MODEL_URL +) + + +def try_docker(output_dir: Path) -> bool: + """Build a Docker image to convert the Paddle model to ONNX. Returns True on success.""" + log.info("=== Method: DOCKER (linux/amd64) ===") + + # Check Docker is available. + docker_bin = shutil.which("docker") or "/usr/local/bin/docker" + try: + subprocess.run( + [docker_bin, "version"], + capture_output=True, + check=True, + timeout=15, + ) + except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired) as exc: + log.error("Docker is not available: %s", exc) + return False + + output_dir.mkdir(parents=True, exist_ok=True) + + with tempfile.TemporaryDirectory(prefix="doclayout-export-") as tmpdir: + tmpdir = Path(tmpdir) + + # Write Dockerfile. + dockerfile_path = tmpdir / "Dockerfile" + dockerfile_path.write_text(DOCKERFILE_CONTENT) + log.info("Wrote Dockerfile to %s", dockerfile_path) + + # Build image. + log.info("Building Docker image (this downloads ~2 GB, may take a while) ...") + build_cmd = [ + docker_bin, "build", + "--platform", "linux/amd64", + "-t", DOCKER_IMAGE_TAG, + "-f", str(dockerfile_path), + str(tmpdir), + ] + log.info(" %s", " ".join(build_cmd)) + build_result = subprocess.run( + build_cmd, + capture_output=False, + timeout=1200, + ) + if build_result.returncode != 0: + log.error("Docker build failed (exit code %d).", build_result.returncode) + return False + + # Run container. + log.info("Running conversion container ...") + run_cmd = [ + docker_bin, "run", + "--rm", + "--platform", "linux/amd64", + "-v", f"{output_dir.resolve()}:/output", + DOCKER_IMAGE_TAG, + ] + log.info(" %s", " ".join(run_cmd)) + run_result = subprocess.run( + run_cmd, + capture_output=False, + timeout=300, + ) + if run_result.returncode != 0: + log.error("Docker run failed (exit code %d).", run_result.returncode) + return False + + model_path = output_dir / "model.onnx" + if model_path.exists(): + size_mb = model_path.stat().st_size / (1 << 20) + log.info("Model exported: %s (%.1f MB)", model_path, size_mb) + return True + else: + log.error("Expected output file not found: %s", model_path) + return False + + +def write_metadata(output_dir: Path, method: str, class_labels: list, model_input_shape: tuple) -> None: + """Write a metadata JSON next to the model for provenance tracking.""" + model_path = output_dir / "model.onnx" + if not model_path.exists(): + return + + meta = { + "model": "PP-DocLayout", + "format": "ONNX", + "export_method": method, + "class_labels": class_labels, + "input_shape": list(model_input_shape), + "file_size_bytes": model_path.stat().st_size, + "sha256": sha256_file(model_path), + } + meta_path = output_dir / "metadata.json" + with open(meta_path, "w") as f: + json.dump(meta, f, indent=2) + log.info("Metadata written to %s", meta_path) diff --git a/scripts/export-doclayout-onnx.py b/scripts/export-doclayout-onnx.py index 0d76271..08d9479 100755 --- a/scripts/export-doclayout-onnx.py +++ b/scripts/export-doclayout-onnx.py @@ -13,15 +13,8 @@ Usage: """ import argparse -import hashlib -import json import logging -import os -import shutil -import subprocess import sys -import tempfile -import urllib.request from pathlib import Path logging.basicConfig( @@ -49,92 +42,23 @@ CLASS_LABELS = [ "abstract", ] -# Known download sources for pre-exported ONNX models. -# Ordered by preference — first successful download wins. -DOWNLOAD_SOURCES = [ - { - "name": "PaddleOCR PP-DocLayout (ppyoloe_plus_sod, HuggingFace)", - "url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx", - "filename": "model.onnx", - "sha256": None, # populated once a known-good hash is available - }, - { - "name": "PaddleOCR PP-DocLayout (RapidOCR mirror)", - "url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx", - "filename": "model.onnx", - "sha256": None, - }, -] - -# Paddle inference model URLs (for Docker-based conversion). -PADDLE_MODEL_URL = ( - "https://paddleocr.bj.bcebos.com/PP-DocLayout/PP-DocLayout_plus.tar" -) - # Expected input shape for the model (batch, channels, height, width). MODEL_INPUT_SHAPE = (1, 3, 800, 800) -# Docker image name used for conversion. -DOCKER_IMAGE_TAG = "breakpilot/paddle2onnx-converter:latest" +# Import methods from sibling module +from doclayout_export_methods import ( + try_download, + try_docker, + write_metadata, + sha256_file, +) + # --------------------------------------------------------------------------- -# Helpers +# Verification # --------------------------------------------------------------------------- -def sha256_file(path: Path) -> str: - """Compute SHA-256 hex digest for a file.""" - h = hashlib.sha256() - with open(path, "rb") as f: - for chunk in iter(lambda: f.read(1 << 20), b""): - h.update(chunk) - return h.hexdigest() - - -def download_file(url: str, dest: Path, desc: str = "") -> bool: - """Download a file with progress reporting. Returns True on success.""" - label = desc or url.split("/")[-1] - log.info("Downloading %s ...", label) - log.info(" URL: %s", url) - - try: - req = urllib.request.Request(url, headers={"User-Agent": "breakpilot-export/1.0"}) - with urllib.request.urlopen(req, timeout=120) as resp: - total = resp.headers.get("Content-Length") - total = int(total) if total else None - downloaded = 0 - - dest.parent.mkdir(parents=True, exist_ok=True) - with open(dest, "wb") as f: - while True: - chunk = resp.read(1 << 18) # 256 KB - if not chunk: - break - f.write(chunk) - downloaded += len(chunk) - if total: - pct = downloaded * 100 / total - mb = downloaded / (1 << 20) - total_mb = total / (1 << 20) - print( - f"\r {mb:.1f}/{total_mb:.1f} MB ({pct:.0f}%)", - end="", - flush=True, - ) - if total: - print() # newline after progress - - size_mb = dest.stat().st_size / (1 << 20) - log.info(" Downloaded %.1f MB -> %s", size_mb, dest) - return True - - except Exception as exc: - log.warning(" Download failed: %s", exc) - if dest.exists(): - dest.unlink() - return False - - def verify_onnx(model_path: Path) -> bool: """Load the ONNX model with onnxruntime, run a dummy inference, check outputs.""" log.info("Verifying ONNX model: %s", model_path) @@ -169,24 +93,23 @@ def verify_onnx(model_path: Path) -> bool: for out in outputs: log.info(" %s: shape=%s dtype=%s", out.name, out.shape, out.type) - # Build dummy input — use the first input's name and expected shape. + # Build dummy input input_name = inputs[0].name input_shape = inputs[0].shape - # Replace dynamic dims (strings or None) with concrete sizes. + # Replace dynamic dims with concrete sizes. concrete_shape = [] for i, dim in enumerate(input_shape): if isinstance(dim, (int,)) and dim > 0: concrete_shape.append(dim) elif i == 0: - concrete_shape.append(1) # batch + concrete_shape.append(1) elif i == 1: - concrete_shape.append(3) # channels + concrete_shape.append(3) else: - concrete_shape.append(800) # spatial + concrete_shape.append(800) concrete_shape = tuple(concrete_shape) - # Fallback if shape looks wrong — use standard MODEL_INPUT_SHAPE. if len(concrete_shape) != 4: concrete_shape = MODEL_INPUT_SHAPE @@ -199,20 +122,15 @@ def verify_onnx(model_path: Path) -> bool: arr = np.asarray(r) log.info(" output[%d]: shape=%s dtype=%s", i, arr.shape, arr.dtype) - # Basic sanity checks if len(result) == 0: log.error(" Model produced no outputs!") return False - # Check for at least one output with a bounding-box-like shape (N, 4) or - # a detection-like structure. Be lenient — different ONNX exports vary. has_plausible_output = False for r in result: arr = np.asarray(r) - # Common detection output shapes: (1, N, 6), (N, 4), (N, 6), (1, N, 5+C), etc. if arr.ndim >= 2 and any(d >= 4 for d in arr.shape): has_plausible_output = True - # Some models output (N,) labels or scores if arr.ndim >= 1 and arr.size > 0: has_plausible_output = True @@ -229,238 +147,6 @@ def verify_onnx(model_path: Path) -> bool: return False -# --------------------------------------------------------------------------- -# Method: Download -# --------------------------------------------------------------------------- - - -def try_download(output_dir: Path) -> bool: - """Attempt to download a pre-exported ONNX model. Returns True on success.""" - log.info("=== Method: DOWNLOAD ===") - - output_dir.mkdir(parents=True, exist_ok=True) - model_path = output_dir / "model.onnx" - - for source in DOWNLOAD_SOURCES: - log.info("Trying source: %s", source["name"]) - tmp_path = output_dir / f".{source['filename']}.tmp" - - if not download_file(source["url"], tmp_path, desc=source["name"]): - continue - - # Check SHA-256 if known. - if source["sha256"]: - actual_hash = sha256_file(tmp_path) - if actual_hash != source["sha256"]: - log.warning( - " SHA-256 mismatch: expected %s, got %s", - source["sha256"], - actual_hash, - ) - tmp_path.unlink() - continue - - # Basic sanity: file should be > 1 MB (a real ONNX model, not an error page). - size = tmp_path.stat().st_size - if size < 1 << 20: - log.warning(" File too small (%.1f KB) — probably not a valid model.", size / 1024) - tmp_path.unlink() - continue - - # Move into place. - shutil.move(str(tmp_path), str(model_path)) - log.info("Model saved to %s (%.1f MB)", model_path, model_path.stat().st_size / (1 << 20)) - return True - - log.warning("All download sources failed.") - return False - - -# --------------------------------------------------------------------------- -# Method: Docker -# --------------------------------------------------------------------------- - -DOCKERFILE_CONTENT = r""" -FROM --platform=linux/amd64 python:3.11-slim - -RUN pip install --no-cache-dir \ - paddlepaddle==3.0.0 \ - paddle2onnx==1.3.1 \ - onnx==1.17.0 \ - requests - -WORKDIR /work - -# Download + extract the PP-DocLayout Paddle inference model. -RUN python3 -c " -import urllib.request, tarfile, os -url = 'PADDLE_MODEL_URL_PLACEHOLDER' -print(f'Downloading {url} ...') -dest = '/work/pp_doclayout.tar' -urllib.request.urlretrieve(url, dest) -print('Extracting ...') -with tarfile.open(dest) as t: - t.extractall('/work/paddle_model') -os.remove(dest) -# List what we extracted -for root, dirs, files in os.walk('/work/paddle_model'): - for f in files: - fp = os.path.join(root, f) - sz = os.path.getsize(fp) - print(f' {fp} ({sz} bytes)') -" - -# Convert Paddle model to ONNX. -# paddle2onnx expects model_dir with model.pdmodel + model.pdiparams -RUN python3 -c " -import os, glob, subprocess - -# Find the inference model files -model_dir = '/work/paddle_model' -pdmodel_files = glob.glob(os.path.join(model_dir, '**', '*.pdmodel'), recursive=True) -pdiparams_files = glob.glob(os.path.join(model_dir, '**', '*.pdiparams'), recursive=True) - -if not pdmodel_files: - raise FileNotFoundError('No .pdmodel file found in extracted archive') - -pdmodel = pdmodel_files[0] -pdiparams = pdiparams_files[0] if pdiparams_files else None -model_dir_actual = os.path.dirname(pdmodel) -pdmodel_name = os.path.basename(pdmodel).replace('.pdmodel', '') - -print(f'Found model: {pdmodel}') -print(f'Found params: {pdiparams}') -print(f'Model dir: {model_dir_actual}') -print(f'Model name prefix: {pdmodel_name}') - -cmd = [ - 'paddle2onnx', - '--model_dir', model_dir_actual, - '--model_filename', os.path.basename(pdmodel), -] -if pdiparams: - cmd += ['--params_filename', os.path.basename(pdiparams)] -cmd += [ - '--save_file', '/work/output/model.onnx', - '--opset_version', '14', - '--enable_onnx_checker', 'True', -] - -os.makedirs('/work/output', exist_ok=True) -print(f'Running: {\" \".join(cmd)}') -subprocess.run(cmd, check=True) - -out_size = os.path.getsize('/work/output/model.onnx') -print(f'Conversion done: /work/output/model.onnx ({out_size} bytes)') -" - -CMD ["cp", "-v", "/work/output/model.onnx", "/output/model.onnx"] -""".replace( - "PADDLE_MODEL_URL_PLACEHOLDER", PADDLE_MODEL_URL -) - - -def try_docker(output_dir: Path) -> bool: - """Build a Docker image to convert the Paddle model to ONNX. Returns True on success.""" - log.info("=== Method: DOCKER (linux/amd64) ===") - - # Check Docker is available. - docker_bin = shutil.which("docker") or "/usr/local/bin/docker" - try: - subprocess.run( - [docker_bin, "version"], - capture_output=True, - check=True, - timeout=15, - ) - except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired) as exc: - log.error("Docker is not available: %s", exc) - return False - - output_dir.mkdir(parents=True, exist_ok=True) - - with tempfile.TemporaryDirectory(prefix="doclayout-export-") as tmpdir: - tmpdir = Path(tmpdir) - - # Write Dockerfile. - dockerfile_path = tmpdir / "Dockerfile" - dockerfile_path.write_text(DOCKERFILE_CONTENT) - log.info("Wrote Dockerfile to %s", dockerfile_path) - - # Build image. - log.info("Building Docker image (this downloads ~2 GB, may take a while) ...") - build_cmd = [ - docker_bin, "build", - "--platform", "linux/amd64", - "-t", DOCKER_IMAGE_TAG, - "-f", str(dockerfile_path), - str(tmpdir), - ] - log.info(" %s", " ".join(build_cmd)) - build_result = subprocess.run( - build_cmd, - capture_output=False, # stream output to terminal - timeout=1200, # 20 min - ) - if build_result.returncode != 0: - log.error("Docker build failed (exit code %d).", build_result.returncode) - return False - - # Run container — mount output_dir as /output, the CMD copies model.onnx there. - log.info("Running conversion container ...") - run_cmd = [ - docker_bin, "run", - "--rm", - "--platform", "linux/amd64", - "-v", f"{output_dir.resolve()}:/output", - DOCKER_IMAGE_TAG, - ] - log.info(" %s", " ".join(run_cmd)) - run_result = subprocess.run( - run_cmd, - capture_output=False, - timeout=300, - ) - if run_result.returncode != 0: - log.error("Docker run failed (exit code %d).", run_result.returncode) - return False - - model_path = output_dir / "model.onnx" - if model_path.exists(): - size_mb = model_path.stat().st_size / (1 << 20) - log.info("Model exported: %s (%.1f MB)", model_path, size_mb) - return True - else: - log.error("Expected output file not found: %s", model_path) - return False - - -# --------------------------------------------------------------------------- -# Write metadata -# --------------------------------------------------------------------------- - - -def write_metadata(output_dir: Path, method: str) -> None: - """Write a metadata JSON next to the model for provenance tracking.""" - model_path = output_dir / "model.onnx" - if not model_path.exists(): - return - - meta = { - "model": "PP-DocLayout", - "format": "ONNX", - "export_method": method, - "class_labels": CLASS_LABELS, - "input_shape": list(MODEL_INPUT_SHAPE), - "file_size_bytes": model_path.stat().st_size, - "sha256": sha256_file(model_path), - } - meta_path = output_dir / "metadata.json" - with open(meta_path, "w") as f: - json.dump(meta, f, indent=2) - log.info("Metadata written to %s", meta_path) - - # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- @@ -527,7 +213,7 @@ def main() -> int: return 1 # Write metadata. - write_metadata(output_dir, used_method) + write_metadata(output_dir, used_method, CLASS_LABELS, MODEL_INPUT_SHAPE) # Verify. if not args.skip_verify: diff --git a/website/app/lehrer/abitur-archiv/_components/archiv-constants.ts b/website/app/lehrer/abitur-archiv/_components/archiv-constants.ts new file mode 100644 index 0000000..672da84 --- /dev/null +++ b/website/app/lehrer/abitur-archiv/_components/archiv-constants.ts @@ -0,0 +1,84 @@ +/** + * Abitur-Archiv Constants & Mock Data + * + * Extracted from page.tsx. + */ + +// API Base URL +export const API_BASE = '/api/education/abitur-archiv' + +// Filter constants +export const FAECHER = [ + { id: 'deutsch', label: 'Deutsch' }, + { id: 'englisch', label: 'Englisch' }, + { id: 'mathematik', label: 'Mathematik' }, + { id: 'biologie', label: 'Biologie' }, + { id: 'physik', label: 'Physik' }, + { id: 'chemie', label: 'Chemie' }, + { id: 'geschichte', label: 'Geschichte' }, +] + +export const JAHRE = [2025, 2024, 2023, 2022, 2021] + +export const NIVEAUS = [ + { id: 'eA', label: 'Erhoehtes Niveau (eA)' }, + { id: 'gA', label: 'Grundlegendes Niveau (gA)' }, +] + +export const TYPEN = [ + { id: 'aufgabe', label: 'Aufgabe' }, + { id: 'erwartungshorizont', label: 'Erwartungshorizont' }, +] + +export interface AbiturDokument { + id: string + dateiname: string + fach: string + jahr: number + niveau: 'eA' | 'gA' + typ: 'aufgabe' | 'erwartungshorizont' + aufgaben_nummer: string + status: string + file_path: string + file_size: number +} + +export function getMockDocuments(): AbiturDokument[] { + const docs: AbiturDokument[] = [] + const faecher = ['deutsch', 'englisch'] + const jahre = [2024, 2023, 2022] + const niveaus: Array<'eA' | 'gA'> = ['eA', 'gA'] + const typen: Array<'aufgabe' | 'erwartungshorizont'> = ['aufgabe', 'erwartungshorizont'] + const nummern = ['I', 'II', 'III'] + + let id = 1 + for (const jahr of jahre) { + for (const fach of faecher) { + for (const niveau of niveaus) { + for (const nummer of nummern) { + for (const typ of typen) { + docs.push({ + id: `doc-${id++}`, + dateiname: `${jahr}_${fach}_${niveau}_${nummer}${typ === 'erwartungshorizont' ? '_EWH' : ''}.pdf`, + fach, + jahr, + niveau, + typ, + aufgaben_nummer: nummer, + status: 'indexed', + file_path: '#', + file_size: 250000 + Math.random() * 500000 + }) + } + } + } + } + } + return docs +} + +export function formatFileSize(bytes: number): string { + if (bytes < 1024) return bytes + ' B' + if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB' + return (bytes / (1024 * 1024)).toFixed(1) + ' MB' +} diff --git a/website/app/lehrer/abitur-archiv/page.tsx b/website/app/lehrer/abitur-archiv/page.tsx index 351d82d..c2e6576 100644 --- a/website/app/lehrer/abitur-archiv/page.tsx +++ b/website/app/lehrer/abitur-archiv/page.tsx @@ -11,45 +11,10 @@ import { FileText, Filter, ChevronLeft, ChevronRight, Eye, Download, Search, X, Loader2, Grid, List, LayoutGrid, Plus, Archive, BookOpen } from 'lucide-react' - -// API Base URL -const API_BASE = '/api/education/abitur-archiv' - -// Filter constants -const FAECHER = [ - { id: 'deutsch', label: 'Deutsch' }, - { id: 'englisch', label: 'Englisch' }, - { id: 'mathematik', label: 'Mathematik' }, - { id: 'biologie', label: 'Biologie' }, - { id: 'physik', label: 'Physik' }, - { id: 'chemie', label: 'Chemie' }, - { id: 'geschichte', label: 'Geschichte' }, -] - -const JAHRE = [2025, 2024, 2023, 2022, 2021] - -const NIVEAUS = [ - { id: 'eA', label: 'Erhoehtes Niveau (eA)' }, - { id: 'gA', label: 'Grundlegendes Niveau (gA)' }, -] - -const TYPEN = [ - { id: 'aufgabe', label: 'Aufgabe' }, - { id: 'erwartungshorizont', label: 'Erwartungshorizont' }, -] - -interface AbiturDokument { - id: string - dateiname: string - fach: string - jahr: number - niveau: 'eA' | 'gA' - typ: 'aufgabe' | 'erwartungshorizont' - aufgaben_nummer: string - status: string - file_path: string - file_size: number -} +import { + API_BASE, FAECHER, JAHRE, NIVEAUS, TYPEN, + type AbiturDokument, getMockDocuments, formatFileSize, +} from './_components/archiv-constants' export default function AbiturArchivPage() { const [documents, setDocuments] = useState([]) @@ -140,12 +105,6 @@ export default function AbiturArchivPage() { const hasActiveFilters = filterFach || filterJahr || filterNiveau || filterTyp || searchQuery - const formatFileSize = (bytes: number) => { - if (bytes < 1024) return bytes + ' B' - if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB' - return (bytes / (1024 * 1024)).toFixed(1) + ' MB' - } - return (
{/* Header */} @@ -469,36 +428,3 @@ export default function AbiturArchivPage() { ) } -function getMockDocuments(): AbiturDokument[] { - const docs: AbiturDokument[] = [] - const faecher = ['deutsch', 'englisch'] - const jahre = [2024, 2023, 2022] - const niveaus: Array<'eA' | 'gA'> = ['eA', 'gA'] - const typen: Array<'aufgabe' | 'erwartungshorizont'> = ['aufgabe', 'erwartungshorizont'] - const nummern = ['I', 'II', 'III'] - - let id = 1 - for (const jahr of jahre) { - for (const fach of faecher) { - for (const niveau of niveaus) { - for (const nummer of nummern) { - for (const typ of typen) { - docs.push({ - id: `doc-${id++}`, - dateiname: `${jahr}_${fach}_${niveau}_${nummer}${typ === 'erwartungshorizont' ? '_EWH' : ''}.pdf`, - fach, - jahr, - niveau, - typ, - aufgaben_nummer: nummer, - status: 'indexed', - file_path: '#', - file_size: 250000 + Math.random() * 500000 - }) - } - } - } - } - } - return docs -} diff --git a/website/components/compliance/charts/DependencyMap.tsx b/website/components/compliance/charts/DependencyMap.tsx index a70718a..6f1cadc 100644 --- a/website/components/compliance/charts/DependencyMap.tsx +++ b/website/components/compliance/charts/DependencyMap.tsx @@ -10,67 +10,9 @@ */ import { useState, useMemo } from 'react' -import { Language, getTerm } from '@/lib/compliance-i18n' - -interface Requirement { - id: string - article: string - title: string - regulation_code: string -} - -interface Control { - id: string - control_id: string - title: string - domain: string - status: string -} - -interface Mapping { - requirement_id: string - control_id: string - coverage_level: 'full' | 'partial' | 'planned' -} - -interface DependencyMapProps { - requirements: Requirement[] - controls: Control[] - mappings: Mapping[] - lang?: Language - onControlClick?: (control: Control) => void - onRequirementClick?: (requirement: Requirement) => void -} - -const DOMAIN_COLORS: Record = { - gov: '#64748b', - priv: '#3b82f6', - iam: '#a855f7', - crypto: '#eab308', - sdlc: '#22c55e', - ops: '#f97316', - ai: '#ec4899', - cra: '#06b6d4', - aud: '#6366f1', -} - -const DOMAIN_LABELS: Record = { - gov: 'Governance', - priv: 'Datenschutz', - iam: 'Identity & Access', - crypto: 'Kryptografie', - sdlc: 'Secure Dev', - ops: 'Operations', - ai: 'KI-spezifisch', - cra: 'Supply Chain', - aud: 'Audit', -} - -const COVERAGE_COLORS: Record = { - full: { bg: 'bg-green-100', border: 'border-green-500', text: 'text-green-700' }, - partial: { bg: 'bg-yellow-100', border: 'border-yellow-500', text: 'text-yellow-700' }, - planned: { bg: 'bg-slate-100', border: 'border-slate-400', text: 'text-slate-600' }, -} +import type { Control, Requirement, DependencyMapProps } from './DependencyMapTypes' +import { DOMAIN_COLORS, COVERAGE_COLORS } from './DependencyMapTypes' +import { DependencyMapSankey } from './DependencyMapSankey' export default function DependencyMap({ requirements, @@ -115,7 +57,7 @@ export default function DependencyMap({ // Build mapping lookup const mappingLookup = useMemo(() => { - const lookup: Record> = {} + const lookup: Record> = {} mappings.forEach((m) => { if (!lookup[m.control_id]) lookup[m.control_id] = {} lookup[m.control_id][m.requirement_id] = m @@ -123,11 +65,6 @@ export default function DependencyMap({ return lookup }, [mappings]) - // Get connected requirements for a control - const getConnectedRequirements = (controlId: string) => { - return Object.keys(mappingLookup[controlId] || {}) - } - // Get connected controls for a requirement const getConnectedControls = (requirementId: string) => { return Object.keys(mappingLookup) @@ -201,81 +138,42 @@ export default function DependencyMap({ {/* Statistics Header */}
-

- {lang === 'de' ? 'Abdeckung' : 'Coverage'} -

+

{lang === 'de' ? 'Abdeckung' : 'Coverage'}

{stats.coveragePercent}%

-

- {stats.coveredRequirements}/{stats.totalRequirements} {lang === 'de' ? 'Anforderungen' : 'Requirements'} -

+

{stats.coveredRequirements}/{stats.totalRequirements} {lang === 'de' ? 'Anforderungen' : 'Requirements'}

-

- {lang === 'de' ? 'Vollstaendig' : 'Full'} -

+

{lang === 'de' ? 'Vollstaendig' : 'Full'}

{stats.fullMappings}

-

{lang === 'de' ? 'Mappings' : 'Mappings'}

+

Mappings

-

- {lang === 'de' ? 'Teilweise' : 'Partial'} -

+

{lang === 'de' ? 'Teilweise' : 'Partial'}

{stats.partialMappings}

-

{lang === 'de' ? 'Mappings' : 'Mappings'}

+

Mappings

-

- {lang === 'de' ? 'Geplant' : 'Planned'} -

+

{lang === 'de' ? 'Geplant' : 'Planned'}

{stats.plannedMappings}

-

{lang === 'de' ? 'Mappings' : 'Mappings'}

+

Mappings

{/* Filters */}
- setFilterRegulation(e.target.value)} className="px-3 py-2 border rounded-lg focus:ring-2 focus:ring-primary-500"> - {regulations.map((reg) => ( - - ))} + {regulations.map((reg) => ())} - - setFilterDomain(e.target.value)} className="px-3 py-2 border rounded-lg focus:ring-2 focus:ring-primary-500"> - {domains.map((dom) => ( - - ))} + {domains.map((dom) => ())} -
-
- - + +
@@ -284,96 +182,38 @@ export default function DependencyMap({ {viewMode === 'matrix' ? (
- {/* Matrix Header */}
{filteredControls.map((control) => ( -
handleControlClick(control)} - className={` - w-20 flex-shrink-0 text-center p-2 cursor-pointer transition-colors - ${selectedControl === control.control_id ? 'bg-primary-100' : 'hover:bg-slate-50'} - `} - > -
-

- {control.control_id} -

+
handleControlClick(control)} className={`w-20 flex-shrink-0 text-center p-2 cursor-pointer transition-colors ${selectedControl === control.control_id ? 'bg-primary-100' : 'hover:bg-slate-50'}`}> +
+

{control.control_id}

))}
- - {/* Matrix Body */} {filteredRequirements.map((req) => { const connectedControls = getConnectedControls(req.id) - const isHighlighted = selectedRequirement === req.id || - (selectedControl && connectedControls.some((c) => c.controlId === selectedControl)) - + const isHighlighted = selectedRequirement === req.id || (selectedControl && connectedControls.some((c) => c.controlId === selectedControl)) return ( -
-
handleRequirementClick(req)} - className={` - w-48 flex-shrink-0 p-2 cursor-pointer transition-colors - ${selectedRequirement === req.id ? 'bg-primary-100' : 'hover:bg-slate-50'} - `} - > -

- {req.regulation_code} {req.article} -

-

- {req.title} -

+
+
handleRequirementClick(req)} className={`w-48 flex-shrink-0 p-2 cursor-pointer transition-colors ${selectedRequirement === req.id ? 'bg-primary-100' : 'hover:bg-slate-50'}`}> +

{req.regulation_code} {req.article}

+

{req.title}

{filteredControls.map((control) => { const mapping = mappingLookup[control.control_id]?.[req.id] const isControlHighlighted = selectedControl === control.control_id const isConnected = selectedControl && mapping - return ( -
+
{mapping && ( -
- {mapping.coverage_level === 'full' && ( - - - - )} - {mapping.coverage_level === 'partial' && ( - - - - )} - {mapping.coverage_level === 'planned' && ( - - - - )} +
+ {mapping.coverage_level === 'full' && ()} + {mapping.coverage_level === 'partial' && ()} + {mapping.coverage_level === 'planned' && ()}
)}
@@ -386,179 +226,35 @@ export default function DependencyMap({
) : ( - /* Sankey/Connection View */ -
-
- {/* Controls Column */} -
-

- Controls ({filteredControls.length}) -

- {filteredControls.map((control) => { - const connectedReqs = getConnectedRequirements(control.control_id) - const isSelected = selectedControl === control.control_id - - return ( - - ) - })} -
- - {/* Connection Lines (simplified) */} -
-
- {selectedControl && ( -
- {getConnectedRequirements(selectedControl).slice(0, 10).map((reqId, idx) => { - const req = requirements.find((r) => r.id === reqId) - const mapping = mappingLookup[selectedControl][reqId] - if (!req) return null - - return ( -
- {req.regulation_code} {req.article} -
- ) - })} - {getConnectedRequirements(selectedControl).length > 10 && ( - - +{getConnectedRequirements(selectedControl).length - 10} {lang === 'de' ? 'weitere' : 'more'} - - )} -
- )} - {selectedRequirement && ( -
- {getConnectedControls(selectedRequirement).slice(0, 10).map(({ controlId, coverage }) => { - const control = controls.find((c) => c.control_id === controlId) - if (!control) return null - - return ( -
- {control.control_id} -
- ) - })} -
- )} - {!selectedControl && !selectedRequirement && ( -
-

- {lang === 'de' - ? 'Waehlen Sie ein Control oder eine Anforderung aus' - : 'Select a control or requirement'} -

-
- )} -
-
- - {/* Requirements Column */} -
-

- {lang === 'de' ? 'Anforderungen' : 'Requirements'} ({filteredRequirements.length}) -

- {filteredRequirements.slice(0, 15).map((req) => { - const connectedCtrls = getConnectedControls(req.id) - const isSelected = selectedRequirement === req.id - const isHighlighted = selectedControl && connectedCtrls.some((c) => c.controlId === selectedControl) - - return ( - - ) - })} - {filteredRequirements.length > 15 && ( -

- +{filteredRequirements.length - 15} {lang === 'de' ? 'weitere' : 'more'} -

- )} -
-
-
+ )} {/* Legend */}
-
-
- - - + {(['full', 'partial', 'planned'] as const).map((level) => ( +
+
+ {level === 'full' && ()} + {level === 'partial' && ()} + {level === 'planned' && ()} +
+ + {lang === 'de' ? (level === 'full' ? 'Vollstaendig abgedeckt' : level === 'partial' ? 'Teilweise abgedeckt' : 'Geplant') : (level === 'full' ? 'Fully covered' : level === 'partial' ? 'Partially covered' : 'Planned')} +
- - {lang === 'de' ? 'Vollstaendig abgedeckt' : 'Fully covered'} - -
-
-
- - - -
- - {lang === 'de' ? 'Teilweise abgedeckt' : 'Partially covered'} - -
-
-
- - - -
- - {lang === 'de' ? 'Geplant' : 'Planned'} - -
+ ))}
diff --git a/website/components/compliance/charts/DependencyMapSankey.tsx b/website/components/compliance/charts/DependencyMapSankey.tsx new file mode 100644 index 0000000..1e639e9 --- /dev/null +++ b/website/components/compliance/charts/DependencyMapSankey.tsx @@ -0,0 +1,190 @@ +'use client' + +/** + * DependencyMap Sankey/Connection View + * + * Extracted from DependencyMap to keep each file under 500 LOC. + */ + +import type { Language } from '@/lib/compliance-i18n' +import type { Requirement, Control, Mapping } from './DependencyMapTypes' +import { DOMAIN_COLORS, COVERAGE_COLORS } from './DependencyMapTypes' + +interface DependencyMapSankeyProps { + filteredControls: Control[] + filteredRequirements: Requirement[] + requirements: Requirement[] + controls: Control[] + mappingLookup: Record> + selectedControl: string | null + selectedRequirement: string | null + onControlClick: (control: Control) => void + onRequirementClick: (requirement: Requirement) => void + lang: Language +} + +export function DependencyMapSankey({ + filteredControls, + filteredRequirements, + requirements, + controls, + mappingLookup, + selectedControl, + selectedRequirement, + onControlClick, + onRequirementClick, + lang, +}: DependencyMapSankeyProps) { + const getConnectedRequirements = (controlId: string) => { + return Object.keys(mappingLookup[controlId] || {}) + } + + const getConnectedControls = (requirementId: string) => { + return Object.keys(mappingLookup) + .filter((controlId) => mappingLookup[controlId][requirementId]) + .map((controlId) => ({ + controlId, + coverage: mappingLookup[controlId][requirementId].coverage_level, + })) + } + + return ( +
+
+ {/* Controls Column */} +
+

+ Controls ({filteredControls.length}) +

+ {filteredControls.map((control) => { + const connectedReqs = getConnectedRequirements(control.control_id) + const isSelected = selectedControl === control.control_id + + return ( + + ) + })} +
+ + {/* Connection Lines (simplified) */} +
+
+ {selectedControl && ( +
+ {getConnectedRequirements(selectedControl).slice(0, 10).map((reqId) => { + const req = requirements.find((r) => r.id === reqId) + const mapping = mappingLookup[selectedControl][reqId] + if (!req) return null + + return ( +
+ {req.regulation_code} {req.article} +
+ ) + })} + {getConnectedRequirements(selectedControl).length > 10 && ( + + +{getConnectedRequirements(selectedControl).length - 10} {lang === 'de' ? 'weitere' : 'more'} + + )} +
+ )} + {selectedRequirement && ( +
+ {getConnectedControls(selectedRequirement).slice(0, 10).map(({ controlId, coverage }) => { + const control = controls.find((c) => c.control_id === controlId) + if (!control) return null + + return ( +
+ {control.control_id} +
+ ) + })} +
+ )} + {!selectedControl && !selectedRequirement && ( +
+

+ {lang === 'de' + ? 'Waehlen Sie ein Control oder eine Anforderung aus' + : 'Select a control or requirement'} +

+
+ )} +
+
+ + {/* Requirements Column */} +
+

+ {lang === 'de' ? 'Anforderungen' : 'Requirements'} ({filteredRequirements.length}) +

+ {filteredRequirements.slice(0, 15).map((req) => { + const connectedCtrls = getConnectedControls(req.id) + const isSelected = selectedRequirement === req.id + const isHighlighted = selectedControl && connectedCtrls.some((c) => c.controlId === selectedControl) + + return ( + + ) + })} + {filteredRequirements.length > 15 && ( +

+ +{filteredRequirements.length - 15} {lang === 'de' ? 'weitere' : 'more'} +

+ )} +
+
+
+ ) +} diff --git a/website/components/compliance/charts/DependencyMapTypes.ts b/website/components/compliance/charts/DependencyMapTypes.ts new file mode 100644 index 0000000..000e2d4 --- /dev/null +++ b/website/components/compliance/charts/DependencyMapTypes.ts @@ -0,0 +1,65 @@ +/** + * Types and constants for DependencyMap component. + */ + +import type { Language } from '@/lib/compliance-i18n' + +export interface Requirement { + id: string + article: string + title: string + regulation_code: string +} + +export interface Control { + id: string + control_id: string + title: string + domain: string + status: string +} + +export interface Mapping { + requirement_id: string + control_id: string + coverage_level: 'full' | 'partial' | 'planned' +} + +export interface DependencyMapProps { + requirements: Requirement[] + controls: Control[] + mappings: Mapping[] + lang?: Language + onControlClick?: (control: Control) => void + onRequirementClick?: (requirement: Requirement) => void +} + +export const DOMAIN_COLORS: Record = { + gov: '#64748b', + priv: '#3b82f6', + iam: '#a855f7', + crypto: '#eab308', + sdlc: '#22c55e', + ops: '#f97316', + ai: '#ec4899', + cra: '#06b6d4', + aud: '#6366f1', +} + +export const DOMAIN_LABELS: Record = { + gov: 'Governance', + priv: 'Datenschutz', + iam: 'Identity & Access', + crypto: 'Kryptografie', + sdlc: 'Secure Dev', + ops: 'Operations', + ai: 'KI-spezifisch', + cra: 'Supply Chain', + aud: 'Audit', +} + +export const COVERAGE_COLORS: Record = { + full: { bg: 'bg-green-100', border: 'border-green-500', text: 'text-green-700' }, + partial: { bg: 'bg-yellow-100', border: 'border-yellow-500', text: 'text-yellow-700' }, + planned: { bg: 'bg-slate-100', border: 'border-slate-400', text: 'text-slate-600' }, +}