diff --git a/admin-lehrer/components/ocr-pipeline/StepGridReviewStats.tsx b/admin-lehrer/components/ocr-pipeline/StepGridReviewStats.tsx
new file mode 100644
index 0000000..a711827
--- /dev/null
+++ b/admin-lehrer/components/ocr-pipeline/StepGridReviewStats.tsx
@@ -0,0 +1,180 @@
+'use client'
+
+/**
+ * StepGridReview Stats Bar & OCR Quality Controls
+ *
+ * Extracted from StepGridReview.tsx to stay under 500 LOC.
+ */
+
+import type { GridZone } from '@/components/grid-editor/types'
+
+interface GridSummary {
+ total_zones: number
+ total_columns: number
+ total_rows: number
+ total_cells: number
+}
+
+interface DictionaryDetection {
+ is_dictionary: boolean
+ confidence: number
+}
+
+interface PageNumber {
+ text?: string
+ number?: number | null
+}
+
+interface ReviewStatsBarProps {
+ summary: GridSummary
+ dictionaryDetection?: DictionaryDetection | null
+ pageNumber?: PageNumber | null
+ lowConfCount: number
+ acceptedCount: number
+ totalRows: number
+ ocrEnhance: boolean
+ ocrMaxCols: number
+ ocrMinConf: number
+ visionFusion: boolean
+ documentCategory: string
+ durationSeconds: number
+ showImage: boolean
+ onOcrEnhanceChange: (v: boolean) => void
+ onOcrMaxColsChange: (v: number) => void
+ onOcrMinConfChange: (v: number) => void
+ onVisionFusionChange: (v: boolean) => void
+ onDocumentCategoryChange: (v: string) => void
+ onAcceptAll: () => void
+ onAutoCorrect: () => number
+ onToggleImage: () => void
+}
+
+export function ReviewStatsBar({
+ summary,
+ dictionaryDetection,
+ pageNumber,
+ lowConfCount,
+ acceptedCount,
+ totalRows,
+ ocrEnhance,
+ ocrMaxCols,
+ ocrMinConf,
+ visionFusion,
+ documentCategory,
+ durationSeconds,
+ showImage,
+ onOcrEnhanceChange,
+ onOcrMaxColsChange,
+ onOcrMinConfChange,
+ onVisionFusionChange,
+ onDocumentCategoryChange,
+ onAcceptAll,
+ onAutoCorrect,
+ onToggleImage,
+}: ReviewStatsBarProps) {
+ return (
+
+
+ {summary.total_zones} Zone(n), {summary.total_columns} Spalten,{' '}
+ {summary.total_rows} Zeilen, {summary.total_cells} Zellen
+
+ {dictionaryDetection?.is_dictionary && (
+
+ Woerterbuch ({Math.round(dictionaryDetection.confidence * 100)}%)
+
+ )}
+ {pageNumber?.text && (
+
+ S. {pageNumber.number ?? pageNumber.text}
+
+ )}
+ {lowConfCount > 0 && (
+
+ {lowConfCount} niedrige Konfidenz
+
+ )}
+
+ {acceptedCount}/{totalRows} Zeilen akzeptiert
+
+ {acceptedCount < totalRows && (
+
+ Alle akzeptieren
+
+ )}
+
+ {/* OCR Quality Steps */}
+
|
+
+ onOcrEnhanceChange(e.target.checked)} className="rounded w-3 h-3" />
+ CLAHE
+
+
+ MaxCol:
+ onOcrMaxColsChange(Number(e.target.value))} className="px-1 py-0.5 text-xs rounded border border-gray-200 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-700 dark:text-gray-300">
+ off
+ 2
+ 3
+ 4
+ 5
+
+
+
+ MinConf:
+ onOcrMinConfChange(Number(e.target.value))} className="px-1 py-0.5 text-xs rounded border border-gray-200 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-700 dark:text-gray-300">
+ auto
+ 20
+ 30
+ 40
+ 50
+ 60
+
+
+
+
|
+
+ onVisionFusionChange(e.target.checked)} className="rounded w-3 h-3 accent-orange-500" />
+ Vision-LLM
+
+
+ Typ:
+ onDocumentCategoryChange(e.target.value)} className="px-1 py-0.5 text-xs rounded border border-gray-200 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-700 dark:text-gray-300">
+ Vokabelseite
+ Woerterbuch
+ Arbeitsblatt
+ Buchseite
+ Sonstiges
+
+
+
+
+ {
+ const n = onAutoCorrect()
+ if (n === 0) alert('Keine Muster-Korrekturen gefunden.')
+ else alert(`${n} Zelle(n) korrigiert (Muster-Vervollstaendigung).`)
+ }}
+ className="px-2.5 py-1 rounded text-xs border border-purple-200 dark:border-purple-700 bg-purple-50 dark:bg-purple-900/20 text-purple-700 dark:text-purple-300 hover:bg-purple-100 dark:hover:bg-purple-900/40 transition-colors"
+ title="Erkennt Muster wie p.70, p.71 und vervollstaendigt partielle Eintraege wie .65 zu p.65"
+ >
+ Auto-Korrektur
+
+
+ {showImage ? 'Bild ausblenden' : 'Bild einblenden'}
+
+
+ {durationSeconds.toFixed(1)}s
+
+
+
+ )
+}
diff --git a/admin-lehrer/components/ocr/GridOverlay.tsx b/admin-lehrer/components/ocr/GridOverlay.tsx
index 0014bc9..23fbe61 100644
--- a/admin-lehrer/components/ocr/GridOverlay.tsx
+++ b/admin-lehrer/components/ocr/GridOverlay.tsx
@@ -474,76 +474,5 @@ export function GridOverlay({
)
}
-/**
- * GridStats Component
- */
-interface GridStatsProps {
- stats: GridData['stats']
- deskewAngle?: number
- source?: string
- className?: string
-}
-
-export function GridStats({ stats, deskewAngle, source, className }: GridStatsProps) {
- const coveragePercent = Math.round(stats.coverage * 100)
-
- return (
-
-
- Erkannt: {stats.recognized}
-
- {(stats.manual ?? 0) > 0 && (
-
- Manuell: {stats.manual}
-
- )}
- {stats.problematic > 0 && (
-
- Problematisch: {stats.problematic}
-
- )}
-
- Leer: {stats.empty}
-
-
- Abdeckung: {coveragePercent}%
-
- {deskewAngle !== undefined && deskewAngle !== 0 && (
-
- Begradigt: {deskewAngle.toFixed(1)}
-
- )}
- {source && (
-
- Quelle: {source === 'tesseract+grid_service' ? 'Tesseract' : 'Vision LLM'}
-
- )}
-
- )
-}
-
-/**
- * Legend Component for GridOverlay
- */
-export function GridLegend({ className }: { className?: string }) {
- return (
-
- )
-}
+// Re-export widgets from sibling file for backwards compatibility
+export { GridStats, GridLegend } from './GridOverlayWidgets'
diff --git a/admin-lehrer/components/ocr/GridOverlayWidgets.tsx b/admin-lehrer/components/ocr/GridOverlayWidgets.tsx
new file mode 100644
index 0000000..2c20856
--- /dev/null
+++ b/admin-lehrer/components/ocr/GridOverlayWidgets.tsx
@@ -0,0 +1,84 @@
+'use client'
+
+/**
+ * GridOverlay Widgets - GridStats and GridLegend
+ *
+ * Extracted from GridOverlay.tsx to keep each file under 500 LOC.
+ */
+
+import { cn } from '@/lib/utils'
+import type { GridData } from './GridOverlay'
+
+/**
+ * GridStats Component
+ */
+interface GridStatsProps {
+ stats: GridData['stats']
+ deskewAngle?: number
+ source?: string
+ className?: string
+}
+
+export function GridStats({ stats, deskewAngle, source, className }: GridStatsProps) {
+ const coveragePercent = Math.round(stats.coverage * 100)
+
+ return (
+
+
+ Erkannt: {stats.recognized}
+
+ {(stats.manual ?? 0) > 0 && (
+
+ Manuell: {stats.manual}
+
+ )}
+ {stats.problematic > 0 && (
+
+ Problematisch: {stats.problematic}
+
+ )}
+
+ Leer: {stats.empty}
+
+
+ Abdeckung: {coveragePercent}%
+
+ {deskewAngle !== undefined && deskewAngle !== 0 && (
+
+ Begradigt: {deskewAngle.toFixed(1)}
+
+ )}
+ {source && (
+
+ Quelle: {source === 'tesseract+grid_service' ? 'Tesseract' : 'Vision LLM'}
+
+ )}
+
+ )
+}
+
+/**
+ * Legend Component for GridOverlay
+ */
+export function GridLegend({ className }: { className?: string }) {
+ return (
+
+ )
+}
diff --git a/agent-core/brain/__init__.py b/agent-core/brain/__init__.py
index 789ecff..4bb6e7a 100644
--- a/agent-core/brain/__init__.py
+++ b/agent-core/brain/__init__.py
@@ -7,15 +7,31 @@ Provides:
- KnowledgeGraph: Entity relationships and semantic connections
"""
-from agent_core.brain.memory_store import MemoryStore, Memory
-from agent_core.brain.context_manager import ConversationContext, ContextManager
-from agent_core.brain.knowledge_graph import KnowledgeGraph, Entity, Relationship
+from agent_core.brain.memory_models import Memory
+from agent_core.brain.memory_store import MemoryStore
+from agent_core.brain.context_models import (
+ MessageRole,
+ Message,
+ ConversationContext,
+)
+from agent_core.brain.context_manager import ContextManager
+from agent_core.brain.knowledge_models import (
+ EntityType,
+ RelationshipType,
+ Entity,
+ Relationship,
+)
+from agent_core.brain.knowledge_graph import KnowledgeGraph
__all__ = [
"MemoryStore",
"Memory",
+ "MessageRole",
+ "Message",
"ConversationContext",
"ContextManager",
+ "EntityType",
+ "RelationshipType",
"KnowledgeGraph",
"Entity",
"Relationship",
diff --git a/agent-core/brain/context_manager.py b/agent-core/brain/context_manager.py
index e142d33..6d7dfca 100644
--- a/agent-core/brain/context_manager.py
+++ b/agent-core/brain/context_manager.py
@@ -1,317 +1,22 @@
"""
Context Management for Breakpilot Agents
-Provides conversation context with:
-- Message history with compression
-- Entity extraction and tracking
-- Intent history
-- Context summarization
+Manages conversation contexts for multiple sessions with persistence.
"""
from typing import Dict, Any, List, Optional, Callable, Awaitable
-from dataclasses import dataclass, field
-from datetime import datetime, timezone
-from enum import Enum
import json
import logging
+from agent_core.brain.context_models import (
+ MessageRole,
+ Message,
+ ConversationContext,
+)
+
logger = logging.getLogger(__name__)
-class MessageRole(Enum):
- """Message roles in a conversation"""
- SYSTEM = "system"
- USER = "user"
- ASSISTANT = "assistant"
- TOOL = "tool"
-
-
-@dataclass
-class Message:
- """Represents a message in a conversation"""
- role: MessageRole
- content: str
- timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
- metadata: Dict[str, Any] = field(default_factory=dict)
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "role": self.role.value,
- "content": self.content,
- "timestamp": self.timestamp.isoformat(),
- "metadata": self.metadata
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "Message":
- return cls(
- role=MessageRole(data["role"]),
- content=data["content"],
- timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc),
- metadata=data.get("metadata", {})
- )
-
-
-@dataclass
-class ConversationContext:
- """
- Context for a running conversation.
-
- Maintains:
- - Message history with automatic compression
- - Extracted entities
- - Intent history
- - Conversation summary
- """
- messages: List[Message] = field(default_factory=list)
- entities: Dict[str, Any] = field(default_factory=dict)
- intent_history: List[str] = field(default_factory=list)
- summary: Optional[str] = None
- max_messages: int = 50
- system_prompt: Optional[str] = None
- metadata: Dict[str, Any] = field(default_factory=dict)
-
- def add_message(
- self,
- role: MessageRole,
- content: str,
- metadata: Optional[Dict[str, Any]] = None
- ) -> Message:
- """
- Adds a message to the conversation.
-
- Args:
- role: Message role
- content: Message content
- metadata: Optional message metadata
-
- Returns:
- The created Message
- """
- message = Message(
- role=role,
- content=content,
- metadata=metadata or {}
- )
- self.messages.append(message)
-
- # Compress if needed
- if len(self.messages) > self.max_messages:
- self._compress_history()
-
- return message
-
- def add_user_message(
- self,
- content: str,
- metadata: Optional[Dict[str, Any]] = None
- ) -> Message:
- """Convenience method to add a user message"""
- return self.add_message(MessageRole.USER, content, metadata)
-
- def add_assistant_message(
- self,
- content: str,
- metadata: Optional[Dict[str, Any]] = None
- ) -> Message:
- """Convenience method to add an assistant message"""
- return self.add_message(MessageRole.ASSISTANT, content, metadata)
-
- def add_system_message(
- self,
- content: str,
- metadata: Optional[Dict[str, Any]] = None
- ) -> Message:
- """Convenience method to add a system message"""
- return self.add_message(MessageRole.SYSTEM, content, metadata)
-
- def add_intent(self, intent: str) -> None:
- """
- Records an intent in the history.
-
- Args:
- intent: The detected intent
- """
- self.intent_history.append(intent)
- # Keep last 20 intents
- if len(self.intent_history) > 20:
- self.intent_history = self.intent_history[-20:]
-
- def set_entity(self, name: str, value: Any) -> None:
- """
- Sets an entity value.
-
- Args:
- name: Entity name
- value: Entity value
- """
- self.entities[name] = value
-
- def get_entity(self, name: str, default: Any = None) -> Any:
- """
- Gets an entity value.
-
- Args:
- name: Entity name
- default: Default value if not found
-
- Returns:
- Entity value or default
- """
- return self.entities.get(name, default)
-
- def get_last_message(self, role: Optional[MessageRole] = None) -> Optional[Message]:
- """
- Gets the last message, optionally filtered by role.
-
- Args:
- role: Optional role filter
-
- Returns:
- The last matching message or None
- """
- if not self.messages:
- return None
-
- if role is None:
- return self.messages[-1]
-
- for msg in reversed(self.messages):
- if msg.role == role:
- return msg
-
- return None
-
- def get_messages_for_llm(self) -> List[Dict[str, str]]:
- """
- Gets messages formatted for LLM API calls.
-
- Returns:
- List of message dicts with role and content
- """
- result = []
-
- # Add system prompt first
- if self.system_prompt:
- result.append({
- "role": "system",
- "content": self.system_prompt
- })
-
- # Add summary if we have one and history was compressed
- if self.summary:
- result.append({
- "role": "system",
- "content": f"Previous conversation summary: {self.summary}"
- })
-
- # Add recent messages
- for msg in self.messages:
- result.append({
- "role": msg.role.value,
- "content": msg.content
- })
-
- return result
-
- def _compress_history(self) -> None:
- """
- Compresses older messages to save context window space.
-
- Keeps:
- - System messages
- - Last 20 messages
- - Creates summary of compressed middle messages
- """
- # Keep system messages
- system_msgs = [m for m in self.messages if m.role == MessageRole.SYSTEM]
-
- # Keep last 20 messages
- recent_msgs = self.messages[-20:]
-
- # Middle messages to summarize
- middle_start = len(system_msgs)
- middle_end = len(self.messages) - 20
- middle_msgs = self.messages[middle_start:middle_end]
-
- if middle_msgs:
- # Create a basic summary (can be enhanced with LLM-based summarization)
- self.summary = self._create_summary(middle_msgs)
-
- # Combine
- self.messages = system_msgs + recent_msgs
-
- logger.debug(
- f"Compressed conversation: {middle_end - middle_start} messages summarized"
- )
-
- def _create_summary(self, messages: List[Message]) -> str:
- """
- Creates a summary of messages.
-
- This is a basic implementation - can be enhanced with LLM-based summarization.
-
- Args:
- messages: Messages to summarize
-
- Returns:
- Summary string
- """
- # Count message types
- user_count = sum(1 for m in messages if m.role == MessageRole.USER)
- assistant_count = sum(1 for m in messages if m.role == MessageRole.ASSISTANT)
-
- # Extract key topics (simplified - could use NLP)
- topics = set()
- for msg in messages:
- # Simple keyword extraction
- words = msg.content.lower().split()
- # Filter common words
- keywords = [w for w in words if len(w) > 5][:3]
- topics.update(keywords)
-
- topics_str = ", ".join(list(topics)[:5])
-
- return (
- f"Earlier conversation: {user_count} user messages, "
- f"{assistant_count} assistant responses. "
- f"Topics discussed: {topics_str}"
- )
-
- def clear(self) -> None:
- """Clears all context"""
- self.messages.clear()
- self.entities.clear()
- self.intent_history.clear()
- self.summary = None
-
- def to_dict(self) -> Dict[str, Any]:
- """Serializes context to dict"""
- return {
- "messages": [m.to_dict() for m in self.messages],
- "entities": self.entities,
- "intent_history": self.intent_history,
- "summary": self.summary,
- "max_messages": self.max_messages,
- "system_prompt": self.system_prompt,
- "metadata": self.metadata
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "ConversationContext":
- """Deserializes context from dict"""
- ctx = cls(
- messages=[Message.from_dict(m) for m in data.get("messages", [])],
- entities=data.get("entities", {}),
- intent_history=data.get("intent_history", []),
- summary=data.get("summary"),
- max_messages=data.get("max_messages", 50),
- system_prompt=data.get("system_prompt"),
- metadata=data.get("metadata", {})
- )
- return ctx
-
-
class ContextManager:
"""
Manages conversation contexts for multiple sessions.
diff --git a/agent-core/brain/context_models.py b/agent-core/brain/context_models.py
new file mode 100644
index 0000000..1a38a86
--- /dev/null
+++ b/agent-core/brain/context_models.py
@@ -0,0 +1,307 @@
+"""
+Context Models for Breakpilot Agents
+
+Data classes for conversation messages and context management.
+"""
+
+from typing import Dict, Any, List, Optional
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from enum import Enum
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class MessageRole(Enum):
+ """Message roles in a conversation"""
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+ TOOL = "tool"
+
+
+@dataclass
+class Message:
+ """Represents a message in a conversation"""
+ role: MessageRole
+ content: str
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "role": self.role.value,
+ "content": self.content,
+ "timestamp": self.timestamp.isoformat(),
+ "metadata": self.metadata
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "Message":
+ return cls(
+ role=MessageRole(data["role"]),
+ content=data["content"],
+ timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc),
+ metadata=data.get("metadata", {})
+ )
+
+
+@dataclass
+class ConversationContext:
+ """
+ Context for a running conversation.
+
+ Maintains:
+ - Message history with automatic compression
+ - Extracted entities
+ - Intent history
+ - Conversation summary
+ """
+ messages: List[Message] = field(default_factory=list)
+ entities: Dict[str, Any] = field(default_factory=dict)
+ intent_history: List[str] = field(default_factory=list)
+ summary: Optional[str] = None
+ max_messages: int = 50
+ system_prompt: Optional[str] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def add_message(
+ self,
+ role: MessageRole,
+ content: str,
+ metadata: Optional[Dict[str, Any]] = None
+ ) -> Message:
+ """
+ Adds a message to the conversation.
+
+ Args:
+ role: Message role
+ content: Message content
+ metadata: Optional message metadata
+
+ Returns:
+ The created Message
+ """
+ message = Message(
+ role=role,
+ content=content,
+ metadata=metadata or {}
+ )
+ self.messages.append(message)
+
+ # Compress if needed
+ if len(self.messages) > self.max_messages:
+ self._compress_history()
+
+ return message
+
+ def add_user_message(
+ self,
+ content: str,
+ metadata: Optional[Dict[str, Any]] = None
+ ) -> Message:
+ """Convenience method to add a user message"""
+ return self.add_message(MessageRole.USER, content, metadata)
+
+ def add_assistant_message(
+ self,
+ content: str,
+ metadata: Optional[Dict[str, Any]] = None
+ ) -> Message:
+ """Convenience method to add an assistant message"""
+ return self.add_message(MessageRole.ASSISTANT, content, metadata)
+
+ def add_system_message(
+ self,
+ content: str,
+ metadata: Optional[Dict[str, Any]] = None
+ ) -> Message:
+ """Convenience method to add a system message"""
+ return self.add_message(MessageRole.SYSTEM, content, metadata)
+
+ def add_intent(self, intent: str) -> None:
+ """
+ Records an intent in the history.
+
+ Args:
+ intent: The detected intent
+ """
+ self.intent_history.append(intent)
+ # Keep last 20 intents
+ if len(self.intent_history) > 20:
+ self.intent_history = self.intent_history[-20:]
+
+ def set_entity(self, name: str, value: Any) -> None:
+ """
+ Sets an entity value.
+
+ Args:
+ name: Entity name
+ value: Entity value
+ """
+ self.entities[name] = value
+
+ def get_entity(self, name: str, default: Any = None) -> Any:
+ """
+ Gets an entity value.
+
+ Args:
+ name: Entity name
+ default: Default value if not found
+
+ Returns:
+ Entity value or default
+ """
+ return self.entities.get(name, default)
+
+ def get_last_message(self, role: Optional[MessageRole] = None) -> Optional[Message]:
+ """
+ Gets the last message, optionally filtered by role.
+
+ Args:
+ role: Optional role filter
+
+ Returns:
+ The last matching message or None
+ """
+ if not self.messages:
+ return None
+
+ if role is None:
+ return self.messages[-1]
+
+ for msg in reversed(self.messages):
+ if msg.role == role:
+ return msg
+
+ return None
+
+ def get_messages_for_llm(self) -> List[Dict[str, str]]:
+ """
+ Gets messages formatted for LLM API calls.
+
+ Returns:
+ List of message dicts with role and content
+ """
+ result = []
+
+ # Add system prompt first
+ if self.system_prompt:
+ result.append({
+ "role": "system",
+ "content": self.system_prompt
+ })
+
+ # Add summary if we have one and history was compressed
+ if self.summary:
+ result.append({
+ "role": "system",
+ "content": f"Previous conversation summary: {self.summary}"
+ })
+
+ # Add recent messages
+ for msg in self.messages:
+ result.append({
+ "role": msg.role.value,
+ "content": msg.content
+ })
+
+ return result
+
+ def _compress_history(self) -> None:
+ """
+ Compresses older messages to save context window space.
+
+ Keeps:
+ - System messages
+ - Last 20 messages
+ - Creates summary of compressed middle messages
+ """
+ # Keep system messages
+ system_msgs = [m for m in self.messages if m.role == MessageRole.SYSTEM]
+
+ # Keep last 20 messages
+ recent_msgs = self.messages[-20:]
+
+ # Middle messages to summarize
+ middle_start = len(system_msgs)
+ middle_end = len(self.messages) - 20
+ middle_msgs = self.messages[middle_start:middle_end]
+
+ if middle_msgs:
+ # Create a basic summary (can be enhanced with LLM-based summarization)
+ self.summary = self._create_summary(middle_msgs)
+
+ # Combine
+ self.messages = system_msgs + recent_msgs
+
+ logger.debug(
+ f"Compressed conversation: {middle_end - middle_start} messages summarized"
+ )
+
+ def _create_summary(self, messages: List[Message]) -> str:
+ """
+ Creates a summary of messages.
+
+ This is a basic implementation - can be enhanced with LLM-based summarization.
+
+ Args:
+ messages: Messages to summarize
+
+ Returns:
+ Summary string
+ """
+ # Count message types
+ user_count = sum(1 for m in messages if m.role == MessageRole.USER)
+ assistant_count = sum(1 for m in messages if m.role == MessageRole.ASSISTANT)
+
+ # Extract key topics (simplified - could use NLP)
+ topics = set()
+ for msg in messages:
+ # Simple keyword extraction
+ words = msg.content.lower().split()
+ # Filter common words
+ keywords = [w for w in words if len(w) > 5][:3]
+ topics.update(keywords)
+
+ topics_str = ", ".join(list(topics)[:5])
+
+ return (
+ f"Earlier conversation: {user_count} user messages, "
+ f"{assistant_count} assistant responses. "
+ f"Topics discussed: {topics_str}"
+ )
+
+ def clear(self) -> None:
+ """Clears all context"""
+ self.messages.clear()
+ self.entities.clear()
+ self.intent_history.clear()
+ self.summary = None
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Serializes context to dict"""
+ return {
+ "messages": [m.to_dict() for m in self.messages],
+ "entities": self.entities,
+ "intent_history": self.intent_history,
+ "summary": self.summary,
+ "max_messages": self.max_messages,
+ "system_prompt": self.system_prompt,
+ "metadata": self.metadata
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "ConversationContext":
+ """Deserializes context from dict"""
+ ctx = cls(
+ messages=[Message.from_dict(m) for m in data.get("messages", [])],
+ entities=data.get("entities", {}),
+ intent_history=data.get("intent_history", []),
+ summary=data.get("summary"),
+ max_messages=data.get("max_messages", 50),
+ system_prompt=data.get("system_prompt"),
+ metadata=data.get("metadata", {})
+ )
+ return ctx
diff --git a/agent-core/brain/knowledge_graph.py b/agent-core/brain/knowledge_graph.py
index 6696b5a..6376b14 100644
--- a/agent-core/brain/knowledge_graph.py
+++ b/agent-core/brain/knowledge_graph.py
@@ -9,109 +9,20 @@ Provides entity and relationship management:
"""
from typing import Dict, Any, List, Optional, Set, Tuple
-from dataclasses import dataclass, field
from datetime import datetime, timezone
-from enum import Enum
import json
import logging
+from agent_core.brain.knowledge_models import (
+ EntityType,
+ RelationshipType,
+ Entity,
+ Relationship,
+)
+
logger = logging.getLogger(__name__)
-class EntityType(Enum):
- """Types of entities in the knowledge graph"""
- STUDENT = "student"
- TEACHER = "teacher"
- CLASS = "class"
- SUBJECT = "subject"
- ASSIGNMENT = "assignment"
- EXAM = "exam"
- TOPIC = "topic"
- CONCEPT = "concept"
- RESOURCE = "resource"
- CUSTOM = "custom"
-
-
-class RelationshipType(Enum):
- """Types of relationships between entities"""
- BELONGS_TO = "belongs_to" # Student belongs to class
- TEACHES = "teaches" # Teacher teaches subject
- ASSIGNED_TO = "assigned_to" # Assignment assigned to student
- COVERS = "covers" # Exam covers topic
- REQUIRES = "requires" # Topic requires concept
- RELATED_TO = "related_to" # General relationship
- PARENT_OF = "parent_of" # Hierarchical relationship
- CREATED_BY = "created_by" # Creator relationship
- GRADED_BY = "graded_by" # Grading relationship
-
-
-@dataclass
-class Entity:
- """Represents an entity in the knowledge graph"""
- id: str
- entity_type: EntityType
- name: str
- properties: Dict[str, Any] = field(default_factory=dict)
- created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
- updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "id": self.id,
- "entity_type": self.entity_type.value,
- "name": self.name,
- "properties": self.properties,
- "created_at": self.created_at.isoformat(),
- "updated_at": self.updated_at.isoformat()
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "Entity":
- return cls(
- id=data["id"],
- entity_type=EntityType(data["entity_type"]),
- name=data["name"],
- properties=data.get("properties", {}),
- created_at=datetime.fromisoformat(data["created_at"]),
- updated_at=datetime.fromisoformat(data["updated_at"])
- )
-
-
-@dataclass
-class Relationship:
- """Represents a relationship between two entities"""
- id: str
- source_id: str
- target_id: str
- relationship_type: RelationshipType
- properties: Dict[str, Any] = field(default_factory=dict)
- weight: float = 1.0
- created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "id": self.id,
- "source_id": self.source_id,
- "target_id": self.target_id,
- "relationship_type": self.relationship_type.value,
- "properties": self.properties,
- "weight": self.weight,
- "created_at": self.created_at.isoformat()
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "Relationship":
- return cls(
- id=data["id"],
- source_id=data["source_id"],
- target_id=data["target_id"],
- relationship_type=RelationshipType(data["relationship_type"]),
- properties=data.get("properties", {}),
- weight=data.get("weight", 1.0),
- created_at=datetime.fromisoformat(data["created_at"])
- )
-
-
class KnowledgeGraph:
"""
Knowledge graph for managing entity relationships.
diff --git a/agent-core/brain/knowledge_models.py b/agent-core/brain/knowledge_models.py
new file mode 100644
index 0000000..063ef7f
--- /dev/null
+++ b/agent-core/brain/knowledge_models.py
@@ -0,0 +1,104 @@
+"""
+Knowledge Graph Models for Breakpilot Agents
+
+Entity and relationship data classes, plus type enumerations.
+"""
+
+from typing import Dict, Any
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from enum import Enum
+
+
+class EntityType(Enum):
+ """Types of entities in the knowledge graph"""
+ STUDENT = "student"
+ TEACHER = "teacher"
+ CLASS = "class"
+ SUBJECT = "subject"
+ ASSIGNMENT = "assignment"
+ EXAM = "exam"
+ TOPIC = "topic"
+ CONCEPT = "concept"
+ RESOURCE = "resource"
+ CUSTOM = "custom"
+
+
+class RelationshipType(Enum):
+ """Types of relationships between entities"""
+ BELONGS_TO = "belongs_to" # Student belongs to class
+ TEACHES = "teaches" # Teacher teaches subject
+ ASSIGNED_TO = "assigned_to" # Assignment assigned to student
+ COVERS = "covers" # Exam covers topic
+ REQUIRES = "requires" # Topic requires concept
+ RELATED_TO = "related_to" # General relationship
+ PARENT_OF = "parent_of" # Hierarchical relationship
+ CREATED_BY = "created_by" # Creator relationship
+ GRADED_BY = "graded_by" # Grading relationship
+
+
+@dataclass
+class Entity:
+ """Represents an entity in the knowledge graph"""
+ id: str
+ entity_type: EntityType
+ name: str
+ properties: Dict[str, Any] = field(default_factory=dict)
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "id": self.id,
+ "entity_type": self.entity_type.value,
+ "name": self.name,
+ "properties": self.properties,
+ "created_at": self.created_at.isoformat(),
+ "updated_at": self.updated_at.isoformat()
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "Entity":
+ return cls(
+ id=data["id"],
+ entity_type=EntityType(data["entity_type"]),
+ name=data["name"],
+ properties=data.get("properties", {}),
+ created_at=datetime.fromisoformat(data["created_at"]),
+ updated_at=datetime.fromisoformat(data["updated_at"])
+ )
+
+
+@dataclass
+class Relationship:
+ """Represents a relationship between two entities"""
+ id: str
+ source_id: str
+ target_id: str
+ relationship_type: RelationshipType
+ properties: Dict[str, Any] = field(default_factory=dict)
+ weight: float = 1.0
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "id": self.id,
+ "source_id": self.source_id,
+ "target_id": self.target_id,
+ "relationship_type": self.relationship_type.value,
+ "properties": self.properties,
+ "weight": self.weight,
+ "created_at": self.created_at.isoformat()
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "Relationship":
+ return cls(
+ id=data["id"],
+ source_id=data["source_id"],
+ target_id=data["target_id"],
+ relationship_type=RelationshipType(data["relationship_type"]),
+ properties=data.get("properties", {}),
+ weight=data.get("weight", 1.0),
+ created_at=datetime.fromisoformat(data["created_at"])
+ )
diff --git a/agent-core/brain/memory_models.py b/agent-core/brain/memory_models.py
new file mode 100644
index 0000000..6283df7
--- /dev/null
+++ b/agent-core/brain/memory_models.py
@@ -0,0 +1,53 @@
+"""
+Memory Models for Breakpilot Agents
+
+Data classes for memory items used by MemoryStore.
+"""
+
+from typing import Dict, Any, Optional
+from datetime import datetime, timezone
+from dataclasses import dataclass, field
+
+
+@dataclass
+class Memory:
+ """Represents a stored memory item"""
+ key: str
+ value: Any
+ agent_id: str
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ expires_at: Optional[datetime] = None
+ access_count: int = 0
+ last_accessed: Optional[datetime] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "key": self.key,
+ "value": self.value,
+ "agent_id": self.agent_id,
+ "created_at": self.created_at.isoformat(),
+ "expires_at": self.expires_at.isoformat() if self.expires_at else None,
+ "access_count": self.access_count,
+ "last_accessed": self.last_accessed.isoformat() if self.last_accessed else None,
+ "metadata": self.metadata
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "Memory":
+ return cls(
+ key=data["key"],
+ value=data["value"],
+ agent_id=data["agent_id"],
+ created_at=datetime.fromisoformat(data["created_at"]),
+ expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
+ access_count=data.get("access_count", 0),
+ last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None,
+ metadata=data.get("metadata", {})
+ )
+
+ def is_expired(self) -> bool:
+ """Check if the memory has expired"""
+ if not self.expires_at:
+ return False
+ return datetime.now(timezone.utc) > self.expires_at
diff --git a/agent-core/brain/memory_store.py b/agent-core/brain/memory_store.py
index f41afbf..ec351de 100644
--- a/agent-core/brain/memory_store.py
+++ b/agent-core/brain/memory_store.py
@@ -1,92 +1,24 @@
"""
Memory Store for Breakpilot Agents
-Provides long-term memory with:
-- TTL-based expiration
-- Access count tracking
-- Pattern-based search
-- Hybrid Valkey + PostgreSQL persistence
+Hybrid Valkey + PostgreSQL persistence with TTL, access tracking, and pattern search.
"""
from typing import List, Dict, Any, Optional
from datetime import datetime, timezone, timedelta
-from dataclasses import dataclass, field
import json
import logging
import hashlib
+from agent_core.brain.memory_models import Memory
+
logger = logging.getLogger(__name__)
-@dataclass
-class Memory:
- """Represents a stored memory item"""
- key: str
- value: Any
- agent_id: str
- created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
- expires_at: Optional[datetime] = None
- access_count: int = 0
- last_accessed: Optional[datetime] = None
- metadata: Dict[str, Any] = field(default_factory=dict)
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "key": self.key,
- "value": self.value,
- "agent_id": self.agent_id,
- "created_at": self.created_at.isoformat(),
- "expires_at": self.expires_at.isoformat() if self.expires_at else None,
- "access_count": self.access_count,
- "last_accessed": self.last_accessed.isoformat() if self.last_accessed else None,
- "metadata": self.metadata
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "Memory":
- return cls(
- key=data["key"],
- value=data["value"],
- agent_id=data["agent_id"],
- created_at=datetime.fromisoformat(data["created_at"]),
- expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
- access_count=data.get("access_count", 0),
- last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None,
- metadata=data.get("metadata", {})
- )
-
- def is_expired(self) -> bool:
- """Check if the memory has expired"""
- if not self.expires_at:
- return False
- return datetime.now(timezone.utc) > self.expires_at
-
-
class MemoryStore:
- """
- Long-term memory store for agents.
+ """Long-term memory store with TTL, access tracking, and hybrid persistence."""
- Stores facts, decisions, and learning progress with:
- - TTL-based expiration
- - Access tracking for importance scoring
- - Pattern-based retrieval
- - Hybrid persistence (Valkey for fast access, PostgreSQL for durability)
- """
-
- def __init__(
- self,
- redis_client=None,
- db_pool=None,
- namespace: str = "breakpilot"
- ):
- """
- Initialize the memory store.
-
- Args:
- redis_client: Async Redis/Valkey client
- db_pool: Async PostgreSQL connection pool
- namespace: Key namespace for isolation
- """
+ def __init__(self, redis_client=None, db_pool=None, namespace: str = "breakpilot"):
self.redis = redis_client
self.db_pool = db_pool
self.namespace = namespace
@@ -103,26 +35,10 @@ class MemoryStore:
return key
async def remember(
- self,
- key: str,
- value: Any,
- agent_id: str,
- ttl_days: int = 30,
- metadata: Optional[Dict[str, Any]] = None
+ self, key: str, value: Any, agent_id: str,
+ ttl_days: int = 30, metadata: Optional[Dict[str, Any]] = None
) -> Memory:
- """
- Stores a memory.
-
- Args:
- key: Unique key for the memory
- value: Value to store (must be JSON-serializable)
- agent_id: ID of the agent storing the memory
- ttl_days: Time to live in days (0 = no expiration)
- metadata: Optional additional metadata
-
- Returns:
- The created Memory object
- """
+ """Stores a memory with optional TTL and metadata."""
expires_at = None
if ttl_days > 0:
expires_at = datetime.now(timezone.utc) + timedelta(days=ttl_days)
@@ -143,32 +59,14 @@ class MemoryStore:
return memory
async def recall(self, key: str) -> Optional[Any]:
- """
- Retrieves a memory value by key.
-
- Args:
- key: The memory key
-
- Returns:
- The stored value or None if not found/expired
- """
+ """Retrieves a memory value by key, or None if not found/expired."""
memory = await self.get_memory(key)
if memory:
return memory.value
return None
async def get_memory(self, key: str) -> Optional[Memory]:
- """
- Retrieves a full Memory object by key.
-
- Updates access count and last_accessed timestamp.
-
- Args:
- key: The memory key
-
- Returns:
- Memory object or None if not found/expired
- """
+ """Retrieves a full Memory object by key, updating access count."""
# Check local cache
if key in self._local_cache:
memory = self._local_cache[key]
diff --git a/agent-core/orchestrator/__init__.py b/agent-core/orchestrator/__init__.py
index 8a7784a..b6289c3 100644
--- a/agent-core/orchestrator/__init__.py
+++ b/agent-core/orchestrator/__init__.py
@@ -12,11 +12,13 @@ from agent_core.orchestrator.message_bus import (
AgentMessage,
MessagePriority,
)
-from agent_core.orchestrator.supervisor import (
- AgentSupervisor,
- AgentInfo,
+from agent_core.orchestrator.supervisor_models import (
AgentStatus,
+ RestartPolicy,
+ AgentInfo,
+ AgentFactory,
)
+from agent_core.orchestrator.supervisor import AgentSupervisor
from agent_core.orchestrator.task_router import (
TaskRouter,
RoutingResult,
@@ -30,6 +32,8 @@ __all__ = [
"AgentSupervisor",
"AgentInfo",
"AgentStatus",
+ "RestartPolicy",
+ "AgentFactory",
"TaskRouter",
"RoutingResult",
"RoutingStrategy",
diff --git a/agent-core/orchestrator/supervisor.py b/agent-core/orchestrator/supervisor.py
index 72ecef4..2d98008 100644
--- a/agent-core/orchestrator/supervisor.py
+++ b/agent-core/orchestrator/supervisor.py
@@ -1,17 +1,11 @@
"""
Agent Supervisor for Breakpilot
-Provides:
-- Agent lifecycle management
-- Health monitoring
-- Restart policies
-- Load balancing
+Agent lifecycle management, health monitoring, restart policies, load balancing.
"""
-from typing import Dict, Optional, Callable, Awaitable, List, Any
-from dataclasses import dataclass, field
+from typing import Dict, Optional, List, Any
from datetime import datetime, timezone, timedelta
-from enum import Enum
import asyncio
import logging
@@ -21,91 +15,24 @@ from agent_core.orchestrator.message_bus import (
AgentMessage,
MessagePriority,
)
+from agent_core.orchestrator.supervisor_models import (
+ AgentStatus,
+ RestartPolicy,
+ AgentInfo,
+ AgentFactory,
+)
logger = logging.getLogger(__name__)
-class AgentStatus(Enum):
- """Agent lifecycle states"""
- INITIALIZING = "initializing"
- STARTING = "starting"
- RUNNING = "running"
- PAUSED = "paused"
- STOPPING = "stopping"
- STOPPED = "stopped"
- ERROR = "error"
- RESTARTING = "restarting"
-
-
-class RestartPolicy(Enum):
- """Agent restart policies"""
- NEVER = "never"
- ON_FAILURE = "on_failure"
- ALWAYS = "always"
-
-
-@dataclass
-class AgentInfo:
- """Information about a registered agent"""
- agent_id: str
- agent_type: str
- status: AgentStatus = AgentStatus.INITIALIZING
- current_task: Optional[str] = None
- started_at: Optional[datetime] = None
- last_activity: Optional[datetime] = None
- error_count: int = 0
- restart_count: int = 0
- max_restarts: int = 3
- restart_policy: RestartPolicy = RestartPolicy.ON_FAILURE
- metadata: Dict[str, Any] = field(default_factory=dict)
- capacity: int = 10 # Max concurrent tasks
- current_load: int = 0
-
- def is_healthy(self) -> bool:
- """Check if agent is healthy"""
- return self.status == AgentStatus.RUNNING and self.error_count < 3
-
- def is_available(self) -> bool:
- """Check if agent can accept new tasks"""
- return (
- self.status == AgentStatus.RUNNING and
- self.current_load < self.capacity
- )
-
- def utilization(self) -> float:
- """Returns agent utilization (0-1)"""
- return self.current_load / max(self.capacity, 1)
-
-
-AgentFactory = Callable[[str], Awaitable[Any]]
-
-
class AgentSupervisor:
- """
- Supervises and coordinates all agents.
-
- Responsibilities:
- - Agent registration and lifecycle
- - Health monitoring via heartbeat
- - Restart policies
- - Load balancing
- - Alert escalation
- """
+ """Supervises agents: lifecycle, health monitoring, restart policies, load balancing."""
def __init__(
- self,
- message_bus: MessageBus,
+ self, message_bus: MessageBus,
heartbeat_monitor: Optional[HeartbeatMonitor] = None,
check_interval_seconds: int = 10
):
- """
- Initialize the supervisor.
-
- Args:
- message_bus: Message bus for inter-agent communication
- heartbeat_monitor: Heartbeat monitor for liveness checks
- check_interval_seconds: How often to run health checks
- """
self.bus = message_bus
self.heartbeat = heartbeat_monitor or HeartbeatMonitor()
self.check_interval = check_interval_seconds
diff --git a/agent-core/orchestrator/supervisor_models.py b/agent-core/orchestrator/supervisor_models.py
new file mode 100644
index 0000000..68048f1
--- /dev/null
+++ b/agent-core/orchestrator/supervisor_models.py
@@ -0,0 +1,65 @@
+"""
+Supervisor Models for Breakpilot Agents
+
+Data classes and enumerations for agent lifecycle management.
+"""
+
+from typing import Dict, Optional, Any, Callable, Awaitable
+from dataclasses import dataclass, field
+from datetime import datetime
+from enum import Enum
+
+
+class AgentStatus(Enum):
+ """Agent lifecycle states"""
+ INITIALIZING = "initializing"
+ STARTING = "starting"
+ RUNNING = "running"
+ PAUSED = "paused"
+ STOPPING = "stopping"
+ STOPPED = "stopped"
+ ERROR = "error"
+ RESTARTING = "restarting"
+
+
+class RestartPolicy(Enum):
+ """Agent restart policies"""
+ NEVER = "never"
+ ON_FAILURE = "on_failure"
+ ALWAYS = "always"
+
+
+@dataclass
+class AgentInfo:
+ """Information about a registered agent"""
+ agent_id: str
+ agent_type: str
+ status: AgentStatus = AgentStatus.INITIALIZING
+ current_task: Optional[str] = None
+ started_at: Optional[datetime] = None
+ last_activity: Optional[datetime] = None
+ error_count: int = 0
+ restart_count: int = 0
+ max_restarts: int = 3
+ restart_policy: RestartPolicy = RestartPolicy.ON_FAILURE
+ metadata: Dict[str, Any] = field(default_factory=dict)
+ capacity: int = 10 # Max concurrent tasks
+ current_load: int = 0
+
+ def is_healthy(self) -> bool:
+ """Check if agent is healthy"""
+ return self.status == AgentStatus.RUNNING and self.error_count < 3
+
+ def is_available(self) -> bool:
+ """Check if agent can accept new tasks"""
+ return (
+ self.status == AgentStatus.RUNNING and
+ self.current_load < self.capacity
+ )
+
+ def utilization(self) -> float:
+ """Returns agent utilization (0-1)"""
+ return self.current_load / max(self.capacity, 1)
+
+
+AgentFactory = Callable[[str], Awaitable[Any]]
diff --git a/agent-core/sessions/__init__.py b/agent-core/sessions/__init__.py
index 3d26c3b..bbe1132 100644
--- a/agent-core/sessions/__init__.py
+++ b/agent-core/sessions/__init__.py
@@ -8,11 +8,12 @@ Provides:
- SessionState: Session state enumeration
"""
-from agent_core.sessions.session_manager import (
+from agent_core.sessions.session_models import (
AgentSession,
- SessionManager,
SessionState,
+ SessionCheckpoint,
)
+from agent_core.sessions.session_manager import SessionManager
from agent_core.sessions.heartbeat import HeartbeatMonitor
from agent_core.sessions.checkpoint import CheckpointManager
@@ -20,6 +21,7 @@ __all__ = [
"AgentSession",
"SessionManager",
"SessionState",
+ "SessionCheckpoint",
"HeartbeatMonitor",
"CheckpointManager",
]
diff --git a/agent-core/sessions/session_manager.py b/agent-core/sessions/session_manager.py
index 0e12e74..2bc6b74 100644
--- a/agent-core/sessions/session_manager.py
+++ b/agent-core/sessions/session_manager.py
@@ -2,189 +2,25 @@
Session Management for Breakpilot Agents
Provides session lifecycle management with:
-- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED)
-- Checkpoint-based recovery
-- Heartbeat integration
- Hybrid Valkey + PostgreSQL persistence
+- Session CRUD operations
+- Stale session cleanup
"""
-from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional, List
-from enum import Enum
-import uuid
import json
import logging
+from agent_core.sessions.session_models import (
+ SessionState,
+ SessionCheckpoint,
+ AgentSession,
+)
+
logger = logging.getLogger(__name__)
-class SessionState(Enum):
- """Agent session states"""
- ACTIVE = "active"
- PAUSED = "paused"
- COMPLETED = "completed"
- FAILED = "failed"
-
-
-@dataclass
-class SessionCheckpoint:
- """Represents a checkpoint in an agent session"""
- name: str
- timestamp: datetime
- data: Dict[str, Any]
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "name": self.name,
- "timestamp": self.timestamp.isoformat(),
- "data": self.data
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
- return cls(
- name=data["name"],
- timestamp=datetime.fromisoformat(data["timestamp"]),
- data=data["data"]
- )
-
-
-@dataclass
-class AgentSession:
- """
- Represents an active agent session.
-
- Attributes:
- session_id: Unique session identifier
- agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
- user_id: Associated user ID
- state: Current session state
- created_at: Session creation timestamp
- last_heartbeat: Last heartbeat timestamp
- context: Session context data
- checkpoints: List of session checkpoints for recovery
- """
- session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
- agent_type: str = ""
- user_id: str = ""
- state: SessionState = SessionState.ACTIVE
- created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
- last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
- context: Dict[str, Any] = field(default_factory=dict)
- checkpoints: List[SessionCheckpoint] = field(default_factory=list)
- metadata: Dict[str, Any] = field(default_factory=dict)
-
- def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
- """
- Creates a checkpoint for recovery.
-
- Args:
- name: Checkpoint name (e.g., "task_received", "processing_complete")
- data: Checkpoint data to store
-
- Returns:
- The created checkpoint
- """
- checkpoint = SessionCheckpoint(
- name=name,
- timestamp=datetime.now(timezone.utc),
- data=data
- )
- self.checkpoints.append(checkpoint)
- logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
- return checkpoint
-
- def heartbeat(self) -> None:
- """Updates the heartbeat timestamp"""
- self.last_heartbeat = datetime.now(timezone.utc)
-
- def pause(self) -> None:
- """Pauses the session"""
- self.state = SessionState.PAUSED
- self.checkpoint("session_paused", {"previous_state": "active"})
-
- def resume(self) -> None:
- """Resumes a paused session"""
- if self.state == SessionState.PAUSED:
- self.state = SessionState.ACTIVE
- self.heartbeat()
- self.checkpoint("session_resumed", {})
-
- def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
- """Marks the session as completed"""
- self.state = SessionState.COMPLETED
- self.checkpoint("session_completed", {"result": result or {}})
-
- def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
- """Marks the session as failed"""
- self.state = SessionState.FAILED
- self.checkpoint("session_failed", {
- "error": error,
- "details": error_details or {}
- })
-
- def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
- """
- Gets the last checkpoint, optionally filtered by name.
-
- Args:
- name: Optional checkpoint name to filter by
-
- Returns:
- The last matching checkpoint or None
- """
- if not self.checkpoints:
- return None
-
- if name:
- matching = [cp for cp in self.checkpoints if cp.name == name]
- return matching[-1] if matching else None
-
- return self.checkpoints[-1]
-
- def get_duration(self) -> timedelta:
- """Returns the session duration"""
- end_time = datetime.now(timezone.utc)
- if self.state in (SessionState.COMPLETED, SessionState.FAILED):
- last_cp = self.get_last_checkpoint()
- if last_cp:
- end_time = last_cp.timestamp
- return end_time - self.created_at
-
- def to_dict(self) -> Dict[str, Any]:
- """Serializes the session to a dictionary"""
- return {
- "session_id": self.session_id,
- "agent_type": self.agent_type,
- "user_id": self.user_id,
- "state": self.state.value,
- "created_at": self.created_at.isoformat(),
- "last_heartbeat": self.last_heartbeat.isoformat(),
- "context": self.context,
- "checkpoints": [cp.to_dict() for cp in self.checkpoints],
- "metadata": self.metadata
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
- """Deserializes a session from a dictionary"""
- return cls(
- session_id=data["session_id"],
- agent_type=data["agent_type"],
- user_id=data["user_id"],
- state=SessionState(data["state"]),
- created_at=datetime.fromisoformat(data["created_at"]),
- last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
- context=data.get("context", {}),
- checkpoints=[
- SessionCheckpoint.from_dict(cp)
- for cp in data.get("checkpoints", [])
- ],
- metadata=data.get("metadata", {})
- )
-
-
class SessionManager:
"""
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
@@ -303,7 +139,6 @@ class SessionManager:
"""
session.heartbeat()
self._local_cache[session.session_id] = session
- self._local_cache[session.session_id] = session
await self._persist_session(session)
async def delete_session(self, session_id: str) -> bool:
diff --git a/agent-core/sessions/session_models.py b/agent-core/sessions/session_models.py
new file mode 100644
index 0000000..514b4c9
--- /dev/null
+++ b/agent-core/sessions/session_models.py
@@ -0,0 +1,180 @@
+"""
+Session Models for Breakpilot Agents
+
+Data classes for agent sessions, checkpoints, and state tracking.
+"""
+
+from dataclasses import dataclass, field
+from datetime import datetime, timezone, timedelta
+from typing import Dict, Any, Optional, List
+from enum import Enum
+import uuid
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class SessionState(Enum):
+ """Agent session states"""
+ ACTIVE = "active"
+ PAUSED = "paused"
+ COMPLETED = "completed"
+ FAILED = "failed"
+
+
+@dataclass
+class SessionCheckpoint:
+ """Represents a checkpoint in an agent session"""
+ name: str
+ timestamp: datetime
+ data: Dict[str, Any]
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "name": self.name,
+ "timestamp": self.timestamp.isoformat(),
+ "data": self.data
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
+ return cls(
+ name=data["name"],
+ timestamp=datetime.fromisoformat(data["timestamp"]),
+ data=data["data"]
+ )
+
+
+@dataclass
+class AgentSession:
+ """
+ Represents an active agent session.
+
+ Attributes:
+ session_id: Unique session identifier
+ agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
+ user_id: Associated user ID
+ state: Current session state
+ created_at: Session creation timestamp
+ last_heartbeat: Last heartbeat timestamp
+ context: Session context data
+ checkpoints: List of session checkpoints for recovery
+ """
+ session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
+ agent_type: str = ""
+ user_id: str = ""
+ state: SessionState = SessionState.ACTIVE
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ context: Dict[str, Any] = field(default_factory=dict)
+ checkpoints: List[SessionCheckpoint] = field(default_factory=list)
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
+ """
+ Creates a checkpoint for recovery.
+
+ Args:
+ name: Checkpoint name (e.g., "task_received", "processing_complete")
+ data: Checkpoint data to store
+
+ Returns:
+ The created checkpoint
+ """
+ checkpoint = SessionCheckpoint(
+ name=name,
+ timestamp=datetime.now(timezone.utc),
+ data=data
+ )
+ self.checkpoints.append(checkpoint)
+ logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
+ return checkpoint
+
+ def heartbeat(self) -> None:
+ """Updates the heartbeat timestamp"""
+ self.last_heartbeat = datetime.now(timezone.utc)
+
+ def pause(self) -> None:
+ """Pauses the session"""
+ self.state = SessionState.PAUSED
+ self.checkpoint("session_paused", {"previous_state": "active"})
+
+ def resume(self) -> None:
+ """Resumes a paused session"""
+ if self.state == SessionState.PAUSED:
+ self.state = SessionState.ACTIVE
+ self.heartbeat()
+ self.checkpoint("session_resumed", {})
+
+ def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
+ """Marks the session as completed"""
+ self.state = SessionState.COMPLETED
+ self.checkpoint("session_completed", {"result": result or {}})
+
+ def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
+ """Marks the session as failed"""
+ self.state = SessionState.FAILED
+ self.checkpoint("session_failed", {
+ "error": error,
+ "details": error_details or {}
+ })
+
+ def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
+ """
+ Gets the last checkpoint, optionally filtered by name.
+
+ Args:
+ name: Optional checkpoint name to filter by
+
+ Returns:
+ The last matching checkpoint or None
+ """
+ if not self.checkpoints:
+ return None
+
+ if name:
+ matching = [cp for cp in self.checkpoints if cp.name == name]
+ return matching[-1] if matching else None
+
+ return self.checkpoints[-1]
+
+ def get_duration(self) -> timedelta:
+ """Returns the session duration"""
+ end_time = datetime.now(timezone.utc)
+ if self.state in (SessionState.COMPLETED, SessionState.FAILED):
+ last_cp = self.get_last_checkpoint()
+ if last_cp:
+ end_time = last_cp.timestamp
+ return end_time - self.created_at
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Serializes the session to a dictionary"""
+ return {
+ "session_id": self.session_id,
+ "agent_type": self.agent_type,
+ "user_id": self.user_id,
+ "state": self.state.value,
+ "created_at": self.created_at.isoformat(),
+ "last_heartbeat": self.last_heartbeat.isoformat(),
+ "context": self.context,
+ "checkpoints": [cp.to_dict() for cp in self.checkpoints],
+ "metadata": self.metadata
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
+ """Deserializes a session from a dictionary"""
+ return cls(
+ session_id=data["session_id"],
+ agent_type=data["agent_type"],
+ user_id=data["user_id"],
+ state=SessionState(data["state"]),
+ created_at=datetime.fromisoformat(data["created_at"]),
+ last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
+ context=data.get("context", {}),
+ checkpoints=[
+ SessionCheckpoint.from_dict(cp)
+ for cp in data.get("checkpoints", [])
+ ],
+ metadata=data.get("metadata", {})
+ )
diff --git a/backend-lehrer/ai_processor/export/print_templates.py b/backend-lehrer/ai_processor/export/print_templates.py
new file mode 100644
index 0000000..b3df69a
--- /dev/null
+++ b/backend-lehrer/ai_processor/export/print_templates.py
@@ -0,0 +1,244 @@
+"""
+AI Processor - HTML Templates for Print Versions
+
+Contains HTML/CSS header templates for Q&A, Cloze, and Multiple Choice print output.
+"""
+
+
+def get_qa_html_header(title: str) -> str:
+ """Get HTML header for Q&A print version."""
+ return f"""
+
+
+
+
{title} - Fragen
+
+
+
+"""
+
+
+def get_cloze_html_header(title: str) -> str:
+ """Get HTML header for cloze print version."""
+ return f"""
+
+
+
+
{title} - Lueckentext
+
+
+
+"""
+
+
+def get_mc_html_header(title: str) -> str:
+ """Get HTML header for MC print version."""
+ return f"""
+
+
+
+
{title} - Multiple Choice
+
+
+
+"""
diff --git a/backend-lehrer/ai_processor/export/print_versions.py b/backend-lehrer/ai_processor/export/print_versions.py
index 1619a87..5f06009 100644
--- a/backend-lehrer/ai_processor/export/print_versions.py
+++ b/backend-lehrer/ai_processor/export/print_versions.py
@@ -10,6 +10,7 @@ import logging
import random
from ..config import BEREINIGT_DIR
+from .print_templates import get_qa_html_header, get_cloze_html_header, get_mc_html_header
logger = logging.getLogger(__name__)
@@ -37,7 +38,7 @@ def generate_print_version_qa(qa_path: Path, include_answers: bool = False) -> P
grade = metadata.get("grade_level", "")
html_parts = []
- html_parts.append(_get_qa_html_header(title))
+ html_parts.append(get_qa_html_header(title))
# Header
version_text = "Loesungsblatt" if include_answers else "Fragenblatt"
@@ -106,7 +107,7 @@ def generate_print_version_cloze(cloze_path: Path, include_answers: bool = False
total_gaps = metadata.get("total_gaps", 0)
html_parts = []
- html_parts.append(_get_cloze_html_header(title))
+ html_parts.append(get_cloze_html_header(title))
# Header
version_text = "Loesungsblatt" if include_answers else "Lueckentext"
@@ -200,7 +201,7 @@ def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> s
grade = metadata.get("grade_level", "")
html_parts = []
- html_parts.append(_get_mc_html_header(title))
+ html_parts.append(get_mc_html_header(title))
# Header
version_text = "Loesungsblatt" if include_answers else "Multiple Choice Test"
@@ -267,242 +268,3 @@ def generate_print_version_mc(mc_path: Path, include_answers: bool = False) -> s
html_parts.append("")
return "\n".join(html_parts)
-
-
-def _get_qa_html_header(title: str) -> str:
- """Get HTML header for Q&A print version."""
- return f"""
-
-
-
-
{title} - Fragen
-
-
-
-"""
-
-
-def _get_cloze_html_header(title: str) -> str:
- """Get HTML header for cloze print version."""
- return f"""
-
-
-
-
{title} - Lueckentext
-
-
-
-"""
-
-
-def _get_mc_html_header(title: str) -> str:
- """Get HTML header for MC print version."""
- return f"""
-
-
-
-
{title} - Multiple Choice
-
-
-
-"""
diff --git a/backend-lehrer/alerts_agent/api/digests.py b/backend-lehrer/alerts_agent/api/digests.py
index 5d1d762..31d7ca7 100644
--- a/backend-lehrer/alerts_agent/api/digests.py
+++ b/backend-lehrer/alerts_agent/api/digests.py
@@ -9,13 +9,10 @@ Endpoints:
- POST /digests/{id}/send-email - Digest per E-Mail versenden
"""
-import uuid
import io
-from typing import Optional, List
from datetime import datetime, timedelta
-from fastapi import APIRouter, Depends, HTTPException, Query, Response
+from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
-from pydantic import BaseModel, Field
from sqlalchemy.orm import Session as DBSession
from ..db.database import get_db
@@ -23,126 +20,27 @@ from ..db.models import (
AlertDigestDB, UserAlertSubscriptionDB, DigestStatusEnum
)
from ..processing.digest_generator import DigestGenerator
+from .digests_models import (
+ DigestDetail,
+ DigestListResponse,
+ GenerateDigestRequest,
+ GenerateDigestResponse,
+ SendEmailRequest,
+ SendEmailResponse,
+ digest_to_list_item,
+ digest_to_detail,
+)
+from .digests_email import generate_pdf_from_html, send_digest_by_email
router = APIRouter(prefix="/digests", tags=["digests"])
-# ============================================================================
-# Request/Response Models
-# ============================================================================
-
-class DigestListItem(BaseModel):
- """Kurze Digest-Info fuer Liste."""
- id: str
- period_start: datetime
- period_end: datetime
- total_alerts: int
- critical_count: int
- urgent_count: int
- status: str
- created_at: datetime
-
-
-class DigestDetail(BaseModel):
- """Vollstaendige Digest-Details."""
- id: str
- subscription_id: Optional[str]
- user_id: str
- period_start: datetime
- period_end: datetime
- summary_html: str
- summary_pdf_url: Optional[str]
- total_alerts: int
- critical_count: int
- urgent_count: int
- important_count: int
- review_count: int
- info_count: int
- status: str
- sent_at: Optional[datetime]
- created_at: datetime
-
-
-class DigestListResponse(BaseModel):
- """Response fuer Digest-Liste."""
- digests: List[DigestListItem]
- total: int
-
-
-class GenerateDigestRequest(BaseModel):
- """Request fuer manuelle Digest-Generierung."""
- weeks_back: int = Field(default=1, ge=1, le=4, description="Wochen zurueck")
- force_regenerate: bool = Field(default=False, description="Vorhandenen Digest ueberschreiben")
-
-
-class GenerateDigestResponse(BaseModel):
- """Response fuer Digest-Generierung."""
- status: str
- digest_id: Optional[str]
- message: str
-
-
-class SendEmailRequest(BaseModel):
- """Request fuer E-Mail-Versand."""
- email: Optional[str] = Field(default=None, description="E-Mail-Adresse (optional, sonst aus Subscription)")
-
-
-class SendEmailResponse(BaseModel):
- """Response fuer E-Mail-Versand."""
- status: str
- sent_to: str
- message: str
-
-
-# ============================================================================
-# Helper Functions
-# ============================================================================
-
def get_user_id_from_request() -> str:
- """
- Extrahiert User-ID aus Request.
- TODO: JWT-Token auswerten, aktuell Dummy.
- """
+ """Extrahiert User-ID aus Request. TODO: JWT-Token auswerten."""
return "demo-user"
-def _digest_to_list_item(digest: AlertDigestDB) -> DigestListItem:
- """Konvertiere DB-Model zu List-Item."""
- return DigestListItem(
- id=digest.id,
- period_start=digest.period_start,
- period_end=digest.period_end,
- total_alerts=digest.total_alerts or 0,
- critical_count=digest.critical_count or 0,
- urgent_count=digest.urgent_count or 0,
- status=digest.status.value if digest.status else "pending",
- created_at=digest.created_at
- )
-
-
-def _digest_to_detail(digest: AlertDigestDB) -> DigestDetail:
- """Konvertiere DB-Model zu Detail."""
- return DigestDetail(
- id=digest.id,
- subscription_id=digest.subscription_id,
- user_id=digest.user_id,
- period_start=digest.period_start,
- period_end=digest.period_end,
- summary_html=digest.summary_html or "",
- summary_pdf_url=digest.summary_pdf_url,
- total_alerts=digest.total_alerts or 0,
- critical_count=digest.critical_count or 0,
- urgent_count=digest.urgent_count or 0,
- important_count=digest.important_count or 0,
- review_count=digest.review_count or 0,
- info_count=digest.info_count or 0,
- status=digest.status.value if digest.status else "pending",
- sent_at=digest.sent_at,
- created_at=digest.created_at
- )
-
-
# ============================================================================
# Endpoints
# ============================================================================
@@ -153,11 +51,7 @@ async def list_digests(
offset: int = Query(0, ge=0),
db: DBSession = Depends(get_db)
):
- """
- Liste alle Digests des aktuellen Users.
-
- Sortiert nach Erstellungsdatum (neueste zuerst).
- """
+ """Liste alle Digests des aktuellen Users."""
user_id = get_user_id_from_request()
query = db.query(AlertDigestDB).filter(
@@ -168,18 +62,14 @@ async def list_digests(
digests = query.offset(offset).limit(limit).all()
return DigestListResponse(
- digests=[_digest_to_list_item(d) for d in digests],
+ digests=[digest_to_list_item(d) for d in digests],
total=total
)
@router.get("/latest", response_model=DigestDetail)
-async def get_latest_digest(
- db: DBSession = Depends(get_db)
-):
- """
- Hole den neuesten Digest des Users.
- """
+async def get_latest_digest(db: DBSession = Depends(get_db)):
+ """Hole den neuesten Digest des Users."""
user_id = get_user_id_from_request()
digest = db.query(AlertDigestDB).filter(
@@ -189,17 +79,12 @@ async def get_latest_digest(
if not digest:
raise HTTPException(status_code=404, detail="Kein Digest vorhanden")
- return _digest_to_detail(digest)
+ return digest_to_detail(digest)
@router.get("/{digest_id}", response_model=DigestDetail)
-async def get_digest(
- digest_id: str,
- db: DBSession = Depends(get_db)
-):
- """
- Hole Details eines spezifischen Digests.
- """
+async def get_digest(digest_id: str, db: DBSession = Depends(get_db)):
+ """Hole Details eines spezifischen Digests."""
user_id = get_user_id_from_request()
digest = db.query(AlertDigestDB).filter(
@@ -210,17 +95,12 @@ async def get_digest(
if not digest:
raise HTTPException(status_code=404, detail="Digest nicht gefunden")
- return _digest_to_detail(digest)
+ return digest_to_detail(digest)
@router.get("/{digest_id}/pdf")
-async def get_digest_pdf(
- digest_id: str,
- db: DBSession = Depends(get_db)
-):
- """
- Generiere und lade PDF-Version des Digests herunter.
- """
+async def get_digest_pdf(digest_id: str, db: DBSession = Depends(get_db)):
+ """Generiere und lade PDF-Version des Digests herunter."""
user_id = get_user_id_from_request()
digest = db.query(AlertDigestDB).filter(
@@ -230,35 +110,26 @@ async def get_digest_pdf(
if not digest:
raise HTTPException(status_code=404, detail="Digest nicht gefunden")
-
if not digest.summary_html:
raise HTTPException(status_code=400, detail="Digest hat keinen Inhalt")
- # PDF generieren
try:
pdf_bytes = await generate_pdf_from_html(digest.summary_html)
except Exception as e:
raise HTTPException(status_code=500, detail=f"PDF-Generierung fehlgeschlagen: {str(e)}")
- # Dateiname
filename = f"wochenbericht_{digest.period_start.strftime('%Y%m%d')}_{digest.period_end.strftime('%Y%m%d')}.pdf"
return StreamingResponse(
io.BytesIO(pdf_bytes),
media_type="application/pdf",
- headers={
- "Content-Disposition": f"attachment; filename={filename}"
- }
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
)
@router.get("/latest/pdf")
-async def get_latest_digest_pdf(
- db: DBSession = Depends(get_db)
-):
- """
- PDF des neuesten Digests herunterladen.
- """
+async def get_latest_digest_pdf(db: DBSession = Depends(get_db)):
+ """PDF des neuesten Digests herunterladen."""
user_id = get_user_id_from_request()
digest = db.query(AlertDigestDB).filter(
@@ -267,11 +138,9 @@ async def get_latest_digest_pdf(
if not digest:
raise HTTPException(status_code=404, detail="Kein Digest vorhanden")
-
if not digest.summary_html:
raise HTTPException(status_code=400, detail="Digest hat keinen Inhalt")
- # PDF generieren
try:
pdf_bytes = await generate_pdf_from_html(digest.summary_html)
except Exception as e:
@@ -282,9 +151,7 @@ async def get_latest_digest_pdf(
return StreamingResponse(
io.BytesIO(pdf_bytes),
media_type="application/pdf",
- headers={
- "Content-Disposition": f"attachment; filename={filename}"
- }
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
)
@@ -293,16 +160,10 @@ async def generate_digest(
request: GenerateDigestRequest = None,
db: DBSession = Depends(get_db)
):
- """
- Generiere einen neuen Digest manuell.
-
- Normalerweise werden Digests automatisch woechentlich generiert.
- Diese Route erlaubt manuelle Generierung fuer Tests oder On-Demand.
- """
+ """Generiere einen neuen Digest manuell."""
user_id = get_user_id_from_request()
weeks_back = request.weeks_back if request else 1
- # Pruefe ob bereits ein Digest fuer diesen Zeitraum existiert
now = datetime.utcnow()
period_end = now - timedelta(days=now.weekday())
period_start = period_end - timedelta(weeks=weeks_back)
@@ -315,12 +176,10 @@ async def generate_digest(
if existing and not (request and request.force_regenerate):
return GenerateDigestResponse(
- status="exists",
- digest_id=existing.id,
+ status="exists", digest_id=existing.id,
message="Digest fuer diesen Zeitraum existiert bereits"
)
- # Generiere neuen Digest
generator = DigestGenerator(db)
try:
@@ -328,14 +187,12 @@ async def generate_digest(
if digest:
return GenerateDigestResponse(
- status="success",
- digest_id=digest.id,
+ status="success", digest_id=digest.id,
message="Digest erfolgreich generiert"
)
else:
return GenerateDigestResponse(
- status="empty",
- digest_id=None,
+ status="empty", digest_id=None,
message="Keine Alerts fuer diesen Zeitraum vorhanden"
)
except Exception as e:
@@ -348,9 +205,7 @@ async def send_digest_email(
request: SendEmailRequest = None,
db: DBSession = Depends(get_db)
):
- """
- Versende Digest per E-Mail.
- """
+ """Versende Digest per E-Mail."""
user_id = get_user_id_from_request()
digest = db.query(AlertDigestDB).filter(
@@ -361,12 +216,10 @@ async def send_digest_email(
if not digest:
raise HTTPException(status_code=404, detail="Digest nicht gefunden")
- # E-Mail-Adresse ermitteln
email = None
if request and request.email:
email = request.email
else:
- # Aus Subscription holen
subscription = db.query(UserAlertSubscriptionDB).filter(
UserAlertSubscriptionDB.id == digest.subscription_id
).first()
@@ -376,176 +229,18 @@ async def send_digest_email(
if not email:
raise HTTPException(status_code=400, detail="Keine E-Mail-Adresse angegeben")
- # E-Mail versenden
try:
await send_digest_by_email(digest, email)
- # Status aktualisieren
digest.status = DigestStatusEnum.SENT
digest.sent_at = datetime.utcnow()
db.commit()
return SendEmailResponse(
- status="success",
- sent_to=email,
+ status="success", sent_to=email,
message="E-Mail erfolgreich versendet"
)
except Exception as e:
digest.status = DigestStatusEnum.FAILED
db.commit()
raise HTTPException(status_code=500, detail=f"E-Mail-Versand fehlgeschlagen: {str(e)}")
-
-
-# ============================================================================
-# PDF Generation
-# ============================================================================
-
-async def generate_pdf_from_html(html_content: str) -> bytes:
- """
- Generiere PDF aus HTML.
-
- Verwendet WeasyPrint oder wkhtmltopdf als Fallback.
- """
- try:
- # Versuche WeasyPrint (bevorzugt)
- from weasyprint import HTML
- pdf_bytes = HTML(string=html_content).write_pdf()
- return pdf_bytes
- except ImportError:
- pass
-
- try:
- # Fallback: wkhtmltopdf via pdfkit
- import pdfkit
- pdf_bytes = pdfkit.from_string(html_content, False)
- return pdf_bytes
- except ImportError:
- pass
-
- try:
- # Fallback: xhtml2pdf
- from xhtml2pdf import pisa
- result = io.BytesIO()
- pisa.CreatePDF(io.StringIO(html_content), dest=result)
- return result.getvalue()
- except ImportError:
- pass
-
- # Letzter Fallback: Einfache Text-Konvertierung
- raise ImportError(
- "Keine PDF-Bibliothek verfuegbar. "
- "Installieren Sie: pip install weasyprint oder pip install pdfkit oder pip install xhtml2pdf"
- )
-
-
-# ============================================================================
-# Email Sending
-# ============================================================================
-
-async def send_digest_by_email(digest: AlertDigestDB, recipient_email: str):
- """
- Versende Digest per E-Mail.
-
- Verwendet:
- - Lokalen SMTP-Server (Postfix/Sendmail)
- - SMTP-Relay (z.B. SES, Mailgun)
- - SendGrid API
- """
- import os
- import smtplib
- from email.mime.text import MIMEText
- from email.mime.multipart import MIMEMultipart
- from email.mime.application import MIMEApplication
-
- # E-Mail zusammenstellen
- msg = MIMEMultipart('alternative')
- msg['Subject'] = f"Wochenbericht: {digest.period_start.strftime('%d.%m.%Y')} - {digest.period_end.strftime('%d.%m.%Y')}"
- msg['From'] = os.getenv('SMTP_FROM', 'alerts@breakpilot.app')
- msg['To'] = recipient_email
-
- # Text-Version
- text_content = f"""
-BreakPilot Alerts - Wochenbericht
-
-Zeitraum: {digest.period_start.strftime('%d.%m.%Y')} - {digest.period_end.strftime('%d.%m.%Y')}
-Gesamt: {digest.total_alerts} Meldungen
-Kritisch: {digest.critical_count}
-Dringend: {digest.urgent_count}
-
-Oeffnen Sie die HTML-Version fuer die vollstaendige Uebersicht.
-
----
-Diese E-Mail wurde automatisch von BreakPilot Alerts generiert.
- """
- msg.attach(MIMEText(text_content, 'plain', 'utf-8'))
-
- # HTML-Version
- if digest.summary_html:
- msg.attach(MIMEText(digest.summary_html, 'html', 'utf-8'))
-
- # PDF-Anhang (optional)
- try:
- pdf_bytes = await generate_pdf_from_html(digest.summary_html)
- pdf_attachment = MIMEApplication(pdf_bytes, _subtype='pdf')
- pdf_attachment.add_header(
- 'Content-Disposition', 'attachment',
- filename=f"wochenbericht_{digest.period_start.strftime('%Y%m%d')}.pdf"
- )
- msg.attach(pdf_attachment)
- except Exception:
- pass # PDF-Anhang ist optional
-
- # Senden
- smtp_host = os.getenv('SMTP_HOST', 'localhost')
- smtp_port = int(os.getenv('SMTP_PORT', '25'))
- smtp_user = os.getenv('SMTP_USER', '')
- smtp_pass = os.getenv('SMTP_PASS', '')
-
- try:
- if smtp_port == 465:
- # SSL
- server = smtplib.SMTP_SSL(smtp_host, smtp_port)
- else:
- server = smtplib.SMTP(smtp_host, smtp_port)
- if smtp_port == 587:
- server.starttls()
-
- if smtp_user and smtp_pass:
- server.login(smtp_user, smtp_pass)
-
- server.send_message(msg)
- server.quit()
-
- except Exception as e:
- # Fallback: SendGrid API
- sendgrid_key = os.getenv('SENDGRID_API_KEY')
- if sendgrid_key:
- await send_via_sendgrid(msg, sendgrid_key)
- else:
- raise e
-
-
-async def send_via_sendgrid(msg, api_key: str):
- """Fallback: SendGrid API."""
- import httpx
-
- async with httpx.AsyncClient() as client:
- response = await client.post(
- "https://api.sendgrid.com/v3/mail/send",
- headers={
- "Authorization": f"Bearer {api_key}",
- "Content-Type": "application/json"
- },
- json={
- "personalizations": [{"to": [{"email": msg['To']}]}],
- "from": {"email": msg['From']},
- "subject": msg['Subject'],
- "content": [
- {"type": "text/plain", "value": msg.get_payload(0).get_payload()},
- {"type": "text/html", "value": msg.get_payload(1).get_payload() if len(msg.get_payload()) > 1 else ""}
- ]
- }
- )
-
- if response.status_code >= 400:
- raise Exception(f"SendGrid error: {response.status_code}")
diff --git a/backend-lehrer/alerts_agent/api/digests_email.py b/backend-lehrer/alerts_agent/api/digests_email.py
new file mode 100644
index 0000000..b4ed763
--- /dev/null
+++ b/backend-lehrer/alerts_agent/api/digests_email.py
@@ -0,0 +1,146 @@
+"""
+Alert Digests - PDF-Generierung und E-Mail-Versand.
+"""
+
+import io
+import logging
+
+from ..db.models import AlertDigestDB
+
+logger = logging.getLogger(__name__)
+
+
+async def generate_pdf_from_html(html_content: str) -> bytes:
+ """
+ Generiere PDF aus HTML.
+
+ Verwendet WeasyPrint oder wkhtmltopdf als Fallback.
+ """
+ try:
+ from weasyprint import HTML
+ pdf_bytes = HTML(string=html_content).write_pdf()
+ return pdf_bytes
+ except ImportError:
+ pass
+
+ try:
+ import pdfkit
+ pdf_bytes = pdfkit.from_string(html_content, False)
+ return pdf_bytes
+ except ImportError:
+ pass
+
+ try:
+ from xhtml2pdf import pisa
+ result = io.BytesIO()
+ pisa.CreatePDF(io.StringIO(html_content), dest=result)
+ return result.getvalue()
+ except ImportError:
+ pass
+
+ raise ImportError(
+ "Keine PDF-Bibliothek verfuegbar. "
+ "Installieren Sie: pip install weasyprint oder pip install pdfkit oder pip install xhtml2pdf"
+ )
+
+
+async def send_digest_by_email(digest: AlertDigestDB, recipient_email: str):
+ """
+ Versende Digest per E-Mail.
+
+ Verwendet:
+ - Lokalen SMTP-Server (Postfix/Sendmail)
+ - SMTP-Relay (z.B. SES, Mailgun)
+ - SendGrid API
+ """
+ import os
+ import smtplib
+ from email.mime.text import MIMEText
+ from email.mime.multipart import MIMEMultipart
+ from email.mime.application import MIMEApplication
+
+ msg = MIMEMultipart('alternative')
+ msg['Subject'] = f"Wochenbericht: {digest.period_start.strftime('%d.%m.%Y')} - {digest.period_end.strftime('%d.%m.%Y')}"
+ msg['From'] = os.getenv('SMTP_FROM', 'alerts@breakpilot.app')
+ msg['To'] = recipient_email
+
+ text_content = f"""
+BreakPilot Alerts - Wochenbericht
+
+Zeitraum: {digest.period_start.strftime('%d.%m.%Y')} - {digest.period_end.strftime('%d.%m.%Y')}
+Gesamt: {digest.total_alerts} Meldungen
+Kritisch: {digest.critical_count}
+Dringend: {digest.urgent_count}
+
+Oeffnen Sie die HTML-Version fuer die vollstaendige Uebersicht.
+
+---
+Diese E-Mail wurde automatisch von BreakPilot Alerts generiert.
+ """
+ msg.attach(MIMEText(text_content, 'plain', 'utf-8'))
+
+ if digest.summary_html:
+ msg.attach(MIMEText(digest.summary_html, 'html', 'utf-8'))
+
+ try:
+ pdf_bytes = await generate_pdf_from_html(digest.summary_html)
+ pdf_attachment = MIMEApplication(pdf_bytes, _subtype='pdf')
+ pdf_attachment.add_header(
+ 'Content-Disposition', 'attachment',
+ filename=f"wochenbericht_{digest.period_start.strftime('%Y%m%d')}.pdf"
+ )
+ msg.attach(pdf_attachment)
+ except Exception:
+ pass # PDF-Anhang ist optional
+
+ smtp_host = os.getenv('SMTP_HOST', 'localhost')
+ smtp_port = int(os.getenv('SMTP_PORT', '25'))
+ smtp_user = os.getenv('SMTP_USER', '')
+ smtp_pass = os.getenv('SMTP_PASS', '')
+
+ try:
+ if smtp_port == 465:
+ server = smtplib.SMTP_SSL(smtp_host, smtp_port)
+ else:
+ server = smtplib.SMTP(smtp_host, smtp_port)
+ if smtp_port == 587:
+ server.starttls()
+
+ if smtp_user and smtp_pass:
+ server.login(smtp_user, smtp_pass)
+
+ server.send_message(msg)
+ server.quit()
+
+ except Exception as e:
+ sendgrid_key = os.getenv('SENDGRID_API_KEY')
+ if sendgrid_key:
+ await send_via_sendgrid(msg, sendgrid_key)
+ else:
+ raise e
+
+
+async def send_via_sendgrid(msg, api_key: str):
+ """Fallback: SendGrid API."""
+ import httpx
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ "https://api.sendgrid.com/v3/mail/send",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json"
+ },
+ json={
+ "personalizations": [{"to": [{"email": msg['To']}]}],
+ "from": {"email": msg['From']},
+ "subject": msg['Subject'],
+ "content": [
+ {"type": "text/plain", "value": msg.get_payload(0).get_payload()},
+ {"type": "text/html", "value": msg.get_payload(1).get_payload() if len(msg.get_payload()) > 1 else ""}
+ ]
+ }
+ )
+
+ if response.status_code >= 400:
+ raise Exception(f"SendGrid error: {response.status_code}")
diff --git a/backend-lehrer/alerts_agent/api/digests_models.py b/backend-lehrer/alerts_agent/api/digests_models.py
new file mode 100644
index 0000000..9c94d5b
--- /dev/null
+++ b/backend-lehrer/alerts_agent/api/digests_models.py
@@ -0,0 +1,116 @@
+"""
+Alert Digests - Request/Response Models und Konverter.
+"""
+
+from typing import Optional, List
+from datetime import datetime
+from pydantic import BaseModel, Field
+
+from ..db.models import AlertDigestDB
+
+
+# ============================================================================
+# Request/Response Models
+# ============================================================================
+
+class DigestListItem(BaseModel):
+ """Kurze Digest-Info fuer Liste."""
+ id: str
+ period_start: datetime
+ period_end: datetime
+ total_alerts: int
+ critical_count: int
+ urgent_count: int
+ status: str
+ created_at: datetime
+
+
+class DigestDetail(BaseModel):
+ """Vollstaendige Digest-Details."""
+ id: str
+ subscription_id: Optional[str]
+ user_id: str
+ period_start: datetime
+ period_end: datetime
+ summary_html: str
+ summary_pdf_url: Optional[str]
+ total_alerts: int
+ critical_count: int
+ urgent_count: int
+ important_count: int
+ review_count: int
+ info_count: int
+ status: str
+ sent_at: Optional[datetime]
+ created_at: datetime
+
+
+class DigestListResponse(BaseModel):
+ """Response fuer Digest-Liste."""
+ digests: List[DigestListItem]
+ total: int
+
+
+class GenerateDigestRequest(BaseModel):
+ """Request fuer manuelle Digest-Generierung."""
+ weeks_back: int = Field(default=1, ge=1, le=4, description="Wochen zurueck")
+ force_regenerate: bool = Field(default=False, description="Vorhandenen Digest ueberschreiben")
+
+
+class GenerateDigestResponse(BaseModel):
+ """Response fuer Digest-Generierung."""
+ status: str
+ digest_id: Optional[str]
+ message: str
+
+
+class SendEmailRequest(BaseModel):
+ """Request fuer E-Mail-Versand."""
+ email: Optional[str] = Field(default=None, description="E-Mail-Adresse (optional)")
+
+
+class SendEmailResponse(BaseModel):
+ """Response fuer E-Mail-Versand."""
+ status: str
+ sent_to: str
+ message: str
+
+
+# ============================================================================
+# Converter Functions
+# ============================================================================
+
+def digest_to_list_item(digest: AlertDigestDB) -> DigestListItem:
+ """Konvertiere DB-Model zu List-Item."""
+ return DigestListItem(
+ id=digest.id,
+ period_start=digest.period_start,
+ period_end=digest.period_end,
+ total_alerts=digest.total_alerts or 0,
+ critical_count=digest.critical_count or 0,
+ urgent_count=digest.urgent_count or 0,
+ status=digest.status.value if digest.status else "pending",
+ created_at=digest.created_at
+ )
+
+
+def digest_to_detail(digest: AlertDigestDB) -> DigestDetail:
+ """Konvertiere DB-Model zu Detail."""
+ return DigestDetail(
+ id=digest.id,
+ subscription_id=digest.subscription_id,
+ user_id=digest.user_id,
+ period_start=digest.period_start,
+ period_end=digest.period_end,
+ summary_html=digest.summary_html or "",
+ summary_pdf_url=digest.summary_pdf_url,
+ total_alerts=digest.total_alerts or 0,
+ critical_count=digest.critical_count or 0,
+ urgent_count=digest.urgent_count or 0,
+ important_count=digest.important_count or 0,
+ review_count=digest.review_count or 0,
+ info_count=digest.info_count or 0,
+ status=digest.status.value if digest.status else "pending",
+ sent_at=digest.sent_at,
+ created_at=digest.created_at
+ )
diff --git a/backend-lehrer/alerts_agent/api/routes.py b/backend-lehrer/alerts_agent/api/routes.py
index 8f52761..630bba4 100644
--- a/backend-lehrer/alerts_agent/api/routes.py
+++ b/backend-lehrer/alerts_agent/api/routes.py
@@ -1,5 +1,5 @@
"""
-API Routes für Alerts Agent.
+API Routes fuer Alerts Agent.
Endpoints:
- POST /alerts/ingest - Manuell Alerts importieren
@@ -13,12 +13,18 @@ Endpoints:
import os
from datetime import datetime
from typing import Optional
-from fastapi import APIRouter, Depends, HTTPException, Query
-from pydantic import BaseModel, Field
+from fastapi import APIRouter, HTTPException, Query
from ..models.alert_item import AlertItem, AlertStatus
from ..models.relevance_profile import RelevanceProfile, PriorityItem
from ..processing.relevance_scorer import RelevanceDecision, RelevanceScorer
+from .schemas import (
+ AlertIngestRequest, AlertIngestResponse,
+ AlertRunRequest, AlertRunResponse,
+ InboxItem, InboxResponse,
+ FeedbackRequest, FeedbackResponse,
+ ProfilePriorityRequest, ProfileUpdateRequest, ProfileResponse,
+)
router = APIRouter(prefix="/alerts", tags=["alerts"])
@@ -30,113 +36,13 @@ ALERTS_USE_LLM = os.getenv("ALERTS_USE_LLM", "false").lower() == "true"
# ============================================================================
-# In-Memory Storage (später durch DB ersetzen)
+# In-Memory Storage (spaeter durch DB ersetzen)
# ============================================================================
_alerts_store: dict[str, AlertItem] = {}
_profile_store: dict[str, RelevanceProfile] = {}
-# ============================================================================
-# Request/Response Models
-# ============================================================================
-
-class AlertIngestRequest(BaseModel):
- """Request für manuelles Alert-Import."""
- title: str = Field(..., min_length=1, max_length=500)
- url: str = Field(..., min_length=1)
- snippet: Optional[str] = Field(default=None, max_length=2000)
- topic_label: str = Field(default="Manual Import")
- published_at: Optional[datetime] = None
-
-
-class AlertIngestResponse(BaseModel):
- """Response für Alert-Import."""
- id: str
- status: str
- message: str
-
-
-class AlertRunRequest(BaseModel):
- """Request für Scoring-Pipeline."""
- limit: int = Field(default=50, ge=1, le=200)
- skip_scored: bool = Field(default=True)
-
-
-class AlertRunResponse(BaseModel):
- """Response für Scoring-Pipeline."""
- processed: int
- keep: int
- drop: int
- review: int
- errors: int
- duration_ms: int
-
-
-class InboxItem(BaseModel):
- """Ein Item in der Inbox."""
- id: str
- title: str
- url: str
- snippet: Optional[str]
- topic_label: str
- published_at: Optional[datetime]
- relevance_score: Optional[float]
- relevance_decision: Optional[str]
- relevance_summary: Optional[str]
- status: str
-
-
-class InboxResponse(BaseModel):
- """Response für Inbox-Abfrage."""
- items: list[InboxItem]
- total: int
- page: int
- page_size: int
-
-
-class FeedbackRequest(BaseModel):
- """Request für Relevanz-Feedback."""
- alert_id: str
- is_relevant: bool
- reason: Optional[str] = None
- tags: list[str] = Field(default_factory=list)
-
-
-class FeedbackResponse(BaseModel):
- """Response für Feedback."""
- success: bool
- message: str
- profile_updated: bool
-
-
-class ProfilePriorityRequest(BaseModel):
- """Priority für Profile-Update."""
- label: str
- weight: float = Field(default=0.5, ge=0.0, le=1.0)
- keywords: list[str] = Field(default_factory=list)
- description: Optional[str] = None
-
-
-class ProfileUpdateRequest(BaseModel):
- """Request für Profile-Update."""
- priorities: Optional[list[ProfilePriorityRequest]] = None
- exclusions: Optional[list[str]] = None
- policies: Optional[dict] = None
-
-
-class ProfileResponse(BaseModel):
- """Response für Profile."""
- id: str
- priorities: list[dict]
- exclusions: list[str]
- policies: dict
- total_scored: int
- total_kept: int
- total_dropped: int
- accuracy_estimate: Optional[float]
-
-
# ============================================================================
# Endpoints
# ============================================================================
@@ -146,7 +52,7 @@ async def ingest_alert(request: AlertIngestRequest):
"""
Manuell einen Alert importieren.
- Nützlich für Tests oder manuelles Hinzufügen von Artikeln.
+ Nuetzlich fuer Tests oder manuelles Hinzufuegen von Artikeln.
"""
alert = AlertItem(
title=request.title,
@@ -168,13 +74,13 @@ async def ingest_alert(request: AlertIngestRequest):
@router.post("/run", response_model=AlertRunResponse)
async def run_scoring_pipeline(request: AlertRunRequest):
"""
- Scoring-Pipeline für neue Alerts starten.
+ Scoring-Pipeline fuer neue Alerts starten.
Bewertet alle unbewerteten Alerts und klassifiziert sie
in KEEP, DROP oder REVIEW.
- Wenn ALERTS_USE_LLM=true, wird das LLM Gateway für Scoring verwendet.
- Sonst wird ein schnelles Keyword-basiertes Scoring durchgeführt.
+ Wenn ALERTS_USE_LLM=true, wird das LLM Gateway fuer Scoring verwendet.
+ Sonst wird ein schnelles Keyword-basiertes Scoring durchgefuehrt.
"""
import time
start = time.time()
@@ -193,7 +99,7 @@ async def run_scoring_pipeline(request: AlertRunRequest):
keep = drop = review = errors = 0
- # Profil für Scoring laden
+ # Profil fuer Scoring laden
profile = _profile_store.get("default")
if not profile:
profile = RelevanceProfile.create_default_education_profile()
@@ -201,7 +107,7 @@ async def run_scoring_pipeline(request: AlertRunRequest):
_profile_store["default"] = profile
if ALERTS_USE_LLM and LLM_API_KEY:
- # LLM-basiertes Scoring über Gateway
+ # LLM-basiertes Scoring ueber Gateway
scorer = RelevanceScorer(
gateway_url=LLM_GATEWAY_URL,
api_key=LLM_API_KEY,
@@ -227,12 +133,12 @@ async def run_scoring_pipeline(request: AlertRunRequest):
snippet_lower = (alert.snippet or "").lower()
combined = title_lower + " " + snippet_lower
- # Ausschlüsse aus Profil prüfen
+ # Ausschluesse aus Profil pruefen
if any(excl.lower() in combined for excl in profile.exclusions):
alert.relevance_score = 0.15
alert.relevance_decision = RelevanceDecision.DROP.value
drop += 1
- # Prioritäten aus Profil prüfen
+ # Prioritaeten aus Profil pruefen
elif any(
p.label.lower() in combined or
any(kw.lower() in combined for kw in (p.keywords if hasattr(p, 'keywords') else []))
@@ -285,9 +191,9 @@ async def get_inbox(
# Pagination
total = len(alerts)
- start = (page - 1) * page_size
- end = start + page_size
- page_alerts = alerts[start:end]
+ start_idx = (page - 1) * page_size
+ end_idx = start_idx + page_size
+ page_alerts = alerts[start_idx:end_idx]
items = [
InboxItem(
@@ -327,7 +233,7 @@ async def submit_feedback(request: FeedbackRequest):
# Alert Status aktualisieren
alert.status = AlertStatus.REVIEWED
- # Profile aktualisieren (Default-Profile für Demo)
+ # Profile aktualisieren (Default-Profile fuer Demo)
profile = _profile_store.get("default")
if not profile:
profile = RelevanceProfile.create_default_education_profile()
@@ -353,7 +259,7 @@ async def get_profile(user_id: Optional[str] = Query(default=None)):
"""
Relevanz-Profil abrufen.
- Ohne user_id wird das Default-Profil zurückgegeben.
+ Ohne user_id wird das Default-Profil zurueckgegeben.
"""
profile_id = user_id or "default"
profile = _profile_store.get(profile_id)
@@ -385,7 +291,7 @@ async def update_profile(
"""
Relevanz-Profil aktualisieren.
- Erlaubt Anpassung von Prioritäten, Ausschlüssen und Policies.
+ Erlaubt Anpassung von Prioritaeten, Ausschluessen und Policies.
"""
profile_id = user_id or "default"
profile = _profile_store.get(profile_id)
@@ -431,34 +337,24 @@ async def update_profile(
@router.get("/stats")
async def get_stats():
"""
- Statistiken über Alerts und Scoring.
-
- Gibt Statistiken im Format zurück, das das Frontend erwartet:
- - total_alerts, new_alerts, kept_alerts, review_alerts, dropped_alerts
- - total_topics, active_topics, total_rules
+ Statistiken ueber Alerts und Scoring.
"""
alerts = list(_alerts_store.values())
total = len(alerts)
- # Zähle nach Status und Decision
new_alerts = sum(1 for a in alerts if a.status == AlertStatus.NEW)
kept_alerts = sum(1 for a in alerts if a.relevance_decision == "KEEP")
review_alerts = sum(1 for a in alerts if a.relevance_decision == "REVIEW")
dropped_alerts = sum(1 for a in alerts if a.relevance_decision == "DROP")
- # Topics und Rules (In-Memory hat diese nicht, aber wir geben 0 zurück)
- # Bei DB-Implementierung würden wir hier die Repositories nutzen
total_topics = 0
active_topics = 0
total_rules = 0
- # Versuche DB-Statistiken zu laden wenn verfügbar
try:
from alerts_agent.db import get_db
from alerts_agent.db.repository import TopicRepository, RuleRepository
- from contextlib import contextmanager
- # Versuche eine DB-Session zu bekommen
db_gen = get_db()
db = next(db_gen, None)
if db:
@@ -478,15 +374,12 @@ async def get_stats():
except StopIteration:
pass
except Exception:
- # DB nicht verfügbar, nutze In-Memory Defaults
pass
- # Berechne Durchschnittsscore
scored_alerts = [a for a in alerts if a.relevance_score is not None]
avg_score = sum(a.relevance_score for a in scored_alerts) / len(scored_alerts) if scored_alerts else 0.0
return {
- # Frontend-kompatibles Format
"total_alerts": total,
"new_alerts": new_alerts,
"kept_alerts": kept_alerts,
@@ -496,7 +389,6 @@ async def get_stats():
"active_topics": active_topics,
"total_rules": total_rules,
"avg_score": avg_score,
- # Zusätzliche Details (Abwärtskompatibilität)
"by_status": {
"new": new_alerts,
"scored": sum(1 for a in alerts if a.status == AlertStatus.SCORED),
diff --git a/backend-lehrer/alerts_agent/api/schemas.py b/backend-lehrer/alerts_agent/api/schemas.py
new file mode 100644
index 0000000..3f0b302
--- /dev/null
+++ b/backend-lehrer/alerts_agent/api/schemas.py
@@ -0,0 +1,111 @@
+"""
+Request/Response Schemas fuer Alerts Agent API.
+"""
+
+from datetime import datetime
+from typing import Optional
+from pydantic import BaseModel, Field
+
+
+# ============================================================================
+# Request Models
+# ============================================================================
+
+class AlertIngestRequest(BaseModel):
+ """Request fuer manuelles Alert-Import."""
+ title: str = Field(..., min_length=1, max_length=500)
+ url: str = Field(..., min_length=1)
+ snippet: Optional[str] = Field(default=None, max_length=2000)
+ topic_label: str = Field(default="Manual Import")
+ published_at: Optional[datetime] = None
+
+
+class AlertRunRequest(BaseModel):
+ """Request fuer Scoring-Pipeline."""
+ limit: int = Field(default=50, ge=1, le=200)
+ skip_scored: bool = Field(default=True)
+
+
+class FeedbackRequest(BaseModel):
+ """Request fuer Relevanz-Feedback."""
+ alert_id: str
+ is_relevant: bool
+ reason: Optional[str] = None
+ tags: list[str] = Field(default_factory=list)
+
+
+class ProfilePriorityRequest(BaseModel):
+ """Priority fuer Profile-Update."""
+ label: str
+ weight: float = Field(default=0.5, ge=0.0, le=1.0)
+ keywords: list[str] = Field(default_factory=list)
+ description: Optional[str] = None
+
+
+class ProfileUpdateRequest(BaseModel):
+ """Request fuer Profile-Update."""
+ priorities: Optional[list[ProfilePriorityRequest]] = None
+ exclusions: Optional[list[str]] = None
+ policies: Optional[dict] = None
+
+
+# ============================================================================
+# Response Models
+# ============================================================================
+
+class AlertIngestResponse(BaseModel):
+ """Response fuer Alert-Import."""
+ id: str
+ status: str
+ message: str
+
+
+class AlertRunResponse(BaseModel):
+ """Response fuer Scoring-Pipeline."""
+ processed: int
+ keep: int
+ drop: int
+ review: int
+ errors: int
+ duration_ms: int
+
+
+class InboxItem(BaseModel):
+ """Ein Item in der Inbox."""
+ id: str
+ title: str
+ url: str
+ snippet: Optional[str]
+ topic_label: str
+ published_at: Optional[datetime]
+ relevance_score: Optional[float]
+ relevance_decision: Optional[str]
+ relevance_summary: Optional[str]
+ status: str
+
+
+class InboxResponse(BaseModel):
+ """Response fuer Inbox-Abfrage."""
+ items: list[InboxItem]
+ total: int
+ page: int
+ page_size: int
+
+
+class FeedbackResponse(BaseModel):
+ """Response fuer Feedback."""
+ success: bool
+ message: str
+ profile_updated: bool
+
+
+class ProfileResponse(BaseModel):
+ """Response fuer Profile."""
+ id: str
+ priorities: list[dict]
+ exclusions: list[str]
+ policies: dict
+ total_scored: int
+ total_kept: int
+ total_dropped: int
+ accuracy_estimate: Optional[float]
diff --git a/backend-lehrer/alerts_agent/api/wizard.py b/backend-lehrer/alerts_agent/api/wizard.py
index c4010ca..ea4a059 100644
--- a/backend-lehrer/alerts_agent/api/wizard.py
+++ b/backend-lehrer/alerts_agent/api/wizard.py
@@ -7,21 +7,12 @@ Verwaltet den 3-Schritt Setup-Wizard:
3. Bestätigung und Aktivierung
Zusätzlich: Migration-Wizard für bestehende Google Alerts.
-
-Endpoints:
-- GET /wizard/state - Aktuellen Wizard-Status abrufen
-- PUT /wizard/step/{step} - Schritt speichern
-- POST /wizard/complete - Wizard abschließen
-- POST /wizard/reset - Wizard zurücksetzen
-- POST /wizard/migrate/email - E-Mail-Migration starten
-- POST /wizard/migrate/rss - RSS-Import
"""
import uuid
-from typing import Optional, List, Dict, Any
+from typing import List, Dict, Any
from datetime import datetime
-from fastapi import APIRouter, Depends, HTTPException, Query
-from pydantic import BaseModel, Field
+from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session as DBSession
from ..db.database import get_db
@@ -29,77 +20,22 @@ from ..db.models import (
UserAlertSubscriptionDB, AlertTemplateDB, AlertSourceDB,
AlertModeEnum, UserRoleEnum, MigrationModeEnum, FeedTypeEnum
)
+from .wizard_models import (
+ WizardState,
+ Step1Data,
+ Step2Data,
+ Step3Data,
+ StepResponse,
+ MigrateEmailRequest,
+ MigrateEmailResponse,
+ MigrateRssRequest,
+ MigrateRssResponse,
+)
router = APIRouter(prefix="/wizard", tags=["wizard"])
-# ============================================================================
-# Request/Response Models
-# ============================================================================
-
-class WizardState(BaseModel):
- """Aktueller Wizard-Status."""
- subscription_id: Optional[str] = None
- current_step: int = 0 # 0=nicht gestartet, 1-3=Schritte, 4=abgeschlossen
- is_completed: bool = False
- step_data: Dict[str, Any] = {}
- recommended_templates: List[Dict[str, Any]] = []
-
-
-class Step1Data(BaseModel):
- """Daten für Schritt 1: Rollenwahl."""
- role: str = Field(..., description="lehrkraft, schulleitung, it_beauftragte")
-
-
-class Step2Data(BaseModel):
- """Daten für Schritt 2: Template-Auswahl."""
- template_ids: List[str] = Field(..., min_length=1, max_length=3)
-
-
-class Step3Data(BaseModel):
- """Daten für Schritt 3: Bestätigung."""
- notification_email: Optional[str] = None
- digest_enabled: bool = True
- digest_frequency: str = "weekly"
-
-
-class StepResponse(BaseModel):
- """Response für Schritt-Update."""
- status: str
- current_step: int
- next_step: int
- message: str
- recommended_templates: List[Dict[str, Any]] = []
-
-
-class MigrateEmailRequest(BaseModel):
- """Request für E-Mail-Migration."""
- original_label: Optional[str] = Field(default=None, description="Beschreibung des Alerts")
-
-
-class MigrateEmailResponse(BaseModel):
- """Response für E-Mail-Migration."""
- status: str
- inbound_address: str
- instructions: List[str]
- source_id: str
-
-
-class MigrateRssRequest(BaseModel):
- """Request für RSS-Import."""
- rss_urls: List[str] = Field(..., min_length=1, max_length=20)
- labels: Optional[List[str]] = None
-
-
-class MigrateRssResponse(BaseModel):
- """Response für RSS-Import."""
- status: str
- sources_created: int
- topics_created: int
- message: str
-
-
# ============================================================================
# Helper Functions
# ============================================================================
@@ -144,13 +80,9 @@ def _get_recommended_templates(db: DBSession, role: str) -> List[Dict[str, Any]]
for t in templates:
if role in (t.target_roles or []):
result.append({
- "id": t.id,
- "slug": t.slug,
- "name": t.name,
- "description": t.description,
- "icon": t.icon,
- "category": t.category,
- "recommended": True,
+ "id": t.id, "slug": t.slug, "name": t.name,
+ "description": t.description, "icon": t.icon,
+ "category": t.category, "recommended": True,
})
return result
@@ -167,14 +99,8 @@ def _generate_inbound_address(user_id: str, source_id: str) -> str:
# ============================================================================
@router.get("/state", response_model=WizardState)
-async def get_wizard_state(
- db: DBSession = Depends(get_db)
-):
- """
- Hole aktuellen Wizard-Status.
-
- Gibt Schritt, gespeicherte Daten und empfohlene Templates zurück.
- """
+async def get_wizard_state(db: DBSession = Depends(get_db)):
+ """Hole aktuellen Wizard-Status."""
user_id = get_user_id_from_request()
subscription = db.query(UserAlertSubscriptionDB).filter(
@@ -182,15 +108,8 @@ async def get_wizard_state(
).order_by(UserAlertSubscriptionDB.created_at.desc()).first()
if not subscription:
- return WizardState(
- subscription_id=None,
- current_step=0,
- is_completed=False,
- step_data={},
- recommended_templates=[],
- )
+ return WizardState()
- # Empfohlene Templates basierend auf Rolle
role = subscription.user_role.value if subscription.user_role else None
recommended = _get_recommended_templates(db, role) if role else []
@@ -204,61 +123,37 @@ async def get_wizard_state(
@router.put("/step/1", response_model=StepResponse)
-async def save_step_1(
- data: Step1Data,
- db: DBSession = Depends(get_db)
-):
- """
- Schritt 1: Rolle speichern.
-
- Wählt die Rolle des Nutzers und gibt passende Template-Empfehlungen.
- """
+async def save_step_1(data: Step1Data, db: DBSession = Depends(get_db)):
+ """Schritt 1: Rolle speichern."""
user_id = get_user_id_from_request()
- # Validiere Rolle
try:
role = UserRoleEnum(data.role)
except ValueError:
- raise HTTPException(
- status_code=400,
- detail="Ungültige Rolle. Erlaubt: 'lehrkraft', 'schulleitung', 'it_beauftragte'"
- )
+ raise HTTPException(status_code=400, detail="Ungültige Rolle. Erlaubt: 'lehrkraft', 'schulleitung', 'it_beauftragte'")
subscription = _get_or_create_subscription(db, user_id)
-
- # Update
subscription.user_role = role
subscription.wizard_step = 1
wizard_state = subscription.wizard_state or {}
wizard_state["step1"] = {"role": data.role}
subscription.wizard_state = wizard_state
subscription.updated_at = datetime.utcnow()
-
db.commit()
db.refresh(subscription)
- # Empfohlene Templates
recommended = _get_recommended_templates(db, data.role)
return StepResponse(
- status="success",
- current_step=1,
- next_step=2,
+ status="success", current_step=1, next_step=2,
message=f"Rolle '{data.role}' gespeichert. Bitte wählen Sie jetzt Ihre Themen.",
recommended_templates=recommended,
)
@router.put("/step/2", response_model=StepResponse)
-async def save_step_2(
- data: Step2Data,
- db: DBSession = Depends(get_db)
-):
- """
- Schritt 2: Templates auswählen.
-
- Speichert die ausgewählten Templates (1-3).
- """
+async def save_step_2(data: Step2Data, db: DBSession = Depends(get_db)):
+ """Schritt 2: Templates auswählen."""
user_id = get_user_id_from_request()
subscription = db.query(UserAlertSubscriptionDB).filter(
@@ -269,46 +164,28 @@ async def save_step_2(
if not subscription:
raise HTTPException(status_code=400, detail="Bitte zuerst Schritt 1 abschließen")
- # Validiere Template-IDs
- templates = db.query(AlertTemplateDB).filter(
- AlertTemplateDB.id.in_(data.template_ids)
- ).all()
+ templates = db.query(AlertTemplateDB).filter(AlertTemplateDB.id.in_(data.template_ids)).all()
if len(templates) != len(data.template_ids):
raise HTTPException(status_code=400, detail="Eine oder mehrere Template-IDs sind ungültig")
- # Update
subscription.selected_template_ids = data.template_ids
subscription.wizard_step = 2
wizard_state = subscription.wizard_state or {}
- wizard_state["step2"] = {
- "template_ids": data.template_ids,
- "template_names": [t.name for t in templates],
- }
+ wizard_state["step2"] = {"template_ids": data.template_ids, "template_names": [t.name for t in templates]}
subscription.wizard_state = wizard_state
subscription.updated_at = datetime.utcnow()
-
db.commit()
return StepResponse(
- status="success",
- current_step=2,
- next_step=3,
+ status="success", current_step=2, next_step=3,
message=f"{len(templates)} Themen ausgewählt. Bitte bestätigen Sie Ihre Auswahl.",
- recommended_templates=[],
)
@router.put("/step/3", response_model=StepResponse)
-async def save_step_3(
- data: Step3Data,
- db: DBSession = Depends(get_db)
-):
- """
- Schritt 3: Digest-Einstellungen und Bestätigung.
-
- Speichert E-Mail und Digest-Präferenzen.
- """
+async def save_step_3(data: Step3Data, db: DBSession = Depends(get_db)):
+ """Schritt 3: Digest-Einstellungen und Bestätigung."""
user_id = get_user_id_from_request()
subscription = db.query(UserAlertSubscriptionDB).filter(
@@ -318,16 +195,13 @@ async def save_step_3(
if not subscription:
raise HTTPException(status_code=400, detail="Bitte zuerst Schritte 1 und 2 abschließen")
-
if not subscription.selected_template_ids:
raise HTTPException(status_code=400, detail="Bitte zuerst Templates auswählen (Schritt 2)")
- # Update
subscription.notification_email = data.notification_email
subscription.digest_enabled = data.digest_enabled
subscription.digest_frequency = data.digest_frequency
subscription.wizard_step = 3
-
wizard_state = subscription.wizard_state or {}
wizard_state["step3"] = {
"notification_email": data.notification_email,
@@ -336,27 +210,17 @@ async def save_step_3(
}
subscription.wizard_state = wizard_state
subscription.updated_at = datetime.utcnow()
-
db.commit()
return StepResponse(
- status="success",
- current_step=3,
- next_step=4,
+ status="success", current_step=3, next_step=4,
message="Einstellungen gespeichert. Klicken Sie auf 'Jetzt starten' um den Wizard abzuschließen.",
- recommended_templates=[],
)
@router.post("/complete")
-async def complete_wizard(
- db: DBSession = Depends(get_db)
-):
- """
- Wizard abschließen und Templates aktivieren.
-
- Erstellt Topics, Rules und Profile basierend auf den gewählten Templates.
- """
+async def complete_wizard(db: DBSession = Depends(get_db)):
+ """Wizard abschließen und Templates aktivieren."""
user_id = get_user_id_from_request()
subscription = db.query(UserAlertSubscriptionDB).filter(
@@ -366,18 +230,14 @@ async def complete_wizard(
if not subscription:
raise HTTPException(status_code=400, detail="Kein aktiver Wizard gefunden")
-
if not subscription.selected_template_ids:
raise HTTPException(status_code=400, detail="Bitte zuerst Templates auswählen")
- # Aktiviere Templates (über Subscription-Endpoint)
from .subscriptions import activate_template, ActivateTemplateRequest
- # Markiere als abgeschlossen
subscription.wizard_completed = True
subscription.wizard_step = 4
subscription.updated_at = datetime.utcnow()
-
db.commit()
return {
@@ -390,9 +250,7 @@ async def complete_wizard(
@router.post("/reset")
-async def reset_wizard(
- db: DBSession = Depends(get_db)
-):
+async def reset_wizard(db: DBSession = Depends(get_db)):
"""Wizard zurücksetzen (für Neustart)."""
user_id = get_user_id_from_request()
@@ -405,10 +263,7 @@ async def reset_wizard(
db.delete(subscription)
db.commit()
- return {
- "status": "success",
- "message": "Wizard zurückgesetzt. Sie können neu beginnen.",
- }
+ return {"status": "success", "message": "Wizard zurückgesetzt. Sie können neu beginnen."}
# ============================================================================
@@ -416,29 +271,16 @@ async def reset_wizard(
# ============================================================================
@router.post("/migrate/email", response_model=MigrateEmailResponse)
-async def start_email_migration(
- request: MigrateEmailRequest = None,
- db: DBSession = Depends(get_db)
-):
- """
- Starte E-Mail-Migration für bestehende Google Alerts.
-
- Generiert eine eindeutige Inbound-E-Mail-Adresse, an die der Nutzer
- seine Google Alerts weiterleiten kann.
- """
+async def start_email_migration(request: MigrateEmailRequest = None, db: DBSession = Depends(get_db)):
+ """Starte E-Mail-Migration für bestehende Google Alerts."""
user_id = get_user_id_from_request()
- # Erstelle AlertSource
source = AlertSourceDB(
- id=str(uuid.uuid4()),
- user_id=user_id,
+ id=str(uuid.uuid4()), user_id=user_id,
source_type=FeedTypeEnum.EMAIL,
original_label=request.original_label if request else "Google Alert Migration",
- migration_mode=MigrationModeEnum.FORWARD,
- is_active=True,
+ migration_mode=MigrationModeEnum.FORWARD, is_active=True,
)
-
- # Generiere Inbound-Adresse
source.inbound_address = _generate_inbound_address(user_id, source.id)
db.add(source)
@@ -446,9 +288,7 @@ async def start_email_migration(
db.refresh(source)
return MigrateEmailResponse(
- status="success",
- inbound_address=source.inbound_address,
- source_id=source.id,
+ status="success", inbound_address=source.inbound_address, source_id=source.id,
instructions=[
"1. Öffnen Sie Google Alerts (google.com/alerts)",
"2. Klicken Sie auf das Bearbeiten-Symbol bei Ihrem Alert",
@@ -460,74 +300,49 @@ async def start_email_migration(
@router.post("/migrate/rss", response_model=MigrateRssResponse)
-async def import_rss_feeds(
- request: MigrateRssRequest,
- db: DBSession = Depends(get_db)
-):
- """
- Importiere bestehende Google Alert RSS-Feeds.
-
- Erstellt für jede RSS-URL einen AlertSource und Topic.
- """
+async def import_rss_feeds(request: MigrateRssRequest, db: DBSession = Depends(get_db)):
+ """Importiere bestehende Google Alert RSS-Feeds."""
user_id = get_user_id_from_request()
from ..db.models import AlertTopicDB
- sources_created = 0
- topics_created = 0
+ sources_created, topics_created = 0, 0
for i, url in enumerate(request.rss_urls):
- # Label aus Request oder generieren
label = None
if request.labels and i < len(request.labels):
label = request.labels[i]
if not label:
label = f"RSS Feed {i + 1}"
- # Erstelle AlertSource
source = AlertSourceDB(
- id=str(uuid.uuid4()),
- user_id=user_id,
- source_type=FeedTypeEnum.RSS,
- original_label=label,
- rss_url=url,
- migration_mode=MigrationModeEnum.IMPORT,
- is_active=True,
+ id=str(uuid.uuid4()), user_id=user_id,
+ source_type=FeedTypeEnum.RSS, original_label=label,
+ rss_url=url, migration_mode=MigrationModeEnum.IMPORT, is_active=True,
)
db.add(source)
sources_created += 1
- # Erstelle Topic
topic = AlertTopicDB(
- id=str(uuid.uuid4()),
- user_id=user_id,
- name=label,
- description=f"Importiert aus RSS: {url[:50]}...",
- feed_url=url,
- feed_type=FeedTypeEnum.RSS,
- is_active=True,
- fetch_interval_minutes=60,
+ id=str(uuid.uuid4()), user_id=user_id,
+ name=label, description=f"Importiert aus RSS: {url[:50]}...",
+ feed_url=url, feed_type=FeedTypeEnum.RSS,
+ is_active=True, fetch_interval_minutes=60,
)
db.add(topic)
-
- # Verknüpfe Source mit Topic
source.topic_id = topic.id
topics_created += 1
db.commit()
return MigrateRssResponse(
- status="success",
- sources_created=sources_created,
- topics_created=topics_created,
+ status="success", sources_created=sources_created, topics_created=topics_created,
message=f"{sources_created} RSS-Feeds importiert. Die Alerts werden automatisch abgerufen.",
)
@router.get("/migrate/sources")
-async def list_migration_sources(
- db: DBSession = Depends(get_db)
-):
+async def list_migration_sources(db: DBSession = Depends(get_db)):
"""Liste alle Migration-Quellen des Users."""
user_id = get_user_id_from_request()
diff --git a/backend-lehrer/alerts_agent/api/wizard_models.py b/backend-lehrer/alerts_agent/api/wizard_models.py
new file mode 100644
index 0000000..6cfea14
--- /dev/null
+++ b/backend-lehrer/alerts_agent/api/wizard_models.py
@@ -0,0 +1,68 @@
+"""
+Wizard API - Request/Response Models.
+"""
+
+from typing import Optional, List, Dict, Any
+from pydantic import BaseModel, Field
+
+
+class WizardState(BaseModel):
+ """Aktueller Wizard-Status."""
+ subscription_id: Optional[str] = None
+ current_step: int = 0 # 0=nicht gestartet, 1-3=Schritte, 4=abgeschlossen
+ is_completed: bool = False
+ step_data: Dict[str, Any] = {}
+ recommended_templates: List[Dict[str, Any]] = []
+
+
+class Step1Data(BaseModel):
+ """Daten für Schritt 1: Rollenwahl."""
+ role: str = Field(..., description="lehrkraft, schulleitung, it_beauftragte")
+
+
+class Step2Data(BaseModel):
+ """Daten für Schritt 2: Template-Auswahl."""
+ template_ids: List[str] = Field(..., min_length=1, max_length=3)
+
+
+class Step3Data(BaseModel):
+ """Daten für Schritt 3: Bestätigung."""
+ notification_email: Optional[str] = None
+ digest_enabled: bool = True
+ digest_frequency: str = "weekly"
+
+
+class StepResponse(BaseModel):
+ """Response für Schritt-Update."""
+ status: str
+ current_step: int
+ next_step: int
+ message: str
+ recommended_templates: List[Dict[str, Any]] = []
+
+
+class MigrateEmailRequest(BaseModel):
+ """Request für E-Mail-Migration."""
+ original_label: Optional[str] = Field(default=None, description="Beschreibung des Alerts")
+
+
+class MigrateEmailResponse(BaseModel):
+ """Response für E-Mail-Migration."""
+ status: str
+ inbound_address: str
+ instructions: List[str]
+ source_id: str
+
+
+class MigrateRssRequest(BaseModel):
+ """Request für RSS-Import."""
+ rss_urls: List[str] = Field(..., min_length=1, max_length=20)
+ labels: Optional[List[str]] = None
+
+
+class MigrateRssResponse(BaseModel):
+ """Response für RSS-Import."""
+ status: str
+ sources_created: int
+ topics_created: int
+ message: str
diff --git a/backend-lehrer/alerts_agent/processing/rule_engine.py b/backend-lehrer/alerts_agent/processing/rule_engine.py
index 1eede0e..01dc3a4 100644
--- a/backend-lehrer/alerts_agent/processing/rule_engine.py
+++ b/backend-lehrer/alerts_agent/processing/rule_engine.py
@@ -2,277 +2,49 @@
Rule Engine für Alerts Agent.
Evaluiert Regeln gegen Alert-Items und führt Aktionen aus.
-
-Regel-Struktur:
-- Bedingungen: [{field, operator, value}, ...] (AND-verknüpft)
-- Aktion: keep, drop, tag, email, webhook, slack
-- Priorität: Höhere Priorität wird zuerst evaluiert
+Batch-Verarbeitung und Action-Anwendung.
"""
-import re
+
import logging
-from dataclasses import dataclass
-from typing import List, Dict, Any, Optional, Callable
-from enum import Enum
+from typing import List, Dict, Any, Optional
from alerts_agent.db.models import AlertItemDB, AlertRuleDB, RuleActionEnum
+from .rule_models import (
+ ConditionOperator,
+ RuleCondition,
+ RuleMatch,
+ get_field_value,
+ evaluate_condition,
+ evaluate_rule,
+ evaluate_rules_for_alert,
+ create_keyword_rule,
+ create_exclusion_rule,
+ create_score_threshold_rule,
+)
+
logger = logging.getLogger(__name__)
-
-class ConditionOperator(str, Enum):
- """Operatoren für Regel-Bedingungen."""
- CONTAINS = "contains"
- NOT_CONTAINS = "not_contains"
- EQUALS = "equals"
- NOT_EQUALS = "not_equals"
- STARTS_WITH = "starts_with"
- ENDS_WITH = "ends_with"
- REGEX = "regex"
- GREATER_THAN = "gt"
- LESS_THAN = "lt"
- GREATER_EQUAL = "gte"
- LESS_EQUAL = "lte"
- IN_LIST = "in"
- NOT_IN_LIST = "not_in"
-
-
-@dataclass
-class RuleCondition:
- """Eine einzelne Regel-Bedingung."""
- field: str # "title", "snippet", "url", "source", "relevance_score"
- operator: ConditionOperator
- value: Any # str, float, list
-
- @classmethod
- def from_dict(cls, data: Dict) -> "RuleCondition":
- """Erstellt eine Bedingung aus einem Dict."""
- return cls(
- field=data.get("field", ""),
- operator=ConditionOperator(data.get("operator", data.get("op", "contains"))),
- value=data.get("value", ""),
- )
-
-
-@dataclass
-class RuleMatch:
- """Ergebnis einer Regel-Evaluierung."""
- rule_id: str
- rule_name: str
- matched: bool
- action: RuleActionEnum
- action_config: Dict[str, Any]
- conditions_met: List[str] # Welche Bedingungen haben gematched
-
-
-def get_field_value(alert: AlertItemDB, field: str) -> Any:
- """
- Extrahiert einen Feldwert aus einem Alert.
-
- Args:
- alert: Alert-Item
- field: Feldname
-
- Returns:
- Feldwert oder None
- """
- field_map = {
- "title": alert.title,
- "snippet": alert.snippet,
- "url": alert.url,
- "source": alert.source.value if alert.source else "",
- "status": alert.status.value if alert.status else "",
- "relevance_score": alert.relevance_score,
- "relevance_decision": alert.relevance_decision.value if alert.relevance_decision else "",
- "lang": alert.lang,
- "topic_id": alert.topic_id,
- }
-
- return field_map.get(field)
-
-
-def evaluate_condition(
- alert: AlertItemDB,
- condition: RuleCondition,
-) -> bool:
- """
- Evaluiert eine einzelne Bedingung gegen einen Alert.
-
- Args:
- alert: Alert-Item
- condition: Zu evaluierende Bedingung
-
- Returns:
- True wenn Bedingung erfüllt
- """
- field_value = get_field_value(alert, condition.field)
-
- if field_value is None:
- return False
-
- op = condition.operator
- target = condition.value
-
- try:
- # String-Operationen (case-insensitive)
- if isinstance(field_value, str):
- field_lower = field_value.lower()
- target_lower = str(target).lower() if isinstance(target, str) else target
-
- if op == ConditionOperator.CONTAINS:
- return target_lower in field_lower
-
- elif op == ConditionOperator.NOT_CONTAINS:
- return target_lower not in field_lower
-
- elif op == ConditionOperator.EQUALS:
- return field_lower == target_lower
-
- elif op == ConditionOperator.NOT_EQUALS:
- return field_lower != target_lower
-
- elif op == ConditionOperator.STARTS_WITH:
- return field_lower.startswith(target_lower)
-
- elif op == ConditionOperator.ENDS_WITH:
- return field_lower.endswith(target_lower)
-
- elif op == ConditionOperator.REGEX:
- try:
- return bool(re.search(str(target), field_value, re.IGNORECASE))
- except re.error:
- logger.warning(f"Invalid regex pattern: {target}")
- return False
-
- elif op == ConditionOperator.IN_LIST:
- if isinstance(target, list):
- return any(t.lower() in field_lower for t in target if isinstance(t, str))
- return False
-
- elif op == ConditionOperator.NOT_IN_LIST:
- if isinstance(target, list):
- return not any(t.lower() in field_lower for t in target if isinstance(t, str))
- return True
-
- # Numerische Operationen
- elif isinstance(field_value, (int, float)):
- target_num = float(target) if target else 0
-
- if op == ConditionOperator.EQUALS:
- return field_value == target_num
-
- elif op == ConditionOperator.NOT_EQUALS:
- return field_value != target_num
-
- elif op == ConditionOperator.GREATER_THAN:
- return field_value > target_num
-
- elif op == ConditionOperator.LESS_THAN:
- return field_value < target_num
-
- elif op == ConditionOperator.GREATER_EQUAL:
- return field_value >= target_num
-
- elif op == ConditionOperator.LESS_EQUAL:
- return field_value <= target_num
-
- except Exception as e:
- logger.error(f"Error evaluating condition: {e}")
- return False
-
- return False
-
-
-def evaluate_rule(
- alert: AlertItemDB,
- rule: AlertRuleDB,
-) -> RuleMatch:
- """
- Evaluiert eine Regel gegen einen Alert.
-
- Alle Bedingungen müssen erfüllt sein (AND-Verknüpfung).
-
- Args:
- alert: Alert-Item
- rule: Zu evaluierende Regel
-
- Returns:
- RuleMatch-Ergebnis
- """
- conditions = rule.conditions or []
- conditions_met = []
- all_matched = True
-
- for cond_dict in conditions:
- condition = RuleCondition.from_dict(cond_dict)
- if evaluate_condition(alert, condition):
- conditions_met.append(f"{condition.field} {condition.operator.value} {condition.value}")
- else:
- all_matched = False
-
- # Wenn keine Bedingungen definiert sind, matcht die Regel immer
- if not conditions:
- all_matched = True
-
- return RuleMatch(
- rule_id=rule.id,
- rule_name=rule.name,
- matched=all_matched,
- action=rule.action_type,
- action_config=rule.action_config or {},
- conditions_met=conditions_met,
- )
-
-
-def evaluate_rules_for_alert(
- alert: AlertItemDB,
- rules: List[AlertRuleDB],
-) -> Optional[RuleMatch]:
- """
- Evaluiert alle Regeln gegen einen Alert und gibt den ersten Match zurück.
-
- Regeln werden nach Priorität (absteigend) evaluiert.
-
- Args:
- alert: Alert-Item
- rules: Liste von Regeln (sollte bereits nach Priorität sortiert sein)
-
- Returns:
- Erster RuleMatch oder None
- """
- for rule in rules:
- if not rule.is_active:
- continue
-
- # Topic-Filter: Regel gilt nur für bestimmtes Topic
- if rule.topic_id and rule.topic_id != alert.topic_id:
- continue
-
- match = evaluate_rule(alert, rule)
-
- if match.matched:
- logger.debug(
- f"Rule '{rule.name}' matched alert '{alert.id[:8]}': "
- f"{match.conditions_met}"
- )
- return match
-
- return None
+# Re-export for backward compatibility
+__all__ = [
+ "ConditionOperator",
+ "RuleCondition",
+ "RuleMatch",
+ "get_field_value",
+ "evaluate_condition",
+ "evaluate_rule",
+ "evaluate_rules_for_alert",
+ "RuleEngine",
+ "create_keyword_rule",
+ "create_exclusion_rule",
+ "create_score_threshold_rule",
+]
class RuleEngine:
- """
- Rule Engine für Batch-Verarbeitung von Alerts.
-
- Verwendet für das Scoring von mehreren Alerts gleichzeitig.
- """
+ """Rule Engine für Batch-Verarbeitung von Alerts."""
def __init__(self, db_session):
- """
- Initialisiert die Rule Engine.
-
- Args:
- db_session: SQLAlchemy Session
- """
self.db = db_session
self._rules_cache: Optional[List[AlertRuleDB]] = None
@@ -282,42 +54,19 @@ class RuleEngine:
from alerts_agent.db.repository import RuleRepository
repo = RuleRepository(self.db)
self._rules_cache = repo.get_active()
-
return self._rules_cache
def clear_cache(self) -> None:
"""Leert den Regel-Cache."""
self._rules_cache = None
- def process_alert(
- self,
- alert: AlertItemDB,
- ) -> Optional[RuleMatch]:
- """
- Verarbeitet einen Alert mit allen aktiven Regeln.
-
- Args:
- alert: Alert-Item
-
- Returns:
- RuleMatch wenn eine Regel matcht, sonst None
- """
+ def process_alert(self, alert: AlertItemDB) -> Optional[RuleMatch]:
+ """Verarbeitet einen Alert mit allen aktiven Regeln."""
rules = self._get_active_rules()
return evaluate_rules_for_alert(alert, rules)
- def process_alerts(
- self,
- alerts: List[AlertItemDB],
- ) -> Dict[str, RuleMatch]:
- """
- Verarbeitet mehrere Alerts mit allen aktiven Regeln.
-
- Args:
- alerts: Liste von Alert-Items
-
- Returns:
- Dict von alert_id -> RuleMatch (nur für gematschte Alerts)
- """
+ def process_alerts(self, alerts: List[AlertItemDB]) -> Dict[str, RuleMatch]:
+ """Verarbeitet mehrere Alerts mit allen aktiven Regeln."""
rules = self._get_active_rules()
results = {}
@@ -328,21 +77,8 @@ class RuleEngine:
return results
- def apply_rule_actions(
- self,
- alert: AlertItemDB,
- match: RuleMatch,
- ) -> Dict[str, Any]:
- """
- Wendet die Regel-Aktion auf einen Alert an.
-
- Args:
- alert: Alert-Item
- match: RuleMatch mit Aktionsinformationen
-
- Returns:
- Dict mit Ergebnis der Aktion
- """
+ def apply_rule_actions(self, alert: AlertItemDB, match: RuleMatch) -> Dict[str, Any]:
+ """Wendet die Regel-Aktion auf einen Alert an."""
from alerts_agent.db.repository import AlertItemRepository, RuleRepository
alert_repo = AlertItemRepository(self.db)
@@ -350,36 +86,26 @@ class RuleEngine:
action = match.action
config = match.action_config
-
result = {"action": action.value, "success": False}
try:
if action == RuleActionEnum.KEEP:
- # Alert als KEEP markieren
alert_repo.update_scoring(
- alert_id=alert.id,
- score=1.0,
- decision="KEEP",
- reasons=["rule_match"],
- summary=f"Matched rule: {match.rule_name}",
+ alert_id=alert.id, score=1.0, decision="KEEP",
+ reasons=["rule_match"], summary=f"Matched rule: {match.rule_name}",
model="rule_engine",
)
result["success"] = True
elif action == RuleActionEnum.DROP:
- # Alert als DROP markieren
alert_repo.update_scoring(
- alert_id=alert.id,
- score=0.0,
- decision="DROP",
- reasons=["rule_match"],
- summary=f"Dropped by rule: {match.rule_name}",
+ alert_id=alert.id, score=0.0, decision="DROP",
+ reasons=["rule_match"], summary=f"Dropped by rule: {match.rule_name}",
model="rule_engine",
)
result["success"] = True
elif action == RuleActionEnum.TAG:
- # Tags hinzufügen
tags = config.get("tags", [])
if tags:
existing_tags = alert.user_tags or []
@@ -389,27 +115,20 @@ class RuleEngine:
result["success"] = True
elif action == RuleActionEnum.EMAIL:
- # E-Mail-Benachrichtigung senden
- # Wird von Actions-Modul behandelt
result["email_config"] = config
result["success"] = True
- result["deferred"] = True # Wird später gesendet
+ result["deferred"] = True
elif action == RuleActionEnum.WEBHOOK:
- # Webhook aufrufen
- # Wird von Actions-Modul behandelt
result["webhook_config"] = config
result["success"] = True
result["deferred"] = True
elif action == RuleActionEnum.SLACK:
- # Slack-Nachricht senden
- # Wird von Actions-Modul behandelt
result["slack_config"] = config
result["success"] = True
result["deferred"] = True
- # Match-Count erhöhen
rule_repo.increment_match_count(match.rule_id)
except Exception as e:
@@ -417,96 +136,3 @@ class RuleEngine:
result["error"] = str(e)
return result
-
-
-# Convenience-Funktionen für einfache Nutzung
-def create_keyword_rule(
- name: str,
- keywords: List[str],
- action: str = "keep",
- field: str = "title",
-) -> Dict:
- """
- Erstellt eine Keyword-basierte Regel.
-
- Args:
- name: Regelname
- keywords: Liste von Keywords (OR-verknüpft über IN_LIST)
- action: Aktion (keep, drop, tag)
- field: Feld zum Prüfen (title, snippet, url)
-
- Returns:
- Regel-Definition als Dict
- """
- return {
- "name": name,
- "conditions": [
- {
- "field": field,
- "operator": "in",
- "value": keywords,
- }
- ],
- "action_type": action,
- "action_config": {},
- }
-
-
-def create_exclusion_rule(
- name: str,
- excluded_terms: List[str],
- field: str = "title",
-) -> Dict:
- """
- Erstellt eine Ausschluss-Regel.
-
- Args:
- name: Regelname
- excluded_terms: Liste von auszuschließenden Begriffen
- field: Feld zum Prüfen
-
- Returns:
- Regel-Definition als Dict
- """
- return {
- "name": name,
- "conditions": [
- {
- "field": field,
- "operator": "in",
- "value": excluded_terms,
- }
- ],
- "action_type": "drop",
- "action_config": {},
- }
-
-
-def create_score_threshold_rule(
- name: str,
- min_score: float,
- action: str = "keep",
-) -> Dict:
- """
- Erstellt eine Score-basierte Regel.
-
- Args:
- name: Regelname
- min_score: Mindest-Score
- action: Aktion bei Erreichen des Scores
-
- Returns:
- Regel-Definition als Dict
- """
- return {
- "name": name,
- "conditions": [
- {
- "field": "relevance_score",
- "operator": "gte",
- "value": min_score,
- }
- ],
- "action_type": action,
- "action_config": {},
- }
diff --git a/backend-lehrer/alerts_agent/processing/rule_models.py b/backend-lehrer/alerts_agent/processing/rule_models.py
new file mode 100644
index 0000000..974af96
--- /dev/null
+++ b/backend-lehrer/alerts_agent/processing/rule_models.py
@@ -0,0 +1,206 @@
+"""
+Rule Engine - Models, Condition Evaluation, and Convenience Functions.
+
+Datenmodelle und Evaluierungs-Logik fuer Alert-Regeln.
+"""
+
+import re
+import logging
+from dataclasses import dataclass
+from typing import List, Dict, Any, Optional
+from enum import Enum
+
+from alerts_agent.db.models import AlertItemDB, AlertRuleDB, RuleActionEnum
+
+logger = logging.getLogger(__name__)
+
+
+class ConditionOperator(str, Enum):
+ """Operatoren für Regel-Bedingungen."""
+ CONTAINS = "contains"
+ NOT_CONTAINS = "not_contains"
+ EQUALS = "equals"
+ NOT_EQUALS = "not_equals"
+ STARTS_WITH = "starts_with"
+ ENDS_WITH = "ends_with"
+ REGEX = "regex"
+ GREATER_THAN = "gt"
+ LESS_THAN = "lt"
+ GREATER_EQUAL = "gte"
+ LESS_EQUAL = "lte"
+ IN_LIST = "in"
+ NOT_IN_LIST = "not_in"
+
+
+@dataclass
+class RuleCondition:
+ """Eine einzelne Regel-Bedingung."""
+ field: str
+ operator: ConditionOperator
+ value: Any
+
+ @classmethod
+ def from_dict(cls, data: Dict) -> "RuleCondition":
+ return cls(
+ field=data.get("field", ""),
+ operator=ConditionOperator(data.get("operator", data.get("op", "contains"))),
+ value=data.get("value", ""),
+ )
+
+
+@dataclass
+class RuleMatch:
+ """Ergebnis einer Regel-Evaluierung."""
+ rule_id: str
+ rule_name: str
+ matched: bool
+ action: RuleActionEnum
+ action_config: Dict[str, Any]
+ conditions_met: List[str]
+
+
+def get_field_value(alert: AlertItemDB, field: str) -> Any:
+ """Extrahiert einen Feldwert aus einem Alert."""
+ field_map = {
+ "title": alert.title,
+ "snippet": alert.snippet,
+ "url": alert.url,
+ "source": alert.source.value if alert.source else "",
+ "status": alert.status.value if alert.status else "",
+ "relevance_score": alert.relevance_score,
+ "relevance_decision": alert.relevance_decision.value if alert.relevance_decision else "",
+ "lang": alert.lang,
+ "topic_id": alert.topic_id,
+ }
+ return field_map.get(field)
+
+
+def evaluate_condition(alert: AlertItemDB, condition: RuleCondition) -> bool:
+ """Evaluiert eine einzelne Bedingung gegen einen Alert."""
+ field_value = get_field_value(alert, condition.field)
+ if field_value is None:
+ return False
+
+ op = condition.operator
+ target = condition.value
+
+ try:
+ if isinstance(field_value, str):
+ field_lower = field_value.lower()
+ target_lower = str(target).lower() if isinstance(target, str) else target
+
+ if op == ConditionOperator.CONTAINS:
+ return target_lower in field_lower
+ elif op == ConditionOperator.NOT_CONTAINS:
+ return target_lower not in field_lower
+ elif op == ConditionOperator.EQUALS:
+ return field_lower == target_lower
+ elif op == ConditionOperator.NOT_EQUALS:
+ return field_lower != target_lower
+ elif op == ConditionOperator.STARTS_WITH:
+ return field_lower.startswith(target_lower)
+ elif op == ConditionOperator.ENDS_WITH:
+ return field_lower.endswith(target_lower)
+ elif op == ConditionOperator.REGEX:
+ try:
+ return bool(re.search(str(target), field_value, re.IGNORECASE))
+ except re.error:
+ logger.warning(f"Invalid regex pattern: {target}")
+ return False
+ elif op == ConditionOperator.IN_LIST:
+ if isinstance(target, list):
+ return any(t.lower() in field_lower for t in target if isinstance(t, str))
+ return False
+ elif op == ConditionOperator.NOT_IN_LIST:
+ if isinstance(target, list):
+ return not any(t.lower() in field_lower for t in target if isinstance(t, str))
+ return True
+
+ elif isinstance(field_value, (int, float)):
+ target_num = float(target) if target else 0
+ if op == ConditionOperator.EQUALS:
+ return field_value == target_num
+ elif op == ConditionOperator.NOT_EQUALS:
+ return field_value != target_num
+ elif op == ConditionOperator.GREATER_THAN:
+ return field_value > target_num
+ elif op == ConditionOperator.LESS_THAN:
+ return field_value < target_num
+ elif op == ConditionOperator.GREATER_EQUAL:
+ return field_value >= target_num
+ elif op == ConditionOperator.LESS_EQUAL:
+ return field_value <= target_num
+
+ except Exception as e:
+ logger.error(f"Error evaluating condition: {e}")
+ return False
+
+ return False
+
+
+def evaluate_rule(alert: AlertItemDB, rule: AlertRuleDB) -> RuleMatch:
+ """Evaluiert eine Regel gegen einen Alert (AND-Verknüpfung)."""
+ conditions = rule.conditions or []
+ conditions_met = []
+ all_matched = True
+
+ for cond_dict in conditions:
+ condition = RuleCondition.from_dict(cond_dict)
+ if evaluate_condition(alert, condition):
+ conditions_met.append(f"{condition.field} {condition.operator.value} {condition.value}")
+ else:
+ all_matched = False
+
+ if not conditions:
+ all_matched = True
+
+ return RuleMatch(
+ rule_id=rule.id, rule_name=rule.name, matched=all_matched,
+ action=rule.action_type, action_config=rule.action_config or {},
+ conditions_met=conditions_met,
+ )
+
+
+def evaluate_rules_for_alert(alert: AlertItemDB, rules: List[AlertRuleDB]) -> Optional[RuleMatch]:
+ """Evaluiert alle Regeln gegen einen Alert und gibt den ersten Match zurück."""
+ for rule in rules:
+ if not rule.is_active:
+ continue
+ if rule.topic_id and rule.topic_id != alert.topic_id:
+ continue
+
+ match = evaluate_rule(alert, rule)
+ if match.matched:
+ logger.debug(f"Rule '{rule.name}' matched alert '{alert.id[:8]}': {match.conditions_met}")
+ return match
+
+ return None
+
+
+# Convenience-Funktionen
+
+def create_keyword_rule(name: str, keywords: List[str], action: str = "keep", field: str = "title") -> Dict:
+ """Erstellt eine Keyword-basierte Regel."""
+ return {
+ "name": name,
+ "conditions": [{"field": field, "operator": "in", "value": keywords}],
+ "action_type": action, "action_config": {},
+ }
+
+
+def create_exclusion_rule(name: str, excluded_terms: List[str], field: str = "title") -> Dict:
+ """Erstellt eine Ausschluss-Regel."""
+ return {
+ "name": name,
+ "conditions": [{"field": field, "operator": "in", "value": excluded_terms}],
+ "action_type": "drop", "action_config": {},
+ }
+
+
+def create_score_threshold_rule(name: str, min_score: float, action: str = "keep") -> Dict:
+ """Erstellt eine Score-basierte Regel."""
+ return {
+ "name": name,
+ "conditions": [{"field": "relevance_score", "operator": "gte", "value": min_score}],
+ "action_type": action, "action_config": {},
+ }
diff --git a/backend-lehrer/auth/__init__.py b/backend-lehrer/auth/__init__.py
index b56b38b..a3778a4 100644
--- a/backend-lehrer/auth/__init__.py
+++ b/backend-lehrer/auth/__init__.py
@@ -4,15 +4,11 @@ BreakPilot Authentication Module
Hybrid authentication supporting both Keycloak and local JWT tokens.
"""
-from .keycloak_auth import (
+from .keycloak_models import (
# Config
KeycloakConfig,
KeycloakUser,
- # Authenticators
- KeycloakAuthenticator,
- HybridAuthenticator,
-
# Exceptions
KeycloakAuthError,
TokenExpiredError,
@@ -21,6 +17,14 @@ from .keycloak_auth import (
# Factory functions
get_keycloak_config_from_env,
+)
+
+from .keycloak_auth import (
+ # Authenticators
+ KeycloakAuthenticator,
+ HybridAuthenticator,
+
+ # Factory functions
get_authenticator,
get_auth,
diff --git a/backend-lehrer/auth/keycloak_auth.py b/backend-lehrer/auth/keycloak_auth.py
index 3449169..a8d8e71 100644
--- a/backend-lehrer/auth/keycloak_auth.py
+++ b/backend-lehrer/auth/keycloak_auth.py
@@ -14,110 +14,24 @@ import os
import httpx
import jwt
from jwt import PyJWKClient
-from datetime import datetime, timezone
-from typing import Optional, Dict, Any, List
-from dataclasses import dataclass
-from functools import lru_cache
import logging
+from typing import Optional, Dict, Any
+
+from .keycloak_models import (
+ KeycloakConfig,
+ KeycloakUser,
+ KeycloakAuthError,
+ TokenExpiredError,
+ TokenInvalidError,
+ KeycloakConfigError,
+ get_keycloak_config_from_env,
+)
logger = logging.getLogger(__name__)
-@dataclass
-class KeycloakConfig:
- """Keycloak connection configuration."""
- server_url: str
- realm: str
- client_id: str
- client_secret: Optional[str] = None
- verify_ssl: bool = True
-
- @property
- def issuer_url(self) -> str:
- return f"{self.server_url}/realms/{self.realm}"
-
- @property
- def jwks_url(self) -> str:
- return f"{self.issuer_url}/protocol/openid-connect/certs"
-
- @property
- def token_url(self) -> str:
- return f"{self.issuer_url}/protocol/openid-connect/token"
-
- @property
- def userinfo_url(self) -> str:
- return f"{self.issuer_url}/protocol/openid-connect/userinfo"
-
-
-@dataclass
-class KeycloakUser:
- """User information extracted from Keycloak token."""
- user_id: str # Keycloak subject (sub)
- email: str
- email_verified: bool
- name: Optional[str]
- given_name: Optional[str]
- family_name: Optional[str]
- realm_roles: List[str] # Keycloak realm roles
- client_roles: Dict[str, List[str]] # Client-specific roles
- groups: List[str] # Keycloak groups
- tenant_id: Optional[str] # Custom claim for school/tenant
- raw_claims: Dict[str, Any] # All claims for debugging
-
- def has_realm_role(self, role: str) -> bool:
- """Check if user has a specific realm role."""
- return role in self.realm_roles
-
- def has_client_role(self, client_id: str, role: str) -> bool:
- """Check if user has a specific client role."""
- client_roles = self.client_roles.get(client_id, [])
- return role in client_roles
-
- def is_admin(self) -> bool:
- """Check if user has admin role."""
- return self.has_realm_role("admin") or self.has_realm_role("schul_admin")
-
- def is_teacher(self) -> bool:
- """Check if user is a teacher."""
- return self.has_realm_role("teacher") or self.has_realm_role("lehrer")
-
-
-class KeycloakAuthError(Exception):
- """Base exception for Keycloak authentication errors."""
- pass
-
-
-class TokenExpiredError(KeycloakAuthError):
- """Token has expired."""
- pass
-
-
-class TokenInvalidError(KeycloakAuthError):
- """Token is invalid."""
- pass
-
-
-class KeycloakConfigError(KeycloakAuthError):
- """Keycloak configuration error."""
- pass
-
-
class KeycloakAuthenticator:
- """
- Validates JWT tokens against Keycloak.
-
- Usage:
- config = KeycloakConfig(
- server_url="https://keycloak.example.com",
- realm="breakpilot",
- client_id="breakpilot-backend"
- )
- auth = KeycloakAuthenticator(config)
-
- user = await auth.validate_token(token)
- if user.is_teacher():
- # Grant access
- """
+ """Validates JWT tokens against Keycloak."""
def __init__(self, config: KeycloakConfig):
self.config = config
@@ -126,64 +40,29 @@ class KeycloakAuthenticator:
@property
def jwks_client(self) -> PyJWKClient:
- """Lazy-load JWKS client."""
if self._jwks_client is None:
- self._jwks_client = PyJWKClient(
- self.config.jwks_url,
- cache_keys=True,
- lifespan=3600 # Cache keys for 1 hour
- )
+ self._jwks_client = PyJWKClient(self.config.jwks_url, cache_keys=True, lifespan=3600)
return self._jwks_client
async def get_http_client(self) -> httpx.AsyncClient:
- """Get or create async HTTP client."""
if self._http_client is None or self._http_client.is_closed:
- self._http_client = httpx.AsyncClient(
- verify=self.config.verify_ssl,
- timeout=30.0
- )
+ self._http_client = httpx.AsyncClient(verify=self.config.verify_ssl, timeout=30.0)
return self._http_client
async def close(self):
- """Close HTTP client."""
if self._http_client and not self._http_client.is_closed:
await self._http_client.aclose()
def validate_token_sync(self, token: str) -> KeycloakUser:
- """
- Synchronously validate a JWT token against Keycloak JWKS.
-
- Args:
- token: The JWT access token
-
- Returns:
- KeycloakUser with extracted claims
-
- Raises:
- TokenExpiredError: If token has expired
- TokenInvalidError: If token signature is invalid
- """
+ """Synchronously validate a JWT token against Keycloak JWKS."""
try:
- # Get signing key from JWKS
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
-
- # Decode and validate token
payload = jwt.decode(
- token,
- signing_key.key,
- algorithms=["RS256"],
- audience=self.config.client_id,
- issuer=self.config.issuer_url,
- options={
- "verify_exp": True,
- "verify_iat": True,
- "verify_aud": True,
- "verify_iss": True
- }
+ token, signing_key.key, algorithms=["RS256"],
+ audience=self.config.client_id, issuer=self.config.issuer_url,
+ options={"verify_exp": True, "verify_iat": True, "verify_aud": True, "verify_iss": True}
)
-
return self._extract_user(payload)
-
except jwt.ExpiredSignatureError:
raise TokenExpiredError("Token has expired")
except jwt.InvalidAudienceError:
@@ -197,27 +76,14 @@ class KeycloakAuthenticator:
raise TokenInvalidError(f"Token validation failed: {e}")
async def validate_token(self, token: str) -> KeycloakUser:
- """
- Asynchronously validate a JWT token.
-
- Note: JWKS fetching is synchronous due to PyJWKClient limitations,
- but this wrapper allows async context usage.
- """
+ """Asynchronously validate a JWT token."""
return self.validate_token_sync(token)
async def get_userinfo(self, token: str) -> Dict[str, Any]:
- """
- Fetch user info from Keycloak userinfo endpoint.
-
- This provides additional user claims not in the access token.
- """
+ """Fetch user info from Keycloak userinfo endpoint."""
client = await self.get_http_client()
-
try:
- response = await client.get(
- self.config.userinfo_url,
- headers={"Authorization": f"Bearer {token}"}
- )
+ response = await client.get(self.config.userinfo_url, headers={"Authorization": f"Bearer {token}"})
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
@@ -227,94 +93,51 @@ class KeycloakAuthenticator:
def _extract_user(self, payload: Dict[str, Any]) -> KeycloakUser:
"""Extract KeycloakUser from JWT payload."""
-
- # Extract realm roles
realm_access = payload.get("realm_access", {})
realm_roles = realm_access.get("roles", [])
- # Extract client roles
resource_access = payload.get("resource_access", {})
client_roles = {}
for client_id, access in resource_access.items():
client_roles[client_id] = access.get("roles", [])
- # Extract groups
groups = payload.get("groups", [])
-
- # Extract custom tenant claim (if configured in Keycloak)
tenant_id = payload.get("tenant_id") or payload.get("school_id")
return KeycloakUser(
- user_id=payload.get("sub", ""),
- email=payload.get("email", ""),
+ user_id=payload.get("sub", ""), email=payload.get("email", ""),
email_verified=payload.get("email_verified", False),
- name=payload.get("name"),
- given_name=payload.get("given_name"),
+ name=payload.get("name"), given_name=payload.get("given_name"),
family_name=payload.get("family_name"),
- realm_roles=realm_roles,
- client_roles=client_roles,
- groups=groups,
- tenant_id=tenant_id,
- raw_claims=payload
+ realm_roles=realm_roles, client_roles=client_roles,
+ groups=groups, tenant_id=tenant_id, raw_claims=payload
)
-# =============================================
-# HYBRID AUTH: Keycloak + Local JWT
-# =============================================
-
class HybridAuthenticator:
- """
- Hybrid authenticator supporting both Keycloak and local JWT tokens.
+ """Hybrid authenticator supporting both Keycloak and local JWT tokens."""
- This allows gradual migration from local JWT to Keycloak:
- 1. Development: Use local JWT (fast, no external dependencies)
- 2. Production: Use Keycloak for full IAM capabilities
-
- Token type detection:
- - Keycloak tokens: Have 'iss' claim matching Keycloak URL
- - Local tokens: Have 'iss' claim as 'breakpilot' or no 'iss'
- """
-
- def __init__(
- self,
- keycloak_config: Optional[KeycloakConfig] = None,
- local_jwt_secret: Optional[str] = None,
- environment: str = "development"
- ):
+ def __init__(self, keycloak_config=None, local_jwt_secret=None, environment="development"):
self.environment = environment
self.keycloak_enabled = keycloak_config is not None
self.local_jwt_secret = local_jwt_secret
-
- if keycloak_config:
- self.keycloak_auth = KeycloakAuthenticator(keycloak_config)
- else:
- self.keycloak_auth = None
+ self.keycloak_auth = KeycloakAuthenticator(keycloak_config) if keycloak_config else None
async def validate_token(self, token: str) -> Dict[str, Any]:
- """
- Validate token using appropriate method.
-
- Returns a unified user dict compatible with existing code.
- """
+ """Validate token using appropriate method."""
if not token:
raise TokenInvalidError("No token provided")
- # Try to peek at the token to determine type
try:
- # Decode without verification to check issuer
unverified = jwt.decode(token, options={"verify_signature": False})
issuer = unverified.get("iss", "")
except jwt.InvalidTokenError:
raise TokenInvalidError("Cannot decode token")
- # Check if it's a Keycloak token
if self.keycloak_auth and self.keycloak_auth.config.issuer_url in issuer:
- # Validate with Keycloak
kc_user = await self.keycloak_auth.validate_token(token)
return self._keycloak_user_to_dict(kc_user)
- # Fall back to local JWT validation
if self.local_jwt_secret:
return self._validate_local_token(token)
@@ -326,13 +149,7 @@ class HybridAuthenticator:
raise KeycloakConfigError("Local JWT secret not configured")
try:
- payload = jwt.decode(
- token,
- self.local_jwt_secret,
- algorithms=["HS256"]
- )
-
- # Map local token claims to unified format
+ payload = jwt.decode(token, self.local_jwt_secret, algorithms=["HS256"])
return {
"user_id": payload.get("user_id", payload.get("sub", "")),
"email": payload.get("email", ""),
@@ -349,7 +166,6 @@ class HybridAuthenticator:
def _keycloak_user_to_dict(self, user: KeycloakUser) -> Dict[str, Any]:
"""Convert KeycloakUser to dict compatible with existing code."""
- # Map Keycloak roles to our role system
role = "user"
if user.is_admin():
role = "admin"
@@ -357,20 +173,15 @@ class HybridAuthenticator:
role = "teacher"
return {
- "user_id": user.user_id,
- "email": user.email,
+ "user_id": user.user_id, "email": user.email,
"name": user.name or f"{user.given_name or ''} {user.family_name or ''}".strip(),
- "role": role,
- "realm_roles": user.realm_roles,
- "client_roles": user.client_roles,
- "groups": user.groups,
- "tenant_id": user.tenant_id,
- "email_verified": user.email_verified,
+ "role": role, "realm_roles": user.realm_roles,
+ "client_roles": user.client_roles, "groups": user.groups,
+ "tenant_id": user.tenant_id, "email_verified": user.email_verified,
"auth_method": "keycloak"
}
async def close(self):
- """Cleanup resources."""
if self.keycloak_auth:
await self.keycloak_auth.close()
@@ -379,57 +190,17 @@ class HybridAuthenticator:
# FACTORY FUNCTIONS
# =============================================
-def get_keycloak_config_from_env() -> Optional[KeycloakConfig]:
- """
- Create KeycloakConfig from environment variables.
-
- Required env vars:
- - KEYCLOAK_SERVER_URL: e.g., https://keycloak.breakpilot.app
- - KEYCLOAK_REALM: e.g., breakpilot
- - KEYCLOAK_CLIENT_ID: e.g., breakpilot-backend
-
- Optional:
- - KEYCLOAK_CLIENT_SECRET: For confidential clients
- - KEYCLOAK_VERIFY_SSL: Default true
- """
- server_url = os.environ.get("KEYCLOAK_SERVER_URL")
- realm = os.environ.get("KEYCLOAK_REALM")
- client_id = os.environ.get("KEYCLOAK_CLIENT_ID")
-
- if not all([server_url, realm, client_id]):
- logger.info("Keycloak not configured, using local JWT only")
- return None
-
- return KeycloakConfig(
- server_url=server_url,
- realm=realm,
- client_id=client_id,
- client_secret=os.environ.get("KEYCLOAK_CLIENT_SECRET"),
- verify_ssl=os.environ.get("KEYCLOAK_VERIFY_SSL", "true").lower() == "true"
- )
-
-
def get_authenticator() -> HybridAuthenticator:
- """
- Get configured authenticator instance.
-
- Uses environment variables to determine configuration.
- """
+ """Get configured authenticator instance."""
keycloak_config = get_keycloak_config_from_env()
-
- # JWT_SECRET is required - no default fallback in production
jwt_secret = os.environ.get("JWT_SECRET")
environment = os.environ.get("ENVIRONMENT", "development")
if not jwt_secret and environment == "production":
- raise KeycloakConfigError(
- "JWT_SECRET environment variable is required in production"
- )
+ raise KeycloakConfigError("JWT_SECRET environment variable is required in production")
return HybridAuthenticator(
- keycloak_config=keycloak_config,
- local_jwt_secret=jwt_secret,
- environment=environment
+ keycloak_config=keycloak_config, local_jwt_secret=jwt_secret, environment=environment
)
@@ -439,7 +210,6 @@ def get_authenticator() -> HybridAuthenticator:
from fastapi import Request, HTTPException, Depends
-# Global authenticator instance (lazy-initialized)
_authenticator: Optional[HybridAuthenticator] = None
@@ -452,26 +222,16 @@ def get_auth() -> HybridAuthenticator:
async def get_current_user(request: Request) -> Dict[str, Any]:
- """
- FastAPI dependency to get current authenticated user.
-
- Usage:
- @app.get("/api/protected")
- async def protected_endpoint(user: dict = Depends(get_current_user)):
- return {"user_id": user["user_id"]}
- """
+ """FastAPI dependency to get current authenticated user."""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
- # Check for development mode
environment = os.environ.get("ENVIRONMENT", "development")
if environment == "development":
- # Return demo user in development without token
return {
"user_id": "10000000-0000-0000-0000-000000000024",
"email": "demo@breakpilot.app",
- "role": "admin",
- "realm_roles": ["admin"],
+ "role": "admin", "realm_roles": ["admin"],
"tenant_id": "a0000000-0000-0000-0000-000000000001",
"auth_method": "development_bypass"
}
@@ -492,24 +252,11 @@ async def get_current_user(request: Request) -> Dict[str, Any]:
async def require_role(required_role: str):
- """
- FastAPI dependency factory for role-based access.
-
- Usage:
- @app.get("/api/admin-only")
- async def admin_endpoint(user: dict = Depends(require_role("admin"))):
- return {"message": "Admin access granted"}
- """
+ """FastAPI dependency factory for role-based access."""
async def role_checker(user: dict = Depends(get_current_user)) -> dict:
user_role = user.get("role", "user")
realm_roles = user.get("realm_roles", [])
-
if user_role == required_role or required_role in realm_roles:
return user
-
- raise HTTPException(
- status_code=403,
- detail=f"Role '{required_role}' required"
- )
-
+ raise HTTPException(status_code=403, detail=f"Role '{required_role}' required")
return role_checker
diff --git a/backend-lehrer/auth/keycloak_models.py b/backend-lehrer/auth/keycloak_models.py
new file mode 100644
index 0000000..31efe9d
--- /dev/null
+++ b/backend-lehrer/auth/keycloak_models.py
@@ -0,0 +1,104 @@
+"""
+Keycloak Authentication - Models, Config, and Exceptions.
+"""
+
+import os
+import logging
+from typing import Optional, Dict, Any, List
+from dataclasses import dataclass
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class KeycloakConfig:
+ """Keycloak connection configuration."""
+ server_url: str
+ realm: str
+ client_id: str
+ client_secret: Optional[str] = None
+ verify_ssl: bool = True
+
+ @property
+ def issuer_url(self) -> str:
+ return f"{self.server_url}/realms/{self.realm}"
+
+ @property
+ def jwks_url(self) -> str:
+ return f"{self.issuer_url}/protocol/openid-connect/certs"
+
+ @property
+ def token_url(self) -> str:
+ return f"{self.issuer_url}/protocol/openid-connect/token"
+
+ @property
+ def userinfo_url(self) -> str:
+ return f"{self.issuer_url}/protocol/openid-connect/userinfo"
+
+
+@dataclass
+class KeycloakUser:
+ """User information extracted from Keycloak token."""
+ user_id: str
+ email: str
+ email_verified: bool
+ name: Optional[str]
+ given_name: Optional[str]
+ family_name: Optional[str]
+ realm_roles: List[str]
+ client_roles: Dict[str, List[str]]
+ groups: List[str]
+ tenant_id: Optional[str]
+ raw_claims: Dict[str, Any]
+
+ def has_realm_role(self, role: str) -> bool:
+ return role in self.realm_roles
+
+ def has_client_role(self, client_id: str, role: str) -> bool:
+ client_roles = self.client_roles.get(client_id, [])
+ return role in client_roles
+
+ def is_admin(self) -> bool:
+ return self.has_realm_role("admin") or self.has_realm_role("schul_admin")
+
+ def is_teacher(self) -> bool:
+ return self.has_realm_role("teacher") or self.has_realm_role("lehrer")
+
+
+class KeycloakAuthError(Exception):
+ """Base exception for Keycloak authentication errors."""
+ pass
+
+
+class TokenExpiredError(KeycloakAuthError):
+ """Token has expired."""
+ pass
+
+
+class TokenInvalidError(KeycloakAuthError):
+ """Token is invalid."""
+ pass
+
+
+class KeycloakConfigError(KeycloakAuthError):
+ """Keycloak configuration error."""
+ pass
+
+
+def get_keycloak_config_from_env() -> Optional[KeycloakConfig]:
+ """Create KeycloakConfig from environment variables."""
+ server_url = os.environ.get("KEYCLOAK_SERVER_URL")
+ realm = os.environ.get("KEYCLOAK_REALM")
+ client_id = os.environ.get("KEYCLOAK_CLIENT_ID")
+
+ if not all([server_url, realm, client_id]):
+ logger.info("Keycloak not configured, using local JWT only")
+ return None
+
+ return KeycloakConfig(
+ server_url=server_url,
+ realm=realm,
+ client_id=client_id,
+ client_secret=os.environ.get("KEYCLOAK_CLIENT_SECRET"),
+ verify_ssl=os.environ.get("KEYCLOAK_VERIFY_SSL", "true").lower() == "true"
+ )
diff --git a/backend-lehrer/classroom/models.py b/backend-lehrer/classroom/models.py
index d8332bf..9dd024e 100644
--- a/backend-lehrer/classroom/models.py
+++ b/backend-lehrer/classroom/models.py
@@ -2,567 +2,68 @@
Classroom API - Pydantic Models
Alle Request- und Response-Models fuer die Classroom API.
+Barrel re-export aus aufgeteilten Modulen.
"""
-from typing import Dict, List, Optional, Any
-from pydantic import BaseModel, Field
-
-
-# === Session Models ===
-
-class CreateSessionRequest(BaseModel):
- """Request zum Erstellen einer neuen Session."""
- teacher_id: str = Field(..., description="ID des Lehrers")
- class_id: str = Field(..., description="ID der Klasse")
- subject: str = Field(..., description="Unterrichtsfach")
- topic: Optional[str] = Field(None, description="Thema der Stunde")
- phase_durations: Optional[Dict[str, int]] = Field(
- None,
- description="Optionale individuelle Phasendauern in Minuten"
- )
-
-
-class NotesRequest(BaseModel):
- """Request zum Aktualisieren von Notizen."""
- notes: str = Field("", description="Stundennotizen")
- homework: str = Field("", description="Hausaufgaben")
-
-
-class ExtendTimeRequest(BaseModel):
- """Request zum Verlaengern der aktuellen Phase (Feature f28)."""
- minutes: int = Field(5, ge=1, le=30, description="Zusaetzliche Minuten (1-30)")
-
-
-class PhaseInfo(BaseModel):
- """Informationen zu einer Phase."""
- phase: str
- display_name: str
- icon: str
- duration_minutes: int
- is_completed: bool
- is_current: bool
- is_future: bool
-
-
-class TimerStatus(BaseModel):
- """Timer-Status einer Phase."""
- remaining_seconds: int
- remaining_formatted: str
- total_seconds: int
- total_formatted: str
- elapsed_seconds: int
- elapsed_formatted: str
- percentage_remaining: int
- percentage_elapsed: int
- percentage: int = Field(description="Alias fuer percentage_remaining (Visual Timer)")
- warning: bool
- overtime: bool
- overtime_seconds: int
- overtime_formatted: Optional[str]
- is_paused: bool = Field(False, description="Ist der Timer pausiert?")
-
-
-class SuggestionItem(BaseModel):
- """Ein Aktivitaets-Vorschlag."""
- id: str
- title: str
- description: str
- activity_type: str
- estimated_minutes: int
- icon: str
- content_url: Optional[str]
-
-
-class SessionResponse(BaseModel):
- """Vollstaendige Session-Response."""
- session_id: str
- teacher_id: str
- class_id: str
- subject: str
- topic: Optional[str]
- current_phase: str
- phase_display_name: str
- phase_started_at: Optional[str]
- lesson_started_at: Optional[str]
- lesson_ended_at: Optional[str]
- timer: TimerStatus
- phases: List[PhaseInfo]
- phase_history: List[Dict[str, Any]]
- notes: str
- homework: str
- is_active: bool
- is_ended: bool
- is_paused: bool = Field(False, description="Ist die Stunde pausiert?")
-
-
-class SuggestionsResponse(BaseModel):
- """Response fuer Vorschlaege."""
- suggestions: List[SuggestionItem]
- current_phase: str
- phase_display_name: str
- total_available: int
-
-
-class PhasesListResponse(BaseModel):
- """Liste aller verfuegbaren Phasen."""
- phases: List[Dict[str, Any]]
-
-
-class ActiveSessionsResponse(BaseModel):
- """Liste aktiver Sessions."""
- sessions: List[Dict[str, Any]]
- count: int
-
-
-# === Session History Models ===
-
-class SessionHistoryItem(BaseModel):
- """Einzelner Eintrag in der Session-History."""
- session_id: str
- teacher_id: str
- class_id: str
- subject: str
- topic: Optional[str]
- lesson_started_at: Optional[str]
- lesson_ended_at: Optional[str]
- total_duration_minutes: Optional[int]
- phases_completed: int
- notes: str
- homework: str
-
-
-class SessionHistoryResponse(BaseModel):
- """Response fuer Session-History."""
- sessions: List[SessionHistoryItem]
- total_count: int
- limit: int
- offset: int
-
-
-# === Template Models ===
-
-class TemplateCreate(BaseModel):
- """Request zum Erstellen einer Vorlage."""
- name: str = Field(..., min_length=1, max_length=200, description="Name der Vorlage")
- description: str = Field("", max_length=1000, description="Beschreibung")
- subject: str = Field("", max_length=100, description="Fach")
- grade_level: str = Field("", max_length=50, description="Klassenstufe (z.B. '7', '10')")
- phase_durations: Optional[Dict[str, int]] = Field(
- None,
- description="Phasendauern in Minuten"
- )
- default_topic: str = Field("", max_length=500, description="Vorausgefuelltes Thema")
- default_notes: str = Field("", description="Vorausgefuellte Notizen")
- is_public: bool = Field(False, description="Vorlage fuer alle sichtbar?")
-
-
-class TemplateUpdate(BaseModel):
- """Request zum Aktualisieren einer Vorlage."""
- name: Optional[str] = Field(None, min_length=1, max_length=200)
- description: Optional[str] = Field(None, max_length=1000)
- subject: Optional[str] = Field(None, max_length=100)
- grade_level: Optional[str] = Field(None, max_length=50)
- phase_durations: Optional[Dict[str, int]] = None
- default_topic: Optional[str] = Field(None, max_length=500)
- default_notes: Optional[str] = None
- is_public: Optional[bool] = None
-
-
-class TemplateResponse(BaseModel):
- """Response fuer eine einzelne Vorlage."""
- template_id: str
- teacher_id: str
- name: str
- description: str
- subject: str
- grade_level: str
- phase_durations: Dict[str, int]
- default_topic: str
- default_notes: str
- is_public: bool
- usage_count: int
- total_duration_minutes: int
- created_at: Optional[str]
- updated_at: Optional[str]
- is_system_template: bool = False
-
-
-class TemplateListResponse(BaseModel):
- """Response fuer Template-Liste."""
- templates: List[TemplateResponse]
- total_count: int
-
-
-# === Homework Models ===
-
-class CreateHomeworkRequest(BaseModel):
- """Request zum Erstellen einer Hausaufgabe."""
- teacher_id: str
- class_id: str
- subject: str
- title: str = Field(..., max_length=300)
- description: str = ""
- session_id: Optional[str] = None
- due_date: Optional[str] = Field(None, description="ISO-Format Datum")
-
-
-class UpdateHomeworkRequest(BaseModel):
- """Request zum Aktualisieren einer Hausaufgabe."""
- title: Optional[str] = Field(None, max_length=300)
- description: Optional[str] = None
- due_date: Optional[str] = Field(None, description="ISO-Format Datum")
- status: Optional[str] = Field(None, description="assigned, in_progress, completed")
-
-
-class HomeworkResponse(BaseModel):
- """Response fuer eine Hausaufgabe."""
- homework_id: str
- teacher_id: str
- class_id: str
- subject: str
- title: str
- description: str
- session_id: Optional[str]
- due_date: Optional[str]
- status: str
- is_overdue: bool
- created_at: Optional[str]
- updated_at: Optional[str]
-
-
-class HomeworkListResponse(BaseModel):
- """Response fuer Liste von Hausaufgaben."""
- homework: List[HomeworkResponse]
- total: int
-
-
-# === Material Models ===
-
-class CreateMaterialRequest(BaseModel):
- """Request zum Erstellen eines Materials."""
- teacher_id: str
- title: str = Field(..., max_length=300)
- material_type: str = Field("document", description="document, link, video, image, worksheet, presentation, other")
- url: Optional[str] = Field(None, max_length=2000)
- description: str = ""
- phase: Optional[str] = Field(None, description="einstieg, erarbeitung, sicherung, transfer, reflexion")
- subject: str = ""
- grade_level: str = ""
- tags: List[str] = []
- is_public: bool = False
- session_id: Optional[str] = None
-
-
-class UpdateMaterialRequest(BaseModel):
- """Request zum Aktualisieren eines Materials."""
- title: Optional[str] = Field(None, max_length=300)
- material_type: Optional[str] = None
- url: Optional[str] = Field(None, max_length=2000)
- description: Optional[str] = None
- phase: Optional[str] = None
- subject: Optional[str] = None
- grade_level: Optional[str] = None
- tags: Optional[List[str]] = None
- is_public: Optional[bool] = None
-
-
-class MaterialResponse(BaseModel):
- """Response fuer ein Material."""
- material_id: str
- teacher_id: str
- title: str
- material_type: str
- url: Optional[str]
- description: str
- phase: Optional[str]
- subject: str
- grade_level: str
- tags: List[str]
- is_public: bool
- usage_count: int
- session_id: Optional[str]
- created_at: Optional[str]
- updated_at: Optional[str]
-
-
-class MaterialListResponse(BaseModel):
- """Response fuer Liste von Materialien."""
- materials: List[MaterialResponse]
- total: int
-
-
-# === Analytics Models ===
-
-class SessionSummaryResponse(BaseModel):
- """Response fuer Session-Summary."""
- session_id: str
- teacher_id: str
- class_id: str
- subject: str
- topic: Optional[str]
- date: Optional[str]
- date_formatted: str
- total_duration_seconds: int
- total_duration_formatted: str
- planned_duration_seconds: int
- planned_duration_formatted: str
- phases_completed: int
- total_phases: int
- completion_percentage: int
- phase_statistics: List[Dict[str, Any]]
- total_overtime_seconds: int
- total_overtime_formatted: str
- phases_with_overtime: int
- total_pause_count: int
- total_pause_seconds: int
- reflection_notes: str = ""
- reflection_rating: Optional[int] = None
- key_learnings: List[str] = []
-
-
-class TeacherAnalyticsResponse(BaseModel):
- """Response fuer Lehrer-Analytics."""
- teacher_id: str
- period_start: Optional[str]
- period_end: Optional[str]
- total_sessions: int
- completed_sessions: int
- total_teaching_minutes: int
- total_teaching_hours: float
- avg_phase_durations: Dict[str, int]
- sessions_with_overtime: int
- overtime_percentage: int
- avg_overtime_seconds: int
- avg_overtime_formatted: str
- most_overtime_phase: Optional[str]
- avg_pause_count: float
- avg_pause_duration_seconds: int
- subjects_taught: Dict[str, int]
- classes_taught: Dict[str, int]
-
-
-class ReflectionCreate(BaseModel):
- """Request-Body fuer Reflection-Erstellung."""
- session_id: str
- teacher_id: str
- notes: str = ""
- overall_rating: Optional[int] = Field(None, ge=1, le=5)
- what_worked: List[str] = []
- improvements: List[str] = []
- notes_for_next_lesson: str = ""
-
-
-class ReflectionUpdate(BaseModel):
- """Request-Body fuer Reflection-Update."""
- notes: Optional[str] = None
- overall_rating: Optional[int] = Field(None, ge=1, le=5)
- what_worked: Optional[List[str]] = None
- improvements: Optional[List[str]] = None
- notes_for_next_lesson: Optional[str] = None
-
-
-class ReflectionResponse(BaseModel):
- """Response fuer eine einzelne Reflection."""
- reflection_id: str
- session_id: str
- teacher_id: str
- notes: str
- overall_rating: Optional[int]
- what_worked: List[str]
- improvements: List[str]
- notes_for_next_lesson: str
- created_at: Optional[str]
- updated_at: Optional[str]
-
-
-# === Feedback Models ===
-
-class FeedbackCreate(BaseModel):
- """Request zum Erstellen von Feedback."""
- title: str = Field(..., min_length=3, max_length=500, description="Kurzer Titel")
- description: str = Field(..., min_length=10, description="Beschreibung")
- feedback_type: str = Field("improvement", description="bug, feature_request, improvement, praise, question")
- priority: str = Field("medium", description="critical, high, medium, low")
- teacher_name: str = Field("", description="Name des Lehrers")
- teacher_email: str = Field("", description="E-Mail fuer Rueckfragen")
- context_url: str = Field("", description="URL wo Feedback gegeben wurde")
- context_phase: str = Field("", description="Aktuelle Phase")
- context_session_id: Optional[str] = Field(None, description="Session-ID falls aktiv")
- related_feature: Optional[str] = Field(None, description="Verwandtes Feature")
-
-
-class FeedbackResponse(BaseModel):
- """Response fuer Feedback."""
- id: str
- teacher_id: str
- teacher_name: str
- title: str
- description: str
- feedback_type: str
- priority: str
- status: str
- created_at: str
- response: Optional[str] = None
-
-
-class FeedbackListResponse(BaseModel):
- """Liste von Feedbacks."""
- feedbacks: List[Dict[str, Any]]
- total: int
-
-
-class FeedbackStatsResponse(BaseModel):
- """Feedback-Statistiken."""
- total: int
- by_status: Dict[str, int]
- by_type: Dict[str, int]
- by_priority: Dict[str, int]
-
-
-# === Settings Models ===
-
-class TeacherSettingsResponse(BaseModel):
- """Response fuer Lehrer-Einstellungen."""
- teacher_id: str
- default_phase_durations: Dict[str, int]
- audio_enabled: bool = True
- high_contrast: bool = False
- show_statistics: bool = True
-
-
-class UpdatePhaseDurationsRequest(BaseModel):
- """Request zum Aktualisieren der Phasen-Dauern."""
- durations: Dict[str, int] = Field(
- ...,
- description="Phasen-Dauern in Minuten, z.B. {'einstieg': 10, 'erarbeitung': 25}",
- examples=[{"einstieg": 10, "erarbeitung": 25, "sicherung": 10, "transfer": 8, "reflexion": 5}]
- )
-
-
-class UpdatePreferencesRequest(BaseModel):
- """Request zum Aktualisieren der UI-Praeferenzen."""
- audio_enabled: Optional[bool] = None
- high_contrast: Optional[bool] = None
- show_statistics: Optional[bool] = None
-
-
-# === Context Models ===
-
-class SchoolInfo(BaseModel):
- """Schul-Informationen."""
- federal_state: str
- federal_state_name: str = ""
- school_type: str
- school_type_name: str = ""
-
-
-class SchoolYearInfo(BaseModel):
- """Schuljahr-Informationen."""
- id: str
- start: Optional[str] = None
- current_week: int = 1
-
-
-class MacroPhaseInfo(BaseModel):
- """Makro-Phase Informationen."""
- id: str
- label: str
- confidence: float = 1.0
-
-
-class CoreCounts(BaseModel):
- """Kern-Zaehler fuer den Kontext."""
- classes: int = 0
- exams_scheduled: int = 0
- corrections_pending: int = 0
-
-
-class ContextFlags(BaseModel):
- """Status-Flags des Kontexts."""
- onboarding_completed: bool = False
- has_classes: bool = False
- has_schedule: bool = False
- is_exam_period: bool = False
- is_before_holidays: bool = False
-
-
-class TeacherContextResponse(BaseModel):
- """Response fuer GET /v1/context."""
- schema_version: str = "1.0"
- teacher_id: str
- school: SchoolInfo
- school_year: SchoolYearInfo
- macro_phase: MacroPhaseInfo
- core_counts: CoreCounts
- flags: ContextFlags
-
-
-class UpdateContextRequest(BaseModel):
- """Request zum Aktualisieren des Kontexts."""
- federal_state: Optional[str] = None
- school_type: Optional[str] = None
- schoolyear: Optional[str] = None
- schoolyear_start: Optional[str] = None
- macro_phase: Optional[str] = None
- current_week: Optional[int] = None
-
-
-# === Event Models ===
-
-class CreateEventRequest(BaseModel):
- """Request zum Erstellen eines Events."""
- title: str
- event_type: str = "other"
- start_date: str
- end_date: Optional[str] = None
- class_id: Optional[str] = None
- subject: Optional[str] = None
- description: str = ""
- needs_preparation: bool = True
- reminder_days_before: int = 7
-
-
-class EventResponse(BaseModel):
- """Response fuer ein Event."""
- id: str
- teacher_id: str
- event_type: str
- title: str
- description: str
- start_date: str
- end_date: Optional[str]
- class_id: Optional[str]
- subject: Optional[str]
- status: str
- needs_preparation: bool
- preparation_done: bool
- reminder_days_before: int
-
-
-# === Routine Models ===
-
-class CreateRoutineRequest(BaseModel):
- """Request zum Erstellen einer Routine."""
- title: str
- routine_type: str = "other"
- recurrence_pattern: str = "weekly"
- day_of_week: Optional[int] = None
- day_of_month: Optional[int] = None
- time_of_day: Optional[str] = None
- duration_minutes: int = 60
- description: str = ""
-
-
-class RoutineResponse(BaseModel):
- """Response fuer eine Routine."""
- id: str
- teacher_id: str
- routine_type: str
- title: str
- description: str
- recurrence_pattern: str
- day_of_week: Optional[int]
- day_of_month: Optional[int]
- time_of_day: Optional[str]
- duration_minutes: int
- is_active: bool
+# Session & Phase Models
+from .models_session import (
+ CreateSessionRequest,
+ NotesRequest,
+ ExtendTimeRequest,
+ PhaseInfo,
+ TimerStatus,
+ SuggestionItem,
+ SessionResponse,
+ SuggestionsResponse,
+ PhasesListResponse,
+ ActiveSessionsResponse,
+ SessionHistoryItem,
+ SessionHistoryResponse,
+)
+
+# Template, Homework, Material Models
+from .models_templates import (
+ TemplateCreate,
+ TemplateUpdate,
+ TemplateResponse,
+ TemplateListResponse,
+ CreateHomeworkRequest,
+ UpdateHomeworkRequest,
+ HomeworkResponse,
+ HomeworkListResponse,
+ CreateMaterialRequest,
+ UpdateMaterialRequest,
+ MaterialResponse,
+ MaterialListResponse,
+)
+
+# Analytics, Reflection, Feedback, Settings Models
+from .models_analytics import (
+ SessionSummaryResponse,
+ TeacherAnalyticsResponse,
+ ReflectionCreate,
+ ReflectionUpdate,
+ ReflectionResponse,
+ FeedbackCreate,
+ FeedbackResponse,
+ FeedbackListResponse,
+ FeedbackStatsResponse,
+ TeacherSettingsResponse,
+ UpdatePhaseDurationsRequest,
+ UpdatePreferencesRequest,
+)
+
+# Context, Event, Routine Models
+from .models_context import (
+ SchoolInfo,
+ SchoolYearInfo,
+ MacroPhaseInfo,
+ CoreCounts,
+ ContextFlags,
+ TeacherContextResponse,
+ UpdateContextRequest,
+ CreateEventRequest,
+ EventResponse,
+ CreateRoutineRequest,
+ RoutineResponse,
+)
diff --git a/backend-lehrer/classroom/models_analytics.py b/backend-lehrer/classroom/models_analytics.py
new file mode 100644
index 0000000..6c1673d
--- /dev/null
+++ b/backend-lehrer/classroom/models_analytics.py
@@ -0,0 +1,161 @@
+"""
+Classroom API - Analytics, Reflection, Feedback, Settings Pydantic Models.
+"""
+
+from typing import Dict, List, Optional, Any
+from pydantic import BaseModel, Field
+
+
+# === Analytics Models ===
+
+class SessionSummaryResponse(BaseModel):
+ """Response fuer Session-Summary."""
+ session_id: str
+ teacher_id: str
+ class_id: str
+ subject: str
+ topic: Optional[str]
+ date: Optional[str]
+ date_formatted: str
+ total_duration_seconds: int
+ total_duration_formatted: str
+ planned_duration_seconds: int
+ planned_duration_formatted: str
+ phases_completed: int
+ total_phases: int
+ completion_percentage: int
+ phase_statistics: List[Dict[str, Any]]
+ total_overtime_seconds: int
+ total_overtime_formatted: str
+ phases_with_overtime: int
+ total_pause_count: int
+ total_pause_seconds: int
+ reflection_notes: str = ""
+ reflection_rating: Optional[int] = None
+ key_learnings: List[str] = []
+
+
+class TeacherAnalyticsResponse(BaseModel):
+ """Response fuer Lehrer-Analytics."""
+ teacher_id: str
+ period_start: Optional[str]
+ period_end: Optional[str]
+ total_sessions: int
+ completed_sessions: int
+ total_teaching_minutes: int
+ total_teaching_hours: float
+ avg_phase_durations: Dict[str, int]
+ sessions_with_overtime: int
+ overtime_percentage: int
+ avg_overtime_seconds: int
+ avg_overtime_formatted: str
+ most_overtime_phase: Optional[str]
+ avg_pause_count: float
+ avg_pause_duration_seconds: int
+ subjects_taught: Dict[str, int]
+ classes_taught: Dict[str, int]
+
+
+class ReflectionCreate(BaseModel):
+ """Request-Body fuer Reflection-Erstellung."""
+ session_id: str
+ teacher_id: str
+ notes: str = ""
+ overall_rating: Optional[int] = Field(None, ge=1, le=5)
+ what_worked: List[str] = []
+ improvements: List[str] = []
+ notes_for_next_lesson: str = ""
+
+
+class ReflectionUpdate(BaseModel):
+ """Request-Body fuer Reflection-Update."""
+ notes: Optional[str] = None
+ overall_rating: Optional[int] = Field(None, ge=1, le=5)
+ what_worked: Optional[List[str]] = None
+ improvements: Optional[List[str]] = None
+ notes_for_next_lesson: Optional[str] = None
+
+
+class ReflectionResponse(BaseModel):
+ """Response fuer eine einzelne Reflection."""
+ reflection_id: str
+ session_id: str
+ teacher_id: str
+ notes: str
+ overall_rating: Optional[int]
+ what_worked: List[str]
+ improvements: List[str]
+ notes_for_next_lesson: str
+ created_at: Optional[str]
+ updated_at: Optional[str]
+
+
+# === Feedback Models ===
+
+class FeedbackCreate(BaseModel):
+ """Request zum Erstellen von Feedback."""
+ title: str = Field(..., min_length=3, max_length=500, description="Kurzer Titel")
+ description: str = Field(..., min_length=10, description="Beschreibung")
+ feedback_type: str = Field("improvement", description="bug, feature_request, improvement, praise, question")
+ priority: str = Field("medium", description="critical, high, medium, low")
+ teacher_name: str = Field("", description="Name des Lehrers")
+ teacher_email: str = Field("", description="E-Mail fuer Rueckfragen")
+ context_url: str = Field("", description="URL wo Feedback gegeben wurde")
+ context_phase: str = Field("", description="Aktuelle Phase")
+ context_session_id: Optional[str] = Field(None, description="Session-ID falls aktiv")
+ related_feature: Optional[str] = Field(None, description="Verwandtes Feature")
+
+
+class FeedbackResponse(BaseModel):
+ """Response fuer Feedback."""
+ id: str
+ teacher_id: str
+ teacher_name: str
+ title: str
+ description: str
+ feedback_type: str
+ priority: str
+ status: str
+ created_at: str
+ response: Optional[str] = None
+
+
+class FeedbackListResponse(BaseModel):
+ """Liste von Feedbacks."""
+ feedbacks: List[Dict[str, Any]]
+ total: int
+
+
+class FeedbackStatsResponse(BaseModel):
+ """Feedback-Statistiken."""
+ total: int
+ by_status: Dict[str, int]
+ by_type: Dict[str, int]
+ by_priority: Dict[str, int]
+
+
+# === Settings Models ===
+
+class TeacherSettingsResponse(BaseModel):
+ """Response fuer Lehrer-Einstellungen."""
+ teacher_id: str
+ default_phase_durations: Dict[str, int]
+ audio_enabled: bool = True
+ high_contrast: bool = False
+ show_statistics: bool = True
+
+
+class UpdatePhaseDurationsRequest(BaseModel):
+ """Request zum Aktualisieren der Phasen-Dauern."""
+ durations: Dict[str, int] = Field(
+ ...,
+ description="Phasen-Dauern in Minuten, z.B. {'einstieg': 10, 'erarbeitung': 25}",
+ examples=[{"einstieg": 10, "erarbeitung": 25, "sicherung": 10, "transfer": 8, "reflexion": 5}]
+ )
+
+
+class UpdatePreferencesRequest(BaseModel):
+ """Request zum Aktualisieren der UI-Praeferenzen."""
+ audio_enabled: Optional[bool] = None
+ high_contrast: Optional[bool] = None
+ show_statistics: Optional[bool] = None
diff --git a/backend-lehrer/classroom/models_context.py b/backend-lehrer/classroom/models_context.py
new file mode 100644
index 0000000..4f743fc
--- /dev/null
+++ b/backend-lehrer/classroom/models_context.py
@@ -0,0 +1,128 @@
+"""
+Classroom API - Context, Event, Routine Pydantic Models.
+"""
+
+from typing import Optional
+from pydantic import BaseModel, Field
+
+
+# === Context Models ===
+
+class SchoolInfo(BaseModel):
+ """Schul-Informationen."""
+ federal_state: str
+ federal_state_name: str = ""
+ school_type: str
+ school_type_name: str = ""
+
+
+class SchoolYearInfo(BaseModel):
+ """Schuljahr-Informationen."""
+ id: str
+ start: Optional[str] = None
+ current_week: int = 1
+
+
+class MacroPhaseInfo(BaseModel):
+ """Makro-Phase Informationen."""
+ id: str
+ label: str
+ confidence: float = 1.0
+
+
+class CoreCounts(BaseModel):
+ """Kern-Zaehler fuer den Kontext."""
+ classes: int = 0
+ exams_scheduled: int = 0
+ corrections_pending: int = 0
+
+
+class ContextFlags(BaseModel):
+ """Status-Flags des Kontexts."""
+ onboarding_completed: bool = False
+ has_classes: bool = False
+ has_schedule: bool = False
+ is_exam_period: bool = False
+ is_before_holidays: bool = False
+
+
+class TeacherContextResponse(BaseModel):
+ """Response fuer GET /v1/context."""
+ schema_version: str = "1.0"
+ teacher_id: str
+ school: SchoolInfo
+ school_year: SchoolYearInfo
+ macro_phase: MacroPhaseInfo
+ core_counts: CoreCounts
+ flags: ContextFlags
+
+
+class UpdateContextRequest(BaseModel):
+ """Request zum Aktualisieren des Kontexts."""
+ federal_state: Optional[str] = None
+ school_type: Optional[str] = None
+ schoolyear: Optional[str] = None
+ schoolyear_start: Optional[str] = None
+ macro_phase: Optional[str] = None
+ current_week: Optional[int] = None
+
+
+# === Event Models ===
+
+class CreateEventRequest(BaseModel):
+ """Request zum Erstellen eines Events."""
+ title: str
+ event_type: str = "other"
+ start_date: str
+ end_date: Optional[str] = None
+ class_id: Optional[str] = None
+ subject: Optional[str] = None
+ description: str = ""
+ needs_preparation: bool = True
+ reminder_days_before: int = 7
+
+
+class EventResponse(BaseModel):
+ """Response fuer ein Event."""
+ id: str
+ teacher_id: str
+ event_type: str
+ title: str
+ description: str
+ start_date: str
+ end_date: Optional[str]
+ class_id: Optional[str]
+ subject: Optional[str]
+ status: str
+ needs_preparation: bool
+ preparation_done: bool
+ reminder_days_before: int
+
+
+# === Routine Models ===
+
+class CreateRoutineRequest(BaseModel):
+ """Request zum Erstellen einer Routine."""
+ title: str
+ routine_type: str = "other"
+ recurrence_pattern: str = "weekly"
+ day_of_week: Optional[int] = None
+ day_of_month: Optional[int] = None
+ time_of_day: Optional[str] = None
+ duration_minutes: int = 60
+ description: str = ""
+
+
+class RoutineResponse(BaseModel):
+ """Response fuer eine Routine."""
+ id: str
+ teacher_id: str
+ routine_type: str
+ title: str
+ description: str
+ recurrence_pattern: str
+ day_of_week: Optional[int]
+ day_of_month: Optional[int]
+ time_of_day: Optional[str]
+ duration_minutes: int
+ is_active: bool
diff --git a/backend-lehrer/classroom/models_session.py b/backend-lehrer/classroom/models_session.py
new file mode 100644
index 0000000..938256b
--- /dev/null
+++ b/backend-lehrer/classroom/models_session.py
@@ -0,0 +1,137 @@
+"""
+Classroom API - Session & Phase Pydantic Models.
+"""
+
+from typing import Dict, List, Optional, Any
+from pydantic import BaseModel, Field
+
+
+# === Session Models ===
+
+class CreateSessionRequest(BaseModel):
+ """Request zum Erstellen einer neuen Session."""
+ teacher_id: str = Field(..., description="ID des Lehrers")
+ class_id: str = Field(..., description="ID der Klasse")
+ subject: str = Field(..., description="Unterrichtsfach")
+ topic: Optional[str] = Field(None, description="Thema der Stunde")
+ phase_durations: Optional[Dict[str, int]] = Field(
+ None,
+ description="Optionale individuelle Phasendauern in Minuten"
+ )
+
+
+class NotesRequest(BaseModel):
+ """Request zum Aktualisieren von Notizen."""
+ notes: str = Field("", description="Stundennotizen")
+ homework: str = Field("", description="Hausaufgaben")
+
+
+class ExtendTimeRequest(BaseModel):
+ """Request zum Verlaengern der aktuellen Phase (Feature f28)."""
+ minutes: int = Field(5, ge=1, le=30, description="Zusaetzliche Minuten (1-30)")
+
+
+class PhaseInfo(BaseModel):
+ """Informationen zu einer Phase."""
+ phase: str
+ display_name: str
+ icon: str
+ duration_minutes: int
+ is_completed: bool
+ is_current: bool
+ is_future: bool
+
+
+class TimerStatus(BaseModel):
+ """Timer-Status einer Phase."""
+ remaining_seconds: int
+ remaining_formatted: str
+ total_seconds: int
+ total_formatted: str
+ elapsed_seconds: int
+ elapsed_formatted: str
+ percentage_remaining: int
+ percentage_elapsed: int
+ percentage: int = Field(description="Alias fuer percentage_remaining (Visual Timer)")
+ warning: bool
+ overtime: bool
+ overtime_seconds: int
+ overtime_formatted: Optional[str]
+ is_paused: bool = Field(False, description="Ist der Timer pausiert?")
+
+
+class SuggestionItem(BaseModel):
+ """Ein Aktivitaets-Vorschlag."""
+ id: str
+ title: str
+ description: str
+ activity_type: str
+ estimated_minutes: int
+ icon: str
+ content_url: Optional[str]
+
+
+class SessionResponse(BaseModel):
+ """Vollstaendige Session-Response."""
+ session_id: str
+ teacher_id: str
+ class_id: str
+ subject: str
+ topic: Optional[str]
+ current_phase: str
+ phase_display_name: str
+ phase_started_at: Optional[str]
+ lesson_started_at: Optional[str]
+ lesson_ended_at: Optional[str]
+ timer: TimerStatus
+ phases: List[PhaseInfo]
+ phase_history: List[Dict[str, Any]]
+ notes: str
+ homework: str
+ is_active: bool
+ is_ended: bool
+ is_paused: bool = Field(False, description="Ist die Stunde pausiert?")
+
+
+class SuggestionsResponse(BaseModel):
+ """Response fuer Vorschlaege."""
+ suggestions: List[SuggestionItem]
+ current_phase: str
+ phase_display_name: str
+ total_available: int
+
+
+class PhasesListResponse(BaseModel):
+ """Liste aller verfuegbaren Phasen."""
+ phases: List[Dict[str, Any]]
+
+
+class ActiveSessionsResponse(BaseModel):
+ """Liste aktiver Sessions."""
+ sessions: List[Dict[str, Any]]
+ count: int
+
+
+# === Session History Models ===
+
+class SessionHistoryItem(BaseModel):
+ """Einzelner Eintrag in der Session-History."""
+ session_id: str
+ teacher_id: str
+ class_id: str
+ subject: str
+ topic: Optional[str]
+ lesson_started_at: Optional[str]
+ lesson_ended_at: Optional[str]
+ total_duration_minutes: Optional[int]
+ phases_completed: int
+ notes: str
+ homework: str
+
+
+class SessionHistoryResponse(BaseModel):
+ """Response fuer Session-History."""
+ sessions: List[SessionHistoryItem]
+ total_count: int
+ limit: int
+ offset: int
diff --git a/backend-lehrer/classroom/models_templates.py b/backend-lehrer/classroom/models_templates.py
new file mode 100644
index 0000000..64f188b
--- /dev/null
+++ b/backend-lehrer/classroom/models_templates.py
@@ -0,0 +1,158 @@
+"""
+Classroom API - Template, Homework, Material Pydantic Models.
+"""
+
+from typing import Dict, List, Optional
+from pydantic import BaseModel, Field
+
+
+# === Template Models ===
+
+class TemplateCreate(BaseModel):
+ """Request zum Erstellen einer Vorlage."""
+ name: str = Field(..., min_length=1, max_length=200, description="Name der Vorlage")
+ description: str = Field("", max_length=1000, description="Beschreibung")
+ subject: str = Field("", max_length=100, description="Fach")
+ grade_level: str = Field("", max_length=50, description="Klassenstufe (z.B. '7', '10')")
+ phase_durations: Optional[Dict[str, int]] = Field(
+ None,
+ description="Phasendauern in Minuten"
+ )
+ default_topic: str = Field("", max_length=500, description="Vorausgefuelltes Thema")
+ default_notes: str = Field("", description="Vorausgefuellte Notizen")
+ is_public: bool = Field(False, description="Vorlage fuer alle sichtbar?")
+
+
+class TemplateUpdate(BaseModel):
+ """Request zum Aktualisieren einer Vorlage."""
+ name: Optional[str] = Field(None, min_length=1, max_length=200)
+ description: Optional[str] = Field(None, max_length=1000)
+ subject: Optional[str] = Field(None, max_length=100)
+ grade_level: Optional[str] = Field(None, max_length=50)
+ phase_durations: Optional[Dict[str, int]] = None
+ default_topic: Optional[str] = Field(None, max_length=500)
+ default_notes: Optional[str] = None
+ is_public: Optional[bool] = None
+
+
+class TemplateResponse(BaseModel):
+ """Response fuer eine einzelne Vorlage."""
+ template_id: str
+ teacher_id: str
+ name: str
+ description: str
+ subject: str
+ grade_level: str
+ phase_durations: Dict[str, int]
+ default_topic: str
+ default_notes: str
+ is_public: bool
+ usage_count: int
+ total_duration_minutes: int
+ created_at: Optional[str]
+ updated_at: Optional[str]
+ is_system_template: bool = False
+
+
+class TemplateListResponse(BaseModel):
+ """Response fuer Template-Liste."""
+ templates: List[TemplateResponse]
+ total_count: int
+
+
+# === Homework Models ===
+
+class CreateHomeworkRequest(BaseModel):
+ """Request zum Erstellen einer Hausaufgabe."""
+ teacher_id: str
+ class_id: str
+ subject: str
+ title: str = Field(..., max_length=300)
+ description: str = ""
+ session_id: Optional[str] = None
+ due_date: Optional[str] = Field(None, description="ISO-Format Datum")
+
+
+class UpdateHomeworkRequest(BaseModel):
+ """Request zum Aktualisieren einer Hausaufgabe."""
+ title: Optional[str] = Field(None, max_length=300)
+ description: Optional[str] = None
+ due_date: Optional[str] = Field(None, description="ISO-Format Datum")
+ status: Optional[str] = Field(None, description="assigned, in_progress, completed")
+
+
+class HomeworkResponse(BaseModel):
+ """Response fuer eine Hausaufgabe."""
+ homework_id: str
+ teacher_id: str
+ class_id: str
+ subject: str
+ title: str
+ description: str
+ session_id: Optional[str]
+ due_date: Optional[str]
+ status: str
+ is_overdue: bool
+ created_at: Optional[str]
+ updated_at: Optional[str]
+
+
+class HomeworkListResponse(BaseModel):
+ """Response fuer Liste von Hausaufgaben."""
+ homework: List[HomeworkResponse]
+ total: int
+
+
+# === Material Models ===
+
+class CreateMaterialRequest(BaseModel):
+ """Request zum Erstellen eines Materials."""
+ teacher_id: str
+ title: str = Field(..., max_length=300)
+ material_type: str = Field("document", description="document, link, video, image, worksheet, presentation, other")
+ url: Optional[str] = Field(None, max_length=2000)
+ description: str = ""
+ phase: Optional[str] = Field(None, description="einstieg, erarbeitung, sicherung, transfer, reflexion")
+ subject: str = ""
+ grade_level: str = ""
+ tags: List[str] = []
+ is_public: bool = False
+ session_id: Optional[str] = None
+
+
+class UpdateMaterialRequest(BaseModel):
+ """Request zum Aktualisieren eines Materials."""
+ title: Optional[str] = Field(None, max_length=300)
+ material_type: Optional[str] = None
+ url: Optional[str] = Field(None, max_length=2000)
+ description: Optional[str] = None
+ phase: Optional[str] = None
+ subject: Optional[str] = None
+ grade_level: Optional[str] = None
+ tags: Optional[List[str]] = None
+ is_public: Optional[bool] = None
+
+
+class MaterialResponse(BaseModel):
+ """Response fuer ein Material."""
+ material_id: str
+ teacher_id: str
+ title: str
+ material_type: str
+ url: Optional[str]
+ description: str
+ phase: Optional[str]
+ subject: str
+ grade_level: str
+ tags: List[str]
+ is_public: bool
+ usage_count: int
+ session_id: Optional[str]
+ created_at: Optional[str]
+ updated_at: Optional[str]
+
+
+class MaterialListResponse(BaseModel):
+ """Response fuer Liste von Materialien."""
+ materials: List[MaterialResponse]
+ total: int
diff --git a/backend-lehrer/classroom/routes/sessions.py b/backend-lehrer/classroom/routes/sessions.py
index ef3286a..55e26c6 100644
--- a/backend-lehrer/classroom/routes/sessions.py
+++ b/backend-lehrer/classroom/routes/sessions.py
@@ -1,525 +1,17 @@
"""
-Classroom API - Session Routes
+Classroom API - Session Routes (barrel re-export)
-Session management endpoints: create, get, start, next-phase, end, etc.
+Combines core session routes and action routes into a single router.
"""
-from uuid import uuid4
-from typing import Dict, Optional, Any
-from datetime import datetime
-import logging
+from fastapi import APIRouter
-from fastapi import APIRouter, HTTPException, Query
+from .sessions_core import router as core_router, build_session_response
+from .sessions_actions import router as actions_router
-from classroom_engine import (
- LessonPhase,
- LessonSession,
- LessonStateMachine,
- PhaseTimer,
- SuggestionEngine,
- LESSON_PHASES,
-)
+router = APIRouter()
+router.include_router(core_router)
+router.include_router(actions_router)
-from ..models import (
- CreateSessionRequest,
- NotesRequest,
- ExtendTimeRequest,
- PhaseInfo,
- TimerStatus,
- SuggestionItem,
- SessionResponse,
- SuggestionsResponse,
- PhasesListResponse,
- ActiveSessionsResponse,
- SessionHistoryItem,
- SessionHistoryResponse,
-)
-from ..services.persistence import (
- sessions,
- init_db_if_needed,
- persist_session,
- get_session_or_404,
- DB_ENABLED,
- SessionLocal,
-)
-from ..websocket_manager import notify_phase_change, notify_session_ended
-
-logger = logging.getLogger(__name__)
-
-router = APIRouter(tags=["Sessions"])
-
-
-def build_session_response(session: LessonSession) -> SessionResponse:
- """Baut die vollstaendige Session-Response."""
- fsm = LessonStateMachine()
- timer = PhaseTimer()
-
- timer_status = timer.get_phase_status(session)
- phases_info = fsm.get_phases_info(session)
-
- return SessionResponse(
- session_id=session.session_id,
- teacher_id=session.teacher_id,
- class_id=session.class_id,
- subject=session.subject,
- topic=session.topic,
- current_phase=session.current_phase.value,
- phase_display_name=session.get_phase_display_name(),
- phase_started_at=session.phase_started_at.isoformat() if session.phase_started_at else None,
- lesson_started_at=session.lesson_started_at.isoformat() if session.lesson_started_at else None,
- lesson_ended_at=session.lesson_ended_at.isoformat() if session.lesson_ended_at else None,
- timer=TimerStatus(**timer_status),
- phases=[PhaseInfo(**p) for p in phases_info],
- phase_history=session.phase_history,
- notes=session.notes,
- homework=session.homework,
- is_active=fsm.is_lesson_active(session),
- is_ended=fsm.is_lesson_ended(session),
- is_paused=session.is_paused,
- )
-
-
-# === Session CRUD Endpoints ===
-
-@router.post("/sessions", response_model=SessionResponse)
-async def create_session(request: CreateSessionRequest) -> SessionResponse:
- """
- Erstellt eine neue Unterrichtsstunde (Session).
-
- Die Stunde ist nach Erstellung im Status NOT_STARTED.
- Zum Starten muss /sessions/{id}/start aufgerufen werden.
- """
- init_db_if_needed()
-
- # Default-Dauern mit uebergebenen Werten mergen
- phase_durations = {
- "einstieg": 8,
- "erarbeitung": 20,
- "sicherung": 10,
- "transfer": 7,
- "reflexion": 5,
- }
- if request.phase_durations:
- phase_durations.update(request.phase_durations)
-
- session = LessonSession(
- session_id=str(uuid4()),
- teacher_id=request.teacher_id,
- class_id=request.class_id,
- subject=request.subject,
- topic=request.topic,
- phase_durations=phase_durations,
- )
-
- sessions[session.session_id] = session
- persist_session(session)
- return build_session_response(session)
-
-
-@router.get("/sessions/{session_id}", response_model=SessionResponse)
-async def get_session(session_id: str) -> SessionResponse:
- """
- Ruft den aktuellen Status einer Session ab.
-
- Enthaelt alle Informationen inkl. Timer-Status und Phasen-Timeline.
- """
- session = get_session_or_404(session_id)
- return build_session_response(session)
-
-
-@router.post("/sessions/{session_id}/start", response_model=SessionResponse)
-async def start_lesson(session_id: str) -> SessionResponse:
- """
- Startet die Unterrichtsstunde.
-
- Wechselt von NOT_STARTED zur ersten Phase (EINSTIEG).
- """
- session = get_session_or_404(session_id)
-
- if session.current_phase != LessonPhase.NOT_STARTED:
- raise HTTPException(
- status_code=400,
- detail=f"Stunde bereits gestartet (aktuelle Phase: {session.current_phase.value})"
- )
-
- fsm = LessonStateMachine()
- session = fsm.transition(session, LessonPhase.EINSTIEG)
-
- persist_session(session)
- return build_session_response(session)
-
-
-@router.post("/sessions/{session_id}/next-phase", response_model=SessionResponse)
-async def next_phase(session_id: str) -> SessionResponse:
- """
- Wechselt zur naechsten Phase.
-
- Wirft 400 wenn keine naechste Phase verfuegbar (z.B. bei ENDED).
- """
- session = get_session_or_404(session_id)
-
- fsm = LessonStateMachine()
- next_p = fsm.next_phase(session.current_phase)
-
- if not next_p:
- raise HTTPException(
- status_code=400,
- detail=f"Keine naechste Phase verfuegbar (aktuelle Phase: {session.current_phase.value})"
- )
-
- session = fsm.transition(session, next_p)
- persist_session(session)
-
- # WebSocket-Benachrichtigung
- response = build_session_response(session)
- await notify_phase_change(session_id, session.current_phase.value, {
- "phase_display_name": session.get_phase_display_name(),
- "is_ended": session.current_phase == LessonPhase.ENDED
- })
- return response
-
-
-@router.post("/sessions/{session_id}/end", response_model=SessionResponse)
-async def end_lesson(session_id: str) -> SessionResponse:
- """
- Beendet die Unterrichtsstunde sofort.
-
- Kann von jeder aktiven Phase aus aufgerufen werden.
- """
- session = get_session_or_404(session_id)
-
- if session.current_phase == LessonPhase.ENDED:
- raise HTTPException(status_code=400, detail="Stunde bereits beendet")
-
- if session.current_phase == LessonPhase.NOT_STARTED:
- raise HTTPException(status_code=400, detail="Stunde noch nicht gestartet")
-
- # Direkt zur Endphase springen (ueberspringt evtl. Phasen)
- fsm = LessonStateMachine()
-
- # Phasen bis zum Ende durchlaufen
- while session.current_phase != LessonPhase.ENDED:
- next_p = fsm.next_phase(session.current_phase)
- if next_p:
- session = fsm.transition(session, next_p)
- else:
- break
-
- persist_session(session)
-
- # WebSocket-Benachrichtigung
- await notify_session_ended(session_id)
- return build_session_response(session)
-
-
-# === Quick Actions (Feature f26/f27/f28) ===
-
-@router.post("/sessions/{session_id}/pause", response_model=SessionResponse)
-async def toggle_pause(session_id: str) -> SessionResponse:
- """
- Pausiert oder setzt die laufende Stunde fort (Feature f27).
-
- Toggle-Funktion: Wenn pausiert -> fortsetzen, wenn laufend -> pausieren.
- Die Pause-Zeit wird nicht auf die Phasendauer angerechnet.
- """
- session = get_session_or_404(session_id)
-
- # Nur aktive Phasen koennen pausiert werden
- if session.current_phase in [LessonPhase.NOT_STARTED, LessonPhase.ENDED]:
- raise HTTPException(
- status_code=400,
- detail="Stunde ist nicht aktiv"
- )
-
- if session.is_paused:
- # Fortsetzen: Pause-Zeit zur Gesamt-Pause addieren
- if session.pause_started_at:
- pause_duration = (datetime.utcnow() - session.pause_started_at).total_seconds()
- session.total_paused_seconds += int(pause_duration)
-
- session.is_paused = False
- session.pause_started_at = None
- else:
- # Pausieren
- session.is_paused = True
- session.pause_started_at = datetime.utcnow()
-
- persist_session(session)
- return build_session_response(session)
-
-
-@router.post("/sessions/{session_id}/extend", response_model=SessionResponse)
-async def extend_phase(session_id: str, request: ExtendTimeRequest) -> SessionResponse:
- """
- Verlaengert die aktuelle Phase um zusaetzliche Minuten (Feature f28).
-
- Nuetzlich wenn mehr Zeit benoetigt wird, z.B. fuer vertiefte Diskussionen.
- """
- session = get_session_or_404(session_id)
-
- # Nur aktive Phasen koennen verlaengert werden
- if session.current_phase in [LessonPhase.NOT_STARTED, LessonPhase.ENDED]:
- raise HTTPException(
- status_code=400,
- detail="Stunde ist nicht aktiv"
- )
-
- # Aktuelle Phasendauer erhoehen
- phase_id = session.current_phase.value
- current_duration = session.phase_durations.get(phase_id, 10)
- session.phase_durations[phase_id] = current_duration + request.minutes
-
- persist_session(session)
- return build_session_response(session)
-
-
-@router.get("/sessions/{session_id}/timer", response_model=TimerStatus)
-async def get_timer(session_id: str) -> TimerStatus:
- """
- Ruft den Timer-Status der aktuellen Phase ab.
-
- Enthaelt verbleibende Zeit, Warnung und Overtime-Status.
- Sollte alle 5 Sekunden gepollt werden.
- """
- session = get_session_or_404(session_id)
- timer = PhaseTimer()
- status = timer.get_phase_status(session)
- return TimerStatus(**status)
-
-
-@router.get("/sessions/{session_id}/suggestions", response_model=SuggestionsResponse)
-async def get_suggestions(
- session_id: str,
- limit: int = Query(3, ge=1, le=10, description="Anzahl Vorschlaege")
-) -> SuggestionsResponse:
- """
- Ruft phasenspezifische Aktivitaets-Vorschlaege ab.
-
- Die Vorschlaege aendern sich je nach aktueller Phase.
- """
- session = get_session_or_404(session_id)
- engine = SuggestionEngine()
- response = engine.get_suggestions_response(session, limit)
-
- return SuggestionsResponse(
- suggestions=[SuggestionItem(**s) for s in response["suggestions"]],
- current_phase=response["current_phase"],
- phase_display_name=response["phase_display_name"],
- total_available=response["total_available"],
- )
-
-
-@router.put("/sessions/{session_id}/notes", response_model=SessionResponse)
-async def update_notes(session_id: str, request: NotesRequest) -> SessionResponse:
- """
- Aktualisiert Notizen und Hausaufgaben der Stunde.
- """
- session = get_session_or_404(session_id)
- session.notes = request.notes
- session.homework = request.homework
- persist_session(session)
- return build_session_response(session)
-
-
-@router.delete("/sessions/{session_id}")
-async def delete_session(session_id: str) -> Dict[str, str]:
- """
- Loescht eine Session.
- """
- if session_id not in sessions:
- raise HTTPException(status_code=404, detail="Session nicht gefunden")
-
- del sessions[session_id]
-
- # Auch aus DB loeschen
- if DB_ENABLED:
- try:
- from ..services.persistence import delete_session_from_db
- delete_session_from_db(session_id)
- except Exception as e:
- logger.error(f"Failed to delete session {session_id} from DB: {e}")
-
- return {"status": "deleted", "session_id": session_id}
-
-
-# === Session History (Feature f17) ===
-
-@router.get("/history/{teacher_id}", response_model=SessionHistoryResponse)
-async def get_session_history(
- teacher_id: str,
- limit: int = Query(20, ge=1, le=100, description="Max. Anzahl Eintraege"),
- offset: int = Query(0, ge=0, description="Offset fuer Pagination")
-) -> SessionHistoryResponse:
- """
- Ruft die Session-History eines Lehrers ab (Feature f17).
-
- Zeigt abgeschlossene Unterrichtsstunden mit Statistiken.
- Nur verfuegbar wenn DB aktiviert ist.
- """
- init_db_if_needed()
-
- if not DB_ENABLED:
- # Fallback: In-Memory Sessions filtern
- ended_sessions = [
- s for s in sessions.values()
- if s.teacher_id == teacher_id and s.current_phase == LessonPhase.ENDED
- ]
- ended_sessions.sort(
- key=lambda x: x.lesson_ended_at or datetime.min,
- reverse=True
- )
- paginated = ended_sessions[offset:offset + limit]
-
- items = []
- for s in paginated:
- duration = None
- if s.lesson_started_at and s.lesson_ended_at:
- duration = int((s.lesson_ended_at - s.lesson_started_at).total_seconds() / 60)
-
- items.append(SessionHistoryItem(
- session_id=s.session_id,
- teacher_id=s.teacher_id,
- class_id=s.class_id,
- subject=s.subject,
- topic=s.topic,
- lesson_started_at=s.lesson_started_at.isoformat() if s.lesson_started_at else None,
- lesson_ended_at=s.lesson_ended_at.isoformat() if s.lesson_ended_at else None,
- total_duration_minutes=duration,
- phases_completed=len(s.phase_history),
- notes=s.notes,
- homework=s.homework,
- ))
-
- return SessionHistoryResponse(
- sessions=items,
- total_count=len(ended_sessions),
- limit=limit,
- offset=offset,
- )
-
- # DB-basierte History
- try:
- from classroom_engine.repository import SessionRepository
- db = SessionLocal()
- repo = SessionRepository(db)
-
- # Beendete Sessions abrufen
- db_sessions = repo.get_history_by_teacher(teacher_id, limit, offset)
-
- # Gesamtanzahl ermitteln
- from classroom_engine.db_models import LessonSessionDB, LessonPhaseEnum
- total_count = db.query(LessonSessionDB).filter(
- LessonSessionDB.teacher_id == teacher_id,
- LessonSessionDB.current_phase == LessonPhaseEnum.ENDED
- ).count()
-
- items = []
- for db_session in db_sessions:
- duration = None
- if db_session.lesson_started_at and db_session.lesson_ended_at:
- duration = int((db_session.lesson_ended_at - db_session.lesson_started_at).total_seconds() / 60)
-
- phase_history = db_session.phase_history or []
-
- items.append(SessionHistoryItem(
- session_id=db_session.id,
- teacher_id=db_session.teacher_id,
- class_id=db_session.class_id,
- subject=db_session.subject,
- topic=db_session.topic,
- lesson_started_at=db_session.lesson_started_at.isoformat() if db_session.lesson_started_at else None,
- lesson_ended_at=db_session.lesson_ended_at.isoformat() if db_session.lesson_ended_at else None,
- total_duration_minutes=duration,
- phases_completed=len(phase_history),
- notes=db_session.notes or "",
- homework=db_session.homework or "",
- ))
-
- db.close()
-
- return SessionHistoryResponse(
- sessions=items,
- total_count=total_count,
- limit=limit,
- offset=offset,
- )
-
- except Exception as e:
- logger.error(f"Failed to get session history: {e}")
- raise HTTPException(status_code=500, detail="Fehler beim Laden der History")
-
-
-# === Utility Endpoints ===
-
-@router.get("/phases", response_model=PhasesListResponse)
-async def list_phases() -> PhasesListResponse:
- """
- Listet alle verfuegbaren Unterrichtsphasen mit Metadaten.
- """
- phases = []
- for phase_id, config in LESSON_PHASES.items():
- phases.append({
- "phase": phase_id,
- "display_name": config["display_name"],
- "default_duration_minutes": config["default_duration_minutes"],
- "activities": config["activities"],
- "icon": config["icon"],
- "description": config.get("description", ""),
- })
- return PhasesListResponse(phases=phases)
-
-
-@router.get("/sessions", response_model=ActiveSessionsResponse)
-async def list_active_sessions(
- teacher_id: Optional[str] = Query(None, description="Filter nach Lehrer")
-) -> ActiveSessionsResponse:
- """
- Listet alle (optionally gefilterten) Sessions.
- """
- sessions_list = []
- for session in sessions.values():
- if teacher_id and session.teacher_id != teacher_id:
- continue
-
- fsm = LessonStateMachine()
- sessions_list.append({
- "session_id": session.session_id,
- "teacher_id": session.teacher_id,
- "class_id": session.class_id,
- "subject": session.subject,
- "current_phase": session.current_phase.value,
- "is_active": fsm.is_lesson_active(session),
- "lesson_started_at": session.lesson_started_at.isoformat() if session.lesson_started_at else None,
- })
-
- return ActiveSessionsResponse(
- sessions=sessions_list,
- count=len(sessions_list)
- )
-
-
-@router.get("/health")
-async def health_check() -> Dict[str, Any]:
- """
- Health-Check fuer den Classroom Service.
- """
- from sqlalchemy import text
-
- db_status = "disabled"
- if DB_ENABLED:
- try:
- db = SessionLocal()
- db.execute(text("SELECT 1"))
- db.close()
- db_status = "connected"
- except Exception as e:
- db_status = f"error: {str(e)}"
-
- return {
- "status": "healthy",
- "service": "classroom-engine",
- "active_sessions": len(sessions),
- "db_enabled": DB_ENABLED,
- "db_status": db_status,
- "timestamp": datetime.utcnow().isoformat(),
- }
+# Re-export for backward compatibility
+__all__ = ["router", "build_session_response"]
diff --git a/backend-lehrer/classroom/routes/sessions_actions.py b/backend-lehrer/classroom/routes/sessions_actions.py
new file mode 100644
index 0000000..3f23693
--- /dev/null
+++ b/backend-lehrer/classroom/routes/sessions_actions.py
@@ -0,0 +1,173 @@
+"""
+Classroom API - Session Actions Routes
+
+Quick actions (pause, extend, timer), suggestions, utility endpoints.
+"""
+
+from typing import Dict, Optional, Any
+from datetime import datetime
+import logging
+
+from fastapi import APIRouter, HTTPException, Query
+from sqlalchemy import text
+
+from classroom_engine import (
+ LessonPhase,
+ LessonStateMachine,
+ PhaseTimer,
+ SuggestionEngine,
+ LESSON_PHASES,
+)
+
+from ..models import (
+ ExtendTimeRequest,
+ TimerStatus,
+ SuggestionItem,
+ SuggestionsResponse,
+ PhasesListResponse,
+ ActiveSessionsResponse,
+)
+from ..services.persistence import (
+ sessions,
+ persist_session,
+ get_session_or_404,
+ DB_ENABLED,
+ SessionLocal,
+)
+from .sessions_core import build_session_response, SessionResponse
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(tags=["Sessions"])
+
+
+# === Quick Actions (Feature f26/f27/f28) ===
+
+@router.post("/sessions/{session_id}/pause", response_model=SessionResponse)
+async def toggle_pause(session_id: str) -> SessionResponse:
+ """Pausiert oder setzt die laufende Stunde fort (Feature f27)."""
+ session = get_session_or_404(session_id)
+
+ if session.current_phase in [LessonPhase.NOT_STARTED, LessonPhase.ENDED]:
+ raise HTTPException(status_code=400, detail="Stunde ist nicht aktiv")
+
+ if session.is_paused:
+ if session.pause_started_at:
+ pause_duration = (datetime.utcnow() - session.pause_started_at).total_seconds()
+ session.total_paused_seconds += int(pause_duration)
+ session.is_paused = False
+ session.pause_started_at = None
+ else:
+ session.is_paused = True
+ session.pause_started_at = datetime.utcnow()
+
+ persist_session(session)
+ return build_session_response(session)
+
+
+@router.post("/sessions/{session_id}/extend", response_model=SessionResponse)
+async def extend_phase(session_id: str, request: ExtendTimeRequest) -> SessionResponse:
+ """Verlaengert die aktuelle Phase um zusaetzliche Minuten (Feature f28)."""
+ session = get_session_or_404(session_id)
+
+ if session.current_phase in [LessonPhase.NOT_STARTED, LessonPhase.ENDED]:
+ raise HTTPException(status_code=400, detail="Stunde ist nicht aktiv")
+
+ phase_id = session.current_phase.value
+ current_duration = session.phase_durations.get(phase_id, 10)
+ session.phase_durations[phase_id] = current_duration + request.minutes
+
+ persist_session(session)
+ return build_session_response(session)
+
+
+@router.get("/sessions/{session_id}/timer", response_model=TimerStatus)
+async def get_timer(session_id: str) -> TimerStatus:
+ """Ruft den Timer-Status der aktuellen Phase ab."""
+ session = get_session_or_404(session_id)
+ timer = PhaseTimer()
+ status = timer.get_phase_status(session)
+ return TimerStatus(**status)
+
+
+@router.get("/sessions/{session_id}/suggestions", response_model=SuggestionsResponse)
+async def get_suggestions(
+ session_id: str,
+ limit: int = Query(3, ge=1, le=10)
+) -> SuggestionsResponse:
+ """Ruft phasenspezifische Aktivitaets-Vorschlaege ab."""
+ session = get_session_or_404(session_id)
+ engine = SuggestionEngine()
+ response = engine.get_suggestions_response(session, limit)
+
+ return SuggestionsResponse(
+ suggestions=[SuggestionItem(**s) for s in response["suggestions"]],
+ current_phase=response["current_phase"],
+ phase_display_name=response["phase_display_name"],
+ total_available=response["total_available"],
+ )
+
+
+# === Utility Endpoints ===
+
+@router.get("/phases", response_model=PhasesListResponse)
+async def list_phases() -> PhasesListResponse:
+ """Listet alle verfuegbaren Unterrichtsphasen mit Metadaten."""
+ phases = []
+ for phase_id, config in LESSON_PHASES.items():
+ phases.append({
+ "phase": phase_id,
+ "display_name": config["display_name"],
+ "default_duration_minutes": config["default_duration_minutes"],
+ "activities": config["activities"],
+ "icon": config["icon"],
+ "description": config.get("description", ""),
+ })
+ return PhasesListResponse(phases=phases)
+
+
+@router.get("/sessions", response_model=ActiveSessionsResponse)
+async def list_active_sessions(
+ teacher_id: Optional[str] = Query(None)
+) -> ActiveSessionsResponse:
+ """Listet alle (optionally gefilterten) Sessions."""
+ sessions_list = []
+ for session in sessions.values():
+ if teacher_id and session.teacher_id != teacher_id:
+ continue
+
+ fsm = LessonStateMachine()
+ sessions_list.append({
+ "session_id": session.session_id,
+ "teacher_id": session.teacher_id,
+ "class_id": session.class_id,
+ "subject": session.subject,
+ "current_phase": session.current_phase.value,
+ "is_active": fsm.is_lesson_active(session),
+ "lesson_started_at": session.lesson_started_at.isoformat() if session.lesson_started_at else None,
+ })
+
+ return ActiveSessionsResponse(sessions=sessions_list, count=len(sessions_list))
+
+
+@router.get("/health")
+async def health_check() -> Dict[str, Any]:
+ """Health-Check fuer den Classroom Service."""
+ db_status = "disabled"
+ if DB_ENABLED:
+ try:
+ db = SessionLocal()
+ db.execute(text("SELECT 1"))
+ db.close()
+ db_status = "connected"
+ except Exception as e:
+ db_status = f"error: {str(e)}"
+
+ return {
+ "status": "healthy",
+ "service": "classroom-engine",
+ "active_sessions": len(sessions),
+ "db_enabled": DB_ENABLED,
+ "db_status": db_status,
+ "timestamp": datetime.utcnow().isoformat(),
+ }
diff --git a/backend-lehrer/classroom/routes/sessions_core.py b/backend-lehrer/classroom/routes/sessions_core.py
new file mode 100644
index 0000000..1f4e529
--- /dev/null
+++ b/backend-lehrer/classroom/routes/sessions_core.py
@@ -0,0 +1,283 @@
+"""
+Classroom API - Session Core Routes
+
+Session CRUD, lifecycle, and history endpoints.
+"""
+
+from uuid import uuid4
+from typing import Dict, Optional, Any
+from datetime import datetime
+import logging
+
+from fastapi import APIRouter, HTTPException, Query
+
+from classroom_engine import (
+ LessonPhase,
+ LessonSession,
+ LessonStateMachine,
+ PhaseTimer,
+ LESSON_PHASES,
+)
+
+from ..models import (
+ CreateSessionRequest,
+ NotesRequest,
+ PhaseInfo,
+ TimerStatus,
+ SessionResponse,
+ PhasesListResponse,
+ ActiveSessionsResponse,
+ SessionHistoryItem,
+ SessionHistoryResponse,
+)
+from ..services.persistence import (
+ sessions,
+ init_db_if_needed,
+ persist_session,
+ get_session_or_404,
+ DB_ENABLED,
+ SessionLocal,
+)
+from ..websocket_manager import notify_phase_change, notify_session_ended
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(tags=["Sessions"])
+
+
+def build_session_response(session: LessonSession) -> SessionResponse:
+ """Baut die vollstaendige Session-Response."""
+ fsm = LessonStateMachine()
+ timer = PhaseTimer()
+
+ timer_status = timer.get_phase_status(session)
+ phases_info = fsm.get_phases_info(session)
+
+ return SessionResponse(
+ session_id=session.session_id,
+ teacher_id=session.teacher_id,
+ class_id=session.class_id,
+ subject=session.subject,
+ topic=session.topic,
+ current_phase=session.current_phase.value,
+ phase_display_name=session.get_phase_display_name(),
+ phase_started_at=session.phase_started_at.isoformat() if session.phase_started_at else None,
+ lesson_started_at=session.lesson_started_at.isoformat() if session.lesson_started_at else None,
+ lesson_ended_at=session.lesson_ended_at.isoformat() if session.lesson_ended_at else None,
+ timer=TimerStatus(**timer_status),
+ phases=[PhaseInfo(**p) for p in phases_info],
+ phase_history=session.phase_history,
+ notes=session.notes,
+ homework=session.homework,
+ is_active=fsm.is_lesson_active(session),
+ is_ended=fsm.is_lesson_ended(session),
+ is_paused=session.is_paused,
+ )
+
+
+# === Session CRUD Endpoints ===
+
+@router.post("/sessions", response_model=SessionResponse)
+async def create_session(request: CreateSessionRequest) -> SessionResponse:
+ """Erstellt eine neue Unterrichtsstunde (Session)."""
+ init_db_if_needed()
+
+ phase_durations = {
+ "einstieg": 8, "erarbeitung": 20, "sicherung": 10,
+ "transfer": 7, "reflexion": 5,
+ }
+ if request.phase_durations:
+ phase_durations.update(request.phase_durations)
+
+ session = LessonSession(
+ session_id=str(uuid4()),
+ teacher_id=request.teacher_id,
+ class_id=request.class_id,
+ subject=request.subject,
+ topic=request.topic,
+ phase_durations=phase_durations,
+ )
+
+ sessions[session.session_id] = session
+ persist_session(session)
+ return build_session_response(session)
+
+
+@router.get("/sessions/{session_id}", response_model=SessionResponse)
+async def get_session(session_id: str) -> SessionResponse:
+ """Ruft den aktuellen Status einer Session ab."""
+ session = get_session_or_404(session_id)
+ return build_session_response(session)
+
+
+@router.post("/sessions/{session_id}/start", response_model=SessionResponse)
+async def start_lesson(session_id: str) -> SessionResponse:
+ """Startet die Unterrichtsstunde."""
+ session = get_session_or_404(session_id)
+
+ if session.current_phase != LessonPhase.NOT_STARTED:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Stunde bereits gestartet (aktuelle Phase: {session.current_phase.value})"
+ )
+
+ fsm = LessonStateMachine()
+ session = fsm.transition(session, LessonPhase.EINSTIEG)
+ persist_session(session)
+ return build_session_response(session)
+
+
+@router.post("/sessions/{session_id}/next-phase", response_model=SessionResponse)
+async def next_phase(session_id: str) -> SessionResponse:
+ """Wechselt zur naechsten Phase."""
+ session = get_session_or_404(session_id)
+
+ fsm = LessonStateMachine()
+ next_p = fsm.next_phase(session.current_phase)
+
+ if not next_p:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Keine naechste Phase verfuegbar (aktuelle Phase: {session.current_phase.value})"
+ )
+
+ session = fsm.transition(session, next_p)
+ persist_session(session)
+
+ response = build_session_response(session)
+ await notify_phase_change(session_id, session.current_phase.value, {
+ "phase_display_name": session.get_phase_display_name(),
+ "is_ended": session.current_phase == LessonPhase.ENDED
+ })
+ return response
+
+
+@router.post("/sessions/{session_id}/end", response_model=SessionResponse)
+async def end_lesson(session_id: str) -> SessionResponse:
+ """Beendet die Unterrichtsstunde sofort."""
+ session = get_session_or_404(session_id)
+
+ if session.current_phase == LessonPhase.ENDED:
+ raise HTTPException(status_code=400, detail="Stunde bereits beendet")
+ if session.current_phase == LessonPhase.NOT_STARTED:
+ raise HTTPException(status_code=400, detail="Stunde noch nicht gestartet")
+
+ fsm = LessonStateMachine()
+ while session.current_phase != LessonPhase.ENDED:
+ next_p = fsm.next_phase(session.current_phase)
+ if next_p:
+ session = fsm.transition(session, next_p)
+ else:
+ break
+
+ persist_session(session)
+ await notify_session_ended(session_id)
+ return build_session_response(session)
+
+
+@router.put("/sessions/{session_id}/notes", response_model=SessionResponse)
+async def update_notes(session_id: str, request: NotesRequest) -> SessionResponse:
+ """Aktualisiert Notizen und Hausaufgaben der Stunde."""
+ session = get_session_or_404(session_id)
+ session.notes = request.notes
+ session.homework = request.homework
+ persist_session(session)
+ return build_session_response(session)
+
+
+@router.delete("/sessions/{session_id}")
+async def delete_session(session_id: str) -> Dict[str, str]:
+ """Loescht eine Session."""
+ if session_id not in sessions:
+ raise HTTPException(status_code=404, detail="Session nicht gefunden")
+
+ del sessions[session_id]
+
+ if DB_ENABLED:
+ try:
+ from ..services.persistence import delete_session_from_db
+ delete_session_from_db(session_id)
+ except Exception as e:
+ logger.error(f"Failed to delete session {session_id} from DB: {e}")
+
+ return {"status": "deleted", "session_id": session_id}
+
+
+# === Session History (Feature f17) ===
+
+@router.get("/history/{teacher_id}", response_model=SessionHistoryResponse)
+async def get_session_history(
+ teacher_id: str,
+ limit: int = Query(20, ge=1, le=100),
+ offset: int = Query(0, ge=0)
+) -> SessionHistoryResponse:
+ """Ruft die Session-History eines Lehrers ab (Feature f17)."""
+ init_db_if_needed()
+
+ if not DB_ENABLED:
+ ended_sessions = [
+ s for s in sessions.values()
+ if s.teacher_id == teacher_id and s.current_phase == LessonPhase.ENDED
+ ]
+ ended_sessions.sort(key=lambda x: x.lesson_ended_at or datetime.min, reverse=True)
+ paginated = ended_sessions[offset:offset + limit]
+
+ items = []
+ for s in paginated:
+ duration = None
+ if s.lesson_started_at and s.lesson_ended_at:
+ duration = int((s.lesson_ended_at - s.lesson_started_at).total_seconds() / 60)
+
+ items.append(SessionHistoryItem(
+ session_id=s.session_id, teacher_id=s.teacher_id,
+ class_id=s.class_id, subject=s.subject, topic=s.topic,
+ lesson_started_at=s.lesson_started_at.isoformat() if s.lesson_started_at else None,
+ lesson_ended_at=s.lesson_ended_at.isoformat() if s.lesson_ended_at else None,
+ total_duration_minutes=duration,
+ phases_completed=len(s.phase_history),
+ notes=s.notes, homework=s.homework,
+ ))
+
+ return SessionHistoryResponse(
+ sessions=items, total_count=len(ended_sessions), limit=limit, offset=offset,
+ )
+
+ try:
+ from classroom_engine.repository import SessionRepository
+ db = SessionLocal()
+ repo = SessionRepository(db)
+ db_sessions = repo.get_history_by_teacher(teacher_id, limit, offset)
+
+ from classroom_engine.db_models import LessonSessionDB, LessonPhaseEnum
+ total_count = db.query(LessonSessionDB).filter(
+ LessonSessionDB.teacher_id == teacher_id,
+ LessonSessionDB.current_phase == LessonPhaseEnum.ENDED
+ ).count()
+
+ items = []
+ for db_session in db_sessions:
+ duration = None
+ if db_session.lesson_started_at and db_session.lesson_ended_at:
+ duration = int((db_session.lesson_ended_at - db_session.lesson_started_at).total_seconds() / 60)
+
+ phase_history = db_session.phase_history or []
+
+ items.append(SessionHistoryItem(
+ session_id=db_session.id, teacher_id=db_session.teacher_id,
+ class_id=db_session.class_id, subject=db_session.subject, topic=db_session.topic,
+ lesson_started_at=db_session.lesson_started_at.isoformat() if db_session.lesson_started_at else None,
+ lesson_ended_at=db_session.lesson_ended_at.isoformat() if db_session.lesson_ended_at else None,
+ total_duration_minutes=duration,
+ phases_completed=len(phase_history),
+ notes=db_session.notes or "", homework=db_session.homework or "",
+ ))
+
+ db.close()
+
+ return SessionHistoryResponse(
+ sessions=items, total_count=total_count, limit=limit, offset=offset,
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to get session history: {e}")
+ raise HTTPException(status_code=500, detail="Fehler beim Laden der History")
diff --git a/backend-lehrer/classroom_engine/__init__.py b/backend-lehrer/classroom_engine/__init__.py
index 23284bb..4362c38 100644
--- a/backend-lehrer/classroom_engine/__init__.py
+++ b/backend-lehrer/classroom_engine/__init__.py
@@ -32,7 +32,8 @@ from .models import (
)
from .fsm import LessonStateMachine
from .timer import PhaseTimer
-from .suggestions import SuggestionEngine, PHASE_SUGGESTIONS, SUBJECT_SUGGESTIONS
+from .suggestions import SuggestionEngine
+from .suggestion_data import PHASE_SUGGESTIONS, SUBJECT_SUGGESTIONS
from .context_models import (
MacroPhaseEnum,
EventTypeEnum,
diff --git a/backend-lehrer/classroom_engine/analytics.py b/backend-lehrer/classroom_engine/analytics.py
index f0d4fa1..ca4dad1 100644
--- a/backend-lehrer/classroom_engine/analytics.py
+++ b/backend-lehrer/classroom_engine/analytics.py
@@ -11,256 +11,28 @@ WICHTIG: Keine wertenden Metriken (z.B. "Sie haben 70% geredet").
Fokus auf neutrale, hilfreiche Statistiken.
"""
-from dataclasses import dataclass, field
-from datetime import datetime, timedelta
+from datetime import datetime
from typing import Optional, List, Dict, Any
-from enum import Enum
+from .analytics_models import (
+ PhaseStatistics,
+ SessionSummary,
+ TeacherAnalytics,
+ LessonReflection,
+)
-# ==================== Analytics Models ====================
+# Re-export models for backward compatibility
+__all__ = [
+ "PhaseStatistics",
+ "SessionSummary",
+ "TeacherAnalytics",
+ "LessonReflection",
+ "AnalyticsCalculator",
+]
-@dataclass
-class PhaseStatistics:
- """Statistik fuer eine einzelne Phase."""
- phase: str
- display_name: str
-
- # Dauer-Metriken
- planned_duration_seconds: int
- actual_duration_seconds: int
- difference_seconds: int # positiv = laenger als geplant
-
- # Overtime
- had_overtime: bool
- overtime_seconds: int = 0
-
- # Erweiterungen
- was_extended: bool = False
- extension_minutes: int = 0
-
- # Pausen
- pause_count: int = 0
- total_pause_seconds: int = 0
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "phase": self.phase,
- "display_name": self.display_name,
- "planned_duration_seconds": self.planned_duration_seconds,
- "actual_duration_seconds": self.actual_duration_seconds,
- "difference_seconds": self.difference_seconds,
- "difference_formatted": self._format_difference(),
- "had_overtime": self.had_overtime,
- "overtime_seconds": self.overtime_seconds,
- "overtime_formatted": self._format_seconds(self.overtime_seconds),
- "was_extended": self.was_extended,
- "extension_minutes": self.extension_minutes,
- "pause_count": self.pause_count,
- "total_pause_seconds": self.total_pause_seconds,
- }
-
- def _format_difference(self) -> str:
- """Formatiert die Differenz als +/-MM:SS."""
- prefix = "+" if self.difference_seconds >= 0 else ""
- return f"{prefix}{self._format_seconds(abs(self.difference_seconds))}"
-
- def _format_seconds(self, seconds: int) -> str:
- """Formatiert Sekunden als MM:SS."""
- mins = seconds // 60
- secs = seconds % 60
- return f"{mins:02d}:{secs:02d}"
-
-
-@dataclass
-class SessionSummary:
- """
- Zusammenfassung einer Unterrichtsstunde.
-
- Wird nach Stundenende generiert und fuer das Lehrer-Dashboard verwendet.
- """
- session_id: str
- teacher_id: str
- class_id: str
- subject: str
- topic: Optional[str]
- date: datetime
-
- # Dauer
- total_duration_seconds: int
- planned_duration_seconds: int
-
- # Phasen-Statistiken
- phases_completed: int
- total_phases: int = 5
- phase_statistics: List[PhaseStatistics] = field(default_factory=list)
-
- # Overtime-Zusammenfassung
- total_overtime_seconds: int = 0
- phases_with_overtime: int = 0
-
- # Pausen-Zusammenfassung
- total_pause_count: int = 0
- total_pause_seconds: int = 0
-
- # Post-Lesson Reflection
- reflection_notes: str = ""
- reflection_rating: Optional[int] = None # 1-5 Sterne (optional)
- key_learnings: List[str] = field(default_factory=list)
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "session_id": self.session_id,
- "teacher_id": self.teacher_id,
- "class_id": self.class_id,
- "subject": self.subject,
- "topic": self.topic,
- "date": self.date.isoformat() if self.date else None,
- "date_formatted": self._format_date(),
- "total_duration_seconds": self.total_duration_seconds,
- "total_duration_formatted": self._format_seconds(self.total_duration_seconds),
- "planned_duration_seconds": self.planned_duration_seconds,
- "planned_duration_formatted": self._format_seconds(self.planned_duration_seconds),
- "phases_completed": self.phases_completed,
- "total_phases": self.total_phases,
- "completion_percentage": round(self.phases_completed / self.total_phases * 100),
- "phase_statistics": [p.to_dict() for p in self.phase_statistics],
- "total_overtime_seconds": self.total_overtime_seconds,
- "total_overtime_formatted": self._format_seconds(self.total_overtime_seconds),
- "phases_with_overtime": self.phases_with_overtime,
- "total_pause_count": self.total_pause_count,
- "total_pause_seconds": self.total_pause_seconds,
- "reflection_notes": self.reflection_notes,
- "reflection_rating": self.reflection_rating,
- "key_learnings": self.key_learnings,
- }
-
- def _format_seconds(self, seconds: int) -> str:
- mins = seconds // 60
- secs = seconds % 60
- return f"{mins:02d}:{secs:02d}"
-
- def _format_date(self) -> str:
- if not self.date:
- return ""
- return self.date.strftime("%d.%m.%Y %H:%M")
-
-
-@dataclass
-class TeacherAnalytics:
- """
- Aggregierte Statistiken fuer einen Lehrer.
-
- Zeigt Trends und Muster ueber mehrere Stunden.
- """
- teacher_id: str
- period_start: datetime
- period_end: datetime
-
- # Stunden-Uebersicht
- total_sessions: int = 0
- completed_sessions: int = 0
- total_teaching_minutes: int = 0
-
- # Durchschnittliche Phasendauern
- avg_phase_durations: Dict[str, float] = field(default_factory=dict)
-
- # Overtime-Trends
- sessions_with_overtime: int = 0
- avg_overtime_seconds: float = 0
- most_overtime_phase: Optional[str] = None
-
- # Pausen-Statistik
- avg_pause_count: float = 0
- avg_pause_duration_seconds: float = 0
-
- # Faecher-Verteilung
- subjects_taught: Dict[str, int] = field(default_factory=dict)
-
- # Klassen-Verteilung
- classes_taught: Dict[str, int] = field(default_factory=dict)
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "teacher_id": self.teacher_id,
- "period_start": self.period_start.isoformat() if self.period_start else None,
- "period_end": self.period_end.isoformat() if self.period_end else None,
- "total_sessions": self.total_sessions,
- "completed_sessions": self.completed_sessions,
- "total_teaching_minutes": self.total_teaching_minutes,
- "total_teaching_hours": round(self.total_teaching_minutes / 60, 1),
- "avg_phase_durations": self.avg_phase_durations,
- "sessions_with_overtime": self.sessions_with_overtime,
- "overtime_percentage": round(self.sessions_with_overtime / max(self.total_sessions, 1) * 100),
- "avg_overtime_seconds": round(self.avg_overtime_seconds),
- "avg_overtime_formatted": self._format_seconds(int(self.avg_overtime_seconds)),
- "most_overtime_phase": self.most_overtime_phase,
- "avg_pause_count": round(self.avg_pause_count, 1),
- "avg_pause_duration_seconds": round(self.avg_pause_duration_seconds),
- "subjects_taught": self.subjects_taught,
- "classes_taught": self.classes_taught,
- }
-
- def _format_seconds(self, seconds: int) -> str:
- mins = seconds // 60
- secs = seconds % 60
- return f"{mins:02d}:{secs:02d}"
-
-
-# ==================== Reflection Model ====================
-
-@dataclass
-class LessonReflection:
- """
- Post-Lesson Reflection (Feature).
-
- Ermoeglicht Lehrern, nach der Stunde Notizen zu machen.
- Keine Bewertung, nur Reflexion.
- """
- reflection_id: str
- session_id: str
- teacher_id: str
-
- # Reflexionsnotizen
- notes: str = ""
-
- # Optional: Sterne-Bewertung (selbst-eingeschaetzt)
- overall_rating: Optional[int] = None # 1-5
-
- # Was hat gut funktioniert?
- what_worked: List[str] = field(default_factory=list)
-
- # Was wuerde ich naechstes Mal anders machen?
- improvements: List[str] = field(default_factory=list)
-
- # Notizen fuer naechste Stunde
- notes_for_next_lesson: str = ""
-
- created_at: Optional[datetime] = None
- updated_at: Optional[datetime] = None
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "reflection_id": self.reflection_id,
- "session_id": self.session_id,
- "teacher_id": self.teacher_id,
- "notes": self.notes,
- "overall_rating": self.overall_rating,
- "what_worked": self.what_worked,
- "improvements": self.improvements,
- "notes_for_next_lesson": self.notes_for_next_lesson,
- "created_at": self.created_at.isoformat() if self.created_at else None,
- "updated_at": self.updated_at.isoformat() if self.updated_at else None,
- }
-
-
-# ==================== Analytics Calculator ====================
class AnalyticsCalculator:
- """
- Berechnet Analytics aus Session-Daten.
-
- Verwendet In-Memory-Daten oder DB-Daten.
- """
+ """Berechnet Analytics aus Session-Daten."""
PHASE_DISPLAY_NAMES = {
"einstieg": "Einstieg",
@@ -276,24 +48,13 @@ class AnalyticsCalculator:
session_data: Dict[str, Any],
phase_history: List[Dict[str, Any]]
) -> SessionSummary:
- """
- Berechnet die Zusammenfassung einer Session.
-
- Args:
- session_data: Session-Dictionary (aus LessonSession.to_dict())
- phase_history: Liste der Phasen-History-Eintraege
-
- Returns:
- SessionSummary mit allen berechneten Statistiken
- """
- # Basis-Daten
+ """Berechnet die Zusammenfassung einer Session."""
session_id = session_data.get("session_id", "")
teacher_id = session_data.get("teacher_id", "")
class_id = session_data.get("class_id", "")
subject = session_data.get("subject", "")
topic = session_data.get("topic")
- # Timestamps
lesson_started = session_data.get("lesson_started_at")
lesson_ended = session_data.get("lesson_ended_at")
@@ -302,16 +63,13 @@ class AnalyticsCalculator:
if isinstance(lesson_ended, str):
lesson_ended = datetime.fromisoformat(lesson_ended.replace("Z", "+00:00"))
- # Dauer berechnen
total_duration = 0
if lesson_started and lesson_ended:
total_duration = int((lesson_ended - lesson_started).total_seconds())
- # Geplante Dauer
phase_durations = session_data.get("phase_durations", {})
- planned_duration = sum(phase_durations.values()) * 60 # Minuten zu Sekunden
+ planned_duration = sum(phase_durations.values()) * 60
- # Phasen-Statistiken berechnen
phase_stats = []
total_overtime = 0
phases_with_overtime = 0
@@ -324,18 +82,10 @@ class AnalyticsCalculator:
if phase in ["not_started", "ended"]:
continue
- # Geplante Dauer fuer diese Phase
planned_seconds = phase_durations.get(phase, 0) * 60
-
- # Tatsaechliche Dauer
- actual_seconds = entry.get("duration_seconds", 0)
- if actual_seconds is None:
- actual_seconds = 0
-
- # Differenz
+ actual_seconds = entry.get("duration_seconds", 0) or 0
difference = actual_seconds - planned_seconds
- # Overtime (nur positive Differenz zaehlt)
had_overtime = difference > 0
overtime_seconds = max(0, difference)
@@ -343,13 +93,11 @@ class AnalyticsCalculator:
total_overtime += overtime_seconds
phases_with_overtime += 1
- # Pausen
pause_count = entry.get("pause_count", 0) or 0
pause_seconds = entry.get("total_pause_seconds", 0) or 0
total_pause_count += pause_count
total_pause_seconds += pause_seconds
- # Phase als abgeschlossen zaehlen
if entry.get("ended_at"):
phases_completed += 1
@@ -368,16 +116,12 @@ class AnalyticsCalculator:
))
return SessionSummary(
- session_id=session_id,
- teacher_id=teacher_id,
- class_id=class_id,
- subject=subject,
- topic=topic,
+ session_id=session_id, teacher_id=teacher_id,
+ class_id=class_id, subject=subject, topic=topic,
date=lesson_started or datetime.now(),
total_duration_seconds=total_duration,
planned_duration_seconds=planned_duration,
- phases_completed=phases_completed,
- total_phases=5,
+ phases_completed=phases_completed, total_phases=5,
phase_statistics=phase_stats,
total_overtime_seconds=total_overtime,
phases_with_overtime=phases_with_overtime,
@@ -392,31 +136,15 @@ class AnalyticsCalculator:
period_start: datetime,
period_end: datetime
) -> TeacherAnalytics:
- """
- Berechnet aggregierte Statistiken fuer einen Lehrer.
-
- Args:
- sessions: Liste von Session-Dictionaries
- period_start: Beginn des Zeitraums
- period_end: Ende des Zeitraums
-
- Returns:
- TeacherAnalytics mit aggregierten Statistiken
- """
+ """Berechnet aggregierte Statistiken fuer einen Lehrer."""
if not sessions:
- return TeacherAnalytics(
- teacher_id="",
- period_start=period_start,
- period_end=period_end,
- )
+ return TeacherAnalytics(teacher_id="", period_start=period_start, period_end=period_end)
teacher_id = sessions[0].get("teacher_id", "")
- # Basis-Zaehler
total_sessions = len(sessions)
completed_sessions = sum(1 for s in sessions if s.get("lesson_ended_at"))
- # Gesamtdauer berechnen
total_minutes = 0
for session in sessions:
started = session.get("lesson_started_at")
@@ -428,41 +156,29 @@ class AnalyticsCalculator:
ended = datetime.fromisoformat(ended.replace("Z", "+00:00"))
total_minutes += (ended - started).total_seconds() / 60
- # Durchschnittliche Phasendauern
phase_durations_sum: Dict[str, List[int]] = {
- "einstieg": [],
- "erarbeitung": [],
- "sicherung": [],
- "transfer": [],
- "reflexion": [],
+ "einstieg": [], "erarbeitung": [], "sicherung": [],
+ "transfer": [], "reflexion": [],
}
- # Overtime-Tracking
overtime_count = 0
overtime_seconds_total = 0
phase_overtime: Dict[str, int] = {}
-
- # Pausen-Tracking
pause_counts = []
pause_durations = []
-
- # Faecher und Klassen
subjects: Dict[str, int] = {}
classes: Dict[str, int] = {}
for session in sessions:
- # Fach und Klasse zaehlen
subject = session.get("subject", "")
class_id = session.get("class_id", "")
subjects[subject] = subjects.get(subject, 0) + 1
classes[class_id] = classes.get(class_id, 0) + 1
- # Phase History analysieren
history = session.get("phase_history", [])
session_has_overtime = False
session_pause_count = 0
session_pause_duration = 0
-
phase_durations_dict = session.get("phase_durations", {})
for entry in history:
@@ -471,7 +187,6 @@ class AnalyticsCalculator:
duration = entry.get("duration_seconds", 0) or 0
phase_durations_sum[phase].append(duration)
- # Overtime berechnen
planned = phase_durations_dict.get(phase, 0) * 60
if duration > planned:
overtime = duration - planned
@@ -479,35 +194,25 @@ class AnalyticsCalculator:
session_has_overtime = True
phase_overtime[phase] = phase_overtime.get(phase, 0) + overtime
- # Pausen zaehlen
session_pause_count += entry.get("pause_count", 0) or 0
session_pause_duration += entry.get("total_pause_seconds", 0) or 0
if session_has_overtime:
overtime_count += 1
-
pause_counts.append(session_pause_count)
pause_durations.append(session_pause_duration)
- # Durchschnitte berechnen
avg_durations = {}
for phase, durations in phase_durations_sum.items():
- if durations:
- avg_durations[phase] = round(sum(durations) / len(durations))
- else:
- avg_durations[phase] = 0
+ avg_durations[phase] = round(sum(durations) / len(durations)) if durations else 0
- # Phase mit meistem Overtime finden
most_overtime_phase = None
if phase_overtime:
most_overtime_phase = max(phase_overtime, key=phase_overtime.get)
return TeacherAnalytics(
- teacher_id=teacher_id,
- period_start=period_start,
- period_end=period_end,
- total_sessions=total_sessions,
- completed_sessions=completed_sessions,
+ teacher_id=teacher_id, period_start=period_start, period_end=period_end,
+ total_sessions=total_sessions, completed_sessions=completed_sessions,
total_teaching_minutes=int(total_minutes),
avg_phase_durations=avg_durations,
sessions_with_overtime=overtime_count,
@@ -515,6 +220,5 @@ class AnalyticsCalculator:
most_overtime_phase=most_overtime_phase,
avg_pause_count=sum(pause_counts) / max(len(pause_counts), 1),
avg_pause_duration_seconds=sum(pause_durations) / max(len(pause_durations), 1),
- subjects_taught=subjects,
- classes_taught=classes,
+ subjects_taught=subjects, classes_taught=classes,
)
diff --git a/backend-lehrer/classroom_engine/analytics_models.py b/backend-lehrer/classroom_engine/analytics_models.py
new file mode 100644
index 0000000..c18008e
--- /dev/null
+++ b/backend-lehrer/classroom_engine/analytics_models.py
@@ -0,0 +1,205 @@
+"""
+Analytics Models - Datenstrukturen fuer Classroom Analytics.
+
+Enthaelt PhaseStatistics, SessionSummary, TeacherAnalytics, LessonReflection.
+"""
+
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import Optional, List, Dict, Any
+
+
+@dataclass
+class PhaseStatistics:
+ """Statistik fuer eine einzelne Phase."""
+ phase: str
+ display_name: str
+
+ # Dauer-Metriken
+ planned_duration_seconds: int
+ actual_duration_seconds: int
+ difference_seconds: int # positiv = laenger als geplant
+
+ # Overtime
+ had_overtime: bool
+ overtime_seconds: int = 0
+
+ # Erweiterungen
+ was_extended: bool = False
+ extension_minutes: int = 0
+
+ # Pausen
+ pause_count: int = 0
+ total_pause_seconds: int = 0
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "phase": self.phase,
+ "display_name": self.display_name,
+ "planned_duration_seconds": self.planned_duration_seconds,
+ "actual_duration_seconds": self.actual_duration_seconds,
+ "difference_seconds": self.difference_seconds,
+ "difference_formatted": self._format_difference(),
+ "had_overtime": self.had_overtime,
+ "overtime_seconds": self.overtime_seconds,
+ "overtime_formatted": self._format_seconds(self.overtime_seconds),
+ "was_extended": self.was_extended,
+ "extension_minutes": self.extension_minutes,
+ "pause_count": self.pause_count,
+ "total_pause_seconds": self.total_pause_seconds,
+ }
+
+ def _format_difference(self) -> str:
+ prefix = "+" if self.difference_seconds >= 0 else ""
+ return f"{prefix}{self._format_seconds(abs(self.difference_seconds))}"
+
+ def _format_seconds(self, seconds: int) -> str:
+ mins = seconds // 60
+ secs = seconds % 60
+ return f"{mins:02d}:{secs:02d}"
+
+
+@dataclass
+class SessionSummary:
+ """Zusammenfassung einer Unterrichtsstunde."""
+ session_id: str
+ teacher_id: str
+ class_id: str
+ subject: str
+ topic: Optional[str]
+ date: datetime
+
+ total_duration_seconds: int
+ planned_duration_seconds: int
+
+ phases_completed: int
+ total_phases: int = 5
+ phase_statistics: List[PhaseStatistics] = field(default_factory=list)
+
+ total_overtime_seconds: int = 0
+ phases_with_overtime: int = 0
+
+ total_pause_count: int = 0
+ total_pause_seconds: int = 0
+
+ reflection_notes: str = ""
+ reflection_rating: Optional[int] = None
+ key_learnings: List[str] = field(default_factory=list)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "session_id": self.session_id,
+ "teacher_id": self.teacher_id,
+ "class_id": self.class_id,
+ "subject": self.subject,
+ "topic": self.topic,
+ "date": self.date.isoformat() if self.date else None,
+ "date_formatted": self._format_date(),
+ "total_duration_seconds": self.total_duration_seconds,
+ "total_duration_formatted": self._format_seconds(self.total_duration_seconds),
+ "planned_duration_seconds": self.planned_duration_seconds,
+ "planned_duration_formatted": self._format_seconds(self.planned_duration_seconds),
+ "phases_completed": self.phases_completed,
+ "total_phases": self.total_phases,
+ "completion_percentage": round(self.phases_completed / self.total_phases * 100),
+ "phase_statistics": [p.to_dict() for p in self.phase_statistics],
+ "total_overtime_seconds": self.total_overtime_seconds,
+ "total_overtime_formatted": self._format_seconds(self.total_overtime_seconds),
+ "phases_with_overtime": self.phases_with_overtime,
+ "total_pause_count": self.total_pause_count,
+ "total_pause_seconds": self.total_pause_seconds,
+ "reflection_notes": self.reflection_notes,
+ "reflection_rating": self.reflection_rating,
+ "key_learnings": self.key_learnings,
+ }
+
+ def _format_seconds(self, seconds: int) -> str:
+ mins = seconds // 60
+ secs = seconds % 60
+ return f"{mins:02d}:{secs:02d}"
+
+ def _format_date(self) -> str:
+ if not self.date:
+ return ""
+ return self.date.strftime("%d.%m.%Y %H:%M")
+
+
+@dataclass
+class TeacherAnalytics:
+ """Aggregierte Statistiken fuer einen Lehrer."""
+ teacher_id: str
+ period_start: datetime
+ period_end: datetime
+
+ total_sessions: int = 0
+ completed_sessions: int = 0
+ total_teaching_minutes: int = 0
+
+ avg_phase_durations: Dict[str, float] = field(default_factory=dict)
+
+ sessions_with_overtime: int = 0
+ avg_overtime_seconds: float = 0
+ most_overtime_phase: Optional[str] = None
+
+ avg_pause_count: float = 0
+ avg_pause_duration_seconds: float = 0
+
+ subjects_taught: Dict[str, int] = field(default_factory=dict)
+ classes_taught: Dict[str, int] = field(default_factory=dict)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "teacher_id": self.teacher_id,
+ "period_start": self.period_start.isoformat() if self.period_start else None,
+ "period_end": self.period_end.isoformat() if self.period_end else None,
+ "total_sessions": self.total_sessions,
+ "completed_sessions": self.completed_sessions,
+ "total_teaching_minutes": self.total_teaching_minutes,
+ "total_teaching_hours": round(self.total_teaching_minutes / 60, 1),
+ "avg_phase_durations": self.avg_phase_durations,
+ "sessions_with_overtime": self.sessions_with_overtime,
+ "overtime_percentage": round(self.sessions_with_overtime / max(self.total_sessions, 1) * 100),
+ "avg_overtime_seconds": round(self.avg_overtime_seconds),
+ "avg_overtime_formatted": self._format_seconds(int(self.avg_overtime_seconds)),
+ "most_overtime_phase": self.most_overtime_phase,
+ "avg_pause_count": round(self.avg_pause_count, 1),
+ "avg_pause_duration_seconds": round(self.avg_pause_duration_seconds),
+ "subjects_taught": self.subjects_taught,
+ "classes_taught": self.classes_taught,
+ }
+
+ def _format_seconds(self, seconds: int) -> str:
+ mins = seconds // 60
+ secs = seconds % 60
+ return f"{mins:02d}:{secs:02d}"
+
+
+@dataclass
+class LessonReflection:
+ """Post-Lesson Reflection (Feature)."""
+ reflection_id: str
+ session_id: str
+ teacher_id: str
+
+ notes: str = ""
+ overall_rating: Optional[int] = None
+ what_worked: List[str] = field(default_factory=list)
+ improvements: List[str] = field(default_factory=list)
+ notes_for_next_lesson: str = ""
+
+ created_at: Optional[datetime] = None
+ updated_at: Optional[datetime] = None
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "reflection_id": self.reflection_id,
+ "session_id": self.session_id,
+ "teacher_id": self.teacher_id,
+ "notes": self.notes,
+ "overall_rating": self.overall_rating,
+ "what_worked": self.what_worked,
+ "improvements": self.improvements,
+ "notes_for_next_lesson": self.notes_for_next_lesson,
+ "created_at": self.created_at.isoformat() if self.created_at else None,
+ "updated_at": self.updated_at.isoformat() if self.updated_at else None,
+ }
diff --git a/backend-lehrer/classroom_engine/suggestion_data.py b/backend-lehrer/classroom_engine/suggestion_data.py
new file mode 100644
index 0000000..7a9e245
--- /dev/null
+++ b/backend-lehrer/classroom_engine/suggestion_data.py
@@ -0,0 +1,494 @@
+"""
+Phasenspezifische und fachspezifische Vorschlags-Daten (Feature f18).
+
+Enthaelt die vordefinierten Vorschlaege fuer allgemeine Phasen
+und fachspezifische Aktivitaeten.
+"""
+
+from typing import List, Dict, Any
+
+from .models import LessonPhase
+
+
+# Unterstuetzte Faecher fuer fachspezifische Vorschlaege
+SUPPORTED_SUBJECTS = [
+ "mathematik", "mathe", "math",
+ "deutsch",
+ "englisch", "english",
+ "biologie", "bio",
+ "physik",
+ "chemie",
+ "geschichte",
+ "geografie", "erdkunde",
+ "kunst",
+ "musik",
+ "sport",
+ "informatik",
+]
+
+
+# Fachspezifische Vorschlaege (Feature f18)
+SUBJECT_SUGGESTIONS: Dict[str, Dict[LessonPhase, List[Dict[str, Any]]]] = {
+ "mathematik": {
+ LessonPhase.EINSTIEG: [
+ {
+ "id": "math_warm_up",
+ "title": "Kopfrechnen-Challenge",
+ "description": "5 schnelle Kopfrechenaufgaben zum Aufwaermen",
+ "activity_type": "warmup",
+ "estimated_minutes": 3,
+ "icon": "calculate",
+ "subjects": ["mathematik", "mathe"],
+ },
+ {
+ "id": "math_puzzle",
+ "title": "Mathematisches Raetsel",
+ "description": "Ein kniffliges Zahlenraetsel als Einstieg",
+ "activity_type": "motivation",
+ "estimated_minutes": 5,
+ "icon": "extension",
+ "subjects": ["mathematik", "mathe"],
+ },
+ ],
+ LessonPhase.ERARBEITUNG: [
+ {
+ "id": "math_geogebra",
+ "title": "GeoGebra-Exploration",
+ "description": "Interaktive Visualisierung mit GeoGebra",
+ "activity_type": "individual_work",
+ "estimated_minutes": 15,
+ "icon": "functions",
+ "subjects": ["mathematik", "mathe"],
+ },
+ {
+ "id": "math_peer_explain",
+ "title": "Rechenweg erklaeren",
+ "description": "Schueler erklaeren sich gegenseitig ihre Loesungswege",
+ "activity_type": "partner_work",
+ "estimated_minutes": 10,
+ "icon": "groups",
+ "subjects": ["mathematik", "mathe"],
+ },
+ ],
+ LessonPhase.SICHERUNG: [
+ {
+ "id": "math_formula_card",
+ "title": "Formelkarte erstellen",
+ "description": "Wichtigste Formeln auf einer Karte festhalten",
+ "activity_type": "documentation",
+ "estimated_minutes": 5,
+ "icon": "note_alt",
+ "subjects": ["mathematik", "mathe"],
+ },
+ ],
+ },
+ "deutsch": {
+ LessonPhase.EINSTIEG: [
+ {
+ "id": "deutsch_wordle",
+ "title": "Wordle-Variante",
+ "description": "Wort des Tages erraten",
+ "activity_type": "warmup",
+ "estimated_minutes": 4,
+ "icon": "abc",
+ "subjects": ["deutsch"],
+ },
+ {
+ "id": "deutsch_zitat",
+ "title": "Zitat-Interpretation",
+ "description": "Ein literarisches Zitat gemeinsam deuten",
+ "activity_type": "motivation",
+ "estimated_minutes": 5,
+ "icon": "format_quote",
+ "subjects": ["deutsch"],
+ },
+ ],
+ LessonPhase.ERARBEITUNG: [
+ {
+ "id": "deutsch_textarbeit",
+ "title": "Textanalyse in Gruppen",
+ "description": "Gruppenarbeit zu verschiedenen Textabschnitten",
+ "activity_type": "group_work",
+ "estimated_minutes": 15,
+ "icon": "menu_book",
+ "subjects": ["deutsch"],
+ },
+ {
+ "id": "deutsch_schreibworkshop",
+ "title": "Schreibwerkstatt",
+ "description": "Kreatives Schreiben mit Peer-Feedback",
+ "activity_type": "individual_work",
+ "estimated_minutes": 20,
+ "icon": "edit_note",
+ "subjects": ["deutsch"],
+ },
+ ],
+ LessonPhase.SICHERUNG: [
+ {
+ "id": "deutsch_zusammenfassung",
+ "title": "Text-Zusammenfassung",
+ "description": "Die wichtigsten Punkte in 3 Saetzen formulieren",
+ "activity_type": "summary",
+ "estimated_minutes": 5,
+ "icon": "summarize",
+ "subjects": ["deutsch"],
+ },
+ ],
+ },
+ "englisch": {
+ LessonPhase.EINSTIEG: [
+ {
+ "id": "english_smalltalk",
+ "title": "Small Talk Warm-Up",
+ "description": "2-Minuten Gespraeche zu einem Alltagsthema",
+ "activity_type": "warmup",
+ "estimated_minutes": 4,
+ "icon": "chat",
+ "subjects": ["englisch", "english"],
+ },
+ {
+ "id": "english_video",
+ "title": "Authentic Video Clip",
+ "description": "Kurzer Clip aus einer englischen Serie oder Nachricht",
+ "activity_type": "motivation",
+ "estimated_minutes": 5,
+ "icon": "movie",
+ "subjects": ["englisch", "english"],
+ },
+ ],
+ LessonPhase.ERARBEITUNG: [
+ {
+ "id": "english_role_play",
+ "title": "Role Play Activity",
+ "description": "Dialoguebung in authentischen Situationen",
+ "activity_type": "partner_work",
+ "estimated_minutes": 12,
+ "icon": "theater_comedy",
+ "subjects": ["englisch", "english"],
+ },
+ {
+ "id": "english_reading_circle",
+ "title": "Reading Circle",
+ "description": "Gemeinsames Lesen mit verteilten Rollen",
+ "activity_type": "group_work",
+ "estimated_minutes": 15,
+ "icon": "auto_stories",
+ "subjects": ["englisch", "english"],
+ },
+ ],
+ },
+ "biologie": {
+ LessonPhase.EINSTIEG: [
+ {
+ "id": "bio_nature_question",
+ "title": "Naturfrage",
+ "description": "Eine spannende Frage aus der Natur diskutieren",
+ "activity_type": "motivation",
+ "estimated_minutes": 5,
+ "icon": "eco",
+ "subjects": ["biologie", "bio"],
+ },
+ ],
+ LessonPhase.ERARBEITUNG: [
+ {
+ "id": "bio_experiment",
+ "title": "Mini-Experiment",
+ "description": "Einfaches Experiment zum Thema durchfuehren",
+ "activity_type": "group_work",
+ "estimated_minutes": 20,
+ "icon": "science",
+ "subjects": ["biologie", "bio"],
+ },
+ {
+ "id": "bio_diagram",
+ "title": "Biologische Zeichnung",
+ "description": "Beschriftete Zeichnung eines Organismus",
+ "activity_type": "individual_work",
+ "estimated_minutes": 15,
+ "icon": "draw",
+ "subjects": ["biologie", "bio"],
+ },
+ ],
+ },
+ "physik": {
+ LessonPhase.EINSTIEG: [
+ {
+ "id": "physik_demo",
+ "title": "Phaenomen-Demo",
+ "description": "Ein physikalisches Phaenomen vorfuehren",
+ "activity_type": "motivation",
+ "estimated_minutes": 5,
+ "icon": "bolt",
+ "subjects": ["physik"],
+ },
+ ],
+ LessonPhase.ERARBEITUNG: [
+ {
+ "id": "physik_simulation",
+ "title": "PhET-Simulation",
+ "description": "Interaktive Simulation von phet.colorado.edu",
+ "activity_type": "individual_work",
+ "estimated_minutes": 15,
+ "icon": "smart_toy",
+ "subjects": ["physik"],
+ },
+ {
+ "id": "physik_rechnung",
+ "title": "Physikalische Rechnung",
+ "description": "Rechenaufgabe mit physikalischem Kontext",
+ "activity_type": "partner_work",
+ "estimated_minutes": 12,
+ "icon": "calculate",
+ "subjects": ["physik"],
+ },
+ ],
+ },
+ "informatik": {
+ LessonPhase.EINSTIEG: [
+ {
+ "id": "info_code_puzzle",
+ "title": "Code-Puzzle",
+ "description": "Kurzen Code-Schnipsel analysieren - was macht er?",
+ "activity_type": "warmup",
+ "estimated_minutes": 4,
+ "icon": "code",
+ "subjects": ["informatik"],
+ },
+ ],
+ LessonPhase.ERARBEITUNG: [
+ {
+ "id": "info_live_coding",
+ "title": "Live Coding",
+ "description": "Gemeinsam Code entwickeln mit Erklaerungen",
+ "activity_type": "instruction",
+ "estimated_minutes": 15,
+ "icon": "terminal",
+ "subjects": ["informatik"],
+ },
+ {
+ "id": "info_pair_programming",
+ "title": "Pair Programming",
+ "description": "Zu zweit programmieren - Driver und Navigator",
+ "activity_type": "partner_work",
+ "estimated_minutes": 20,
+ "icon": "computer",
+ "subjects": ["informatik"],
+ },
+ ],
+ },
+}
+
+
+# Vordefinierte allgemeine Vorschlaege pro Phase
+PHASE_SUGGESTIONS: Dict[LessonPhase, List[Dict[str, Any]]] = {
+ LessonPhase.EINSTIEG: [
+ {
+ "id": "warmup_quiz",
+ "title": "Kurzes Quiz zum Einstieg",
+ "description": "Aktivieren Sie das Vorwissen der Schueler mit 3-5 Fragen zum Thema",
+ "activity_type": "warmup",
+ "estimated_minutes": 3,
+ "icon": "quiz"
+ },
+ {
+ "id": "problem_story",
+ "title": "Problemgeschichte erzaehlen",
+ "description": "Stellen Sie ein alltagsnahes Problem vor, das zum Thema fuehrt",
+ "activity_type": "motivation",
+ "estimated_minutes": 5,
+ "icon": "auto_stories"
+ },
+ {
+ "id": "video_intro",
+ "title": "Kurzes Erklaervideo",
+ "description": "Zeigen Sie ein 2-3 Minuten Video zur Einfuehrung ins Thema",
+ "activity_type": "motivation",
+ "estimated_minutes": 4,
+ "icon": "play_circle"
+ },
+ {
+ "id": "brainstorming",
+ "title": "Brainstorming",
+ "description": "Sammeln Sie Ideen und Vorkenntnisse der Schueler an der Tafel",
+ "activity_type": "warmup",
+ "estimated_minutes": 5,
+ "icon": "psychology"
+ },
+ {
+ "id": "daily_challenge",
+ "title": "Tagesaufgabe vorstellen",
+ "description": "Praesentieren Sie die zentrale Frage oder Aufgabe der Stunde",
+ "activity_type": "problem_introduction",
+ "estimated_minutes": 3,
+ "icon": "flag"
+ }
+ ],
+ LessonPhase.ERARBEITUNG: [
+ {
+ "id": "think_pair_share",
+ "title": "Think-Pair-Share",
+ "description": "Schueler denken erst einzeln nach, tauschen sich dann zu zweit aus und praesentieren im Plenum",
+ "activity_type": "partner_work",
+ "estimated_minutes": 10,
+ "icon": "groups"
+ },
+ {
+ "id": "worksheet_digital",
+ "title": "Digitales Arbeitsblatt",
+ "description": "Schueler bearbeiten ein interaktives Arbeitsblatt am Tablet oder Computer",
+ "activity_type": "individual_work",
+ "estimated_minutes": 15,
+ "icon": "description"
+ },
+ {
+ "id": "station_learning",
+ "title": "Stationenlernen",
+ "description": "Verschiedene Stationen mit unterschiedlichen Aufgaben und Materialien",
+ "activity_type": "group_work",
+ "estimated_minutes": 20,
+ "icon": "hub"
+ },
+ {
+ "id": "expert_puzzle",
+ "title": "Expertenrunde (Jigsaw)",
+ "description": "Schueler werden Experten fuer ein Teilthema und lehren es anderen",
+ "activity_type": "group_work",
+ "estimated_minutes": 15,
+ "icon": "extension"
+ },
+ {
+ "id": "guided_instruction",
+ "title": "Geleitete Instruktion",
+ "description": "Schrittweise Erklaerung mit Uebungsphasen zwischendurch",
+ "activity_type": "instruction",
+ "estimated_minutes": 12,
+ "icon": "school"
+ },
+ {
+ "id": "pair_programming",
+ "title": "Partnerarbeit",
+ "description": "Zwei Schueler loesen gemeinsam eine Aufgabe",
+ "activity_type": "partner_work",
+ "estimated_minutes": 10,
+ "icon": "people"
+ }
+ ],
+ LessonPhase.SICHERUNG: [
+ {
+ "id": "mindmap_class",
+ "title": "Gemeinsame Mindmap",
+ "description": "Ergebnisse als Mindmap an der Tafel oder digital sammeln und strukturieren",
+ "activity_type": "visualization",
+ "estimated_minutes": 8,
+ "icon": "account_tree"
+ },
+ {
+ "id": "exit_ticket",
+ "title": "Exit Ticket",
+ "description": "Schueler notieren 3 Dinge die sie gelernt haben und 1 offene Frage",
+ "activity_type": "summary",
+ "estimated_minutes": 5,
+ "icon": "sticky_note_2"
+ },
+ {
+ "id": "gallery_walk",
+ "title": "Galerie-Rundgang",
+ "description": "Schueler praesentieren ihre Ergebnisse und geben sich Feedback",
+ "activity_type": "presentation",
+ "estimated_minutes": 10,
+ "icon": "photo_library"
+ },
+ {
+ "id": "key_points",
+ "title": "Kernpunkte zusammenfassen",
+ "description": "Gemeinsam die wichtigsten Erkenntnisse der Stunde formulieren",
+ "activity_type": "summary",
+ "estimated_minutes": 5,
+ "icon": "format_list_bulleted"
+ },
+ {
+ "id": "quick_check",
+ "title": "Schneller Wissenscheck",
+ "description": "5 kurze Fragen zur Ueberpruefung des Verstaendnisses",
+ "activity_type": "documentation",
+ "estimated_minutes": 5,
+ "icon": "fact_check"
+ }
+ ],
+ LessonPhase.TRANSFER: [
+ {
+ "id": "real_world_example",
+ "title": "Alltagsbeispiele finden",
+ "description": "Schueler suchen Beispiele aus ihrem Alltag, wo das Gelernte vorkommt",
+ "activity_type": "application",
+ "estimated_minutes": 5,
+ "icon": "public"
+ },
+ {
+ "id": "challenge_task",
+ "title": "Knobelaufgabe",
+ "description": "Eine anspruchsvollere Aufgabe fuer schnelle Schueler oder als Bonus",
+ "activity_type": "differentiation",
+ "estimated_minutes": 7,
+ "icon": "psychology"
+ },
+ {
+ "id": "creative_application",
+ "title": "Kreative Anwendung",
+ "description": "Schueler wenden das Gelernte in einem kreativen Projekt an",
+ "activity_type": "application",
+ "estimated_minutes": 10,
+ "icon": "palette"
+ },
+ {
+ "id": "peer_teaching",
+ "title": "Peer-Teaching",
+ "description": "Schueler erklaeren sich gegenseitig das Gelernte",
+ "activity_type": "real_world_connection",
+ "estimated_minutes": 5,
+ "icon": "supervisor_account"
+ }
+ ],
+ LessonPhase.REFLEXION: [
+ {
+ "id": "thumbs_feedback",
+ "title": "Daumen-Feedback",
+ "description": "Schnelle Stimmungsabfrage: Daumen hoch/mitte/runter",
+ "activity_type": "feedback",
+ "estimated_minutes": 2,
+ "icon": "thumb_up"
+ },
+ {
+ "id": "homework_assign",
+ "title": "Hausaufgabe vergeben",
+ "description": "Passende Hausaufgabe zur Vertiefung des Gelernten",
+ "activity_type": "homework",
+ "estimated_minutes": 3,
+ "icon": "home_work"
+ },
+ {
+ "id": "one_word",
+ "title": "Ein-Wort-Reflexion",
+ "description": "Jeder Schueler nennt ein Wort, das die Stunde beschreibt",
+ "activity_type": "feedback",
+ "estimated_minutes": 3,
+ "icon": "chat"
+ },
+ {
+ "id": "preview_next",
+ "title": "Ausblick naechste Stunde",
+ "description": "Kurzer Ausblick auf das Thema der naechsten Stunde",
+ "activity_type": "preview",
+ "estimated_minutes": 2,
+ "icon": "event"
+ },
+ {
+ "id": "learning_log",
+ "title": "Lerntagebuch",
+ "description": "Schueler notieren ihre wichtigsten Erkenntnisse im Lerntagebuch",
+ "activity_type": "feedback",
+ "estimated_minutes": 4,
+ "icon": "menu_book"
+ }
+ ]
+}
diff --git a/backend-lehrer/classroom_engine/suggestions.py b/backend-lehrer/classroom_engine/suggestions.py
index 3aa148b..b8e36d2 100644
--- a/backend-lehrer/classroom_engine/suggestions.py
+++ b/backend-lehrer/classroom_engine/suggestions.py
@@ -8,490 +8,11 @@ und optional dem Fach.
from typing import List, Dict, Any, Optional
from .models import LessonPhase, LessonSession, PhaseSuggestion
-
-
-# Unterstuetzte Faecher fuer fachspezifische Vorschlaege
-SUPPORTED_SUBJECTS = [
- "mathematik", "mathe", "math",
- "deutsch",
- "englisch", "english",
- "biologie", "bio",
- "physik",
- "chemie",
- "geschichte",
- "geografie", "erdkunde",
- "kunst",
- "musik",
- "sport",
- "informatik",
-]
-
-
-# Fachspezifische Vorschlaege (Feature f18)
-SUBJECT_SUGGESTIONS: Dict[str, Dict[LessonPhase, List[Dict[str, Any]]]] = {
- "mathematik": {
- LessonPhase.EINSTIEG: [
- {
- "id": "math_warm_up",
- "title": "Kopfrechnen-Challenge",
- "description": "5 schnelle Kopfrechenaufgaben zum Aufwaermen",
- "activity_type": "warmup",
- "estimated_minutes": 3,
- "icon": "calculate",
- "subjects": ["mathematik", "mathe"],
- },
- {
- "id": "math_puzzle",
- "title": "Mathematisches Raetsel",
- "description": "Ein kniffliges Zahlenraetsel als Einstieg",
- "activity_type": "motivation",
- "estimated_minutes": 5,
- "icon": "extension",
- "subjects": ["mathematik", "mathe"],
- },
- ],
- LessonPhase.ERARBEITUNG: [
- {
- "id": "math_geogebra",
- "title": "GeoGebra-Exploration",
- "description": "Interaktive Visualisierung mit GeoGebra",
- "activity_type": "individual_work",
- "estimated_minutes": 15,
- "icon": "functions",
- "subjects": ["mathematik", "mathe"],
- },
- {
- "id": "math_peer_explain",
- "title": "Rechenweg erklaeren",
- "description": "Schueler erklaeren sich gegenseitig ihre Loesungswege",
- "activity_type": "partner_work",
- "estimated_minutes": 10,
- "icon": "groups",
- "subjects": ["mathematik", "mathe"],
- },
- ],
- LessonPhase.SICHERUNG: [
- {
- "id": "math_formula_card",
- "title": "Formelkarte erstellen",
- "description": "Wichtigste Formeln auf einer Karte festhalten",
- "activity_type": "documentation",
- "estimated_minutes": 5,
- "icon": "note_alt",
- "subjects": ["mathematik", "mathe"],
- },
- ],
- },
- "deutsch": {
- LessonPhase.EINSTIEG: [
- {
- "id": "deutsch_wordle",
- "title": "Wordle-Variante",
- "description": "Wort des Tages erraten",
- "activity_type": "warmup",
- "estimated_minutes": 4,
- "icon": "abc",
- "subjects": ["deutsch"],
- },
- {
- "id": "deutsch_zitat",
- "title": "Zitat-Interpretation",
- "description": "Ein literarisches Zitat gemeinsam deuten",
- "activity_type": "motivation",
- "estimated_minutes": 5,
- "icon": "format_quote",
- "subjects": ["deutsch"],
- },
- ],
- LessonPhase.ERARBEITUNG: [
- {
- "id": "deutsch_textarbeit",
- "title": "Textanalyse in Gruppen",
- "description": "Gruppenarbeit zu verschiedenen Textabschnitten",
- "activity_type": "group_work",
- "estimated_minutes": 15,
- "icon": "menu_book",
- "subjects": ["deutsch"],
- },
- {
- "id": "deutsch_schreibworkshop",
- "title": "Schreibwerkstatt",
- "description": "Kreatives Schreiben mit Peer-Feedback",
- "activity_type": "individual_work",
- "estimated_minutes": 20,
- "icon": "edit_note",
- "subjects": ["deutsch"],
- },
- ],
- LessonPhase.SICHERUNG: [
- {
- "id": "deutsch_zusammenfassung",
- "title": "Text-Zusammenfassung",
- "description": "Die wichtigsten Punkte in 3 Saetzen formulieren",
- "activity_type": "summary",
- "estimated_minutes": 5,
- "icon": "summarize",
- "subjects": ["deutsch"],
- },
- ],
- },
- "englisch": {
- LessonPhase.EINSTIEG: [
- {
- "id": "english_smalltalk",
- "title": "Small Talk Warm-Up",
- "description": "2-Minuten Gespraeche zu einem Alltagsthema",
- "activity_type": "warmup",
- "estimated_minutes": 4,
- "icon": "chat",
- "subjects": ["englisch", "english"],
- },
- {
- "id": "english_video",
- "title": "Authentic Video Clip",
- "description": "Kurzer Clip aus einer englischen Serie oder Nachricht",
- "activity_type": "motivation",
- "estimated_minutes": 5,
- "icon": "movie",
- "subjects": ["englisch", "english"],
- },
- ],
- LessonPhase.ERARBEITUNG: [
- {
- "id": "english_role_play",
- "title": "Role Play Activity",
- "description": "Dialoguebung in authentischen Situationen",
- "activity_type": "partner_work",
- "estimated_minutes": 12,
- "icon": "theater_comedy",
- "subjects": ["englisch", "english"],
- },
- {
- "id": "english_reading_circle",
- "title": "Reading Circle",
- "description": "Gemeinsames Lesen mit verteilten Rollen",
- "activity_type": "group_work",
- "estimated_minutes": 15,
- "icon": "auto_stories",
- "subjects": ["englisch", "english"],
- },
- ],
- },
- "biologie": {
- LessonPhase.EINSTIEG: [
- {
- "id": "bio_nature_question",
- "title": "Naturfrage",
- "description": "Eine spannende Frage aus der Natur diskutieren",
- "activity_type": "motivation",
- "estimated_minutes": 5,
- "icon": "eco",
- "subjects": ["biologie", "bio"],
- },
- ],
- LessonPhase.ERARBEITUNG: [
- {
- "id": "bio_experiment",
- "title": "Mini-Experiment",
- "description": "Einfaches Experiment zum Thema durchfuehren",
- "activity_type": "group_work",
- "estimated_minutes": 20,
- "icon": "science",
- "subjects": ["biologie", "bio"],
- },
- {
- "id": "bio_diagram",
- "title": "Biologische Zeichnung",
- "description": "Beschriftete Zeichnung eines Organismus",
- "activity_type": "individual_work",
- "estimated_minutes": 15,
- "icon": "draw",
- "subjects": ["biologie", "bio"],
- },
- ],
- },
- "physik": {
- LessonPhase.EINSTIEG: [
- {
- "id": "physik_demo",
- "title": "Phaenomen-Demo",
- "description": "Ein physikalisches Phaenomen vorfuehren",
- "activity_type": "motivation",
- "estimated_minutes": 5,
- "icon": "bolt",
- "subjects": ["physik"],
- },
- ],
- LessonPhase.ERARBEITUNG: [
- {
- "id": "physik_simulation",
- "title": "PhET-Simulation",
- "description": "Interaktive Simulation von phet.colorado.edu",
- "activity_type": "individual_work",
- "estimated_minutes": 15,
- "icon": "smart_toy",
- "subjects": ["physik"],
- },
- {
- "id": "physik_rechnung",
- "title": "Physikalische Rechnung",
- "description": "Rechenaufgabe mit physikalischem Kontext",
- "activity_type": "partner_work",
- "estimated_minutes": 12,
- "icon": "calculate",
- "subjects": ["physik"],
- },
- ],
- },
- "informatik": {
- LessonPhase.EINSTIEG: [
- {
- "id": "info_code_puzzle",
- "title": "Code-Puzzle",
- "description": "Kurzen Code-Schnipsel analysieren - was macht er?",
- "activity_type": "warmup",
- "estimated_minutes": 4,
- "icon": "code",
- "subjects": ["informatik"],
- },
- ],
- LessonPhase.ERARBEITUNG: [
- {
- "id": "info_live_coding",
- "title": "Live Coding",
- "description": "Gemeinsam Code entwickeln mit Erklaerungen",
- "activity_type": "instruction",
- "estimated_minutes": 15,
- "icon": "terminal",
- "subjects": ["informatik"],
- },
- {
- "id": "info_pair_programming",
- "title": "Pair Programming",
- "description": "Zu zweit programmieren - Driver und Navigator",
- "activity_type": "partner_work",
- "estimated_minutes": 20,
- "icon": "computer",
- "subjects": ["informatik"],
- },
- ],
- },
-}
-
-
-# Vordefinierte allgemeine Vorschlaege pro Phase
-PHASE_SUGGESTIONS: Dict[LessonPhase, List[Dict[str, Any]]] = {
- LessonPhase.EINSTIEG: [
- {
- "id": "warmup_quiz",
- "title": "Kurzes Quiz zum Einstieg",
- "description": "Aktivieren Sie das Vorwissen der Schueler mit 3-5 Fragen zum Thema",
- "activity_type": "warmup",
- "estimated_minutes": 3,
- "icon": "quiz"
- },
- {
- "id": "problem_story",
- "title": "Problemgeschichte erzaehlen",
- "description": "Stellen Sie ein alltagsnahes Problem vor, das zum Thema fuehrt",
- "activity_type": "motivation",
- "estimated_minutes": 5,
- "icon": "auto_stories"
- },
- {
- "id": "video_intro",
- "title": "Kurzes Erklaervideo",
- "description": "Zeigen Sie ein 2-3 Minuten Video zur Einfuehrung ins Thema",
- "activity_type": "motivation",
- "estimated_minutes": 4,
- "icon": "play_circle"
- },
- {
- "id": "brainstorming",
- "title": "Brainstorming",
- "description": "Sammeln Sie Ideen und Vorkenntnisse der Schueler an der Tafel",
- "activity_type": "warmup",
- "estimated_minutes": 5,
- "icon": "psychology"
- },
- {
- "id": "daily_challenge",
- "title": "Tagesaufgabe vorstellen",
- "description": "Praesentieren Sie die zentrale Frage oder Aufgabe der Stunde",
- "activity_type": "problem_introduction",
- "estimated_minutes": 3,
- "icon": "flag"
- }
- ],
- LessonPhase.ERARBEITUNG: [
- {
- "id": "think_pair_share",
- "title": "Think-Pair-Share",
- "description": "Schueler denken erst einzeln nach, tauschen sich dann zu zweit aus und praesentieren im Plenum",
- "activity_type": "partner_work",
- "estimated_minutes": 10,
- "icon": "groups"
- },
- {
- "id": "worksheet_digital",
- "title": "Digitales Arbeitsblatt",
- "description": "Schueler bearbeiten ein interaktives Arbeitsblatt am Tablet oder Computer",
- "activity_type": "individual_work",
- "estimated_minutes": 15,
- "icon": "description"
- },
- {
- "id": "station_learning",
- "title": "Stationenlernen",
- "description": "Verschiedene Stationen mit unterschiedlichen Aufgaben und Materialien",
- "activity_type": "group_work",
- "estimated_minutes": 20,
- "icon": "hub"
- },
- {
- "id": "expert_puzzle",
- "title": "Expertenrunde (Jigsaw)",
- "description": "Schueler werden Experten fuer ein Teilthema und lehren es anderen",
- "activity_type": "group_work",
- "estimated_minutes": 15,
- "icon": "extension"
- },
- {
- "id": "guided_instruction",
- "title": "Geleitete Instruktion",
- "description": "Schrittweise Erklaerung mit Uebungsphasen zwischendurch",
- "activity_type": "instruction",
- "estimated_minutes": 12,
- "icon": "school"
- },
- {
- "id": "pair_programming",
- "title": "Partnerarbeit",
- "description": "Zwei Schueler loesen gemeinsam eine Aufgabe",
- "activity_type": "partner_work",
- "estimated_minutes": 10,
- "icon": "people"
- }
- ],
- LessonPhase.SICHERUNG: [
- {
- "id": "mindmap_class",
- "title": "Gemeinsame Mindmap",
- "description": "Ergebnisse als Mindmap an der Tafel oder digital sammeln und strukturieren",
- "activity_type": "visualization",
- "estimated_minutes": 8,
- "icon": "account_tree"
- },
- {
- "id": "exit_ticket",
- "title": "Exit Ticket",
- "description": "Schueler notieren 3 Dinge die sie gelernt haben und 1 offene Frage",
- "activity_type": "summary",
- "estimated_minutes": 5,
- "icon": "sticky_note_2"
- },
- {
- "id": "gallery_walk",
- "title": "Galerie-Rundgang",
- "description": "Schueler praesentieren ihre Ergebnisse und geben sich Feedback",
- "activity_type": "presentation",
- "estimated_minutes": 10,
- "icon": "photo_library"
- },
- {
- "id": "key_points",
- "title": "Kernpunkte zusammenfassen",
- "description": "Gemeinsam die wichtigsten Erkenntnisse der Stunde formulieren",
- "activity_type": "summary",
- "estimated_minutes": 5,
- "icon": "format_list_bulleted"
- },
- {
- "id": "quick_check",
- "title": "Schneller Wissenscheck",
- "description": "5 kurze Fragen zur Ueberpruefung des Verstaendnisses",
- "activity_type": "documentation",
- "estimated_minutes": 5,
- "icon": "fact_check"
- }
- ],
- LessonPhase.TRANSFER: [
- {
- "id": "real_world_example",
- "title": "Alltagsbeispiele finden",
- "description": "Schueler suchen Beispiele aus ihrem Alltag, wo das Gelernte vorkommt",
- "activity_type": "application",
- "estimated_minutes": 5,
- "icon": "public"
- },
- {
- "id": "challenge_task",
- "title": "Knobelaufgabe",
- "description": "Eine anspruchsvollere Aufgabe fuer schnelle Schueler oder als Bonus",
- "activity_type": "differentiation",
- "estimated_minutes": 7,
- "icon": "psychology"
- },
- {
- "id": "creative_application",
- "title": "Kreative Anwendung",
- "description": "Schueler wenden das Gelernte in einem kreativen Projekt an",
- "activity_type": "application",
- "estimated_minutes": 10,
- "icon": "palette"
- },
- {
- "id": "peer_teaching",
- "title": "Peer-Teaching",
- "description": "Schueler erklaeren sich gegenseitig das Gelernte",
- "activity_type": "real_world_connection",
- "estimated_minutes": 5,
- "icon": "supervisor_account"
- }
- ],
- LessonPhase.REFLEXION: [
- {
- "id": "thumbs_feedback",
- "title": "Daumen-Feedback",
- "description": "Schnelle Stimmungsabfrage: Daumen hoch/mitte/runter",
- "activity_type": "feedback",
- "estimated_minutes": 2,
- "icon": "thumb_up"
- },
- {
- "id": "homework_assign",
- "title": "Hausaufgabe vergeben",
- "description": "Passende Hausaufgabe zur Vertiefung des Gelernten",
- "activity_type": "homework",
- "estimated_minutes": 3,
- "icon": "home_work"
- },
- {
- "id": "one_word",
- "title": "Ein-Wort-Reflexion",
- "description": "Jeder Schueler nennt ein Wort, das die Stunde beschreibt",
- "activity_type": "feedback",
- "estimated_minutes": 3,
- "icon": "chat"
- },
- {
- "id": "preview_next",
- "title": "Ausblick naechste Stunde",
- "description": "Kurzer Ausblick auf das Thema der naechsten Stunde",
- "activity_type": "preview",
- "estimated_minutes": 2,
- "icon": "event"
- },
- {
- "id": "learning_log",
- "title": "Lerntagebuch",
- "description": "Schueler notieren ihre wichtigsten Erkenntnisse im Lerntagebuch",
- "activity_type": "feedback",
- "estimated_minutes": 4,
- "icon": "menu_book"
- }
- ]
-}
+from .suggestion_data import (
+ SUPPORTED_SUBJECTS,
+ SUBJECT_SUGGESTIONS,
+ PHASE_SUGGESTIONS,
+)
class SuggestionEngine:
diff --git a/backend-lehrer/content_generators/__init__.py b/backend-lehrer/content_generators/__init__.py
index 2313c04..2292f89 100644
--- a/backend-lehrer/content_generators/__init__.py
+++ b/backend-lehrer/content_generators/__init__.py
@@ -11,10 +11,12 @@ from .h5p_generator import (
generate_h5p_manifest,
)
-from .pdf_generator import (
- PDFGenerator,
+from .worksheet_models import (
Worksheet,
WorksheetSection,
+)
+from .pdf_generator import (
+ PDFGenerator,
generate_worksheet_html,
generate_worksheet_pdf,
)
diff --git a/backend-lehrer/content_generators/pdf_generator.py b/backend-lehrer/content_generators/pdf_generator.py
index d41c442..c036732 100644
--- a/backend-lehrer/content_generators/pdf_generator.py
+++ b/backend-lehrer/content_generators/pdf_generator.py
@@ -12,252 +12,9 @@ Structure:
6. Reflection Questions
"""
-import io
-from dataclasses import dataclass
-from typing import Any, Optional, Union
+from typing import Optional, Union
-# Note: In production, use reportlab or weasyprint for actual PDF generation
-# This module generates an intermediate format that can be converted to PDF
-
-
-@dataclass
-class WorksheetSection:
- """A section of the worksheet"""
- title: str
- content_type: str # "text", "table", "exercises", "blanks"
- content: Any
- difficulty: int = 1 # 1-4
-
-
-@dataclass
-class Worksheet:
- """Complete worksheet structure"""
- title: str
- subtitle: str
- unit_id: str
- locale: str
- sections: list[WorksheetSection]
- footer: str = ""
-
- def to_html(self) -> str:
- """Convert worksheet to HTML (for PDF conversion via weasyprint)"""
- html_parts = [
- "",
- "",
- "",
- "
",
- "",
- "",
- "",
- f"
{self.title} ",
- f"{self.subtitle}
",
- ]
-
- for section in self.sections:
- html_parts.append(self._render_section(section))
-
- html_parts.extend([
- f"
",
- "",
- ""
- ])
-
- return "\n".join(html_parts)
-
- def _get_styles(self) -> str:
- return """
- @page {
- size: A4;
- margin: 2cm;
- }
- body {
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
- font-size: 11pt;
- line-height: 1.5;
- color: #333;
- }
- header {
- text-align: center;
- margin-bottom: 1.5em;
- border-bottom: 2px solid #2c5282;
- padding-bottom: 1em;
- }
- h1 {
- color: #2c5282;
- margin-bottom: 0.25em;
- font-size: 20pt;
- }
- .subtitle {
- color: #666;
- font-style: italic;
- }
- h2 {
- color: #2c5282;
- border-bottom: 1px solid #e2e8f0;
- padding-bottom: 0.25em;
- margin-top: 1.5em;
- font-size: 14pt;
- }
- h3 {
- color: #4a5568;
- font-size: 12pt;
- }
- table {
- width: 100%;
- border-collapse: collapse;
- margin: 1em 0;
- }
- th, td {
- border: 1px solid #e2e8f0;
- padding: 0.5em;
- text-align: left;
- }
- th {
- background-color: #edf2f7;
- font-weight: bold;
- }
- .exercise {
- margin: 1em 0;
- padding: 1em;
- background-color: #f7fafc;
- border-left: 4px solid #4299e1;
- }
- .exercise-number {
- font-weight: bold;
- color: #2c5282;
- }
- .blank {
- display: inline-block;
- min-width: 100px;
- border-bottom: 1px solid #333;
- margin: 0 0.25em;
- }
- .difficulty {
- font-size: 9pt;
- color: #718096;
- }
- .difficulty-1 { color: #48bb78; }
- .difficulty-2 { color: #4299e1; }
- .difficulty-3 { color: #ed8936; }
- .difficulty-4 { color: #f56565; }
- .reflection {
- margin-top: 2em;
- padding: 1em;
- background-color: #fffaf0;
- border: 1px dashed #ed8936;
- }
- .write-area {
- min-height: 80px;
- border: 1px solid #e2e8f0;
- margin: 0.5em 0;
- background-color: #fff;
- }
- footer {
- margin-top: 2em;
- padding-top: 1em;
- border-top: 1px solid #e2e8f0;
- font-size: 9pt;
- color: #718096;
- text-align: center;
- }
- ul, ol {
- margin: 0.5em 0;
- padding-left: 1.5em;
- }
- .objectives {
- background-color: #ebf8ff;
- padding: 1em;
- border-radius: 4px;
- }
- """
-
- def _render_section(self, section: WorksheetSection) -> str:
- parts = [f"
{section.title} "]
-
- if section.content_type == "text":
- parts.append(f"{section.content}
")
-
- elif section.content_type == "objectives":
- parts.append("")
- for obj in section.content:
- parts.append(f"{obj} ")
- parts.append(" ")
-
- elif section.content_type == "table":
- parts.append("")
- for header in section.content.get("headers", []):
- parts.append(f"{header} ")
- parts.append(" ")
- for row in section.content.get("rows", []):
- parts.append("")
- for cell in row:
- parts.append(f"{cell} ")
- parts.append(" ")
- parts.append("
")
-
- elif section.content_type == "exercises":
- for i, ex in enumerate(section.content, 1):
- diff_class = f"difficulty-{ex.get('difficulty', 1)}"
- diff_stars = "*" * ex.get("difficulty", 1)
- parts.append(f"""
-
-
Aufgabe {i}
-
({diff_stars})
-
{ex.get('question', '')}
- {self._render_exercise_input(ex)}
-
- """)
-
- elif section.content_type == "blanks":
- text = section.content
- # Replace *word* with blank
- import re
- text = re.sub(r'\*([^*]+)\*', r" ", text)
- parts.append(f"{text}
")
-
- elif section.content_type == "reflection":
- parts.append("")
- parts.append(f"
{section.content.get('prompt', '')}
")
- parts.append("
")
- parts.append("
")
-
- parts.append("")
- return "\n".join(parts)
-
- def _render_exercise_input(self, exercise: dict) -> str:
- ex_type = exercise.get("type", "text")
-
- if ex_type == "multiple_choice":
- options = exercise.get("options", [])
- parts = ["
"]
- for opt in options:
- parts.append(f"□ {opt} ")
- parts.append(" ")
- return "\n".join(parts)
-
- elif ex_type == "matching":
- left = exercise.get("left", [])
- right = exercise.get("right", [])
- parts = ["
Begriff Zuordnung "]
- for i, item in enumerate(left):
- right_item = right[i] if i < len(right) else ""
- parts.append(f"{item} ")
- parts.append("
")
- return "\n".join(parts)
-
- elif ex_type == "sequence":
- items = exercise.get("items", [])
- parts = ["
Bringe in die richtige Reihenfolge:
"]
- for item in items:
- parts.append(f" ")
- parts.append(" ")
- parts.append(f"
Begriffe: {', '.join(items)}
")
- return "\n".join(parts)
-
- else:
- return "
"
+from .worksheet_models import Worksheet, WorksheetSection
class PDFGenerator:
@@ -267,15 +24,7 @@ class PDFGenerator:
self.locale = locale
def generate_from_unit(self, unit: dict) -> Worksheet:
- """
- Generate a worksheet from a unit definition.
-
- Args:
- unit: Unit definition dictionary
-
- Returns:
- Worksheet object
- """
+ """Generate a worksheet from a unit definition."""
unit_id = unit.get("unit_id", "unknown")
title = self._get_localized(unit.get("title"), "Arbeitsblatt")
objectives = unit.get("learning_objectives", [])
@@ -283,51 +32,36 @@ class PDFGenerator:
sections = []
- # Learning Objectives
if objectives:
sections.append(WorksheetSection(
- title="Lernziele",
- content_type="objectives",
- content=objectives
+ title="Lernziele", content_type="objectives", content=objectives
))
- # Vocabulary Table
vocab_section = self._create_vocabulary_section(stops)
if vocab_section:
sections.append(vocab_section)
- # Key Concepts Summary
concepts_section = self._create_concepts_section(stops)
if concepts_section:
sections.append(concepts_section)
- # Basic Exercises
basic_exercises = self._create_basic_exercises(stops)
if basic_exercises:
sections.append(WorksheetSection(
- title="Ubungen - Basis",
- content_type="exercises",
- content=basic_exercises,
- difficulty=1
+ title="Ubungen - Basis", content_type="exercises",
+ content=basic_exercises, difficulty=1
))
- # Challenge Exercises
challenge_exercises = self._create_challenge_exercises(stops, unit)
if challenge_exercises:
sections.append(WorksheetSection(
- title="Ubungen - Herausforderung",
- content_type="exercises",
- content=challenge_exercises,
- difficulty=3
+ title="Ubungen - Herausforderung", content_type="exercises",
+ content=challenge_exercises, difficulty=3
))
- # Reflection
sections.append(WorksheetSection(
- title="Reflexion",
- content_type="reflection",
- content={
- "prompt": "Erklaere in eigenen Worten, was du heute gelernt hast:"
- }
+ title="Reflexion", content_type="reflection",
+ content={"prompt": "Erklaere in eigenen Worten, was du heute gelernt hast:"}
))
return Worksheet(
@@ -370,12 +104,8 @@ class PDFGenerator:
return None
return WorksheetSection(
- title="Wichtige Begriffe",
- content_type="table",
- content={
- "headers": ["Begriff", "Erklarung"],
- "rows": rows
- }
+ title="Wichtige Begriffe", content_type="table",
+ content={"headers": ["Begriff", "Erklarung"], "rows": rows}
)
def _create_concepts_section(self, stops: list) -> Optional[WorksheetSection]:
@@ -392,19 +122,14 @@ class PDFGenerator:
return None
return WorksheetSection(
- title="Zusammenfassung",
- content_type="table",
- content={
- "headers": ["Station", "Was hast du gelernt?"],
- "rows": rows
- }
+ title="Zusammenfassung", content_type="table",
+ content={"headers": ["Station", "Was hast du gelernt?"], "rows": rows}
)
def _create_basic_exercises(self, stops: list) -> list[dict]:
"""Create basic difficulty exercises"""
exercises = []
- # Vocabulary matching
vocab_items = []
for stop in stops:
for v in stop.get("vocab", []):
@@ -422,7 +147,6 @@ class PDFGenerator:
"difficulty": 1
})
- # True/False from concepts
for stop in stops[:3]:
concept = stop.get("concept", {})
why = self._get_localized(concept.get("why"))
@@ -435,7 +159,6 @@ class PDFGenerator:
})
break
- # Sequence ordering (for FlightPath)
if len(stops) >= 4:
labels = [self._get_localized(s.get("label")) for s in stops[:6] if self._get_localized(s.get("label"))]
if len(labels) >= 4:
@@ -455,7 +178,6 @@ class PDFGenerator:
"""Create challenging exercises"""
exercises = []
- # Misconception identification
for stop in stops:
concept = stop.get("concept", {})
misconception = self._get_localized(concept.get("common_misconception"))
@@ -472,14 +194,12 @@ class PDFGenerator:
if len(exercises) >= 2:
break
- # Transfer/Application question
exercises.append({
"type": "text",
"question": "Erklaere einem Freund in 2-3 Satzen, was du gelernt hast:",
"difficulty": 3
})
- # Critical thinking
exercises.append({
"type": "text",
"question": "Was moechtest du noch mehr uber dieses Thema erfahren?",
@@ -490,35 +210,14 @@ class PDFGenerator:
def generate_worksheet_html(unit_definition: dict, locale: str = "de-DE") -> str:
- """
- Generate HTML worksheet from unit definition.
-
- Args:
- unit_definition: The unit JSON definition
- locale: Target locale for content
-
- Returns:
- HTML string ready for PDF conversion
- """
+ """Generate HTML worksheet from unit definition."""
generator = PDFGenerator(locale=locale)
worksheet = generator.generate_from_unit(unit_definition)
return worksheet.to_html()
def generate_worksheet_pdf(unit_definition: dict, locale: str = "de-DE") -> bytes:
- """
- Generate PDF worksheet from unit definition.
-
- Requires weasyprint to be installed:
- pip install weasyprint
-
- Args:
- unit_definition: The unit JSON definition
- locale: Target locale for content
-
- Returns:
- PDF bytes
- """
+ """Generate PDF worksheet from unit definition."""
try:
from weasyprint import HTML
except ImportError:
diff --git a/backend-lehrer/content_generators/worksheet_models.py b/backend-lehrer/content_generators/worksheet_models.py
new file mode 100644
index 0000000..e245f5b
--- /dev/null
+++ b/backend-lehrer/content_generators/worksheet_models.py
@@ -0,0 +1,247 @@
+"""
+Worksheet Models - Datenstrukturen und HTML-Rendering fuer Arbeitsblaetter.
+"""
+
+import re
+from dataclasses import dataclass
+from typing import Any
+
+
+@dataclass
+class WorksheetSection:
+ """A section of the worksheet"""
+ title: str
+ content_type: str # "text", "table", "exercises", "blanks"
+ content: Any
+ difficulty: int = 1 # 1-4
+
+
+@dataclass
+class Worksheet:
+ """Complete worksheet structure"""
+ title: str
+ subtitle: str
+ unit_id: str
+ locale: str
+ sections: list[WorksheetSection]
+ footer: str = ""
+
+ def to_html(self) -> str:
+ """Convert worksheet to HTML (for PDF conversion via weasyprint)"""
+ html_parts = [
+ "",
+ "",
+ "",
+ "
",
+ "",
+ "",
+ "",
+ f"
{self.title} ",
+ f"{self.subtitle}
",
+ ]
+
+ for section in self.sections:
+ html_parts.append(_render_section(section))
+
+ html_parts.extend([
+ f"
",
+ "",
+ ""
+ ])
+
+ return "\n".join(html_parts)
+
+
+def _get_styles() -> str:
+ return """
+ @page {
+ size: A4;
+ margin: 2cm;
+ }
+ body {
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
+ font-size: 11pt;
+ line-height: 1.5;
+ color: #333;
+ }
+ header {
+ text-align: center;
+ margin-bottom: 1.5em;
+ border-bottom: 2px solid #2c5282;
+ padding-bottom: 1em;
+ }
+ h1 {
+ color: #2c5282;
+ margin-bottom: 0.25em;
+ font-size: 20pt;
+ }
+ .subtitle {
+ color: #666;
+ font-style: italic;
+ }
+ h2 {
+ color: #2c5282;
+ border-bottom: 1px solid #e2e8f0;
+ padding-bottom: 0.25em;
+ margin-top: 1.5em;
+ font-size: 14pt;
+ }
+ h3 {
+ color: #4a5568;
+ font-size: 12pt;
+ }
+ table {
+ width: 100%;
+ border-collapse: collapse;
+ margin: 1em 0;
+ }
+ th, td {
+ border: 1px solid #e2e8f0;
+ padding: 0.5em;
+ text-align: left;
+ }
+ th {
+ background-color: #edf2f7;
+ font-weight: bold;
+ }
+ .exercise {
+ margin: 1em 0;
+ padding: 1em;
+ background-color: #f7fafc;
+ border-left: 4px solid #4299e1;
+ }
+ .exercise-number {
+ font-weight: bold;
+ color: #2c5282;
+ }
+ .blank {
+ display: inline-block;
+ min-width: 100px;
+ border-bottom: 1px solid #333;
+ margin: 0 0.25em;
+ }
+ .difficulty {
+ font-size: 9pt;
+ color: #718096;
+ }
+ .difficulty-1 { color: #48bb78; }
+ .difficulty-2 { color: #4299e1; }
+ .difficulty-3 { color: #ed8936; }
+ .difficulty-4 { color: #f56565; }
+ .reflection {
+ margin-top: 2em;
+ padding: 1em;
+ background-color: #fffaf0;
+ border: 1px dashed #ed8936;
+ }
+ .write-area {
+ min-height: 80px;
+ border: 1px solid #e2e8f0;
+ margin: 0.5em 0;
+ background-color: #fff;
+ }
+ footer {
+ margin-top: 2em;
+ padding-top: 1em;
+ border-top: 1px solid #e2e8f0;
+ font-size: 9pt;
+ color: #718096;
+ text-align: center;
+ }
+ ul, ol {
+ margin: 0.5em 0;
+ padding-left: 1.5em;
+ }
+ .objectives {
+ background-color: #ebf8ff;
+ padding: 1em;
+ border-radius: 4px;
+ }
+ """
+
+
+def _render_section(section: WorksheetSection) -> str:
+ parts = [f"
{section.title} "]
+
+ if section.content_type == "text":
+ parts.append(f"{section.content}
")
+
+ elif section.content_type == "objectives":
+ parts.append("")
+ for obj in section.content:
+ parts.append(f"{obj} ")
+ parts.append(" ")
+
+ elif section.content_type == "table":
+ parts.append("")
+ for header in section.content.get("headers", []):
+ parts.append(f"{header} ")
+ parts.append(" ")
+ for row in section.content.get("rows", []):
+ parts.append("")
+ for cell in row:
+ parts.append(f"{cell} ")
+ parts.append(" ")
+ parts.append("
")
+
+ elif section.content_type == "exercises":
+ for i, ex in enumerate(section.content, 1):
+ diff_class = f"difficulty-{ex.get('difficulty', 1)}"
+ diff_stars = "*" * ex.get("difficulty", 1)
+ parts.append(f"""
+
+
Aufgabe {i}
+
({diff_stars})
+
{ex.get('question', '')}
+ {_render_exercise_input(ex)}
+
+ """)
+
+ elif section.content_type == "blanks":
+ text = section.content
+ text = re.sub(r'\*([^*]+)\*', r" ", text)
+ parts.append(f"{text}
")
+
+ elif section.content_type == "reflection":
+ parts.append("")
+ parts.append(f"
{section.content.get('prompt', '')}
")
+ parts.append("
")
+ parts.append("
")
+
+ parts.append("")
+ return "\n".join(parts)
+
+
+def _render_exercise_input(exercise: dict) -> str:
+ ex_type = exercise.get("type", "text")
+
+ if ex_type == "multiple_choice":
+ options = exercise.get("options", [])
+ parts = ["
"]
+ for opt in options:
+ parts.append(f"□ {opt} ")
+ parts.append(" ")
+ return "\n".join(parts)
+
+ elif ex_type == "matching":
+ left = exercise.get("left", [])
+ right = exercise.get("right", [])
+ parts = ["
Begriff Zuordnung "]
+ for i, item in enumerate(left):
+ parts.append(f"{item} ")
+ parts.append("
")
+ return "\n".join(parts)
+
+ elif ex_type == "sequence":
+ items = exercise.get("items", [])
+ parts = ["
Bringe in die richtige Reihenfolge:
"]
+ for item in items:
+ parts.append(f" ")
+ parts.append(" ")
+ parts.append(f"
Begriffe: {', '.join(items)}
")
+ return "\n".join(parts)
+
+ else:
+ return "
"
diff --git a/backend-lehrer/generators/quiz_generator.py b/backend-lehrer/generators/quiz_generator.py
index 3f4b6e5..4732cd1 100644
--- a/backend-lehrer/generators/quiz_generator.py
+++ b/backend-lehrer/generators/quiz_generator.py
@@ -10,66 +10,27 @@ Generiert:
import logging
import json
-import re
-from typing import List, Dict, Any, Optional, Tuple
-from dataclasses import dataclass
-from enum import Enum
+from typing import List, Dict, Any, Optional
+
+from .quiz_models import (
+ QuizType,
+ TrueFalseQuestion,
+ MatchingPair,
+ SortingItem,
+ OpenQuestion,
+ Quiz,
+)
+from .quiz_helpers import (
+ extract_factual_sentences,
+ negate_sentence,
+ extract_definitions,
+ extract_sequence,
+ extract_keywords,
+)
logger = logging.getLogger(__name__)
-class QuizType(str, Enum):
- """Typen von Quiz-Aufgaben."""
- TRUE_FALSE = "true_false"
- MATCHING = "matching"
- SORTING = "sorting"
- OPEN_ENDED = "open_ended"
-
-
-@dataclass
-class TrueFalseQuestion:
- """Eine Wahr/Falsch-Frage."""
- statement: str
- is_true: bool
- explanation: str
- source_reference: Optional[str] = None
-
-
-@dataclass
-class MatchingPair:
- """Ein Zuordnungspaar."""
- left: str
- right: str
- hint: Optional[str] = None
-
-
-@dataclass
-class SortingItem:
- """Ein Element zum Sortieren."""
- text: str
- correct_position: int
- category: Optional[str] = None
-
-
-@dataclass
-class OpenQuestion:
- """Eine offene Frage."""
- question: str
- model_answer: str
- keywords: List[str]
- points: int = 1
-
-
-@dataclass
-class Quiz:
- """Ein komplettes Quiz."""
- quiz_type: QuizType
- title: str
- questions: List[Any] # Je nach Typ unterschiedlich
- topic: Optional[str] = None
- difficulty: str = "medium"
-
-
class QuizGenerator:
"""
Generiert verschiedene Quiz-Typen aus Quelltexten.
@@ -146,13 +107,12 @@ class QuizGenerator:
return self._generate_true_false_llm(source_text, num_questions, difficulty)
# Automatische Generierung
- sentences = self._extract_factual_sentences(source_text)
+ sentences = extract_factual_sentences(source_text)
questions = []
for i, sentence in enumerate(sentences[:num_questions]):
# Abwechselnd wahre und falsche Aussagen
if i % 2 == 0:
- # Wahre Aussage
questions.append(TrueFalseQuestion(
statement=sentence,
is_true=True,
@@ -160,8 +120,7 @@ class QuizGenerator:
source_reference=sentence[:50]
))
else:
- # Falsche Aussage (Negation)
- false_statement = self._negate_sentence(sentence)
+ false_statement = negate_sentence(sentence)
questions.append(TrueFalseQuestion(
statement=false_statement,
is_true=False,
@@ -222,9 +181,8 @@ Antworte im JSON-Format:
if self.llm_client:
return self._generate_matching_llm(source_text, num_pairs, difficulty)
- # Automatische Generierung: Begriff -> Definition
pairs = []
- definitions = self._extract_definitions(source_text)
+ definitions = extract_definitions(source_text)
for term, definition in definitions[:num_pairs]:
pairs.append(MatchingPair(
@@ -286,9 +244,8 @@ Antworte im JSON-Format:
if self.llm_client:
return self._generate_sorting_llm(source_text, num_items, difficulty)
- # Automatische Generierung: Chronologische Reihenfolge
items = []
- steps = self._extract_sequence(source_text)
+ steps = extract_sequence(source_text)
for i, step in enumerate(steps[:num_items]):
items.append(SortingItem(
@@ -349,9 +306,8 @@ Antworte im JSON-Format:
if self.llm_client:
return self._generate_open_ended_llm(source_text, num_questions, difficulty)
- # Automatische Generierung
questions = []
- sentences = self._extract_factual_sentences(source_text)
+ sentences = extract_factual_sentences(source_text)
question_starters = [
"Was bedeutet",
@@ -362,8 +318,7 @@ Antworte im JSON-Format:
]
for i, sentence in enumerate(sentences[:num_questions]):
- # Extrahiere Schlüsselwort
- keywords = self._extract_keywords(sentence)
+ keywords = extract_keywords(sentence)
if keywords:
keyword = keywords[0]
starter = question_starters[i % len(question_starters)]
@@ -421,76 +376,6 @@ Antworte im JSON-Format:
logger.error(f"LLM error: {e}")
return self._generate_open_ended(source_text, num_questions, difficulty)
- # Hilfsmethoden
-
- def _extract_factual_sentences(self, text: str) -> List[str]:
- """Extrahiert Fakten-Sätze aus dem Text."""
- sentences = re.split(r'[.!?]+', text)
- factual = []
-
- for sentence in sentences:
- sentence = sentence.strip()
- # Filtere zu kurze oder fragende Sätze
- if len(sentence) > 20 and '?' not in sentence:
- factual.append(sentence)
-
- return factual
-
- def _negate_sentence(self, sentence: str) -> str:
- """Negiert eine Aussage einfach."""
- # Einfache Negation durch Einfügen von "nicht"
- words = sentence.split()
- if len(words) > 2:
- # Nach erstem Verb "nicht" einfügen
- for i, word in enumerate(words):
- if word.endswith(('t', 'en', 'st')) and i > 0:
- words.insert(i + 1, 'nicht')
- break
- return ' '.join(words)
-
- def _extract_definitions(self, text: str) -> List[Tuple[str, str]]:
- """Extrahiert Begriff-Definition-Paare."""
- definitions = []
-
- # Suche nach Mustern wie "X ist Y" oder "X bezeichnet Y"
- patterns = [
- r'(\w+)\s+ist\s+(.+?)[.]',
- r'(\w+)\s+bezeichnet\s+(.+?)[.]',
- r'(\w+)\s+bedeutet\s+(.+?)[.]',
- r'(\w+):\s+(.+?)[.]',
- ]
-
- for pattern in patterns:
- matches = re.findall(pattern, text)
- for term, definition in matches:
- if len(definition) > 10:
- definitions.append((term, definition.strip()))
-
- return definitions
-
- def _extract_sequence(self, text: str) -> List[str]:
- """Extrahiert eine Sequenz von Schritten."""
- steps = []
-
- # Suche nach nummerierten Schritten
- numbered = re.findall(r'\d+[.)]\s*([^.]+)', text)
- steps.extend(numbered)
-
- # Suche nach Signalwörtern
- signal_words = ['zuerst', 'dann', 'danach', 'anschließend', 'schließlich']
- for word in signal_words:
- pattern = rf'{word}\s+([^.]+)'
- matches = re.findall(pattern, text, re.IGNORECASE)
- steps.extend(matches)
-
- return steps
-
- def _extract_keywords(self, text: str) -> List[str]:
- """Extrahiert Schlüsselwörter."""
- # Längere Wörter mit Großbuchstaben (meist Substantive)
- words = re.findall(r'\b[A-ZÄÖÜ][a-zäöüß]+\b', text)
- return list(set(words))[:5]
-
def _empty_quiz(self, quiz_type: QuizType, title: str) -> Quiz:
"""Erstellt leeres Quiz bei Fehler."""
return Quiz(
@@ -549,7 +434,6 @@ Antworte im JSON-Format:
return self._true_false_to_h5p(quiz)
elif quiz.quiz_type == QuizType.MATCHING:
return self._matching_to_h5p(quiz)
- # Weitere Typen...
return {}
def _true_false_to_h5p(self, quiz: Quiz) -> Dict[str, Any]:
diff --git a/backend-lehrer/generators/quiz_helpers.py b/backend-lehrer/generators/quiz_helpers.py
new file mode 100644
index 0000000..650dc76
--- /dev/null
+++ b/backend-lehrer/generators/quiz_helpers.py
@@ -0,0 +1,70 @@
+"""
+Quiz Helpers - Text-Verarbeitungs-Hilfsfunktionen fuer Quiz-Generierung.
+"""
+
+import re
+from typing import List, Tuple
+
+
+def extract_factual_sentences(text: str) -> List[str]:
+ """Extrahiert Fakten-Sätze aus dem Text."""
+ sentences = re.split(r'[.!?]+', text)
+ factual = []
+
+ for sentence in sentences:
+ sentence = sentence.strip()
+ if len(sentence) > 20 and '?' not in sentence:
+ factual.append(sentence)
+
+ return factual
+
+
+def negate_sentence(sentence: str) -> str:
+ """Negiert eine Aussage einfach."""
+ words = sentence.split()
+ if len(words) > 2:
+ for i, word in enumerate(words):
+ if word.endswith(('t', 'en', 'st')) and i > 0:
+ words.insert(i + 1, 'nicht')
+ break
+ return ' '.join(words)
+
+
+def extract_definitions(text: str) -> List[Tuple[str, str]]:
+ """Extrahiert Begriff-Definition-Paare."""
+ definitions = []
+ patterns = [
+ r'(\w+)\s+ist\s+(.+?)[.]',
+ r'(\w+)\s+bezeichnet\s+(.+?)[.]',
+ r'(\w+)\s+bedeutet\s+(.+?)[.]',
+ r'(\w+):\s+(.+?)[.]',
+ ]
+
+ for pattern in patterns:
+ matches = re.findall(pattern, text)
+ for term, definition in matches:
+ if len(definition) > 10:
+ definitions.append((term, definition.strip()))
+
+ return definitions
+
+
+def extract_sequence(text: str) -> List[str]:
+ """Extrahiert eine Sequenz von Schritten."""
+ steps = []
+ numbered = re.findall(r'\d+[.)]\s*([^.]+)', text)
+ steps.extend(numbered)
+
+ signal_words = ['zuerst', 'dann', 'danach', 'anschließend', 'schließlich']
+ for word in signal_words:
+ pattern = rf'{word}\s+([^.]+)'
+ matches = re.findall(pattern, text, re.IGNORECASE)
+ steps.extend(matches)
+
+ return steps
+
+
+def extract_keywords(text: str) -> List[str]:
+ """Extrahiert Schlüsselwörter."""
+ words = re.findall(r'\b[A-ZÄÖÜ][a-zäöüß]+\b', text)
+ return list(set(words))[:5]
diff --git a/backend-lehrer/generators/quiz_models.py b/backend-lehrer/generators/quiz_models.py
new file mode 100644
index 0000000..d466811
--- /dev/null
+++ b/backend-lehrer/generators/quiz_models.py
@@ -0,0 +1,65 @@
+"""
+Quiz Models - Datenmodelle fuer Quiz-Generierung.
+
+Enthaelt alle Dataclasses und Enums fuer Quiz-Typen:
+- True/False Fragen
+- Zuordnungsaufgaben (Matching)
+- Sortieraufgaben
+- Offene Fragen
+"""
+
+from typing import List, Any, Optional
+from dataclasses import dataclass
+from enum import Enum
+
+
+class QuizType(str, Enum):
+ """Typen von Quiz-Aufgaben."""
+ TRUE_FALSE = "true_false"
+ MATCHING = "matching"
+ SORTING = "sorting"
+ OPEN_ENDED = "open_ended"
+
+
+@dataclass
+class TrueFalseQuestion:
+ """Eine Wahr/Falsch-Frage."""
+ statement: str
+ is_true: bool
+ explanation: str
+ source_reference: Optional[str] = None
+
+
+@dataclass
+class MatchingPair:
+ """Ein Zuordnungspaar."""
+ left: str
+ right: str
+ hint: Optional[str] = None
+
+
+@dataclass
+class SortingItem:
+ """Ein Element zum Sortieren."""
+ text: str
+ correct_position: int
+ category: Optional[str] = None
+
+
+@dataclass
+class OpenQuestion:
+ """Eine offene Frage."""
+ question: str
+ model_answer: str
+ keywords: List[str]
+ points: int = 1
+
+
+@dataclass
+class Quiz:
+ """Ein komplettes Quiz."""
+ quiz_type: QuizType
+ title: str
+ questions: List[Any] # Je nach Typ unterschiedlich
+ topic: Optional[str] = None
+ difficulty: str = "medium"
diff --git a/backend-lehrer/llm_gateway/routes/comparison.py b/backend-lehrer/llm_gateway/routes/comparison.py
index b662d40..b4ab076 100644
--- a/backend-lehrer/llm_gateway/routes/comparison.py
+++ b/backend-lehrer/llm_gateway/routes/comparison.py
@@ -9,378 +9,33 @@ Dieses Modul ermoeglicht:
import asyncio
import logging
-import time
import uuid
from datetime import datetime, timezone
from typing import Optional
-from pydantic import BaseModel, Field
from fastapi import APIRouter, HTTPException, Depends
-from ..models.chat import ChatMessage
from ..middleware.auth import verify_api_key
+from .comparison_models import (
+ ComparisonRequest,
+ LLMResponse,
+ ComparisonResponse,
+ SavedComparison,
+ _comparisons_store,
+ _system_prompts_store,
+)
+from .comparison_providers import (
+ call_openai,
+ call_claude,
+ search_tavily,
+ search_edusearch,
+ call_selfhosted_with_search,
+)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/comparison", tags=["LLM Comparison"])
-class ComparisonRequest(BaseModel):
- """Request fuer LLM-Vergleich."""
- prompt: str = Field(..., description="User prompt (z.B. Lehrer-Frage)")
- system_prompt: Optional[str] = Field(None, description="Optionaler System Prompt")
- enable_openai: bool = Field(True, description="OpenAI/ChatGPT aktivieren")
- enable_claude: bool = Field(True, description="Claude aktivieren")
- enable_selfhosted_tavily: bool = Field(True, description="Self-hosted + Tavily aktivieren")
- enable_selfhosted_edusearch: bool = Field(True, description="Self-hosted + EduSearch aktivieren")
-
- # Parameter fuer Self-hosted Modelle
- selfhosted_model: str = Field("llama3.2:3b", description="Self-hosted Modell")
- temperature: float = Field(0.7, ge=0.0, le=2.0, description="Temperature")
- top_p: float = Field(0.9, ge=0.0, le=1.0, description="Top-p Sampling")
- max_tokens: int = Field(2048, ge=1, le=8192, description="Max Tokens")
-
- # Search Parameter
- search_results_count: int = Field(5, ge=1, le=20, description="Anzahl Suchergebnisse")
- edu_search_filters: Optional[dict] = Field(None, description="Filter fuer EduSearch")
-
-
-class LLMResponse(BaseModel):
- """Antwort eines einzelnen LLM."""
- provider: str
- model: str
- response: str
- latency_ms: int
- tokens_used: Optional[int] = None
- search_results: Optional[list] = None
- error: Optional[str] = None
- timestamp: datetime = Field(default_factory=datetime.utcnow)
-
-
-class ComparisonResponse(BaseModel):
- """Gesamt-Antwort des Vergleichs."""
- comparison_id: str
- prompt: str
- system_prompt: Optional[str]
- responses: list[LLMResponse]
- created_at: datetime = Field(default_factory=datetime.utcnow)
-
-
-class SavedComparison(BaseModel):
- """Gespeicherter Vergleich fuer QA."""
- comparison_id: str
- prompt: str
- system_prompt: Optional[str]
- responses: list[LLMResponse]
- notes: Optional[str] = None
- rating: Optional[dict] = None # {"openai": 4, "claude": 5, ...}
- created_at: datetime
- created_by: Optional[str] = None
-
-
-# In-Memory Storage (in Production: Database)
-_comparisons_store: dict[str, SavedComparison] = {}
-_system_prompts_store: dict[str, dict] = {
- "default": {
- "id": "default",
- "name": "Standard Lehrer-Assistent",
- "prompt": """Du bist ein hilfreicher Assistent fuer Lehrkraefte in Deutschland.
-Deine Aufgaben:
-- Hilfe bei der Unterrichtsplanung
-- Erklaerung von Fachinhalten
-- Erstellung von Arbeitsblaettern und Pruefungen
-- Beratung zu paedagogischen Methoden
-
-Antworte immer auf Deutsch und beachte den deutschen Lehrplankontext.""",
- "created_at": datetime.now(timezone.utc).isoformat(),
- },
- "curriculum": {
- "id": "curriculum",
- "name": "Lehrplan-Experte",
- "prompt": """Du bist ein Experte fuer deutsche Lehrplaene und Bildungsstandards.
-Du kennst:
-- Lehrplaene aller 16 Bundeslaender
-- KMK Bildungsstandards
-- Kompetenzorientierung im deutschen Bildungssystem
-
-Beziehe dich immer auf konkrete Lehrplanvorgaben wenn moeglich.""",
- "created_at": datetime.now(timezone.utc).isoformat(),
- },
- "worksheet": {
- "id": "worksheet",
- "name": "Arbeitsblatt-Generator",
- "prompt": """Du bist ein spezialisierter Assistent fuer die Erstellung von Arbeitsblaettern.
-Erstelle didaktisch sinnvolle Aufgaben mit:
-- Klaren Arbeitsanweisungen
-- Differenzierungsmoeglichkeiten
-- Loesungshinweisen
-
-Format: Markdown mit klarer Struktur.""",
- "created_at": datetime.now(timezone.utc).isoformat(),
- },
-}
-
-
-async def _call_openai(prompt: str, system_prompt: Optional[str]) -> LLMResponse:
- """Ruft OpenAI ChatGPT auf."""
- import os
- import httpx
-
- start_time = time.time()
- api_key = os.getenv("OPENAI_API_KEY")
-
- if not api_key:
- return LLMResponse(
- provider="openai",
- model="gpt-4o-mini",
- response="",
- latency_ms=0,
- error="OPENAI_API_KEY nicht konfiguriert"
- )
-
- messages = []
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
- messages.append({"role": "user", "content": prompt})
-
- try:
- async with httpx.AsyncClient(timeout=60.0) as client:
- response = await client.post(
- "https://api.openai.com/v1/chat/completions",
- headers={
- "Authorization": f"Bearer {api_key}",
- "Content-Type": "application/json",
- },
- json={
- "model": "gpt-4o-mini",
- "messages": messages,
- "temperature": 0.7,
- "max_tokens": 2048,
- },
- )
- response.raise_for_status()
- data = response.json()
-
- latency_ms = int((time.time() - start_time) * 1000)
- content = data["choices"][0]["message"]["content"]
- tokens = data.get("usage", {}).get("total_tokens")
-
- return LLMResponse(
- provider="openai",
- model="gpt-4o-mini",
- response=content,
- latency_ms=latency_ms,
- tokens_used=tokens,
- )
- except Exception as e:
- return LLMResponse(
- provider="openai",
- model="gpt-4o-mini",
- response="",
- latency_ms=int((time.time() - start_time) * 1000),
- error=str(e),
- )
-
-
-async def _call_claude(prompt: str, system_prompt: Optional[str]) -> LLMResponse:
- """Ruft Anthropic Claude auf."""
- import os
-
- start_time = time.time()
- api_key = os.getenv("ANTHROPIC_API_KEY")
-
- if not api_key:
- return LLMResponse(
- provider="claude",
- model="claude-3-5-sonnet-20241022",
- response="",
- latency_ms=0,
- error="ANTHROPIC_API_KEY nicht konfiguriert"
- )
-
- try:
- import anthropic
- client = anthropic.AsyncAnthropic(api_key=api_key)
-
- response = await client.messages.create(
- model="claude-3-5-sonnet-20241022",
- max_tokens=2048,
- system=system_prompt or "",
- messages=[{"role": "user", "content": prompt}],
- )
-
- latency_ms = int((time.time() - start_time) * 1000)
- content = response.content[0].text if response.content else ""
- tokens = response.usage.input_tokens + response.usage.output_tokens
-
- return LLMResponse(
- provider="claude",
- model="claude-3-5-sonnet-20241022",
- response=content,
- latency_ms=latency_ms,
- tokens_used=tokens,
- )
- except Exception as e:
- return LLMResponse(
- provider="claude",
- model="claude-3-5-sonnet-20241022",
- response="",
- latency_ms=int((time.time() - start_time) * 1000),
- error=str(e),
- )
-
-
-async def _search_tavily(query: str, count: int = 5) -> list[dict]:
- """Sucht mit Tavily API."""
- import os
- import httpx
-
- api_key = os.getenv("TAVILY_API_KEY")
- if not api_key:
- return []
-
- try:
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(
- "https://api.tavily.com/search",
- json={
- "api_key": api_key,
- "query": query,
- "max_results": count,
- "include_domains": [
- "kmk.org", "bildungsserver.de", "bpb.de",
- "bayern.de", "nrw.de", "berlin.de",
- ],
- },
- )
- response.raise_for_status()
- data = response.json()
- return data.get("results", [])
- except Exception as e:
- logger.error(f"Tavily search error: {e}")
- return []
-
-
-async def _search_edusearch(query: str, count: int = 5, filters: Optional[dict] = None) -> list[dict]:
- """Sucht mit EduSearch API."""
- import os
- import httpx
-
- edu_search_url = os.getenv("EDU_SEARCH_URL", "http://edu-search-service:8084")
-
- try:
- async with httpx.AsyncClient(timeout=30.0) as client:
- payload = {
- "q": query,
- "limit": count,
- "mode": "keyword",
- }
- if filters:
- payload["filters"] = filters
-
- response = await client.post(
- f"{edu_search_url}/v1/search",
- json=payload,
- )
- response.raise_for_status()
- data = response.json()
-
- # Formatiere Ergebnisse
- results = []
- for r in data.get("results", []):
- results.append({
- "title": r.get("title", ""),
- "url": r.get("url", ""),
- "content": r.get("snippet", ""),
- "score": r.get("scores", {}).get("final", 0),
- })
- return results
- except Exception as e:
- logger.error(f"EduSearch error: {e}")
- return []
-
-
-async def _call_selfhosted_with_search(
- prompt: str,
- system_prompt: Optional[str],
- search_provider: str,
- search_results: list[dict],
- model: str,
- temperature: float,
- top_p: float,
- max_tokens: int,
-) -> LLMResponse:
- """Ruft Self-hosted LLM mit Suchergebnissen auf."""
- import os
- import httpx
-
- start_time = time.time()
- ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
-
- # Baue Kontext aus Suchergebnissen
- context_parts = []
- for i, result in enumerate(search_results, 1):
- context_parts.append(f"[{i}] {result.get('title', 'Untitled')}")
- context_parts.append(f" URL: {result.get('url', '')}")
- context_parts.append(f" {result.get('content', '')[:500]}")
- context_parts.append("")
-
- search_context = "\n".join(context_parts)
-
- # Erweitere System Prompt mit Suchergebnissen
- augmented_system = f"""{system_prompt or ''}
-
-Du hast Zugriff auf folgende Suchergebnisse aus {"Tavily" if search_provider == "tavily" else "EduSearch (deutsche Bildungsquellen)"}:
-
-{search_context}
-
-Nutze diese Quellen um deine Antwort zu unterstuetzen. Zitiere relevante Quellen mit [Nummer]."""
-
- messages = [
- {"role": "system", "content": augmented_system},
- {"role": "user", "content": prompt},
- ]
-
- try:
- async with httpx.AsyncClient(timeout=120.0) as client:
- response = await client.post(
- f"{ollama_url}/api/chat",
- json={
- "model": model,
- "messages": messages,
- "stream": False,
- "options": {
- "temperature": temperature,
- "top_p": top_p,
- "num_predict": max_tokens,
- },
- },
- )
- response.raise_for_status()
- data = response.json()
-
- latency_ms = int((time.time() - start_time) * 1000)
- content = data.get("message", {}).get("content", "")
- tokens = data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
-
- return LLMResponse(
- provider=f"selfhosted_{search_provider}",
- model=model,
- response=content,
- latency_ms=latency_ms,
- tokens_used=tokens,
- search_results=search_results,
- )
- except Exception as e:
- return LLMResponse(
- provider=f"selfhosted_{search_provider}",
- model=model,
- response="",
- latency_ms=int((time.time() - start_time) * 1000),
- error=str(e),
- search_results=search_results,
- )
-
-
@router.post("/run", response_model=ComparisonResponse)
async def run_comparison(
request: ComparisonRequest,
@@ -395,23 +50,19 @@ async def run_comparison(
comparison_id = f"cmp-{uuid.uuid4().hex[:12]}"
tasks = []
- # System Prompt vorbereiten
system_prompt = request.system_prompt
- # OpenAI
if request.enable_openai:
- tasks.append(("openai", _call_openai(request.prompt, system_prompt)))
+ tasks.append(("openai", call_openai(request.prompt, system_prompt)))
- # Claude
if request.enable_claude:
- tasks.append(("claude", _call_claude(request.prompt, system_prompt)))
+ tasks.append(("claude", call_claude(request.prompt, system_prompt)))
- # Self-hosted + Tavily
if request.enable_selfhosted_tavily:
- tavily_results = await _search_tavily(request.prompt, request.search_results_count)
+ tavily_results = await search_tavily(request.prompt, request.search_results_count)
tasks.append((
"selfhosted_tavily",
- _call_selfhosted_with_search(
+ call_selfhosted_with_search(
request.prompt,
system_prompt,
"tavily",
@@ -423,16 +74,15 @@ async def run_comparison(
)
))
- # Self-hosted + EduSearch
if request.enable_selfhosted_edusearch:
- edu_results = await _search_edusearch(
+ edu_results = await search_edusearch(
request.prompt,
request.search_results_count,
request.edu_search_filters,
)
tasks.append((
"selfhosted_edusearch",
- _call_selfhosted_with_search(
+ call_selfhosted_with_search(
request.prompt,
system_prompt,
"edusearch",
@@ -444,7 +94,6 @@ async def run_comparison(
)
))
- # Parallele Ausfuehrung
responses = []
if tasks:
results = await asyncio.gather(*[t[1] for t in tasks], return_exceptions=True)
diff --git a/backend-lehrer/llm_gateway/routes/comparison_models.py b/backend-lehrer/llm_gateway/routes/comparison_models.py
new file mode 100644
index 0000000..3652a57
--- /dev/null
+++ b/backend-lehrer/llm_gateway/routes/comparison_models.py
@@ -0,0 +1,103 @@
+"""
+LLM Comparison - Pydantic Models und In-Memory Storage.
+"""
+
+from datetime import datetime, timezone
+from typing import Optional
+from pydantic import BaseModel, Field
+
+
+class ComparisonRequest(BaseModel):
+ """Request fuer LLM-Vergleich."""
+ prompt: str = Field(..., description="User prompt (z.B. Lehrer-Frage)")
+ system_prompt: Optional[str] = Field(None, description="Optionaler System Prompt")
+ enable_openai: bool = Field(True, description="OpenAI/ChatGPT aktivieren")
+ enable_claude: bool = Field(True, description="Claude aktivieren")
+ enable_selfhosted_tavily: bool = Field(True, description="Self-hosted + Tavily aktivieren")
+ enable_selfhosted_edusearch: bool = Field(True, description="Self-hosted + EduSearch aktivieren")
+
+ # Parameter fuer Self-hosted Modelle
+ selfhosted_model: str = Field("llama3.2:3b", description="Self-hosted Modell")
+ temperature: float = Field(0.7, ge=0.0, le=2.0, description="Temperature")
+ top_p: float = Field(0.9, ge=0.0, le=1.0, description="Top-p Sampling")
+ max_tokens: int = Field(2048, ge=1, le=8192, description="Max Tokens")
+
+ # Search Parameter
+ search_results_count: int = Field(5, ge=1, le=20, description="Anzahl Suchergebnisse")
+ edu_search_filters: Optional[dict] = Field(None, description="Filter fuer EduSearch")
+
+
+class LLMResponse(BaseModel):
+ """Antwort eines einzelnen LLM."""
+ provider: str
+ model: str
+ response: str
+ latency_ms: int
+ tokens_used: Optional[int] = None
+ search_results: Optional[list] = None
+ error: Optional[str] = None
+ timestamp: datetime = Field(default_factory=datetime.utcnow)
+
+
+class ComparisonResponse(BaseModel):
+ """Gesamt-Antwort des Vergleichs."""
+ comparison_id: str
+ prompt: str
+ system_prompt: Optional[str]
+ responses: list[LLMResponse]
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+
+
+class SavedComparison(BaseModel):
+ """Gespeicherter Vergleich fuer QA."""
+ comparison_id: str
+ prompt: str
+ system_prompt: Optional[str]
+ responses: list[LLMResponse]
+ notes: Optional[str] = None
+ rating: Optional[dict] = None # {"openai": 4, "claude": 5, ...}
+ created_at: datetime
+ created_by: Optional[str] = None
+
+
+# In-Memory Storage (in Production: Database)
+_comparisons_store: dict[str, SavedComparison] = {}
+_system_prompts_store: dict[str, dict] = {
+ "default": {
+ "id": "default",
+ "name": "Standard Lehrer-Assistent",
+ "prompt": """Du bist ein hilfreicher Assistent fuer Lehrkraefte in Deutschland.
+Deine Aufgaben:
+- Hilfe bei der Unterrichtsplanung
+- Erklaerung von Fachinhalten
+- Erstellung von Arbeitsblaettern und Pruefungen
+- Beratung zu paedagogischen Methoden
+
+Antworte immer auf Deutsch und beachte den deutschen Lehrplankontext.""",
+ "created_at": datetime.now(timezone.utc).isoformat(),
+ },
+ "curriculum": {
+ "id": "curriculum",
+ "name": "Lehrplan-Experte",
+ "prompt": """Du bist ein Experte fuer deutsche Lehrplaene und Bildungsstandards.
+Du kennst:
+- Lehrplaene aller 16 Bundeslaender
+- KMK Bildungsstandards
+- Kompetenzorientierung im deutschen Bildungssystem
+
+Beziehe dich immer auf konkrete Lehrplanvorgaben wenn moeglich.""",
+ "created_at": datetime.now(timezone.utc).isoformat(),
+ },
+ "worksheet": {
+ "id": "worksheet",
+ "name": "Arbeitsblatt-Generator",
+ "prompt": """Du bist ein spezialisierter Assistent fuer die Erstellung von Arbeitsblaettern.
+Erstelle didaktisch sinnvolle Aufgaben mit:
+- Klaren Arbeitsanweisungen
+- Differenzierungsmoeglichkeiten
+- Loesungshinweisen
+
+Format: Markdown mit klarer Struktur.""",
+ "created_at": datetime.now(timezone.utc).isoformat(),
+ },
+}
diff --git a/backend-lehrer/llm_gateway/routes/comparison_providers.py b/backend-lehrer/llm_gateway/routes/comparison_providers.py
new file mode 100644
index 0000000..36237c2
--- /dev/null
+++ b/backend-lehrer/llm_gateway/routes/comparison_providers.py
@@ -0,0 +1,270 @@
+"""
+LLM Comparison - Provider-Aufrufe (OpenAI, Claude, Self-hosted, Search).
+"""
+
+import logging
+import time
+from typing import Optional
+
+from .comparison_models import LLMResponse
+
+logger = logging.getLogger(__name__)
+
+
+async def call_openai(prompt: str, system_prompt: Optional[str]) -> LLMResponse:
+ """Ruft OpenAI ChatGPT auf."""
+ import os
+ import httpx
+
+ start_time = time.time()
+ api_key = os.getenv("OPENAI_API_KEY")
+
+ if not api_key:
+ return LLMResponse(
+ provider="openai",
+ model="gpt-4o-mini",
+ response="",
+ latency_ms=0,
+ error="OPENAI_API_KEY nicht konfiguriert"
+ )
+
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.append({"role": "user", "content": prompt})
+
+ try:
+ async with httpx.AsyncClient(timeout=60.0) as client:
+ response = await client.post(
+ "https://api.openai.com/v1/chat/completions",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ json={
+ "model": "gpt-4o-mini",
+ "messages": messages,
+ "temperature": 0.7,
+ "max_tokens": 2048,
+ },
+ )
+ response.raise_for_status()
+ data = response.json()
+
+ latency_ms = int((time.time() - start_time) * 1000)
+ content = data["choices"][0]["message"]["content"]
+ tokens = data.get("usage", {}).get("total_tokens")
+
+ return LLMResponse(
+ provider="openai",
+ model="gpt-4o-mini",
+ response=content,
+ latency_ms=latency_ms,
+ tokens_used=tokens,
+ )
+ except Exception as e:
+ return LLMResponse(
+ provider="openai",
+ model="gpt-4o-mini",
+ response="",
+ latency_ms=int((time.time() - start_time) * 1000),
+ error=str(e),
+ )
+
+
+async def call_claude(prompt: str, system_prompt: Optional[str]) -> LLMResponse:
+ """Ruft Anthropic Claude auf."""
+ import os
+
+ start_time = time.time()
+ api_key = os.getenv("ANTHROPIC_API_KEY")
+
+ if not api_key:
+ return LLMResponse(
+ provider="claude",
+ model="claude-3-5-sonnet-20241022",
+ response="",
+ latency_ms=0,
+ error="ANTHROPIC_API_KEY nicht konfiguriert"
+ )
+
+ try:
+ import anthropic
+ client = anthropic.AsyncAnthropic(api_key=api_key)
+
+ response = await client.messages.create(
+ model="claude-3-5-sonnet-20241022",
+ max_tokens=2048,
+ system=system_prompt or "",
+ messages=[{"role": "user", "content": prompt}],
+ )
+
+ latency_ms = int((time.time() - start_time) * 1000)
+ content = response.content[0].text if response.content else ""
+ tokens = response.usage.input_tokens + response.usage.output_tokens
+
+ return LLMResponse(
+ provider="claude",
+ model="claude-3-5-sonnet-20241022",
+ response=content,
+ latency_ms=latency_ms,
+ tokens_used=tokens,
+ )
+ except Exception as e:
+ return LLMResponse(
+ provider="claude",
+ model="claude-3-5-sonnet-20241022",
+ response="",
+ latency_ms=int((time.time() - start_time) * 1000),
+ error=str(e),
+ )
+
+
+async def search_tavily(query: str, count: int = 5) -> list[dict]:
+ """Sucht mit Tavily API."""
+ import os
+ import httpx
+
+ api_key = os.getenv("TAVILY_API_KEY")
+ if not api_key:
+ return []
+
+ try:
+ async with httpx.AsyncClient(timeout=30.0) as client:
+ response = await client.post(
+ "https://api.tavily.com/search",
+ json={
+ "api_key": api_key,
+ "query": query,
+ "max_results": count,
+ "include_domains": [
+ "kmk.org", "bildungsserver.de", "bpb.de",
+ "bayern.de", "nrw.de", "berlin.de",
+ ],
+ },
+ )
+ response.raise_for_status()
+ data = response.json()
+ return data.get("results", [])
+ except Exception as e:
+ logger.error(f"Tavily search error: {e}")
+ return []
+
+
+async def search_edusearch(query: str, count: int = 5, filters: Optional[dict] = None) -> list[dict]:
+ """Sucht mit EduSearch API."""
+ import os
+ import httpx
+
+ edu_search_url = os.getenv("EDU_SEARCH_URL", "http://edu-search-service:8084")
+
+ try:
+ async with httpx.AsyncClient(timeout=30.0) as client:
+ payload = {
+ "q": query,
+ "limit": count,
+ "mode": "keyword",
+ }
+ if filters:
+ payload["filters"] = filters
+
+ response = await client.post(
+ f"{edu_search_url}/v1/search",
+ json=payload,
+ )
+ response.raise_for_status()
+ data = response.json()
+
+ results = []
+ for r in data.get("results", []):
+ results.append({
+ "title": r.get("title", ""),
+ "url": r.get("url", ""),
+ "content": r.get("snippet", ""),
+ "score": r.get("scores", {}).get("final", 0),
+ })
+ return results
+ except Exception as e:
+ logger.error(f"EduSearch error: {e}")
+ return []
+
+
+async def call_selfhosted_with_search(
+ prompt: str,
+ system_prompt: Optional[str],
+ search_provider: str,
+ search_results: list[dict],
+ model: str,
+ temperature: float,
+ top_p: float,
+ max_tokens: int,
+) -> LLMResponse:
+ """Ruft Self-hosted LLM mit Suchergebnissen auf."""
+ import os
+ import httpx
+
+ start_time = time.time()
+ ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
+
+ # Baue Kontext aus Suchergebnissen
+ context_parts = []
+ for i, result in enumerate(search_results, 1):
+ context_parts.append(f"[{i}] {result.get('title', 'Untitled')}")
+ context_parts.append(f" URL: {result.get('url', '')}")
+ context_parts.append(f" {result.get('content', '')[:500]}")
+ context_parts.append("")
+
+ search_context = "\n".join(context_parts)
+
+ augmented_system = f"""{system_prompt or ''}
+
+Du hast Zugriff auf folgende Suchergebnisse aus {"Tavily" if search_provider == "tavily" else "EduSearch (deutsche Bildungsquellen)"}:
+
+{search_context}
+
+Nutze diese Quellen um deine Antwort zu unterstuetzen. Zitiere relevante Quellen mit [Nummer]."""
+
+ messages = [
+ {"role": "system", "content": augmented_system},
+ {"role": "user", "content": prompt},
+ ]
+
+ try:
+ async with httpx.AsyncClient(timeout=120.0) as client:
+ response = await client.post(
+ f"{ollama_url}/api/chat",
+ json={
+ "model": model,
+ "messages": messages,
+ "stream": False,
+ "options": {
+ "temperature": temperature,
+ "top_p": top_p,
+ "num_predict": max_tokens,
+ },
+ },
+ )
+ response.raise_for_status()
+ data = response.json()
+
+ latency_ms = int((time.time() - start_time) * 1000)
+ content = data.get("message", {}).get("content", "")
+ tokens = data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
+
+ return LLMResponse(
+ provider=f"selfhosted_{search_provider}",
+ model=model,
+ response=content,
+ latency_ms=latency_ms,
+ tokens_used=tokens,
+ search_results=search_results,
+ )
+ except Exception as e:
+ return LLMResponse(
+ provider=f"selfhosted_{search_provider}",
+ model=model,
+ response="",
+ latency_ms=int((time.time() - start_time) * 1000),
+ error=str(e),
+ search_results=search_results,
+ )
diff --git a/backend-lehrer/llm_gateway/services/inference.py b/backend-lehrer/llm_gateway/services/inference.py
index 756afc5..e39f68e 100644
--- a/backend-lehrer/llm_gateway/services/inference.py
+++ b/backend-lehrer/llm_gateway/services/inference.py
@@ -8,10 +8,8 @@ Unterstützt:
"""
import httpx
-import json
import logging
from typing import AsyncIterator, Optional
-from dataclasses import dataclass
from ..config import get_config, LLMBackendConfig
from ..models.chat import (
@@ -20,26 +18,23 @@ from ..models.chat import (
ChatCompletionChunk,
ChatMessage,
ChatChoice,
- StreamChoice,
- ChatChoiceDelta,
Usage,
ModelInfo,
ModelListResponse,
)
+from .inference_backends import (
+ InferenceResult,
+ call_ollama,
+ stream_ollama,
+ call_openai_compatible,
+ stream_openai_compatible,
+ call_anthropic,
+ stream_anthropic,
+)
logger = logging.getLogger(__name__)
-@dataclass
-class InferenceResult:
- """Ergebnis einer Inference-Anfrage."""
- content: str
- model: str
- backend: str
- usage: Optional[Usage] = None
- finish_reason: str = "stop"
-
-
class InferenceService:
"""Service für LLM Inference über verschiedene Backends."""
@@ -68,26 +63,17 @@ class InferenceService:
return None
def _map_model_to_backend(self, model: str) -> tuple[str, LLMBackendConfig]:
- """
- Mapped ein Modell-Name zum entsprechenden Backend.
-
- Beispiele:
- - "breakpilot-teacher-8b" → Ollama/vLLM mit llama3.1:8b
- - "claude-3-5-sonnet" → Anthropic
- """
+ """Mapped ein Modell-Name zum entsprechenden Backend."""
model_lower = model.lower()
- # Explizite Claude-Modelle → Anthropic
if "claude" in model_lower:
if self.config.anthropic and self.config.anthropic.enabled:
return self.config.anthropic.default_model, self.config.anthropic
raise ValueError("Anthropic backend not configured")
- # BreakPilot Modelle → primäres Backend
if "breakpilot" in model_lower or "teacher" in model_lower:
backend = self._get_available_backend()
if backend:
- # Map zu tatsächlichem Modell-Namen
if "70b" in model_lower:
actual_model = "llama3.1:70b" if backend.name == "ollama" else "meta-llama/Meta-Llama-3.1-70B-Instruct"
else:
@@ -95,7 +81,6 @@ class InferenceService:
return actual_model, backend
raise ValueError("No LLM backend available")
- # Mistral Modelle
if "mistral" in model_lower:
backend = self._get_available_backend()
if backend:
@@ -103,409 +88,64 @@ class InferenceService:
return actual_model, backend
raise ValueError("No LLM backend available")
- # Fallback: verwende Modell-Name direkt
backend = self._get_available_backend()
if backend:
return model, backend
raise ValueError("No LLM backend available")
- async def _call_ollama(
- self,
- backend: LLMBackendConfig,
- model: str,
- request: ChatCompletionRequest,
- ) -> InferenceResult:
- """Ruft Ollama API auf (nicht OpenAI-kompatibel)."""
- client = await self.get_client()
-
- # Ollama verwendet eigenes Format
- messages = [{"role": m.role, "content": m.content or ""} for m in request.messages]
-
- payload = {
- "model": model,
- "messages": messages,
- "stream": False,
- "options": {
- "temperature": request.temperature,
- "top_p": request.top_p,
- },
- }
-
- if request.max_tokens:
- payload["options"]["num_predict"] = request.max_tokens
-
- response = await client.post(
- f"{backend.base_url}/api/chat",
- json=payload,
- timeout=backend.timeout,
- )
- response.raise_for_status()
- data = response.json()
-
- return InferenceResult(
- content=data.get("message", {}).get("content", ""),
- model=model,
- backend="ollama",
- usage=Usage(
- prompt_tokens=data.get("prompt_eval_count", 0),
- completion_tokens=data.get("eval_count", 0),
- total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
- ),
- finish_reason="stop" if data.get("done") else "length",
- )
-
- async def _stream_ollama(
- self,
- backend: LLMBackendConfig,
- model: str,
- request: ChatCompletionRequest,
- response_id: str,
- ) -> AsyncIterator[ChatCompletionChunk]:
- """Streamt von Ollama."""
- client = await self.get_client()
-
- messages = [{"role": m.role, "content": m.content or ""} for m in request.messages]
-
- payload = {
- "model": model,
- "messages": messages,
- "stream": True,
- "options": {
- "temperature": request.temperature,
- "top_p": request.top_p,
- },
- }
-
- if request.max_tokens:
- payload["options"]["num_predict"] = request.max_tokens
-
- async with client.stream(
- "POST",
- f"{backend.base_url}/api/chat",
- json=payload,
- timeout=backend.timeout,
- ) as response:
- response.raise_for_status()
- async for line in response.aiter_lines():
- if not line:
- continue
- try:
- data = json.loads(line)
- content = data.get("message", {}).get("content", "")
- done = data.get("done", False)
-
- yield ChatCompletionChunk(
- id=response_id,
- model=model,
- choices=[
- StreamChoice(
- index=0,
- delta=ChatChoiceDelta(content=content),
- finish_reason="stop" if done else None,
- )
- ],
- )
- except json.JSONDecodeError:
- continue
-
- async def _call_openai_compatible(
- self,
- backend: LLMBackendConfig,
- model: str,
- request: ChatCompletionRequest,
- ) -> InferenceResult:
- """Ruft OpenAI-kompatible API auf (vLLM, etc.)."""
- client = await self.get_client()
-
- headers = {"Content-Type": "application/json"}
- if backend.api_key:
- headers["Authorization"] = f"Bearer {backend.api_key}"
-
- payload = {
- "model": model,
- "messages": [m.model_dump(exclude_none=True) for m in request.messages],
- "stream": False,
- "temperature": request.temperature,
- "top_p": request.top_p,
- }
-
- if request.max_tokens:
- payload["max_tokens"] = request.max_tokens
- if request.stop:
- payload["stop"] = request.stop
-
- response = await client.post(
- f"{backend.base_url}/v1/chat/completions",
- json=payload,
- headers=headers,
- timeout=backend.timeout,
- )
- response.raise_for_status()
- data = response.json()
-
- choice = data.get("choices", [{}])[0]
- usage_data = data.get("usage", {})
-
- return InferenceResult(
- content=choice.get("message", {}).get("content", ""),
- model=model,
- backend=backend.name,
- usage=Usage(
- prompt_tokens=usage_data.get("prompt_tokens", 0),
- completion_tokens=usage_data.get("completion_tokens", 0),
- total_tokens=usage_data.get("total_tokens", 0),
- ),
- finish_reason=choice.get("finish_reason", "stop"),
- )
-
- async def _stream_openai_compatible(
- self,
- backend: LLMBackendConfig,
- model: str,
- request: ChatCompletionRequest,
- response_id: str,
- ) -> AsyncIterator[ChatCompletionChunk]:
- """Streamt von OpenAI-kompatibler API."""
- client = await self.get_client()
-
- headers = {"Content-Type": "application/json"}
- if backend.api_key:
- headers["Authorization"] = f"Bearer {backend.api_key}"
-
- payload = {
- "model": model,
- "messages": [m.model_dump(exclude_none=True) for m in request.messages],
- "stream": True,
- "temperature": request.temperature,
- "top_p": request.top_p,
- }
-
- if request.max_tokens:
- payload["max_tokens"] = request.max_tokens
-
- async with client.stream(
- "POST",
- f"{backend.base_url}/v1/chat/completions",
- json=payload,
- headers=headers,
- timeout=backend.timeout,
- ) as response:
- response.raise_for_status()
- async for line in response.aiter_lines():
- if not line or not line.startswith("data: "):
- continue
- data_str = line[6:] # Remove "data: " prefix
- if data_str == "[DONE]":
- break
- try:
- data = json.loads(data_str)
- choice = data.get("choices", [{}])[0]
- delta = choice.get("delta", {})
-
- yield ChatCompletionChunk(
- id=response_id,
- model=model,
- choices=[
- StreamChoice(
- index=0,
- delta=ChatChoiceDelta(
- role=delta.get("role"),
- content=delta.get("content"),
- ),
- finish_reason=choice.get("finish_reason"),
- )
- ],
- )
- except json.JSONDecodeError:
- continue
-
- async def _call_anthropic(
- self,
- backend: LLMBackendConfig,
- model: str,
- request: ChatCompletionRequest,
- ) -> InferenceResult:
- """Ruft Anthropic Claude API auf."""
- # Anthropic SDK verwenden (bereits installiert)
- try:
- import anthropic
- except ImportError:
- raise ImportError("anthropic package required for Claude API")
-
- client = anthropic.AsyncAnthropic(api_key=backend.api_key)
-
- # System message extrahieren
- system_content = ""
- messages = []
- for msg in request.messages:
- if msg.role == "system":
- system_content += (msg.content or "") + "\n"
- else:
- messages.append({"role": msg.role, "content": msg.content or ""})
-
- response = await client.messages.create(
- model=model,
- max_tokens=request.max_tokens or 4096,
- system=system_content.strip() if system_content else None,
- messages=messages,
- temperature=request.temperature,
- top_p=request.top_p,
- )
-
- content = ""
- if response.content:
- content = response.content[0].text if response.content[0].type == "text" else ""
-
- return InferenceResult(
- content=content,
- model=model,
- backend="anthropic",
- usage=Usage(
- prompt_tokens=response.usage.input_tokens,
- completion_tokens=response.usage.output_tokens,
- total_tokens=response.usage.input_tokens + response.usage.output_tokens,
- ),
- finish_reason="stop" if response.stop_reason == "end_turn" else response.stop_reason or "stop",
- )
-
- async def _stream_anthropic(
- self,
- backend: LLMBackendConfig,
- model: str,
- request: ChatCompletionRequest,
- response_id: str,
- ) -> AsyncIterator[ChatCompletionChunk]:
- """Streamt von Anthropic Claude API."""
- try:
- import anthropic
- except ImportError:
- raise ImportError("anthropic package required for Claude API")
-
- client = anthropic.AsyncAnthropic(api_key=backend.api_key)
-
- # System message extrahieren
- system_content = ""
- messages = []
- for msg in request.messages:
- if msg.role == "system":
- system_content += (msg.content or "") + "\n"
- else:
- messages.append({"role": msg.role, "content": msg.content or ""})
-
- async with client.messages.stream(
- model=model,
- max_tokens=request.max_tokens or 4096,
- system=system_content.strip() if system_content else None,
- messages=messages,
- temperature=request.temperature,
- top_p=request.top_p,
- ) as stream:
- async for text in stream.text_stream:
- yield ChatCompletionChunk(
- id=response_id,
- model=model,
- choices=[
- StreamChoice(
- index=0,
- delta=ChatChoiceDelta(content=text),
- finish_reason=None,
- )
- ],
- )
-
- # Final chunk with finish_reason
- yield ChatCompletionChunk(
- id=response_id,
- model=model,
- choices=[
- StreamChoice(
- index=0,
- delta=ChatChoiceDelta(),
- finish_reason="stop",
- )
- ],
- )
-
async def complete(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
- """
- Führt Chat Completion durch (non-streaming).
- """
+ """Führt Chat Completion durch (non-streaming)."""
actual_model, backend = self._map_model_to_backend(request.model)
+ logger.info(f"Inference request: model={request.model} -> {actual_model} via {backend.name}")
- logger.info(f"Inference request: model={request.model} → {actual_model} via {backend.name}")
+ client = await self.get_client()
if backend.name == "ollama":
- result = await self._call_ollama(backend, actual_model, request)
+ result = await call_ollama(client, backend, actual_model, request)
elif backend.name == "anthropic":
- result = await self._call_anthropic(backend, actual_model, request)
+ result = await call_anthropic(backend, actual_model, request)
else:
- result = await self._call_openai_compatible(backend, actual_model, request)
+ result = await call_openai_compatible(client, backend, actual_model, request)
return ChatCompletionResponse(
- model=request.model, # Original requested model name
- choices=[
- ChatChoice(
- index=0,
- message=ChatMessage(role="assistant", content=result.content),
- finish_reason=result.finish_reason,
- )
- ],
+ model=request.model,
+ choices=[ChatChoice(index=0, message=ChatMessage(role="assistant", content=result.content), finish_reason=result.finish_reason)],
usage=result.usage,
)
async def stream(self, request: ChatCompletionRequest) -> AsyncIterator[ChatCompletionChunk]:
- """
- Führt Chat Completion mit Streaming durch.
- """
+ """Führt Chat Completion mit Streaming durch."""
import uuid
response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
actual_model, backend = self._map_model_to_backend(request.model)
+ logger.info(f"Streaming request: model={request.model} -> {actual_model} via {backend.name}")
- logger.info(f"Streaming request: model={request.model} → {actual_model} via {backend.name}")
+ client = await self.get_client()
if backend.name == "ollama":
- async for chunk in self._stream_ollama(backend, actual_model, request, response_id):
+ async for chunk in stream_ollama(client, backend, actual_model, request, response_id):
yield chunk
elif backend.name == "anthropic":
- async for chunk in self._stream_anthropic(backend, actual_model, request, response_id):
+ async for chunk in stream_anthropic(backend, actual_model, request, response_id):
yield chunk
else:
- async for chunk in self._stream_openai_compatible(backend, actual_model, request, response_id):
+ async for chunk in stream_openai_compatible(client, backend, actual_model, request, response_id):
yield chunk
async def list_models(self) -> ModelListResponse:
"""Listet verfügbare Modelle."""
models = []
- # BreakPilot Modelle (mapped zu verfügbaren Backends)
backend = self._get_available_backend()
if backend:
models.extend([
- ModelInfo(
- id="breakpilot-teacher-8b",
- owned_by="breakpilot",
- description="Llama 3.1 8B optimiert für Schulkontext",
- context_length=8192,
- ),
- ModelInfo(
- id="breakpilot-teacher-70b",
- owned_by="breakpilot",
- description="Llama 3.1 70B für komplexe Aufgaben",
- context_length=8192,
- ),
+ ModelInfo(id="breakpilot-teacher-8b", owned_by="breakpilot", description="Llama 3.1 8B optimiert für Schulkontext", context_length=8192),
+ ModelInfo(id="breakpilot-teacher-70b", owned_by="breakpilot", description="Llama 3.1 70B für komplexe Aufgaben", context_length=8192),
])
- # Claude Modelle (wenn Anthropic konfiguriert)
if self.config.anthropic and self.config.anthropic.enabled:
- models.append(
- ModelInfo(
- id="claude-3-5-sonnet",
- owned_by="anthropic",
- description="Claude 3.5 Sonnet - Fallback für höchste Qualität",
- context_length=200000,
- )
- )
+ models.append(ModelInfo(id="claude-3-5-sonnet", owned_by="anthropic", description="Claude 3.5 Sonnet - Fallback für höchste Qualität", context_length=200000))
return ModelListResponse(data=models)
diff --git a/backend-lehrer/llm_gateway/services/inference_backends.py b/backend-lehrer/llm_gateway/services/inference_backends.py
new file mode 100644
index 0000000..90de01c
--- /dev/null
+++ b/backend-lehrer/llm_gateway/services/inference_backends.py
@@ -0,0 +1,230 @@
+"""
+Inference Backends - Kommunikation mit einzelnen LLM-Providern.
+
+Unterstützt Ollama, OpenAI-kompatible APIs und Anthropic Claude.
+"""
+
+import json
+import logging
+from typing import AsyncIterator, Optional
+from dataclasses import dataclass
+
+from ..config import LLMBackendConfig
+from ..models.chat import (
+ ChatCompletionRequest,
+ ChatCompletionChunk,
+ ChatMessage,
+ StreamChoice,
+ ChatChoiceDelta,
+ Usage,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class InferenceResult:
+ """Ergebnis einer Inference-Anfrage."""
+ content: str
+ model: str
+ backend: str
+ usage: Optional[Usage] = None
+ finish_reason: str = "stop"
+
+
+async def call_ollama(client, backend: LLMBackendConfig, model: str, request: ChatCompletionRequest) -> InferenceResult:
+ """Ruft Ollama API auf (nicht OpenAI-kompatibel)."""
+ messages = [{"role": m.role, "content": m.content or ""} for m in request.messages]
+
+ payload = {
+ "model": model,
+ "messages": messages,
+ "stream": False,
+ "options": {"temperature": request.temperature, "top_p": request.top_p},
+ }
+ if request.max_tokens:
+ payload["options"]["num_predict"] = request.max_tokens
+
+ response = await client.post(f"{backend.base_url}/api/chat", json=payload, timeout=backend.timeout)
+ response.raise_for_status()
+ data = response.json()
+
+ return InferenceResult(
+ content=data.get("message", {}).get("content", ""),
+ model=model, backend="ollama",
+ usage=Usage(
+ prompt_tokens=data.get("prompt_eval_count", 0),
+ completion_tokens=data.get("eval_count", 0),
+ total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
+ ),
+ finish_reason="stop" if data.get("done") else "length",
+ )
+
+
+async def stream_ollama(client, backend, model, request, response_id) -> AsyncIterator[ChatCompletionChunk]:
+ """Streamt von Ollama."""
+ messages = [{"role": m.role, "content": m.content or ""} for m in request.messages]
+
+ payload = {
+ "model": model, "messages": messages, "stream": True,
+ "options": {"temperature": request.temperature, "top_p": request.top_p},
+ }
+ if request.max_tokens:
+ payload["options"]["num_predict"] = request.max_tokens
+
+ async with client.stream("POST", f"{backend.base_url}/api/chat", json=payload, timeout=backend.timeout) as response:
+ response.raise_for_status()
+ async for line in response.aiter_lines():
+ if not line:
+ continue
+ try:
+ data = json.loads(line)
+ content = data.get("message", {}).get("content", "")
+ done = data.get("done", False)
+ yield ChatCompletionChunk(
+ id=response_id, model=model,
+ choices=[StreamChoice(index=0, delta=ChatChoiceDelta(content=content), finish_reason="stop" if done else None)],
+ )
+ except json.JSONDecodeError:
+ continue
+
+
+async def call_openai_compatible(client, backend, model, request) -> InferenceResult:
+ """Ruft OpenAI-kompatible API auf (vLLM, etc.)."""
+ headers = {"Content-Type": "application/json"}
+ if backend.api_key:
+ headers["Authorization"] = f"Bearer {backend.api_key}"
+
+ payload = {
+ "model": model,
+ "messages": [m.model_dump(exclude_none=True) for m in request.messages],
+ "stream": False, "temperature": request.temperature, "top_p": request.top_p,
+ }
+ if request.max_tokens:
+ payload["max_tokens"] = request.max_tokens
+ if request.stop:
+ payload["stop"] = request.stop
+
+ response = await client.post(f"{backend.base_url}/v1/chat/completions", json=payload, headers=headers, timeout=backend.timeout)
+ response.raise_for_status()
+ data = response.json()
+
+ choice = data.get("choices", [{}])[0]
+ usage_data = data.get("usage", {})
+
+ return InferenceResult(
+ content=choice.get("message", {}).get("content", ""),
+ model=model, backend=backend.name,
+ usage=Usage(
+ prompt_tokens=usage_data.get("prompt_tokens", 0),
+ completion_tokens=usage_data.get("completion_tokens", 0),
+ total_tokens=usage_data.get("total_tokens", 0),
+ ),
+ finish_reason=choice.get("finish_reason", "stop"),
+ )
+
+
+async def stream_openai_compatible(client, backend, model, request, response_id) -> AsyncIterator[ChatCompletionChunk]:
+ """Streamt von OpenAI-kompatibler API."""
+ headers = {"Content-Type": "application/json"}
+ if backend.api_key:
+ headers["Authorization"] = f"Bearer {backend.api_key}"
+
+ payload = {
+ "model": model,
+ "messages": [m.model_dump(exclude_none=True) for m in request.messages],
+ "stream": True, "temperature": request.temperature, "top_p": request.top_p,
+ }
+ if request.max_tokens:
+ payload["max_tokens"] = request.max_tokens
+
+ async with client.stream("POST", f"{backend.base_url}/v1/chat/completions", json=payload, headers=headers, timeout=backend.timeout) as response:
+ response.raise_for_status()
+ async for line in response.aiter_lines():
+ if not line or not line.startswith("data: "):
+ continue
+ data_str = line[6:]
+ if data_str == "[DONE]":
+ break
+ try:
+ data = json.loads(data_str)
+ choice = data.get("choices", [{}])[0]
+ delta = choice.get("delta", {})
+ yield ChatCompletionChunk(
+ id=response_id, model=model,
+ choices=[StreamChoice(index=0, delta=ChatChoiceDelta(role=delta.get("role"), content=delta.get("content")), finish_reason=choice.get("finish_reason"))],
+ )
+ except json.JSONDecodeError:
+ continue
+
+
+async def call_anthropic(backend, model, request) -> InferenceResult:
+ """Ruft Anthropic Claude API auf."""
+ try:
+ import anthropic
+ except ImportError:
+ raise ImportError("anthropic package required for Claude API")
+
+ client = anthropic.AsyncAnthropic(api_key=backend.api_key)
+
+ system_content = ""
+ messages = []
+ for msg in request.messages:
+ if msg.role == "system":
+ system_content += (msg.content or "") + "\n"
+ else:
+ messages.append({"role": msg.role, "content": msg.content or ""})
+
+ response = await client.messages.create(
+ model=model, max_tokens=request.max_tokens or 4096,
+ system=system_content.strip() if system_content else None,
+ messages=messages, temperature=request.temperature, top_p=request.top_p,
+ )
+
+ content = ""
+ if response.content:
+ content = response.content[0].text if response.content[0].type == "text" else ""
+
+ return InferenceResult(
+ content=content, model=model, backend="anthropic",
+ usage=Usage(
+ prompt_tokens=response.usage.input_tokens,
+ completion_tokens=response.usage.output_tokens,
+ total_tokens=response.usage.input_tokens + response.usage.output_tokens,
+ ),
+ finish_reason="stop" if response.stop_reason == "end_turn" else response.stop_reason or "stop",
+ )
+
+
+async def stream_anthropic(backend, model, request, response_id) -> AsyncIterator[ChatCompletionChunk]:
+ """Streamt von Anthropic Claude API."""
+ try:
+ import anthropic
+ except ImportError:
+ raise ImportError("anthropic package required for Claude API")
+
+ client = anthropic.AsyncAnthropic(api_key=backend.api_key)
+
+ system_content = ""
+ messages = []
+ for msg in request.messages:
+ if msg.role == "system":
+ system_content += (msg.content or "") + "\n"
+ else:
+ messages.append({"role": msg.role, "content": msg.content or ""})
+
+ async with client.messages.stream(
+ model=model, max_tokens=request.max_tokens or 4096,
+ system=system_content.strip() if system_content else None,
+ messages=messages, temperature=request.temperature, top_p=request.top_p,
+ ) as stream:
+ async for text in stream.text_stream:
+ yield ChatCompletionChunk(
+ id=response_id, model=model,
+ choices=[StreamChoice(index=0, delta=ChatChoiceDelta(content=text), finish_reason=None)],
+ )
+
+ yield ChatCompletionChunk(
+ id=response_id, model=model,
+ choices=[StreamChoice(index=0, delta=ChatChoiceDelta(), finish_reason="stop")],
+ )
diff --git a/backend-lehrer/services/file_processor.py b/backend-lehrer/services/file_processor.py
index 438c220..17ad3dd 100644
--- a/backend-lehrer/services/file_processor.py
+++ b/backend-lehrer/services/file_processor.py
@@ -15,60 +15,24 @@ Verwendet:
"""
import logging
-import os
import io
-import base64
from pathlib import Path
-from typing import Optional, List, Dict, Any, Tuple, Union
-from dataclasses import dataclass
-from enum import Enum
+from typing import Optional, List, Dict, Any
import cv2
import numpy as np
from PIL import Image
+from .file_processor_models import (
+ FileType,
+ ProcessingMode,
+ ProcessedRegion,
+ ProcessingResult,
+)
+
logger = logging.getLogger(__name__)
-class FileType(str, Enum):
- """Unterstützte Dateitypen."""
- PDF = "pdf"
- IMAGE = "image"
- DOCX = "docx"
- DOC = "doc"
- TXT = "txt"
- UNKNOWN = "unknown"
-
-
-class ProcessingMode(str, Enum):
- """Verarbeitungsmodi."""
- OCR_HANDWRITING = "ocr_handwriting" # Handschrifterkennung
- OCR_PRINTED = "ocr_printed" # Gedruckter Text
- TEXT_EXTRACT = "text_extract" # Textextraktion (PDF/DOCX)
- MIXED = "mixed" # Kombiniert OCR + Textextraktion
-
-
-@dataclass
-class ProcessedRegion:
- """Ein erkannter Textbereich."""
- text: str
- confidence: float
- bbox: Tuple[int, int, int, int] # x1, y1, x2, y2
- page: int = 1
-
-
-@dataclass
-class ProcessingResult:
- """Ergebnis der Dokumentenverarbeitung."""
- text: str
- confidence: float
- regions: List[ProcessedRegion]
- page_count: int
- file_type: FileType
- processing_mode: ProcessingMode
- metadata: Dict[str, Any]
-
-
class FileProcessor:
"""
Zentrale Dokumentenverarbeitung für BreakPilot.
@@ -81,17 +45,9 @@ class FileProcessor:
"""
def __init__(self, ocr_lang: str = "de", use_gpu: bool = False):
- """
- Initialisiert den File Processor.
-
- Args:
- ocr_lang: Sprache für OCR (default: "de" für Deutsch)
- use_gpu: GPU für OCR nutzen (beschleunigt Verarbeitung)
- """
self.ocr_lang = ocr_lang
self.use_gpu = use_gpu
self._ocr_engine = None
-
logger.info(f"FileProcessor initialized (lang={ocr_lang}, gpu={use_gpu})")
@property
@@ -107,7 +63,7 @@ class FileProcessor:
from paddleocr import PaddleOCR
return PaddleOCR(
use_angle_cls=True,
- lang='german', # Deutsch
+ lang='german',
use_gpu=self.use_gpu,
show_log=False
)
@@ -116,16 +72,7 @@ class FileProcessor:
return None
def detect_file_type(self, file_path: str = None, file_bytes: bytes = None) -> FileType:
- """
- Erkennt den Dateityp.
-
- Args:
- file_path: Pfad zur Datei
- file_bytes: Dateiinhalt als Bytes
-
- Returns:
- FileType enum
- """
+ """Erkennt den Dateityp."""
if file_path:
ext = Path(file_path).suffix.lower()
if ext == ".pdf":
@@ -140,14 +87,13 @@ class FileProcessor:
return FileType.TXT
if file_bytes:
- # Magic number detection
if file_bytes[:4] == b'%PDF':
return FileType.PDF
elif file_bytes[:8] == b'\x89PNG\r\n\x1a\n':
return FileType.IMAGE
- elif file_bytes[:2] in [b'\xff\xd8', b'BM']: # JPEG, BMP
+ elif file_bytes[:2] in [b'\xff\xd8', b'BM']:
return FileType.IMAGE
- elif file_bytes[:4] == b'PK\x03\x04': # ZIP (DOCX)
+ elif file_bytes[:4] == b'PK\x03\x04':
return FileType.DOCX
return FileType.UNKNOWN
@@ -158,17 +104,7 @@ class FileProcessor:
file_bytes: bytes = None,
mode: ProcessingMode = ProcessingMode.MIXED
) -> ProcessingResult:
- """
- Verarbeitet ein Dokument.
-
- Args:
- file_path: Pfad zur Datei
- file_bytes: Dateiinhalt als Bytes
- mode: Verarbeitungsmodus
-
- Returns:
- ProcessingResult mit extrahiertem Text und Metadaten
- """
+ """Verarbeitet ein Dokument."""
if not file_path and not file_bytes:
raise ValueError("Entweder file_path oder file_bytes muss angegeben werden")
@@ -186,18 +122,12 @@ class FileProcessor:
else:
raise ValueError(f"Nicht unterstützter Dateityp: {file_type}")
- def _process_pdf(
- self,
- file_path: str = None,
- file_bytes: bytes = None,
- mode: ProcessingMode = ProcessingMode.MIXED
- ) -> ProcessingResult:
+ def _process_pdf(self, file_path=None, file_bytes=None, mode=ProcessingMode.MIXED):
"""Verarbeitet PDF-Dateien."""
try:
- import fitz # PyMuPDF
+ import fitz
except ImportError:
logger.warning("PyMuPDF nicht installiert - versuche Fallback")
- # Fallback: PDF als Bild behandeln
return self._process_image(file_path, file_bytes, mode)
if file_bytes:
@@ -205,35 +135,27 @@ class FileProcessor:
else:
doc = fitz.open(file_path)
- all_text = []
- all_regions = []
- total_confidence = 0.0
- region_count = 0
+ all_text, all_regions = [], []
+ total_confidence, region_count = 0.0, 0
for page_num, page in enumerate(doc, start=1):
- # Erst versuchen Text direkt zu extrahieren
page_text = page.get_text()
if page_text.strip() and mode != ProcessingMode.OCR_HANDWRITING:
- # PDF enthält Text (nicht nur Bilder)
all_text.append(page_text)
all_regions.append(ProcessedRegion(
- text=page_text,
- confidence=1.0,
+ text=page_text, confidence=1.0,
bbox=(0, 0, int(page.rect.width), int(page.rect.height)),
page=page_num
))
total_confidence += 1.0
region_count += 1
else:
- # Seite als Bild rendern und OCR anwenden
- pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x Auflösung
+ pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
img_bytes = pix.tobytes("png")
img = Image.open(io.BytesIO(img_bytes))
-
ocr_result = self._ocr_image(img)
all_text.append(ocr_result["text"])
-
for region in ocr_result["regions"]:
region.page = page_num
all_regions.append(region)
@@ -241,55 +163,34 @@ class FileProcessor:
region_count += 1
doc.close()
-
avg_confidence = total_confidence / region_count if region_count > 0 else 0.0
return ProcessingResult(
- text="\n\n".join(all_text),
- confidence=avg_confidence,
+ text="\n\n".join(all_text), confidence=avg_confidence,
regions=all_regions,
page_count=len(doc) if hasattr(doc, '__len__') else 1,
- file_type=FileType.PDF,
- processing_mode=mode,
+ file_type=FileType.PDF, processing_mode=mode,
metadata={"source": file_path or "bytes"}
)
- def _process_image(
- self,
- file_path: str = None,
- file_bytes: bytes = None,
- mode: ProcessingMode = ProcessingMode.MIXED
- ) -> ProcessingResult:
+ def _process_image(self, file_path=None, file_bytes=None, mode=ProcessingMode.MIXED):
"""Verarbeitet Bilddateien."""
if file_bytes:
img = Image.open(io.BytesIO(file_bytes))
else:
img = Image.open(file_path)
- # Bildvorverarbeitung
processed_img = self._preprocess_image(img)
-
- # OCR
ocr_result = self._ocr_image(processed_img)
return ProcessingResult(
- text=ocr_result["text"],
- confidence=ocr_result["confidence"],
- regions=ocr_result["regions"],
- page_count=1,
- file_type=FileType.IMAGE,
- processing_mode=mode,
- metadata={
- "source": file_path or "bytes",
- "image_size": img.size
- }
+ text=ocr_result["text"], confidence=ocr_result["confidence"],
+ regions=ocr_result["regions"], page_count=1,
+ file_type=FileType.IMAGE, processing_mode=mode,
+ metadata={"source": file_path or "bytes", "image_size": img.size}
)
- def _process_docx(
- self,
- file_path: str = None,
- file_bytes: bytes = None
- ) -> ProcessingResult:
+ def _process_docx(self, file_path=None, file_bytes=None):
"""Verarbeitet DOCX-Dateien."""
try:
from docx import Document
@@ -306,7 +207,6 @@ class FileProcessor:
if para.text.strip():
paragraphs.append(para.text)
- # Auch Tabellen extrahieren
for table in doc.tables:
for row in table.rows:
row_text = " | ".join(cell.text for cell in row.cells)
@@ -316,25 +216,14 @@ class FileProcessor:
text = "\n\n".join(paragraphs)
return ProcessingResult(
- text=text,
- confidence=1.0, # Direkte Textextraktion
- regions=[ProcessedRegion(
- text=text,
- confidence=1.0,
- bbox=(0, 0, 0, 0),
- page=1
- )],
- page_count=1,
- file_type=FileType.DOCX,
+ text=text, confidence=1.0,
+ regions=[ProcessedRegion(text=text, confidence=1.0, bbox=(0, 0, 0, 0), page=1)],
+ page_count=1, file_type=FileType.DOCX,
processing_mode=ProcessingMode.TEXT_EXTRACT,
metadata={"source": file_path or "bytes"}
)
- def _process_txt(
- self,
- file_path: str = None,
- file_bytes: bytes = None
- ) -> ProcessingResult:
+ def _process_txt(self, file_path=None, file_bytes=None):
"""Verarbeitet Textdateien."""
if file_bytes:
text = file_bytes.decode('utf-8', errors='ignore')
@@ -343,146 +232,65 @@ class FileProcessor:
text = f.read()
return ProcessingResult(
- text=text,
- confidence=1.0,
- regions=[ProcessedRegion(
- text=text,
- confidence=1.0,
- bbox=(0, 0, 0, 0),
- page=1
- )],
- page_count=1,
- file_type=FileType.TXT,
+ text=text, confidence=1.0,
+ regions=[ProcessedRegion(text=text, confidence=1.0, bbox=(0, 0, 0, 0), page=1)],
+ page_count=1, file_type=FileType.TXT,
processing_mode=ProcessingMode.TEXT_EXTRACT,
metadata={"source": file_path or "bytes"}
)
def _preprocess_image(self, img: Image.Image) -> Image.Image:
- """
- Vorverarbeitung des Bildes für bessere OCR-Ergebnisse.
-
- - Konvertierung zu Graustufen
- - Kontrastverstärkung
- - Rauschunterdrückung
- - Binarisierung
- """
- # PIL zu OpenCV
+ """Vorverarbeitung des Bildes für bessere OCR-Ergebnisse."""
cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
-
- # Zu Graustufen konvertieren
gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
-
- # Rauschunterdrückung
denoised = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
-
- # Kontrastverstärkung (CLAHE)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(denoised)
-
- # Adaptive Binarisierung
binary = cv2.adaptiveThreshold(
- enhanced,
- 255,
- cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
- cv2.THRESH_BINARY,
- 11,
- 2
+ enhanced, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
+ cv2.THRESH_BINARY, 11, 2
)
-
- # Zurück zu PIL
return Image.fromarray(binary)
def _ocr_image(self, img: Image.Image) -> Dict[str, Any]:
- """
- Führt OCR auf einem Bild aus.
-
- Returns:
- Dict mit text, confidence und regions
- """
+ """Führt OCR auf einem Bild aus."""
if self.ocr_engine is None:
- # Fallback wenn kein OCR-Engine verfügbar
- return {
- "text": "[OCR nicht verfügbar - bitte PaddleOCR installieren]",
- "confidence": 0.0,
- "regions": []
- }
+ return {"text": "[OCR nicht verfügbar - bitte PaddleOCR installieren]",
+ "confidence": 0.0, "regions": []}
- # PIL zu numpy array
img_array = np.array(img)
-
- # Wenn Graustufen, zu RGB konvertieren (PaddleOCR erwartet RGB)
if len(img_array.shape) == 2:
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
- # OCR ausführen
result = self.ocr_engine.ocr(img_array, cls=True)
if not result or not result[0]:
return {"text": "", "confidence": 0.0, "regions": []}
- all_text = []
- all_regions = []
+ all_text, all_regions = [], []
total_confidence = 0.0
for line in result[0]:
- bbox_points = line[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
+ bbox_points = line[0]
text, confidence = line[1]
-
- # Bounding Box zu x1, y1, x2, y2 konvertieren
x_coords = [p[0] for p in bbox_points]
y_coords = [p[1] for p in bbox_points]
- bbox = (
- int(min(x_coords)),
- int(min(y_coords)),
- int(max(x_coords)),
- int(max(y_coords))
- )
-
+ bbox = (int(min(x_coords)), int(min(y_coords)),
+ int(max(x_coords)), int(max(y_coords)))
all_text.append(text)
- all_regions.append(ProcessedRegion(
- text=text,
- confidence=confidence,
- bbox=bbox
- ))
+ all_regions.append(ProcessedRegion(text=text, confidence=confidence, bbox=bbox))
total_confidence += confidence
avg_confidence = total_confidence / len(all_regions) if all_regions else 0.0
+ return {"text": "\n".join(all_text), "confidence": avg_confidence, "regions": all_regions}
- return {
- "text": "\n".join(all_text),
- "confidence": avg_confidence,
- "regions": all_regions
- }
-
- def extract_handwriting_regions(
- self,
- img: Image.Image,
- min_area: int = 500
- ) -> List[Dict[str, Any]]:
- """
- Erkennt und extrahiert handschriftliche Bereiche aus einem Bild.
-
- Nützlich für Klausuren mit gedruckten Fragen und handschriftlichen Antworten.
-
- Args:
- img: Eingabebild
- min_area: Minimale Fläche für erkannte Regionen
-
- Returns:
- Liste von Regionen mit Koordinaten und erkanntem Text
- """
- # Bildvorverarbeitung
+ def extract_handwriting_regions(self, img: Image.Image, min_area: int = 500) -> List[Dict[str, Any]]:
+ """Erkennt und extrahiert handschriftliche Bereiche aus einem Bild."""
cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
-
- # Kanten erkennen
edges = cv2.Canny(gray, 50, 150)
-
- # Morphologische Operationen zum Verbinden
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 5))
dilated = cv2.dilate(edges, kernel, iterations=2)
-
- # Konturen finden
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
regions = []
@@ -490,25 +298,15 @@ class FileProcessor:
area = cv2.contourArea(contour)
if area < min_area:
continue
-
x, y, w, h = cv2.boundingRect(contour)
-
- # Region ausschneiden
region_img = img.crop((x, y, x + w, y + h))
-
- # OCR auf Region anwenden
ocr_result = self._ocr_image(region_img)
-
regions.append({
- "bbox": (x, y, x + w, y + h),
- "area": area,
- "text": ocr_result["text"],
- "confidence": ocr_result["confidence"]
+ "bbox": (x, y, x + w, y + h), "area": area,
+ "text": ocr_result["text"], "confidence": ocr_result["confidence"]
})
- # Nach Y-Position sortieren (oben nach unten)
regions.sort(key=lambda r: r["bbox"][1])
-
return regions
@@ -525,39 +323,25 @@ def get_file_processor() -> FileProcessor:
# Convenience functions
-def process_file(
- file_path: str = None,
- file_bytes: bytes = None,
- mode: ProcessingMode = ProcessingMode.MIXED
-) -> ProcessingResult:
- """
- Convenience function zum Verarbeiten einer Datei.
-
- Args:
- file_path: Pfad zur Datei
- file_bytes: Dateiinhalt als Bytes
- mode: Verarbeitungsmodus
-
- Returns:
- ProcessingResult
- """
+def process_file(file_path=None, file_bytes=None, mode=ProcessingMode.MIXED) -> ProcessingResult:
+ """Convenience function zum Verarbeiten einer Datei."""
processor = get_file_processor()
return processor.process(file_path, file_bytes, mode)
-def extract_text_from_pdf(file_path: str = None, file_bytes: bytes = None) -> str:
+def extract_text_from_pdf(file_path=None, file_bytes=None) -> str:
"""Extrahiert Text aus einer PDF-Datei."""
result = process_file(file_path, file_bytes, ProcessingMode.TEXT_EXTRACT)
return result.text
-def ocr_image(file_path: str = None, file_bytes: bytes = None) -> str:
+def ocr_image(file_path=None, file_bytes=None) -> str:
"""Führt OCR auf einem Bild aus."""
result = process_file(file_path, file_bytes, ProcessingMode.OCR_PRINTED)
return result.text
-def ocr_handwriting(file_path: str = None, file_bytes: bytes = None) -> str:
+def ocr_handwriting(file_path=None, file_bytes=None) -> str:
"""Führt Handschrift-OCR auf einem Bild aus."""
result = process_file(file_path, file_bytes, ProcessingMode.OCR_HANDWRITING)
return result.text
diff --git a/backend-lehrer/services/file_processor_models.py b/backend-lehrer/services/file_processor_models.py
new file mode 100644
index 0000000..dc5f084
--- /dev/null
+++ b/backend-lehrer/services/file_processor_models.py
@@ -0,0 +1,48 @@
+"""
+File Processor - Datenmodelle und Enums.
+
+Typen fuer Dokumentenverarbeitung: Dateitypen, Modi, Ergebnisse.
+"""
+
+from typing import List, Dict, Any, Tuple
+from dataclasses import dataclass
+from enum import Enum
+
+
+class FileType(str, Enum):
+ """Unterstützte Dateitypen."""
+ PDF = "pdf"
+ IMAGE = "image"
+ DOCX = "docx"
+ DOC = "doc"
+ TXT = "txt"
+ UNKNOWN = "unknown"
+
+
+class ProcessingMode(str, Enum):
+ """Verarbeitungsmodi."""
+ OCR_HANDWRITING = "ocr_handwriting" # Handschrifterkennung
+ OCR_PRINTED = "ocr_printed" # Gedruckter Text
+ TEXT_EXTRACT = "text_extract" # Textextraktion (PDF/DOCX)
+ MIXED = "mixed" # Kombiniert OCR + Textextraktion
+
+
+@dataclass
+class ProcessedRegion:
+ """Ein erkannter Textbereich."""
+ text: str
+ confidence: float
+ bbox: Tuple[int, int, int, int] # x1, y1, x2, y2
+ page: int = 1
+
+
+@dataclass
+class ProcessingResult:
+ """Ergebnis der Dokumentenverarbeitung."""
+ text: str
+ confidence: float
+ regions: List[ProcessedRegion]
+ page_count: int
+ file_type: FileType
+ processing_mode: ProcessingMode
+ metadata: Dict[str, Any]
diff --git a/backend-lehrer/state_engine_api.py b/backend-lehrer/state_engine_api.py
index fe669d6..3f43ca8 100644
--- a/backend-lehrer/state_engine_api.py
+++ b/backend-lehrer/state_engine_api.py
@@ -12,21 +12,29 @@ Endpoints:
import logging
import uuid
from datetime import datetime, timedelta
-from typing import Dict, Any, List, Optional
+from typing import Dict, Any, List
from fastapi import APIRouter, HTTPException, Query
-from pydantic import BaseModel, Field
from state_engine import (
AnticipationEngine,
PhaseService,
- TeacherContext,
SchoolYearPhase,
ClassSummary,
Event,
- TeacherStats,
get_phase_info,
- PHASE_INFO
+)
+from state_engine_models import (
+ MilestoneRequest,
+ TransitionRequest,
+ ContextResponse,
+ SuggestionsResponse,
+ DashboardResponse,
+ _teacher_contexts,
+ _milestones,
+ get_or_create_context,
+ update_context_from_services,
+ get_phase_display_name,
)
logger = logging.getLogger(__name__)
@@ -41,157 +49,15 @@ _engine = AnticipationEngine()
_phase_service = PhaseService()
-# ============================================================================
-# In-Memory Storage (später durch DB ersetzen)
-# ============================================================================
-
-# Simulierter Lehrer-Kontext (in Produktion aus DB)
-_teacher_contexts: Dict[str, TeacherContext] = {}
-_milestones: Dict[str, List[str]] = {} # teacher_id -> milestones
-
-
-# ============================================================================
-# Pydantic Models
-# ============================================================================
-
-class MilestoneRequest(BaseModel):
- """Request zum Abschließen eines Meilensteins."""
- milestone: str = Field(..., description="Name des Meilensteins")
-
-
-class TransitionRequest(BaseModel):
- """Request für Phasen-Übergang."""
- target_phase: str = Field(..., description="Zielphase")
-
-
-class ContextResponse(BaseModel):
- """Response mit TeacherContext."""
- context: Dict[str, Any]
- phase_info: Dict[str, Any]
-
-
-class SuggestionsResponse(BaseModel):
- """Response mit Vorschlägen."""
- suggestions: List[Dict[str, Any]]
- current_phase: str
- phase_display_name: str
- priority_counts: Dict[str, int]
-
-
-class DashboardResponse(BaseModel):
- """Response mit Dashboard-Daten."""
- context: Dict[str, Any]
- suggestions: List[Dict[str, Any]]
- stats: Dict[str, Any]
- upcoming_events: List[Dict[str, Any]]
- progress: Dict[str, Any]
- phases: List[Dict[str, Any]]
-
-
-# ============================================================================
-# Helper Functions
-# ============================================================================
-
-def _get_or_create_context(teacher_id: str) -> TeacherContext:
- """
- Holt oder erstellt TeacherContext.
-
- In Produktion würde dies aus der Datenbank geladen.
- """
- if teacher_id not in _teacher_contexts:
- # Erstelle Demo-Kontext
- now = datetime.now()
- school_year_start = datetime(now.year if now.month >= 8 else now.year - 1, 8, 1)
- weeks_since_start = (now - school_year_start).days // 7
-
- # Bestimme Phase basierend auf Monat
- month = now.month
- if month in [8, 9]:
- phase = SchoolYearPhase.SCHOOL_YEAR_START
- elif month in [10, 11]:
- phase = SchoolYearPhase.TEACHING_SETUP
- elif month == 12:
- phase = SchoolYearPhase.PERFORMANCE_1
- elif month in [1, 2]:
- phase = SchoolYearPhase.SEMESTER_END
- elif month in [3, 4]:
- phase = SchoolYearPhase.TEACHING_2
- elif month in [5, 6]:
- phase = SchoolYearPhase.PERFORMANCE_2
- else:
- phase = SchoolYearPhase.YEAR_END
-
- _teacher_contexts[teacher_id] = TeacherContext(
- teacher_id=teacher_id,
- school_id=str(uuid.uuid4()),
- school_year_id=str(uuid.uuid4()),
- federal_state="niedersachsen",
- school_type="gymnasium",
- school_year_start=school_year_start,
- current_phase=phase,
- phase_entered_at=now - timedelta(days=7),
- weeks_since_start=weeks_since_start,
- days_in_phase=7,
- classes=[],
- total_students=0,
- upcoming_events=[],
- completed_milestones=_milestones.get(teacher_id, []),
- pending_milestones=[],
- stats=TeacherStats(),
- )
-
- return _teacher_contexts[teacher_id]
-
-
-def _update_context_from_services(ctx: TeacherContext) -> TeacherContext:
- """
- Aktualisiert Kontext mit Daten aus anderen Services.
-
- In Produktion würde dies von school-service, gradebook etc. laden.
- """
- # Simulierte Daten - in Produktion API-Calls
- # Hier könnten wir den Kontext mit echten Daten anreichern
-
- # Berechne days_in_phase
- ctx.days_in_phase = (datetime.now() - ctx.phase_entered_at).days
-
- # Lade abgeschlossene Meilensteine
- ctx.completed_milestones = _milestones.get(ctx.teacher_id, [])
-
- # Berechne pending milestones
- phase_info = get_phase_info(ctx.current_phase)
- ctx.pending_milestones = [
- m for m in phase_info.required_actions
- if m not in ctx.completed_milestones
- ]
-
- return ctx
-
-
-def _get_phase_display_name(phase: str) -> str:
- """Gibt Display-Name für Phase zurück."""
- try:
- return get_phase_info(SchoolYearPhase(phase)).display_name
- except (ValueError, KeyError):
- return phase
-
-
# ============================================================================
# API Endpoints
# ============================================================================
@router.get("/context", response_model=ContextResponse)
async def get_teacher_context(teacher_id: str = Query("demo-teacher")):
- """
- Gibt den aggregierten TeacherContext zurück.
-
- Enthält alle relevanten Informationen für:
- - Phasen-Anzeige
- - Antizipations-Engine
- - Dashboard
- """
- ctx = _get_or_create_context(teacher_id)
- ctx = _update_context_from_services(ctx)
+ """Gibt den aggregierten TeacherContext zurück."""
+ ctx = get_or_create_context(teacher_id)
+ ctx = update_context_from_services(ctx)
phase_info = get_phase_info(ctx.current_phase)
@@ -210,10 +76,8 @@ async def get_teacher_context(teacher_id: str = Query("demo-teacher")):
@router.get("/phase")
async def get_current_phase(teacher_id: str = Query("demo-teacher")):
- """
- Gibt die aktuelle Phase mit Details zurück.
- """
- ctx = _get_or_create_context(teacher_id)
+ """Gibt die aktuelle Phase mit Details zurück."""
+ ctx = get_or_create_context(teacher_id)
phase_info = get_phase_info(ctx.current_phase)
return {
@@ -230,11 +94,7 @@ async def get_current_phase(teacher_id: str = Query("demo-teacher")):
@router.get("/phases")
async def get_all_phases():
- """
- Gibt alle Phasen mit Metadaten zurück.
-
- Nützlich für die Phasen-Anzeige im Dashboard.
- """
+ """Gibt alle Phasen mit Metadaten zurück."""
return {
"phases": _phase_service.get_all_phases()
}
@@ -242,13 +102,9 @@ async def get_all_phases():
@router.get("/suggestions", response_model=SuggestionsResponse)
async def get_suggestions(teacher_id: str = Query("demo-teacher")):
- """
- Gibt Vorschläge basierend auf dem aktuellen Kontext zurück.
-
- Die Vorschläge sind priorisiert und auf max. 5 limitiert.
- """
- ctx = _get_or_create_context(teacher_id)
- ctx = _update_context_from_services(ctx)
+ """Gibt Vorschläge basierend auf dem aktuellen Kontext zurück."""
+ ctx = get_or_create_context(teacher_id)
+ ctx = update_context_from_services(ctx)
suggestions = _engine.get_suggestions(ctx)
priority_counts = _engine.count_by_priority(ctx)
@@ -256,18 +112,16 @@ async def get_suggestions(teacher_id: str = Query("demo-teacher")):
return SuggestionsResponse(
suggestions=[s.to_dict() for s in suggestions],
current_phase=ctx.current_phase.value,
- phase_display_name=_get_phase_display_name(ctx.current_phase.value),
+ phase_display_name=get_phase_display_name(ctx.current_phase.value),
priority_counts=priority_counts,
)
@router.get("/suggestions/top")
async def get_top_suggestion(teacher_id: str = Query("demo-teacher")):
- """
- Gibt den wichtigsten einzelnen Vorschlag zurück.
- """
- ctx = _get_or_create_context(teacher_id)
- ctx = _update_context_from_services(ctx)
+ """Gibt den wichtigsten einzelnen Vorschlag zurück."""
+ ctx = get_or_create_context(teacher_id)
+ ctx = update_context_from_services(ctx)
suggestion = _engine.get_top_suggestion(ctx)
@@ -284,28 +138,17 @@ async def get_top_suggestion(teacher_id: str = Query("demo-teacher")):
@router.get("/dashboard", response_model=DashboardResponse)
async def get_dashboard_data(teacher_id: str = Query("demo-teacher")):
- """
- Gibt alle Daten für das Begleiter-Dashboard zurück.
-
- Kombiniert:
- - TeacherContext
- - Vorschläge
- - Statistiken
- - Termine
- - Fortschritt
- """
- ctx = _get_or_create_context(teacher_id)
- ctx = _update_context_from_services(ctx)
+ """Gibt alle Daten für das Begleiter-Dashboard zurück."""
+ ctx = get_or_create_context(teacher_id)
+ ctx = update_context_from_services(ctx)
suggestions = _engine.get_suggestions(ctx)
phase_info = get_phase_info(ctx.current_phase)
- # Berechne Fortschritt
required = set(phase_info.required_actions)
completed = set(ctx.completed_milestones)
completed_in_phase = len(required.intersection(completed))
- # Alle Phasen für Anzeige
all_phases = []
phase_order = [
SchoolYearPhase.ONBOARDING,
@@ -376,14 +219,9 @@ async def complete_milestone(
request: MilestoneRequest,
teacher_id: str = Query("demo-teacher")
):
- """
- Markiert einen Meilenstein als erledigt.
-
- Prüft automatisch ob ein Phasen-Übergang möglich ist.
- """
+ """Markiert einen Meilenstein als erledigt."""
milestone = request.milestone
- # Speichere Meilenstein
if teacher_id not in _milestones:
_milestones[teacher_id] = []
@@ -391,12 +229,10 @@ async def complete_milestone(
_milestones[teacher_id].append(milestone)
logger.info(f"Milestone '{milestone}' completed for teacher {teacher_id}")
- # Aktualisiere Kontext
- ctx = _get_or_create_context(teacher_id)
+ ctx = get_or_create_context(teacher_id)
ctx.completed_milestones = _milestones[teacher_id]
_teacher_contexts[teacher_id] = ctx
- # Prüfe automatischen Phasen-Übergang
new_phase = _phase_service.check_and_transition(ctx)
if new_phase:
@@ -420,9 +256,7 @@ async def transition_phase(
request: TransitionRequest,
teacher_id: str = Query("demo-teacher")
):
- """
- Führt einen manuellen Phasen-Übergang durch.
- """
+ """Führt einen manuellen Phasen-Übergang durch."""
try:
target_phase = SchoolYearPhase(request.target_phase)
except ValueError:
@@ -431,16 +265,14 @@ async def transition_phase(
detail=f"Ungültige Phase: {request.target_phase}"
)
- ctx = _get_or_create_context(teacher_id)
+ ctx = get_or_create_context(teacher_id)
- # Prüfe ob Übergang erlaubt
if not _phase_service.can_transition_to(ctx, target_phase):
raise HTTPException(
status_code=400,
detail=f"Übergang von {ctx.current_phase.value} zu {target_phase.value} nicht erlaubt"
)
- # Führe Übergang durch
old_phase = ctx.current_phase
ctx.current_phase = target_phase
ctx.phase_entered_at = datetime.now()
@@ -459,10 +291,8 @@ async def transition_phase(
@router.get("/next-phase")
async def get_next_phase(teacher_id: str = Query("demo-teacher")):
- """
- Gibt die nächste Phase und Anforderungen zurück.
- """
- ctx = _get_or_create_context(teacher_id)
+ """Gibt die nächste Phase und Anforderungen zurück."""
+ ctx = get_or_create_context(teacher_id)
next_phase = _phase_service.get_next_phase(ctx.current_phase)
if not next_phase:
@@ -475,7 +305,6 @@ async def get_next_phase(teacher_id: str = Query("demo-teacher")):
next_info = get_phase_info(next_phase)
current_info = get_phase_info(ctx.current_phase)
- # Fehlende Anforderungen
missing = [
m for m in current_info.required_actions
if m not in ctx.completed_milestones
@@ -505,7 +334,7 @@ async def demo_add_class(
teacher_id: str = Query("demo-teacher")
):
"""Demo: Fügt eine Klasse zum Kontext hinzu."""
- ctx = _get_or_create_context(teacher_id)
+ ctx = get_or_create_context(teacher_id)
ctx.classes.append(ClassSummary(
class_id=str(uuid.uuid4()),
@@ -515,7 +344,6 @@ async def demo_add_class(
subject="Deutsch"
))
ctx.total_students += student_count
-
_teacher_contexts[teacher_id] = ctx
return {"success": True, "classes": len(ctx.classes)}
@@ -529,7 +357,7 @@ async def demo_add_event(
teacher_id: str = Query("demo-teacher")
):
"""Demo: Fügt ein Event zum Kontext hinzu."""
- ctx = _get_or_create_context(teacher_id)
+ ctx = get_or_create_context(teacher_id)
ctx.upcoming_events.append(Event(
type=event_type,
@@ -538,7 +366,6 @@ async def demo_add_event(
in_days=in_days,
priority="high" if in_days <= 3 else "medium"
))
-
_teacher_contexts[teacher_id] = ctx
return {"success": True, "events": len(ctx.upcoming_events)}
@@ -554,7 +381,7 @@ async def demo_update_stats(
teacher_id: str = Query("demo-teacher")
):
"""Demo: Aktualisiert Statistiken."""
- ctx = _get_or_create_context(teacher_id)
+ ctx = get_or_create_context(teacher_id)
if learning_units:
ctx.stats.learning_units_created = learning_units
diff --git a/backend-lehrer/state_engine_models.py b/backend-lehrer/state_engine_models.py
new file mode 100644
index 0000000..778b2f7
--- /dev/null
+++ b/backend-lehrer/state_engine_models.py
@@ -0,0 +1,143 @@
+"""
+State Engine API - Pydantic Models und Helper Functions.
+"""
+
+import uuid
+from datetime import datetime, timedelta
+from typing import Dict, Any, List, Optional
+
+from pydantic import BaseModel, Field
+
+from state_engine import (
+ SchoolYearPhase,
+ ClassSummary,
+ Event,
+ TeacherContext,
+ TeacherStats,
+ get_phase_info,
+)
+
+
+# ============================================================================
+# In-Memory Storage (später durch DB ersetzen)
+# ============================================================================
+
+_teacher_contexts: Dict[str, TeacherContext] = {}
+_milestones: Dict[str, List[str]] = {} # teacher_id -> milestones
+
+
+# ============================================================================
+# Pydantic Models
+# ============================================================================
+
+class MilestoneRequest(BaseModel):
+ """Request zum Abschließen eines Meilensteins."""
+ milestone: str = Field(..., description="Name des Meilensteins")
+
+
+class TransitionRequest(BaseModel):
+ """Request für Phasen-Übergang."""
+ target_phase: str = Field(..., description="Zielphase")
+
+
+class ContextResponse(BaseModel):
+ """Response mit TeacherContext."""
+ context: Dict[str, Any]
+ phase_info: Dict[str, Any]
+
+
+class SuggestionsResponse(BaseModel):
+ """Response mit Vorschlägen."""
+ suggestions: List[Dict[str, Any]]
+ current_phase: str
+ phase_display_name: str
+ priority_counts: Dict[str, int]
+
+
+class DashboardResponse(BaseModel):
+ """Response mit Dashboard-Daten."""
+ context: Dict[str, Any]
+ suggestions: List[Dict[str, Any]]
+ stats: Dict[str, Any]
+ upcoming_events: List[Dict[str, Any]]
+ progress: Dict[str, Any]
+ phases: List[Dict[str, Any]]
+
+
+# ============================================================================
+# Helper Functions
+# ============================================================================
+
+def get_or_create_context(teacher_id: str) -> TeacherContext:
+ """
+ Holt oder erstellt TeacherContext.
+
+ In Produktion würde dies aus der Datenbank geladen.
+ """
+ if teacher_id not in _teacher_contexts:
+ now = datetime.now()
+ school_year_start = datetime(now.year if now.month >= 8 else now.year - 1, 8, 1)
+ weeks_since_start = (now - school_year_start).days // 7
+
+ month = now.month
+ if month in [8, 9]:
+ phase = SchoolYearPhase.SCHOOL_YEAR_START
+ elif month in [10, 11]:
+ phase = SchoolYearPhase.TEACHING_SETUP
+ elif month == 12:
+ phase = SchoolYearPhase.PERFORMANCE_1
+ elif month in [1, 2]:
+ phase = SchoolYearPhase.SEMESTER_END
+ elif month in [3, 4]:
+ phase = SchoolYearPhase.TEACHING_2
+ elif month in [5, 6]:
+ phase = SchoolYearPhase.PERFORMANCE_2
+ else:
+ phase = SchoolYearPhase.YEAR_END
+
+ _teacher_contexts[teacher_id] = TeacherContext(
+ teacher_id=teacher_id,
+ school_id=str(uuid.uuid4()),
+ school_year_id=str(uuid.uuid4()),
+ federal_state="niedersachsen",
+ school_type="gymnasium",
+ school_year_start=school_year_start,
+ current_phase=phase,
+ phase_entered_at=now - timedelta(days=7),
+ weeks_since_start=weeks_since_start,
+ days_in_phase=7,
+ classes=[],
+ total_students=0,
+ upcoming_events=[],
+ completed_milestones=_milestones.get(teacher_id, []),
+ pending_milestones=[],
+ stats=TeacherStats(),
+ )
+
+ return _teacher_contexts[teacher_id]
+
+
+def update_context_from_services(ctx: TeacherContext) -> TeacherContext:
+ """
+ Aktualisiert Kontext mit Daten aus anderen Services.
+
+ In Produktion würde dies von school-service, gradebook etc. laden.
+ """
+ ctx.days_in_phase = (datetime.now() - ctx.phase_entered_at).days
+ ctx.completed_milestones = _milestones.get(ctx.teacher_id, [])
+
+ phase_info = get_phase_info(ctx.current_phase)
+ ctx.pending_milestones = [
+ m for m in phase_info.required_actions
+ if m not in ctx.completed_milestones
+ ]
+
+ return ctx
+
+
+def get_phase_display_name(phase: str) -> str:
+ """Gibt Display-Name für Phase zurück."""
+ try:
+ return get_phase_info(SchoolYearPhase(phase)).display_name
+ except (ValueError, KeyError):
+ return phase
diff --git a/backend-lehrer/worksheets_api.py b/backend-lehrer/worksheets_api.py
index 527bf95..d6f9f8d 100644
--- a/backend-lehrer/worksheets_api.py
+++ b/backend-lehrer/worksheets_api.py
@@ -16,11 +16,9 @@ Unterstützt:
import logging
import uuid
from datetime import datetime
-from typing import List, Dict, Any, Optional
-from enum import Enum
+from typing import Dict
-from fastapi import APIRouter, HTTPException, UploadFile, File, Form
-from pydantic import BaseModel, Field
+from fastapi import APIRouter, HTTPException
from generators import (
MultipleChoiceGenerator,
@@ -28,9 +26,22 @@ from generators import (
MindmapGenerator,
QuizGenerator
)
-from generators.mc_generator import Difficulty
-from generators.cloze_generator import ClozeType
-from generators.quiz_generator import QuizType
+
+from worksheets_models import (
+ ContentType,
+ GenerateRequest,
+ MCGenerateRequest,
+ ClozeGenerateRequest,
+ MindmapGenerateRequest,
+ QuizGenerateRequest,
+ BatchGenerateRequest,
+ WorksheetContent,
+ GenerateResponse,
+ BatchGenerateResponse,
+ parse_difficulty,
+ parse_cloze_type,
+ parse_quiz_types,
+)
logger = logging.getLogger(__name__)
@@ -40,89 +51,6 @@ router = APIRouter(
)
-# ============================================================================
-# Pydantic Models
-# ============================================================================
-
-class ContentType(str, Enum):
- """Verfügbare Content-Typen."""
- MULTIPLE_CHOICE = "multiple_choice"
- CLOZE = "cloze"
- MINDMAP = "mindmap"
- QUIZ = "quiz"
-
-
-class GenerateRequest(BaseModel):
- """Basis-Request für Generierung."""
- source_text: str = Field(..., min_length=50, description="Quelltext für Generierung")
- topic: Optional[str] = Field(None, description="Thema/Titel")
- subject: Optional[str] = Field(None, description="Fach")
- grade_level: Optional[str] = Field(None, description="Klassenstufe")
-
-
-class MCGenerateRequest(GenerateRequest):
- """Request für Multiple-Choice-Generierung."""
- num_questions: int = Field(5, ge=1, le=20, description="Anzahl Fragen")
- difficulty: str = Field("medium", description="easy, medium, hard")
-
-
-class ClozeGenerateRequest(GenerateRequest):
- """Request für Lückentext-Generierung."""
- num_gaps: int = Field(5, ge=1, le=15, description="Anzahl Lücken")
- difficulty: str = Field("medium", description="easy, medium, hard")
- cloze_type: str = Field("fill_in", description="fill_in, drag_drop, dropdown")
-
-
-class MindmapGenerateRequest(GenerateRequest):
- """Request für Mindmap-Generierung."""
- max_depth: int = Field(3, ge=2, le=5, description="Maximale Tiefe")
-
-
-class QuizGenerateRequest(GenerateRequest):
- """Request für Quiz-Generierung."""
- quiz_types: List[str] = Field(
- ["true_false", "matching"],
- description="Typen: true_false, matching, sorting, open_ended"
- )
- num_items: int = Field(5, ge=1, le=10, description="Items pro Typ")
- difficulty: str = Field("medium", description="easy, medium, hard")
-
-
-class BatchGenerateRequest(BaseModel):
- """Request für Batch-Generierung mehrerer Content-Typen."""
- source_text: str = Field(..., min_length=50)
- content_types: List[str] = Field(..., description="Liste von Content-Typen")
- topic: Optional[str] = None
- subject: Optional[str] = None
- grade_level: Optional[str] = None
- difficulty: str = "medium"
-
-
-class WorksheetContent(BaseModel):
- """Generierter Content."""
- id: str
- content_type: str
- data: Dict[str, Any]
- h5p_format: Optional[Dict[str, Any]] = None
- created_at: datetime
- topic: Optional[str] = None
- difficulty: Optional[str] = None
-
-
-class GenerateResponse(BaseModel):
- """Response mit generiertem Content."""
- success: bool
- content: Optional[WorksheetContent] = None
- error: Optional[str] = None
-
-
-class BatchGenerateResponse(BaseModel):
- """Response für Batch-Generierung."""
- success: bool
- contents: List[WorksheetContent] = []
- errors: List[str] = []
-
-
# ============================================================================
# In-Memory Storage (später durch DB ersetzen)
# ============================================================================
@@ -134,49 +62,12 @@ _generated_content: Dict[str, WorksheetContent] = {}
# Generator Instances
# ============================================================================
-# Generatoren ohne LLM-Client (automatische Generierung)
-# In Produktion würde hier der LLM-Client injiziert
mc_generator = MultipleChoiceGenerator()
cloze_generator = ClozeGenerator()
mindmap_generator = MindmapGenerator()
quiz_generator = QuizGenerator()
-# ============================================================================
-# Helper Functions
-# ============================================================================
-
-def _parse_difficulty(difficulty_str: str) -> Difficulty:
- """Konvertiert String zu Difficulty Enum."""
- mapping = {
- "easy": Difficulty.EASY,
- "medium": Difficulty.MEDIUM,
- "hard": Difficulty.HARD
- }
- return mapping.get(difficulty_str.lower(), Difficulty.MEDIUM)
-
-
-def _parse_cloze_type(type_str: str) -> ClozeType:
- """Konvertiert String zu ClozeType Enum."""
- mapping = {
- "fill_in": ClozeType.FILL_IN,
- "drag_drop": ClozeType.DRAG_DROP,
- "dropdown": ClozeType.DROPDOWN
- }
- return mapping.get(type_str.lower(), ClozeType.FILL_IN)
-
-
-def _parse_quiz_types(type_strs: List[str]) -> List[QuizType]:
- """Konvertiert String-Liste zu QuizType Enums."""
- mapping = {
- "true_false": QuizType.TRUE_FALSE,
- "matching": QuizType.MATCHING,
- "sorting": QuizType.SORTING,
- "open_ended": QuizType.OPEN_ENDED
- }
- return [mapping.get(t.lower(), QuizType.TRUE_FALSE) for t in type_strs]
-
-
def _store_content(content: WorksheetContent) -> None:
"""Speichert generierten Content."""
_generated_content[content.id] = content
@@ -188,15 +79,9 @@ def _store_content(content: WorksheetContent) -> None:
@router.post("/generate/multiple-choice", response_model=GenerateResponse)
async def generate_multiple_choice(request: MCGenerateRequest):
- """
- Generiert Multiple-Choice-Fragen aus Quelltext.
-
- - **source_text**: Text mit mind. 50 Zeichen
- - **num_questions**: Anzahl Fragen (1-20)
- - **difficulty**: easy, medium, hard
- """
+ """Generiert Multiple-Choice-Fragen aus Quelltext."""
try:
- difficulty = _parse_difficulty(request.difficulty)
+ difficulty = parse_difficulty(request.difficulty)
questions = mc_generator.generate(
source_text=request.source_text,
@@ -212,7 +97,6 @@ async def generate_multiple_choice(request: MCGenerateRequest):
error="Keine Fragen generiert. Text möglicherweise zu kurz."
)
- # Konvertiere zu Dict
questions_dict = mc_generator.to_dict(questions)
h5p_format = mc_generator.to_h5p_format(questions)
@@ -227,7 +111,6 @@ async def generate_multiple_choice(request: MCGenerateRequest):
)
_store_content(content)
-
return GenerateResponse(success=True, content=content)
except Exception as e:
@@ -237,15 +120,9 @@ async def generate_multiple_choice(request: MCGenerateRequest):
@router.post("/generate/cloze", response_model=GenerateResponse)
async def generate_cloze(request: ClozeGenerateRequest):
- """
- Generiert Lückentext aus Quelltext.
-
- - **source_text**: Text mit mind. 50 Zeichen
- - **num_gaps**: Anzahl Lücken (1-15)
- - **cloze_type**: fill_in, drag_drop, dropdown
- """
+ """Generiert Lückentext aus Quelltext."""
try:
- cloze_type = _parse_cloze_type(request.cloze_type)
+ cloze_type = parse_cloze_type(request.cloze_type)
cloze = cloze_generator.generate(
source_text=request.source_text,
@@ -275,7 +152,6 @@ async def generate_cloze(request: ClozeGenerateRequest):
)
_store_content(content)
-
return GenerateResponse(success=True, content=content)
except Exception as e:
@@ -285,12 +161,7 @@ async def generate_cloze(request: ClozeGenerateRequest):
@router.post("/generate/mindmap", response_model=GenerateResponse)
async def generate_mindmap(request: MindmapGenerateRequest):
- """
- Generiert Mindmap aus Quelltext.
-
- - **source_text**: Text mit mind. 50 Zeichen
- - **max_depth**: Maximale Tiefe (2-5)
- """
+ """Generiert Mindmap aus Quelltext."""
try:
mindmap = mindmap_generator.generate(
source_text=request.source_text,
@@ -317,14 +188,13 @@ async def generate_mindmap(request: MindmapGenerateRequest):
"mermaid": mermaid,
"json_tree": json_tree
},
- h5p_format=None, # Mindmaps haben kein H5P-Format
+ h5p_format=None,
created_at=datetime.utcnow(),
topic=request.topic,
difficulty=None
)
_store_content(content)
-
return GenerateResponse(success=True, content=content)
except Exception as e:
@@ -334,17 +204,10 @@ async def generate_mindmap(request: MindmapGenerateRequest):
@router.post("/generate/quiz", response_model=GenerateResponse)
async def generate_quiz(request: QuizGenerateRequest):
- """
- Generiert Quiz mit verschiedenen Fragetypen.
-
- - **source_text**: Text mit mind. 50 Zeichen
- - **quiz_types**: Liste von true_false, matching, sorting, open_ended
- - **num_items**: Items pro Typ (1-10)
- """
+ """Generiert Quiz mit verschiedenen Fragetypen."""
try:
- quiz_types = _parse_quiz_types(request.quiz_types)
+ quiz_types = parse_quiz_types(request.quiz_types)
- # Generate quiz for each type and combine results
all_questions = []
quizzes = []
@@ -365,7 +228,6 @@ async def generate_quiz(request: QuizGenerateRequest):
error="Quiz konnte nicht generiert werden. Text möglicherweise zu kurz."
)
- # Combine all quizzes into a single dict
combined_quiz_dict = {
"quiz_types": [qt.value for qt in quiz_types],
"title": f"Combined Quiz - {request.topic or 'Various Topics'}",
@@ -374,12 +236,10 @@ async def generate_quiz(request: QuizGenerateRequest):
"questions": []
}
- # Add questions from each quiz
for quiz in quizzes:
quiz_dict = quiz_generator.to_dict(quiz)
combined_quiz_dict["questions"].extend(quiz_dict.get("questions", []))
- # Use first quiz's H5P format as base (or empty if none)
h5p_format = quiz_generator.to_h5p_format(quizzes[0]) if quizzes else {}
content = WorksheetContent(
@@ -393,7 +253,6 @@ async def generate_quiz(request: QuizGenerateRequest):
)
_store_content(content)
-
return GenerateResponse(success=True, content=content)
except Exception as e:
@@ -403,22 +262,10 @@ async def generate_quiz(request: QuizGenerateRequest):
@router.post("/generate/batch", response_model=BatchGenerateResponse)
async def generate_batch(request: BatchGenerateRequest):
- """
- Generiert mehrere Content-Typen aus einem Quelltext.
-
- Ideal für die Erstellung kompletter Arbeitsblätter mit
- verschiedenen Übungstypen.
- """
+ """Generiert mehrere Content-Typen aus einem Quelltext."""
contents = []
errors = []
- type_mapping = {
- "multiple_choice": MCGenerateRequest,
- "cloze": ClozeGenerateRequest,
- "mindmap": MindmapGenerateRequest,
- "quiz": QuizGenerateRequest
- }
-
for content_type in request.content_types:
try:
if content_type == "multiple_choice":
diff --git a/backend-lehrer/worksheets_models.py b/backend-lehrer/worksheets_models.py
new file mode 100644
index 0000000..87c6b25
--- /dev/null
+++ b/backend-lehrer/worksheets_models.py
@@ -0,0 +1,135 @@
+"""
+Worksheets API - Pydantic Models und Helpers.
+
+Request-/Response-Models und Hilfsfunktionen fuer die
+Arbeitsblatt-Generierungs-API.
+"""
+
+import uuid
+from datetime import datetime
+from typing import List, Dict, Any, Optional
+from enum import Enum
+
+from pydantic import BaseModel, Field
+
+from generators.mc_generator import Difficulty
+from generators.cloze_generator import ClozeType
+from generators.quiz_generator import QuizType
+
+
+# ============================================================================
+# Pydantic Models
+# ============================================================================
+
+class ContentType(str, Enum):
+ """Verfügbare Content-Typen."""
+ MULTIPLE_CHOICE = "multiple_choice"
+ CLOZE = "cloze"
+ MINDMAP = "mindmap"
+ QUIZ = "quiz"
+
+
+class GenerateRequest(BaseModel):
+ """Basis-Request für Generierung."""
+ source_text: str = Field(..., min_length=50, description="Quelltext für Generierung")
+ topic: Optional[str] = Field(None, description="Thema/Titel")
+ subject: Optional[str] = Field(None, description="Fach")
+ grade_level: Optional[str] = Field(None, description="Klassenstufe")
+
+
+class MCGenerateRequest(GenerateRequest):
+ """Request für Multiple-Choice-Generierung."""
+ num_questions: int = Field(5, ge=1, le=20, description="Anzahl Fragen")
+ difficulty: str = Field("medium", description="easy, medium, hard")
+
+
+class ClozeGenerateRequest(GenerateRequest):
+ """Request für Lückentext-Generierung."""
+ num_gaps: int = Field(5, ge=1, le=15, description="Anzahl Lücken")
+ difficulty: str = Field("medium", description="easy, medium, hard")
+ cloze_type: str = Field("fill_in", description="fill_in, drag_drop, dropdown")
+
+
+class MindmapGenerateRequest(GenerateRequest):
+ """Request für Mindmap-Generierung."""
+ max_depth: int = Field(3, ge=2, le=5, description="Maximale Tiefe")
+
+
+class QuizGenerateRequest(GenerateRequest):
+ """Request für Quiz-Generierung."""
+ quiz_types: List[str] = Field(
+ ["true_false", "matching"],
+ description="Typen: true_false, matching, sorting, open_ended"
+ )
+ num_items: int = Field(5, ge=1, le=10, description="Items pro Typ")
+ difficulty: str = Field("medium", description="easy, medium, hard")
+
+
+class BatchGenerateRequest(BaseModel):
+ """Request für Batch-Generierung mehrerer Content-Typen."""
+ source_text: str = Field(..., min_length=50)
+ content_types: List[str] = Field(..., description="Liste von Content-Typen")
+ topic: Optional[str] = None
+ subject: Optional[str] = None
+ grade_level: Optional[str] = None
+ difficulty: str = "medium"
+
+
+class WorksheetContent(BaseModel):
+ """Generierter Content."""
+ id: str
+ content_type: str
+ data: Dict[str, Any]
+ h5p_format: Optional[Dict[str, Any]] = None
+ created_at: datetime
+ topic: Optional[str] = None
+ difficulty: Optional[str] = None
+
+
+class GenerateResponse(BaseModel):
+ """Response mit generiertem Content."""
+ success: bool
+ content: Optional[WorksheetContent] = None
+ error: Optional[str] = None
+
+
+class BatchGenerateResponse(BaseModel):
+ """Response für Batch-Generierung."""
+ success: bool
+ contents: List[WorksheetContent] = []
+ errors: List[str] = []
+
+
+# ============================================================================
+# Helper Functions
+# ============================================================================
+
+def parse_difficulty(difficulty_str: str) -> Difficulty:
+ """Konvertiert String zu Difficulty Enum."""
+ mapping = {
+ "easy": Difficulty.EASY,
+ "medium": Difficulty.MEDIUM,
+ "hard": Difficulty.HARD
+ }
+ return mapping.get(difficulty_str.lower(), Difficulty.MEDIUM)
+
+
+def parse_cloze_type(type_str: str) -> ClozeType:
+ """Konvertiert String zu ClozeType Enum."""
+ mapping = {
+ "fill_in": ClozeType.FILL_IN,
+ "drag_drop": ClozeType.DRAG_DROP,
+ "dropdown": ClozeType.DROPDOWN
+ }
+ return mapping.get(type_str.lower(), ClozeType.FILL_IN)
+
+
+def parse_quiz_types(type_strs: List[str]) -> List[QuizType]:
+ """Konvertiert String-Liste zu QuizType Enums."""
+ mapping = {
+ "true_false": QuizType.TRUE_FALSE,
+ "matching": QuizType.MATCHING,
+ "sorting": QuizType.SORTING,
+ "open_ended": QuizType.OPEN_ENDED
+ }
+ return [mapping.get(t.lower(), QuizType.TRUE_FALSE) for t in type_strs]
diff --git a/klausur-service/backend/cv_gutter_repair.py b/klausur-service/backend/cv_gutter_repair.py
index 03c7bd1..fc6fc6c 100644
--- a/klausur-service/backend/cv_gutter_repair.py
+++ b/klausur-service/backend/cv_gutter_repair.py
@@ -1,610 +1,35 @@
"""
-Gutter Repair — detects and fixes words truncated or blurred at the book gutter.
+Gutter Repair — barrel re-export.
-When scanning double-page spreads, the binding area (gutter) causes:
- 1. Blurry/garbled trailing characters ("stammeli" → "stammeln")
- 2. Words split across lines with a hyphen lost in the gutter
- ("ve" + "künden" → "verkünden")
-
-This module analyses grid cells, identifies gutter-edge candidates, and
-proposes corrections using pyspellchecker (DE + EN).
+All implementation split into:
+ cv_gutter_repair_core — spellchecker setup, data types, single-word repair
+ cv_gutter_repair_grid — grid analysis, suggestion application
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
-import itertools
-import logging
-import re
-import time
-import uuid
-from dataclasses import dataclass, field, asdict
-from typing import Any, Dict, List, Optional, Tuple
-
-logger = logging.getLogger(__name__)
-
-# ---------------------------------------------------------------------------
-# Spellchecker setup (lazy, cached)
-# ---------------------------------------------------------------------------
-
-_spell_de = None
-_spell_en = None
-_SPELL_AVAILABLE = False
-
-def _init_spellcheckers():
- """Lazy-load DE + EN spellcheckers (cached across calls)."""
- global _spell_de, _spell_en, _SPELL_AVAILABLE
- if _spell_de is not None:
- return
- try:
- from spellchecker import SpellChecker
- _spell_de = SpellChecker(language='de', distance=1)
- _spell_en = SpellChecker(language='en', distance=1)
- _SPELL_AVAILABLE = True
- logger.info("Gutter repair: spellcheckers loaded (DE + EN)")
- except ImportError:
- logger.warning("pyspellchecker not installed — gutter repair unavailable")
-
-
-def _is_known(word: str) -> bool:
- """Check if a word is known in DE or EN dictionary."""
- _init_spellcheckers()
- if not _SPELL_AVAILABLE:
- return False
- w = word.lower()
- return bool(_spell_de.known([w])) or bool(_spell_en.known([w]))
-
-
-def _spell_candidates(word: str, lang: str = "both") -> List[str]:
- """Get all plausible spellchecker candidates for a word (deduplicated)."""
- _init_spellcheckers()
- if not _SPELL_AVAILABLE:
- return []
- w = word.lower()
- seen: set = set()
- results: List[str] = []
-
- for checker in ([_spell_de, _spell_en] if lang == "both"
- else [_spell_de] if lang == "de"
- else [_spell_en]):
- if checker is None:
- continue
- cands = checker.candidates(w)
- if cands:
- for c in cands:
- if c and c != w and c not in seen:
- seen.add(c)
- results.append(c)
-
- return results
-
-
-# ---------------------------------------------------------------------------
-# Gutter position detection
-# ---------------------------------------------------------------------------
-
-# Minimum word length for spell-fix (very short words are often legitimate)
-_MIN_WORD_LEN_SPELL = 3
-
-# Minimum word length for hyphen-join candidates (fragments at the gutter
-# can be as short as 1-2 chars, e.g. "ve" from "ver-künden")
-_MIN_WORD_LEN_HYPHEN = 2
-
-# How close to the right column edge a word must be to count as "gutter-adjacent".
-# Expressed as fraction of column width (e.g. 0.75 = rightmost 25%).
-_GUTTER_EDGE_THRESHOLD = 0.70
-
-# Small common words / abbreviations that should NOT be repaired
-_STOPWORDS = frozenset([
- # German
- "ab", "an", "am", "da", "er", "es", "im", "in", "ja", "ob", "so", "um",
- "zu", "wo", "du", "eh", "ei", "je", "na", "nu", "oh",
- # English
- "a", "am", "an", "as", "at", "be", "by", "do", "go", "he", "if", "in",
- "is", "it", "me", "my", "no", "of", "on", "or", "so", "to", "up", "us",
- "we",
-])
-
-# IPA / phonetic patterns — skip these cells
-_IPA_RE = re.compile(r'[\[\]/ˈˌːʃʒθðŋɑɒæɔəɛɪʊʌ]')
-
-
-def _is_ipa_text(text: str) -> bool:
- """True if text looks like IPA transcription."""
- return bool(_IPA_RE.search(text))
-
-
-def _word_is_at_gutter_edge(word_bbox: Dict, col_x: float, col_width: float) -> bool:
- """Check if a word's right edge is near the right boundary of its column."""
- if col_width <= 0:
- return False
- word_right = word_bbox.get("left", 0) + word_bbox.get("width", 0)
- col_right = col_x + col_width
- # Word's right edge within the rightmost portion of the column
- relative_pos = (word_right - col_x) / col_width
- return relative_pos >= _GUTTER_EDGE_THRESHOLD
-
-
-# ---------------------------------------------------------------------------
-# Suggestion types
-# ---------------------------------------------------------------------------
-
-@dataclass
-class GutterSuggestion:
- """A single correction suggestion."""
- id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
- type: str = "" # "hyphen_join" | "spell_fix"
- zone_index: int = 0
- row_index: int = 0
- col_index: int = 0
- col_type: str = ""
- cell_id: str = ""
- original_text: str = ""
- suggested_text: str = ""
- # For hyphen_join:
- next_row_index: int = -1
- next_row_cell_id: str = ""
- next_row_text: str = ""
- missing_chars: str = ""
- display_parts: List[str] = field(default_factory=list)
- # Alternatives (other plausible corrections the user can pick from)
- alternatives: List[str] = field(default_factory=list)
- # Meta:
- confidence: float = 0.0
- reason: str = "" # "gutter_truncation" | "gutter_blur" | "hyphen_continuation"
-
- def to_dict(self) -> Dict[str, Any]:
- return asdict(self)
-
-
-# ---------------------------------------------------------------------------
-# Core repair logic
-# ---------------------------------------------------------------------------
-
-_TRAILING_PUNCT_RE = re.compile(r'[.,;:!?\)\]]+$')
-
-
-def _try_hyphen_join(
- word_text: str,
- next_word_text: str,
- max_missing: int = 3,
-) -> Optional[Tuple[str, str, float]]:
- """Try joining two fragments with 0..max_missing interpolated chars.
-
- Strips trailing punctuation from the continuation word before testing
- (e.g. "künden," → "künden") so dictionary lookup succeeds.
-
- Returns (joined_word, missing_chars, confidence) or None.
- """
- base = word_text.rstrip("-").rstrip()
- # Strip trailing punctuation from continuation (commas, periods, etc.)
- raw_continuation = next_word_text.lstrip()
- continuation = _TRAILING_PUNCT_RE.sub('', raw_continuation)
-
- if not base or not continuation:
- return None
-
- # 1. Direct join (no missing chars)
- direct = base + continuation
- if _is_known(direct):
- return (direct, "", 0.95)
-
- # 2. Try with 1..max_missing missing characters
- # Use common letters, weighted by frequency in German/English
- _COMMON_CHARS = "enristaldhgcmobwfkzpvjyxqu"
-
- for n_missing in range(1, max_missing + 1):
- for chars in itertools.product(_COMMON_CHARS[:15], repeat=n_missing):
- candidate = base + "".join(chars) + continuation
- if _is_known(candidate):
- missing = "".join(chars)
- # Confidence decreases with more missing chars
- conf = 0.90 - (n_missing - 1) * 0.10
- return (candidate, missing, conf)
-
- return None
-
-
-def _try_spell_fix(
- word_text: str, col_type: str = "",
-) -> Optional[Tuple[str, float, List[str]]]:
- """Try to fix a single garbled gutter word via spellchecker.
-
- Returns (best_correction, confidence, alternatives_list) or None.
- The alternatives list contains other plausible corrections the user
- can choose from (e.g. "stammelt" vs "stammeln").
- """
- if len(word_text) < _MIN_WORD_LEN_SPELL:
- return None
-
- # Strip trailing/leading parentheses and check if the bare word is valid.
- # Words like "probieren)" or "(Englisch" are valid words with punctuation,
- # not OCR errors. Don't suggest corrections for them.
- stripped = word_text.strip("()")
- if stripped and _is_known(stripped):
- return None
-
- # Determine language priority from column type
- if "en" in col_type:
- lang = "en"
- elif "de" in col_type:
- lang = "de"
- else:
- lang = "both"
-
- candidates = _spell_candidates(word_text, lang=lang)
- if not candidates and lang != "both":
- candidates = _spell_candidates(word_text, lang="both")
-
- if not candidates:
- return None
-
- # Preserve original casing
- is_upper = word_text[0].isupper()
-
- def _preserve_case(w: str) -> str:
- if is_upper and w:
- return w[0].upper() + w[1:]
- return w
-
- # Sort candidates by edit distance (closest first)
- scored = []
- for c in candidates:
- dist = _edit_distance(word_text.lower(), c.lower())
- scored.append((dist, c))
- scored.sort(key=lambda x: x[0])
-
- best_dist, best = scored[0]
- best = _preserve_case(best)
- conf = max(0.5, 1.0 - best_dist * 0.15)
-
- # Build alternatives (all other candidates, also case-preserved)
- alts = [_preserve_case(c) for _, c in scored[1:] if c.lower() != best.lower()]
- # Limit to top 5 alternatives
- alts = alts[:5]
-
- return (best, conf, alts)
-
-
-def _edit_distance(a: str, b: str) -> int:
- """Simple Levenshtein distance."""
- if len(a) < len(b):
- return _edit_distance(b, a)
- if len(b) == 0:
- return len(a)
- prev = list(range(len(b) + 1))
- for i, ca in enumerate(a):
- curr = [i + 1]
- for j, cb in enumerate(b):
- cost = 0 if ca == cb else 1
- curr.append(min(curr[j] + 1, prev[j + 1] + 1, prev[j] + cost))
- prev = curr
- return prev[len(b)]
-
-
-# ---------------------------------------------------------------------------
-# Grid analysis
-# ---------------------------------------------------------------------------
-
-def analyse_grid_for_gutter_repair(
- grid_data: Dict[str, Any],
- image_width: int = 0,
-) -> Dict[str, Any]:
- """Analyse a structured grid and return gutter repair suggestions.
-
- Args:
- grid_data: The grid_editor_result from the session (zones→cells structure).
- image_width: Image width in pixels (for determining gutter side).
-
- Returns:
- Dict with "suggestions" list and "stats".
- """
- t0 = time.time()
- _init_spellcheckers()
-
- if not _SPELL_AVAILABLE:
- return {
- "suggestions": [],
- "stats": {"error": "pyspellchecker not installed"},
- "duration_seconds": 0,
- }
-
- zones = grid_data.get("zones", [])
- suggestions: List[GutterSuggestion] = []
- words_checked = 0
- gutter_candidates = 0
-
- for zi, zone in enumerate(zones):
- columns = zone.get("columns", [])
- cells = zone.get("cells", [])
- if not columns or not cells:
- continue
-
- # Build column lookup: col_index → {x, width, type}
- col_info: Dict[int, Dict] = {}
- for col in columns:
- ci = col.get("index", col.get("col_index", -1))
- col_info[ci] = {
- "x": col.get("x_min_px", col.get("x", 0)),
- "width": col.get("x_max_px", col.get("width", 0)) - col.get("x_min_px", col.get("x", 0)),
- "type": col.get("type", col.get("col_type", "")),
- }
-
- # Build row→col→cell lookup
- cell_map: Dict[Tuple[int, int], Dict] = {}
- max_row = 0
- for cell in cells:
- ri = cell.get("row_index", 0)
- ci = cell.get("col_index", 0)
- cell_map[(ri, ci)] = cell
- if ri > max_row:
- max_row = ri
-
- # Determine which columns are at the gutter edge.
- # For a left page: rightmost content columns.
- # For now, check ALL columns — a word is a candidate if it's at the
- # right edge of its column AND not a known word.
- for (ri, ci), cell in cell_map.items():
- text = (cell.get("text") or "").strip()
- if not text:
- continue
- if _is_ipa_text(text):
- continue
-
- words_checked += 1
- col = col_info.get(ci, {})
- col_type = col.get("type", "")
-
- # Get word boxes to check position
- word_boxes = cell.get("word_boxes", [])
-
- # Check the LAST word in the cell (rightmost, closest to gutter)
- cell_words = text.split()
- if not cell_words:
- continue
-
- last_word = cell_words[-1]
-
- # Skip stopwords
- if last_word.lower().rstrip(".,;:!?-") in _STOPWORDS:
- continue
-
- last_word_clean = last_word.rstrip(".,;:!?)(")
- if len(last_word_clean) < _MIN_WORD_LEN_HYPHEN:
- continue
-
- # Check if the last word is at the gutter edge
- is_at_edge = False
- if word_boxes:
- last_wb = word_boxes[-1]
- is_at_edge = _word_is_at_gutter_edge(
- last_wb, col.get("x", 0), col.get("width", 1)
- )
- else:
- # No word boxes — use cell bbox
- bbox = cell.get("bbox_px", {})
- is_at_edge = _word_is_at_gutter_edge(
- {"left": bbox.get("x", 0), "width": bbox.get("w", 0)},
- col.get("x", 0), col.get("width", 1)
- )
-
- if not is_at_edge:
- continue
-
- # Word is at gutter edge — check if it's a known word
- if _is_known(last_word_clean):
- continue
-
- # Check if the word ends with "-" (explicit hyphen break)
- ends_with_hyphen = last_word.endswith("-")
-
- # If the word already ends with "-" and the stem (without
- # the hyphen) is a known word, this is a VALID line-break
- # hyphenation — not a gutter error. Gutter problems cause
- # the hyphen to be LOST ("ve" instead of "ver-"), so a
- # visible hyphen + known stem = intentional word-wrap.
- # Example: "wunder-" → "wunder" is known → skip.
- if ends_with_hyphen:
- stem = last_word_clean.rstrip("-")
- if stem and _is_known(stem):
- continue
-
- gutter_candidates += 1
-
- # --- Strategy 1: Hyphen join with next row ---
- next_cell = cell_map.get((ri + 1, ci))
- if next_cell:
- next_text = (next_cell.get("text") or "").strip()
- next_words = next_text.split()
- if next_words:
- first_next = next_words[0]
- first_next_clean = _TRAILING_PUNCT_RE.sub('', first_next)
- first_alpha = next((c for c in first_next if c.isalpha()), "")
-
- # Also skip if the joined word is known (covers compound
- # words where the stem alone might not be in the dictionary)
- if ends_with_hyphen and first_next_clean:
- direct = last_word_clean.rstrip("-") + first_next_clean
- if _is_known(direct):
- continue
-
- # Continuation likely if:
- # - explicit hyphen, OR
- # - next row starts lowercase (= not a new entry)
- if ends_with_hyphen or (first_alpha and first_alpha.islower()):
- result = _try_hyphen_join(last_word_clean, first_next)
- if result:
- joined, missing, conf = result
- # Build display parts: show hyphenation for original layout
- if ends_with_hyphen:
- display_p1 = last_word_clean.rstrip("-")
- if missing:
- display_p1 += missing
- display_p1 += "-"
- else:
- display_p1 = last_word_clean
- if missing:
- display_p1 += missing + "-"
- else:
- display_p1 += "-"
-
- suggestion = GutterSuggestion(
- type="hyphen_join",
- zone_index=zi,
- row_index=ri,
- col_index=ci,
- col_type=col_type,
- cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
- original_text=last_word,
- suggested_text=joined,
- next_row_index=ri + 1,
- next_row_cell_id=next_cell.get("cell_id", f"R{ri+1:02d}_C{ci}"),
- next_row_text=next_text,
- missing_chars=missing,
- display_parts=[display_p1, first_next],
- confidence=conf,
- reason="gutter_truncation" if missing else "hyphen_continuation",
- )
- suggestions.append(suggestion)
- continue # skip spell_fix if hyphen_join found
-
- # --- Strategy 2: Single-word spell fix (only for longer words) ---
- fix_result = _try_spell_fix(last_word_clean, col_type)
- if fix_result:
- corrected, conf, alts = fix_result
- suggestion = GutterSuggestion(
- type="spell_fix",
- zone_index=zi,
- row_index=ri,
- col_index=ci,
- col_type=col_type,
- cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
- original_text=last_word,
- suggested_text=corrected,
- alternatives=alts,
- confidence=conf,
- reason="gutter_blur",
- )
- suggestions.append(suggestion)
-
- duration = round(time.time() - t0, 3)
-
- logger.info(
- "Gutter repair: checked %d words, %d gutter candidates, %d suggestions (%.2fs)",
- words_checked, gutter_candidates, len(suggestions), duration,
- )
-
- return {
- "suggestions": [s.to_dict() for s in suggestions],
- "stats": {
- "words_checked": words_checked,
- "gutter_candidates": gutter_candidates,
- "suggestions_found": len(suggestions),
- },
- "duration_seconds": duration,
- }
-
-
-def apply_gutter_suggestions(
- grid_data: Dict[str, Any],
- accepted_ids: List[str],
- suggestions: List[Dict[str, Any]],
-) -> Dict[str, Any]:
- """Apply accepted gutter repair suggestions to the grid data.
-
- Modifies cells in-place and returns summary of changes.
-
- Args:
- grid_data: The grid_editor_result (zones→cells).
- accepted_ids: List of suggestion IDs the user accepted.
- suggestions: The full suggestions list (from analyse_grid_for_gutter_repair).
-
- Returns:
- Dict with "applied_count" and "changes" list.
- """
- accepted_set = set(accepted_ids)
- accepted_suggestions = [s for s in suggestions if s.get("id") in accepted_set]
-
- zones = grid_data.get("zones", [])
- changes: List[Dict[str, Any]] = []
-
- for s in accepted_suggestions:
- zi = s.get("zone_index", 0)
- ri = s.get("row_index", 0)
- ci = s.get("col_index", 0)
- stype = s.get("type", "")
-
- if zi >= len(zones):
- continue
- zone_cells = zones[zi].get("cells", [])
-
- # Find the target cell
- target_cell = None
- for cell in zone_cells:
- if cell.get("row_index") == ri and cell.get("col_index") == ci:
- target_cell = cell
- break
-
- if not target_cell:
- continue
-
- old_text = target_cell.get("text", "")
-
- if stype == "spell_fix":
- # Replace the last word in the cell text
- original_word = s.get("original_text", "")
- corrected = s.get("suggested_text", "")
- if original_word and corrected:
- # Replace from the right (last occurrence)
- idx = old_text.rfind(original_word)
- if idx >= 0:
- new_text = old_text[:idx] + corrected + old_text[idx + len(original_word):]
- target_cell["text"] = new_text
- changes.append({
- "type": "spell_fix",
- "zone_index": zi,
- "row_index": ri,
- "col_index": ci,
- "cell_id": target_cell.get("cell_id", ""),
- "old_text": old_text,
- "new_text": new_text,
- })
-
- elif stype == "hyphen_join":
- # Current cell: replace last word with the hyphenated first part
- original_word = s.get("original_text", "")
- joined = s.get("suggested_text", "")
- display_parts = s.get("display_parts", [])
- next_ri = s.get("next_row_index", -1)
-
- if not original_word or not joined or not display_parts:
- continue
-
- # The first display part is what goes in the current row
- first_part = display_parts[0] if display_parts else ""
-
- # Replace the last word in current cell with the restored form.
- # The next row is NOT modified — "künden" stays in its row
- # because the original book layout has it there. We only fix
- # the truncated word in the current row (e.g. "ve" → "ver-").
- idx = old_text.rfind(original_word)
- if idx >= 0:
- new_text = old_text[:idx] + first_part + old_text[idx + len(original_word):]
- target_cell["text"] = new_text
- changes.append({
- "type": "hyphen_join",
- "zone_index": zi,
- "row_index": ri,
- "col_index": ci,
- "cell_id": target_cell.get("cell_id", ""),
- "old_text": old_text,
- "new_text": new_text,
- "joined_word": joined,
- })
-
- logger.info("Gutter repair applied: %d/%d suggestions", len(changes), len(accepted_suggestions))
-
- return {
- "applied_count": len(accepted_suggestions),
- "changes": changes,
- }
+# Core: spellchecker, data types, repair helpers
+from cv_gutter_repair_core import ( # noqa: F401
+ _init_spellcheckers,
+ _is_known,
+ _spell_candidates,
+ _MIN_WORD_LEN_SPELL,
+ _MIN_WORD_LEN_HYPHEN,
+ _GUTTER_EDGE_THRESHOLD,
+ _STOPWORDS,
+ _IPA_RE,
+ _is_ipa_text,
+ _word_is_at_gutter_edge,
+ GutterSuggestion,
+ _TRAILING_PUNCT_RE,
+ _try_hyphen_join,
+ _try_spell_fix,
+ _edit_distance,
+)
+
+# Grid: analysis and application
+from cv_gutter_repair_grid import ( # noqa: F401
+ analyse_grid_for_gutter_repair,
+ apply_gutter_suggestions,
+)
diff --git a/klausur-service/backend/cv_gutter_repair_core.py b/klausur-service/backend/cv_gutter_repair_core.py
new file mode 100644
index 0000000..4387e88
--- /dev/null
+++ b/klausur-service/backend/cv_gutter_repair_core.py
@@ -0,0 +1,275 @@
+"""
+Gutter Repair Core — spellchecker setup, data types, and single-word repair logic.
+
+Extracted from cv_gutter_repair.py for modularity.
+
+Lizenz: Apache 2.0 (kommerziell nutzbar)
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import itertools
+import logging
+import re
+import uuid
+from dataclasses import dataclass, field, asdict
+from typing import Any, Dict, List, Optional, Tuple
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# Spellchecker setup (lazy, cached)
+# ---------------------------------------------------------------------------
+
+_spell_de = None
+_spell_en = None
+_SPELL_AVAILABLE = False
+
+def _init_spellcheckers():
+ """Lazy-load DE + EN spellcheckers (cached across calls)."""
+ global _spell_de, _spell_en, _SPELL_AVAILABLE
+ if _spell_de is not None:
+ return
+ try:
+ from spellchecker import SpellChecker
+ _spell_de = SpellChecker(language='de', distance=1)
+ _spell_en = SpellChecker(language='en', distance=1)
+ _SPELL_AVAILABLE = True
+ logger.info("Gutter repair: spellcheckers loaded (DE + EN)")
+ except ImportError:
+ logger.warning("pyspellchecker not installed — gutter repair unavailable")
+
+
+def _is_known(word: str) -> bool:
+ """Check if a word is known in DE or EN dictionary."""
+ _init_spellcheckers()
+ if not _SPELL_AVAILABLE:
+ return False
+ w = word.lower()
+ return bool(_spell_de.known([w])) or bool(_spell_en.known([w]))
+
+
+def _spell_candidates(word: str, lang: str = "both") -> List[str]:
+ """Get all plausible spellchecker candidates for a word (deduplicated)."""
+ _init_spellcheckers()
+ if not _SPELL_AVAILABLE:
+ return []
+ w = word.lower()
+ seen: set = set()
+ results: List[str] = []
+
+ for checker in ([_spell_de, _spell_en] if lang == "both"
+ else [_spell_de] if lang == "de"
+ else [_spell_en]):
+ if checker is None:
+ continue
+ cands = checker.candidates(w)
+ if cands:
+ for c in cands:
+ if c and c != w and c not in seen:
+ seen.add(c)
+ results.append(c)
+
+ return results
+
+
+# ---------------------------------------------------------------------------
+# Gutter position detection
+# ---------------------------------------------------------------------------
+
+# Minimum word length for spell-fix (very short words are often legitimate)
+_MIN_WORD_LEN_SPELL = 3
+
+# Minimum word length for hyphen-join candidates (fragments at the gutter
+# can be as short as 1-2 chars, e.g. "ve" from "ver-künden")
+_MIN_WORD_LEN_HYPHEN = 2
+
+# How close to the right column edge a word must be to count as "gutter-adjacent".
+# Expressed as fraction of column width (e.g. 0.75 = rightmost 25%).
+_GUTTER_EDGE_THRESHOLD = 0.70
+
+# Small common words / abbreviations that should NOT be repaired
+_STOPWORDS = frozenset([
+ # German
+ "ab", "an", "am", "da", "er", "es", "im", "in", "ja", "ob", "so", "um",
+ "zu", "wo", "du", "eh", "ei", "je", "na", "nu", "oh",
+ # English
+ "a", "am", "an", "as", "at", "be", "by", "do", "go", "he", "if", "in",
+ "is", "it", "me", "my", "no", "of", "on", "or", "so", "to", "up", "us",
+ "we",
+])
+
+# IPA / phonetic patterns — skip these cells
+_IPA_RE = re.compile(r'[\[\]/ˈˌːʃʒθðŋɑɒæɔəɛɪʊʌ]')
+
+
+def _is_ipa_text(text: str) -> bool:
+ """True if text looks like IPA transcription."""
+ return bool(_IPA_RE.search(text))
+
+
+def _word_is_at_gutter_edge(word_bbox: Dict, col_x: float, col_width: float) -> bool:
+ """Check if a word's right edge is near the right boundary of its column."""
+ if col_width <= 0:
+ return False
+ word_right = word_bbox.get("left", 0) + word_bbox.get("width", 0)
+ col_right = col_x + col_width
+ # Word's right edge within the rightmost portion of the column
+ relative_pos = (word_right - col_x) / col_width
+ return relative_pos >= _GUTTER_EDGE_THRESHOLD
+
+
+# ---------------------------------------------------------------------------
+# Suggestion types
+# ---------------------------------------------------------------------------
+
+@dataclass
+class GutterSuggestion:
+ """A single correction suggestion."""
+ id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
+ type: str = "" # "hyphen_join" | "spell_fix"
+ zone_index: int = 0
+ row_index: int = 0
+ col_index: int = 0
+ col_type: str = ""
+ cell_id: str = ""
+ original_text: str = ""
+ suggested_text: str = ""
+ # For hyphen_join:
+ next_row_index: int = -1
+ next_row_cell_id: str = ""
+ next_row_text: str = ""
+ missing_chars: str = ""
+ display_parts: List[str] = field(default_factory=list)
+ # Alternatives (other plausible corrections the user can pick from)
+ alternatives: List[str] = field(default_factory=list)
+ # Meta:
+ confidence: float = 0.0
+ reason: str = "" # "gutter_truncation" | "gutter_blur" | "hyphen_continuation"
+
+ def to_dict(self) -> Dict[str, Any]:
+ return asdict(self)
+
+
+# ---------------------------------------------------------------------------
+# Core repair logic
+# ---------------------------------------------------------------------------
+
+_TRAILING_PUNCT_RE = re.compile(r'[.,;:!?\)\]]+$')
+
+
+def _try_hyphen_join(
+ word_text: str,
+ next_word_text: str,
+ max_missing: int = 3,
+) -> Optional[Tuple[str, str, float]]:
+ """Try joining two fragments with 0..max_missing interpolated chars.
+
+ Strips trailing punctuation from the continuation word before testing
+ (e.g. "künden," → "künden") so dictionary lookup succeeds.
+
+ Returns (joined_word, missing_chars, confidence) or None.
+ """
+ base = word_text.rstrip("-").rstrip()
+ # Strip trailing punctuation from continuation (commas, periods, etc.)
+ raw_continuation = next_word_text.lstrip()
+ continuation = _TRAILING_PUNCT_RE.sub('', raw_continuation)
+
+ if not base or not continuation:
+ return None
+
+ # 1. Direct join (no missing chars)
+ direct = base + continuation
+ if _is_known(direct):
+ return (direct, "", 0.95)
+
+ # 2. Try with 1..max_missing missing characters
+ # Use common letters, weighted by frequency in German/English
+ _COMMON_CHARS = "enristaldhgcmobwfkzpvjyxqu"
+
+ for n_missing in range(1, max_missing + 1):
+ for chars in itertools.product(_COMMON_CHARS[:15], repeat=n_missing):
+ candidate = base + "".join(chars) + continuation
+ if _is_known(candidate):
+ missing = "".join(chars)
+ # Confidence decreases with more missing chars
+ conf = 0.90 - (n_missing - 1) * 0.10
+ return (candidate, missing, conf)
+
+ return None
+
+
+def _try_spell_fix(
+ word_text: str, col_type: str = "",
+) -> Optional[Tuple[str, float, List[str]]]:
+ """Try to fix a single garbled gutter word via spellchecker.
+
+ Returns (best_correction, confidence, alternatives_list) or None.
+ The alternatives list contains other plausible corrections the user
+ can choose from (e.g. "stammelt" vs "stammeln").
+ """
+ if len(word_text) < _MIN_WORD_LEN_SPELL:
+ return None
+
+ # Strip trailing/leading parentheses and check if the bare word is valid.
+ # Words like "probieren)" or "(Englisch" are valid words with punctuation,
+ # not OCR errors. Don't suggest corrections for them.
+ stripped = word_text.strip("()")
+ if stripped and _is_known(stripped):
+ return None
+
+ # Determine language priority from column type
+ if "en" in col_type:
+ lang = "en"
+ elif "de" in col_type:
+ lang = "de"
+ else:
+ lang = "both"
+
+ candidates = _spell_candidates(word_text, lang=lang)
+ if not candidates and lang != "both":
+ candidates = _spell_candidates(word_text, lang="both")
+
+ if not candidates:
+ return None
+
+ # Preserve original casing
+ is_upper = word_text[0].isupper()
+
+ def _preserve_case(w: str) -> str:
+ if is_upper and w:
+ return w[0].upper() + w[1:]
+ return w
+
+ # Sort candidates by edit distance (closest first)
+ scored = []
+ for c in candidates:
+ dist = _edit_distance(word_text.lower(), c.lower())
+ scored.append((dist, c))
+ scored.sort(key=lambda x: x[0])
+
+ best_dist, best = scored[0]
+ best = _preserve_case(best)
+ conf = max(0.5, 1.0 - best_dist * 0.15)
+
+ # Build alternatives (all other candidates, also case-preserved)
+ alts = [_preserve_case(c) for _, c in scored[1:] if c.lower() != best.lower()]
+ # Limit to top 5 alternatives
+ alts = alts[:5]
+
+ return (best, conf, alts)
+
+
+def _edit_distance(a: str, b: str) -> int:
+ """Simple Levenshtein distance."""
+ if len(a) < len(b):
+ return _edit_distance(b, a)
+ if len(b) == 0:
+ return len(a)
+ prev = list(range(len(b) + 1))
+ for i, ca in enumerate(a):
+ curr = [i + 1]
+ for j, cb in enumerate(b):
+ cost = 0 if ca == cb else 1
+ curr.append(min(curr[j] + 1, prev[j + 1] + 1, prev[j] + cost))
+ prev = curr
+ return prev[len(b)]
diff --git a/klausur-service/backend/cv_gutter_repair_grid.py b/klausur-service/backend/cv_gutter_repair_grid.py
new file mode 100644
index 0000000..caf7c0f
--- /dev/null
+++ b/klausur-service/backend/cv_gutter_repair_grid.py
@@ -0,0 +1,356 @@
+"""
+Gutter Repair Grid — grid analysis and suggestion application.
+
+Extracted from cv_gutter_repair.py for modularity.
+
+Lizenz: Apache 2.0 (kommerziell nutzbar)
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import logging
+import time
+from typing import Any, Dict, List, Tuple
+
+from cv_gutter_repair_core import (
+ _init_spellcheckers,
+ _is_ipa_text,
+ _is_known,
+ _MIN_WORD_LEN_HYPHEN,
+ _SPELL_AVAILABLE,
+ _STOPWORDS,
+ _TRAILING_PUNCT_RE,
+ _try_hyphen_join,
+ _try_spell_fix,
+ _word_is_at_gutter_edge,
+ GutterSuggestion,
+)
+
+logger = logging.getLogger(__name__)
+
+
+# ---------------------------------------------------------------------------
+# Grid analysis
+# ---------------------------------------------------------------------------
+
+def analyse_grid_for_gutter_repair(
+ grid_data: Dict[str, Any],
+ image_width: int = 0,
+) -> Dict[str, Any]:
+ """Analyse a structured grid and return gutter repair suggestions.
+
+ Args:
+ grid_data: The grid_editor_result from the session (zones→cells structure).
+ image_width: Image width in pixels (for determining gutter side).
+
+ Returns:
+ Dict with "suggestions" list and "stats".
+ """
+ t0 = time.time()
+ _init_spellcheckers()
+
+ if not _SPELL_AVAILABLE:
+ return {
+ "suggestions": [],
+ "stats": {"error": "pyspellchecker not installed"},
+ "duration_seconds": 0,
+ }
+
+ zones = grid_data.get("zones", [])
+ suggestions: List[GutterSuggestion] = []
+ words_checked = 0
+ gutter_candidates = 0
+
+ for zi, zone in enumerate(zones):
+ columns = zone.get("columns", [])
+ cells = zone.get("cells", [])
+ if not columns or not cells:
+ continue
+
+ # Build column lookup: col_index → {x, width, type}
+ col_info: Dict[int, Dict] = {}
+ for col in columns:
+ ci = col.get("index", col.get("col_index", -1))
+ col_info[ci] = {
+ "x": col.get("x_min_px", col.get("x", 0)),
+ "width": col.get("x_max_px", col.get("width", 0)) - col.get("x_min_px", col.get("x", 0)),
+ "type": col.get("type", col.get("col_type", "")),
+ }
+
+ # Build row→col→cell lookup
+ cell_map: Dict[Tuple[int, int], Dict] = {}
+ max_row = 0
+ for cell in cells:
+ ri = cell.get("row_index", 0)
+ ci = cell.get("col_index", 0)
+ cell_map[(ri, ci)] = cell
+ if ri > max_row:
+ max_row = ri
+
+ # Determine which columns are at the gutter edge.
+ # For a left page: rightmost content columns.
+ # For now, check ALL columns — a word is a candidate if it's at the
+ # right edge of its column AND not a known word.
+ for (ri, ci), cell in cell_map.items():
+ text = (cell.get("text") or "").strip()
+ if not text:
+ continue
+ if _is_ipa_text(text):
+ continue
+
+ words_checked += 1
+ col = col_info.get(ci, {})
+ col_type = col.get("type", "")
+
+ # Get word boxes to check position
+ word_boxes = cell.get("word_boxes", [])
+
+ # Check the LAST word in the cell (rightmost, closest to gutter)
+ cell_words = text.split()
+ if not cell_words:
+ continue
+
+ last_word = cell_words[-1]
+
+ # Skip stopwords
+ if last_word.lower().rstrip(".,;:!?-") in _STOPWORDS:
+ continue
+
+ last_word_clean = last_word.rstrip(".,;:!?)(")
+ if len(last_word_clean) < _MIN_WORD_LEN_HYPHEN:
+ continue
+
+ # Check if the last word is at the gutter edge
+ is_at_edge = False
+ if word_boxes:
+ last_wb = word_boxes[-1]
+ is_at_edge = _word_is_at_gutter_edge(
+ last_wb, col.get("x", 0), col.get("width", 1)
+ )
+ else:
+ # No word boxes — use cell bbox
+ bbox = cell.get("bbox_px", {})
+ is_at_edge = _word_is_at_gutter_edge(
+ {"left": bbox.get("x", 0), "width": bbox.get("w", 0)},
+ col.get("x", 0), col.get("width", 1)
+ )
+
+ if not is_at_edge:
+ continue
+
+ # Word is at gutter edge — check if it's a known word
+ if _is_known(last_word_clean):
+ continue
+
+ # Check if the word ends with "-" (explicit hyphen break)
+ ends_with_hyphen = last_word.endswith("-")
+
+ # If the word already ends with "-" and the stem (without
+ # the hyphen) is a known word, this is a VALID line-break
+ # hyphenation — not a gutter error. Gutter problems cause
+ # the hyphen to be LOST ("ve" instead of "ver-"), so a
+ # visible hyphen + known stem = intentional word-wrap.
+ # Example: "wunder-" → "wunder" is known → skip.
+ if ends_with_hyphen:
+ stem = last_word_clean.rstrip("-")
+ if stem and _is_known(stem):
+ continue
+
+ gutter_candidates += 1
+
+ # --- Strategy 1: Hyphen join with next row ---
+ next_cell = cell_map.get((ri + 1, ci))
+ if next_cell:
+ next_text = (next_cell.get("text") or "").strip()
+ next_words = next_text.split()
+ if next_words:
+ first_next = next_words[0]
+ first_next_clean = _TRAILING_PUNCT_RE.sub('', first_next)
+ first_alpha = next((c for c in first_next if c.isalpha()), "")
+
+ # Also skip if the joined word is known (covers compound
+ # words where the stem alone might not be in the dictionary)
+ if ends_with_hyphen and first_next_clean:
+ direct = last_word_clean.rstrip("-") + first_next_clean
+ if _is_known(direct):
+ continue
+
+ # Continuation likely if:
+ # - explicit hyphen, OR
+ # - next row starts lowercase (= not a new entry)
+ if ends_with_hyphen or (first_alpha and first_alpha.islower()):
+ result = _try_hyphen_join(last_word_clean, first_next)
+ if result:
+ joined, missing, conf = result
+ # Build display parts: show hyphenation for original layout
+ if ends_with_hyphen:
+ display_p1 = last_word_clean.rstrip("-")
+ if missing:
+ display_p1 += missing
+ display_p1 += "-"
+ else:
+ display_p1 = last_word_clean
+ if missing:
+ display_p1 += missing + "-"
+ else:
+ display_p1 += "-"
+
+ suggestion = GutterSuggestion(
+ type="hyphen_join",
+ zone_index=zi,
+ row_index=ri,
+ col_index=ci,
+ col_type=col_type,
+ cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
+ original_text=last_word,
+ suggested_text=joined,
+ next_row_index=ri + 1,
+ next_row_cell_id=next_cell.get("cell_id", f"R{ri+1:02d}_C{ci}"),
+ next_row_text=next_text,
+ missing_chars=missing,
+ display_parts=[display_p1, first_next],
+ confidence=conf,
+ reason="gutter_truncation" if missing else "hyphen_continuation",
+ )
+ suggestions.append(suggestion)
+ continue # skip spell_fix if hyphen_join found
+
+ # --- Strategy 2: Single-word spell fix (only for longer words) ---
+ fix_result = _try_spell_fix(last_word_clean, col_type)
+ if fix_result:
+ corrected, conf, alts = fix_result
+ suggestion = GutterSuggestion(
+ type="spell_fix",
+ zone_index=zi,
+ row_index=ri,
+ col_index=ci,
+ col_type=col_type,
+ cell_id=cell.get("cell_id", f"R{ri:02d}_C{ci}"),
+ original_text=last_word,
+ suggested_text=corrected,
+ alternatives=alts,
+ confidence=conf,
+ reason="gutter_blur",
+ )
+ suggestions.append(suggestion)
+
+ duration = round(time.time() - t0, 3)
+
+ logger.info(
+ "Gutter repair: checked %d words, %d gutter candidates, %d suggestions (%.2fs)",
+ words_checked, gutter_candidates, len(suggestions), duration,
+ )
+
+ return {
+ "suggestions": [s.to_dict() for s in suggestions],
+ "stats": {
+ "words_checked": words_checked,
+ "gutter_candidates": gutter_candidates,
+ "suggestions_found": len(suggestions),
+ },
+ "duration_seconds": duration,
+ }
+
+
+def apply_gutter_suggestions(
+ grid_data: Dict[str, Any],
+ accepted_ids: List[str],
+ suggestions: List[Dict[str, Any]],
+) -> Dict[str, Any]:
+ """Apply accepted gutter repair suggestions to the grid data.
+
+ Modifies cells in-place and returns summary of changes.
+
+ Args:
+ grid_data: The grid_editor_result (zones→cells).
+ accepted_ids: List of suggestion IDs the user accepted.
+ suggestions: The full suggestions list (from analyse_grid_for_gutter_repair).
+
+ Returns:
+ Dict with "applied_count" and "changes" list.
+ """
+ accepted_set = set(accepted_ids)
+ accepted_suggestions = [s for s in suggestions if s.get("id") in accepted_set]
+
+ zones = grid_data.get("zones", [])
+ changes: List[Dict[str, Any]] = []
+
+ for s in accepted_suggestions:
+ zi = s.get("zone_index", 0)
+ ri = s.get("row_index", 0)
+ ci = s.get("col_index", 0)
+ stype = s.get("type", "")
+
+ if zi >= len(zones):
+ continue
+ zone_cells = zones[zi].get("cells", [])
+
+ # Find the target cell
+ target_cell = None
+ for cell in zone_cells:
+ if cell.get("row_index") == ri and cell.get("col_index") == ci:
+ target_cell = cell
+ break
+
+ if not target_cell:
+ continue
+
+ old_text = target_cell.get("text", "")
+
+ if stype == "spell_fix":
+ # Replace the last word in the cell text
+ original_word = s.get("original_text", "")
+ corrected = s.get("suggested_text", "")
+ if original_word and corrected:
+ # Replace from the right (last occurrence)
+ idx = old_text.rfind(original_word)
+ if idx >= 0:
+ new_text = old_text[:idx] + corrected + old_text[idx + len(original_word):]
+ target_cell["text"] = new_text
+ changes.append({
+ "type": "spell_fix",
+ "zone_index": zi,
+ "row_index": ri,
+ "col_index": ci,
+ "cell_id": target_cell.get("cell_id", ""),
+ "old_text": old_text,
+ "new_text": new_text,
+ })
+
+ elif stype == "hyphen_join":
+ # Current cell: replace last word with the hyphenated first part
+ original_word = s.get("original_text", "")
+ joined = s.get("suggested_text", "")
+ display_parts = s.get("display_parts", [])
+ next_ri = s.get("next_row_index", -1)
+
+ if not original_word or not joined or not display_parts:
+ continue
+
+ # The first display part is what goes in the current row
+ first_part = display_parts[0] if display_parts else ""
+
+ # Replace the last word in current cell with the restored form.
+ # The next row is NOT modified — "künden" stays in its row
+ # because the original book layout has it there. We only fix
+ # the truncated word in the current row (e.g. "ve" → "ver-").
+ idx = old_text.rfind(original_word)
+ if idx >= 0:
+ new_text = old_text[:idx] + first_part + old_text[idx + len(original_word):]
+ target_cell["text"] = new_text
+ changes.append({
+ "type": "hyphen_join",
+ "zone_index": zi,
+ "row_index": ri,
+ "col_index": ci,
+ "cell_id": target_cell.get("cell_id", ""),
+ "old_text": old_text,
+ "new_text": new_text,
+ "joined_word": joined,
+ })
+
+ logger.info("Gutter repair applied: %d/%d suggestions", len(changes), len(accepted_suggestions))
+
+ return {
+ "applied_count": len(accepted_suggestions),
+ "changes": changes,
+ }
diff --git a/klausur-service/backend/cv_syllable_core.py b/klausur-service/backend/cv_syllable_core.py
new file mode 100644
index 0000000..4a4dca8
--- /dev/null
+++ b/klausur-service/backend/cv_syllable_core.py
@@ -0,0 +1,231 @@
+"""
+Syllable Core — hyphenator init, word validation, pipe autocorrect.
+
+Extracted from cv_syllable_detect.py for modularity.
+
+Lizenz: Apache 2.0 (kommerziell nutzbar)
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import logging
+import re
+from typing import Any, Dict, List, Optional, Tuple
+
+logger = logging.getLogger(__name__)
+
+# IPA/phonetic characters -- skip cells containing these
+_IPA_RE = re.compile(r'[\[\]\u02c8\u02cc\u02d0\u0283\u0292\u03b8\u00f0\u014b\u0251\u0252\u00e6\u0254\u0259\u025b\u025c\u026a\u028a\u028c]')
+
+# Common German words that should NOT be merged with adjacent tokens.
+_STOP_WORDS = frozenset([
+ # Articles
+ 'der', 'die', 'das', 'dem', 'den', 'des',
+ 'ein', 'eine', 'einem', 'einen', 'einer',
+ # Pronouns
+ 'du', 'er', 'es', 'sie', 'wir', 'ihr', 'ich', 'man', 'sich',
+ 'dich', 'dir', 'mich', 'mir', 'uns', 'euch', 'ihm', 'ihn',
+ # Prepositions
+ 'mit', 'von', 'zu', 'f\u00fcr', 'auf', 'in', 'an', 'um', 'am', 'im',
+ 'aus', 'bei', 'nach', 'vor', 'bis', 'durch', '\u00fcber', 'unter',
+ 'zwischen', 'ohne', 'gegen',
+ # Conjunctions
+ 'und', 'oder', 'als', 'wie', 'wenn', 'dass', 'weil', 'aber',
+ # Adverbs
+ 'auch', 'noch', 'nur', 'schon', 'sehr', 'nicht',
+ # Verbs
+ 'ist', 'hat', 'wird', 'kann', 'soll', 'muss', 'darf',
+ 'sein', 'haben',
+ # Other
+ 'kein', 'keine', 'keinem', 'keinen', 'keiner',
+])
+
+# Cached hyphenators
+_hyph_de = None
+_hyph_en = None
+
+# Cached spellchecker (for autocorrect_pipe_artifacts)
+_spell_de = None
+
+
+def _get_hyphenators():
+ """Lazy-load pyphen hyphenators (cached across calls)."""
+ global _hyph_de, _hyph_en
+ if _hyph_de is not None:
+ return _hyph_de, _hyph_en
+ try:
+ import pyphen
+ except ImportError:
+ return None, None
+ _hyph_de = pyphen.Pyphen(lang='de_DE')
+ _hyph_en = pyphen.Pyphen(lang='en_US')
+ return _hyph_de, _hyph_en
+
+
+def _get_spellchecker():
+ """Lazy-load German spellchecker (cached across calls)."""
+ global _spell_de
+ if _spell_de is not None:
+ return _spell_de
+ try:
+ from spellchecker import SpellChecker
+ except ImportError:
+ return None
+ _spell_de = SpellChecker(language='de')
+ return _spell_de
+
+
+def _is_known_word(word: str, hyph_de, hyph_en) -> bool:
+ """Check whether pyphen recognises a word (DE or EN)."""
+ if len(word) < 2:
+ return False
+ return ('|' in hyph_de.inserted(word, hyphen='|')
+ or '|' in hyph_en.inserted(word, hyphen='|'))
+
+
+def _is_real_word(word: str) -> bool:
+ """Check whether spellchecker knows this word (case-insensitive)."""
+ spell = _get_spellchecker()
+ if spell is None:
+ return False
+ return word.lower() in spell
+
+
+def _hyphenate_word(word: str, hyph_de, hyph_en) -> Optional[str]:
+ """Try to hyphenate a word using DE then EN dictionary.
+
+ Returns word with | separators, or None if not recognized.
+ """
+ hyph = hyph_de.inserted(word, hyphen='|')
+ if '|' in hyph:
+ return hyph
+ hyph = hyph_en.inserted(word, hyphen='|')
+ if '|' in hyph:
+ return hyph
+ return None
+
+
+def _autocorrect_piped_word(word_with_pipes: str) -> Optional[str]:
+ """Try to correct a word that has OCR pipe artifacts.
+
+ Printed syllable divider lines on dictionary pages confuse OCR:
+ the vertical stroke is often read as an extra character (commonly
+ ``l``, ``I``, ``1``, ``i``) adjacent to where the pipe appears.
+
+ Uses ``spellchecker`` (frequency-based word list) for validation.
+
+ Strategy:
+ 1. Strip ``|`` -- if spellchecker knows the result, done.
+ 2. Try deleting each pipe-like character (l, I, 1, i, t).
+ 3. Fall back to spellchecker's own ``correction()`` method.
+ 4. Preserve the original casing of the first letter.
+ """
+ stripped = word_with_pipes.replace('|', '')
+ if not stripped or len(stripped) < 3:
+ return stripped # too short to validate
+
+ # Step 1: if the stripped word is already a real word, done
+ if _is_real_word(stripped):
+ return stripped
+
+ # Step 2: try deleting pipe-like characters (most likely artifacts)
+ _PIPE_LIKE = frozenset('lI1it')
+ for idx in range(len(stripped)):
+ if stripped[idx] not in _PIPE_LIKE:
+ continue
+ candidate = stripped[:idx] + stripped[idx + 1:]
+ if len(candidate) >= 3 and _is_real_word(candidate):
+ return candidate
+
+ # Step 3: use spellchecker's built-in correction
+ spell = _get_spellchecker()
+ if spell is not None:
+ suggestion = spell.correction(stripped.lower())
+ if suggestion and suggestion != stripped.lower():
+ # Preserve original first-letter case
+ if stripped[0].isupper():
+ suggestion = suggestion[0].upper() + suggestion[1:]
+ return suggestion
+
+ return None # could not fix
+
+
+def autocorrect_pipe_artifacts(
+ zones_data: List[Dict], session_id: str,
+) -> int:
+ """Strip OCR pipe artifacts and correct garbled words in-place.
+
+ Printed syllable divider lines on dictionary scans are read by OCR
+ as ``|`` characters embedded in words (e.g. ``Zel|le``, ``Ze|plpe|lin``).
+ This function:
+
+ 1. Strips ``|`` from every word in content cells.
+ 2. Validates with spellchecker (real dictionary lookup).
+ 3. If not recognised, tries deleting pipe-like characters or uses
+ spellchecker's correction (e.g. ``Zeplpelin`` -> ``Zeppelin``).
+ 4. Updates both word-box texts and cell text.
+
+ Returns the number of cells modified.
+ """
+ spell = _get_spellchecker()
+ if spell is None:
+ logger.warning("spellchecker not available -- pipe autocorrect limited")
+ # Fall back: still strip pipes even without spellchecker
+ pass
+
+ modified = 0
+ for z in zones_data:
+ for cell in z.get("cells", []):
+ ct = cell.get("col_type", "")
+ if not ct.startswith("column_"):
+ continue
+
+ cell_changed = False
+
+ # --- Fix word boxes ---
+ for wb in cell.get("word_boxes", []):
+ wb_text = wb.get("text", "")
+ if "|" not in wb_text:
+ continue
+
+ # Separate trailing punctuation
+ m = re.match(
+ r'^([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)'
+ r'(.*?)'
+ r'([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)$',
+ wb_text,
+ )
+ if not m:
+ continue
+ lead, core, trail = m.group(1), m.group(2), m.group(3)
+ if "|" not in core:
+ continue
+
+ corrected = _autocorrect_piped_word(core)
+ if corrected is not None and corrected != core:
+ wb["text"] = lead + corrected + trail
+ cell_changed = True
+
+ # --- Rebuild cell text from word boxes ---
+ if cell_changed:
+ wbs = cell.get("word_boxes", [])
+ if wbs:
+ cell["text"] = " ".join(
+ (wb.get("text") or "") for wb in wbs
+ )
+ modified += 1
+
+ # --- Fallback: strip residual | from cell text ---
+ text = cell.get("text", "")
+ if "|" in text:
+ clean = text.replace("|", "")
+ if clean != text:
+ cell["text"] = clean
+ if not cell_changed:
+ modified += 1
+
+ if modified:
+ logger.info(
+ "build-grid session %s: autocorrected pipe artifacts in %d cells",
+ session_id, modified,
+ )
+ return modified
diff --git a/klausur-service/backend/cv_syllable_detect.py b/klausur-service/backend/cv_syllable_detect.py
index 65e0ae9..fe2b003 100644
--- a/klausur-service/backend/cv_syllable_detect.py
+++ b/klausur-service/backend/cv_syllable_detect.py
@@ -1,532 +1,32 @@
"""
-Syllable divider insertion for dictionary pages.
+Syllable divider insertion for dictionary pages — barrel re-export.
-For confirmed dictionary pages (is_dictionary=True), processes all content
-column cells:
- 1. Strips existing | dividers for clean normalization
- 2. Merges pipe-gap spaces (where OCR split a word at a divider position)
- 3. Applies pyphen syllabification to each word >= 3 alpha chars (DE then EN)
- 4. Only modifies words that pyphen recognizes — garbled OCR stays as-is
-
-No CV gate needed — the dictionary detection confidence is sufficient.
-pyphen uses Hunspell/TeX hyphenation dictionaries and is very reliable.
+All implementation split into:
+ cv_syllable_core — hyphenator init, word validation, pipe autocorrect
+ cv_syllable_merge — word gap merging, syllabification, divider insertion
Lizenz: Apache 2.0 (kommerziell nutzbar)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
-import logging
-import re
-from typing import Any, Dict, List, Optional, Tuple
-
-import numpy as np
-
-logger = logging.getLogger(__name__)
-
-# IPA/phonetic characters — skip cells containing these
-_IPA_RE = re.compile(r'[\[\]ˈˌːʃʒθðŋɑɒæɔəɛɜɪʊʌ]')
-
-# Common German words that should NOT be merged with adjacent tokens.
-# These are function words that appear as standalone words between
-# headwords/definitions on dictionary pages.
-_STOP_WORDS = frozenset([
- # Articles
- 'der', 'die', 'das', 'dem', 'den', 'des',
- 'ein', 'eine', 'einem', 'einen', 'einer',
- # Pronouns
- 'du', 'er', 'es', 'sie', 'wir', 'ihr', 'ich', 'man', 'sich',
- 'dich', 'dir', 'mich', 'mir', 'uns', 'euch', 'ihm', 'ihn',
- # Prepositions
- 'mit', 'von', 'zu', 'für', 'auf', 'in', 'an', 'um', 'am', 'im',
- 'aus', 'bei', 'nach', 'vor', 'bis', 'durch', 'über', 'unter',
- 'zwischen', 'ohne', 'gegen',
- # Conjunctions
- 'und', 'oder', 'als', 'wie', 'wenn', 'dass', 'weil', 'aber',
- # Adverbs
- 'auch', 'noch', 'nur', 'schon', 'sehr', 'nicht',
- # Verbs
- 'ist', 'hat', 'wird', 'kann', 'soll', 'muss', 'darf',
- 'sein', 'haben',
- # Other
- 'kein', 'keine', 'keinem', 'keinen', 'keiner',
-])
-
-# Cached hyphenators
-_hyph_de = None
-_hyph_en = None
-
-# Cached spellchecker (for autocorrect_pipe_artifacts)
-_spell_de = None
-
-
-def _get_hyphenators():
- """Lazy-load pyphen hyphenators (cached across calls)."""
- global _hyph_de, _hyph_en
- if _hyph_de is not None:
- return _hyph_de, _hyph_en
- try:
- import pyphen
- except ImportError:
- return None, None
- _hyph_de = pyphen.Pyphen(lang='de_DE')
- _hyph_en = pyphen.Pyphen(lang='en_US')
- return _hyph_de, _hyph_en
-
-
-def _get_spellchecker():
- """Lazy-load German spellchecker (cached across calls)."""
- global _spell_de
- if _spell_de is not None:
- return _spell_de
- try:
- from spellchecker import SpellChecker
- except ImportError:
- return None
- _spell_de = SpellChecker(language='de')
- return _spell_de
-
-
-def _is_known_word(word: str, hyph_de, hyph_en) -> bool:
- """Check whether pyphen recognises a word (DE or EN)."""
- if len(word) < 2:
- return False
- return ('|' in hyph_de.inserted(word, hyphen='|')
- or '|' in hyph_en.inserted(word, hyphen='|'))
-
-
-def _is_real_word(word: str) -> bool:
- """Check whether spellchecker knows this word (case-insensitive)."""
- spell = _get_spellchecker()
- if spell is None:
- return False
- return word.lower() in spell
-
-
-def _hyphenate_word(word: str, hyph_de, hyph_en) -> Optional[str]:
- """Try to hyphenate a word using DE then EN dictionary.
-
- Returns word with | separators, or None if not recognized.
- """
- hyph = hyph_de.inserted(word, hyphen='|')
- if '|' in hyph:
- return hyph
- hyph = hyph_en.inserted(word, hyphen='|')
- if '|' in hyph:
- return hyph
- return None
-
-
-def _autocorrect_piped_word(word_with_pipes: str) -> Optional[str]:
- """Try to correct a word that has OCR pipe artifacts.
-
- Printed syllable divider lines on dictionary pages confuse OCR:
- the vertical stroke is often read as an extra character (commonly
- ``l``, ``I``, ``1``, ``i``) adjacent to where the pipe appears.
- Sometimes OCR reads one divider as ``|`` and another as a letter,
- so the garbled character may be far from any detected pipe.
-
- Uses ``spellchecker`` (frequency-based word list) for validation —
- unlike pyphen which is a pattern-based hyphenator and accepts
- nonsense strings like "Zeplpelin".
-
- Strategy:
- 1. Strip ``|`` — if spellchecker knows the result, done.
- 2. Try deleting each pipe-like character (l, I, 1, i, t).
- OCR inserts extra chars that resemble vertical strokes.
- 3. Fall back to spellchecker's own ``correction()`` method.
- 4. Preserve the original casing of the first letter.
- """
- stripped = word_with_pipes.replace('|', '')
- if not stripped or len(stripped) < 3:
- return stripped # too short to validate
-
- # Step 1: if the stripped word is already a real word, done
- if _is_real_word(stripped):
- return stripped
-
- # Step 2: try deleting pipe-like characters (most likely artifacts)
- _PIPE_LIKE = frozenset('lI1it')
- for idx in range(len(stripped)):
- if stripped[idx] not in _PIPE_LIKE:
- continue
- candidate = stripped[:idx] + stripped[idx + 1:]
- if len(candidate) >= 3 and _is_real_word(candidate):
- return candidate
-
- # Step 3: use spellchecker's built-in correction
- spell = _get_spellchecker()
- if spell is not None:
- suggestion = spell.correction(stripped.lower())
- if suggestion and suggestion != stripped.lower():
- # Preserve original first-letter case
- if stripped[0].isupper():
- suggestion = suggestion[0].upper() + suggestion[1:]
- return suggestion
-
- return None # could not fix
-
-
-def autocorrect_pipe_artifacts(
- zones_data: List[Dict], session_id: str,
-) -> int:
- """Strip OCR pipe artifacts and correct garbled words in-place.
-
- Printed syllable divider lines on dictionary scans are read by OCR
- as ``|`` characters embedded in words (e.g. ``Zel|le``, ``Ze|plpe|lin``).
- This function:
-
- 1. Strips ``|`` from every word in content cells.
- 2. Validates with spellchecker (real dictionary lookup).
- 3. If not recognised, tries deleting pipe-like characters or uses
- spellchecker's correction (e.g. ``Zeplpelin`` → ``Zeppelin``).
- 4. Updates both word-box texts and cell text.
-
- Returns the number of cells modified.
- """
- spell = _get_spellchecker()
- if spell is None:
- logger.warning("spellchecker not available — pipe autocorrect limited")
- # Fall back: still strip pipes even without spellchecker
- pass
-
- modified = 0
- for z in zones_data:
- for cell in z.get("cells", []):
- ct = cell.get("col_type", "")
- if not ct.startswith("column_"):
- continue
-
- cell_changed = False
-
- # --- Fix word boxes ---
- for wb in cell.get("word_boxes", []):
- wb_text = wb.get("text", "")
- if "|" not in wb_text:
- continue
-
- # Separate trailing punctuation
- m = re.match(
- r'^([^a-zA-ZäöüÄÖÜßẞ]*)'
- r'(.*?)'
- r'([^a-zA-ZäöüÄÖÜßẞ]*)$',
- wb_text,
- )
- if not m:
- continue
- lead, core, trail = m.group(1), m.group(2), m.group(3)
- if "|" not in core:
- continue
-
- corrected = _autocorrect_piped_word(core)
- if corrected is not None and corrected != core:
- wb["text"] = lead + corrected + trail
- cell_changed = True
-
- # --- Rebuild cell text from word boxes ---
- if cell_changed:
- wbs = cell.get("word_boxes", [])
- if wbs:
- cell["text"] = " ".join(
- (wb.get("text") or "") for wb in wbs
- )
- modified += 1
-
- # --- Fallback: strip residual | from cell text ---
- # (covers cases where word_boxes don't exist or weren't fixed)
- text = cell.get("text", "")
- if "|" in text:
- clean = text.replace("|", "")
- if clean != text:
- cell["text"] = clean
- if not cell_changed:
- modified += 1
-
- if modified:
- logger.info(
- "build-grid session %s: autocorrected pipe artifacts in %d cells",
- session_id, modified,
- )
- return modified
-
-
-def _try_merge_pipe_gaps(text: str, hyph_de) -> str:
- """Merge fragments separated by single spaces where OCR split at a pipe.
-
- Example: "Kaf fee" -> "Kaffee" (pyphen recognizes the merged word).
- Multi-step: "Ka bel jau" -> "Kabel jau" -> "Kabeljau".
-
- Guards against false merges:
- - The FIRST token must be pure alpha (word start — no attached punctuation)
- - The second token may have trailing punctuation (comma, period) which
- stays attached to the merged word: "Kä" + "fer," -> "Käfer,"
- - Common German function words (der, die, das, ...) are never merged
- - At least one fragment must be very short (<=3 alpha chars)
- """
- parts = text.split(' ')
- if len(parts) < 2:
- return text
-
- result = [parts[0]]
- i = 1
- while i < len(parts):
- prev = result[-1]
- curr = parts[i]
-
- # Extract alpha-only core for lookup
- prev_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', prev)
- curr_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', curr)
-
- # Guard 1: first token must be pure alpha (word-start fragment)
- # second token may have trailing punctuation
- # Guard 2: neither alpha core can be a common German function word
- # Guard 3: the shorter fragment must be <= 3 chars (pipe-gap signal)
- # Guard 4: combined length must be >= 4
- should_try = (
- prev == prev_alpha # first token: pure alpha (word start)
- and prev_alpha and curr_alpha
- and prev_alpha.lower() not in _STOP_WORDS
- and curr_alpha.lower() not in _STOP_WORDS
- and min(len(prev_alpha), len(curr_alpha)) <= 3
- and len(prev_alpha) + len(curr_alpha) >= 4
- )
-
- if should_try:
- merged_alpha = prev_alpha + curr_alpha
- hyph = hyph_de.inserted(merged_alpha, hyphen='-')
- if '-' in hyph:
- # pyphen recognizes merged word — collapse the space
- result[-1] = prev + curr
- i += 1
- continue
-
- result.append(curr)
- i += 1
-
- return ' '.join(result)
-
-
-def merge_word_gaps_in_zones(zones_data: List[Dict], session_id: str) -> int:
- """Merge OCR word-gap fragments in cell texts using pyphen validation.
-
- OCR often splits words at syllable boundaries into separate word_boxes,
- producing text like "zerknit tert" instead of "zerknittert". This
- function tries to merge adjacent fragments in every content cell.
-
- More permissive than ``_try_merge_pipe_gaps`` (threshold 5 instead of 3)
- but still guarded by pyphen dictionary lookup and stop-word exclusion.
-
- Returns the number of cells modified.
- """
- hyph_de, _ = _get_hyphenators()
- if hyph_de is None:
- return 0
-
- modified = 0
- for z in zones_data:
- for cell in z.get("cells", []):
- ct = cell.get("col_type", "")
- if not ct.startswith("column_"):
- continue
- text = cell.get("text", "")
- if not text or " " not in text:
- continue
-
- # Skip IPA cells
- text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
- if _IPA_RE.search(text_no_brackets):
- continue
-
- new_text = _try_merge_word_gaps(text, hyph_de)
- if new_text != text:
- cell["text"] = new_text
- modified += 1
-
- if modified:
- logger.info(
- "build-grid session %s: merged word gaps in %d cells",
- session_id, modified,
- )
- return modified
-
-
-def _try_merge_word_gaps(text: str, hyph_de) -> str:
- """Merge OCR word fragments with relaxed threshold (max_short=5).
-
- Similar to ``_try_merge_pipe_gaps`` but allows slightly longer fragments
- (max_short=5 instead of 3). Still requires pyphen to recognize the
- merged word.
- """
- parts = text.split(' ')
- if len(parts) < 2:
- return text
-
- result = [parts[0]]
- i = 1
- while i < len(parts):
- prev = result[-1]
- curr = parts[i]
-
- prev_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', prev)
- curr_alpha = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ]', '', curr)
-
- should_try = (
- prev == prev_alpha
- and prev_alpha and curr_alpha
- and prev_alpha.lower() not in _STOP_WORDS
- and curr_alpha.lower() not in _STOP_WORDS
- and min(len(prev_alpha), len(curr_alpha)) <= 5
- and len(prev_alpha) + len(curr_alpha) >= 4
- )
-
- if should_try:
- merged_alpha = prev_alpha + curr_alpha
- hyph = hyph_de.inserted(merged_alpha, hyphen='-')
- if '-' in hyph:
- result[-1] = prev + curr
- i += 1
- continue
-
- result.append(curr)
- i += 1
-
- return ' '.join(result)
-
-
-def _syllabify_text(text: str, hyph_de, hyph_en) -> str:
- """Syllabify all significant words in a text string.
-
- 1. Strip existing | dividers
- 2. Merge pipe-gap spaces where possible
- 3. Apply pyphen to each word >= 3 alphabetic chars
- 4. Words pyphen doesn't recognize stay as-is (no bad guesses)
- """
- if not text:
- return text
-
- # Skip cells that contain IPA transcription characters outside brackets.
- # Bracket content like [bɪltʃøn] is programmatically inserted and should
- # not block syllabification of the surrounding text.
- text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
- if _IPA_RE.search(text_no_brackets):
- return text
-
- # Phase 1: strip existing pipe dividers for clean normalization
- clean = text.replace('|', '')
-
- # Phase 2: merge pipe-gap spaces (OCR fragments from pipe splitting)
- clean = _try_merge_pipe_gaps(clean, hyph_de)
-
- # Phase 3: tokenize and syllabify each word
- # Split on whitespace and comma/semicolon sequences, keeping separators
- tokens = re.split(r'(\s+|[,;:]+\s*)', clean)
-
- result = []
- for tok in tokens:
- if not tok or re.match(r'^[\s,;:]+$', tok):
- result.append(tok)
- continue
-
- # Strip trailing/leading punctuation for pyphen lookup
- m = re.match(r'^([^a-zA-ZäöüÄÖÜßẞ]*)(.*?)([^a-zA-ZäöüÄÖÜßẞ]*)$', tok)
- if not m:
- result.append(tok)
- continue
- lead, word, trail = m.group(1), m.group(2), m.group(3)
-
- if len(word) < 3 or not re.search(r'[a-zA-ZäöüÄÖÜß]', word):
- result.append(tok)
- continue
-
- hyph = _hyphenate_word(word, hyph_de, hyph_en)
- if hyph:
- result.append(lead + hyph + trail)
- else:
- result.append(tok)
-
- return ''.join(result)
-
-
-def insert_syllable_dividers(
- zones_data: List[Dict],
- img_bgr: np.ndarray,
- session_id: str,
- *,
- force: bool = False,
- col_filter: Optional[set] = None,
-) -> int:
- """Insert pipe syllable dividers into dictionary cells.
-
- For dictionary pages: process all content column cells, strip existing
- pipes, merge pipe-gap spaces, and re-syllabify using pyphen.
-
- Pre-check: at least 1% of content cells must already contain ``|`` from
- OCR. This guards against pages with zero pipe characters (the primary
- guard — article_col_index — is checked at the call site).
-
- Args:
- force: If True, skip the pipe-ratio pre-check and syllabify all
- content words regardless of whether the original has pipe dividers.
- col_filter: If set, only process cells whose col_type is in this set.
- None means process all content columns.
-
- Returns the number of cells modified.
- """
- hyph_de, hyph_en = _get_hyphenators()
- if hyph_de is None:
- logger.warning("pyphen not installed — skipping syllable insertion")
- return 0
-
- # Pre-check: count cells that already have | from OCR.
- # Real dictionary pages with printed syllable dividers will have OCR-
- # detected pipes in many cells. Pages without syllable dividers will
- # have zero — skip those to avoid false syllabification.
- if not force:
- total_col_cells = 0
- cells_with_pipes = 0
- for z in zones_data:
- for cell in z.get("cells", []):
- if cell.get("col_type", "").startswith("column_"):
- total_col_cells += 1
- if "|" in cell.get("text", ""):
- cells_with_pipes += 1
-
- if total_col_cells > 0:
- pipe_ratio = cells_with_pipes / total_col_cells
- if pipe_ratio < 0.01:
- logger.info(
- "build-grid session %s: skipping syllable insertion — "
- "only %.1f%% of cells have existing pipes (need >=1%%)",
- session_id, pipe_ratio * 100,
- )
- return 0
-
- insertions = 0
- for z in zones_data:
- for cell in z.get("cells", []):
- ct = cell.get("col_type", "")
- if not ct.startswith("column_"):
- continue
- if col_filter is not None and ct not in col_filter:
- continue
- text = cell.get("text", "")
- if not text:
- continue
-
- # In auto mode (force=False), only normalize cells that already
- # have | from OCR (i.e. printed syllable dividers on the original
- # scan). Don't add new syllable marks to other words.
- if not force and "|" not in text:
- continue
-
- new_text = _syllabify_text(text, hyph_de, hyph_en)
- if new_text != text:
- cell["text"] = new_text
- insertions += 1
-
- if insertions:
- logger.info(
- "build-grid session %s: syllable dividers inserted/normalized "
- "in %d cells (pyphen)",
- session_id, insertions,
- )
- return insertions
+# Core: init, validation, autocorrect
+from cv_syllable_core import ( # noqa: F401
+ _IPA_RE,
+ _STOP_WORDS,
+ _get_hyphenators,
+ _get_spellchecker,
+ _is_known_word,
+ _is_real_word,
+ _hyphenate_word,
+ _autocorrect_piped_word,
+ autocorrect_pipe_artifacts,
+)
+
+# Merge: gap merging, syllabify, insert
+from cv_syllable_merge import ( # noqa: F401
+ _try_merge_pipe_gaps,
+ merge_word_gaps_in_zones,
+ _try_merge_word_gaps,
+ _syllabify_text,
+ insert_syllable_dividers,
+)
diff --git a/klausur-service/backend/cv_syllable_merge.py b/klausur-service/backend/cv_syllable_merge.py
new file mode 100644
index 0000000..3684210
--- /dev/null
+++ b/klausur-service/backend/cv_syllable_merge.py
@@ -0,0 +1,300 @@
+"""
+Syllable Merge — word gap merging, syllabification, divider insertion.
+
+Extracted from cv_syllable_detect.py for modularity.
+
+Lizenz: Apache 2.0 (kommerziell nutzbar)
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import logging
+import re
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+
+from cv_syllable_core import (
+ _get_hyphenators,
+ _hyphenate_word,
+ _IPA_RE,
+ _STOP_WORDS,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _try_merge_pipe_gaps(text: str, hyph_de) -> str:
+ """Merge fragments separated by single spaces where OCR split at a pipe.
+
+ Example: "Kaf fee" -> "Kaffee" (pyphen recognizes the merged word).
+ Multi-step: "Ka bel jau" -> "Kabel jau" -> "Kabeljau".
+
+ Guards against false merges:
+ - The FIRST token must be pure alpha (word start -- no attached punctuation)
+ - The second token may have trailing punctuation (comma, period) which
+ stays attached to the merged word: "Ka" + "fer," -> "Kafer,"
+ - Common German function words (der, die, das, ...) are never merged
+ - At least one fragment must be very short (<=3 alpha chars)
+ """
+ parts = text.split(' ')
+ if len(parts) < 2:
+ return text
+
+ result = [parts[0]]
+ i = 1
+ while i < len(parts):
+ prev = result[-1]
+ curr = parts[i]
+
+ # Extract alpha-only core for lookup
+ prev_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', prev)
+ curr_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', curr)
+
+ # Guard 1: first token must be pure alpha (word-start fragment)
+ # second token may have trailing punctuation
+ # Guard 2: neither alpha core can be a common German function word
+ # Guard 3: the shorter fragment must be <= 3 chars (pipe-gap signal)
+ # Guard 4: combined length must be >= 4
+ should_try = (
+ prev == prev_alpha # first token: pure alpha (word start)
+ and prev_alpha and curr_alpha
+ and prev_alpha.lower() not in _STOP_WORDS
+ and curr_alpha.lower() not in _STOP_WORDS
+ and min(len(prev_alpha), len(curr_alpha)) <= 3
+ and len(prev_alpha) + len(curr_alpha) >= 4
+ )
+
+ if should_try:
+ merged_alpha = prev_alpha + curr_alpha
+ hyph = hyph_de.inserted(merged_alpha, hyphen='-')
+ if '-' in hyph:
+ # pyphen recognizes merged word -- collapse the space
+ result[-1] = prev + curr
+ i += 1
+ continue
+
+ result.append(curr)
+ i += 1
+
+ return ' '.join(result)
+
+
+def merge_word_gaps_in_zones(zones_data: List[Dict], session_id: str) -> int:
+ """Merge OCR word-gap fragments in cell texts using pyphen validation.
+
+ OCR often splits words at syllable boundaries into separate word_boxes,
+ producing text like "zerknit tert" instead of "zerknittert". This
+ function tries to merge adjacent fragments in every content cell.
+
+ More permissive than ``_try_merge_pipe_gaps`` (threshold 5 instead of 3)
+ but still guarded by pyphen dictionary lookup and stop-word exclusion.
+
+ Returns the number of cells modified.
+ """
+ hyph_de, _ = _get_hyphenators()
+ if hyph_de is None:
+ return 0
+
+ modified = 0
+ for z in zones_data:
+ for cell in z.get("cells", []):
+ ct = cell.get("col_type", "")
+ if not ct.startswith("column_"):
+ continue
+ text = cell.get("text", "")
+ if not text or " " not in text:
+ continue
+
+ # Skip IPA cells
+ text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
+ if _IPA_RE.search(text_no_brackets):
+ continue
+
+ new_text = _try_merge_word_gaps(text, hyph_de)
+ if new_text != text:
+ cell["text"] = new_text
+ modified += 1
+
+ if modified:
+ logger.info(
+ "build-grid session %s: merged word gaps in %d cells",
+ session_id, modified,
+ )
+ return modified
+
+
+def _try_merge_word_gaps(text: str, hyph_de) -> str:
+ """Merge OCR word fragments with relaxed threshold (max_short=5).
+
+ Similar to ``_try_merge_pipe_gaps`` but allows slightly longer fragments
+ (max_short=5 instead of 3). Still requires pyphen to recognize the
+ merged word.
+ """
+ parts = text.split(' ')
+ if len(parts) < 2:
+ return text
+
+ result = [parts[0]]
+ i = 1
+ while i < len(parts):
+ prev = result[-1]
+ curr = parts[i]
+
+ prev_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', prev)
+ curr_alpha = re.sub(r'[^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]', '', curr)
+
+ should_try = (
+ prev == prev_alpha
+ and prev_alpha and curr_alpha
+ and prev_alpha.lower() not in _STOP_WORDS
+ and curr_alpha.lower() not in _STOP_WORDS
+ and min(len(prev_alpha), len(curr_alpha)) <= 5
+ and len(prev_alpha) + len(curr_alpha) >= 4
+ )
+
+ if should_try:
+ merged_alpha = prev_alpha + curr_alpha
+ hyph = hyph_de.inserted(merged_alpha, hyphen='-')
+ if '-' in hyph:
+ result[-1] = prev + curr
+ i += 1
+ continue
+
+ result.append(curr)
+ i += 1
+
+ return ' '.join(result)
+
+
+def _syllabify_text(text: str, hyph_de, hyph_en) -> str:
+ """Syllabify all significant words in a text string.
+
+ 1. Strip existing | dividers
+ 2. Merge pipe-gap spaces where possible
+ 3. Apply pyphen to each word >= 3 alphabetic chars
+ 4. Words pyphen doesn't recognize stay as-is (no bad guesses)
+ """
+ if not text:
+ return text
+
+ # Skip cells that contain IPA transcription characters outside brackets.
+ text_no_brackets = re.sub(r'\[[^\]]*\]', '', text)
+ if _IPA_RE.search(text_no_brackets):
+ return text
+
+ # Phase 1: strip existing pipe dividers for clean normalization
+ clean = text.replace('|', '')
+
+ # Phase 2: merge pipe-gap spaces (OCR fragments from pipe splitting)
+ clean = _try_merge_pipe_gaps(clean, hyph_de)
+
+ # Phase 3: tokenize and syllabify each word
+ # Split on whitespace and comma/semicolon sequences, keeping separators
+ tokens = re.split(r'(\s+|[,;:]+\s*)', clean)
+
+ result = []
+ for tok in tokens:
+ if not tok or re.match(r'^[\s,;:]+$', tok):
+ result.append(tok)
+ continue
+
+ # Strip trailing/leading punctuation for pyphen lookup
+ m = re.match(r'^([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)(.*?)([^a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df\u1e9e]*)$', tok)
+ if not m:
+ result.append(tok)
+ continue
+ lead, word, trail = m.group(1), m.group(2), m.group(3)
+
+ if len(word) < 3 or not re.search(r'[a-zA-Z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc\u00df]', word):
+ result.append(tok)
+ continue
+
+ hyph = _hyphenate_word(word, hyph_de, hyph_en)
+ if hyph:
+ result.append(lead + hyph + trail)
+ else:
+ result.append(tok)
+
+ return ''.join(result)
+
+
+def insert_syllable_dividers(
+ zones_data: List[Dict],
+ img_bgr: np.ndarray,
+ session_id: str,
+ *,
+ force: bool = False,
+ col_filter: Optional[set] = None,
+) -> int:
+ """Insert pipe syllable dividers into dictionary cells.
+
+ For dictionary pages: process all content column cells, strip existing
+ pipes, merge pipe-gap spaces, and re-syllabify using pyphen.
+
+ Pre-check: at least 1% of content cells must already contain ``|`` from
+ OCR. This guards against pages with zero pipe characters.
+
+ Args:
+ force: If True, skip the pipe-ratio pre-check and syllabify all
+ content words regardless of whether the original has pipe dividers.
+ col_filter: If set, only process cells whose col_type is in this set.
+ None means process all content columns.
+
+ Returns the number of cells modified.
+ """
+ hyph_de, hyph_en = _get_hyphenators()
+ if hyph_de is None:
+ logger.warning("pyphen not installed -- skipping syllable insertion")
+ return 0
+
+ # Pre-check: count cells that already have | from OCR.
+ if not force:
+ total_col_cells = 0
+ cells_with_pipes = 0
+ for z in zones_data:
+ for cell in z.get("cells", []):
+ if cell.get("col_type", "").startswith("column_"):
+ total_col_cells += 1
+ if "|" in cell.get("text", ""):
+ cells_with_pipes += 1
+
+ if total_col_cells > 0:
+ pipe_ratio = cells_with_pipes / total_col_cells
+ if pipe_ratio < 0.01:
+ logger.info(
+ "build-grid session %s: skipping syllable insertion -- "
+ "only %.1f%% of cells have existing pipes (need >=1%%)",
+ session_id, pipe_ratio * 100,
+ )
+ return 0
+
+ insertions = 0
+ for z in zones_data:
+ for cell in z.get("cells", []):
+ ct = cell.get("col_type", "")
+ if not ct.startswith("column_"):
+ continue
+ if col_filter is not None and ct not in col_filter:
+ continue
+ text = cell.get("text", "")
+ if not text:
+ continue
+
+ # In auto mode (force=False), only normalize cells that already
+ # have | from OCR (i.e. printed syllable dividers on the original
+ # scan). Don't add new syllable marks to other words.
+ if not force and "|" not in text:
+ continue
+
+ new_text = _syllabify_text(text, hyph_de, hyph_en)
+ if new_text != text:
+ cell["text"] = new_text
+ insertions += 1
+
+ if insertions:
+ logger.info(
+ "build-grid session %s: syllable dividers inserted/normalized "
+ "in %d cells (pyphen)",
+ session_id, insertions,
+ )
+ return insertions
diff --git a/klausur-service/backend/mail/aggregator.py b/klausur-service/backend/mail/aggregator.py
index 081a61c..01756b7 100644
--- a/klausur-service/backend/mail/aggregator.py
+++ b/klausur-service/backend/mail/aggregator.py
@@ -1,52 +1,27 @@
"""
-Mail Aggregator Service
+Mail Aggregator Service — barrel re-export.
+
+All implementation split into:
+ aggregator_imap — IMAP connection, sync, email parsing
+ aggregator_smtp — SMTP connection, email sending
Multi-account IMAP aggregation with async support.
"""
-import os
-import ssl
-import email
import asyncio
import logging
-import smtplib
-from typing import Optional, List, Dict, Any, Tuple
-from datetime import datetime, timezone
-from email.mime.text import MIMEText
-from email.mime.multipart import MIMEMultipart
-from email.header import decode_header, make_header
-from email.utils import parsedate_to_datetime, parseaddr
+from typing import Optional, List, Dict, Any
-from .credentials import get_credentials_service, MailCredentials
-from .mail_db import (
- get_email_accounts,
- get_email_account,
- update_account_status,
- upsert_email,
- get_unified_inbox,
-)
-from .models import (
- AccountStatus,
- AccountTestResult,
- AggregatedEmail,
- EmailComposeRequest,
- EmailSendResult,
-)
+from .credentials import get_credentials_service
+from .mail_db import get_email_accounts, get_unified_inbox
+from .models import AccountTestResult
+from .aggregator_imap import IMAPMixin, IMAPConnectionError
+from .aggregator_smtp import SMTPMixin, SMTPConnectionError
logger = logging.getLogger(__name__)
-class IMAPConnectionError(Exception):
- """Raised when IMAP connection fails."""
- pass
-
-
-class SMTPConnectionError(Exception):
- """Raised when SMTP connection fails."""
- pass
-
-
-class MailAggregator:
+class MailAggregator(IMAPMixin, SMTPMixin):
"""
Aggregates emails from multiple IMAP accounts into a unified inbox.
@@ -86,390 +61,29 @@ class MailAggregator:
)
# Test IMAP
- try:
- import imaplib
-
- if imap_ssl:
- imap = imaplib.IMAP4_SSL(imap_host, imap_port)
- else:
- imap = imaplib.IMAP4(imap_host, imap_port)
-
- imap.login(email_address, password)
- result.imap_connected = True
-
- # List folders
- status, folders = imap.list()
- if status == "OK":
- result.folders_found = [
- self._parse_folder_name(f) for f in folders if f
- ]
-
- imap.logout()
-
- except Exception as e:
- result.error_message = f"IMAP Error: {str(e)}"
- logger.warning(f"IMAP test failed for {email_address}: {e}")
+ imap_ok, imap_err, folders = await self.test_imap_connection(
+ imap_host, imap_port, imap_ssl, email_address, password
+ )
+ result.imap_connected = imap_ok
+ if folders:
+ result.folders_found = folders
+ if imap_err:
+ result.error_message = imap_err
# Test SMTP
- try:
- if smtp_ssl:
- smtp = smtplib.SMTP_SSL(smtp_host, smtp_port)
- else:
- smtp = smtplib.SMTP(smtp_host, smtp_port)
- smtp.starttls()
-
- smtp.login(email_address, password)
- result.smtp_connected = True
- smtp.quit()
-
- except Exception as e:
- smtp_error = f"SMTP Error: {str(e)}"
+ smtp_ok, smtp_err = await self.test_smtp_connection(
+ smtp_host, smtp_port, smtp_ssl, email_address, password
+ )
+ result.smtp_connected = smtp_ok
+ if smtp_err:
if result.error_message:
- result.error_message += f"; {smtp_error}"
+ result.error_message += f"; {smtp_err}"
else:
- result.error_message = smtp_error
- logger.warning(f"SMTP test failed for {email_address}: {e}")
+ result.error_message = smtp_err
result.success = result.imap_connected and result.smtp_connected
return result
- def _parse_folder_name(self, folder_response: bytes) -> str:
- """Parse folder name from IMAP LIST response."""
- try:
- # Format: '(\\HasNoChildren) "/" "INBOX"'
- decoded = folder_response.decode("utf-8") if isinstance(folder_response, bytes) else folder_response
- parts = decoded.rsplit('" "', 1)
- if len(parts) == 2:
- return parts[1].rstrip('"')
- return decoded
- except Exception:
- return str(folder_response)
-
- async def sync_account(
- self,
- account_id: str,
- user_id: str,
- max_emails: int = 100,
- folders: Optional[List[str]] = None,
- ) -> Tuple[int, int]:
- """
- Sync emails from an IMAP account.
-
- Args:
- account_id: The account ID
- user_id: The user ID
- max_emails: Maximum emails to fetch
- folders: Specific folders to sync (default: INBOX)
-
- Returns:
- Tuple of (new_emails, total_emails)
- """
- import imaplib
-
- account = await get_email_account(account_id, user_id)
- if not account:
- raise ValueError(f"Account not found: {account_id}")
-
- # Get credentials
- vault_path = account.get("vault_path", "")
- creds = await self._credentials_service.get_credentials(account_id, vault_path)
- if not creds:
- await update_account_status(account_id, "error", "Credentials not found")
- raise IMAPConnectionError("Credentials not found")
-
- new_count = 0
- total_count = 0
-
- try:
- # Connect to IMAP
- if account["imap_ssl"]:
- imap = imaplib.IMAP4_SSL(account["imap_host"], account["imap_port"])
- else:
- imap = imaplib.IMAP4(account["imap_host"], account["imap_port"])
-
- imap.login(creds.email, creds.password)
-
- # Sync specified folders or just INBOX
- sync_folders = folders or ["INBOX"]
-
- for folder in sync_folders:
- try:
- status, _ = imap.select(folder)
- if status != "OK":
- continue
-
- # Search for recent emails
- status, messages = imap.search(None, "ALL")
- if status != "OK":
- continue
-
- message_ids = messages[0].split()
- total_count += len(message_ids)
-
- # Fetch most recent emails
- recent_ids = message_ids[-max_emails:] if len(message_ids) > max_emails else message_ids
-
- for msg_id in recent_ids:
- try:
- email_data = await self._fetch_and_store_email(
- imap, msg_id, account_id, user_id, account["tenant_id"], folder
- )
- if email_data:
- new_count += 1
- except Exception as e:
- logger.warning(f"Failed to fetch email {msg_id}: {e}")
-
- except Exception as e:
- logger.warning(f"Failed to sync folder {folder}: {e}")
-
- imap.logout()
-
- # Update account status
- await update_account_status(
- account_id,
- "active",
- email_count=total_count,
- unread_count=new_count, # Will be recalculated
- )
-
- return new_count, total_count
-
- except Exception as e:
- logger.error(f"Account sync failed: {e}")
- await update_account_status(account_id, "error", str(e))
- raise IMAPConnectionError(str(e))
-
- async def _fetch_and_store_email(
- self,
- imap,
- msg_id: bytes,
- account_id: str,
- user_id: str,
- tenant_id: str,
- folder: str,
- ) -> Optional[str]:
- """Fetch a single email and store it in the database."""
- try:
- status, msg_data = imap.fetch(msg_id, "(RFC822)")
- if status != "OK" or not msg_data or not msg_data[0]:
- return None
-
- raw_email = msg_data[0][1]
- msg = email.message_from_bytes(raw_email)
-
- # Parse headers
- message_id = msg.get("Message-ID", str(msg_id))
- subject = self._decode_header(msg.get("Subject", ""))
- from_header = msg.get("From", "")
- sender_name, sender_email = parseaddr(from_header)
- sender_name = self._decode_header(sender_name)
-
- # Parse recipients
- to_header = msg.get("To", "")
- recipients = [addr[1] for addr in email.utils.getaddresses([to_header])]
-
- cc_header = msg.get("Cc", "")
- cc = [addr[1] for addr in email.utils.getaddresses([cc_header])]
-
- # Parse dates
- date_str = msg.get("Date")
- try:
- date_sent = parsedate_to_datetime(date_str) if date_str else datetime.now(timezone.utc)
- except Exception:
- date_sent = datetime.now(timezone.utc)
-
- date_received = datetime.now(timezone.utc)
-
- # Parse body
- body_text, body_html, attachments = self._parse_body(msg)
-
- # Create preview
- body_preview = (body_text[:200] + "...") if body_text and len(body_text) > 200 else body_text
-
- # Get headers dict
- headers = {k: self._decode_header(v) for k, v in msg.items() if k not in ["Body"]}
-
- # Store in database
- email_id = await upsert_email(
- account_id=account_id,
- user_id=user_id,
- tenant_id=tenant_id,
- message_id=message_id,
- subject=subject,
- sender_email=sender_email,
- sender_name=sender_name,
- recipients=recipients,
- cc=cc,
- body_preview=body_preview,
- body_text=body_text,
- body_html=body_html,
- has_attachments=len(attachments) > 0,
- attachments=attachments,
- headers=headers,
- folder=folder,
- date_sent=date_sent,
- date_received=date_received,
- )
-
- return email_id
-
- except Exception as e:
- logger.error(f"Failed to parse email: {e}")
- return None
-
- def _decode_header(self, header_value: str) -> str:
- """Decode email header value."""
- if not header_value:
- return ""
- try:
- decoded = decode_header(header_value)
- return str(make_header(decoded))
- except Exception:
- return str(header_value)
-
- def _parse_body(self, msg) -> Tuple[Optional[str], Optional[str], List[Dict]]:
- """
- Parse email body and attachments.
-
- Returns:
- Tuple of (body_text, body_html, attachments)
- """
- body_text = None
- body_html = None
- attachments = []
-
- if msg.is_multipart():
- for part in msg.walk():
- content_type = part.get_content_type()
- content_disposition = str(part.get("Content-Disposition", ""))
-
- # Skip multipart containers
- if content_type.startswith("multipart/"):
- continue
-
- # Check for attachments
- if "attachment" in content_disposition:
- filename = part.get_filename()
- if filename:
- attachments.append({
- "filename": self._decode_header(filename),
- "content_type": content_type,
- "size": len(part.get_payload(decode=True) or b""),
- })
- continue
-
- # Get body content
- try:
- payload = part.get_payload(decode=True)
- charset = part.get_content_charset() or "utf-8"
-
- if payload:
- text = payload.decode(charset, errors="replace")
-
- if content_type == "text/plain" and not body_text:
- body_text = text
- elif content_type == "text/html" and not body_html:
- body_html = text
- except Exception as e:
- logger.debug(f"Failed to decode body part: {e}")
-
- else:
- # Single part message
- content_type = msg.get_content_type()
- try:
- payload = msg.get_payload(decode=True)
- charset = msg.get_content_charset() or "utf-8"
-
- if payload:
- text = payload.decode(charset, errors="replace")
-
- if content_type == "text/plain":
- body_text = text
- elif content_type == "text/html":
- body_html = text
- except Exception as e:
- logger.debug(f"Failed to decode body: {e}")
-
- return body_text, body_html, attachments
-
- async def send_email(
- self,
- account_id: str,
- user_id: str,
- request: EmailComposeRequest,
- ) -> EmailSendResult:
- """
- Send an email via SMTP.
-
- Args:
- account_id: The account to send from
- user_id: The user ID
- request: The compose request with recipients and content
-
- Returns:
- EmailSendResult with success status
- """
- account = await get_email_account(account_id, user_id)
- if not account:
- return EmailSendResult(success=False, error="Account not found")
-
- # Verify the account_id matches
- if request.account_id != account_id:
- return EmailSendResult(success=False, error="Account mismatch")
-
- # Get credentials
- vault_path = account.get("vault_path", "")
- creds = await self._credentials_service.get_credentials(account_id, vault_path)
- if not creds:
- return EmailSendResult(success=False, error="Credentials not found")
-
- try:
- # Create message
- if request.is_html:
- msg = MIMEMultipart("alternative")
- msg.attach(MIMEText(request.body, "html"))
- else:
- msg = MIMEText(request.body, "plain")
-
- msg["Subject"] = request.subject
- msg["From"] = account["email"]
- msg["To"] = ", ".join(request.to)
-
- if request.cc:
- msg["Cc"] = ", ".join(request.cc)
-
- if request.reply_to_message_id:
- msg["In-Reply-To"] = request.reply_to_message_id
- msg["References"] = request.reply_to_message_id
-
- # Send via SMTP
- if account["smtp_ssl"]:
- smtp = smtplib.SMTP_SSL(account["smtp_host"], account["smtp_port"])
- else:
- smtp = smtplib.SMTP(account["smtp_host"], account["smtp_port"])
- smtp.starttls()
-
- smtp.login(creds.email, creds.password)
-
- # All recipients
- all_recipients = list(request.to)
- if request.cc:
- all_recipients.extend(request.cc)
- if request.bcc:
- all_recipients.extend(request.bcc)
-
- smtp.sendmail(account["email"], all_recipients, msg.as_string())
- smtp.quit()
-
- return EmailSendResult(
- success=True,
- message_id=msg.get("Message-ID"),
- )
-
- except Exception as e:
- logger.error(f"Failed to send email: {e}")
- return EmailSendResult(success=False, error=str(e))
-
async def sync_all_accounts(self, user_id: str, tenant_id: Optional[str] = None) -> Dict[str, Any]:
"""
Sync all accounts for a user.
diff --git a/klausur-service/backend/mail/aggregator_imap.py b/klausur-service/backend/mail/aggregator_imap.py
new file mode 100644
index 0000000..9b5e259
--- /dev/null
+++ b/klausur-service/backend/mail/aggregator_imap.py
@@ -0,0 +1,322 @@
+"""
+Mail Aggregator IMAP — IMAP connection, sync, email parsing.
+
+Extracted from aggregator.py for modularity.
+"""
+
+import email
+import logging
+from typing import Optional, List, Dict, Any, Tuple
+from datetime import datetime, timezone
+from email.header import decode_header, make_header
+from email.utils import parsedate_to_datetime, parseaddr
+
+from .mail_db import upsert_email, update_account_status, get_email_account
+
+logger = logging.getLogger(__name__)
+
+
+class IMAPConnectionError(Exception):
+ """Raised when IMAP connection fails."""
+ pass
+
+
+class IMAPMixin:
+ """IMAP-related methods for MailAggregator.
+
+ Provides connection testing, syncing, and email parsing.
+ Must be mixed into a class that has ``_credentials_service``.
+ """
+
+ def _parse_folder_name(self, folder_response: bytes) -> str:
+ """Parse folder name from IMAP LIST response."""
+ try:
+ # Format: '(\\HasNoChildren) "/" "INBOX"'
+ decoded = folder_response.decode("utf-8") if isinstance(folder_response, bytes) else folder_response
+ parts = decoded.rsplit('" "', 1)
+ if len(parts) == 2:
+ return parts[1].rstrip('"')
+ return decoded
+ except Exception:
+ return str(folder_response)
+
+ async def test_imap_connection(
+ self,
+ imap_host: str,
+ imap_port: int,
+ imap_ssl: bool,
+ email_address: str,
+ password: str,
+ ) -> Tuple[bool, Optional[str], Optional[List[str]]]:
+ """Test IMAP connection. Returns (success, error, folders)."""
+ try:
+ import imaplib
+
+ if imap_ssl:
+ imap = imaplib.IMAP4_SSL(imap_host, imap_port)
+ else:
+ imap = imaplib.IMAP4(imap_host, imap_port)
+
+ imap.login(email_address, password)
+
+ # List folders
+ folders_found = None
+ status, folders = imap.list()
+ if status == "OK":
+ folders_found = [
+ self._parse_folder_name(f) for f in folders if f
+ ]
+
+ imap.logout()
+ return True, None, folders_found
+
+ except Exception as e:
+ logger.warning(f"IMAP test failed for {email_address}: {e}")
+ return False, f"IMAP Error: {str(e)}", None
+
+ async def sync_account(
+ self,
+ account_id: str,
+ user_id: str,
+ max_emails: int = 100,
+ folders: Optional[List[str]] = None,
+ ) -> Tuple[int, int]:
+ """
+ Sync emails from an IMAP account.
+
+ Args:
+ account_id: The account ID
+ user_id: The user ID
+ max_emails: Maximum emails to fetch
+ folders: Specific folders to sync (default: INBOX)
+
+ Returns:
+ Tuple of (new_emails, total_emails)
+ """
+ import imaplib
+
+ account = await get_email_account(account_id, user_id)
+ if not account:
+ raise ValueError(f"Account not found: {account_id}")
+
+ # Get credentials
+ vault_path = account.get("vault_path", "")
+ creds = await self._credentials_service.get_credentials(account_id, vault_path)
+ if not creds:
+ await update_account_status(account_id, "error", "Credentials not found")
+ raise IMAPConnectionError("Credentials not found")
+
+ new_count = 0
+ total_count = 0
+
+ try:
+ # Connect to IMAP
+ if account["imap_ssl"]:
+ imap = imaplib.IMAP4_SSL(account["imap_host"], account["imap_port"])
+ else:
+ imap = imaplib.IMAP4(account["imap_host"], account["imap_port"])
+
+ imap.login(creds.email, creds.password)
+
+ # Sync specified folders or just INBOX
+ sync_folders = folders or ["INBOX"]
+
+ for folder in sync_folders:
+ try:
+ status, _ = imap.select(folder)
+ if status != "OK":
+ continue
+
+ # Search for recent emails
+ status, messages = imap.search(None, "ALL")
+ if status != "OK":
+ continue
+
+ message_ids = messages[0].split()
+ total_count += len(message_ids)
+
+ # Fetch most recent emails
+ recent_ids = message_ids[-max_emails:] if len(message_ids) > max_emails else message_ids
+
+ for msg_id in recent_ids:
+ try:
+ email_data = await self._fetch_and_store_email(
+ imap, msg_id, account_id, user_id, account["tenant_id"], folder
+ )
+ if email_data:
+ new_count += 1
+ except Exception as e:
+ logger.warning(f"Failed to fetch email {msg_id}: {e}")
+
+ except Exception as e:
+ logger.warning(f"Failed to sync folder {folder}: {e}")
+
+ imap.logout()
+
+ # Update account status
+ await update_account_status(
+ account_id,
+ "active",
+ email_count=total_count,
+ unread_count=new_count, # Will be recalculated
+ )
+
+ return new_count, total_count
+
+ except Exception as e:
+ logger.error(f"Account sync failed: {e}")
+ await update_account_status(account_id, "error", str(e))
+ raise IMAPConnectionError(str(e))
+
+ async def _fetch_and_store_email(
+ self,
+ imap,
+ msg_id: bytes,
+ account_id: str,
+ user_id: str,
+ tenant_id: str,
+ folder: str,
+ ) -> Optional[str]:
+ """Fetch a single email and store it in the database."""
+ try:
+ status, msg_data = imap.fetch(msg_id, "(RFC822)")
+ if status != "OK" or not msg_data or not msg_data[0]:
+ return None
+
+ raw_email = msg_data[0][1]
+ msg = email.message_from_bytes(raw_email)
+
+ # Parse headers
+ message_id = msg.get("Message-ID", str(msg_id))
+ subject = self._decode_header(msg.get("Subject", ""))
+ from_header = msg.get("From", "")
+ sender_name, sender_email = parseaddr(from_header)
+ sender_name = self._decode_header(sender_name)
+
+ # Parse recipients
+ to_header = msg.get("To", "")
+ recipients = [addr[1] for addr in email.utils.getaddresses([to_header])]
+
+ cc_header = msg.get("Cc", "")
+ cc = [addr[1] for addr in email.utils.getaddresses([cc_header])]
+
+ # Parse dates
+ date_str = msg.get("Date")
+ try:
+ date_sent = parsedate_to_datetime(date_str) if date_str else datetime.now(timezone.utc)
+ except Exception:
+ date_sent = datetime.now(timezone.utc)
+
+ date_received = datetime.now(timezone.utc)
+
+ # Parse body
+ body_text, body_html, attachments = self._parse_body(msg)
+
+ # Create preview
+ body_preview = (body_text[:200] + "...") if body_text and len(body_text) > 200 else body_text
+
+ # Get headers dict
+ headers = {k: self._decode_header(v) for k, v in msg.items() if k not in ["Body"]}
+
+ # Store in database
+ email_id = await upsert_email(
+ account_id=account_id,
+ user_id=user_id,
+ tenant_id=tenant_id,
+ message_id=message_id,
+ subject=subject,
+ sender_email=sender_email,
+ sender_name=sender_name,
+ recipients=recipients,
+ cc=cc,
+ body_preview=body_preview,
+ body_text=body_text,
+ body_html=body_html,
+ has_attachments=len(attachments) > 0,
+ attachments=attachments,
+ headers=headers,
+ folder=folder,
+ date_sent=date_sent,
+ date_received=date_received,
+ )
+
+ return email_id
+
+ except Exception as e:
+ logger.error(f"Failed to parse email: {e}")
+ return None
+
+ def _decode_header(self, header_value: str) -> str:
+ """Decode email header value."""
+ if not header_value:
+ return ""
+ try:
+ decoded = decode_header(header_value)
+ return str(make_header(decoded))
+ except Exception:
+ return str(header_value)
+
+ def _parse_body(self, msg) -> Tuple[Optional[str], Optional[str], List[Dict]]:
+ """
+ Parse email body and attachments.
+
+ Returns:
+ Tuple of (body_text, body_html, attachments)
+ """
+ body_text = None
+ body_html = None
+ attachments = []
+
+ if msg.is_multipart():
+ for part in msg.walk():
+ content_type = part.get_content_type()
+ content_disposition = str(part.get("Content-Disposition", ""))
+
+ # Skip multipart containers
+ if content_type.startswith("multipart/"):
+ continue
+
+ # Check for attachments
+ if "attachment" in content_disposition:
+ filename = part.get_filename()
+ if filename:
+ attachments.append({
+ "filename": self._decode_header(filename),
+ "content_type": content_type,
+ "size": len(part.get_payload(decode=True) or b""),
+ })
+ continue
+
+ # Get body content
+ try:
+ payload = part.get_payload(decode=True)
+ charset = part.get_content_charset() or "utf-8"
+
+ if payload:
+ text = payload.decode(charset, errors="replace")
+
+ if content_type == "text/plain" and not body_text:
+ body_text = text
+ elif content_type == "text/html" and not body_html:
+ body_html = text
+ except Exception as e:
+ logger.debug(f"Failed to decode body part: {e}")
+
+ else:
+ # Single part message
+ content_type = msg.get_content_type()
+ try:
+ payload = msg.get_payload(decode=True)
+ charset = msg.get_content_charset() or "utf-8"
+
+ if payload:
+ text = payload.decode(charset, errors="replace")
+
+ if content_type == "text/plain":
+ body_text = text
+ elif content_type == "text/html":
+ body_html = text
+ except Exception as e:
+ logger.debug(f"Failed to decode body: {e}")
+
+ return body_text, body_html, attachments
diff --git a/klausur-service/backend/mail/aggregator_smtp.py b/klausur-service/backend/mail/aggregator_smtp.py
new file mode 100644
index 0000000..2f67a3e
--- /dev/null
+++ b/klausur-service/backend/mail/aggregator_smtp.py
@@ -0,0 +1,131 @@
+"""
+Mail Aggregator SMTP — email sending via SMTP.
+
+Extracted from aggregator.py for modularity.
+"""
+
+import logging
+import smtplib
+from typing import Optional, List, Dict, Any
+from email.mime.text import MIMEText
+from email.mime.multipart import MIMEMultipart
+
+from .mail_db import get_email_account
+from .models import EmailComposeRequest, EmailSendResult
+
+logger = logging.getLogger(__name__)
+
+
+class SMTPConnectionError(Exception):
+ """Raised when SMTP connection fails."""
+ pass
+
+
+class SMTPMixin:
+ """SMTP-related methods for MailAggregator.
+
+ Provides SMTP connection testing and email sending.
+ Must be mixed into a class that has ``_credentials_service``.
+ """
+
+ async def test_smtp_connection(
+ self,
+ smtp_host: str,
+ smtp_port: int,
+ smtp_ssl: bool,
+ email_address: str,
+ password: str,
+ ) -> tuple:
+ """Test SMTP connection. Returns (success, error)."""
+ try:
+ if smtp_ssl:
+ smtp = smtplib.SMTP_SSL(smtp_host, smtp_port)
+ else:
+ smtp = smtplib.SMTP(smtp_host, smtp_port)
+ smtp.starttls()
+
+ smtp.login(email_address, password)
+ smtp.quit()
+ return True, None
+
+ except Exception as e:
+ logger.warning(f"SMTP test failed for {email_address}: {e}")
+ return False, f"SMTP Error: {str(e)}"
+
+ async def send_email(
+ self,
+ account_id: str,
+ user_id: str,
+ request: EmailComposeRequest,
+ ) -> EmailSendResult:
+ """
+ Send an email via SMTP.
+
+ Args:
+ account_id: The account to send from
+ user_id: The user ID
+ request: The compose request with recipients and content
+
+ Returns:
+ EmailSendResult with success status
+ """
+ account = await get_email_account(account_id, user_id)
+ if not account:
+ return EmailSendResult(success=False, error="Account not found")
+
+ # Verify the account_id matches
+ if request.account_id != account_id:
+ return EmailSendResult(success=False, error="Account mismatch")
+
+ # Get credentials
+ vault_path = account.get("vault_path", "")
+ creds = await self._credentials_service.get_credentials(account_id, vault_path)
+ if not creds:
+ return EmailSendResult(success=False, error="Credentials not found")
+
+ try:
+ # Create message
+ if request.is_html:
+ msg = MIMEMultipart("alternative")
+ msg.attach(MIMEText(request.body, "html"))
+ else:
+ msg = MIMEText(request.body, "plain")
+
+ msg["Subject"] = request.subject
+ msg["From"] = account["email"]
+ msg["To"] = ", ".join(request.to)
+
+ if request.cc:
+ msg["Cc"] = ", ".join(request.cc)
+
+ if request.reply_to_message_id:
+ msg["In-Reply-To"] = request.reply_to_message_id
+ msg["References"] = request.reply_to_message_id
+
+ # Send via SMTP
+ if account["smtp_ssl"]:
+ smtp = smtplib.SMTP_SSL(account["smtp_host"], account["smtp_port"])
+ else:
+ smtp = smtplib.SMTP(account["smtp_host"], account["smtp_port"])
+ smtp.starttls()
+
+ smtp.login(creds.email, creds.password)
+
+ # All recipients
+ all_recipients = list(request.to)
+ if request.cc:
+ all_recipients.extend(request.cc)
+ if request.bcc:
+ all_recipients.extend(request.bcc)
+
+ smtp.sendmail(account["email"], all_recipients, msg.as_string())
+ smtp.quit()
+
+ return EmailSendResult(
+ success=True,
+ message_id=msg.get("Message-ID"),
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to send email: {e}")
+ return EmailSendResult(success=False, error=str(e))
diff --git a/klausur-service/backend/nibis_ingestion.py b/klausur-service/backend/nibis_ingestion.py
index 3fe22e0..63b2f23 100644
--- a/klausur-service/backend/nibis_ingestion.py
+++ b/klausur-service/backend/nibis_ingestion.py
@@ -10,12 +10,11 @@ Unterstützt:
"""
import os
-import re
import zipfile
import hashlib
import json
from pathlib import Path
-from typing import List, Dict, Optional, Tuple
+from typing import List, Dict, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import asyncio
@@ -23,6 +22,7 @@ import asyncio
# Local imports
from eh_pipeline import chunk_text, generate_embeddings, extract_text_from_pdf, get_vector_size, EMBEDDING_BACKEND
from qdrant_service import QdrantService
+from nibis_parsers import parse_filename_old_format, parse_filename_new_format
# Configuration
DOCS_BASE_PATH = Path("/Users/benjaminadmin/projekte/breakpilot-pwa/docs")
@@ -87,15 +87,6 @@ SUBJECT_MAPPING = {
"gespfl": "Gesundheit-Pflege",
}
-# Niveau-Mapping
-NIVEAU_MAPPING = {
- "ea": "eA", # erhöhtes Anforderungsniveau
- "ga": "gA", # grundlegendes Anforderungsniveau
- "neuga": "gA (neu einsetzend)",
- "neuea": "eA (neu einsetzend)",
-}
-
-
def compute_file_hash(file_path: Path) -> str:
"""Berechnet SHA-256 Hash einer Datei."""
sha256 = hashlib.sha256()
@@ -135,103 +126,6 @@ def extract_zip_files(base_path: Path) -> List[Path]:
return extracted
-def parse_filename_old_format(filename: str, file_path: Path) -> Optional[Dict]:
- """
- Parst alte Namenskonvention (2016, 2017):
- - {Jahr}{Fach}{Niveau}Lehrer/{Jahr}{Fach}{Niveau}A{Nr}L.pdf
- - Beispiel: 2016DeutschEALehrer/2016DeutschEAA1L.pdf
- """
- # Pattern für Lehrer-Dateien
- pattern = r"(\d{4})([A-Za-zäöüÄÖÜ]+)(EA|GA|NeuGA|NeuEA)(?:Lehrer)?.*?(?:A(\d+)|Aufg(\d+))?L?\.pdf$"
-
- match = re.search(pattern, filename, re.IGNORECASE)
- if not match:
- return None
-
- year = int(match.group(1))
- subject_raw = match.group(2).lower()
- niveau = match.group(3).upper()
- task_num = match.group(4) or match.group(5)
-
- # Prüfe ob es ein Lehrer-Dokument ist (EWH)
- is_ewh = "lehrer" in str(file_path).lower() or filename.endswith("L.pdf")
-
- # Extrahiere Variante (Tech, Wirt, CAS, GTR, etc.)
- variant = None
- variant_patterns = ["Tech", "Wirt", "CAS", "GTR", "Pflicht", "BG", "mitExp", "ohneExp"]
- for v in variant_patterns:
- if v.lower() in str(file_path).lower():
- variant = v
- break
-
- return {
- "year": year,
- "subject": subject_raw,
- "niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
- "task_number": int(task_num) if task_num else None,
- "doc_type": "EWH" if is_ewh else "Aufgabe",
- "variant": variant,
- }
-
-
-def parse_filename_new_format(filename: str, file_path: Path) -> Optional[Dict]:
- """
- Parst neue Namenskonvention (2024, 2025):
- - {Jahr}_{Fach}_{niveau}_{Nr}_EWH.pdf
- - Beispiel: 2025_Deutsch_eA_I_EWH.pdf
- """
- # Pattern für neue Dateien
- pattern = r"(\d{4})_([A-Za-zäöüÄÖÜ]+)(?:BG)?_(eA|gA)(?:_([IVX\d]+))?(?:_(.+))?\.pdf$"
-
- match = re.search(pattern, filename, re.IGNORECASE)
- if not match:
- return None
-
- year = int(match.group(1))
- subject_raw = match.group(2).lower()
- niveau = match.group(3)
- task_id = match.group(4)
- suffix = match.group(5) or ""
-
- # Task-Nummer aus römischen Zahlen
- task_num = None
- if task_id:
- roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
- task_num = roman_map.get(task_id) or (int(task_id) if task_id.isdigit() else None)
-
- # Dokumenttyp
- is_ewh = "EWH" in filename or "ewh" in filename.lower()
-
- # Spezielle Dokumenttypen
- doc_type = "EWH" if is_ewh else "Aufgabe"
- if "Material" in suffix:
- doc_type = "Material"
- elif "GBU" in suffix:
- doc_type = "GBU"
- elif "Ergebnis" in suffix:
- doc_type = "Ergebnis"
- elif "Bewertungsbogen" in suffix:
- doc_type = "Bewertungsbogen"
- elif "HV" in suffix:
- doc_type = "Hörverstehen"
- elif "ME" in suffix:
- doc_type = "Mediation"
-
- # BG Variante
- variant = "BG" if "BG" in filename else None
- if "mitExp" in str(file_path):
- variant = "mitExp"
-
- return {
- "year": year,
- "subject": subject_raw,
- "niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
- "task_number": task_num,
- "doc_type": doc_type,
- "variant": variant,
- }
-
-
def discover_documents(base_path: Path, ewh_only: bool = True) -> List[NiBiSDocument]:
"""
Findet alle relevanten Dokumente in den za-download Verzeichnissen.
diff --git a/klausur-service/backend/nibis_parsers.py b/klausur-service/backend/nibis_parsers.py
new file mode 100644
index 0000000..f65adff
--- /dev/null
+++ b/klausur-service/backend/nibis_parsers.py
@@ -0,0 +1,113 @@
+"""
+NiBiS Filename Parsers
+
+Parses old and new naming conventions for NiBiS Abitur documents.
+"""
+
+import re
+from typing import Dict, Optional
+
+# Niveau-Mapping
+NIVEAU_MAPPING = {
+ "ea": "eA", # erhoehtes Anforderungsniveau
+ "ga": "gA", # grundlegendes Anforderungsniveau
+ "neuga": "gA (neu einsetzend)",
+ "neuea": "eA (neu einsetzend)",
+}
+
+
+def parse_filename_old_format(filename: str, file_path) -> Optional[Dict]:
+ """
+ Parst alte Namenskonvention (2016, 2017):
+ - {Jahr}{Fach}{Niveau}Lehrer/{Jahr}{Fach}{Niveau}A{Nr}L.pdf
+ - Beispiel: 2016DeutschEALehrer/2016DeutschEAA1L.pdf
+ """
+ # Pattern fuer Lehrer-Dateien
+ pattern = r"(\d{4})([A-Za-z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc]+)(EA|GA|NeuGA|NeuEA)(?:Lehrer)?.*?(?:A(\d+)|Aufg(\d+))?L?\.pdf$"
+
+ match = re.search(pattern, filename, re.IGNORECASE)
+ if not match:
+ return None
+
+ year = int(match.group(1))
+ subject_raw = match.group(2).lower()
+ niveau = match.group(3).upper()
+ task_num = match.group(4) or match.group(5)
+
+ # Pruefe ob es ein Lehrer-Dokument ist (EWH)
+ is_ewh = "lehrer" in str(file_path).lower() or filename.endswith("L.pdf")
+
+ # Extrahiere Variante (Tech, Wirt, CAS, GTR, etc.)
+ variant = None
+ variant_patterns = ["Tech", "Wirt", "CAS", "GTR", "Pflicht", "BG", "mitExp", "ohneExp"]
+ for v in variant_patterns:
+ if v.lower() in str(file_path).lower():
+ variant = v
+ break
+
+ return {
+ "year": year,
+ "subject": subject_raw,
+ "niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
+ "task_number": int(task_num) if task_num else None,
+ "doc_type": "EWH" if is_ewh else "Aufgabe",
+ "variant": variant,
+ }
+
+
+def parse_filename_new_format(filename: str, file_path) -> Optional[Dict]:
+ """
+ Parst neue Namenskonvention (2024, 2025):
+ - {Jahr}_{Fach}_{niveau}_{Nr}_EWH.pdf
+ - Beispiel: 2025_Deutsch_eA_I_EWH.pdf
+ """
+ # Pattern fuer neue Dateien
+ pattern = r"(\d{4})_([A-Za-z\u00e4\u00f6\u00fc\u00c4\u00d6\u00dc]+)(?:BG)?_(eA|gA)(?:_([IVX\d]+))?(?:_(.+))?\.pdf$"
+
+ match = re.search(pattern, filename, re.IGNORECASE)
+ if not match:
+ return None
+
+ year = int(match.group(1))
+ subject_raw = match.group(2).lower()
+ niveau = match.group(3)
+ task_id = match.group(4)
+ suffix = match.group(5) or ""
+
+ # Task-Nummer aus roemischen Zahlen
+ task_num = None
+ if task_id:
+ roman_map = {"I": 1, "II": 2, "III": 3, "IV": 4, "V": 5}
+ task_num = roman_map.get(task_id) or (int(task_id) if task_id.isdigit() else None)
+
+ # Dokumenttyp
+ is_ewh = "EWH" in filename or "ewh" in filename.lower()
+
+ # Spezielle Dokumenttypen
+ doc_type = "EWH" if is_ewh else "Aufgabe"
+ if "Material" in suffix:
+ doc_type = "Material"
+ elif "GBU" in suffix:
+ doc_type = "GBU"
+ elif "Ergebnis" in suffix:
+ doc_type = "Ergebnis"
+ elif "Bewertungsbogen" in suffix:
+ doc_type = "Bewertungsbogen"
+ elif "HV" in suffix:
+ doc_type = "Hoerverstehen"
+ elif "ME" in suffix:
+ doc_type = "Mediation"
+
+ # BG Variante
+ variant = "BG" if "BG" in filename else None
+ if "mitExp" in str(file_path):
+ variant = "mitExp"
+
+ return {
+ "year": year,
+ "subject": subject_raw,
+ "niveau": NIVEAU_MAPPING.get(niveau.lower(), niveau),
+ "task_number": task_num,
+ "doc_type": doc_type,
+ "variant": variant,
+ }
diff --git a/klausur-service/backend/nru_worksheet_generator.py b/klausur-service/backend/nru_worksheet_generator.py
index d75715b..e3a79ef 100644
--- a/klausur-service/backend/nru_worksheet_generator.py
+++ b/klausur-service/backend/nru_worksheet_generator.py
@@ -1,557 +1,26 @@
"""
-NRU Worksheet Generator - Generate vocabulary worksheets in NRU format.
+NRU Worksheet Generator — barrel re-export.
-Format:
-- Page 1 (Vokabeln): 3-column table
- - Column 1: English vocabulary
- - Column 2: Empty (child writes German translation)
- - Column 3: Empty (child writes corrected English after parent review)
-
-- Page 2 (Lernsätze): Full-width table
- - Row 1: German sentence (pre-filled)
- - Row 2-3: Empty lines (child writes English translation)
+All implementation split into:
+ nru_worksheet_models — data classes, entry separation
+ nru_worksheet_html — HTML generation
+ nru_worksheet_pdf — PDF generation
Per scanned page, we generate 2 worksheet pages.
"""
-import io
-import logging
-from typing import List, Dict, Tuple
-from dataclasses import dataclass
+# Models
+from nru_worksheet_models import ( # noqa: F401
+ VocabEntry,
+ SentenceEntry,
+ separate_vocab_and_sentences,
+)
-logger = logging.getLogger(__name__)
+# HTML generation
+from nru_worksheet_html import ( # noqa: F401
+ generate_nru_html,
+ generate_nru_worksheet_html,
+)
-
-@dataclass
-class VocabEntry:
- english: str
- german: str
- source_page: int = 1
-
-
-@dataclass
-class SentenceEntry:
- german: str
- english: str # For solution sheet
- source_page: int = 1
-
-
-def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]:
- """
- Separate vocabulary entries into single words/phrases and full sentences.
-
- Sentences are identified by:
- - Ending with punctuation (. ! ?)
- - Being longer than 40 characters
- - Containing multiple words with capital letters mid-sentence
- """
- vocab_list = []
- sentence_list = []
-
- for entry in entries:
- english = entry.get("english", "").strip()
- german = entry.get("german", "").strip()
- source_page = entry.get("source_page", 1)
-
- if not english or not german:
- continue
-
- # Detect if this is a sentence
- is_sentence = (
- english.endswith('.') or
- english.endswith('!') or
- english.endswith('?') or
- len(english) > 50 or
- (len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w))
- )
-
- if is_sentence:
- sentence_list.append(SentenceEntry(
- german=german,
- english=english,
- source_page=source_page
- ))
- else:
- vocab_list.append(VocabEntry(
- english=english,
- german=german,
- source_page=source_page
- ))
-
- return vocab_list, sentence_list
-
-
-def generate_nru_html(
- vocab_list: List[VocabEntry],
- sentence_list: List[SentenceEntry],
- page_number: int,
- title: str = "Vokabeltest",
- show_solutions: bool = False,
- line_height_px: int = 28
-) -> str:
- """
- Generate HTML for NRU-format worksheet.
-
- Returns HTML for 2 pages:
- - Page 1: Vocabulary table (3 columns)
- - Page 2: Sentence practice (full width)
- """
-
- # Filter by page
- page_vocab = [v for v in vocab_list if v.source_page == page_number]
- page_sentences = [s for s in sentence_list if s.source_page == page_number]
-
- html = f"""
-
-
-
-
-
-
-"""
-
- # ========== PAGE 1: VOCABULARY TABLE ==========
- if page_vocab:
- html += f"""
-
-
-
-
-
-
- Englisch
- Deutsch
- Korrektur
-
-
-
-"""
- for v in page_vocab:
- if show_solutions:
- html += f"""
-
- {v.english}
- {v.german}
-
-
-"""
- else:
- html += f"""
-
- {v.english}
-
-
-
-"""
-
- html += """
-
-
-
Vokabeln aus Unit
-
-"""
-
- # ========== PAGE 2: SENTENCE PRACTICE ==========
- if page_sentences:
- html += f"""
-
-
-"""
- for s in page_sentences:
- html += f"""
-
-
-
-
-"""
- if show_solutions:
- html += f"""
-
- {s.english}
-
-
-
-
-"""
- else:
- html += """
-
-
-
-
-
-
-"""
- html += """
-
-"""
-
- html += """
-
Lernsaetze aus Unit
-
-"""
-
- html += """
-
-
-"""
- return html
-
-
-def generate_nru_worksheet_html(
- entries: List[Dict],
- title: str = "Vokabeltest",
- show_solutions: bool = False,
- specific_pages: List[int] = None
-) -> str:
- """
- Generate complete NRU worksheet HTML for all pages.
-
- Args:
- entries: List of vocabulary entries with source_page
- title: Worksheet title
- show_solutions: Whether to show answers
- specific_pages: List of specific page numbers to include (1-indexed)
-
- Returns:
- Complete HTML document
- """
- # Separate into vocab and sentences
- vocab_list, sentence_list = separate_vocab_and_sentences(entries)
-
- # Get unique page numbers
- all_pages = set()
- for v in vocab_list:
- all_pages.add(v.source_page)
- for s in sentence_list:
- all_pages.add(s.source_page)
-
- # Filter to specific pages if requested
- if specific_pages:
- all_pages = all_pages.intersection(set(specific_pages))
-
- pages_sorted = sorted(all_pages)
-
- logger.info(f"Generating NRU worksheet for pages {pages_sorted}")
- logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}")
-
- # Generate HTML for each page
- combined_html = """
-
-
-
-
-
-
-"""
-
- for page_num in pages_sorted:
- page_vocab = [v for v in vocab_list if v.source_page == page_num]
- page_sentences = [s for s in sentence_list if s.source_page == page_num]
-
- # PAGE 1: VOCABULARY TABLE
- if page_vocab:
- combined_html += f"""
-
-
-
-
-
-
- Englisch
- Deutsch
- Korrektur
-
-
-
-"""
- for v in page_vocab:
- if show_solutions:
- combined_html += f"""
-
- {v.english}
- {v.german}
-
-
-"""
- else:
- combined_html += f"""
-
- {v.english}
-
-
-
-"""
-
- combined_html += f"""
-
-
-
{title} - Seite {page_num}
-
-"""
-
- # PAGE 2: SENTENCE PRACTICE
- if page_sentences:
- combined_html += f"""
-
-
-"""
- for s in page_sentences:
- combined_html += f"""
-
-
-
-
-"""
- if show_solutions:
- combined_html += f"""
-
- {s.english}
-
-
-
-
-"""
- else:
- combined_html += """
-
-
-
-
-
-
-"""
- combined_html += """
-
-"""
-
- combined_html += f"""
-
{title} - Seite {page_num}
-
-"""
-
- combined_html += """
-
-
-"""
- return combined_html
-
-
-async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]:
- """
- Generate NRU worksheet PDFs.
-
- Returns:
- Tuple of (worksheet_pdf_bytes, solution_pdf_bytes)
- """
- from weasyprint import HTML
-
- # Generate worksheet HTML
- worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False)
- worksheet_pdf = HTML(string=worksheet_html).write_pdf()
-
- # Generate solution HTML
- solution_pdf = None
- if include_solutions:
- solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True)
- solution_pdf = HTML(string=solution_html).write_pdf()
-
- return worksheet_pdf, solution_pdf
+# PDF generation
+from nru_worksheet_pdf import generate_nru_pdf # noqa: F401
diff --git a/klausur-service/backend/nru_worksheet_html.py b/klausur-service/backend/nru_worksheet_html.py
new file mode 100644
index 0000000..8d881de
--- /dev/null
+++ b/klausur-service/backend/nru_worksheet_html.py
@@ -0,0 +1,466 @@
+"""
+NRU Worksheet HTML — HTML generation for vocabulary worksheets.
+
+Extracted from nru_worksheet_generator.py for modularity.
+"""
+
+import logging
+from typing import List, Dict
+
+from nru_worksheet_models import VocabEntry, SentenceEntry, separate_vocab_and_sentences
+
+logger = logging.getLogger(__name__)
+
+
+def generate_nru_html(
+ vocab_list: List[VocabEntry],
+ sentence_list: List[SentenceEntry],
+ page_number: int,
+ title: str = "Vokabeltest",
+ show_solutions: bool = False,
+ line_height_px: int = 28
+) -> str:
+ """
+ Generate HTML for NRU-format worksheet.
+
+ Returns HTML for 2 pages:
+ - Page 1: Vocabulary table (3 columns)
+ - Page 2: Sentence practice (full width)
+ """
+
+ # Filter by page
+ page_vocab = [v for v in vocab_list if v.source_page == page_number]
+ page_sentences = [s for s in sentence_list if s.source_page == page_number]
+
+ html = f"""
+
+
+
+
+
+
+"""
+
+ # ========== PAGE 1: VOCABULARY TABLE ==========
+ if page_vocab:
+ html += f"""
+
+
+
+
+
+
+ Englisch
+ Deutsch
+ Korrektur
+
+
+
+"""
+ for v in page_vocab:
+ if show_solutions:
+ html += f"""
+
+ {v.english}
+ {v.german}
+
+
+"""
+ else:
+ html += f"""
+
+ {v.english}
+
+
+
+"""
+
+ html += """
+
+
+
Vokabeln aus Unit
+
+"""
+
+ # ========== PAGE 2: SENTENCE PRACTICE ==========
+ if page_sentences:
+ html += f"""
+
+
+"""
+ for s in page_sentences:
+ html += f"""
+
+
+
+
+"""
+ if show_solutions:
+ html += f"""
+
+ {s.english}
+
+
+
+
+"""
+ else:
+ html += """
+
+
+
+
+
+
+"""
+ html += """
+
+"""
+
+ html += """
+
Lernsaetze aus Unit
+
+"""
+
+ html += """
+
+
+"""
+ return html
+
+
+def generate_nru_worksheet_html(
+ entries: List[Dict],
+ title: str = "Vokabeltest",
+ show_solutions: bool = False,
+ specific_pages: List[int] = None
+) -> str:
+ """
+ Generate complete NRU worksheet HTML for all pages.
+
+ Args:
+ entries: List of vocabulary entries with source_page
+ title: Worksheet title
+ show_solutions: Whether to show answers
+ specific_pages: List of specific page numbers to include (1-indexed)
+
+ Returns:
+ Complete HTML document
+ """
+ # Separate into vocab and sentences
+ vocab_list, sentence_list = separate_vocab_and_sentences(entries)
+
+ # Get unique page numbers
+ all_pages = set()
+ for v in vocab_list:
+ all_pages.add(v.source_page)
+ for s in sentence_list:
+ all_pages.add(s.source_page)
+
+ # Filter to specific pages if requested
+ if specific_pages:
+ all_pages = all_pages.intersection(set(specific_pages))
+
+ pages_sorted = sorted(all_pages)
+
+ logger.info(f"Generating NRU worksheet for pages {pages_sorted}")
+ logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}")
+
+ # Generate HTML for each page
+ combined_html = """
+
+
+
+
+
+
+"""
+
+ for page_num in pages_sorted:
+ page_vocab = [v for v in vocab_list if v.source_page == page_num]
+ page_sentences = [s for s in sentence_list if s.source_page == page_num]
+
+ # PAGE 1: VOCABULARY TABLE
+ if page_vocab:
+ combined_html += f"""
+
+
+
+
+
+
+ Englisch
+ Deutsch
+ Korrektur
+
+
+
+"""
+ for v in page_vocab:
+ if show_solutions:
+ combined_html += f"""
+
+ {v.english}
+ {v.german}
+
+
+"""
+ else:
+ combined_html += f"""
+
+ {v.english}
+
+
+
+"""
+
+ combined_html += f"""
+
+
+
{title} - Seite {page_num}
+
+"""
+
+ # PAGE 2: SENTENCE PRACTICE
+ if page_sentences:
+ combined_html += f"""
+
+
+"""
+ for s in page_sentences:
+ combined_html += f"""
+
+
+
+
+"""
+ if show_solutions:
+ combined_html += f"""
+
+ {s.english}
+
+
+
+
+"""
+ else:
+ combined_html += """
+
+
+
+
+
+
+"""
+ combined_html += """
+
+"""
+
+ combined_html += f"""
+
{title} - Seite {page_num}
+
+"""
+
+ combined_html += """
+
+
+"""
+ return combined_html
diff --git a/klausur-service/backend/nru_worksheet_models.py b/klausur-service/backend/nru_worksheet_models.py
new file mode 100644
index 0000000..1276bfe
--- /dev/null
+++ b/klausur-service/backend/nru_worksheet_models.py
@@ -0,0 +1,70 @@
+"""
+NRU Worksheet Models — data classes and entry separation logic.
+
+Extracted from nru_worksheet_generator.py for modularity.
+"""
+
+import logging
+from typing import List, Dict, Tuple
+from dataclasses import dataclass
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class VocabEntry:
+ english: str
+ german: str
+ source_page: int = 1
+
+
+@dataclass
+class SentenceEntry:
+ german: str
+ english: str # For solution sheet
+ source_page: int = 1
+
+
+def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]:
+ """
+ Separate vocabulary entries into single words/phrases and full sentences.
+
+ Sentences are identified by:
+ - Ending with punctuation (. ! ?)
+ - Being longer than 40 characters
+ - Containing multiple words with capital letters mid-sentence
+ """
+ vocab_list = []
+ sentence_list = []
+
+ for entry in entries:
+ english = entry.get("english", "").strip()
+ german = entry.get("german", "").strip()
+ source_page = entry.get("source_page", 1)
+
+ if not english or not german:
+ continue
+
+ # Detect if this is a sentence
+ is_sentence = (
+ english.endswith('.') or
+ english.endswith('!') or
+ english.endswith('?') or
+ len(english) > 50 or
+ (len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w))
+ )
+
+ if is_sentence:
+ sentence_list.append(SentenceEntry(
+ german=german,
+ english=english,
+ source_page=source_page
+ ))
+ else:
+ vocab_list.append(VocabEntry(
+ english=english,
+ german=german,
+ source_page=source_page
+ ))
+
+ return vocab_list, sentence_list
diff --git a/klausur-service/backend/nru_worksheet_pdf.py b/klausur-service/backend/nru_worksheet_pdf.py
new file mode 100644
index 0000000..ceebc1a
--- /dev/null
+++ b/klausur-service/backend/nru_worksheet_pdf.py
@@ -0,0 +1,31 @@
+"""
+NRU Worksheet PDF — PDF generation using weasyprint.
+
+Extracted from nru_worksheet_generator.py for modularity.
+"""
+
+from typing import List, Dict, Tuple
+
+from nru_worksheet_html import generate_nru_worksheet_html
+
+
+async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]:
+ """
+ Generate NRU worksheet PDFs.
+
+ Returns:
+ Tuple of (worksheet_pdf_bytes, solution_pdf_bytes)
+ """
+ from weasyprint import HTML
+
+ # Generate worksheet HTML
+ worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False)
+ worksheet_pdf = HTML(string=worksheet_html).write_pdf()
+
+ # Generate solution HTML
+ solution_pdf = None
+ if include_solutions:
+ solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True)
+ solution_pdf = HTML(string=solution_html).write_pdf()
+
+ return worksheet_pdf, solution_pdf
diff --git a/klausur-service/backend/ocr_pipeline_overlay_grid.py b/klausur-service/backend/ocr_pipeline_overlay_grid.py
new file mode 100644
index 0000000..769ef0f
--- /dev/null
+++ b/klausur-service/backend/ocr_pipeline_overlay_grid.py
@@ -0,0 +1,333 @@
+"""
+Overlay rendering for columns, rows, and words (grid-based overlays).
+
+Extracted from ocr_pipeline_overlays.py for modularity.
+
+Lizenz: Apache 2.0
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import logging
+from typing import Any, Dict, List
+
+import cv2
+import numpy as np
+from fastapi import HTTPException
+from fastapi.responses import Response
+
+from ocr_pipeline_common import _get_base_image_png
+from ocr_pipeline_session_store import get_session_db
+from ocr_pipeline_rows import _draw_box_exclusion_overlay
+
+logger = logging.getLogger(__name__)
+
+
+async def _get_columns_overlay(session_id: str) -> Response:
+ """Generate cropped (or dewarped) image with column borders drawn on it."""
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ column_result = session.get("column_result")
+ if not column_result or not column_result.get("columns"):
+ raise HTTPException(status_code=404, detail="No column data available")
+
+ # Load best available base image (cropped > dewarped > original)
+ base_png = await _get_base_image_png(session_id)
+ if not base_png:
+ raise HTTPException(status_code=404, detail="No base image available")
+
+ arr = np.frombuffer(base_png, dtype=np.uint8)
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
+ if img is None:
+ raise HTTPException(status_code=500, detail="Failed to decode image")
+
+ # Color map for region types (BGR)
+ colors = {
+ "column_en": (255, 180, 0), # Blue
+ "column_de": (0, 200, 0), # Green
+ "column_example": (0, 140, 255), # Orange
+ "column_text": (200, 200, 0), # Cyan/Turquoise
+ "page_ref": (200, 0, 200), # Purple
+ "column_marker": (0, 0, 220), # Red
+ "column_ignore": (180, 180, 180), # Light Gray
+ "header": (128, 128, 128), # Gray
+ "footer": (128, 128, 128), # Gray
+ "margin_top": (100, 100, 100), # Dark Gray
+ "margin_bottom": (100, 100, 100), # Dark Gray
+ }
+
+ overlay = img.copy()
+ for col in column_result["columns"]:
+ x, y = col["x"], col["y"]
+ w, h = col["width"], col["height"]
+ color = colors.get(col.get("type", ""), (200, 200, 200))
+
+ # Semi-transparent fill
+ cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
+
+ # Solid border
+ cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
+
+ # Label with confidence
+ label = col.get("type", "unknown").replace("column_", "").upper()
+ conf = col.get("classification_confidence")
+ if conf is not None and conf < 1.0:
+ label = f"{label} {int(conf * 100)}%"
+ cv2.putText(img, label, (x + 10, y + 30),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
+
+ # Blend overlay at 20% opacity
+ cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
+
+ # Draw detected box boundaries as dashed rectangles
+ zones = column_result.get("zones") or []
+ for zone in zones:
+ if zone.get("zone_type") == "box" and zone.get("box"):
+ box = zone["box"]
+ bx, by = box["x"], box["y"]
+ bw, bh = box["width"], box["height"]
+ box_color = (0, 200, 255) # Yellow (BGR)
+ # Draw dashed rectangle by drawing short line segments
+ dash_len = 15
+ for edge_x in range(bx, bx + bw, dash_len * 2):
+ end_x = min(edge_x + dash_len, bx + bw)
+ cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
+ cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
+ for edge_y in range(by, by + bh, dash_len * 2):
+ end_y = min(edge_y + dash_len, by + bh)
+ cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
+ cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
+ cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
+
+ # Red semi-transparent overlay for box zones
+ _draw_box_exclusion_overlay(img, zones)
+
+ success, result_png = cv2.imencode(".png", img)
+ if not success:
+ raise HTTPException(status_code=500, detail="Failed to encode overlay image")
+
+ return Response(content=result_png.tobytes(), media_type="image/png")
+
+
+async def _get_rows_overlay(session_id: str) -> Response:
+ """Generate cropped (or dewarped) image with row bands drawn on it."""
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ row_result = session.get("row_result")
+ if not row_result or not row_result.get("rows"):
+ raise HTTPException(status_code=404, detail="No row data available")
+
+ # Load best available base image (cropped > dewarped > original)
+ base_png = await _get_base_image_png(session_id)
+ if not base_png:
+ raise HTTPException(status_code=404, detail="No base image available")
+
+ arr = np.frombuffer(base_png, dtype=np.uint8)
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
+ if img is None:
+ raise HTTPException(status_code=500, detail="Failed to decode image")
+
+ # Color map for row types (BGR)
+ row_colors = {
+ "content": (255, 180, 0), # Blue
+ "header": (128, 128, 128), # Gray
+ "footer": (128, 128, 128), # Gray
+ "margin_top": (100, 100, 100), # Dark Gray
+ "margin_bottom": (100, 100, 100), # Dark Gray
+ }
+
+ overlay = img.copy()
+ for row in row_result["rows"]:
+ x, y = row["x"], row["y"]
+ w, h = row["width"], row["height"]
+ row_type = row.get("row_type", "content")
+ color = row_colors.get(row_type, (200, 200, 200))
+
+ # Semi-transparent fill
+ cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
+
+ # Solid border
+ cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
+
+ # Label
+ idx = row.get("index", 0)
+ label = f"R{idx} {row_type.upper()}"
+ wc = row.get("word_count", 0)
+ if wc:
+ label = f"{label} ({wc}w)"
+ cv2.putText(img, label, (x + 5, y + 18),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
+
+ # Blend overlay at 15% opacity
+ cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
+
+ # Draw zone separator lines if zones exist
+ column_result = session.get("column_result") or {}
+ zones = column_result.get("zones") or []
+ if zones:
+ img_w_px = img.shape[1]
+ zone_color = (0, 200, 255) # Yellow (BGR)
+ dash_len = 20
+ for zone in zones:
+ if zone.get("zone_type") == "box":
+ zy = zone["y"]
+ zh = zone["height"]
+ for line_y in [zy, zy + zh]:
+ for sx in range(0, img_w_px, dash_len * 2):
+ ex = min(sx + dash_len, img_w_px)
+ cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
+
+ # Red semi-transparent overlay for box zones
+ _draw_box_exclusion_overlay(img, zones)
+
+ success, result_png = cv2.imencode(".png", img)
+ if not success:
+ raise HTTPException(status_code=500, detail="Failed to encode overlay image")
+
+ return Response(content=result_png.tobytes(), media_type="image/png")
+
+
+async def _get_words_overlay(session_id: str) -> Response:
+ """Generate cropped (or dewarped) image with cell grid drawn on it."""
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ word_result = session.get("word_result")
+ if not word_result:
+ raise HTTPException(status_code=404, detail="No word data available")
+
+ # Support both new cell-based and legacy entry-based formats
+ cells = word_result.get("cells")
+ if not cells and not word_result.get("entries"):
+ raise HTTPException(status_code=404, detail="No word data available")
+
+ # Load best available base image (cropped > dewarped > original)
+ base_png = await _get_base_image_png(session_id)
+ if not base_png:
+ raise HTTPException(status_code=404, detail="No base image available")
+
+ arr = np.frombuffer(base_png, dtype=np.uint8)
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
+ if img is None:
+ raise HTTPException(status_code=500, detail="Failed to decode image")
+
+ img_h, img_w = img.shape[:2]
+
+ overlay = img.copy()
+
+ if cells:
+ # New cell-based overlay: color by column index
+ col_palette = [
+ (255, 180, 0), # Blue (BGR)
+ (0, 200, 0), # Green
+ (0, 140, 255), # Orange
+ (200, 100, 200), # Purple
+ (200, 200, 0), # Cyan
+ (100, 200, 200), # Yellow-ish
+ ]
+
+ for cell in cells:
+ bbox = cell.get("bbox_px", {})
+ cx = bbox.get("x", 0)
+ cy = bbox.get("y", 0)
+ cw = bbox.get("w", 0)
+ ch = bbox.get("h", 0)
+ if cw <= 0 or ch <= 0:
+ continue
+
+ col_idx = cell.get("col_index", 0)
+ color = col_palette[col_idx % len(col_palette)]
+
+ # Cell rectangle border
+ cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
+ # Semi-transparent fill
+ cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
+
+ # Cell-ID label (top-left corner)
+ cell_id = cell.get("cell_id", "")
+ cv2.putText(img, cell_id, (cx + 2, cy + 10),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
+
+ # Text label (bottom of cell)
+ text = cell.get("text", "")
+ if text:
+ conf = cell.get("confidence", 0)
+ if conf >= 70:
+ text_color = (0, 180, 0)
+ elif conf >= 50:
+ text_color = (0, 180, 220)
+ else:
+ text_color = (0, 0, 220)
+
+ label = text.replace('\n', ' ')[:30]
+ cv2.putText(img, label, (cx + 3, cy + ch - 4),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
+ else:
+ # Legacy fallback: entry-based overlay (for old sessions)
+ column_result = session.get("column_result")
+ row_result = session.get("row_result")
+ col_colors = {
+ "column_en": (255, 180, 0),
+ "column_de": (0, 200, 0),
+ "column_example": (0, 140, 255),
+ }
+
+ columns = []
+ if column_result and column_result.get("columns"):
+ columns = [c for c in column_result["columns"]
+ if c.get("type", "").startswith("column_")]
+
+ content_rows_data = []
+ if row_result and row_result.get("rows"):
+ content_rows_data = [r for r in row_result["rows"]
+ if r.get("row_type") == "content"]
+
+ for col in columns:
+ col_type = col.get("type", "")
+ color = col_colors.get(col_type, (200, 200, 200))
+ cx, cw = col["x"], col["width"]
+ for row in content_rows_data:
+ ry, rh = row["y"], row["height"]
+ cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
+ cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
+
+ entries = word_result["entries"]
+ entry_by_row: Dict[int, Dict] = {}
+ for entry in entries:
+ entry_by_row[entry.get("row_index", -1)] = entry
+
+ for row_idx, row in enumerate(content_rows_data):
+ entry = entry_by_row.get(row_idx)
+ if not entry:
+ continue
+ conf = entry.get("confidence", 0)
+ text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
+ ry, rh = row["y"], row["height"]
+ for col in columns:
+ col_type = col.get("type", "")
+ cx, cw = col["x"], col["width"]
+ field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
+ text = entry.get(field, "") if field else ""
+ if text:
+ label = text.replace('\n', ' ')[:30]
+ cv2.putText(img, label, (cx + 3, ry + rh - 4),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
+
+ # Blend overlay at 10% opacity
+ cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
+
+ # Red semi-transparent overlay for box zones
+ column_result = session.get("column_result") or {}
+ zones = column_result.get("zones") or []
+ _draw_box_exclusion_overlay(img, zones)
+
+ success, result_png = cv2.imencode(".png", img)
+ if not success:
+ raise HTTPException(status_code=500, detail="Failed to encode overlay image")
+
+ return Response(content=result_png.tobytes(), media_type="image/png")
diff --git a/klausur-service/backend/ocr_pipeline_overlay_structure.py b/klausur-service/backend/ocr_pipeline_overlay_structure.py
new file mode 100644
index 0000000..ad48382
--- /dev/null
+++ b/klausur-service/backend/ocr_pipeline_overlay_structure.py
@@ -0,0 +1,205 @@
+"""
+Overlay rendering for structure detection (boxes, zones, colors, graphics).
+
+Extracted from ocr_pipeline_overlays.py for modularity.
+
+Lizenz: Apache 2.0
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import logging
+from typing import Any, Dict, List
+
+import cv2
+import numpy as np
+from fastapi import HTTPException
+from fastapi.responses import Response
+
+from ocr_pipeline_common import _get_base_image_png
+from ocr_pipeline_session_store import get_session_db
+from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
+from cv_box_detect import detect_boxes, split_page_into_zones
+
+logger = logging.getLogger(__name__)
+
+
+async def _get_structure_overlay(session_id: str) -> Response:
+ """Generate overlay image showing detected boxes, zones, and color regions."""
+ base_png = await _get_base_image_png(session_id)
+ if not base_png:
+ raise HTTPException(status_code=404, detail="No base image available")
+
+ arr = np.frombuffer(base_png, dtype=np.uint8)
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
+ if img is None:
+ raise HTTPException(status_code=500, detail="Failed to decode image")
+
+ h, w = img.shape[:2]
+
+ # Get structure result (run detection if not cached)
+ session = await get_session_db(session_id)
+ structure = (session or {}).get("structure_result")
+
+ if not structure:
+ # Run detection on-the-fly
+ margin = int(min(w, h) * 0.03)
+ content_x, content_y = margin, margin
+ content_w_px = w - 2 * margin
+ content_h_px = h - 2 * margin
+ boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
+ zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
+ structure = {
+ "boxes": [
+ {"x": b.x, "y": b.y, "w": b.width, "h": b.height,
+ "confidence": b.confidence, "border_thickness": b.border_thickness}
+ for b in boxes
+ ],
+ "zones": [
+ {"index": z.index, "zone_type": z.zone_type,
+ "y": z.y, "h": z.height, "x": z.x, "w": z.width}
+ for z in zones
+ ],
+ }
+
+ overlay = img.copy()
+
+ # --- Draw zone boundaries ---
+ zone_colors = {
+ "content": (200, 200, 200), # light gray
+ "box": (255, 180, 0), # blue-ish (BGR)
+ }
+ for zone in structure.get("zones", []):
+ zx = zone["x"]
+ zy = zone["y"]
+ zw = zone["w"]
+ zh = zone["h"]
+ color = zone_colors.get(zone["zone_type"], (200, 200, 200))
+
+ # Draw zone boundary as dashed line
+ dash_len = 12
+ for edge_x in range(zx, zx + zw, dash_len * 2):
+ end_x = min(edge_x + dash_len, zx + zw)
+ cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
+ cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
+
+ # Zone label
+ zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
+ cv2.putText(img, zone_label, (zx + 5, zy + 15),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
+
+ # --- Draw detected boxes ---
+ # Color map for box backgrounds (BGR)
+ bg_hex_to_bgr = {
+ "#dc2626": (38, 38, 220), # red
+ "#2563eb": (235, 99, 37), # blue
+ "#16a34a": (74, 163, 22), # green
+ "#ea580c": (12, 88, 234), # orange
+ "#9333ea": (234, 51, 147), # purple
+ "#ca8a04": (4, 138, 202), # yellow
+ "#6b7280": (128, 114, 107), # gray
+ }
+
+ for box_data in structure.get("boxes", []):
+ bx = box_data["x"]
+ by = box_data["y"]
+ bw = box_data["w"]
+ bh = box_data["h"]
+ conf = box_data.get("confidence", 0)
+ thickness = box_data.get("border_thickness", 0)
+ bg_hex = box_data.get("bg_color_hex", "#6b7280")
+ bg_name = box_data.get("bg_color_name", "")
+
+ # Box fill color
+ fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
+
+ # Semi-transparent fill
+ cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
+
+ # Solid border
+ border_color = fill_bgr
+ cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
+
+ # Label
+ label = f"BOX"
+ if bg_name and bg_name not in ("unknown", "white"):
+ label += f" ({bg_name})"
+ if thickness > 0:
+ label += f" border={thickness}px"
+ label += f" {int(conf * 100)}%"
+ cv2.putText(img, label, (bx + 8, by + 22),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
+ cv2.putText(img, label, (bx + 8, by + 22),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
+
+ # Blend overlay at 15% opacity
+ cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
+
+ # --- Draw color regions (HSV masks) ---
+ hsv = cv2.cvtColor(
+ cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
+ cv2.COLOR_BGR2HSV,
+ )
+ color_bgr_map = {
+ "red": (0, 0, 255),
+ "orange": (0, 140, 255),
+ "yellow": (0, 200, 255),
+ "green": (0, 200, 0),
+ "blue": (255, 150, 0),
+ "purple": (200, 0, 200),
+ }
+ for color_name, ranges in _COLOR_RANGES.items():
+ mask = np.zeros((h, w), dtype=np.uint8)
+ for lower, upper in ranges:
+ mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
+ # Only draw if there are significant colored pixels
+ if np.sum(mask > 0) < 100:
+ continue
+ # Draw colored contours
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ draw_color = color_bgr_map.get(color_name, (200, 200, 200))
+ for cnt in contours:
+ area = cv2.contourArea(cnt)
+ if area < 20:
+ continue
+ cv2.drawContours(img, [cnt], -1, draw_color, 2)
+
+ # --- Draw graphic elements ---
+ graphics_data = structure.get("graphics", [])
+ shape_icons = {
+ "image": "IMAGE",
+ "illustration": "ILLUST",
+ }
+ for gfx in graphics_data:
+ gx, gy = gfx["x"], gfx["y"]
+ gw, gh = gfx["w"], gfx["h"]
+ shape = gfx.get("shape", "icon")
+ color_hex = gfx.get("color_hex", "#6b7280")
+ conf = gfx.get("confidence", 0)
+
+ # Pick draw color based on element color (BGR)
+ gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
+
+ # Draw bounding box (dashed style via short segments)
+ dash = 6
+ for seg_x in range(gx, gx + gw, dash * 2):
+ end_x = min(seg_x + dash, gx + gw)
+ cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
+ cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
+ for seg_y in range(gy, gy + gh, dash * 2):
+ end_y = min(seg_y + dash, gy + gh)
+ cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
+ cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
+
+ # Label
+ icon = shape_icons.get(shape, shape.upper()[:5])
+ label = f"{icon} {int(conf * 100)}%"
+ # White background for readability
+ (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
+ lx = gx + 2
+ ly = max(gy - 4, th + 4)
+ cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
+ cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
+
+ # Encode result
+ _, png_buf = cv2.imencode(".png", img)
+ return Response(content=png_buf.tobytes(), media_type="image/png")
diff --git a/klausur-service/backend/ocr_pipeline_overlays.py b/klausur-service/backend/ocr_pipeline_overlays.py
index 2789557..7a30f9b 100644
--- a/klausur-service/backend/ocr_pipeline_overlays.py
+++ b/klausur-service/backend/ocr_pipeline_overlays.py
@@ -1,34 +1,23 @@
"""
-Overlay image rendering for OCR pipeline.
+Overlay image rendering for OCR pipeline — barrel re-export.
-Generates visual overlays for structure, columns, rows, and words
-detection results.
+All implementation split into:
+ ocr_pipeline_overlay_structure — structure overlay (boxes, zones, colors, graphics)
+ ocr_pipeline_overlay_grid — columns, rows, words overlays
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
-import logging
-from dataclasses import asdict
-from typing import Any, Dict, List, Optional
-
-import cv2
-import numpy as np
from fastapi import HTTPException
from fastapi.responses import Response
-from ocr_pipeline_common import (
- _cache,
- _get_base_image_png,
- _load_session_to_cache,
- _get_cached,
+from ocr_pipeline_overlay_structure import _get_structure_overlay # noqa: F401
+from ocr_pipeline_overlay_grid import ( # noqa: F401
+ _get_columns_overlay,
+ _get_rows_overlay,
+ _get_words_overlay,
)
-from ocr_pipeline_session_store import get_session_db, get_session_image
-from cv_color_detect import _COLOR_HEX, _COLOR_RANGES
-from cv_box_detect import detect_boxes, split_page_into_zones
-from ocr_pipeline_rows import _draw_box_exclusion_overlay
-
-logger = logging.getLogger(__name__)
async def render_overlay(overlay_type: str, session_id: str) -> Response:
@@ -43,505 +32,3 @@ async def render_overlay(overlay_type: str, session_id: str) -> Response:
return await _get_words_overlay(session_id)
else:
raise HTTPException(status_code=400, detail=f"Unknown overlay type: {overlay_type}")
-
-
-async def _get_structure_overlay(session_id: str) -> Response:
- """Generate overlay image showing detected boxes, zones, and color regions."""
- base_png = await _get_base_image_png(session_id)
- if not base_png:
- raise HTTPException(status_code=404, detail="No base image available")
-
- arr = np.frombuffer(base_png, dtype=np.uint8)
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
- if img is None:
- raise HTTPException(status_code=500, detail="Failed to decode image")
-
- h, w = img.shape[:2]
-
- # Get structure result (run detection if not cached)
- session = await get_session_db(session_id)
- structure = (session or {}).get("structure_result")
-
- if not structure:
- # Run detection on-the-fly
- margin = int(min(w, h) * 0.03)
- content_x, content_y = margin, margin
- content_w_px = w - 2 * margin
- content_h_px = h - 2 * margin
- boxes = detect_boxes(img, content_x, content_w_px, content_y, content_h_px)
- zones = split_page_into_zones(content_x, content_y, content_w_px, content_h_px, boxes)
- structure = {
- "boxes": [
- {"x": b.x, "y": b.y, "w": b.width, "h": b.height,
- "confidence": b.confidence, "border_thickness": b.border_thickness}
- for b in boxes
- ],
- "zones": [
- {"index": z.index, "zone_type": z.zone_type,
- "y": z.y, "h": z.height, "x": z.x, "w": z.width}
- for z in zones
- ],
- }
-
- overlay = img.copy()
-
- # --- Draw zone boundaries ---
- zone_colors = {
- "content": (200, 200, 200), # light gray
- "box": (255, 180, 0), # blue-ish (BGR)
- }
- for zone in structure.get("zones", []):
- zx = zone["x"]
- zy = zone["y"]
- zw = zone["w"]
- zh = zone["h"]
- color = zone_colors.get(zone["zone_type"], (200, 200, 200))
-
- # Draw zone boundary as dashed line
- dash_len = 12
- for edge_x in range(zx, zx + zw, dash_len * 2):
- end_x = min(edge_x + dash_len, zx + zw)
- cv2.line(img, (edge_x, zy), (end_x, zy), color, 1)
- cv2.line(img, (edge_x, zy + zh), (end_x, zy + zh), color, 1)
-
- # Zone label
- zone_label = f"Zone {zone['index']} ({zone['zone_type']})"
- cv2.putText(img, zone_label, (zx + 5, zy + 15),
- cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)
-
- # --- Draw detected boxes ---
- # Color map for box backgrounds (BGR)
- bg_hex_to_bgr = {
- "#dc2626": (38, 38, 220), # red
- "#2563eb": (235, 99, 37), # blue
- "#16a34a": (74, 163, 22), # green
- "#ea580c": (12, 88, 234), # orange
- "#9333ea": (234, 51, 147), # purple
- "#ca8a04": (4, 138, 202), # yellow
- "#6b7280": (128, 114, 107), # gray
- }
-
- for box_data in structure.get("boxes", []):
- bx = box_data["x"]
- by = box_data["y"]
- bw = box_data["w"]
- bh = box_data["h"]
- conf = box_data.get("confidence", 0)
- thickness = box_data.get("border_thickness", 0)
- bg_hex = box_data.get("bg_color_hex", "#6b7280")
- bg_name = box_data.get("bg_color_name", "")
-
- # Box fill color
- fill_bgr = bg_hex_to_bgr.get(bg_hex, (128, 114, 107))
-
- # Semi-transparent fill
- cv2.rectangle(overlay, (bx, by), (bx + bw, by + bh), fill_bgr, -1)
-
- # Solid border
- border_color = fill_bgr
- cv2.rectangle(img, (bx, by), (bx + bw, by + bh), border_color, 3)
-
- # Label
- label = f"BOX"
- if bg_name and bg_name not in ("unknown", "white"):
- label += f" ({bg_name})"
- if thickness > 0:
- label += f" border={thickness}px"
- label += f" {int(conf * 100)}%"
- cv2.putText(img, label, (bx + 8, by + 22),
- cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2)
- cv2.putText(img, label, (bx + 8, by + 22),
- cv2.FONT_HERSHEY_SIMPLEX, 0.55, border_color, 1)
-
- # Blend overlay at 15% opacity
- cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
-
- # --- Draw color regions (HSV masks) ---
- hsv = cv2.cvtColor(
- cv2.imdecode(np.frombuffer(base_png, dtype=np.uint8), cv2.IMREAD_COLOR),
- cv2.COLOR_BGR2HSV,
- )
- color_bgr_map = {
- "red": (0, 0, 255),
- "orange": (0, 140, 255),
- "yellow": (0, 200, 255),
- "green": (0, 200, 0),
- "blue": (255, 150, 0),
- "purple": (200, 0, 200),
- }
- for color_name, ranges in _COLOR_RANGES.items():
- mask = np.zeros((h, w), dtype=np.uint8)
- for lower, upper in ranges:
- mask = cv2.bitwise_or(mask, cv2.inRange(hsv, lower, upper))
- # Only draw if there are significant colored pixels
- if np.sum(mask > 0) < 100:
- continue
- # Draw colored contours
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- draw_color = color_bgr_map.get(color_name, (200, 200, 200))
- for cnt in contours:
- area = cv2.contourArea(cnt)
- if area < 20:
- continue
- cv2.drawContours(img, [cnt], -1, draw_color, 2)
-
- # --- Draw graphic elements ---
- graphics_data = structure.get("graphics", [])
- shape_icons = {
- "image": "IMAGE",
- "illustration": "ILLUST",
- }
- for gfx in graphics_data:
- gx, gy = gfx["x"], gfx["y"]
- gw, gh = gfx["w"], gfx["h"]
- shape = gfx.get("shape", "icon")
- color_hex = gfx.get("color_hex", "#6b7280")
- conf = gfx.get("confidence", 0)
-
- # Pick draw color based on element color (BGR)
- gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107))
-
- # Draw bounding box (dashed style via short segments)
- dash = 6
- for seg_x in range(gx, gx + gw, dash * 2):
- end_x = min(seg_x + dash, gx + gw)
- cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2)
- cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2)
- for seg_y in range(gy, gy + gh, dash * 2):
- end_y = min(seg_y + dash, gy + gh)
- cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2)
- cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2)
-
- # Label
- icon = shape_icons.get(shape, shape.upper()[:5])
- label = f"{icon} {int(conf * 100)}%"
- # White background for readability
- (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
- lx = gx + 2
- ly = max(gy - 4, th + 4)
- cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1)
- cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1)
-
- # Encode result
- _, png_buf = cv2.imencode(".png", img)
- return Response(content=png_buf.tobytes(), media_type="image/png")
-
-
-
-async def _get_columns_overlay(session_id: str) -> Response:
- """Generate cropped (or dewarped) image with column borders drawn on it."""
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- column_result = session.get("column_result")
- if not column_result or not column_result.get("columns"):
- raise HTTPException(status_code=404, detail="No column data available")
-
- # Load best available base image (cropped > dewarped > original)
- base_png = await _get_base_image_png(session_id)
- if not base_png:
- raise HTTPException(status_code=404, detail="No base image available")
-
- arr = np.frombuffer(base_png, dtype=np.uint8)
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
- if img is None:
- raise HTTPException(status_code=500, detail="Failed to decode image")
-
- # Color map for region types (BGR)
- colors = {
- "column_en": (255, 180, 0), # Blue
- "column_de": (0, 200, 0), # Green
- "column_example": (0, 140, 255), # Orange
- "column_text": (200, 200, 0), # Cyan/Turquoise
- "page_ref": (200, 0, 200), # Purple
- "column_marker": (0, 0, 220), # Red
- "column_ignore": (180, 180, 180), # Light Gray
- "header": (128, 128, 128), # Gray
- "footer": (128, 128, 128), # Gray
- "margin_top": (100, 100, 100), # Dark Gray
- "margin_bottom": (100, 100, 100), # Dark Gray
- }
-
- overlay = img.copy()
- for col in column_result["columns"]:
- x, y = col["x"], col["y"]
- w, h = col["width"], col["height"]
- color = colors.get(col.get("type", ""), (200, 200, 200))
-
- # Semi-transparent fill
- cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
-
- # Solid border
- cv2.rectangle(img, (x, y), (x + w, y + h), color, 3)
-
- # Label with confidence
- label = col.get("type", "unknown").replace("column_", "").upper()
- conf = col.get("classification_confidence")
- if conf is not None and conf < 1.0:
- label = f"{label} {int(conf * 100)}%"
- cv2.putText(img, label, (x + 10, y + 30),
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2)
-
- # Blend overlay at 20% opacity
- cv2.addWeighted(overlay, 0.2, img, 0.8, 0, img)
-
- # Draw detected box boundaries as dashed rectangles
- zones = column_result.get("zones") or []
- for zone in zones:
- if zone.get("zone_type") == "box" and zone.get("box"):
- box = zone["box"]
- bx, by = box["x"], box["y"]
- bw, bh = box["width"], box["height"]
- box_color = (0, 200, 255) # Yellow (BGR)
- # Draw dashed rectangle by drawing short line segments
- dash_len = 15
- for edge_x in range(bx, bx + bw, dash_len * 2):
- end_x = min(edge_x + dash_len, bx + bw)
- cv2.line(img, (edge_x, by), (end_x, by), box_color, 2)
- cv2.line(img, (edge_x, by + bh), (end_x, by + bh), box_color, 2)
- for edge_y in range(by, by + bh, dash_len * 2):
- end_y = min(edge_y + dash_len, by + bh)
- cv2.line(img, (bx, edge_y), (bx, end_y), box_color, 2)
- cv2.line(img, (bx + bw, edge_y), (bx + bw, end_y), box_color, 2)
- cv2.putText(img, "BOX", (bx + 10, by + bh - 10),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
-
- # Red semi-transparent overlay for box zones
- _draw_box_exclusion_overlay(img, zones)
-
- success, result_png = cv2.imencode(".png", img)
- if not success:
- raise HTTPException(status_code=500, detail="Failed to encode overlay image")
-
- return Response(content=result_png.tobytes(), media_type="image/png")
-
-
-# ---------------------------------------------------------------------------
-# Row Detection Endpoints
-# ---------------------------------------------------------------------------
-
-
-
-async def _get_rows_overlay(session_id: str) -> Response:
- """Generate cropped (or dewarped) image with row bands drawn on it."""
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- row_result = session.get("row_result")
- if not row_result or not row_result.get("rows"):
- raise HTTPException(status_code=404, detail="No row data available")
-
- # Load best available base image (cropped > dewarped > original)
- base_png = await _get_base_image_png(session_id)
- if not base_png:
- raise HTTPException(status_code=404, detail="No base image available")
-
- arr = np.frombuffer(base_png, dtype=np.uint8)
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
- if img is None:
- raise HTTPException(status_code=500, detail="Failed to decode image")
-
- # Color map for row types (BGR)
- row_colors = {
- "content": (255, 180, 0), # Blue
- "header": (128, 128, 128), # Gray
- "footer": (128, 128, 128), # Gray
- "margin_top": (100, 100, 100), # Dark Gray
- "margin_bottom": (100, 100, 100), # Dark Gray
- }
-
- overlay = img.copy()
- for row in row_result["rows"]:
- x, y = row["x"], row["y"]
- w, h = row["width"], row["height"]
- row_type = row.get("row_type", "content")
- color = row_colors.get(row_type, (200, 200, 200))
-
- # Semi-transparent fill
- cv2.rectangle(overlay, (x, y), (x + w, y + h), color, -1)
-
- # Solid border
- cv2.rectangle(img, (x, y), (x + w, y + h), color, 2)
-
- # Label
- idx = row.get("index", 0)
- label = f"R{idx} {row_type.upper()}"
- wc = row.get("word_count", 0)
- if wc:
- label = f"{label} ({wc}w)"
- cv2.putText(img, label, (x + 5, y + 18),
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
-
- # Blend overlay at 15% opacity
- cv2.addWeighted(overlay, 0.15, img, 0.85, 0, img)
-
- # Draw zone separator lines if zones exist
- column_result = session.get("column_result") or {}
- zones = column_result.get("zones") or []
- if zones:
- img_w_px = img.shape[1]
- zone_color = (0, 200, 255) # Yellow (BGR)
- dash_len = 20
- for zone in zones:
- if zone.get("zone_type") == "box":
- zy = zone["y"]
- zh = zone["height"]
- for line_y in [zy, zy + zh]:
- for sx in range(0, img_w_px, dash_len * 2):
- ex = min(sx + dash_len, img_w_px)
- cv2.line(img, (sx, line_y), (ex, line_y), zone_color, 2)
-
- # Red semi-transparent overlay for box zones
- _draw_box_exclusion_overlay(img, zones)
-
- success, result_png = cv2.imencode(".png", img)
- if not success:
- raise HTTPException(status_code=500, detail="Failed to encode overlay image")
-
- return Response(content=result_png.tobytes(), media_type="image/png")
-
-
-
-async def _get_words_overlay(session_id: str) -> Response:
- """Generate cropped (or dewarped) image with cell grid drawn on it."""
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- word_result = session.get("word_result")
- if not word_result:
- raise HTTPException(status_code=404, detail="No word data available")
-
- # Support both new cell-based and legacy entry-based formats
- cells = word_result.get("cells")
- if not cells and not word_result.get("entries"):
- raise HTTPException(status_code=404, detail="No word data available")
-
- # Load best available base image (cropped > dewarped > original)
- base_png = await _get_base_image_png(session_id)
- if not base_png:
- raise HTTPException(status_code=404, detail="No base image available")
-
- arr = np.frombuffer(base_png, dtype=np.uint8)
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
- if img is None:
- raise HTTPException(status_code=500, detail="Failed to decode image")
-
- img_h, img_w = img.shape[:2]
-
- overlay = img.copy()
-
- if cells:
- # New cell-based overlay: color by column index
- col_palette = [
- (255, 180, 0), # Blue (BGR)
- (0, 200, 0), # Green
- (0, 140, 255), # Orange
- (200, 100, 200), # Purple
- (200, 200, 0), # Cyan
- (100, 200, 200), # Yellow-ish
- ]
-
- for cell in cells:
- bbox = cell.get("bbox_px", {})
- cx = bbox.get("x", 0)
- cy = bbox.get("y", 0)
- cw = bbox.get("w", 0)
- ch = bbox.get("h", 0)
- if cw <= 0 or ch <= 0:
- continue
-
- col_idx = cell.get("col_index", 0)
- color = col_palette[col_idx % len(col_palette)]
-
- # Cell rectangle border
- cv2.rectangle(img, (cx, cy), (cx + cw, cy + ch), color, 1)
- # Semi-transparent fill
- cv2.rectangle(overlay, (cx, cy), (cx + cw, cy + ch), color, -1)
-
- # Cell-ID label (top-left corner)
- cell_id = cell.get("cell_id", "")
- cv2.putText(img, cell_id, (cx + 2, cy + 10),
- cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1)
-
- # Text label (bottom of cell)
- text = cell.get("text", "")
- if text:
- conf = cell.get("confidence", 0)
- if conf >= 70:
- text_color = (0, 180, 0)
- elif conf >= 50:
- text_color = (0, 180, 220)
- else:
- text_color = (0, 0, 220)
-
- label = text.replace('\n', ' ')[:30]
- cv2.putText(img, label, (cx + 3, cy + ch - 4),
- cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
- else:
- # Legacy fallback: entry-based overlay (for old sessions)
- column_result = session.get("column_result")
- row_result = session.get("row_result")
- col_colors = {
- "column_en": (255, 180, 0),
- "column_de": (0, 200, 0),
- "column_example": (0, 140, 255),
- }
-
- columns = []
- if column_result and column_result.get("columns"):
- columns = [c for c in column_result["columns"]
- if c.get("type", "").startswith("column_")]
-
- content_rows_data = []
- if row_result and row_result.get("rows"):
- content_rows_data = [r for r in row_result["rows"]
- if r.get("row_type") == "content"]
-
- for col in columns:
- col_type = col.get("type", "")
- color = col_colors.get(col_type, (200, 200, 200))
- cx, cw = col["x"], col["width"]
- for row in content_rows_data:
- ry, rh = row["y"], row["height"]
- cv2.rectangle(img, (cx, ry), (cx + cw, ry + rh), color, 1)
- cv2.rectangle(overlay, (cx, ry), (cx + cw, ry + rh), color, -1)
-
- entries = word_result["entries"]
- entry_by_row: Dict[int, Dict] = {}
- for entry in entries:
- entry_by_row[entry.get("row_index", -1)] = entry
-
- for row_idx, row in enumerate(content_rows_data):
- entry = entry_by_row.get(row_idx)
- if not entry:
- continue
- conf = entry.get("confidence", 0)
- text_color = (0, 180, 0) if conf >= 70 else (0, 180, 220) if conf >= 50 else (0, 0, 220)
- ry, rh = row["y"], row["height"]
- for col in columns:
- col_type = col.get("type", "")
- cx, cw = col["x"], col["width"]
- field = {"column_en": "english", "column_de": "german", "column_example": "example"}.get(col_type, "")
- text = entry.get(field, "") if field else ""
- if text:
- label = text.replace('\n', ' ')[:30]
- cv2.putText(img, label, (cx + 3, ry + rh - 4),
- cv2.FONT_HERSHEY_SIMPLEX, 0.35, text_color, 1)
-
- # Blend overlay at 10% opacity
- cv2.addWeighted(overlay, 0.1, img, 0.9, 0, img)
-
- # Red semi-transparent overlay for box zones
- column_result = session.get("column_result") or {}
- zones = column_result.get("zones") or []
- _draw_box_exclusion_overlay(img, zones)
-
- success, result_png = cv2.imencode(".png", img)
- if not success:
- raise HTTPException(status_code=500, detail="Failed to encode overlay image")
-
- return Response(content=result_png.tobytes(), media_type="image/png")
-
diff --git a/klausur-service/backend/ocr_pipeline_regression.py b/klausur-service/backend/ocr_pipeline_regression.py
index b6e09a0..5c8ff89 100644
--- a/klausur-service/backend/ocr_pipeline_regression.py
+++ b/klausur-service/backend/ocr_pipeline_regression.py
@@ -1,607 +1,22 @@
"""
-OCR Pipeline Regression Tests — Ground Truth comparison system.
+OCR Pipeline Regression Tests — barrel re-export.
-Allows marking sessions as "ground truth" and re-running build_grid()
-to detect regressions after code changes.
+All implementation split into:
+ ocr_pipeline_regression_helpers — DB persistence, snapshot, comparison
+ ocr_pipeline_regression_endpoints — FastAPI routes
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
-import json
-import logging
-import os
-import time
-import uuid
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-from fastapi import APIRouter, HTTPException, Query
-
-from grid_editor_api import _build_grid_core
-from ocr_pipeline_session_store import (
- get_pool,
- get_session_db,
- list_ground_truth_sessions_db,
- update_session_db,
+# Helpers (used by grid_editor_api_grid.py)
+from ocr_pipeline_regression_helpers import ( # noqa: F401
+ _init_regression_table,
+ _persist_regression_run,
+ _extract_cells_for_comparison,
+ _build_reference_snapshot,
+ compare_grids,
)
-logger = logging.getLogger(__name__)
-
-router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
-
-
-# ---------------------------------------------------------------------------
-# DB persistence for regression runs
-# ---------------------------------------------------------------------------
-
-async def _init_regression_table():
- """Ensure regression_runs table exists (idempotent)."""
- pool = await get_pool()
- async with pool.acquire() as conn:
- migration_path = os.path.join(
- os.path.dirname(__file__),
- "migrations/008_regression_runs.sql",
- )
- if os.path.exists(migration_path):
- with open(migration_path, "r") as f:
- sql = f.read()
- await conn.execute(sql)
-
-
-async def _persist_regression_run(
- status: str,
- summary: dict,
- results: list,
- duration_ms: int,
- triggered_by: str = "manual",
-) -> str:
- """Save a regression run to the database. Returns the run ID."""
- try:
- await _init_regression_table()
- pool = await get_pool()
- run_id = str(uuid.uuid4())
- async with pool.acquire() as conn:
- await conn.execute(
- """
- INSERT INTO regression_runs
- (id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
- """,
- run_id,
- status,
- summary.get("total", 0),
- summary.get("passed", 0),
- summary.get("failed", 0),
- summary.get("errors", 0),
- duration_ms,
- json.dumps(results),
- triggered_by,
- )
- logger.info("Regression run %s persisted: %s", run_id, status)
- return run_id
- except Exception as e:
- logger.warning("Failed to persist regression run: %s", e)
- return ""
-
-
-# ---------------------------------------------------------------------------
-# Helpers
-# ---------------------------------------------------------------------------
-
-def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
- """Extract a flat list of cells from a grid_editor_result for comparison.
-
- Only keeps fields relevant for comparison: cell_id, row_index, col_index,
- col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
- """
- cells = []
- for zone in grid_result.get("zones", []):
- for cell in zone.get("cells", []):
- cells.append({
- "cell_id": cell.get("cell_id", ""),
- "row_index": cell.get("row_index"),
- "col_index": cell.get("col_index"),
- "col_type": cell.get("col_type", ""),
- "text": cell.get("text", ""),
- })
- return cells
-
-
-def _build_reference_snapshot(
- grid_result: dict,
- pipeline: Optional[str] = None,
-) -> dict:
- """Build a ground-truth reference snapshot from a grid_editor_result."""
- cells = _extract_cells_for_comparison(grid_result)
-
- total_zones = len(grid_result.get("zones", []))
- total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
- total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
-
- snapshot = {
- "saved_at": datetime.now(timezone.utc).isoformat(),
- "version": 1,
- "pipeline": pipeline,
- "summary": {
- "total_zones": total_zones,
- "total_columns": total_columns,
- "total_rows": total_rows,
- "total_cells": len(cells),
- },
- "cells": cells,
- }
- return snapshot
-
-
-def compare_grids(reference: dict, current: dict) -> dict:
- """Compare a reference grid snapshot with a newly computed one.
-
- Returns a diff report with:
- - status: "pass" or "fail"
- - structural_diffs: changes in zone/row/column counts
- - cell_diffs: list of individual cell changes
- """
- ref_summary = reference.get("summary", {})
- cur_summary = current.get("summary", {})
-
- structural_diffs = []
- for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
- ref_val = ref_summary.get(key, 0)
- cur_val = cur_summary.get(key, 0)
- if ref_val != cur_val:
- structural_diffs.append({
- "field": key,
- "reference": ref_val,
- "current": cur_val,
- })
-
- # Build cell lookup by cell_id
- ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
- cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
-
- cell_diffs: List[Dict[str, Any]] = []
-
- # Check for missing cells (in reference but not in current)
- for cell_id in ref_cells:
- if cell_id not in cur_cells:
- cell_diffs.append({
- "type": "cell_missing",
- "cell_id": cell_id,
- "reference_text": ref_cells[cell_id].get("text", ""),
- })
-
- # Check for added cells (in current but not in reference)
- for cell_id in cur_cells:
- if cell_id not in ref_cells:
- cell_diffs.append({
- "type": "cell_added",
- "cell_id": cell_id,
- "current_text": cur_cells[cell_id].get("text", ""),
- })
-
- # Check for changes in shared cells
- for cell_id in ref_cells:
- if cell_id not in cur_cells:
- continue
- ref_cell = ref_cells[cell_id]
- cur_cell = cur_cells[cell_id]
-
- if ref_cell.get("text", "") != cur_cell.get("text", ""):
- cell_diffs.append({
- "type": "text_change",
- "cell_id": cell_id,
- "reference": ref_cell.get("text", ""),
- "current": cur_cell.get("text", ""),
- })
-
- if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
- cell_diffs.append({
- "type": "col_type_change",
- "cell_id": cell_id,
- "reference": ref_cell.get("col_type", ""),
- "current": cur_cell.get("col_type", ""),
- })
-
- status = "pass" if not structural_diffs and not cell_diffs else "fail"
-
- return {
- "status": status,
- "structural_diffs": structural_diffs,
- "cell_diffs": cell_diffs,
- "summary": {
- "structural_changes": len(structural_diffs),
- "cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
- "cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
- "text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
- "col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
- },
- }
-
-
-# ---------------------------------------------------------------------------
-# Endpoints
-# ---------------------------------------------------------------------------
-
-@router.post("/sessions/{session_id}/mark-ground-truth")
-async def mark_ground_truth(
- session_id: str,
- pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
-):
- """Save the current build-grid result as ground-truth reference."""
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- grid_result = session.get("grid_editor_result")
- if not grid_result or not grid_result.get("zones"):
- raise HTTPException(
- status_code=400,
- detail="No grid_editor_result found. Run build-grid first.",
- )
-
- # Auto-detect pipeline from word_result if not provided
- if not pipeline:
- wr = session.get("word_result") or {}
- engine = wr.get("ocr_engine", "")
- if engine in ("kombi", "rapid_kombi"):
- pipeline = "kombi"
- elif engine == "paddle_direct":
- pipeline = "paddle-direct"
- else:
- pipeline = "pipeline"
-
- reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
-
- # Merge into existing ground_truth JSONB
- gt = session.get("ground_truth") or {}
- gt["build_grid_reference"] = reference
- await update_session_db(session_id, ground_truth=gt, current_step=11)
-
- # Compare with auto-snapshot if available (shows what the user corrected)
- auto_snapshot = gt.get("auto_grid_snapshot")
- correction_diff = None
- if auto_snapshot:
- correction_diff = compare_grids(auto_snapshot, reference)
-
- logger.info(
- "Ground truth marked for session %s: %d cells (corrections: %s)",
- session_id,
- len(reference["cells"]),
- correction_diff["summary"] if correction_diff else "no auto-snapshot",
- )
-
- return {
- "status": "ok",
- "session_id": session_id,
- "cells_saved": len(reference["cells"]),
- "summary": reference["summary"],
- "correction_diff": correction_diff,
- }
-
-
-@router.delete("/sessions/{session_id}/mark-ground-truth")
-async def unmark_ground_truth(session_id: str):
- """Remove the ground-truth reference from a session."""
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- gt = session.get("ground_truth") or {}
- if "build_grid_reference" not in gt:
- raise HTTPException(status_code=404, detail="No ground truth reference found")
-
- del gt["build_grid_reference"]
- await update_session_db(session_id, ground_truth=gt)
-
- logger.info("Ground truth removed for session %s", session_id)
- return {"status": "ok", "session_id": session_id}
-
-
-@router.get("/sessions/{session_id}/correction-diff")
-async def get_correction_diff(session_id: str):
- """Compare automatic OCR grid with manually corrected ground truth.
-
- Returns a diff showing exactly which cells the user corrected,
- broken down by col_type (english, german, ipa, etc.).
- """
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- gt = session.get("ground_truth") or {}
- auto_snapshot = gt.get("auto_grid_snapshot")
- reference = gt.get("build_grid_reference")
-
- if not auto_snapshot:
- raise HTTPException(
- status_code=404,
- detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
- )
- if not reference:
- raise HTTPException(
- status_code=404,
- detail="No ground truth reference found. Mark as ground truth first.",
- )
-
- diff = compare_grids(auto_snapshot, reference)
-
- # Enrich with per-col_type breakdown
- col_type_stats: Dict[str, Dict[str, int]] = {}
- for cell_diff in diff.get("cell_diffs", []):
- if cell_diff["type"] != "text_change":
- continue
- # Find col_type from reference cells
- cell_id = cell_diff["cell_id"]
- ref_cell = next(
- (c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
- None,
- )
- ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
- if ct not in col_type_stats:
- col_type_stats[ct] = {"total": 0, "corrected": 0}
- col_type_stats[ct]["corrected"] += 1
-
- # Count total cells per col_type from reference
- for cell in reference.get("cells", []):
- ct = cell.get("col_type", "unknown")
- if ct not in col_type_stats:
- col_type_stats[ct] = {"total": 0, "corrected": 0}
- col_type_stats[ct]["total"] += 1
-
- # Calculate accuracy per col_type
- for ct, stats in col_type_stats.items():
- total = stats["total"]
- corrected = stats["corrected"]
- stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
-
- diff["col_type_breakdown"] = col_type_stats
-
- return diff
-
-
-@router.get("/ground-truth-sessions")
-async def list_ground_truth_sessions():
- """List all sessions that have a ground-truth reference."""
- sessions = await list_ground_truth_sessions_db()
-
- result = []
- for s in sessions:
- gt = s.get("ground_truth") or {}
- ref = gt.get("build_grid_reference", {})
- result.append({
- "session_id": s["id"],
- "name": s.get("name", ""),
- "filename": s.get("filename", ""),
- "document_category": s.get("document_category"),
- "pipeline": ref.get("pipeline"),
- "saved_at": ref.get("saved_at"),
- "summary": ref.get("summary", {}),
- })
-
- return {"sessions": result, "count": len(result)}
-
-
-@router.post("/sessions/{session_id}/regression/run")
-async def run_single_regression(session_id: str):
- """Re-run build_grid for a single session and compare to ground truth."""
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- gt = session.get("ground_truth") or {}
- reference = gt.get("build_grid_reference")
- if not reference:
- raise HTTPException(
- status_code=400,
- detail="No ground truth reference found for this session",
- )
-
- # Re-compute grid without persisting
- try:
- new_result = await _build_grid_core(session_id, session)
- except (ValueError, Exception) as e:
- return {
- "session_id": session_id,
- "name": session.get("name", ""),
- "status": "error",
- "error": str(e),
- }
-
- new_snapshot = _build_reference_snapshot(new_result)
- diff = compare_grids(reference, new_snapshot)
-
- logger.info(
- "Regression test session %s: %s (%d structural, %d cell diffs)",
- session_id, diff["status"],
- diff["summary"]["structural_changes"],
- sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
- )
-
- return {
- "session_id": session_id,
- "name": session.get("name", ""),
- "status": diff["status"],
- "diff": diff,
- "reference_summary": reference.get("summary", {}),
- "current_summary": new_snapshot.get("summary", {}),
- }
-
-
-@router.post("/regression/run")
-async def run_all_regressions(
- triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
-):
- """Re-run build_grid for ALL ground-truth sessions and compare."""
- start_time = time.monotonic()
- sessions = await list_ground_truth_sessions_db()
-
- if not sessions:
- return {
- "status": "pass",
- "message": "No ground truth sessions found",
- "results": [],
- "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
- }
-
- results = []
- passed = 0
- failed = 0
- errors = 0
-
- for s in sessions:
- session_id = s["id"]
- gt = s.get("ground_truth") or {}
- reference = gt.get("build_grid_reference")
- if not reference:
- continue
-
- # Re-load full session (list query may not include all JSONB fields)
- full_session = await get_session_db(session_id)
- if not full_session:
- results.append({
- "session_id": session_id,
- "name": s.get("name", ""),
- "status": "error",
- "error": "Session not found during re-load",
- })
- errors += 1
- continue
-
- try:
- new_result = await _build_grid_core(session_id, full_session)
- except (ValueError, Exception) as e:
- results.append({
- "session_id": session_id,
- "name": s.get("name", ""),
- "status": "error",
- "error": str(e),
- })
- errors += 1
- continue
-
- new_snapshot = _build_reference_snapshot(new_result)
- diff = compare_grids(reference, new_snapshot)
-
- entry = {
- "session_id": session_id,
- "name": s.get("name", ""),
- "status": diff["status"],
- "diff_summary": diff["summary"],
- "reference_summary": reference.get("summary", {}),
- "current_summary": new_snapshot.get("summary", {}),
- }
-
- # Include full diffs only for failures (keep response compact)
- if diff["status"] == "fail":
- entry["structural_diffs"] = diff["structural_diffs"]
- entry["cell_diffs"] = diff["cell_diffs"]
- failed += 1
- else:
- passed += 1
-
- results.append(entry)
-
- overall = "pass" if failed == 0 and errors == 0 else "fail"
- duration_ms = int((time.monotonic() - start_time) * 1000)
-
- summary = {
- "total": len(results),
- "passed": passed,
- "failed": failed,
- "errors": errors,
- }
-
- logger.info(
- "Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms",
- overall, passed, failed, errors, len(results), duration_ms,
- )
-
- # Persist to DB
- run_id = await _persist_regression_run(
- status=overall,
- summary=summary,
- results=results,
- duration_ms=duration_ms,
- triggered_by=triggered_by,
- )
-
- return {
- "status": overall,
- "run_id": run_id,
- "duration_ms": duration_ms,
- "results": results,
- "summary": summary,
- }
-
-
-@router.get("/regression/history")
-async def get_regression_history(
- limit: int = Query(20, ge=1, le=100),
-):
- """Get recent regression run history from the database."""
- try:
- await _init_regression_table()
- pool = await get_pool()
- async with pool.acquire() as conn:
- rows = await conn.fetch(
- """
- SELECT id, run_at, status, total, passed, failed, errors,
- duration_ms, triggered_by
- FROM regression_runs
- ORDER BY run_at DESC
- LIMIT $1
- """,
- limit,
- )
- return {
- "runs": [
- {
- "id": str(row["id"]),
- "run_at": row["run_at"].isoformat() if row["run_at"] else None,
- "status": row["status"],
- "total": row["total"],
- "passed": row["passed"],
- "failed": row["failed"],
- "errors": row["errors"],
- "duration_ms": row["duration_ms"],
- "triggered_by": row["triggered_by"],
- }
- for row in rows
- ],
- "count": len(rows),
- }
- except Exception as e:
- logger.warning("Failed to fetch regression history: %s", e)
- return {"runs": [], "count": 0, "error": str(e)}
-
-
-@router.get("/regression/history/{run_id}")
-async def get_regression_run_detail(run_id: str):
- """Get detailed results of a specific regression run."""
- try:
- await _init_regression_table()
- pool = await get_pool()
- async with pool.acquire() as conn:
- row = await conn.fetchrow(
- "SELECT * FROM regression_runs WHERE id = $1",
- run_id,
- )
- if not row:
- raise HTTPException(status_code=404, detail="Run not found")
- return {
- "id": str(row["id"]),
- "run_at": row["run_at"].isoformat() if row["run_at"] else None,
- "status": row["status"],
- "total": row["total"],
- "passed": row["passed"],
- "failed": row["failed"],
- "errors": row["errors"],
- "duration_ms": row["duration_ms"],
- "triggered_by": row["triggered_by"],
- "results": json.loads(row["results"]) if row["results"] else [],
- }
- except HTTPException:
- raise
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
+# Endpoints (router used by ocr_pipeline_api.py)
+from ocr_pipeline_regression_endpoints import router # noqa: F401
diff --git a/klausur-service/backend/ocr_pipeline_regression_endpoints.py b/klausur-service/backend/ocr_pipeline_regression_endpoints.py
new file mode 100644
index 0000000..a91d6d6
--- /dev/null
+++ b/klausur-service/backend/ocr_pipeline_regression_endpoints.py
@@ -0,0 +1,421 @@
+"""
+OCR Pipeline Regression Endpoints — FastAPI routes for ground truth and regression.
+
+Extracted from ocr_pipeline_regression.py for modularity.
+
+Lizenz: Apache 2.0
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import json
+import logging
+import time
+from typing import Any, Dict, Optional
+
+from fastapi import APIRouter, HTTPException, Query
+
+from grid_editor_api import _build_grid_core
+from ocr_pipeline_session_store import (
+ get_session_db,
+ list_ground_truth_sessions_db,
+ update_session_db,
+)
+from ocr_pipeline_regression_helpers import (
+ _build_reference_snapshot,
+ _init_regression_table,
+ _persist_regression_run,
+ compare_grids,
+ get_pool,
+)
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["regression"])
+
+
+# ---------------------------------------------------------------------------
+# Endpoints
+# ---------------------------------------------------------------------------
+
+@router.post("/sessions/{session_id}/mark-ground-truth")
+async def mark_ground_truth(
+ session_id: str,
+ pipeline: Optional[str] = Query(None, description="Pipeline used: kombi, pipeline, paddle-direct"),
+):
+ """Save the current build-grid result as ground-truth reference."""
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ grid_result = session.get("grid_editor_result")
+ if not grid_result or not grid_result.get("zones"):
+ raise HTTPException(
+ status_code=400,
+ detail="No grid_editor_result found. Run build-grid first.",
+ )
+
+ # Auto-detect pipeline from word_result if not provided
+ if not pipeline:
+ wr = session.get("word_result") or {}
+ engine = wr.get("ocr_engine", "")
+ if engine in ("kombi", "rapid_kombi"):
+ pipeline = "kombi"
+ elif engine == "paddle_direct":
+ pipeline = "paddle-direct"
+ else:
+ pipeline = "pipeline"
+
+ reference = _build_reference_snapshot(grid_result, pipeline=pipeline)
+
+ # Merge into existing ground_truth JSONB
+ gt = session.get("ground_truth") or {}
+ gt["build_grid_reference"] = reference
+ await update_session_db(session_id, ground_truth=gt, current_step=11)
+
+ # Compare with auto-snapshot if available (shows what the user corrected)
+ auto_snapshot = gt.get("auto_grid_snapshot")
+ correction_diff = None
+ if auto_snapshot:
+ correction_diff = compare_grids(auto_snapshot, reference)
+
+ logger.info(
+ "Ground truth marked for session %s: %d cells (corrections: %s)",
+ session_id,
+ len(reference["cells"]),
+ correction_diff["summary"] if correction_diff else "no auto-snapshot",
+ )
+
+ return {
+ "status": "ok",
+ "session_id": session_id,
+ "cells_saved": len(reference["cells"]),
+ "summary": reference["summary"],
+ "correction_diff": correction_diff,
+ }
+
+
+@router.delete("/sessions/{session_id}/mark-ground-truth")
+async def unmark_ground_truth(session_id: str):
+ """Remove the ground-truth reference from a session."""
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ gt = session.get("ground_truth") or {}
+ if "build_grid_reference" not in gt:
+ raise HTTPException(status_code=404, detail="No ground truth reference found")
+
+ del gt["build_grid_reference"]
+ await update_session_db(session_id, ground_truth=gt)
+
+ logger.info("Ground truth removed for session %s", session_id)
+ return {"status": "ok", "session_id": session_id}
+
+
+@router.get("/sessions/{session_id}/correction-diff")
+async def get_correction_diff(session_id: str):
+ """Compare automatic OCR grid with manually corrected ground truth.
+
+ Returns a diff showing exactly which cells the user corrected,
+ broken down by col_type (english, german, ipa, etc.).
+ """
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ gt = session.get("ground_truth") or {}
+ auto_snapshot = gt.get("auto_grid_snapshot")
+ reference = gt.get("build_grid_reference")
+
+ if not auto_snapshot:
+ raise HTTPException(
+ status_code=404,
+ detail="No auto_grid_snapshot found. Re-run build-grid to create one.",
+ )
+ if not reference:
+ raise HTTPException(
+ status_code=404,
+ detail="No ground truth reference found. Mark as ground truth first.",
+ )
+
+ diff = compare_grids(auto_snapshot, reference)
+
+ # Enrich with per-col_type breakdown
+ col_type_stats: Dict[str, Dict[str, int]] = {}
+ for cell_diff in diff.get("cell_diffs", []):
+ if cell_diff["type"] != "text_change":
+ continue
+ # Find col_type from reference cells
+ cell_id = cell_diff["cell_id"]
+ ref_cell = next(
+ (c for c in reference.get("cells", []) if c["cell_id"] == cell_id),
+ None,
+ )
+ ct = ref_cell.get("col_type", "unknown") if ref_cell else "unknown"
+ if ct not in col_type_stats:
+ col_type_stats[ct] = {"total": 0, "corrected": 0}
+ col_type_stats[ct]["corrected"] += 1
+
+ # Count total cells per col_type from reference
+ for cell in reference.get("cells", []):
+ ct = cell.get("col_type", "unknown")
+ if ct not in col_type_stats:
+ col_type_stats[ct] = {"total": 0, "corrected": 0}
+ col_type_stats[ct]["total"] += 1
+
+ # Calculate accuracy per col_type
+ for ct, stats in col_type_stats.items():
+ total = stats["total"]
+ corrected = stats["corrected"]
+ stats["accuracy_pct"] = round((total - corrected) / total * 100, 1) if total > 0 else 100.0
+
+ diff["col_type_breakdown"] = col_type_stats
+
+ return diff
+
+
+@router.get("/ground-truth-sessions")
+async def list_ground_truth_sessions():
+ """List all sessions that have a ground-truth reference."""
+ sessions = await list_ground_truth_sessions_db()
+
+ result = []
+ for s in sessions:
+ gt = s.get("ground_truth") or {}
+ ref = gt.get("build_grid_reference", {})
+ result.append({
+ "session_id": s["id"],
+ "name": s.get("name", ""),
+ "filename": s.get("filename", ""),
+ "document_category": s.get("document_category"),
+ "pipeline": ref.get("pipeline"),
+ "saved_at": ref.get("saved_at"),
+ "summary": ref.get("summary", {}),
+ })
+
+ return {"sessions": result, "count": len(result)}
+
+
+@router.post("/sessions/{session_id}/regression/run")
+async def run_single_regression(session_id: str):
+ """Re-run build_grid for a single session and compare to ground truth."""
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ gt = session.get("ground_truth") or {}
+ reference = gt.get("build_grid_reference")
+ if not reference:
+ raise HTTPException(
+ status_code=400,
+ detail="No ground truth reference found for this session",
+ )
+
+ # Re-compute grid without persisting
+ try:
+ new_result = await _build_grid_core(session_id, session)
+ except (ValueError, Exception) as e:
+ return {
+ "session_id": session_id,
+ "name": session.get("name", ""),
+ "status": "error",
+ "error": str(e),
+ }
+
+ new_snapshot = _build_reference_snapshot(new_result)
+ diff = compare_grids(reference, new_snapshot)
+
+ logger.info(
+ "Regression test session %s: %s (%d structural, %d cell diffs)",
+ session_id, diff["status"],
+ diff["summary"]["structural_changes"],
+ sum(v for k, v in diff["summary"].items() if k != "structural_changes"),
+ )
+
+ return {
+ "session_id": session_id,
+ "name": session.get("name", ""),
+ "status": diff["status"],
+ "diff": diff,
+ "reference_summary": reference.get("summary", {}),
+ "current_summary": new_snapshot.get("summary", {}),
+ }
+
+
+@router.post("/regression/run")
+async def run_all_regressions(
+ triggered_by: str = Query("manual", description="Who triggered: manual, script, ci"),
+):
+ """Re-run build_grid for ALL ground-truth sessions and compare."""
+ start_time = time.monotonic()
+ sessions = await list_ground_truth_sessions_db()
+
+ if not sessions:
+ return {
+ "status": "pass",
+ "message": "No ground truth sessions found",
+ "results": [],
+ "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
+ }
+
+ results = []
+ passed = 0
+ failed = 0
+ errors = 0
+
+ for s in sessions:
+ session_id = s["id"]
+ gt = s.get("ground_truth") or {}
+ reference = gt.get("build_grid_reference")
+ if not reference:
+ continue
+
+ # Re-load full session (list query may not include all JSONB fields)
+ full_session = await get_session_db(session_id)
+ if not full_session:
+ results.append({
+ "session_id": session_id,
+ "name": s.get("name", ""),
+ "status": "error",
+ "error": "Session not found during re-load",
+ })
+ errors += 1
+ continue
+
+ try:
+ new_result = await _build_grid_core(session_id, full_session)
+ except (ValueError, Exception) as e:
+ results.append({
+ "session_id": session_id,
+ "name": s.get("name", ""),
+ "status": "error",
+ "error": str(e),
+ })
+ errors += 1
+ continue
+
+ new_snapshot = _build_reference_snapshot(new_result)
+ diff = compare_grids(reference, new_snapshot)
+
+ entry = {
+ "session_id": session_id,
+ "name": s.get("name", ""),
+ "status": diff["status"],
+ "diff_summary": diff["summary"],
+ "reference_summary": reference.get("summary", {}),
+ "current_summary": new_snapshot.get("summary", {}),
+ }
+
+ # Include full diffs only for failures (keep response compact)
+ if diff["status"] == "fail":
+ entry["structural_diffs"] = diff["structural_diffs"]
+ entry["cell_diffs"] = diff["cell_diffs"]
+ failed += 1
+ else:
+ passed += 1
+
+ results.append(entry)
+
+ overall = "pass" if failed == 0 and errors == 0 else "fail"
+ duration_ms = int((time.monotonic() - start_time) * 1000)
+
+ summary = {
+ "total": len(results),
+ "passed": passed,
+ "failed": failed,
+ "errors": errors,
+ }
+
+ logger.info(
+ "Regression suite: %s — %d passed, %d failed, %d errors (of %d) in %dms",
+ overall, passed, failed, errors, len(results), duration_ms,
+ )
+
+ # Persist to DB
+ run_id = await _persist_regression_run(
+ status=overall,
+ summary=summary,
+ results=results,
+ duration_ms=duration_ms,
+ triggered_by=triggered_by,
+ )
+
+ return {
+ "status": overall,
+ "run_id": run_id,
+ "duration_ms": duration_ms,
+ "results": results,
+ "summary": summary,
+ }
+
+
+@router.get("/regression/history")
+async def get_regression_history(
+ limit: int = Query(20, ge=1, le=100),
+):
+ """Get recent regression run history from the database."""
+ try:
+ await _init_regression_table()
+ pool = await get_pool()
+ async with pool.acquire() as conn:
+ rows = await conn.fetch(
+ """
+ SELECT id, run_at, status, total, passed, failed, errors,
+ duration_ms, triggered_by
+ FROM regression_runs
+ ORDER BY run_at DESC
+ LIMIT $1
+ """,
+ limit,
+ )
+ return {
+ "runs": [
+ {
+ "id": str(row["id"]),
+ "run_at": row["run_at"].isoformat() if row["run_at"] else None,
+ "status": row["status"],
+ "total": row["total"],
+ "passed": row["passed"],
+ "failed": row["failed"],
+ "errors": row["errors"],
+ "duration_ms": row["duration_ms"],
+ "triggered_by": row["triggered_by"],
+ }
+ for row in rows
+ ],
+ "count": len(rows),
+ }
+ except Exception as e:
+ logger.warning("Failed to fetch regression history: %s", e)
+ return {"runs": [], "count": 0, "error": str(e)}
+
+
+@router.get("/regression/history/{run_id}")
+async def get_regression_run_detail(run_id: str):
+ """Get detailed results of a specific regression run."""
+ try:
+ await _init_regression_table()
+ pool = await get_pool()
+ async with pool.acquire() as conn:
+ row = await conn.fetchrow(
+ "SELECT * FROM regression_runs WHERE id = $1",
+ run_id,
+ )
+ if not row:
+ raise HTTPException(status_code=404, detail="Run not found")
+ return {
+ "id": str(row["id"]),
+ "run_at": row["run_at"].isoformat() if row["run_at"] else None,
+ "status": row["status"],
+ "total": row["total"],
+ "passed": row["passed"],
+ "failed": row["failed"],
+ "errors": row["errors"],
+ "duration_ms": row["duration_ms"],
+ "triggered_by": row["triggered_by"],
+ "results": json.loads(row["results"]) if row["results"] else [],
+ }
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
diff --git a/klausur-service/backend/ocr_pipeline_regression_helpers.py b/klausur-service/backend/ocr_pipeline_regression_helpers.py
new file mode 100644
index 0000000..b8e0a57
--- /dev/null
+++ b/klausur-service/backend/ocr_pipeline_regression_helpers.py
@@ -0,0 +1,207 @@
+"""
+OCR Pipeline Regression Helpers — DB persistence, snapshot building, comparison.
+
+Extracted from ocr_pipeline_regression.py for modularity.
+
+Lizenz: Apache 2.0
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import json
+import logging
+import os
+import uuid
+from datetime import datetime, timezone
+from typing import Any, Dict, List, Optional
+
+from ocr_pipeline_session_store import get_pool
+
+logger = logging.getLogger(__name__)
+
+
+# ---------------------------------------------------------------------------
+# DB persistence for regression runs
+# ---------------------------------------------------------------------------
+
+async def _init_regression_table():
+ """Ensure regression_runs table exists (idempotent)."""
+ pool = await get_pool()
+ async with pool.acquire() as conn:
+ migration_path = os.path.join(
+ os.path.dirname(__file__),
+ "migrations/008_regression_runs.sql",
+ )
+ if os.path.exists(migration_path):
+ with open(migration_path, "r") as f:
+ sql = f.read()
+ await conn.execute(sql)
+
+
+async def _persist_regression_run(
+ status: str,
+ summary: dict,
+ results: list,
+ duration_ms: int,
+ triggered_by: str = "manual",
+) -> str:
+ """Save a regression run to the database. Returns the run ID."""
+ try:
+ await _init_regression_table()
+ pool = await get_pool()
+ run_id = str(uuid.uuid4())
+ async with pool.acquire() as conn:
+ await conn.execute(
+ """
+ INSERT INTO regression_runs
+ (id, status, total, passed, failed, errors, duration_ms, results, triggered_by)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
+ """,
+ run_id,
+ status,
+ summary.get("total", 0),
+ summary.get("passed", 0),
+ summary.get("failed", 0),
+ summary.get("errors", 0),
+ duration_ms,
+ json.dumps(results),
+ triggered_by,
+ )
+ logger.info("Regression run %s persisted: %s", run_id, status)
+ return run_id
+ except Exception as e:
+ logger.warning("Failed to persist regression run: %s", e)
+ return ""
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _extract_cells_for_comparison(grid_result: dict) -> List[Dict[str, Any]]:
+ """Extract a flat list of cells from a grid_editor_result for comparison.
+
+ Only keeps fields relevant for comparison: cell_id, row_index, col_index,
+ col_type, text. Ignores confidence, bbox, word_boxes, duration, is_bold.
+ """
+ cells = []
+ for zone in grid_result.get("zones", []):
+ for cell in zone.get("cells", []):
+ cells.append({
+ "cell_id": cell.get("cell_id", ""),
+ "row_index": cell.get("row_index"),
+ "col_index": cell.get("col_index"),
+ "col_type": cell.get("col_type", ""),
+ "text": cell.get("text", ""),
+ })
+ return cells
+
+
+def _build_reference_snapshot(
+ grid_result: dict,
+ pipeline: Optional[str] = None,
+) -> dict:
+ """Build a ground-truth reference snapshot from a grid_editor_result."""
+ cells = _extract_cells_for_comparison(grid_result)
+
+ total_zones = len(grid_result.get("zones", []))
+ total_columns = sum(len(z.get("columns", [])) for z in grid_result.get("zones", []))
+ total_rows = sum(len(z.get("rows", [])) for z in grid_result.get("zones", []))
+
+ snapshot = {
+ "saved_at": datetime.now(timezone.utc).isoformat(),
+ "version": 1,
+ "pipeline": pipeline,
+ "summary": {
+ "total_zones": total_zones,
+ "total_columns": total_columns,
+ "total_rows": total_rows,
+ "total_cells": len(cells),
+ },
+ "cells": cells,
+ }
+ return snapshot
+
+
+def compare_grids(reference: dict, current: dict) -> dict:
+ """Compare a reference grid snapshot with a newly computed one.
+
+ Returns a diff report with:
+ - status: "pass" or "fail"
+ - structural_diffs: changes in zone/row/column counts
+ - cell_diffs: list of individual cell changes
+ """
+ ref_summary = reference.get("summary", {})
+ cur_summary = current.get("summary", {})
+
+ structural_diffs = []
+ for key in ("total_zones", "total_columns", "total_rows", "total_cells"):
+ ref_val = ref_summary.get(key, 0)
+ cur_val = cur_summary.get(key, 0)
+ if ref_val != cur_val:
+ structural_diffs.append({
+ "field": key,
+ "reference": ref_val,
+ "current": cur_val,
+ })
+
+ # Build cell lookup by cell_id
+ ref_cells = {c["cell_id"]: c for c in reference.get("cells", [])}
+ cur_cells = {c["cell_id"]: c for c in current.get("cells", [])}
+
+ cell_diffs: List[Dict[str, Any]] = []
+
+ # Check for missing cells (in reference but not in current)
+ for cell_id in ref_cells:
+ if cell_id not in cur_cells:
+ cell_diffs.append({
+ "type": "cell_missing",
+ "cell_id": cell_id,
+ "reference_text": ref_cells[cell_id].get("text", ""),
+ })
+
+ # Check for added cells (in current but not in reference)
+ for cell_id in cur_cells:
+ if cell_id not in ref_cells:
+ cell_diffs.append({
+ "type": "cell_added",
+ "cell_id": cell_id,
+ "current_text": cur_cells[cell_id].get("text", ""),
+ })
+
+ # Check for changes in shared cells
+ for cell_id in ref_cells:
+ if cell_id not in cur_cells:
+ continue
+ ref_cell = ref_cells[cell_id]
+ cur_cell = cur_cells[cell_id]
+
+ if ref_cell.get("text", "") != cur_cell.get("text", ""):
+ cell_diffs.append({
+ "type": "text_change",
+ "cell_id": cell_id,
+ "reference": ref_cell.get("text", ""),
+ "current": cur_cell.get("text", ""),
+ })
+
+ if ref_cell.get("col_type", "") != cur_cell.get("col_type", ""):
+ cell_diffs.append({
+ "type": "col_type_change",
+ "cell_id": cell_id,
+ "reference": ref_cell.get("col_type", ""),
+ "current": cur_cell.get("col_type", ""),
+ })
+
+ status = "pass" if not structural_diffs and not cell_diffs else "fail"
+
+ return {
+ "status": status,
+ "structural_diffs": structural_diffs,
+ "cell_diffs": cell_diffs,
+ "summary": {
+ "structural_changes": len(structural_diffs),
+ "cells_missing": sum(1 for d in cell_diffs if d["type"] == "cell_missing"),
+ "cells_added": sum(1 for d in cell_diffs if d["type"] == "cell_added"),
+ "text_changes": sum(1 for d in cell_diffs if d["type"] == "text_change"),
+ "col_type_changes": sum(1 for d in cell_diffs if d["type"] == "col_type_change"),
+ },
+ }
diff --git a/klausur-service/backend/ocr_pipeline_sessions.py b/klausur-service/backend/ocr_pipeline_sessions.py
index 22e00d2..ae3f771 100644
--- a/klausur-service/backend/ocr_pipeline_sessions.py
+++ b/klausur-service/backend/ocr_pipeline_sessions.py
@@ -1,597 +1,20 @@
"""
-OCR Pipeline Sessions API - Session management and image serving endpoints.
+OCR Pipeline Sessions API — barrel re-export.
-Extracted from ocr_pipeline_api.py for modularity.
-Handles: CRUD for sessions, thumbnails, pipeline logs, categories,
-image serving (with overlay dispatch), and document type detection.
+All implementation split into:
+ ocr_pipeline_sessions_crud — session CRUD, box sessions
+ ocr_pipeline_sessions_images — image serving, thumbnails, doc-type detection
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
-import logging
-import time
-import uuid
-from typing import Any, Dict, Optional
+from fastapi import APIRouter
-import cv2
-import numpy as np
-from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
-from fastapi.responses import Response
+from ocr_pipeline_sessions_crud import router as _crud_router # noqa: F401
+from ocr_pipeline_sessions_images import router as _images_router # noqa: F401
-from cv_vocab_pipeline import (
- create_ocr_image,
- detect_document_type,
- render_image_high_res,
- render_pdf_high_res,
-)
-from ocr_pipeline_common import (
- VALID_DOCUMENT_CATEGORIES,
- UpdateSessionRequest,
- _append_pipeline_log,
- _cache,
- _get_base_image_png,
- _get_cached,
- _load_session_to_cache,
-)
-from ocr_pipeline_overlays import render_overlay
-from ocr_pipeline_session_store import (
- create_session_db,
- delete_all_sessions_db,
- delete_session_db,
- get_session_db,
- get_session_image,
- get_sub_sessions,
- list_sessions_db,
- update_session_db,
-)
-
-logger = logging.getLogger(__name__)
-
-router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
-
-
-# ---------------------------------------------------------------------------
-# Session Management Endpoints
-# ---------------------------------------------------------------------------
-
-@router.get("/sessions")
-async def list_sessions(include_sub_sessions: bool = False):
- """List OCR pipeline sessions.
-
- By default, sub-sessions (box regions) are hidden.
- Pass ?include_sub_sessions=true to show them.
- """
- sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
- return {"sessions": sessions}
-
-
-@router.post("/sessions")
-async def create_session(
- file: UploadFile = File(...),
- name: Optional[str] = Form(None),
-):
- """Upload a PDF or image file and create a pipeline session.
-
- For multi-page PDFs (> 1 page), each page becomes its own session
- grouped under a ``document_group_id``. The response includes a
- ``pages`` array with one entry per page/session.
- """
- file_data = await file.read()
- filename = file.filename or "upload"
- content_type = file.content_type or ""
-
- is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
- session_name = name or filename
-
- # --- Multi-page PDF handling ---
- if is_pdf:
- try:
- import fitz # PyMuPDF
- pdf_doc = fitz.open(stream=file_data, filetype="pdf")
- page_count = pdf_doc.page_count
- pdf_doc.close()
- except Exception as e:
- raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
-
- if page_count > 1:
- return await _create_multi_page_sessions(
- file_data, filename, session_name, page_count,
- )
-
- # --- Single page (image or 1-page PDF) ---
- session_id = str(uuid.uuid4())
-
- try:
- if is_pdf:
- img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
- else:
- img_bgr = render_image_high_res(file_data)
- except Exception as e:
- raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
-
- # Encode original as PNG bytes
- success, png_buf = cv2.imencode(".png", img_bgr)
- if not success:
- raise HTTPException(status_code=500, detail="Failed to encode image")
-
- original_png = png_buf.tobytes()
-
- # Persist to DB
- await create_session_db(
- session_id=session_id,
- name=session_name,
- filename=filename,
- original_png=original_png,
- )
-
- # Cache BGR array for immediate processing
- _cache[session_id] = {
- "id": session_id,
- "filename": filename,
- "name": session_name,
- "original_bgr": img_bgr,
- "oriented_bgr": None,
- "cropped_bgr": None,
- "deskewed_bgr": None,
- "dewarped_bgr": None,
- "orientation_result": None,
- "crop_result": None,
- "deskew_result": None,
- "dewarp_result": None,
- "ground_truth": {},
- "current_step": 1,
- }
-
- logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
- f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
-
- return {
- "session_id": session_id,
- "filename": filename,
- "name": session_name,
- "image_width": img_bgr.shape[1],
- "image_height": img_bgr.shape[0],
- "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
- }
-
-
-async def _create_multi_page_sessions(
- pdf_data: bytes,
- filename: str,
- base_name: str,
- page_count: int,
-) -> dict:
- """Create one session per PDF page, grouped by document_group_id."""
- document_group_id = str(uuid.uuid4())
- pages = []
-
- for page_idx in range(page_count):
- session_id = str(uuid.uuid4())
- page_name = f"{base_name} — Seite {page_idx + 1}"
-
- try:
- img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
- except Exception as e:
- logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
- continue
-
- ok, png_buf = cv2.imencode(".png", img_bgr)
- if not ok:
- continue
- page_png = png_buf.tobytes()
-
- await create_session_db(
- session_id=session_id,
- name=page_name,
- filename=filename,
- original_png=page_png,
- document_group_id=document_group_id,
- page_number=page_idx + 1,
- )
-
- _cache[session_id] = {
- "id": session_id,
- "filename": filename,
- "name": page_name,
- "original_bgr": img_bgr,
- "oriented_bgr": None,
- "cropped_bgr": None,
- "deskewed_bgr": None,
- "dewarped_bgr": None,
- "orientation_result": None,
- "crop_result": None,
- "deskew_result": None,
- "dewarp_result": None,
- "ground_truth": {},
- "current_step": 1,
- }
-
- h, w = img_bgr.shape[:2]
- pages.append({
- "session_id": session_id,
- "name": page_name,
- "page_number": page_idx + 1,
- "image_width": w,
- "image_height": h,
- "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
- })
-
- logger.info(
- f"OCR Pipeline: created page session {session_id} "
- f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
- )
-
- # Include session_id pointing to first page for backwards compatibility
- # (frontends that expect a single session_id will navigate to page 1)
- first_session_id = pages[0]["session_id"] if pages else None
-
- return {
- "session_id": first_session_id,
- "document_group_id": document_group_id,
- "filename": filename,
- "name": base_name,
- "page_count": page_count,
- "pages": pages,
- }
-
-
-@router.get("/sessions/{session_id}")
-async def get_session_info(session_id: str):
- """Get session info including deskew/dewarp/column results for step navigation."""
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- # Get image dimensions from original PNG
- original_png = await get_session_image(session_id, "original")
- if original_png:
- arr = np.frombuffer(original_png, dtype=np.uint8)
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
- img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
- else:
- img_w, img_h = 0, 0
-
- result = {
- "session_id": session["id"],
- "filename": session.get("filename", ""),
- "name": session.get("name", ""),
- "image_width": img_w,
- "image_height": img_h,
- "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
- "current_step": session.get("current_step", 1),
- "document_category": session.get("document_category"),
- "doc_type": session.get("doc_type"),
- }
-
- if session.get("orientation_result"):
- result["orientation_result"] = session["orientation_result"]
- if session.get("crop_result"):
- result["crop_result"] = session["crop_result"]
- if session.get("deskew_result"):
- result["deskew_result"] = session["deskew_result"]
- if session.get("dewarp_result"):
- result["dewarp_result"] = session["dewarp_result"]
- if session.get("column_result"):
- result["column_result"] = session["column_result"]
- if session.get("row_result"):
- result["row_result"] = session["row_result"]
- if session.get("word_result"):
- result["word_result"] = session["word_result"]
- if session.get("doc_type_result"):
- result["doc_type_result"] = session["doc_type_result"]
- if session.get("structure_result"):
- result["structure_result"] = session["structure_result"]
- if session.get("grid_editor_result"):
- # Include summary only to keep response small
- gr = session["grid_editor_result"]
- result["grid_editor_result"] = {
- "summary": gr.get("summary", {}),
- "zones_count": len(gr.get("zones", [])),
- "edited": gr.get("edited", False),
- }
- if session.get("ground_truth"):
- result["ground_truth"] = session["ground_truth"]
-
- # Box sub-session info (zone_type='box' from column detection — NOT page-split)
- if session.get("parent_session_id"):
- result["parent_session_id"] = session["parent_session_id"]
- result["box_index"] = session.get("box_index")
- else:
- # Check for box sub-sessions (column detection creates these)
- subs = await get_sub_sessions(session_id)
- if subs:
- result["sub_sessions"] = [
- {"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
- for s in subs
- ]
-
- return result
-
-
-@router.put("/sessions/{session_id}")
-async def update_session(session_id: str, req: UpdateSessionRequest):
- """Update session name and/or document category."""
- kwargs: Dict[str, Any] = {}
- if req.name is not None:
- kwargs["name"] = req.name
- if req.document_category is not None:
- if req.document_category not in VALID_DOCUMENT_CATEGORIES:
- raise HTTPException(
- status_code=400,
- detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
- )
- kwargs["document_category"] = req.document_category
- if not kwargs:
- raise HTTPException(status_code=400, detail="Nothing to update")
- updated = await update_session_db(session_id, **kwargs)
- if not updated:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
- return {"session_id": session_id, **kwargs}
-
-
-@router.delete("/sessions/{session_id}")
-async def delete_session(session_id: str):
- """Delete a session."""
- _cache.pop(session_id, None)
- deleted = await delete_session_db(session_id)
- if not deleted:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
- return {"session_id": session_id, "deleted": True}
-
-
-@router.delete("/sessions")
-async def delete_all_sessions():
- """Delete ALL sessions (cleanup)."""
- _cache.clear()
- count = await delete_all_sessions_db()
- return {"deleted_count": count}
-
-
-@router.post("/sessions/{session_id}/create-box-sessions")
-async def create_box_sessions(session_id: str):
- """Create sub-sessions for each detected box region.
-
- Crops box regions from the cropped/dewarped image and creates
- independent sub-sessions that can be processed through the pipeline.
- """
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
-
- column_result = session.get("column_result")
- if not column_result:
- raise HTTPException(status_code=400, detail="Column detection must be completed first")
-
- zones = column_result.get("zones") or []
- box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
- if not box_zones:
- return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
-
- # Check for existing sub-sessions
- existing = await get_sub_sessions(session_id)
- if existing:
- return {
- "session_id": session_id,
- "sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
- "message": f"{len(existing)} sub-session(s) already exist",
- }
-
- # Load base image
- base_png = await get_session_image(session_id, "cropped")
- if not base_png:
- base_png = await get_session_image(session_id, "dewarped")
- if not base_png:
- raise HTTPException(status_code=400, detail="No base image available")
-
- arr = np.frombuffer(base_png, dtype=np.uint8)
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
- if img is None:
- raise HTTPException(status_code=500, detail="Failed to decode image")
-
- parent_name = session.get("name", "Session")
- created = []
-
- for i, zone in enumerate(box_zones):
- box = zone["box"]
- bx, by = box["x"], box["y"]
- bw, bh = box["width"], box["height"]
-
- # Crop box region with small padding
- pad = 5
- y1 = max(0, by - pad)
- y2 = min(img.shape[0], by + bh + pad)
- x1 = max(0, bx - pad)
- x2 = min(img.shape[1], bx + bw + pad)
- crop = img[y1:y2, x1:x2]
-
- # Encode as PNG
- success, png_buf = cv2.imencode(".png", crop)
- if not success:
- logger.warning(f"Failed to encode box {i} crop for session {session_id}")
- continue
-
- sub_id = str(uuid.uuid4())
- sub_name = f"{parent_name} — Box {i + 1}"
-
- await create_session_db(
- session_id=sub_id,
- name=sub_name,
- filename=session.get("filename", "box-crop.png"),
- original_png=png_buf.tobytes(),
- parent_session_id=session_id,
- box_index=i,
- )
-
- # Cache the BGR for immediate processing
- # Promote original to cropped so column/row/word detection finds it
- box_bgr = crop.copy()
- _cache[sub_id] = {
- "id": sub_id,
- "filename": session.get("filename", "box-crop.png"),
- "name": sub_name,
- "parent_session_id": session_id,
- "original_bgr": box_bgr,
- "oriented_bgr": None,
- "cropped_bgr": box_bgr,
- "deskewed_bgr": None,
- "dewarped_bgr": None,
- "orientation_result": None,
- "crop_result": None,
- "deskew_result": None,
- "dewarp_result": None,
- "ground_truth": {},
- "current_step": 1,
- }
-
- created.append({
- "id": sub_id,
- "name": sub_name,
- "box_index": i,
- "box": box,
- "image_width": crop.shape[1],
- "image_height": crop.shape[0],
- })
-
- logger.info(f"Created box sub-session {sub_id} for session {session_id} "
- f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
-
- return {
- "session_id": session_id,
- "sub_sessions": created,
- "total": len(created),
- }
-
-
-@router.get("/sessions/{session_id}/thumbnail")
-async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
- """Return a small thumbnail of the original image."""
- original_png = await get_session_image(session_id, "original")
- if not original_png:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
- arr = np.frombuffer(original_png, dtype=np.uint8)
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
- if img is None:
- raise HTTPException(status_code=500, detail="Failed to decode image")
- h, w = img.shape[:2]
- scale = size / max(h, w)
- new_w, new_h = int(w * scale), int(h * scale)
- thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
- _, png_bytes = cv2.imencode(".png", thumb)
- return Response(content=png_bytes.tobytes(), media_type="image/png",
- headers={"Cache-Control": "public, max-age=3600"})
-
-
-@router.get("/sessions/{session_id}/pipeline-log")
-async def get_pipeline_log(session_id: str):
- """Get the pipeline execution log for a session."""
- session = await get_session_db(session_id)
- if not session:
- raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
- return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
-
-
-@router.get("/categories")
-async def list_categories():
- """List valid document categories."""
- return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
-
-
-# ---------------------------------------------------------------------------
-# Image Endpoints
-# ---------------------------------------------------------------------------
-
-@router.get("/sessions/{session_id}/image/{image_type}")
-async def get_image(session_id: str, image_type: str):
- """Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
- valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
- if image_type not in valid_types:
- raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
-
- if image_type == "structure-overlay":
- return await render_overlay("structure", session_id)
-
- if image_type == "columns-overlay":
- return await render_overlay("columns", session_id)
-
- if image_type == "rows-overlay":
- return await render_overlay("rows", session_id)
-
- if image_type == "words-overlay":
- return await render_overlay("words", session_id)
-
- # Try cache first for fast serving
- cached = _cache.get(session_id)
- if cached:
- png_key = f"{image_type}_png" if image_type != "original" else None
- bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
-
- # For binarized, check if we have it cached as PNG
- if image_type == "binarized" and cached.get("binarized_png"):
- return Response(content=cached["binarized_png"], media_type="image/png")
-
- # Load from DB — for cropped/dewarped, fall back through the chain
- if image_type in ("cropped", "dewarped"):
- data = await _get_base_image_png(session_id)
- else:
- data = await get_session_image(session_id, image_type)
- if not data:
- raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
-
- return Response(content=data, media_type="image/png")
-
-
-# ---------------------------------------------------------------------------
-# Document Type Detection (between Dewarp and Columns)
-# ---------------------------------------------------------------------------
-
-@router.post("/sessions/{session_id}/detect-type")
-async def detect_type(session_id: str):
- """Detect document type (vocab_table, full_text, generic_table).
-
- Should be called after crop (clean image available).
- Falls back to dewarped if crop was skipped.
- Stores result in session for frontend to decide pipeline flow.
- """
- if session_id not in _cache:
- await _load_session_to_cache(session_id)
- cached = _get_cached(session_id)
-
- img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
- if img_bgr is None:
- raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
-
- t0 = time.time()
- ocr_img = create_ocr_image(img_bgr)
- result = detect_document_type(ocr_img, img_bgr)
- duration = time.time() - t0
-
- result_dict = {
- "doc_type": result.doc_type,
- "confidence": result.confidence,
- "pipeline": result.pipeline,
- "skip_steps": result.skip_steps,
- "features": result.features,
- "duration_seconds": round(duration, 2),
- }
-
- # Persist to DB
- await update_session_db(
- session_id,
- doc_type=result.doc_type,
- doc_type_result=result_dict,
- )
-
- cached["doc_type_result"] = result_dict
-
- logger.info(f"OCR Pipeline: detect-type session {session_id}: "
- f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
-
- await _append_pipeline_log(session_id, "detect_type", {
- "doc_type": result.doc_type,
- "pipeline": result.pipeline,
- "confidence": result.confidence,
- **{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
- }, duration_ms=int(duration * 1000))
-
- return {"session_id": session_id, **result_dict}
+# Composite router (used by ocr_pipeline_api.py)
+router = APIRouter()
+router.include_router(_crud_router)
+router.include_router(_images_router)
diff --git a/klausur-service/backend/ocr_pipeline_sessions_crud.py b/klausur-service/backend/ocr_pipeline_sessions_crud.py
new file mode 100644
index 0000000..19343d7
--- /dev/null
+++ b/klausur-service/backend/ocr_pipeline_sessions_crud.py
@@ -0,0 +1,449 @@
+"""
+OCR Pipeline Sessions CRUD — session create, read, update, delete, box sessions.
+
+Extracted from ocr_pipeline_sessions.py for modularity.
+
+Lizenz: Apache 2.0
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import logging
+import uuid
+from typing import Any, Dict, Optional
+
+import cv2
+import numpy as np
+from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
+
+from cv_vocab_pipeline import render_image_high_res, render_pdf_high_res
+from ocr_pipeline_common import (
+ VALID_DOCUMENT_CATEGORIES,
+ UpdateSessionRequest,
+ _cache,
+)
+from ocr_pipeline_session_store import (
+ create_session_db,
+ delete_all_sessions_db,
+ delete_session_db,
+ get_session_db,
+ get_session_image,
+ get_sub_sessions,
+ list_sessions_db,
+ update_session_db,
+)
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
+
+
+# ---------------------------------------------------------------------------
+# Session Management Endpoints
+# ---------------------------------------------------------------------------
+
+@router.get("/sessions")
+async def list_sessions(include_sub_sessions: bool = False):
+ """List OCR pipeline sessions.
+
+ By default, sub-sessions (box regions) are hidden.
+ Pass ?include_sub_sessions=true to show them.
+ """
+ sessions = await list_sessions_db(include_sub_sessions=include_sub_sessions)
+ return {"sessions": sessions}
+
+
+@router.post("/sessions")
+async def create_session(
+ file: UploadFile = File(...),
+ name: Optional[str] = Form(None),
+):
+ """Upload a PDF or image file and create a pipeline session.
+
+ For multi-page PDFs (> 1 page), each page becomes its own session
+ grouped under a ``document_group_id``. The response includes a
+ ``pages`` array with one entry per page/session.
+ """
+ file_data = await file.read()
+ filename = file.filename or "upload"
+ content_type = file.content_type or ""
+
+ is_pdf = content_type == "application/pdf" or filename.lower().endswith(".pdf")
+ session_name = name or filename
+
+ # --- Multi-page PDF handling ---
+ if is_pdf:
+ try:
+ import fitz # PyMuPDF
+ pdf_doc = fitz.open(stream=file_data, filetype="pdf")
+ page_count = pdf_doc.page_count
+ pdf_doc.close()
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=f"Could not read PDF: {e}")
+
+ if page_count > 1:
+ return await _create_multi_page_sessions(
+ file_data, filename, session_name, page_count,
+ )
+
+ # --- Single page (image or 1-page PDF) ---
+ session_id = str(uuid.uuid4())
+
+ try:
+ if is_pdf:
+ img_bgr = render_pdf_high_res(file_data, page_number=0, zoom=3.0)
+ else:
+ img_bgr = render_image_high_res(file_data)
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=f"Could not process file: {e}")
+
+ # Encode original as PNG bytes
+ success, png_buf = cv2.imencode(".png", img_bgr)
+ if not success:
+ raise HTTPException(status_code=500, detail="Failed to encode image")
+
+ original_png = png_buf.tobytes()
+
+ # Persist to DB
+ await create_session_db(
+ session_id=session_id,
+ name=session_name,
+ filename=filename,
+ original_png=original_png,
+ )
+
+ # Cache BGR array for immediate processing
+ _cache[session_id] = {
+ "id": session_id,
+ "filename": filename,
+ "name": session_name,
+ "original_bgr": img_bgr,
+ "oriented_bgr": None,
+ "cropped_bgr": None,
+ "deskewed_bgr": None,
+ "dewarped_bgr": None,
+ "orientation_result": None,
+ "crop_result": None,
+ "deskew_result": None,
+ "dewarp_result": None,
+ "ground_truth": {},
+ "current_step": 1,
+ }
+
+ logger.info(f"OCR Pipeline: created session {session_id} from {filename} "
+ f"({img_bgr.shape[1]}x{img_bgr.shape[0]})")
+
+ return {
+ "session_id": session_id,
+ "filename": filename,
+ "name": session_name,
+ "image_width": img_bgr.shape[1],
+ "image_height": img_bgr.shape[0],
+ "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
+ }
+
+
+async def _create_multi_page_sessions(
+ pdf_data: bytes,
+ filename: str,
+ base_name: str,
+ page_count: int,
+) -> dict:
+ """Create one session per PDF page, grouped by document_group_id."""
+ document_group_id = str(uuid.uuid4())
+ pages = []
+
+ for page_idx in range(page_count):
+ session_id = str(uuid.uuid4())
+ page_name = f"{base_name} — Seite {page_idx + 1}"
+
+ try:
+ img_bgr = render_pdf_high_res(pdf_data, page_number=page_idx, zoom=3.0)
+ except Exception as e:
+ logger.warning(f"Failed to render PDF page {page_idx + 1}: {e}")
+ continue
+
+ ok, png_buf = cv2.imencode(".png", img_bgr)
+ if not ok:
+ continue
+ page_png = png_buf.tobytes()
+
+ await create_session_db(
+ session_id=session_id,
+ name=page_name,
+ filename=filename,
+ original_png=page_png,
+ document_group_id=document_group_id,
+ page_number=page_idx + 1,
+ )
+
+ _cache[session_id] = {
+ "id": session_id,
+ "filename": filename,
+ "name": page_name,
+ "original_bgr": img_bgr,
+ "oriented_bgr": None,
+ "cropped_bgr": None,
+ "deskewed_bgr": None,
+ "dewarped_bgr": None,
+ "orientation_result": None,
+ "crop_result": None,
+ "deskew_result": None,
+ "dewarp_result": None,
+ "ground_truth": {},
+ "current_step": 1,
+ }
+
+ h, w = img_bgr.shape[:2]
+ pages.append({
+ "session_id": session_id,
+ "name": page_name,
+ "page_number": page_idx + 1,
+ "image_width": w,
+ "image_height": h,
+ "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
+ })
+
+ logger.info(
+ f"OCR Pipeline: created page session {session_id} "
+ f"(page {page_idx + 1}/{page_count}) from {filename} ({w}x{h})"
+ )
+
+ # Include session_id pointing to first page for backwards compatibility
+ # (frontends that expect a single session_id will navigate to page 1)
+ first_session_id = pages[0]["session_id"] if pages else None
+
+ return {
+ "session_id": first_session_id,
+ "document_group_id": document_group_id,
+ "filename": filename,
+ "name": base_name,
+ "page_count": page_count,
+ "pages": pages,
+ }
+
+
+@router.get("/sessions/{session_id}")
+async def get_session_info(session_id: str):
+ """Get session info including deskew/dewarp/column results for step navigation."""
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ # Get image dimensions from original PNG
+ original_png = await get_session_image(session_id, "original")
+ if original_png:
+ arr = np.frombuffer(original_png, dtype=np.uint8)
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
+ img_w, img_h = img.shape[1], img.shape[0] if img is not None else (0, 0)
+ else:
+ img_w, img_h = 0, 0
+
+ result = {
+ "session_id": session["id"],
+ "filename": session.get("filename", ""),
+ "name": session.get("name", ""),
+ "image_width": img_w,
+ "image_height": img_h,
+ "original_image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/original",
+ "current_step": session.get("current_step", 1),
+ "document_category": session.get("document_category"),
+ "doc_type": session.get("doc_type"),
+ }
+
+ if session.get("orientation_result"):
+ result["orientation_result"] = session["orientation_result"]
+ if session.get("crop_result"):
+ result["crop_result"] = session["crop_result"]
+ if session.get("deskew_result"):
+ result["deskew_result"] = session["deskew_result"]
+ if session.get("dewarp_result"):
+ result["dewarp_result"] = session["dewarp_result"]
+ if session.get("column_result"):
+ result["column_result"] = session["column_result"]
+ if session.get("row_result"):
+ result["row_result"] = session["row_result"]
+ if session.get("word_result"):
+ result["word_result"] = session["word_result"]
+ if session.get("doc_type_result"):
+ result["doc_type_result"] = session["doc_type_result"]
+ if session.get("structure_result"):
+ result["structure_result"] = session["structure_result"]
+ if session.get("grid_editor_result"):
+ # Include summary only to keep response small
+ gr = session["grid_editor_result"]
+ result["grid_editor_result"] = {
+ "summary": gr.get("summary", {}),
+ "zones_count": len(gr.get("zones", [])),
+ "edited": gr.get("edited", False),
+ }
+ if session.get("ground_truth"):
+ result["ground_truth"] = session["ground_truth"]
+
+ # Box sub-session info (zone_type='box' from column detection — NOT page-split)
+ if session.get("parent_session_id"):
+ result["parent_session_id"] = session["parent_session_id"]
+ result["box_index"] = session.get("box_index")
+ else:
+ # Check for box sub-sessions (column detection creates these)
+ subs = await get_sub_sessions(session_id)
+ if subs:
+ result["sub_sessions"] = [
+ {"id": s["id"], "name": s.get("name"), "box_index": s.get("box_index")}
+ for s in subs
+ ]
+
+ return result
+
+
+@router.put("/sessions/{session_id}")
+async def update_session(session_id: str, req: UpdateSessionRequest):
+ """Update session name and/or document category."""
+ kwargs: Dict[str, Any] = {}
+ if req.name is not None:
+ kwargs["name"] = req.name
+ if req.document_category is not None:
+ if req.document_category not in VALID_DOCUMENT_CATEGORIES:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid category '{req.document_category}'. Valid: {sorted(VALID_DOCUMENT_CATEGORIES)}",
+ )
+ kwargs["document_category"] = req.document_category
+ if not kwargs:
+ raise HTTPException(status_code=400, detail="Nothing to update")
+ updated = await update_session_db(session_id, **kwargs)
+ if not updated:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+ return {"session_id": session_id, **kwargs}
+
+
+@router.delete("/sessions/{session_id}")
+async def delete_session(session_id: str):
+ """Delete a session."""
+ _cache.pop(session_id, None)
+ deleted = await delete_session_db(session_id)
+ if not deleted:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+ return {"session_id": session_id, "deleted": True}
+
+
+@router.delete("/sessions")
+async def delete_all_sessions():
+ """Delete ALL sessions (cleanup)."""
+ _cache.clear()
+ count = await delete_all_sessions_db()
+ return {"deleted_count": count}
+
+
+@router.post("/sessions/{session_id}/create-box-sessions")
+async def create_box_sessions(session_id: str):
+ """Create sub-sessions for each detected box region.
+
+ Crops box regions from the cropped/dewarped image and creates
+ independent sub-sessions that can be processed through the pipeline.
+ """
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+
+ column_result = session.get("column_result")
+ if not column_result:
+ raise HTTPException(status_code=400, detail="Column detection must be completed first")
+
+ zones = column_result.get("zones") or []
+ box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
+ if not box_zones:
+ return {"session_id": session_id, "sub_sessions": [], "message": "No boxes detected"}
+
+ # Check for existing sub-sessions
+ existing = await get_sub_sessions(session_id)
+ if existing:
+ return {
+ "session_id": session_id,
+ "sub_sessions": [{"id": s["id"], "box_index": s.get("box_index")} for s in existing],
+ "message": f"{len(existing)} sub-session(s) already exist",
+ }
+
+ # Load base image
+ base_png = await get_session_image(session_id, "cropped")
+ if not base_png:
+ base_png = await get_session_image(session_id, "dewarped")
+ if not base_png:
+ raise HTTPException(status_code=400, detail="No base image available")
+
+ arr = np.frombuffer(base_png, dtype=np.uint8)
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
+ if img is None:
+ raise HTTPException(status_code=500, detail="Failed to decode image")
+
+ parent_name = session.get("name", "Session")
+ created = []
+
+ for i, zone in enumerate(box_zones):
+ box = zone["box"]
+ bx, by = box["x"], box["y"]
+ bw, bh = box["width"], box["height"]
+
+ # Crop box region with small padding
+ pad = 5
+ y1 = max(0, by - pad)
+ y2 = min(img.shape[0], by + bh + pad)
+ x1 = max(0, bx - pad)
+ x2 = min(img.shape[1], bx + bw + pad)
+ crop = img[y1:y2, x1:x2]
+
+ # Encode as PNG
+ success, png_buf = cv2.imencode(".png", crop)
+ if not success:
+ logger.warning(f"Failed to encode box {i} crop for session {session_id}")
+ continue
+
+ sub_id = str(uuid.uuid4())
+ sub_name = f"{parent_name} — Box {i + 1}"
+
+ await create_session_db(
+ session_id=sub_id,
+ name=sub_name,
+ filename=session.get("filename", "box-crop.png"),
+ original_png=png_buf.tobytes(),
+ parent_session_id=session_id,
+ box_index=i,
+ )
+
+ # Cache the BGR for immediate processing
+ # Promote original to cropped so column/row/word detection finds it
+ box_bgr = crop.copy()
+ _cache[sub_id] = {
+ "id": sub_id,
+ "filename": session.get("filename", "box-crop.png"),
+ "name": sub_name,
+ "parent_session_id": session_id,
+ "original_bgr": box_bgr,
+ "oriented_bgr": None,
+ "cropped_bgr": box_bgr,
+ "deskewed_bgr": None,
+ "dewarped_bgr": None,
+ "orientation_result": None,
+ "crop_result": None,
+ "deskew_result": None,
+ "dewarp_result": None,
+ "ground_truth": {},
+ "current_step": 1,
+ }
+
+ created.append({
+ "id": sub_id,
+ "name": sub_name,
+ "box_index": i,
+ "box": box,
+ "image_width": crop.shape[1],
+ "image_height": crop.shape[0],
+ })
+
+ logger.info(f"Created box sub-session {sub_id} for session {session_id} "
+ f"(box {i}, {crop.shape[1]}x{crop.shape[0]})")
+
+ return {
+ "session_id": session_id,
+ "sub_sessions": created,
+ "total": len(created),
+ }
diff --git a/klausur-service/backend/ocr_pipeline_sessions_images.py b/klausur-service/backend/ocr_pipeline_sessions_images.py
new file mode 100644
index 0000000..79da448
--- /dev/null
+++ b/klausur-service/backend/ocr_pipeline_sessions_images.py
@@ -0,0 +1,176 @@
+"""
+OCR Pipeline Sessions Images — image serving, thumbnails, pipeline log,
+categories, and document type detection.
+
+Extracted from ocr_pipeline_sessions.py for modularity.
+
+Lizenz: Apache 2.0
+DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
+"""
+
+import logging
+import time
+from typing import Any, Dict
+
+import cv2
+import numpy as np
+from fastapi import APIRouter, HTTPException, Query
+from fastapi.responses import Response
+
+from cv_vocab_pipeline import create_ocr_image, detect_document_type
+from ocr_pipeline_common import (
+ VALID_DOCUMENT_CATEGORIES,
+ _append_pipeline_log,
+ _cache,
+ _get_base_image_png,
+ _get_cached,
+ _load_session_to_cache,
+)
+from ocr_pipeline_overlays import render_overlay
+from ocr_pipeline_session_store import (
+ get_session_db,
+ get_session_image,
+ update_session_db,
+)
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
+
+
+# ---------------------------------------------------------------------------
+# Thumbnail & Log Endpoints
+# ---------------------------------------------------------------------------
+
+@router.get("/sessions/{session_id}/thumbnail")
+async def get_session_thumbnail(session_id: str, size: int = Query(default=80, ge=16, le=400)):
+ """Return a small thumbnail of the original image."""
+ original_png = await get_session_image(session_id, "original")
+ if not original_png:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found or no image")
+ arr = np.frombuffer(original_png, dtype=np.uint8)
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
+ if img is None:
+ raise HTTPException(status_code=500, detail="Failed to decode image")
+ h, w = img.shape[:2]
+ scale = size / max(h, w)
+ new_w, new_h = int(w * scale), int(h * scale)
+ thumb = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
+ _, png_bytes = cv2.imencode(".png", thumb)
+ return Response(content=png_bytes.tobytes(), media_type="image/png",
+ headers={"Cache-Control": "public, max-age=3600"})
+
+
+@router.get("/sessions/{session_id}/pipeline-log")
+async def get_pipeline_log(session_id: str):
+ """Get the pipeline execution log for a session."""
+ session = await get_session_db(session_id)
+ if not session:
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
+ return {"session_id": session_id, "pipeline_log": session.get("pipeline_log") or {"steps": []}}
+
+
+@router.get("/categories")
+async def list_categories():
+ """List valid document categories."""
+ return {"categories": sorted(VALID_DOCUMENT_CATEGORIES)}
+
+
+# ---------------------------------------------------------------------------
+# Image Endpoints
+# ---------------------------------------------------------------------------
+
+@router.get("/sessions/{session_id}/image/{image_type}")
+async def get_image(session_id: str, image_type: str):
+ """Serve session images: original, deskewed, dewarped, binarized, structure-overlay, columns-overlay, or rows-overlay."""
+ valid_types = {"original", "oriented", "cropped", "deskewed", "dewarped", "binarized", "structure-overlay", "columns-overlay", "rows-overlay", "words-overlay", "clean"}
+ if image_type not in valid_types:
+ raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
+
+ if image_type == "structure-overlay":
+ return await render_overlay("structure", session_id)
+
+ if image_type == "columns-overlay":
+ return await render_overlay("columns", session_id)
+
+ if image_type == "rows-overlay":
+ return await render_overlay("rows", session_id)
+
+ if image_type == "words-overlay":
+ return await render_overlay("words", session_id)
+
+ # Try cache first for fast serving
+ cached = _cache.get(session_id)
+ if cached:
+ png_key = f"{image_type}_png" if image_type != "original" else None
+ bgr_key = f"{image_type}_bgr" if image_type != "binarized" else None
+
+ # For binarized, check if we have it cached as PNG
+ if image_type == "binarized" and cached.get("binarized_png"):
+ return Response(content=cached["binarized_png"], media_type="image/png")
+
+ # Load from DB — for cropped/dewarped, fall back through the chain
+ if image_type in ("cropped", "dewarped"):
+ data = await _get_base_image_png(session_id)
+ else:
+ data = await get_session_image(session_id, image_type)
+ if not data:
+ raise HTTPException(status_code=404, detail=f"Image '{image_type}' not available yet")
+
+ return Response(content=data, media_type="image/png")
+
+
+# ---------------------------------------------------------------------------
+# Document Type Detection (between Dewarp and Columns)
+# ---------------------------------------------------------------------------
+
+@router.post("/sessions/{session_id}/detect-type")
+async def detect_type(session_id: str):
+ """Detect document type (vocab_table, full_text, generic_table).
+
+ Should be called after crop (clean image available).
+ Falls back to dewarped if crop was skipped.
+ Stores result in session for frontend to decide pipeline flow.
+ """
+ if session_id not in _cache:
+ await _load_session_to_cache(session_id)
+ cached = _get_cached(session_id)
+
+ img_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
+ if img_bgr is None:
+ raise HTTPException(status_code=400, detail="Crop or dewarp must be completed first")
+
+ t0 = time.time()
+ ocr_img = create_ocr_image(img_bgr)
+ result = detect_document_type(ocr_img, img_bgr)
+ duration = time.time() - t0
+
+ result_dict = {
+ "doc_type": result.doc_type,
+ "confidence": result.confidence,
+ "pipeline": result.pipeline,
+ "skip_steps": result.skip_steps,
+ "features": result.features,
+ "duration_seconds": round(duration, 2),
+ }
+
+ # Persist to DB
+ await update_session_db(
+ session_id,
+ doc_type=result.doc_type,
+ doc_type_result=result_dict,
+ )
+
+ cached["doc_type_result"] = result_dict
+
+ logger.info(f"OCR Pipeline: detect-type session {session_id}: "
+ f"{result.doc_type} (confidence={result.confidence}, {duration:.2f}s)")
+
+ await _append_pipeline_log(session_id, "detect_type", {
+ "doc_type": result.doc_type,
+ "pipeline": result.pipeline,
+ "confidence": result.confidence,
+ **{k: v for k, v in (result.features or {}).items() if isinstance(v, (int, float, str, bool))},
+ }, duration_ms=int(duration * 1000))
+
+ return {"session_id": session_id, **result_dict}
diff --git a/klausur-service/backend/self_rag.py b/klausur-service/backend/self_rag.py
index 61de02c..9bcf871 100644
--- a/klausur-service/backend/self_rag.py
+++ b/klausur-service/backend/self_rag.py
@@ -1,529 +1,38 @@
"""
-Self-RAG / Corrective RAG Module
+Self-RAG / Corrective RAG Module — barrel re-export.
-Implements self-reflective RAG that can:
-1. Grade retrieved documents for relevance
-2. Decide if more retrieval is needed
-3. Reformulate queries if initial retrieval fails
-4. Filter irrelevant passages before generation
-5. Grade answers for groundedness and hallucination
+All implementation split into:
+ self_rag_grading — document relevance grading, filtering, decisions
+ self_rag_retrieval — query reformulation, retrieval loop, info
+
+IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
+When enabled, search queries and retrieved documents are sent to OpenAI API.
Based on research:
-- Self-RAG (Asai et al., 2023): Learning to retrieve, generate, and critique
-- Corrective RAG (Yan et al., 2024): Self-correcting retrieval augmented generation
-
-This is especially useful for German educational documents where:
-- Queries may use informal language
-- Documents use formal/technical terminology
-- Context must be precisely matched to scoring criteria
+- Self-RAG (Asai et al., 2023)
+- Corrective RAG (Yan et al., 2024)
"""
-import os
-from typing import List, Dict, Optional, Tuple
-from enum import Enum
-import httpx
-
-# Configuration
-# IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
-# When enabled, search queries and retrieved documents are sent to OpenAI API
-# for relevance grading and query reformulation. This exposes user data to third parties.
-# Only enable if you have explicit user consent for data processing.
-SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true"
-OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
-SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini")
-
-# Thresholds for self-reflection
-RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6"))
-GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7"))
-MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2"))
-
-
-class RetrievalDecision(Enum):
- """Decision after grading retrieval."""
- SUFFICIENT = "sufficient" # Context is good, proceed to generation
- NEEDS_MORE = "needs_more" # Need to retrieve more documents
- REFORMULATE = "reformulate" # Query needs reformulation
- FALLBACK = "fallback" # Use fallback (no good context found)
-
-
-class SelfRAGError(Exception):
- """Error during Self-RAG processing."""
- pass
-
-
-async def grade_document_relevance(
- query: str,
- document: str,
-) -> Tuple[float, str]:
- """
- Grade whether a document is relevant to the query.
-
- Returns a score between 0 (irrelevant) and 1 (highly relevant)
- along with an explanation.
- """
- if not OPENAI_API_KEY:
- # Fallback: simple keyword overlap
- query_words = set(query.lower().split())
- doc_words = set(document.lower().split())
- overlap = len(query_words & doc_words) / max(len(query_words), 1)
- return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)"
-
- prompt = f"""Bewerte, ob das folgende Dokument relevant für die Suchanfrage ist.
-
-SUCHANFRAGE: {query}
-
-DOKUMENT:
-{document[:2000]}
-
-Ist dieses Dokument relevant, um die Anfrage zu beantworten?
-Berücksichtige:
-- Thematische Übereinstimmung
-- Enthält das Dokument spezifische Informationen zur Anfrage?
-- Würde dieses Dokument bei der Beantwortung helfen?
-
-Antworte im Format:
-SCORE: [0.0-1.0]
-BEGRÜNDUNG: [Kurze Erklärung]"""
-
- try:
- async with httpx.AsyncClient() as client:
- response = await client.post(
- "https://api.openai.com/v1/chat/completions",
- headers={
- "Authorization": f"Bearer {OPENAI_API_KEY}",
- "Content-Type": "application/json"
- },
- json={
- "model": SELF_RAG_MODEL,
- "messages": [{"role": "user", "content": prompt}],
- "max_tokens": 150,
- "temperature": 0.0,
- },
- timeout=30.0
- )
-
- if response.status_code != 200:
- return 0.5, f"API error: {response.status_code}"
-
- result = response.json()["choices"][0]["message"]["content"]
-
- import re
- score_match = re.search(r'SCORE:\s*([\d.]+)', result)
- score = float(score_match.group(1)) if score_match else 0.5
-
- reason_match = re.search(r'BEGRÜNDUNG:\s*(.+)', result, re.DOTALL)
- reason = reason_match.group(1).strip() if reason_match else result
-
- return min(max(score, 0.0), 1.0), reason
-
- except Exception as e:
- return 0.5, f"Grading error: {str(e)}"
-
-
-async def grade_documents_batch(
- query: str,
- documents: List[str],
-) -> List[Tuple[float, str]]:
- """
- Grade multiple documents for relevance.
-
- Returns list of (score, reason) tuples.
- """
- results = []
- for doc in documents:
- score, reason = await grade_document_relevance(query, doc)
- results.append((score, reason))
- return results
-
-
-async def filter_relevant_documents(
- query: str,
- documents: List[Dict],
- threshold: float = RELEVANCE_THRESHOLD,
-) -> Tuple[List[Dict], List[Dict]]:
- """
- Filter documents by relevance, separating relevant from irrelevant.
-
- Args:
- query: The search query
- documents: List of document dicts with 'text' field
- threshold: Minimum relevance score to keep
-
- Returns:
- Tuple of (relevant_docs, filtered_out_docs)
- """
- relevant = []
- filtered = []
-
- for doc in documents:
- text = doc.get("text", "")
- score, reason = await grade_document_relevance(query, text)
-
- doc_with_grade = doc.copy()
- doc_with_grade["relevance_score"] = score
- doc_with_grade["relevance_reason"] = reason
-
- if score >= threshold:
- relevant.append(doc_with_grade)
- else:
- filtered.append(doc_with_grade)
-
- # Sort relevant by score
- relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
-
- return relevant, filtered
-
-
-async def decide_retrieval_strategy(
- query: str,
- documents: List[Dict],
- attempt: int = 1,
-) -> Tuple[RetrievalDecision, Dict]:
- """
- Decide what to do based on current retrieval results.
-
- Args:
- query: The search query
- documents: Retrieved documents with relevance scores
- attempt: Current retrieval attempt number
-
- Returns:
- Tuple of (decision, metadata)
- """
- if not documents:
- if attempt >= MAX_RETRIEVAL_ATTEMPTS:
- return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"}
- return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"}
-
- # Check average relevance
- scores = [doc.get("relevance_score", 0.5) for doc in documents]
- avg_score = sum(scores) / len(scores)
- max_score = max(scores)
-
- if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7:
- return RetrievalDecision.SUFFICIENT, {
- "avg_relevance": avg_score,
- "max_relevance": max_score,
- "doc_count": len(documents),
- }
-
- if attempt >= MAX_RETRIEVAL_ATTEMPTS:
- if max_score >= RELEVANCE_THRESHOLD * 0.5:
- # At least some relevant context, proceed with caution
- return RetrievalDecision.SUFFICIENT, {
- "avg_relevance": avg_score,
- "warning": "Low relevance after max attempts",
- }
- return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"}
-
- if avg_score < 0.3:
- return RetrievalDecision.REFORMULATE, {
- "reason": "Very low relevance, query reformulation needed",
- "avg_relevance": avg_score,
- }
-
- return RetrievalDecision.NEEDS_MORE, {
- "reason": "Moderate relevance, retrieving more documents",
- "avg_relevance": avg_score,
- }
-
-
-async def reformulate_query(
- original_query: str,
- context: Optional[str] = None,
- previous_results_summary: Optional[str] = None,
-) -> str:
- """
- Reformulate a query to improve retrieval.
-
- Uses LLM to generate a better query based on:
- - Original query
- - Optional context (subject, niveau, etc.)
- - Summary of why previous retrieval failed
- """
- if not OPENAI_API_KEY:
- # Simple reformulation: expand abbreviations, add synonyms
- reformulated = original_query
- expansions = {
- "EA": "erhöhtes Anforderungsniveau",
- "eA": "erhöhtes Anforderungsniveau",
- "GA": "grundlegendes Anforderungsniveau",
- "gA": "grundlegendes Anforderungsniveau",
- "AFB": "Anforderungsbereich",
- "Abi": "Abitur",
- }
- for abbr, expansion in expansions.items():
- if abbr in original_query:
- reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})")
- return reformulated
-
- prompt = f"""Du bist ein Experte für deutsche Bildungsstandards und Prüfungsanforderungen.
-
-Die folgende Suchanfrage hat keine guten Ergebnisse geliefert:
-ORIGINAL: {original_query}
-
-{f"KONTEXT: {context}" if context else ""}
-{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""}
-
-Formuliere die Anfrage so um, dass sie:
-1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten)
-2. Relevante Synonyme oder verwandte Begriffe einschließt
-3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist
-
-Antworte NUR mit der umformulierten Suchanfrage, ohne Erklärung."""
-
- try:
- async with httpx.AsyncClient() as client:
- response = await client.post(
- "https://api.openai.com/v1/chat/completions",
- headers={
- "Authorization": f"Bearer {OPENAI_API_KEY}",
- "Content-Type": "application/json"
- },
- json={
- "model": SELF_RAG_MODEL,
- "messages": [{"role": "user", "content": prompt}],
- "max_tokens": 100,
- "temperature": 0.3,
- },
- timeout=30.0
- )
-
- if response.status_code != 200:
- return original_query
-
- return response.json()["choices"][0]["message"]["content"].strip()
-
- except Exception:
- return original_query
-
-
-async def grade_answer_groundedness(
- answer: str,
- contexts: List[str],
-) -> Tuple[float, List[str]]:
- """
- Grade whether an answer is grounded in the provided contexts.
-
- Returns:
- Tuple of (grounding_score, list of unsupported claims)
- """
- if not OPENAI_API_KEY:
- return 0.5, ["LLM not configured for grounding check"]
-
- context_text = "\n---\n".join(contexts[:5])
-
- prompt = f"""Analysiere, ob die folgende Antwort vollständig durch die Kontexte gestützt wird.
-
-KONTEXTE:
-{context_text}
-
-ANTWORT:
-{answer}
-
-Identifiziere:
-1. Welche Aussagen sind durch die Kontexte belegt?
-2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)?
-
-Antworte im Format:
-SCORE: [0.0-1.0] (1.0 = vollständig belegt)
-NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]"""
-
- try:
- async with httpx.AsyncClient() as client:
- response = await client.post(
- "https://api.openai.com/v1/chat/completions",
- headers={
- "Authorization": f"Bearer {OPENAI_API_KEY}",
- "Content-Type": "application/json"
- },
- json={
- "model": SELF_RAG_MODEL,
- "messages": [{"role": "user", "content": prompt}],
- "max_tokens": 300,
- "temperature": 0.0,
- },
- timeout=30.0
- )
-
- if response.status_code != 200:
- return 0.5, [f"API error: {response.status_code}"]
-
- result = response.json()["choices"][0]["message"]["content"]
-
- import re
- score_match = re.search(r'SCORE:\s*([\d.]+)', result)
- score = float(score_match.group(1)) if score_match else 0.5
-
- unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL)
- unsupported_text = unsupported_match.group(1).strip() if unsupported_match else ""
-
- if unsupported_text.lower() == "keine":
- unsupported = []
- else:
- unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()]
-
- return min(max(score, 0.0), 1.0), unsupported
-
- except Exception as e:
- return 0.5, [f"Grounding check error: {str(e)}"]
-
-
-async def self_rag_retrieve(
- query: str,
- search_func,
- subject: Optional[str] = None,
- niveau: Optional[str] = None,
- initial_top_k: int = 10,
- final_top_k: int = 5,
- **search_kwargs
-) -> Dict:
- """
- Perform Self-RAG enhanced retrieval with reflection and correction.
-
- This implements a retrieval loop that:
- 1. Retrieves initial documents
- 2. Grades them for relevance
- 3. Decides if more retrieval is needed
- 4. Reformulates query if necessary
- 5. Returns filtered, high-quality context
-
- Args:
- query: The search query
- search_func: Async function to perform the actual search
- subject: Optional subject context
- niveau: Optional niveau context
- initial_top_k: Number of documents for initial retrieval
- final_top_k: Maximum documents to return
- **search_kwargs: Additional args for search_func
-
- Returns:
- Dict with results, metadata, and reflection trace
- """
- if not SELF_RAG_ENABLED:
- # Fall back to simple search
- results = await search_func(query=query, limit=final_top_k, **search_kwargs)
- return {
- "results": results,
- "self_rag_enabled": False,
- "query_used": query,
- }
-
- trace = []
- current_query = query
- attempt = 1
-
- while attempt <= MAX_RETRIEVAL_ATTEMPTS:
- # Step 1: Retrieve documents
- results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs)
-
- trace.append({
- "attempt": attempt,
- "query": current_query,
- "retrieved_count": len(results) if results else 0,
- })
-
- if not results:
- attempt += 1
- if attempt <= MAX_RETRIEVAL_ATTEMPTS:
- current_query = await reformulate_query(
- query,
- context=f"Fach: {subject}" if subject else None,
- previous_results_summary="Keine Dokumente gefunden"
- )
- trace[-1]["action"] = "reformulate"
- trace[-1]["new_query"] = current_query
- continue
-
- # Step 2: Grade documents for relevance
- relevant, filtered = await filter_relevant_documents(current_query, results)
-
- trace[-1]["relevant_count"] = len(relevant)
- trace[-1]["filtered_count"] = len(filtered)
-
- # Step 3: Decide what to do
- decision, decision_meta = await decide_retrieval_strategy(
- current_query, relevant, attempt
- )
-
- trace[-1]["decision"] = decision.value
- trace[-1]["decision_meta"] = decision_meta
-
- if decision == RetrievalDecision.SUFFICIENT:
- # We have good context, return it
- return {
- "results": relevant[:final_top_k],
- "self_rag_enabled": True,
- "query_used": current_query,
- "original_query": query if current_query != query else None,
- "attempts": attempt,
- "decision": decision.value,
- "trace": trace,
- "filtered_out_count": len(filtered),
- }
-
- elif decision == RetrievalDecision.REFORMULATE:
- # Reformulate and try again
- avg_score = decision_meta.get("avg_relevance", 0)
- current_query = await reformulate_query(
- query,
- context=f"Fach: {subject}" if subject else None,
- previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}"
- )
- trace[-1]["action"] = "reformulate"
- trace[-1]["new_query"] = current_query
-
- elif decision == RetrievalDecision.NEEDS_MORE:
- # Retrieve more with expanded query
- current_query = f"{current_query} Bewertungskriterien Anforderungen"
- trace[-1]["action"] = "expand_query"
- trace[-1]["new_query"] = current_query
-
- elif decision == RetrievalDecision.FALLBACK:
- # Return what we have, even if not ideal
- return {
- "results": (relevant or results)[:final_top_k],
- "self_rag_enabled": True,
- "query_used": current_query,
- "original_query": query if current_query != query else None,
- "attempts": attempt,
- "decision": decision.value,
- "warning": "Fallback mode - low relevance context",
- "trace": trace,
- }
-
- attempt += 1
-
- # Max attempts reached
- return {
- "results": results[:final_top_k] if results else [],
- "self_rag_enabled": True,
- "query_used": current_query,
- "original_query": query if current_query != query else None,
- "attempts": attempt - 1,
- "decision": "max_attempts",
- "warning": "Max retrieval attempts reached",
- "trace": trace,
- }
-
-
-def get_self_rag_info() -> dict:
- """Get information about Self-RAG configuration."""
- return {
- "enabled": SELF_RAG_ENABLED,
- "llm_configured": bool(OPENAI_API_KEY),
- "model": SELF_RAG_MODEL,
- "relevance_threshold": RELEVANCE_THRESHOLD,
- "grounding_threshold": GROUNDING_THRESHOLD,
- "max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS,
- "features": [
- "document_grading",
- "relevance_filtering",
- "query_reformulation",
- "answer_grounding_check",
- "retrieval_decision",
- ],
- "sends_data_externally": True, # ALWAYS true when enabled - documents sent to OpenAI
- "privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading",
- "default_enabled": False, # Disabled by default for privacy
- }
+# Grading: relevance, filtering, decisions, groundedness
+from self_rag_grading import ( # noqa: F401
+ SELF_RAG_ENABLED,
+ OPENAI_API_KEY,
+ SELF_RAG_MODEL,
+ RELEVANCE_THRESHOLD,
+ GROUNDING_THRESHOLD,
+ MAX_RETRIEVAL_ATTEMPTS,
+ RetrievalDecision,
+ SelfRAGError,
+ grade_document_relevance,
+ grade_documents_batch,
+ filter_relevant_documents,
+ decide_retrieval_strategy,
+ grade_answer_groundedness,
+)
+
+# Retrieval: reformulation, loop, info
+from self_rag_retrieval import ( # noqa: F401
+ reformulate_query,
+ self_rag_retrieve,
+ get_self_rag_info,
+)
diff --git a/klausur-service/backend/self_rag_grading.py b/klausur-service/backend/self_rag_grading.py
new file mode 100644
index 0000000..be6b096
--- /dev/null
+++ b/klausur-service/backend/self_rag_grading.py
@@ -0,0 +1,285 @@
+"""
+Self-RAG Grading — document relevance grading, filtering, retrieval decisions.
+
+Extracted from self_rag.py for modularity.
+
+Based on research:
+- Self-RAG (Asai et al., 2023)
+- Corrective RAG (Yan et al., 2024)
+"""
+
+import os
+from typing import List, Dict, Optional, Tuple
+from enum import Enum
+import httpx
+
+# Configuration
+SELF_RAG_ENABLED = os.getenv("SELF_RAG_ENABLED", "false").lower() == "true"
+OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
+SELF_RAG_MODEL = os.getenv("SELF_RAG_MODEL", "gpt-4o-mini")
+
+# Thresholds for self-reflection
+RELEVANCE_THRESHOLD = float(os.getenv("SELF_RAG_RELEVANCE_THRESHOLD", "0.6"))
+GROUNDING_THRESHOLD = float(os.getenv("SELF_RAG_GROUNDING_THRESHOLD", "0.7"))
+MAX_RETRIEVAL_ATTEMPTS = int(os.getenv("SELF_RAG_MAX_ATTEMPTS", "2"))
+
+
+class RetrievalDecision(Enum):
+ """Decision after grading retrieval."""
+ SUFFICIENT = "sufficient" # Context is good, proceed to generation
+ NEEDS_MORE = "needs_more" # Need to retrieve more documents
+ REFORMULATE = "reformulate" # Query needs reformulation
+ FALLBACK = "fallback" # Use fallback (no good context found)
+
+
+class SelfRAGError(Exception):
+ """Error during Self-RAG processing."""
+ pass
+
+
+async def grade_document_relevance(
+ query: str,
+ document: str,
+) -> Tuple[float, str]:
+ """
+ Grade whether a document is relevant to the query.
+
+ Returns a score between 0 (irrelevant) and 1 (highly relevant)
+ along with an explanation.
+ """
+ if not OPENAI_API_KEY:
+ # Fallback: simple keyword overlap
+ query_words = set(query.lower().split())
+ doc_words = set(document.lower().split())
+ overlap = len(query_words & doc_words) / max(len(query_words), 1)
+ return min(overlap * 2, 1.0), "Keyword-based relevance (no LLM)"
+
+ prompt = f"""Bewerte, ob das folgende Dokument relevant fuer die Suchanfrage ist.
+
+SUCHANFRAGE: {query}
+
+DOKUMENT:
+{document[:2000]}
+
+Ist dieses Dokument relevant, um die Anfrage zu beantworten?
+Beruecksichtige:
+- Thematische Uebereinstimmung
+- Enthaelt das Dokument spezifische Informationen zur Anfrage?
+- Wuerde dieses Dokument bei der Beantwortung helfen?
+
+Antworte im Format:
+SCORE: [0.0-1.0]
+BEGRUENDUNG: [Kurze Erklaerung]"""
+
+ try:
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ "https://api.openai.com/v1/chat/completions",
+ headers={
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
+ "Content-Type": "application/json"
+ },
+ json={
+ "model": SELF_RAG_MODEL,
+ "messages": [{"role": "user", "content": prompt}],
+ "max_tokens": 150,
+ "temperature": 0.0,
+ },
+ timeout=30.0
+ )
+
+ if response.status_code != 200:
+ return 0.5, f"API error: {response.status_code}"
+
+ result = response.json()["choices"][0]["message"]["content"]
+
+ import re
+ score_match = re.search(r'SCORE:\s*([\d.]+)', result)
+ score = float(score_match.group(1)) if score_match else 0.5
+
+ reason_match = re.search(r'BEGRUENDUNG:\s*(.+)', result, re.DOTALL)
+ reason = reason_match.group(1).strip() if reason_match else result
+
+ return min(max(score, 0.0), 1.0), reason
+
+ except Exception as e:
+ return 0.5, f"Grading error: {str(e)}"
+
+
+async def grade_documents_batch(
+ query: str,
+ documents: List[str],
+) -> List[Tuple[float, str]]:
+ """
+ Grade multiple documents for relevance.
+
+ Returns list of (score, reason) tuples.
+ """
+ results = []
+ for doc in documents:
+ score, reason = await grade_document_relevance(query, doc)
+ results.append((score, reason))
+ return results
+
+
+async def filter_relevant_documents(
+ query: str,
+ documents: List[Dict],
+ threshold: float = RELEVANCE_THRESHOLD,
+) -> Tuple[List[Dict], List[Dict]]:
+ """
+ Filter documents by relevance, separating relevant from irrelevant.
+
+ Args:
+ query: The search query
+ documents: List of document dicts with 'text' field
+ threshold: Minimum relevance score to keep
+
+ Returns:
+ Tuple of (relevant_docs, filtered_out_docs)
+ """
+ relevant = []
+ filtered = []
+
+ for doc in documents:
+ text = doc.get("text", "")
+ score, reason = await grade_document_relevance(query, text)
+
+ doc_with_grade = doc.copy()
+ doc_with_grade["relevance_score"] = score
+ doc_with_grade["relevance_reason"] = reason
+
+ if score >= threshold:
+ relevant.append(doc_with_grade)
+ else:
+ filtered.append(doc_with_grade)
+
+ # Sort relevant by score
+ relevant.sort(key=lambda x: x.get("relevance_score", 0), reverse=True)
+
+ return relevant, filtered
+
+
+async def decide_retrieval_strategy(
+ query: str,
+ documents: List[Dict],
+ attempt: int = 1,
+) -> Tuple[RetrievalDecision, Dict]:
+ """
+ Decide what to do based on current retrieval results.
+
+ Args:
+ query: The search query
+ documents: Retrieved documents with relevance scores
+ attempt: Current retrieval attempt number
+
+ Returns:
+ Tuple of (decision, metadata)
+ """
+ if not documents:
+ if attempt >= MAX_RETRIEVAL_ATTEMPTS:
+ return RetrievalDecision.FALLBACK, {"reason": "No documents found after max attempts"}
+ return RetrievalDecision.REFORMULATE, {"reason": "No documents retrieved"}
+
+ # Check average relevance
+ scores = [doc.get("relevance_score", 0.5) for doc in documents]
+ avg_score = sum(scores) / len(scores)
+ max_score = max(scores)
+
+ if max_score >= RELEVANCE_THRESHOLD and avg_score >= RELEVANCE_THRESHOLD * 0.7:
+ return RetrievalDecision.SUFFICIENT, {
+ "avg_relevance": avg_score,
+ "max_relevance": max_score,
+ "doc_count": len(documents),
+ }
+
+ if attempt >= MAX_RETRIEVAL_ATTEMPTS:
+ if max_score >= RELEVANCE_THRESHOLD * 0.5:
+ # At least some relevant context, proceed with caution
+ return RetrievalDecision.SUFFICIENT, {
+ "avg_relevance": avg_score,
+ "warning": "Low relevance after max attempts",
+ }
+ return RetrievalDecision.FALLBACK, {"reason": "Max attempts reached, low relevance"}
+
+ if avg_score < 0.3:
+ return RetrievalDecision.REFORMULATE, {
+ "reason": "Very low relevance, query reformulation needed",
+ "avg_relevance": avg_score,
+ }
+
+ return RetrievalDecision.NEEDS_MORE, {
+ "reason": "Moderate relevance, retrieving more documents",
+ "avg_relevance": avg_score,
+ }
+
+
+async def grade_answer_groundedness(
+ answer: str,
+ contexts: List[str],
+) -> Tuple[float, List[str]]:
+ """
+ Grade whether an answer is grounded in the provided contexts.
+
+ Returns:
+ Tuple of (grounding_score, list of unsupported claims)
+ """
+ if not OPENAI_API_KEY:
+ return 0.5, ["LLM not configured for grounding check"]
+
+ context_text = "\n---\n".join(contexts[:5])
+
+ prompt = f"""Analysiere, ob die folgende Antwort vollstaendig durch die Kontexte gestuetzt wird.
+
+KONTEXTE:
+{context_text}
+
+ANTWORT:
+{answer}
+
+Identifiziere:
+1. Welche Aussagen sind durch die Kontexte belegt?
+2. Welche Aussagen sind NICHT belegt (potenzielle Halluzinationen)?
+
+Antworte im Format:
+SCORE: [0.0-1.0] (1.0 = vollstaendig belegt)
+NICHT_BELEGT: [Liste der nicht belegten Aussagen, eine pro Zeile, oder "Keine"]"""
+
+ try:
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ "https://api.openai.com/v1/chat/completions",
+ headers={
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
+ "Content-Type": "application/json"
+ },
+ json={
+ "model": SELF_RAG_MODEL,
+ "messages": [{"role": "user", "content": prompt}],
+ "max_tokens": 300,
+ "temperature": 0.0,
+ },
+ timeout=30.0
+ )
+
+ if response.status_code != 200:
+ return 0.5, [f"API error: {response.status_code}"]
+
+ result = response.json()["choices"][0]["message"]["content"]
+
+ import re
+ score_match = re.search(r'SCORE:\s*([\d.]+)', result)
+ score = float(score_match.group(1)) if score_match else 0.5
+
+ unsupported_match = re.search(r'NICHT_BELEGT:\s*(.+)', result, re.DOTALL)
+ unsupported_text = unsupported_match.group(1).strip() if unsupported_match else ""
+
+ if unsupported_text.lower() == "keine":
+ unsupported = []
+ else:
+ unsupported = [line.strip() for line in unsupported_text.split("\n") if line.strip()]
+
+ return min(max(score, 0.0), 1.0), unsupported
+
+ except Exception as e:
+ return 0.5, [f"Grounding check error: {str(e)}"]
diff --git a/klausur-service/backend/self_rag_retrieval.py b/klausur-service/backend/self_rag_retrieval.py
new file mode 100644
index 0000000..ac989d7
--- /dev/null
+++ b/klausur-service/backend/self_rag_retrieval.py
@@ -0,0 +1,255 @@
+"""
+Self-RAG Retrieval — query reformulation, retrieval loop, info.
+
+Extracted from self_rag.py for modularity.
+
+IMPORTANT: Self-RAG is DISABLED by default for privacy reasons!
+When enabled, search queries and retrieved documents are sent to OpenAI API
+for relevance grading and query reformulation.
+"""
+
+import os
+from typing import List, Dict, Optional
+import httpx
+
+from self_rag_grading import (
+ SELF_RAG_ENABLED,
+ OPENAI_API_KEY,
+ SELF_RAG_MODEL,
+ RELEVANCE_THRESHOLD,
+ GROUNDING_THRESHOLD,
+ MAX_RETRIEVAL_ATTEMPTS,
+ RetrievalDecision,
+ filter_relevant_documents,
+ decide_retrieval_strategy,
+)
+
+
+async def reformulate_query(
+ original_query: str,
+ context: Optional[str] = None,
+ previous_results_summary: Optional[str] = None,
+) -> str:
+ """
+ Reformulate a query to improve retrieval.
+
+ Uses LLM to generate a better query based on:
+ - Original query
+ - Optional context (subject, niveau, etc.)
+ - Summary of why previous retrieval failed
+ """
+ if not OPENAI_API_KEY:
+ # Simple reformulation: expand abbreviations, add synonyms
+ reformulated = original_query
+ expansions = {
+ "EA": "erhoehtes Anforderungsniveau",
+ "eA": "erhoehtes Anforderungsniveau",
+ "GA": "grundlegendes Anforderungsniveau",
+ "gA": "grundlegendes Anforderungsniveau",
+ "AFB": "Anforderungsbereich",
+ "Abi": "Abitur",
+ }
+ for abbr, expansion in expansions.items():
+ if abbr in original_query:
+ reformulated = reformulated.replace(abbr, f"{abbr} ({expansion})")
+ return reformulated
+
+ prompt = f"""Du bist ein Experte fuer deutsche Bildungsstandards und Pruefungsanforderungen.
+
+Die folgende Suchanfrage hat keine guten Ergebnisse geliefert:
+ORIGINAL: {original_query}
+
+{f"KONTEXT: {context}" if context else ""}
+{f"PROBLEM MIT VORHERIGEN ERGEBNISSEN: {previous_results_summary}" if previous_results_summary else ""}
+
+Formuliere die Anfrage so um, dass sie:
+1. Formellere/technischere Begriffe verwendet (wie in offiziellen Dokumenten)
+2. Relevante Synonyme oder verwandte Begriffe einschliesst
+3. Spezifischer auf Erwartungshorizonte/Bewertungskriterien ausgerichtet ist
+
+Antworte NUR mit der umformulierten Suchanfrage, ohne Erklaerung."""
+
+ try:
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ "https://api.openai.com/v1/chat/completions",
+ headers={
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
+ "Content-Type": "application/json"
+ },
+ json={
+ "model": SELF_RAG_MODEL,
+ "messages": [{"role": "user", "content": prompt}],
+ "max_tokens": 100,
+ "temperature": 0.3,
+ },
+ timeout=30.0
+ )
+
+ if response.status_code != 200:
+ return original_query
+
+ return response.json()["choices"][0]["message"]["content"].strip()
+
+ except Exception:
+ return original_query
+
+
+async def self_rag_retrieve(
+ query: str,
+ search_func,
+ subject: Optional[str] = None,
+ niveau: Optional[str] = None,
+ initial_top_k: int = 10,
+ final_top_k: int = 5,
+ **search_kwargs
+) -> Dict:
+ """
+ Perform Self-RAG enhanced retrieval with reflection and correction.
+
+ This implements a retrieval loop that:
+ 1. Retrieves initial documents
+ 2. Grades them for relevance
+ 3. Decides if more retrieval is needed
+ 4. Reformulates query if necessary
+ 5. Returns filtered, high-quality context
+
+ Args:
+ query: The search query
+ search_func: Async function to perform the actual search
+ subject: Optional subject context
+ niveau: Optional niveau context
+ initial_top_k: Number of documents for initial retrieval
+ final_top_k: Maximum documents to return
+ **search_kwargs: Additional args for search_func
+
+ Returns:
+ Dict with results, metadata, and reflection trace
+ """
+ if not SELF_RAG_ENABLED:
+ # Fall back to simple search
+ results = await search_func(query=query, limit=final_top_k, **search_kwargs)
+ return {
+ "results": results,
+ "self_rag_enabled": False,
+ "query_used": query,
+ }
+
+ trace = []
+ current_query = query
+ attempt = 1
+
+ while attempt <= MAX_RETRIEVAL_ATTEMPTS:
+ # Step 1: Retrieve documents
+ results = await search_func(query=current_query, limit=initial_top_k, **search_kwargs)
+
+ trace.append({
+ "attempt": attempt,
+ "query": current_query,
+ "retrieved_count": len(results) if results else 0,
+ })
+
+ if not results:
+ attempt += 1
+ if attempt <= MAX_RETRIEVAL_ATTEMPTS:
+ current_query = await reformulate_query(
+ query,
+ context=f"Fach: {subject}" if subject else None,
+ previous_results_summary="Keine Dokumente gefunden"
+ )
+ trace[-1]["action"] = "reformulate"
+ trace[-1]["new_query"] = current_query
+ continue
+
+ # Step 2: Grade documents for relevance
+ relevant, filtered = await filter_relevant_documents(current_query, results)
+
+ trace[-1]["relevant_count"] = len(relevant)
+ trace[-1]["filtered_count"] = len(filtered)
+
+ # Step 3: Decide what to do
+ decision, decision_meta = await decide_retrieval_strategy(
+ current_query, relevant, attempt
+ )
+
+ trace[-1]["decision"] = decision.value
+ trace[-1]["decision_meta"] = decision_meta
+
+ if decision == RetrievalDecision.SUFFICIENT:
+ # We have good context, return it
+ return {
+ "results": relevant[:final_top_k],
+ "self_rag_enabled": True,
+ "query_used": current_query,
+ "original_query": query if current_query != query else None,
+ "attempts": attempt,
+ "decision": decision.value,
+ "trace": trace,
+ "filtered_out_count": len(filtered),
+ }
+
+ elif decision == RetrievalDecision.REFORMULATE:
+ # Reformulate and try again
+ avg_score = decision_meta.get("avg_relevance", 0)
+ current_query = await reformulate_query(
+ query,
+ context=f"Fach: {subject}" if subject else None,
+ previous_results_summary=f"Durchschnittliche Relevanz: {avg_score:.2f}"
+ )
+ trace[-1]["action"] = "reformulate"
+ trace[-1]["new_query"] = current_query
+
+ elif decision == RetrievalDecision.NEEDS_MORE:
+ # Retrieve more with expanded query
+ current_query = f"{current_query} Bewertungskriterien Anforderungen"
+ trace[-1]["action"] = "expand_query"
+ trace[-1]["new_query"] = current_query
+
+ elif decision == RetrievalDecision.FALLBACK:
+ # Return what we have, even if not ideal
+ return {
+ "results": (relevant or results)[:final_top_k],
+ "self_rag_enabled": True,
+ "query_used": current_query,
+ "original_query": query if current_query != query else None,
+ "attempts": attempt,
+ "decision": decision.value,
+ "warning": "Fallback mode - low relevance context",
+ "trace": trace,
+ }
+
+ attempt += 1
+
+ # Max attempts reached
+ return {
+ "results": results[:final_top_k] if results else [],
+ "self_rag_enabled": True,
+ "query_used": current_query,
+ "original_query": query if current_query != query else None,
+ "attempts": attempt - 1,
+ "decision": "max_attempts",
+ "warning": "Max retrieval attempts reached",
+ "trace": trace,
+ }
+
+
+def get_self_rag_info() -> dict:
+ """Get information about Self-RAG configuration."""
+ return {
+ "enabled": SELF_RAG_ENABLED,
+ "llm_configured": bool(OPENAI_API_KEY),
+ "model": SELF_RAG_MODEL,
+ "relevance_threshold": RELEVANCE_THRESHOLD,
+ "grounding_threshold": GROUNDING_THRESHOLD,
+ "max_retrieval_attempts": MAX_RETRIEVAL_ATTEMPTS,
+ "features": [
+ "document_grading",
+ "relevance_filtering",
+ "query_reformulation",
+ "answer_grounding_check",
+ "retrieval_decision",
+ ],
+ "sends_data_externally": True, # ALWAYS true when enabled
+ "privacy_warning": "When enabled, queries and documents are sent to OpenAI API for grading",
+ "default_enabled": False, # Disabled by default for privacy
+ }
diff --git a/klausur-service/backend/services/grid_detection_models.py b/klausur-service/backend/services/grid_detection_models.py
new file mode 100644
index 0000000..dcd2bf2
--- /dev/null
+++ b/klausur-service/backend/services/grid_detection_models.py
@@ -0,0 +1,164 @@
+"""
+Grid Detection Models v4
+
+Data classes for OCR grid detection results.
+Coordinates use percentage (0-100) and mm (A4 format).
+"""
+
+from enum import Enum
+from dataclasses import dataclass, field
+from typing import List, Dict, Any
+
+# A4 dimensions
+A4_WIDTH_MM = 210.0
+A4_HEIGHT_MM = 297.0
+
+# Column margin (1mm)
+COLUMN_MARGIN_MM = 1.0
+COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
+
+
+class CellStatus(str, Enum):
+ EMPTY = "empty"
+ RECOGNIZED = "recognized"
+ PROBLEMATIC = "problematic"
+ MANUAL = "manual"
+
+
+class ColumnType(str, Enum):
+ ENGLISH = "english"
+ GERMAN = "german"
+ EXAMPLE = "example"
+ UNKNOWN = "unknown"
+
+
+@dataclass
+class OCRRegion:
+ """A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
+ text: str
+ confidence: float
+ x: float # X position as percentage of page width
+ y: float # Y position as percentage of page height
+ width: float # Width as percentage of page width
+ height: float # Height as percentage of page height
+
+ @property
+ def x_mm(self) -> float:
+ return round(self.x / 100 * A4_WIDTH_MM, 1)
+
+ @property
+ def y_mm(self) -> float:
+ return round(self.y / 100 * A4_HEIGHT_MM, 1)
+
+ @property
+ def width_mm(self) -> float:
+ return round(self.width / 100 * A4_WIDTH_MM, 1)
+
+ @property
+ def height_mm(self) -> float:
+ return round(self.height / 100 * A4_HEIGHT_MM, 2)
+
+ @property
+ def center_x(self) -> float:
+ return self.x + self.width / 2
+
+ @property
+ def center_y(self) -> float:
+ return self.y + self.height / 2
+
+ @property
+ def right(self) -> float:
+ return self.x + self.width
+
+ @property
+ def bottom(self) -> float:
+ return self.y + self.height
+
+
+@dataclass
+class GridCell:
+ """A cell in the detected grid with coordinates in percentage (0-100)."""
+ row: int
+ col: int
+ x: float
+ y: float
+ width: float
+ height: float
+ text: str = ""
+ confidence: float = 0.0
+ status: CellStatus = CellStatus.EMPTY
+ column_type: ColumnType = ColumnType.UNKNOWN
+ logical_row: int = 0
+ logical_col: int = 0
+ is_continuation: bool = False
+
+ @property
+ def x_mm(self) -> float:
+ return round(self.x / 100 * A4_WIDTH_MM, 1)
+
+ @property
+ def y_mm(self) -> float:
+ return round(self.y / 100 * A4_HEIGHT_MM, 1)
+
+ @property
+ def width_mm(self) -> float:
+ return round(self.width / 100 * A4_WIDTH_MM, 1)
+
+ @property
+ def height_mm(self) -> float:
+ return round(self.height / 100 * A4_HEIGHT_MM, 2)
+
+ def to_dict(self) -> dict:
+ return {
+ "row": self.row,
+ "col": self.col,
+ "x": round(self.x, 2),
+ "y": round(self.y, 2),
+ "width": round(self.width, 2),
+ "height": round(self.height, 2),
+ "x_mm": self.x_mm,
+ "y_mm": self.y_mm,
+ "width_mm": self.width_mm,
+ "height_mm": self.height_mm,
+ "text": self.text,
+ "confidence": self.confidence,
+ "status": self.status.value,
+ "column_type": self.column_type.value,
+ "logical_row": self.logical_row,
+ "logical_col": self.logical_col,
+ "is_continuation": self.is_continuation,
+ }
+
+
+@dataclass
+class GridResult:
+ """Result of grid detection."""
+ rows: int = 0
+ columns: int = 0
+ cells: List[List[GridCell]] = field(default_factory=list)
+ column_types: List[str] = field(default_factory=list)
+ column_boundaries: List[float] = field(default_factory=list)
+ row_boundaries: List[float] = field(default_factory=list)
+ deskew_angle: float = 0.0
+ stats: Dict[str, Any] = field(default_factory=dict)
+
+ def to_dict(self) -> dict:
+ cells_dicts = []
+ for row_cells in self.cells:
+ cells_dicts.append([c.to_dict() for c in row_cells])
+
+ return {
+ "rows": self.rows,
+ "columns": self.columns,
+ "cells": cells_dicts,
+ "column_types": self.column_types,
+ "column_boundaries": [round(b, 2) for b in self.column_boundaries],
+ "row_boundaries": [round(b, 2) for b in self.row_boundaries],
+ "deskew_angle": round(self.deskew_angle, 2),
+ "stats": self.stats,
+ "page_dimensions": {
+ "width_mm": A4_WIDTH_MM,
+ "height_mm": A4_HEIGHT_MM,
+ "format": "A4",
+ },
+ }
diff --git a/klausur-service/backend/services/grid_detection_service.py b/klausur-service/backend/services/grid_detection_service.py
index 4f6c4c2..c544275 100644
--- a/klausur-service/backend/services/grid_detection_service.py
+++ b/klausur-service/backend/services/grid_detection_service.py
@@ -10,166 +10,21 @@ Lizenz: Apache 2.0 (kommerziell nutzbar)
import math
import logging
-from enum import Enum
-from dataclasses import dataclass, field
-from typing import List, Optional, Dict, Any, Tuple
+from typing import List
+
+from .grid_detection_models import (
+ A4_WIDTH_MM,
+ A4_HEIGHT_MM,
+ COLUMN_MARGIN_MM,
+ CellStatus,
+ ColumnType,
+ OCRRegion,
+ GridCell,
+ GridResult,
+)
logger = logging.getLogger(__name__)
-# A4 dimensions
-A4_WIDTH_MM = 210.0
-A4_HEIGHT_MM = 297.0
-
-# Column margin (1mm)
-COLUMN_MARGIN_MM = 1.0
-COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
-
-
-class CellStatus(str, Enum):
- EMPTY = "empty"
- RECOGNIZED = "recognized"
- PROBLEMATIC = "problematic"
- MANUAL = "manual"
-
-
-class ColumnType(str, Enum):
- ENGLISH = "english"
- GERMAN = "german"
- EXAMPLE = "example"
- UNKNOWN = "unknown"
-
-
-@dataclass
-class OCRRegion:
- """A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
- text: str
- confidence: float
- x: float # X position as percentage of page width
- y: float # Y position as percentage of page height
- width: float # Width as percentage of page width
- height: float # Height as percentage of page height
-
- @property
- def x_mm(self) -> float:
- return round(self.x / 100 * A4_WIDTH_MM, 1)
-
- @property
- def y_mm(self) -> float:
- return round(self.y / 100 * A4_HEIGHT_MM, 1)
-
- @property
- def width_mm(self) -> float:
- return round(self.width / 100 * A4_WIDTH_MM, 1)
-
- @property
- def height_mm(self) -> float:
- return round(self.height / 100 * A4_HEIGHT_MM, 2)
-
- @property
- def center_x(self) -> float:
- return self.x + self.width / 2
-
- @property
- def center_y(self) -> float:
- return self.y + self.height / 2
-
- @property
- def right(self) -> float:
- return self.x + self.width
-
- @property
- def bottom(self) -> float:
- return self.y + self.height
-
-
-@dataclass
-class GridCell:
- """A cell in the detected grid with coordinates in percentage (0-100)."""
- row: int
- col: int
- x: float
- y: float
- width: float
- height: float
- text: str = ""
- confidence: float = 0.0
- status: CellStatus = CellStatus.EMPTY
- column_type: ColumnType = ColumnType.UNKNOWN
- logical_row: int = 0
- logical_col: int = 0
- is_continuation: bool = False
-
- @property
- def x_mm(self) -> float:
- return round(self.x / 100 * A4_WIDTH_MM, 1)
-
- @property
- def y_mm(self) -> float:
- return round(self.y / 100 * A4_HEIGHT_MM, 1)
-
- @property
- def width_mm(self) -> float:
- return round(self.width / 100 * A4_WIDTH_MM, 1)
-
- @property
- def height_mm(self) -> float:
- return round(self.height / 100 * A4_HEIGHT_MM, 2)
-
- def to_dict(self) -> dict:
- return {
- "row": self.row,
- "col": self.col,
- "x": round(self.x, 2),
- "y": round(self.y, 2),
- "width": round(self.width, 2),
- "height": round(self.height, 2),
- "x_mm": self.x_mm,
- "y_mm": self.y_mm,
- "width_mm": self.width_mm,
- "height_mm": self.height_mm,
- "text": self.text,
- "confidence": self.confidence,
- "status": self.status.value,
- "column_type": self.column_type.value,
- "logical_row": self.logical_row,
- "logical_col": self.logical_col,
- "is_continuation": self.is_continuation,
- }
-
-
-@dataclass
-class GridResult:
- """Result of grid detection."""
- rows: int = 0
- columns: int = 0
- cells: List[List[GridCell]] = field(default_factory=list)
- column_types: List[str] = field(default_factory=list)
- column_boundaries: List[float] = field(default_factory=list)
- row_boundaries: List[float] = field(default_factory=list)
- deskew_angle: float = 0.0
- stats: Dict[str, Any] = field(default_factory=dict)
-
- def to_dict(self) -> dict:
- cells_dicts = []
- for row_cells in self.cells:
- cells_dicts.append([c.to_dict() for c in row_cells])
-
- return {
- "rows": self.rows,
- "columns": self.columns,
- "cells": cells_dicts,
- "column_types": self.column_types,
- "column_boundaries": [round(b, 2) for b in self.column_boundaries],
- "row_boundaries": [round(b, 2) for b in self.row_boundaries],
- "deskew_angle": round(self.deskew_angle, 2),
- "stats": self.stats,
- "page_dimensions": {
- "width_mm": A4_WIDTH_MM,
- "height_mm": A4_HEIGHT_MM,
- "format": "A4",
- },
- }
-
class GridDetectionService:
"""Detect grid/table structure from OCR bounding-box regions."""
@@ -184,7 +39,7 @@ class GridDetectionService:
"""Calculate page skew angle from OCR region positions.
Uses left-edge alignment of regions to detect consistent tilt.
- Returns angle in degrees, clamped to ±5°.
+ Returns angle in degrees, clamped to +/-5 degrees.
"""
if len(regions) < 3:
return 0.0
@@ -229,12 +84,12 @@ class GridDetectionService:
slope = (n * sum_xy - sum_y * sum_x) / denom
# Convert slope to angle (slope is dx/dy in percent space)
- # Adjust for aspect ratio: A4 is 210/297 ≈ 0.707
+ # Adjust for aspect ratio: A4 is 210/297 ~ 0.707
aspect = A4_WIDTH_MM / A4_HEIGHT_MM
angle_rad = math.atan(slope * aspect)
angle_deg = math.degrees(angle_rad)
- # Clamp to ±5°
+ # Clamp to +/-5 degrees
return max(-5.0, min(5.0, round(angle_deg, 2)))
def apply_deskew_to_regions(self, regions: List[OCRRegion], angle: float) -> List[OCRRegion]:
diff --git a/klausur-service/backend/smart_spell.py b/klausur-service/backend/smart_spell.py
index e400474..1926500 100644
--- a/klausur-service/backend/smart_spell.py
+++ b/klausur-service/backend/smart_spell.py
@@ -1,594 +1,25 @@
"""
-SmartSpellChecker — Language-aware OCR post-correction without LLMs.
+SmartSpellChecker — barrel re-export.
-Uses pyspellchecker (MIT) with dual EN+DE dictionaries for:
-- Automatic language detection per word (dual-dictionary heuristic)
-- OCR error correction (digit↔letter, umlauts, transpositions)
-- Context-based disambiguation (a/I, l/I) via bigram lookup
-- Mixed-language support for example sentences
+All implementation split into:
+ smart_spell_core — init, data types, language detection, word correction
+ smart_spell_text — full text correction, boundary repair, context split
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
-import logging
-import re
-from dataclasses import dataclass, field
-from typing import Dict, List, Literal, Optional, Set, Tuple
-
-logger = logging.getLogger(__name__)
-
-# ---------------------------------------------------------------------------
-# Init
-# ---------------------------------------------------------------------------
-
-try:
- from spellchecker import SpellChecker as _SpellChecker
- _en_spell = _SpellChecker(language='en', distance=1)
- _de_spell = _SpellChecker(language='de', distance=1)
- _AVAILABLE = True
-except ImportError:
- _AVAILABLE = False
- logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
-
-Lang = Literal["en", "de", "both", "unknown"]
-
-# ---------------------------------------------------------------------------
-# Bigram context for a/I disambiguation
-# ---------------------------------------------------------------------------
-
-# Words that commonly follow "I" (subject pronoun → verb/modal)
-_I_FOLLOWERS: frozenset = frozenset({
- "am", "was", "have", "had", "do", "did", "will", "would", "can",
- "could", "should", "shall", "may", "might", "must",
- "think", "know", "see", "want", "need", "like", "love", "hate",
- "go", "went", "come", "came", "say", "said", "get", "got",
- "make", "made", "take", "took", "give", "gave", "tell", "told",
- "feel", "felt", "find", "found", "believe", "hope", "wish",
- "remember", "forget", "understand", "mean", "meant",
- "don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
- "shouldn't", "haven't", "hadn't", "isn't", "wasn't",
- "really", "just", "also", "always", "never", "often", "sometimes",
-})
-
-# Words that commonly follow "a" (article → noun/adjective)
-_A_FOLLOWERS: frozenset = frozenset({
- "lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
- "long", "short", "big", "small", "large", "huge", "tiny",
- "nice", "beautiful", "wonderful", "terrible", "horrible",
- "man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
- "book", "car", "house", "room", "school", "teacher", "student",
- "day", "week", "month", "year", "time", "place", "way",
- "friend", "family", "person", "problem", "question", "story",
- "very", "really", "quite", "rather", "pretty", "single",
-})
-
-# Digit→letter substitutions (OCR confusion)
-_DIGIT_SUBS: Dict[str, List[str]] = {
- '0': ['o', 'O'],
- '1': ['l', 'I'],
- '5': ['s', 'S'],
- '6': ['g', 'G'],
- '8': ['b', 'B'],
- '|': ['I', 'l'],
- '/': ['l'], # italic 'l' misread as slash (e.g. "p/" → "pl")
-}
-_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
-
-# Umlaut confusion: OCR drops dots (ü→u, ä→a, ö→o)
-_UMLAUT_MAP = {
- 'a': 'ä', 'o': 'ö', 'u': 'ü', 'i': 'ü',
- 'A': 'Ä', 'O': 'Ö', 'U': 'Ü', 'I': 'Ü',
-}
-
-# Tokenizer — includes | and / so OCR artifacts like "p/" are treated as words
-_TOKEN_RE = re.compile(r"([A-Za-zÄÖÜäöüß'|/]+)([^A-Za-zÄÖÜäöüß'|/]*)")
-
-
-# ---------------------------------------------------------------------------
-# Data types
-# ---------------------------------------------------------------------------
-
-@dataclass
-class CorrectionResult:
- original: str
- corrected: str
- lang_detected: Lang
- changed: bool
- changes: List[str] = field(default_factory=list)
-
-
-# ---------------------------------------------------------------------------
-# Core class
-# ---------------------------------------------------------------------------
-
-class SmartSpellChecker:
- """Language-aware OCR spell checker using pyspellchecker (no LLM)."""
-
- def __init__(self):
- if not _AVAILABLE:
- raise RuntimeError("pyspellchecker not installed")
- self.en = _en_spell
- self.de = _de_spell
-
- # --- Language detection ---
-
- def detect_word_lang(self, word: str) -> Lang:
- """Detect language of a single word using dual-dict heuristic."""
- w = word.lower().strip(".,;:!?\"'()")
- if not w:
- return "unknown"
- in_en = bool(self.en.known([w]))
- in_de = bool(self.de.known([w]))
- if in_en and in_de:
- return "both"
- if in_en:
- return "en"
- if in_de:
- return "de"
- return "unknown"
-
- def detect_text_lang(self, text: str) -> Lang:
- """Detect dominant language of a text string (sentence/phrase)."""
- words = re.findall(r"[A-Za-zÄÖÜäöüß]+", text)
- if not words:
- return "unknown"
-
- en_count = 0
- de_count = 0
- for w in words:
- lang = self.detect_word_lang(w)
- if lang == "en":
- en_count += 1
- elif lang == "de":
- de_count += 1
- # "both" doesn't count for either
-
- if en_count > de_count:
- return "en"
- if de_count > en_count:
- return "de"
- if en_count == de_count and en_count > 0:
- return "both"
- return "unknown"
-
- # --- Single-word correction ---
-
- def _known(self, word: str) -> bool:
- """True if word is known in EN or DE dictionary, or is a known abbreviation."""
- w = word.lower()
- if bool(self.en.known([w])) or bool(self.de.known([w])):
- return True
- # Also accept known abbreviations (sth, sb, adj, etc.)
- try:
- from cv_ocr_engines import _KNOWN_ABBREVIATIONS
- if w in _KNOWN_ABBREVIATIONS:
- return True
- except ImportError:
- pass
- return False
-
- def _word_freq(self, word: str) -> float:
- """Get word frequency (max of EN and DE)."""
- w = word.lower()
- return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
-
- def _known_in(self, word: str, lang: str) -> bool:
- """True if word is known in a specific language dictionary."""
- w = word.lower()
- spell = self.en if lang == "en" else self.de
- return bool(spell.known([w]))
-
- def correct_word(self, word: str, lang: str = "en",
- prev_word: str = "", next_word: str = "") -> Optional[str]:
- """Correct a single word for the given language.
-
- Returns None if no correction needed, or the corrected string.
-
- Args:
- word: The word to check/correct
- lang: Expected language ("en" or "de")
- prev_word: Previous word (for context)
- next_word: Next word (for context)
- """
- if not word or not word.strip():
- return None
-
- # Skip numbers, abbreviations with dots, very short tokens
- if word.isdigit() or '.' in word:
- return None
-
- # Skip IPA/phonetic content in brackets
- if '[' in word or ']' in word:
- return None
-
- has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
-
- # 1. Already known → no fix
- if self._known(word):
- # But check a/I disambiguation for single-char words
- if word.lower() in ('l', '|') and next_word:
- return self._disambiguate_a_I(word, next_word)
- return None
-
- # 2. Digit/pipe substitution
- if has_suspicious:
- if word == '|':
- return 'I'
- # Try single-char substitutions
- for i, ch in enumerate(word):
- if ch not in _DIGIT_SUBS:
- continue
- for replacement in _DIGIT_SUBS[ch]:
- candidate = word[:i] + replacement + word[i + 1:]
- if self._known(candidate):
- return candidate
- # Try multi-char substitution (e.g., "sch00l" → "school")
- multi = self._try_multi_digit_sub(word)
- if multi:
- return multi
-
- # 3. Umlaut correction (German)
- if lang == "de" and len(word) >= 3 and word.isalpha():
- umlaut_fix = self._try_umlaut_fix(word)
- if umlaut_fix:
- return umlaut_fix
-
- # 4. General spell correction
- if not has_suspicious and len(word) >= 3 and word.isalpha():
- # Safety: don't correct if the word is valid in the OTHER language
- # (either directly or via umlaut fix)
- other_lang = "de" if lang == "en" else "en"
- if self._known_in(word, other_lang):
- return None
- if other_lang == "de" and self._try_umlaut_fix(word):
- return None # has a valid DE umlaut variant → don't touch
-
- spell = self.en if lang == "en" else self.de
- correction = spell.correction(word.lower())
- if correction and correction != word.lower():
- if word[0].isupper():
- correction = correction[0].upper() + correction[1:]
- if self._known(correction):
- return correction
-
- return None
-
- # --- Multi-digit substitution ---
-
- def _try_multi_digit_sub(self, word: str) -> Optional[str]:
- """Try replacing multiple digits simultaneously."""
- positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
- if len(positions) < 1 or len(positions) > 4:
- return None
-
- # Try all combinations (max 2^4 = 16 for 4 positions)
- chars = list(word)
- best = None
- self._multi_sub_recurse(chars, positions, 0, best_result=[None])
- return self._multi_sub_recurse_result
-
- _multi_sub_recurse_result: Optional[str] = None
-
- def _try_multi_digit_sub(self, word: str) -> Optional[str]:
- """Try replacing multiple digits simultaneously using BFS."""
- positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
- if not positions or len(positions) > 4:
- return None
-
- # BFS over substitution combinations
- queue = [list(word)]
- for pos, ch in positions:
- next_queue = []
- for current in queue:
- # Keep original
- next_queue.append(current[:])
- # Try each substitution
- for repl in _DIGIT_SUBS[ch]:
- variant = current[:]
- variant[pos] = repl
- next_queue.append(variant)
- queue = next_queue
-
- # Check which combinations produce known words
- for combo in queue:
- candidate = "".join(combo)
- if candidate != word and self._known(candidate):
- return candidate
-
- return None
-
- # --- Umlaut fix ---
-
- def _try_umlaut_fix(self, word: str) -> Optional[str]:
- """Try single-char umlaut substitutions for German words."""
- for i, ch in enumerate(word):
- if ch in _UMLAUT_MAP:
- candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
- if self._known(candidate):
- return candidate
- return None
-
- # --- Boundary repair (shifted word boundaries) ---
-
- def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
- """Fix shifted word boundaries between adjacent tokens.
-
- OCR sometimes shifts the boundary: "at sth." → "ats th."
- Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
- Returns (fixed_word1, fixed_word2) or None.
- """
- # Import known abbreviations for vocabulary context
- try:
- from cv_ocr_engines import _KNOWN_ABBREVIATIONS
- except ImportError:
- _KNOWN_ABBREVIATIONS = set()
-
- # Strip trailing punctuation for checking, preserve for result
- w2_stripped = word2.rstrip(".,;:!?")
- w2_punct = word2[len(w2_stripped):]
-
- # Try shifting 1-2 chars from word1 → word2
- for shift in (1, 2):
- if len(word1) <= shift:
- continue
- new_w1 = word1[:-shift]
- new_w2_base = word1[-shift:] + w2_stripped
-
- w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
- w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
-
- if w1_ok and w2_ok:
- return (new_w1, new_w2_base + w2_punct)
-
- # Try shifting 1-2 chars from word2 → word1
- for shift in (1, 2):
- if len(w2_stripped) <= shift:
- continue
- new_w1 = word1 + w2_stripped[:shift]
- new_w2_base = w2_stripped[shift:]
-
- w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
- w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
-
- if w1_ok and w2_ok:
- return (new_w1, new_w2_base + w2_punct)
-
- return None
-
- # --- Context-based word split for ambiguous merges ---
-
- # Patterns where a valid word is actually "a" + adjective/noun
- _ARTICLE_SPLIT_CANDIDATES = {
- # word → (article, remainder) — only when followed by a compatible word
- "anew": ("a", "new"),
- "areal": ("a", "real"),
- "alive": None, # genuinely one word, never split
- "alone": None,
- "aware": None,
- "alike": None,
- "apart": None,
- "aside": None,
- "above": None,
- "about": None,
- "among": None,
- "along": None,
- }
-
- def _try_context_split(self, word: str, next_word: str,
- prev_word: str) -> Optional[str]:
- """Split words like 'anew' → 'a new' when context indicates a merge.
-
- Only splits when:
- - The word is in the split candidates list
- - The following word makes sense as a noun (for "a + adj + noun" pattern)
- - OR the word is unknown and can be split into article + known word
- """
- w_lower = word.lower()
-
- # Check explicit candidates
- if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
- split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
- if split is None:
- return None # explicitly marked as "don't split"
- article, remainder = split
- # Only split if followed by a word (noun pattern)
- if next_word and next_word[0].islower():
- return f"{article} {remainder}"
- # Also split if remainder + next_word makes a common phrase
- if next_word and self._known(next_word):
- return f"{article} {remainder}"
-
- # Generic: if word starts with 'a' and rest is a known adjective/word
- if (len(word) >= 4 and word[0].lower() == 'a'
- and not self._known(word) # only for UNKNOWN words
- and self._known(word[1:])):
- return f"a {word[1:]}"
-
- return None
-
- # --- a/I disambiguation ---
-
- def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
- """Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
- nw = next_word.lower().strip(".,;:!?")
- if nw in _I_FOLLOWERS:
- return "I"
- if nw in _A_FOLLOWERS:
- return "a"
- # Fallback: check if next word is more commonly a verb (→I) or noun/adj (→a)
- # Simple heuristic: if next word starts with uppercase (and isn't first in sentence)
- # it's likely a German noun following "I"... but in English context, uppercase
- # after "I" is unusual.
- return None # uncertain, don't change
-
- # --- Full text correction ---
-
- def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
- """Correct a full text string (field value).
-
- Three passes:
- 1. Boundary repair — fix shifted word boundaries between adjacent tokens
- 2. Context split — split ambiguous merges (anew → a new)
- 3. Per-word correction — spell check individual words
-
- Args:
- text: The text to correct
- lang: Expected language ("en" or "de")
- """
- if not text or not text.strip():
- return CorrectionResult(text, text, "unknown", False)
-
- detected = self.detect_text_lang(text) if lang == "auto" else lang
- effective_lang = detected if detected in ("en", "de") else "en"
-
- changes: List[str] = []
- tokens = list(_TOKEN_RE.finditer(text))
-
- # Extract token list: [(word, separator), ...]
- token_list: List[List[str]] = [] # [[word, sep], ...]
- for m in tokens:
- token_list.append([m.group(1), m.group(2)])
-
- # --- Pass 1: Boundary repair between adjacent unknown words ---
- # Import abbreviations for the heuristic below
- try:
- from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
- except ImportError:
- _ABBREVS = set()
-
- for i in range(len(token_list) - 1):
- w1 = token_list[i][0]
- w2_raw = token_list[i + 1][0]
-
- # Skip boundary repair for IPA/bracket content
- # Brackets may be in the token OR in the adjacent separators
- sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
- sep_after_w1 = token_list[i][1]
- sep_after_w2 = token_list[i + 1][1]
- has_bracket = (
- '[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
- or ']' in sep_after_w1 # w1 text was inside [brackets]
- or '[' in sep_after_w1 # w2 starts a bracket
- or ']' in sep_after_w2 # w2 text was inside [brackets]
- or '[' in sep_before_w1 # w1 starts a bracket
- )
- if has_bracket:
- continue
-
- # Include trailing punct from separator in w2 for abbreviation matching
- w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
-
- # Try boundary repair — always, even if both words are valid.
- # Use word-frequency scoring to decide if repair is better.
- repair = self._try_boundary_repair(w1, w2_with_punct)
- if not repair and w2_with_punct != w2_raw:
- repair = self._try_boundary_repair(w1, w2_raw)
- if repair:
- new_w1, new_w2_full = repair
- new_w2_base = new_w2_full.rstrip(".,;:!?")
-
- # Frequency-based scoring: product of word frequencies
- # Higher product = more common word pair = better
- old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
- new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
-
- # Abbreviation bonus: if repair produces a known abbreviation
- has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
- if has_abbrev:
- # Accept abbreviation repair ONLY if at least one of the
- # original words is rare/unknown (prevents "Can I" → "Ca nI"
- # where both original words are common and correct).
- # "Rare" = frequency < 1e-6 (covers "ats", "th" but not "Can", "I")
- RARE_THRESHOLD = 1e-6
- orig_both_common = (
- self._word_freq(w1) > RARE_THRESHOLD
- and self._word_freq(w2_raw) > RARE_THRESHOLD
- )
- if not orig_both_common:
- new_freq = max(new_freq, old_freq * 10)
- else:
- has_abbrev = False # both originals common → don't trust
-
- # Accept if repair produces a more frequent word pair
- # (threshold: at least 5x more frequent to avoid false positives)
- if new_freq > old_freq * 5:
- new_w2_punct = new_w2_full[len(new_w2_base):]
- changes.append(f"{w1} {w2_raw}→{new_w1} {new_w2_base}")
- token_list[i][0] = new_w1
- token_list[i + 1][0] = new_w2_base
- if new_w2_punct:
- token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
-
- # --- Pass 2: Context split (anew → a new) ---
- expanded: List[List[str]] = []
- for i, (word, sep) in enumerate(token_list):
- next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
- prev_word = token_list[i - 1][0] if i > 0 else ""
- split = self._try_context_split(word, next_word, prev_word)
- if split and split != word:
- changes.append(f"{word}→{split}")
- expanded.append([split, sep])
- else:
- expanded.append([word, sep])
- token_list = expanded
-
- # --- Pass 3: Per-word correction ---
- parts: List[str] = []
-
- # Preserve any leading text before the first token match
- # (e.g., "(= " before "I won and he lost.")
- first_start = tokens[0].start() if tokens else 0
- if first_start > 0:
- parts.append(text[:first_start])
-
- for i, (word, sep) in enumerate(token_list):
- # Skip words inside IPA brackets (brackets land in separators)
- prev_sep = token_list[i - 1][1] if i > 0 else ""
- if '[' in prev_sep or ']' in sep:
- parts.append(word)
- parts.append(sep)
- continue
-
- next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
- prev_word = token_list[i - 1][0] if i > 0 else ""
-
- correction = self.correct_word(
- word, lang=effective_lang,
- prev_word=prev_word, next_word=next_word,
- )
- if correction and correction != word:
- changes.append(f"{word}→{correction}")
- parts.append(correction)
- else:
- parts.append(word)
- parts.append(sep)
-
- # Append any trailing text
- last_end = tokens[-1].end() if tokens else 0
- if last_end < len(text):
- parts.append(text[last_end:])
-
- corrected = "".join(parts)
- return CorrectionResult(
- original=text,
- corrected=corrected,
- lang_detected=detected,
- changed=corrected != text,
- changes=changes,
- )
-
- # --- Vocabulary entry correction ---
-
- def correct_vocab_entry(self, english: str, german: str,
- example: str = "") -> Dict[str, CorrectionResult]:
- """Correct a full vocabulary entry (EN + DE + example).
-
- Uses column position to determine language — the most reliable signal.
- """
- results = {}
- results["english"] = self.correct_text(english, lang="en")
- results["german"] = self.correct_text(german, lang="de")
- if example:
- # For examples, auto-detect language
- results["example"] = self.correct_text(example, lang="auto")
- return results
+# Core: data types, lang detection (re-exported for tests)
+from smart_spell_core import ( # noqa: F401
+ _AVAILABLE,
+ _DIGIT_SUBS,
+ _SUSPICIOUS_CHARS,
+ _UMLAUT_MAP,
+ _TOKEN_RE,
+ _I_FOLLOWERS,
+ _A_FOLLOWERS,
+ CorrectionResult,
+ Lang,
+)
+
+# Text: SmartSpellChecker class (the main public API)
+from smart_spell_text import SmartSpellChecker # noqa: F401
diff --git a/klausur-service/backend/smart_spell_core.py b/klausur-service/backend/smart_spell_core.py
new file mode 100644
index 0000000..9f2fa7d
--- /dev/null
+++ b/klausur-service/backend/smart_spell_core.py
@@ -0,0 +1,298 @@
+"""
+SmartSpellChecker Core — init, data types, language detection, word correction.
+
+Extracted from smart_spell.py for modularity.
+
+Lizenz: Apache 2.0 (kommerziell nutzbar)
+"""
+
+import logging
+import re
+from dataclasses import dataclass, field
+from typing import Dict, List, Literal, Optional, Set, Tuple
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# Init
+# ---------------------------------------------------------------------------
+
+try:
+ from spellchecker import SpellChecker as _SpellChecker
+ _en_spell = _SpellChecker(language='en', distance=1)
+ _de_spell = _SpellChecker(language='de', distance=1)
+ _AVAILABLE = True
+except ImportError:
+ _AVAILABLE = False
+ logger.warning("pyspellchecker not installed — SmartSpellChecker disabled")
+
+Lang = Literal["en", "de", "both", "unknown"]
+
+# ---------------------------------------------------------------------------
+# Bigram context for a/I disambiguation
+# ---------------------------------------------------------------------------
+
+# Words that commonly follow "I" (subject pronoun -> verb/modal)
+_I_FOLLOWERS: frozenset = frozenset({
+ "am", "was", "have", "had", "do", "did", "will", "would", "can",
+ "could", "should", "shall", "may", "might", "must",
+ "think", "know", "see", "want", "need", "like", "love", "hate",
+ "go", "went", "come", "came", "say", "said", "get", "got",
+ "make", "made", "take", "took", "give", "gave", "tell", "told",
+ "feel", "felt", "find", "found", "believe", "hope", "wish",
+ "remember", "forget", "understand", "mean", "meant",
+ "don't", "didn't", "can't", "won't", "couldn't", "wouldn't",
+ "shouldn't", "haven't", "hadn't", "isn't", "wasn't",
+ "really", "just", "also", "always", "never", "often", "sometimes",
+})
+
+# Words that commonly follow "a" (article -> noun/adjective)
+_A_FOLLOWERS: frozenset = frozenset({
+ "lot", "few", "little", "bit", "good", "bad", "great", "new", "old",
+ "long", "short", "big", "small", "large", "huge", "tiny",
+ "nice", "beautiful", "wonderful", "terrible", "horrible",
+ "man", "woman", "boy", "girl", "child", "dog", "cat", "bird",
+ "book", "car", "house", "room", "school", "teacher", "student",
+ "day", "week", "month", "year", "time", "place", "way",
+ "friend", "family", "person", "problem", "question", "story",
+ "very", "really", "quite", "rather", "pretty", "single",
+})
+
+# Digit->letter substitutions (OCR confusion)
+_DIGIT_SUBS: Dict[str, List[str]] = {
+ '0': ['o', 'O'],
+ '1': ['l', 'I'],
+ '5': ['s', 'S'],
+ '6': ['g', 'G'],
+ '8': ['b', 'B'],
+ '|': ['I', 'l'],
+ '/': ['l'], # italic 'l' misread as slash (e.g. "p/" -> "pl")
+}
+_SUSPICIOUS_CHARS = frozenset(_DIGIT_SUBS.keys())
+
+# Umlaut confusion: OCR drops dots (u->u, a->a, o->o)
+_UMLAUT_MAP = {
+ 'a': '\u00e4', 'o': '\u00f6', 'u': '\u00fc', 'i': '\u00fc',
+ 'A': '\u00c4', 'O': '\u00d6', 'U': '\u00dc', 'I': '\u00dc',
+}
+
+# Tokenizer -- includes | and / so OCR artifacts like "p/" are treated as words
+_TOKEN_RE = re.compile(r"([A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]+)([^A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df'|/]*)")
+
+
+# ---------------------------------------------------------------------------
+# Data types
+# ---------------------------------------------------------------------------
+
+@dataclass
+class CorrectionResult:
+ original: str
+ corrected: str
+ lang_detected: Lang
+ changed: bool
+ changes: List[str] = field(default_factory=list)
+
+
+# ---------------------------------------------------------------------------
+# Core class — language detection and word-level correction
+# ---------------------------------------------------------------------------
+
+class _SmartSpellCoreBase:
+ """Base class with language detection and single-word correction.
+
+ Not intended for direct use — SmartSpellChecker inherits from this.
+ """
+
+ def __init__(self):
+ if not _AVAILABLE:
+ raise RuntimeError("pyspellchecker not installed")
+ self.en = _en_spell
+ self.de = _de_spell
+
+ # --- Language detection ---
+
+ def detect_word_lang(self, word: str) -> Lang:
+ """Detect language of a single word using dual-dict heuristic."""
+ w = word.lower().strip(".,;:!?\"'()")
+ if not w:
+ return "unknown"
+ in_en = bool(self.en.known([w]))
+ in_de = bool(self.de.known([w]))
+ if in_en and in_de:
+ return "both"
+ if in_en:
+ return "en"
+ if in_de:
+ return "de"
+ return "unknown"
+
+ def detect_text_lang(self, text: str) -> Lang:
+ """Detect dominant language of a text string (sentence/phrase)."""
+ words = re.findall(r"[A-Za-z\u00c4\u00d6\u00dc\u00e4\u00f6\u00fc\u00df]+", text)
+ if not words:
+ return "unknown"
+
+ en_count = 0
+ de_count = 0
+ for w in words:
+ lang = self.detect_word_lang(w)
+ if lang == "en":
+ en_count += 1
+ elif lang == "de":
+ de_count += 1
+ # "both" doesn't count for either
+
+ if en_count > de_count:
+ return "en"
+ if de_count > en_count:
+ return "de"
+ if en_count == de_count and en_count > 0:
+ return "both"
+ return "unknown"
+
+ # --- Single-word correction ---
+
+ def _known(self, word: str) -> bool:
+ """True if word is known in EN or DE dictionary, or is a known abbreviation."""
+ w = word.lower()
+ if bool(self.en.known([w])) or bool(self.de.known([w])):
+ return True
+ # Also accept known abbreviations (sth, sb, adj, etc.)
+ try:
+ from cv_ocr_engines import _KNOWN_ABBREVIATIONS
+ if w in _KNOWN_ABBREVIATIONS:
+ return True
+ except ImportError:
+ pass
+ return False
+
+ def _word_freq(self, word: str) -> float:
+ """Get word frequency (max of EN and DE)."""
+ w = word.lower()
+ return max(self.en.word_usage_frequency(w), self.de.word_usage_frequency(w))
+
+ def _known_in(self, word: str, lang: str) -> bool:
+ """True if word is known in a specific language dictionary."""
+ w = word.lower()
+ spell = self.en if lang == "en" else self.de
+ return bool(spell.known([w]))
+
+ def correct_word(self, word: str, lang: str = "en",
+ prev_word: str = "", next_word: str = "") -> Optional[str]:
+ """Correct a single word for the given language.
+
+ Returns None if no correction needed, or the corrected string.
+ """
+ if not word or not word.strip():
+ return None
+
+ # Skip numbers, abbreviations with dots, very short tokens
+ if word.isdigit() or '.' in word:
+ return None
+
+ # Skip IPA/phonetic content in brackets
+ if '[' in word or ']' in word:
+ return None
+
+ has_suspicious = any(ch in _SUSPICIOUS_CHARS for ch in word)
+
+ # 1. Already known -> no fix
+ if self._known(word):
+ # But check a/I disambiguation for single-char words
+ if word.lower() in ('l', '|') and next_word:
+ return self._disambiguate_a_I(word, next_word)
+ return None
+
+ # 2. Digit/pipe substitution
+ if has_suspicious:
+ if word == '|':
+ return 'I'
+ # Try single-char substitutions
+ for i, ch in enumerate(word):
+ if ch not in _DIGIT_SUBS:
+ continue
+ for replacement in _DIGIT_SUBS[ch]:
+ candidate = word[:i] + replacement + word[i + 1:]
+ if self._known(candidate):
+ return candidate
+ # Try multi-char substitution (e.g., "sch00l" -> "school")
+ multi = self._try_multi_digit_sub(word)
+ if multi:
+ return multi
+
+ # 3. Umlaut correction (German)
+ if lang == "de" and len(word) >= 3 and word.isalpha():
+ umlaut_fix = self._try_umlaut_fix(word)
+ if umlaut_fix:
+ return umlaut_fix
+
+ # 4. General spell correction
+ if not has_suspicious and len(word) >= 3 and word.isalpha():
+ # Safety: don't correct if the word is valid in the OTHER language
+ other_lang = "de" if lang == "en" else "en"
+ if self._known_in(word, other_lang):
+ return None
+ if other_lang == "de" and self._try_umlaut_fix(word):
+ return None # has a valid DE umlaut variant -> don't touch
+
+ spell = self.en if lang == "en" else self.de
+ correction = spell.correction(word.lower())
+ if correction and correction != word.lower():
+ if word[0].isupper():
+ correction = correction[0].upper() + correction[1:]
+ if self._known(correction):
+ return correction
+
+ return None
+
+ # --- Multi-digit substitution ---
+
+ def _try_multi_digit_sub(self, word: str) -> Optional[str]:
+ """Try replacing multiple digits simultaneously using BFS."""
+ positions = [(i, ch) for i, ch in enumerate(word) if ch in _DIGIT_SUBS]
+ if not positions or len(positions) > 4:
+ return None
+
+ # BFS over substitution combinations
+ queue = [list(word)]
+ for pos, ch in positions:
+ next_queue = []
+ for current in queue:
+ # Keep original
+ next_queue.append(current[:])
+ # Try each substitution
+ for repl in _DIGIT_SUBS[ch]:
+ variant = current[:]
+ variant[pos] = repl
+ next_queue.append(variant)
+ queue = next_queue
+
+ # Check which combinations produce known words
+ for combo in queue:
+ candidate = "".join(combo)
+ if candidate != word and self._known(candidate):
+ return candidate
+
+ return None
+
+ # --- Umlaut fix ---
+
+ def _try_umlaut_fix(self, word: str) -> Optional[str]:
+ """Try single-char umlaut substitutions for German words."""
+ for i, ch in enumerate(word):
+ if ch in _UMLAUT_MAP:
+ candidate = word[:i] + _UMLAUT_MAP[ch] + word[i + 1:]
+ if self._known(candidate):
+ return candidate
+ return None
+
+ # --- a/I disambiguation ---
+
+ def _disambiguate_a_I(self, token: str, next_word: str) -> Optional[str]:
+ """Disambiguate 'a' vs 'I' (and OCR variants like 'l', '|')."""
+ nw = next_word.lower().strip(".,;:!?")
+ if nw in _I_FOLLOWERS:
+ return "I"
+ if nw in _A_FOLLOWERS:
+ return "a"
+ return None # uncertain, don't change
diff --git a/klausur-service/backend/smart_spell_text.py b/klausur-service/backend/smart_spell_text.py
new file mode 100644
index 0000000..7628e61
--- /dev/null
+++ b/klausur-service/backend/smart_spell_text.py
@@ -0,0 +1,289 @@
+"""
+SmartSpellChecker Text — full text correction, boundary repair, context split.
+
+Extracted from smart_spell.py for modularity.
+
+Lizenz: Apache 2.0 (kommerziell nutzbar)
+"""
+
+import re
+from typing import Dict, List, Optional, Tuple
+
+from smart_spell_core import (
+ _SmartSpellCoreBase,
+ _TOKEN_RE,
+ CorrectionResult,
+ Lang,
+)
+
+
+class SmartSpellChecker(_SmartSpellCoreBase):
+ """Language-aware OCR spell checker using pyspellchecker (no LLM).
+
+ Inherits single-word correction from _SmartSpellCoreBase.
+ Adds text-level passes: boundary repair, context split, full correction.
+ """
+
+ # --- Boundary repair (shifted word boundaries) ---
+
+ def _try_boundary_repair(self, word1: str, word2: str) -> Optional[Tuple[str, str]]:
+ """Fix shifted word boundaries between adjacent tokens.
+
+ OCR sometimes shifts the boundary: "at sth." -> "ats th."
+ Try moving 1-2 chars from end of word1 to start of word2 and vice versa.
+ Returns (fixed_word1, fixed_word2) or None.
+ """
+ # Import known abbreviations for vocabulary context
+ try:
+ from cv_ocr_engines import _KNOWN_ABBREVIATIONS
+ except ImportError:
+ _KNOWN_ABBREVIATIONS = set()
+
+ # Strip trailing punctuation for checking, preserve for result
+ w2_stripped = word2.rstrip(".,;:!?")
+ w2_punct = word2[len(w2_stripped):]
+
+ # Try shifting 1-2 chars from word1 -> word2
+ for shift in (1, 2):
+ if len(word1) <= shift:
+ continue
+ new_w1 = word1[:-shift]
+ new_w2_base = word1[-shift:] + w2_stripped
+
+ w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
+ w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
+
+ if w1_ok and w2_ok:
+ return (new_w1, new_w2_base + w2_punct)
+
+ # Try shifting 1-2 chars from word2 -> word1
+ for shift in (1, 2):
+ if len(w2_stripped) <= shift:
+ continue
+ new_w1 = word1 + w2_stripped[:shift]
+ new_w2_base = w2_stripped[shift:]
+
+ w1_ok = self._known(new_w1) or new_w1.lower() in _KNOWN_ABBREVIATIONS
+ w2_ok = self._known(new_w2_base) or new_w2_base.lower() in _KNOWN_ABBREVIATIONS
+
+ if w1_ok and w2_ok:
+ return (new_w1, new_w2_base + w2_punct)
+
+ return None
+
+ # --- Context-based word split for ambiguous merges ---
+
+ # Patterns where a valid word is actually "a" + adjective/noun
+ _ARTICLE_SPLIT_CANDIDATES = {
+ # word -> (article, remainder) -- only when followed by a compatible word
+ "anew": ("a", "new"),
+ "areal": ("a", "real"),
+ "alive": None, # genuinely one word, never split
+ "alone": None,
+ "aware": None,
+ "alike": None,
+ "apart": None,
+ "aside": None,
+ "above": None,
+ "about": None,
+ "among": None,
+ "along": None,
+ }
+
+ def _try_context_split(self, word: str, next_word: str,
+ prev_word: str) -> Optional[str]:
+ """Split words like 'anew' -> 'a new' when context indicates a merge.
+
+ Only splits when:
+ - The word is in the split candidates list
+ - The following word makes sense as a noun (for "a + adj + noun" pattern)
+ - OR the word is unknown and can be split into article + known word
+ """
+ w_lower = word.lower()
+
+ # Check explicit candidates
+ if w_lower in self._ARTICLE_SPLIT_CANDIDATES:
+ split = self._ARTICLE_SPLIT_CANDIDATES[w_lower]
+ if split is None:
+ return None # explicitly marked as "don't split"
+ article, remainder = split
+ # Only split if followed by a word (noun pattern)
+ if next_word and next_word[0].islower():
+ return f"{article} {remainder}"
+ # Also split if remainder + next_word makes a common phrase
+ if next_word and self._known(next_word):
+ return f"{article} {remainder}"
+
+ # Generic: if word starts with 'a' and rest is a known adjective/word
+ if (len(word) >= 4 and word[0].lower() == 'a'
+ and not self._known(word) # only for UNKNOWN words
+ and self._known(word[1:])):
+ return f"a {word[1:]}"
+
+ return None
+
+ # --- Full text correction ---
+
+ def correct_text(self, text: str, lang: str = "en") -> CorrectionResult:
+ """Correct a full text string (field value).
+
+ Three passes:
+ 1. Boundary repair -- fix shifted word boundaries between adjacent tokens
+ 2. Context split -- split ambiguous merges (anew -> a new)
+ 3. Per-word correction -- spell check individual words
+ """
+ if not text or not text.strip():
+ return CorrectionResult(text, text, "unknown", False)
+
+ detected = self.detect_text_lang(text) if lang == "auto" else lang
+ effective_lang = detected if detected in ("en", "de") else "en"
+
+ changes: List[str] = []
+ tokens = list(_TOKEN_RE.finditer(text))
+
+ # Extract token list: [(word, separator), ...]
+ token_list: List[List[str]] = [] # [[word, sep], ...]
+ for m in tokens:
+ token_list.append([m.group(1), m.group(2)])
+
+ # --- Pass 1: Boundary repair between adjacent unknown words ---
+ # Import abbreviations for the heuristic below
+ try:
+ from cv_ocr_engines import _KNOWN_ABBREVIATIONS as _ABBREVS
+ except ImportError:
+ _ABBREVS = set()
+
+ for i in range(len(token_list) - 1):
+ w1 = token_list[i][0]
+ w2_raw = token_list[i + 1][0]
+
+ # Skip boundary repair for IPA/bracket content
+ # Brackets may be in the token OR in the adjacent separators
+ sep_before_w1 = token_list[i - 1][1] if i > 0 else ""
+ sep_after_w1 = token_list[i][1]
+ sep_after_w2 = token_list[i + 1][1]
+ has_bracket = (
+ '[' in w1 or ']' in w1 or '[' in w2_raw or ']' in w2_raw
+ or ']' in sep_after_w1 # w1 text was inside [brackets]
+ or '[' in sep_after_w1 # w2 starts a bracket
+ or ']' in sep_after_w2 # w2 text was inside [brackets]
+ or '[' in sep_before_w1 # w1 starts a bracket
+ )
+ if has_bracket:
+ continue
+
+ # Include trailing punct from separator in w2 for abbreviation matching
+ w2_with_punct = w2_raw + token_list[i + 1][1].rstrip(" ")
+
+ # Try boundary repair -- always, even if both words are valid.
+ # Use word-frequency scoring to decide if repair is better.
+ repair = self._try_boundary_repair(w1, w2_with_punct)
+ if not repair and w2_with_punct != w2_raw:
+ repair = self._try_boundary_repair(w1, w2_raw)
+ if repair:
+ new_w1, new_w2_full = repair
+ new_w2_base = new_w2_full.rstrip(".,;:!?")
+
+ # Frequency-based scoring: product of word frequencies
+ # Higher product = more common word pair = better
+ old_freq = self._word_freq(w1) * self._word_freq(w2_raw)
+ new_freq = self._word_freq(new_w1) * self._word_freq(new_w2_base)
+
+ # Abbreviation bonus: if repair produces a known abbreviation
+ has_abbrev = new_w1.lower() in _ABBREVS or new_w2_base.lower() in _ABBREVS
+ if has_abbrev:
+ # Accept abbreviation repair ONLY if at least one of the
+ # original words is rare/unknown (prevents "Can I" -> "Ca nI"
+ # where both original words are common and correct).
+ RARE_THRESHOLD = 1e-6
+ orig_both_common = (
+ self._word_freq(w1) > RARE_THRESHOLD
+ and self._word_freq(w2_raw) > RARE_THRESHOLD
+ )
+ if not orig_both_common:
+ new_freq = max(new_freq, old_freq * 10)
+ else:
+ has_abbrev = False # both originals common -> don't trust
+
+ # Accept if repair produces a more frequent word pair
+ # (threshold: at least 5x more frequent to avoid false positives)
+ if new_freq > old_freq * 5:
+ new_w2_punct = new_w2_full[len(new_w2_base):]
+ changes.append(f"{w1} {w2_raw}\u2192{new_w1} {new_w2_base}")
+ token_list[i][0] = new_w1
+ token_list[i + 1][0] = new_w2_base
+ if new_w2_punct:
+ token_list[i + 1][1] = new_w2_punct + token_list[i + 1][1].lstrip(".,;:!?")
+
+ # --- Pass 2: Context split (anew -> a new) ---
+ expanded: List[List[str]] = []
+ for i, (word, sep) in enumerate(token_list):
+ next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
+ prev_word = token_list[i - 1][0] if i > 0 else ""
+ split = self._try_context_split(word, next_word, prev_word)
+ if split and split != word:
+ changes.append(f"{word}\u2192{split}")
+ expanded.append([split, sep])
+ else:
+ expanded.append([word, sep])
+ token_list = expanded
+
+ # --- Pass 3: Per-word correction ---
+ parts: List[str] = []
+
+ # Preserve any leading text before the first token match
+ first_start = tokens[0].start() if tokens else 0
+ if first_start > 0:
+ parts.append(text[:first_start])
+
+ for i, (word, sep) in enumerate(token_list):
+ # Skip words inside IPA brackets (brackets land in separators)
+ prev_sep = token_list[i - 1][1] if i > 0 else ""
+ if '[' in prev_sep or ']' in sep:
+ parts.append(word)
+ parts.append(sep)
+ continue
+
+ next_word = token_list[i + 1][0] if i + 1 < len(token_list) else ""
+ prev_word = token_list[i - 1][0] if i > 0 else ""
+
+ correction = self.correct_word(
+ word, lang=effective_lang,
+ prev_word=prev_word, next_word=next_word,
+ )
+ if correction and correction != word:
+ changes.append(f"{word}\u2192{correction}")
+ parts.append(correction)
+ else:
+ parts.append(word)
+ parts.append(sep)
+
+ # Append any trailing text
+ last_end = tokens[-1].end() if tokens else 0
+ if last_end < len(text):
+ parts.append(text[last_end:])
+
+ corrected = "".join(parts)
+ return CorrectionResult(
+ original=text,
+ corrected=corrected,
+ lang_detected=detected,
+ changed=corrected != text,
+ changes=changes,
+ )
+
+ # --- Vocabulary entry correction ---
+
+ def correct_vocab_entry(self, english: str, german: str,
+ example: str = "") -> Dict[str, CorrectionResult]:
+ """Correct a full vocabulary entry (EN + DE + example).
+
+ Uses column position to determine language -- the most reliable signal.
+ """
+ results = {}
+ results["english"] = self.correct_text(english, lang="en")
+ results["german"] = self.correct_text(german, lang="de")
+ if example:
+ # For examples, auto-detect language
+ results["example"] = self.correct_text(example, lang="auto")
+ return results
diff --git a/klausur-service/backend/upload_api.py b/klausur-service/backend/upload_api.py
index 6846192..98e6f1a 100644
--- a/klausur-service/backend/upload_api.py
+++ b/klausur-service/backend/upload_api.py
@@ -1,602 +1,29 @@
"""
-Mobile Upload API for Klausur-Service
+Mobile Upload API — barrel re-export.
+
+All implementation split into:
+ upload_api_chunked — chunked upload endpoints (init, chunk, finalize, simple, status, cancel, list)
+ upload_api_mobile — mobile HTML upload page
-Provides chunked upload endpoints for large PDF files (100MB+) from mobile devices.
DSGVO-konform: Data stays local in WLAN, no external transmission.
"""
-import os
-import uuid
-import shutil
-import hashlib
-from pathlib import Path
-from datetime import datetime, timezone
-from typing import Dict, Optional
-
-from fastapi import APIRouter, HTTPException, UploadFile, File, Form
-from fastapi.responses import HTMLResponse
-from pydantic import BaseModel
-
-# Configuration
-UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
-CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
-EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
-
-# Ensure directories exist
-UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
-CHUNK_DIR.mkdir(parents=True, exist_ok=True)
-EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
-
-# In-memory storage for upload sessions (for simplicity)
-# In production, use Redis or database
-_upload_sessions: Dict[str, dict] = {}
-
-router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
-
-
-class InitUploadRequest(BaseModel):
- filename: str
- filesize: int
- chunks: int
- destination: str = "klausur" # "klausur" or "rag"
-
-
-class InitUploadResponse(BaseModel):
- upload_id: str
- chunk_size: int
- total_chunks: int
- message: str
-
-
-class ChunkUploadResponse(BaseModel):
- upload_id: str
- chunk_index: int
- received: bool
- chunks_received: int
- total_chunks: int
-
-
-class FinalizeResponse(BaseModel):
- upload_id: str
- filename: str
- filepath: str
- filesize: int
- checksum: str
- message: str
-
-
-@router.post("/init", response_model=InitUploadResponse)
-async def init_upload(request: InitUploadRequest):
- """
- Initialize a chunked upload session.
-
- Returns an upload_id that must be used for subsequent chunk uploads.
- """
- upload_id = str(uuid.uuid4())
-
- # Create session directory
- session_dir = CHUNK_DIR / upload_id
- session_dir.mkdir(parents=True, exist_ok=True)
-
- # Store session info
- _upload_sessions[upload_id] = {
- "filename": request.filename,
- "filesize": request.filesize,
- "total_chunks": request.chunks,
- "received_chunks": set(),
- "destination": request.destination,
- "session_dir": str(session_dir),
- "created_at": datetime.now(timezone.utc).isoformat(),
- }
-
- return InitUploadResponse(
- upload_id=upload_id,
- chunk_size=5 * 1024 * 1024, # 5 MB
- total_chunks=request.chunks,
- message="Upload-Session erstellt"
- )
-
-
-@router.post("/chunk", response_model=ChunkUploadResponse)
-async def upload_chunk(
- chunk: UploadFile = File(...),
- upload_id: str = Form(...),
- chunk_index: int = Form(...)
-):
- """
- Upload a single chunk of a file.
-
- Chunks are stored temporarily until finalize is called.
- """
- if upload_id not in _upload_sessions:
- raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
-
- session = _upload_sessions[upload_id]
-
- if chunk_index < 0 or chunk_index >= session["total_chunks"]:
- raise HTTPException(
- status_code=400,
- detail=f"Ungueltiger Chunk-Index: {chunk_index}"
- )
-
- # Save chunk
- chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
-
- with open(chunk_path, "wb") as f:
- content = await chunk.read()
- f.write(content)
-
- # Track received chunks
- session["received_chunks"].add(chunk_index)
-
- return ChunkUploadResponse(
- upload_id=upload_id,
- chunk_index=chunk_index,
- received=True,
- chunks_received=len(session["received_chunks"]),
- total_chunks=session["total_chunks"]
- )
-
-
-@router.post("/finalize", response_model=FinalizeResponse)
-async def finalize_upload(upload_id: str = Form(...)):
- """
- Finalize the upload by combining all chunks into a single file.
-
- Validates that all chunks were received and calculates checksum.
- """
- if upload_id not in _upload_sessions:
- raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
-
- session = _upload_sessions[upload_id]
-
- # Check if all chunks received
- if len(session["received_chunks"]) != session["total_chunks"]:
- missing = session["total_chunks"] - len(session["received_chunks"])
- raise HTTPException(
- status_code=400,
- detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
- )
-
- # Determine destination directory
- if session["destination"] == "rag":
- dest_dir = EH_UPLOAD_DIR
- else:
- dest_dir = UPLOAD_DIR
-
- # Generate unique filename
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- safe_filename = session["filename"].replace(" ", "_")
- final_filename = f"{timestamp}_{safe_filename}"
- final_path = dest_dir / final_filename
-
- # Combine chunks
- hasher = hashlib.sha256()
- total_size = 0
-
- with open(final_path, "wb") as outfile:
- for i in range(session["total_chunks"]):
- chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
-
- if not chunk_path.exists():
- raise HTTPException(
- status_code=500,
- detail=f"Chunk {i} nicht gefunden"
- )
-
- with open(chunk_path, "rb") as infile:
- data = infile.read()
- outfile.write(data)
- hasher.update(data)
- total_size += len(data)
-
- # Clean up chunks
- shutil.rmtree(session["session_dir"], ignore_errors=True)
- del _upload_sessions[upload_id]
-
- checksum = hasher.hexdigest()
-
- return FinalizeResponse(
- upload_id=upload_id,
- filename=final_filename,
- filepath=str(final_path),
- filesize=total_size,
- checksum=checksum,
- message="Upload erfolgreich abgeschlossen"
- )
-
-
-@router.post("/simple")
-async def simple_upload(
- file: UploadFile = File(...),
- destination: str = Form("klausur")
-):
- """
- Simple single-request upload for smaller files (<10MB).
-
- For larger files, use the chunked upload endpoints.
- """
- # Determine destination directory
- if destination == "rag":
- dest_dir = EH_UPLOAD_DIR
- else:
- dest_dir = UPLOAD_DIR
-
- # Generate unique filename
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
- final_filename = f"{timestamp}_{safe_filename}"
- final_path = dest_dir / final_filename
-
- # Calculate checksum while writing
- hasher = hashlib.sha256()
- total_size = 0
-
- with open(final_path, "wb") as f:
- while True:
- chunk = await file.read(1024 * 1024) # Read 1MB at a time
- if not chunk:
- break
- f.write(chunk)
- hasher.update(chunk)
- total_size += len(chunk)
-
- return {
- "filename": final_filename,
- "filepath": str(final_path),
- "filesize": total_size,
- "checksum": hasher.hexdigest(),
- "message": "Upload erfolgreich"
- }
-
-
-@router.get("/status/{upload_id}")
-async def get_upload_status(upload_id: str):
- """
- Get the status of an ongoing upload.
- """
- if upload_id not in _upload_sessions:
- raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
-
- session = _upload_sessions[upload_id]
-
- return {
- "upload_id": upload_id,
- "filename": session["filename"],
- "total_chunks": session["total_chunks"],
- "received_chunks": len(session["received_chunks"]),
- "progress_percent": round(
- len(session["received_chunks"]) / session["total_chunks"] * 100, 1
- ),
- "destination": session["destination"],
- "created_at": session["created_at"]
- }
-
-
-@router.delete("/cancel/{upload_id}")
-async def cancel_upload(upload_id: str):
- """
- Cancel an ongoing upload and clean up temporary files.
- """
- if upload_id not in _upload_sessions:
- raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
-
- session = _upload_sessions[upload_id]
-
- # Clean up chunks
- shutil.rmtree(session["session_dir"], ignore_errors=True)
- del _upload_sessions[upload_id]
-
- return {"message": "Upload abgebrochen", "upload_id": upload_id}
-
-
-@router.get("/list")
-async def list_uploads(destination: str = "klausur"):
- """
- List all uploaded files in the specified destination.
- """
- if destination == "rag":
- dest_dir = EH_UPLOAD_DIR
- else:
- dest_dir = UPLOAD_DIR
-
- files = []
-
- for f in dest_dir.iterdir():
- if f.is_file() and f.suffix.lower() == ".pdf":
- stat = f.stat()
- files.append({
- "filename": f.name,
- "size": stat.st_size,
- "modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
- })
-
- files.sort(key=lambda x: x["modified"], reverse=True)
-
- return {
- "destination": destination,
- "count": len(files),
- "files": files[:50] # Limit to 50 most recent
- }
-
-
-@router.get("/mobile", response_class=HTMLResponse)
-async def mobile_upload_page():
- """
- Serve the mobile upload page directly from the klausur-service.
- This allows mobile devices to upload without needing the Next.js website.
- """
- from fastapi.responses import HTMLResponse
-
- html_content = '''
-
-
-
-
-
-
BreakPilot Upload
-
-
-
-
-
-
- Klausuren
- Erwartungshorizonte
-
-
-
-
-
☁
-
PDF-Dateien hochladen
-
Tippen zum Auswaehlen oder hierher ziehen
-
Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen
-
-
-
- 0 von 0 fertig
- 0 B gesamt
-
-
-
-
-
-
Hinweise:
-
- Die Dateien werden lokal im WLAN uebertragen
- Keine Daten werden ins Internet gesendet
- Unterstuetzte Formate: PDF
-
-
-
-
Server: wird ermittelt...
-
-
-
-'''
- return HTMLResponse(content=html_content)
+from fastapi import APIRouter
+
+from upload_api_chunked import ( # noqa: F401
+ router as _chunked_router,
+ UPLOAD_DIR,
+ CHUNK_DIR,
+ EH_UPLOAD_DIR,
+ _upload_sessions,
+ InitUploadRequest,
+ InitUploadResponse,
+ ChunkUploadResponse,
+ FinalizeResponse,
+)
+from upload_api_mobile import router as _mobile_router # noqa: F401
+
+# Composite router that includes both sub-routers
+router = APIRouter()
+router.include_router(_chunked_router)
+router.include_router(_mobile_router)
diff --git a/klausur-service/backend/upload_api_chunked.py b/klausur-service/backend/upload_api_chunked.py
new file mode 100644
index 0000000..13ddfff
--- /dev/null
+++ b/klausur-service/backend/upload_api_chunked.py
@@ -0,0 +1,320 @@
+"""
+Chunked Upload API — init, chunk, finalize, simple upload, status, cancel, list.
+
+Extracted from upload_api.py for modularity.
+
+DSGVO-konform: Data stays local in WLAN, no external transmission.
+"""
+
+import os
+import uuid
+import shutil
+import hashlib
+from pathlib import Path
+from datetime import datetime, timezone
+from typing import Dict, Optional
+
+from fastapi import APIRouter, HTTPException, UploadFile, File, Form
+from pydantic import BaseModel
+
+# Configuration
+UPLOAD_DIR = Path(os.getenv("UPLOAD_DIR", "/app/uploads"))
+CHUNK_DIR = Path(os.getenv("CHUNK_DIR", "/app/chunks"))
+EH_UPLOAD_DIR = Path(os.getenv("EH_UPLOAD_DIR", "/app/eh-uploads"))
+
+# Ensure directories exist
+UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
+CHUNK_DIR.mkdir(parents=True, exist_ok=True)
+EH_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
+
+# In-memory storage for upload sessions (for simplicity)
+# In production, use Redis or database
+_upload_sessions: Dict[str, dict] = {}
+
+router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
+
+
+class InitUploadRequest(BaseModel):
+ filename: str
+ filesize: int
+ chunks: int
+ destination: str = "klausur" # "klausur" or "rag"
+
+
+class InitUploadResponse(BaseModel):
+ upload_id: str
+ chunk_size: int
+ total_chunks: int
+ message: str
+
+
+class ChunkUploadResponse(BaseModel):
+ upload_id: str
+ chunk_index: int
+ received: bool
+ chunks_received: int
+ total_chunks: int
+
+
+class FinalizeResponse(BaseModel):
+ upload_id: str
+ filename: str
+ filepath: str
+ filesize: int
+ checksum: str
+ message: str
+
+
+@router.post("/init", response_model=InitUploadResponse)
+async def init_upload(request: InitUploadRequest):
+ """
+ Initialize a chunked upload session.
+
+ Returns an upload_id that must be used for subsequent chunk uploads.
+ """
+ upload_id = str(uuid.uuid4())
+
+ # Create session directory
+ session_dir = CHUNK_DIR / upload_id
+ session_dir.mkdir(parents=True, exist_ok=True)
+
+ # Store session info
+ _upload_sessions[upload_id] = {
+ "filename": request.filename,
+ "filesize": request.filesize,
+ "total_chunks": request.chunks,
+ "received_chunks": set(),
+ "destination": request.destination,
+ "session_dir": str(session_dir),
+ "created_at": datetime.now(timezone.utc).isoformat(),
+ }
+
+ return InitUploadResponse(
+ upload_id=upload_id,
+ chunk_size=5 * 1024 * 1024, # 5 MB
+ total_chunks=request.chunks,
+ message="Upload-Session erstellt"
+ )
+
+
+@router.post("/chunk", response_model=ChunkUploadResponse)
+async def upload_chunk(
+ chunk: UploadFile = File(...),
+ upload_id: str = Form(...),
+ chunk_index: int = Form(...)
+):
+ """
+ Upload a single chunk of a file.
+
+ Chunks are stored temporarily until finalize is called.
+ """
+ if upload_id not in _upload_sessions:
+ raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
+
+ session = _upload_sessions[upload_id]
+
+ if chunk_index < 0 or chunk_index >= session["total_chunks"]:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Ungueltiger Chunk-Index: {chunk_index}"
+ )
+
+ # Save chunk
+ chunk_path = Path(session["session_dir"]) / f"chunk_{chunk_index:05d}"
+
+ with open(chunk_path, "wb") as f:
+ content = await chunk.read()
+ f.write(content)
+
+ # Track received chunks
+ session["received_chunks"].add(chunk_index)
+
+ return ChunkUploadResponse(
+ upload_id=upload_id,
+ chunk_index=chunk_index,
+ received=True,
+ chunks_received=len(session["received_chunks"]),
+ total_chunks=session["total_chunks"]
+ )
+
+
+@router.post("/finalize", response_model=FinalizeResponse)
+async def finalize_upload(upload_id: str = Form(...)):
+ """
+ Finalize the upload by combining all chunks into a single file.
+
+ Validates that all chunks were received and calculates checksum.
+ """
+ if upload_id not in _upload_sessions:
+ raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
+
+ session = _upload_sessions[upload_id]
+
+ # Check if all chunks received
+ if len(session["received_chunks"]) != session["total_chunks"]:
+ missing = session["total_chunks"] - len(session["received_chunks"])
+ raise HTTPException(
+ status_code=400,
+ detail=f"Nicht alle Chunks empfangen. Fehlend: {missing}"
+ )
+
+ # Determine destination directory
+ if session["destination"] == "rag":
+ dest_dir = EH_UPLOAD_DIR
+ else:
+ dest_dir = UPLOAD_DIR
+
+ # Generate unique filename
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ safe_filename = session["filename"].replace(" ", "_")
+ final_filename = f"{timestamp}_{safe_filename}"
+ final_path = dest_dir / final_filename
+
+ # Combine chunks
+ hasher = hashlib.sha256()
+ total_size = 0
+
+ with open(final_path, "wb") as outfile:
+ for i in range(session["total_chunks"]):
+ chunk_path = Path(session["session_dir"]) / f"chunk_{i:05d}"
+
+ if not chunk_path.exists():
+ raise HTTPException(
+ status_code=500,
+ detail=f"Chunk {i} nicht gefunden"
+ )
+
+ with open(chunk_path, "rb") as infile:
+ data = infile.read()
+ outfile.write(data)
+ hasher.update(data)
+ total_size += len(data)
+
+ # Clean up chunks
+ shutil.rmtree(session["session_dir"], ignore_errors=True)
+ del _upload_sessions[upload_id]
+
+ checksum = hasher.hexdigest()
+
+ return FinalizeResponse(
+ upload_id=upload_id,
+ filename=final_filename,
+ filepath=str(final_path),
+ filesize=total_size,
+ checksum=checksum,
+ message="Upload erfolgreich abgeschlossen"
+ )
+
+
+@router.post("/simple")
+async def simple_upload(
+ file: UploadFile = File(...),
+ destination: str = Form("klausur")
+):
+ """
+ Simple single-request upload for smaller files (<10MB).
+
+ For larger files, use the chunked upload endpoints.
+ """
+ # Determine destination directory
+ if destination == "rag":
+ dest_dir = EH_UPLOAD_DIR
+ else:
+ dest_dir = UPLOAD_DIR
+
+ # Generate unique filename
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ safe_filename = file.filename.replace(" ", "_") if file.filename else "upload.pdf"
+ final_filename = f"{timestamp}_{safe_filename}"
+ final_path = dest_dir / final_filename
+
+ # Calculate checksum while writing
+ hasher = hashlib.sha256()
+ total_size = 0
+
+ with open(final_path, "wb") as f:
+ while True:
+ chunk = await file.read(1024 * 1024) # Read 1MB at a time
+ if not chunk:
+ break
+ f.write(chunk)
+ hasher.update(chunk)
+ total_size += len(chunk)
+
+ return {
+ "filename": final_filename,
+ "filepath": str(final_path),
+ "filesize": total_size,
+ "checksum": hasher.hexdigest(),
+ "message": "Upload erfolgreich"
+ }
+
+
+@router.get("/status/{upload_id}")
+async def get_upload_status(upload_id: str):
+ """
+ Get the status of an ongoing upload.
+ """
+ if upload_id not in _upload_sessions:
+ raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
+
+ session = _upload_sessions[upload_id]
+
+ return {
+ "upload_id": upload_id,
+ "filename": session["filename"],
+ "total_chunks": session["total_chunks"],
+ "received_chunks": len(session["received_chunks"]),
+ "progress_percent": round(
+ len(session["received_chunks"]) / session["total_chunks"] * 100, 1
+ ),
+ "destination": session["destination"],
+ "created_at": session["created_at"]
+ }
+
+
+@router.delete("/cancel/{upload_id}")
+async def cancel_upload(upload_id: str):
+ """
+ Cancel an ongoing upload and clean up temporary files.
+ """
+ if upload_id not in _upload_sessions:
+ raise HTTPException(status_code=404, detail="Upload-Session nicht gefunden")
+
+ session = _upload_sessions[upload_id]
+
+ # Clean up chunks
+ shutil.rmtree(session["session_dir"], ignore_errors=True)
+ del _upload_sessions[upload_id]
+
+ return {"message": "Upload abgebrochen", "upload_id": upload_id}
+
+
+@router.get("/list")
+async def list_uploads(destination: str = "klausur"):
+ """
+ List all uploaded files in the specified destination.
+ """
+ if destination == "rag":
+ dest_dir = EH_UPLOAD_DIR
+ else:
+ dest_dir = UPLOAD_DIR
+
+ files = []
+
+ for f in dest_dir.iterdir():
+ if f.is_file() and f.suffix.lower() == ".pdf":
+ stat = f.stat()
+ files.append({
+ "filename": f.name,
+ "size": stat.st_size,
+ "modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
+ })
+
+ files.sort(key=lambda x: x["modified"], reverse=True)
+
+ return {
+ "destination": destination,
+ "count": len(files),
+ "files": files[:50] # Limit to 50 most recent
+ }
diff --git a/klausur-service/backend/upload_api_mobile.py b/klausur-service/backend/upload_api_mobile.py
new file mode 100644
index 0000000..8ddd423
--- /dev/null
+++ b/klausur-service/backend/upload_api_mobile.py
@@ -0,0 +1,292 @@
+"""
+Mobile Upload HTML Page — serves the mobile upload UI directly from klausur-service.
+
+Extracted from upload_api.py for modularity.
+
+DSGVO-konform: Data stays local in WLAN, no external transmission.
+"""
+
+from fastapi import APIRouter
+from fastapi.responses import HTMLResponse
+
+router = APIRouter(prefix="/api/v1/upload", tags=["Mobile Upload"])
+
+
+@router.get("/mobile", response_class=HTMLResponse)
+async def mobile_upload_page():
+ """
+ Serve the mobile upload page directly from the klausur-service.
+ This allows mobile devices to upload without needing the Next.js website.
+ """
+ html_content = '''
+
+
+
+
+
+
BreakPilot Upload
+
+
+
+
+
+
+ Klausuren
+ Erwartungshorizonte
+
+
+
+
+
☁
+
PDF-Dateien hochladen
+
Tippen zum Auswaehlen oder hierher ziehen
+
Grosse Dateien bis 200 MB werden automatisch in Teilen hochgeladen
+
+
+
+ 0 von 0 fertig
+ 0 B gesamt
+
+
+
+
+
+
Hinweise:
+
+ Die Dateien werden lokal im WLAN uebertragen
+ Keine Daten werden ins Internet gesendet
+ Unterstuetzte Formate: PDF
+
+
+
+
Server: wird ermittelt...
+
+
+
+'''
+ return HTMLResponse(content=html_content)
diff --git a/klausur-service/backend/zeugnis_api.py b/klausur-service/backend/zeugnis_api.py
index 4d1618d..53e2ca2 100644
--- a/klausur-service/backend/zeugnis_api.py
+++ b/klausur-service/backend/zeugnis_api.py
@@ -1,537 +1,19 @@
"""
-Zeugnis Rights-Aware Crawler - API Endpoints
+Zeugnis Rights-Aware Crawler — barrel re-export.
+
+All implementation split into:
+ zeugnis_api_sources — sources, seed URLs, initialization
+ zeugnis_api_docs — documents, crawler, statistics, audit
FastAPI router for managing zeugnis sources, documents, and crawler operations.
"""
-from datetime import datetime, timedelta
-from typing import Optional, List
-from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
-from pydantic import BaseModel
+from fastapi import APIRouter
-from zeugnis_models import (
- ZeugnisSource, ZeugnisSourceCreate, ZeugnisSourceVerify,
- SeedUrl, SeedUrlCreate,
- ZeugnisDocument, ZeugnisStats,
- CrawlerStatus, CrawlRequest, CrawlQueueItem,
- UsageEvent, AuditExport,
- LicenseType, CrawlStatus, DocType, EventType,
- BUNDESLAENDER, TRAINING_PERMISSIONS,
- generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland,
-)
-from zeugnis_crawler import (
- start_crawler, stop_crawler, get_crawler_status,
-)
-from metrics_db import (
- get_zeugnis_sources, upsert_zeugnis_source,
- get_zeugnis_documents, get_zeugnis_stats,
- log_zeugnis_event, get_pool,
-)
+from zeugnis_api_sources import router as _sources_router # noqa: F401
+from zeugnis_api_docs import router as _docs_router # noqa: F401
-
-router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
-
-
-# =============================================================================
-# Sources Endpoints
-# =============================================================================
-
-@router.get("/sources", response_model=List[dict])
-async def list_sources():
- """Get all zeugnis sources (Bundesländer)."""
- sources = await get_zeugnis_sources()
- if not sources:
- # Return default sources if none exist
- return [
- {
- "id": None,
- "bundesland": code,
- "name": info["name"],
- "base_url": None,
- "license_type": str(get_license_for_bundesland(code).value),
- "training_allowed": get_training_allowed(code),
- "verified_by": None,
- "verified_at": None,
- "created_at": None,
- "updated_at": None,
- }
- for code, info in BUNDESLAENDER.items()
- ]
- return sources
-
-
-@router.post("/sources", response_model=dict)
-async def create_source(source: ZeugnisSourceCreate):
- """Create or update a zeugnis source."""
- source_id = generate_id()
- success = await upsert_zeugnis_source(
- id=source_id,
- bundesland=source.bundesland,
- name=source.name,
- license_type=source.license_type.value,
- training_allowed=source.training_allowed,
- base_url=source.base_url,
- )
- if not success:
- raise HTTPException(status_code=500, detail="Failed to create source")
- return {"id": source_id, "success": True}
-
-
-@router.put("/sources/{source_id}/verify", response_model=dict)
-async def verify_source(source_id: str, verification: ZeugnisSourceVerify):
- """Verify a source's license status."""
- pool = await get_pool()
- if not pool:
- raise HTTPException(status_code=503, detail="Database not available")
-
- try:
- async with pool.acquire() as conn:
- await conn.execute(
- """
- UPDATE zeugnis_sources
- SET license_type = $2,
- training_allowed = $3,
- verified_by = $4,
- verified_at = NOW(),
- updated_at = NOW()
- WHERE id = $1
- """,
- source_id, verification.license_type.value,
- verification.training_allowed, verification.verified_by
- )
- return {"success": True, "source_id": source_id}
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@router.get("/sources/{bundesland}", response_model=dict)
-async def get_source_by_bundesland(bundesland: str):
- """Get source details for a specific Bundesland."""
- pool = await get_pool()
- if not pool:
- # Return default info
- if bundesland not in BUNDESLAENDER:
- raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}")
- return {
- "bundesland": bundesland,
- "name": get_bundesland_name(bundesland),
- "training_allowed": get_training_allowed(bundesland),
- "license_type": get_license_for_bundesland(bundesland).value,
- "document_count": 0,
- }
-
- try:
- async with pool.acquire() as conn:
- source = await conn.fetchrow(
- "SELECT * FROM zeugnis_sources WHERE bundesland = $1",
- bundesland
- )
- if source:
- doc_count = await conn.fetchval(
- """
- SELECT COUNT(*) FROM zeugnis_documents d
- JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
- WHERE u.source_id = $1
- """,
- source["id"]
- )
- return {**dict(source), "document_count": doc_count or 0}
-
- # Return default
- return {
- "bundesland": bundesland,
- "name": get_bundesland_name(bundesland),
- "training_allowed": get_training_allowed(bundesland),
- "license_type": get_license_for_bundesland(bundesland).value,
- "document_count": 0,
- }
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-# =============================================================================
-# Seed URLs Endpoints
-# =============================================================================
-
-@router.get("/sources/{source_id}/urls", response_model=List[dict])
-async def list_seed_urls(source_id: str):
- """Get all seed URLs for a source."""
- pool = await get_pool()
- if not pool:
- return []
-
- try:
- async with pool.acquire() as conn:
- rows = await conn.fetch(
- "SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at",
- source_id
- )
- return [dict(r) for r in rows]
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@router.post("/sources/{source_id}/urls", response_model=dict)
-async def add_seed_url(source_id: str, seed_url: SeedUrlCreate):
- """Add a new seed URL to a source."""
- pool = await get_pool()
- if not pool:
- raise HTTPException(status_code=503, detail="Database not available")
-
- url_id = generate_id()
- try:
- async with pool.acquire() as conn:
- await conn.execute(
- """
- INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status)
- VALUES ($1, $2, $3, $4, 'pending')
- """,
- url_id, source_id, seed_url.url, seed_url.doc_type.value
- )
- return {"id": url_id, "success": True}
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@router.delete("/urls/{url_id}", response_model=dict)
-async def delete_seed_url(url_id: str):
- """Delete a seed URL."""
- pool = await get_pool()
- if not pool:
- raise HTTPException(status_code=503, detail="Database not available")
-
- try:
- async with pool.acquire() as conn:
- await conn.execute(
- "DELETE FROM zeugnis_seed_urls WHERE id = $1",
- url_id
- )
- return {"success": True}
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-# =============================================================================
-# Documents Endpoints
-# =============================================================================
-
-@router.get("/documents", response_model=List[dict])
-async def list_documents(
- bundesland: Optional[str] = None,
- limit: int = Query(100, le=500),
- offset: int = 0,
-):
- """Get all zeugnis documents with optional filtering."""
- documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset)
- return documents
-
-
-@router.get("/documents/{document_id}", response_model=dict)
-async def get_document(document_id: str):
- """Get details for a specific document."""
- pool = await get_pool()
- if not pool:
- raise HTTPException(status_code=503, detail="Database not available")
-
- try:
- async with pool.acquire() as conn:
- doc = await conn.fetchrow(
- """
- SELECT d.*, s.bundesland, s.name as source_name
- FROM zeugnis_documents d
- JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
- JOIN zeugnis_sources s ON u.source_id = s.id
- WHERE d.id = $1
- """,
- document_id
- )
- if not doc:
- raise HTTPException(status_code=404, detail="Document not found")
-
- # Log view event
- await log_zeugnis_event(document_id, EventType.VIEWED.value)
-
- return dict(doc)
- except HTTPException:
- raise
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@router.get("/documents/{document_id}/versions", response_model=List[dict])
-async def get_document_versions(document_id: str):
- """Get version history for a document."""
- pool = await get_pool()
- if not pool:
- raise HTTPException(status_code=503, detail="Database not available")
-
- try:
- async with pool.acquire() as conn:
- rows = await conn.fetch(
- """
- SELECT * FROM zeugnis_document_versions
- WHERE document_id = $1
- ORDER BY version DESC
- """,
- document_id
- )
- return [dict(r) for r in rows]
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-# =============================================================================
-# Crawler Control Endpoints
-# =============================================================================
-
-@router.get("/crawler/status", response_model=dict)
-async def crawler_status():
- """Get current crawler status."""
- return get_crawler_status()
-
-
-@router.post("/crawler/start", response_model=dict)
-async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks):
- """Start the crawler."""
- success = await start_crawler(
- bundesland=request.bundesland,
- source_id=request.source_id,
- )
- if not success:
- raise HTTPException(status_code=409, detail="Crawler already running")
- return {"success": True, "message": "Crawler started"}
-
-
-@router.post("/crawler/stop", response_model=dict)
-async def stop_crawl():
- """Stop the crawler."""
- success = await stop_crawler()
- if not success:
- raise HTTPException(status_code=409, detail="Crawler not running")
- return {"success": True, "message": "Crawler stopped"}
-
-
-@router.get("/crawler/queue", response_model=List[dict])
-async def get_queue():
- """Get the crawler queue."""
- pool = await get_pool()
- if not pool:
- return []
-
- try:
- async with pool.acquire() as conn:
- rows = await conn.fetch(
- """
- SELECT q.*, s.bundesland, s.name as source_name
- FROM zeugnis_crawler_queue q
- JOIN zeugnis_sources s ON q.source_id = s.id
- ORDER BY q.priority DESC, q.created_at
- """
- )
- return [dict(r) for r in rows]
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@router.post("/crawler/queue", response_model=dict)
-async def add_to_queue(request: CrawlRequest):
- """Add a source to the crawler queue."""
- pool = await get_pool()
- if not pool:
- raise HTTPException(status_code=503, detail="Database not available")
-
- queue_id = generate_id()
- try:
- async with pool.acquire() as conn:
- # Get source ID if bundesland provided
- source_id = request.source_id
- if not source_id and request.bundesland:
- source = await conn.fetchrow(
- "SELECT id FROM zeugnis_sources WHERE bundesland = $1",
- request.bundesland
- )
- if source:
- source_id = source["id"]
-
- if not source_id:
- raise HTTPException(status_code=400, detail="Source not found")
-
- await conn.execute(
- """
- INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status)
- VALUES ($1, $2, $3, 'pending')
- """,
- queue_id, source_id, request.priority
- )
- return {"id": queue_id, "success": True}
- except HTTPException:
- raise
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-# =============================================================================
-# Statistics Endpoints
-# =============================================================================
-
-@router.get("/stats", response_model=dict)
-async def get_stats():
- """Get zeugnis crawler statistics."""
- stats = await get_zeugnis_stats()
- return stats
-
-
-@router.get("/stats/bundesland", response_model=List[dict])
-async def get_bundesland_stats():
- """Get statistics per Bundesland."""
- pool = await get_pool()
-
- # Build stats from BUNDESLAENDER with DB data if available
- stats = []
- for code, info in BUNDESLAENDER.items():
- stat = {
- "bundesland": code,
- "name": info["name"],
- "training_allowed": get_training_allowed(code),
- "document_count": 0,
- "indexed_count": 0,
- "last_crawled": None,
- }
-
- if pool:
- try:
- async with pool.acquire() as conn:
- row = await conn.fetchrow(
- """
- SELECT
- COUNT(d.id) as doc_count,
- COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count,
- MAX(u.last_crawled) as last_crawled
- FROM zeugnis_sources s
- LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
- LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
- WHERE s.bundesland = $1
- GROUP BY s.id
- """,
- code
- )
- if row:
- stat["document_count"] = row["doc_count"] or 0
- stat["indexed_count"] = row["indexed_count"] or 0
- stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None
- except Exception:
- pass
-
- stats.append(stat)
-
- return stats
-
-
-# =============================================================================
-# Audit Endpoints
-# =============================================================================
-
-@router.get("/audit/events", response_model=List[dict])
-async def get_audit_events(
- document_id: Optional[str] = None,
- event_type: Optional[str] = None,
- limit: int = Query(100, le=1000),
- days: int = Query(30, le=365),
-):
- """Get audit events with optional filtering."""
- pool = await get_pool()
- if not pool:
- return []
-
- try:
- since = datetime.now() - timedelta(days=days)
- async with pool.acquire() as conn:
- query = """
- SELECT * FROM zeugnis_usage_events
- WHERE created_at >= $1
- """
- params = [since]
-
- if document_id:
- query += " AND document_id = $2"
- params.append(document_id)
- if event_type:
- query += f" AND event_type = ${len(params) + 1}"
- params.append(event_type)
-
- query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}"
- params.append(limit)
-
- rows = await conn.fetch(query, *params)
- return [dict(r) for r in rows]
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-@router.get("/audit/export", response_model=dict)
-async def export_audit(
- days: int = Query(30, le=365),
- requested_by: str = Query(..., description="User requesting the export"),
-):
- """Export audit data for GDPR compliance."""
- pool = await get_pool()
- if not pool:
- raise HTTPException(status_code=503, detail="Database not available")
-
- try:
- since = datetime.now() - timedelta(days=days)
- async with pool.acquire() as conn:
- rows = await conn.fetch(
- """
- SELECT * FROM zeugnis_usage_events
- WHERE created_at >= $1
- ORDER BY created_at DESC
- """,
- since
- )
-
- doc_count = await conn.fetchval(
- "SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1",
- since
- )
-
- return {
- "export_date": datetime.now().isoformat(),
- "requested_by": requested_by,
- "events": [dict(r) for r in rows],
- "document_count": doc_count or 0,
- "date_range_start": since.isoformat(),
- "date_range_end": datetime.now().isoformat(),
- }
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
-
-
-# =============================================================================
-# Initialization Endpoint
-# =============================================================================
-
-@router.post("/init", response_model=dict)
-async def initialize_sources():
- """Initialize default sources from BUNDESLAENDER."""
- pool = await get_pool()
- if not pool:
- raise HTTPException(status_code=503, detail="Database not available")
-
- created = 0
- try:
- for code, info in BUNDESLAENDER.items():
- source_id = generate_id()
- success = await upsert_zeugnis_source(
- id=source_id,
- bundesland=code,
- name=info["name"],
- license_type=get_license_for_bundesland(code).value,
- training_allowed=get_training_allowed(code),
- )
- if success:
- created += 1
-
- return {"success": True, "sources_created": created}
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
+# Composite router (used by main.py)
+router = APIRouter()
+router.include_router(_sources_router)
+router.include_router(_docs_router)
diff --git a/klausur-service/backend/zeugnis_api_docs.py b/klausur-service/backend/zeugnis_api_docs.py
new file mode 100644
index 0000000..0800380
--- /dev/null
+++ b/klausur-service/backend/zeugnis_api_docs.py
@@ -0,0 +1,321 @@
+"""
+Zeugnis API Docs — documents, crawler control, statistics, audit endpoints.
+
+Extracted from zeugnis_api.py for modularity.
+"""
+
+from datetime import datetime, timedelta
+from typing import Optional, List
+from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
+
+from zeugnis_models import (
+ CrawlRequest, EventType,
+ BUNDESLAENDER,
+ generate_id, get_training_allowed, get_license_for_bundesland,
+)
+from zeugnis_crawler import (
+ start_crawler, stop_crawler, get_crawler_status,
+)
+from metrics_db import (
+ get_zeugnis_documents, get_zeugnis_stats,
+ log_zeugnis_event, get_pool,
+)
+
+
+router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
+
+
+# =============================================================================
+# Documents Endpoints
+# =============================================================================
+
+@router.get("/documents", response_model=List[dict])
+async def list_documents(
+ bundesland: Optional[str] = None,
+ limit: int = Query(100, le=500),
+ offset: int = 0,
+):
+ """Get all zeugnis documents with optional filtering."""
+ documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset)
+ return documents
+
+
+@router.get("/documents/{document_id}", response_model=dict)
+async def get_document(document_id: str):
+ """Get details for a specific document."""
+ pool = await get_pool()
+ if not pool:
+ raise HTTPException(status_code=503, detail="Database not available")
+
+ try:
+ async with pool.acquire() as conn:
+ doc = await conn.fetchrow(
+ """
+ SELECT d.*, s.bundesland, s.name as source_name
+ FROM zeugnis_documents d
+ JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
+ JOIN zeugnis_sources s ON u.source_id = s.id
+ WHERE d.id = $1
+ """,
+ document_id
+ )
+ if not doc:
+ raise HTTPException(status_code=404, detail="Document not found")
+
+ # Log view event
+ await log_zeugnis_event(document_id, EventType.VIEWED.value)
+
+ return dict(doc)
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.get("/documents/{document_id}/versions", response_model=List[dict])
+async def get_document_versions(document_id: str):
+ """Get version history for a document."""
+ pool = await get_pool()
+ if not pool:
+ raise HTTPException(status_code=503, detail="Database not available")
+
+ try:
+ async with pool.acquire() as conn:
+ rows = await conn.fetch(
+ """
+ SELECT * FROM zeugnis_document_versions
+ WHERE document_id = $1
+ ORDER BY version DESC
+ """,
+ document_id
+ )
+ return [dict(r) for r in rows]
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+# =============================================================================
+# Crawler Control Endpoints
+# =============================================================================
+
+@router.get("/crawler/status", response_model=dict)
+async def crawler_status():
+ """Get current crawler status."""
+ return get_crawler_status()
+
+
+@router.post("/crawler/start", response_model=dict)
+async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks):
+ """Start the crawler."""
+ success = await start_crawler(
+ bundesland=request.bundesland,
+ source_id=request.source_id,
+ )
+ if not success:
+ raise HTTPException(status_code=409, detail="Crawler already running")
+ return {"success": True, "message": "Crawler started"}
+
+
+@router.post("/crawler/stop", response_model=dict)
+async def stop_crawl():
+ """Stop the crawler."""
+ success = await stop_crawler()
+ if not success:
+ raise HTTPException(status_code=409, detail="Crawler not running")
+ return {"success": True, "message": "Crawler stopped"}
+
+
+@router.get("/crawler/queue", response_model=List[dict])
+async def get_queue():
+ """Get the crawler queue."""
+ pool = await get_pool()
+ if not pool:
+ return []
+
+ try:
+ async with pool.acquire() as conn:
+ rows = await conn.fetch(
+ """
+ SELECT q.*, s.bundesland, s.name as source_name
+ FROM zeugnis_crawler_queue q
+ JOIN zeugnis_sources s ON q.source_id = s.id
+ ORDER BY q.priority DESC, q.created_at
+ """
+ )
+ return [dict(r) for r in rows]
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.post("/crawler/queue", response_model=dict)
+async def add_to_queue(request: CrawlRequest):
+ """Add a source to the crawler queue."""
+ pool = await get_pool()
+ if not pool:
+ raise HTTPException(status_code=503, detail="Database not available")
+
+ queue_id = generate_id()
+ try:
+ async with pool.acquire() as conn:
+ # Get source ID if bundesland provided
+ source_id = request.source_id
+ if not source_id and request.bundesland:
+ source = await conn.fetchrow(
+ "SELECT id FROM zeugnis_sources WHERE bundesland = $1",
+ request.bundesland
+ )
+ if source:
+ source_id = source["id"]
+
+ if not source_id:
+ raise HTTPException(status_code=400, detail="Source not found")
+
+ await conn.execute(
+ """
+ INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status)
+ VALUES ($1, $2, $3, 'pending')
+ """,
+ queue_id, source_id, request.priority
+ )
+ return {"id": queue_id, "success": True}
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+# =============================================================================
+# Statistics Endpoints
+# =============================================================================
+
+@router.get("/stats", response_model=dict)
+async def get_stats():
+ """Get zeugnis crawler statistics."""
+ stats = await get_zeugnis_stats()
+ return stats
+
+
+@router.get("/stats/bundesland", response_model=List[dict])
+async def get_bundesland_stats():
+ """Get statistics per Bundesland."""
+ pool = await get_pool()
+
+ # Build stats from BUNDESLAENDER with DB data if available
+ stats = []
+ for code, info in BUNDESLAENDER.items():
+ stat = {
+ "bundesland": code,
+ "name": info["name"],
+ "training_allowed": get_training_allowed(code),
+ "document_count": 0,
+ "indexed_count": 0,
+ "last_crawled": None,
+ }
+
+ if pool:
+ try:
+ async with pool.acquire() as conn:
+ row = await conn.fetchrow(
+ """
+ SELECT
+ COUNT(d.id) as doc_count,
+ COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count,
+ MAX(u.last_crawled) as last_crawled
+ FROM zeugnis_sources s
+ LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
+ LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
+ WHERE s.bundesland = $1
+ GROUP BY s.id
+ """,
+ code
+ )
+ if row:
+ stat["document_count"] = row["doc_count"] or 0
+ stat["indexed_count"] = row["indexed_count"] or 0
+ stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None
+ except Exception:
+ pass
+
+ stats.append(stat)
+
+ return stats
+
+
+# =============================================================================
+# Audit Endpoints
+# =============================================================================
+
+@router.get("/audit/events", response_model=List[dict])
+async def get_audit_events(
+ document_id: Optional[str] = None,
+ event_type: Optional[str] = None,
+ limit: int = Query(100, le=1000),
+ days: int = Query(30, le=365),
+):
+ """Get audit events with optional filtering."""
+ pool = await get_pool()
+ if not pool:
+ return []
+
+ try:
+ since = datetime.now() - timedelta(days=days)
+ async with pool.acquire() as conn:
+ query = """
+ SELECT * FROM zeugnis_usage_events
+ WHERE created_at >= $1
+ """
+ params = [since]
+
+ if document_id:
+ query += " AND document_id = $2"
+ params.append(document_id)
+ if event_type:
+ query += f" AND event_type = ${len(params) + 1}"
+ params.append(event_type)
+
+ query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}"
+ params.append(limit)
+
+ rows = await conn.fetch(query, *params)
+ return [dict(r) for r in rows]
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.get("/audit/export", response_model=dict)
+async def export_audit(
+ days: int = Query(30, le=365),
+ requested_by: str = Query(..., description="User requesting the export"),
+):
+ """Export audit data for GDPR compliance."""
+ pool = await get_pool()
+ if not pool:
+ raise HTTPException(status_code=503, detail="Database not available")
+
+ try:
+ since = datetime.now() - timedelta(days=days)
+ async with pool.acquire() as conn:
+ rows = await conn.fetch(
+ """
+ SELECT * FROM zeugnis_usage_events
+ WHERE created_at >= $1
+ ORDER BY created_at DESC
+ """,
+ since
+ )
+
+ doc_count = await conn.fetchval(
+ "SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1",
+ since
+ )
+
+ return {
+ "export_date": datetime.now().isoformat(),
+ "requested_by": requested_by,
+ "events": [dict(r) for r in rows],
+ "document_count": doc_count or 0,
+ "date_range_start": since.isoformat(),
+ "date_range_end": datetime.now().isoformat(),
+ }
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
diff --git a/klausur-service/backend/zeugnis_api_sources.py b/klausur-service/backend/zeugnis_api_sources.py
new file mode 100644
index 0000000..3eecf28
--- /dev/null
+++ b/klausur-service/backend/zeugnis_api_sources.py
@@ -0,0 +1,232 @@
+"""
+Zeugnis API Sources — source and seed URL management endpoints.
+
+Extracted from zeugnis_api.py for modularity.
+"""
+
+from typing import Optional, List
+from fastapi import APIRouter, HTTPException
+from pydantic import BaseModel
+
+from zeugnis_models import (
+ ZeugnisSourceCreate, ZeugnisSourceVerify,
+ SeedUrlCreate,
+ LicenseType, DocType,
+ BUNDESLAENDER,
+ generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland,
+)
+from metrics_db import (
+ get_zeugnis_sources, upsert_zeugnis_source, get_pool,
+)
+
+
+router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
+
+
+# =============================================================================
+# Sources Endpoints
+# =============================================================================
+
+@router.get("/sources", response_model=List[dict])
+async def list_sources():
+ """Get all zeugnis sources (Bundeslaender)."""
+ sources = await get_zeugnis_sources()
+ if not sources:
+ # Return default sources if none exist
+ return [
+ {
+ "id": None,
+ "bundesland": code,
+ "name": info["name"],
+ "base_url": None,
+ "license_type": str(get_license_for_bundesland(code).value),
+ "training_allowed": get_training_allowed(code),
+ "verified_by": None,
+ "verified_at": None,
+ "created_at": None,
+ "updated_at": None,
+ }
+ for code, info in BUNDESLAENDER.items()
+ ]
+ return sources
+
+
+@router.post("/sources", response_model=dict)
+async def create_source(source: ZeugnisSourceCreate):
+ """Create or update a zeugnis source."""
+ source_id = generate_id()
+ success = await upsert_zeugnis_source(
+ id=source_id,
+ bundesland=source.bundesland,
+ name=source.name,
+ license_type=source.license_type.value,
+ training_allowed=source.training_allowed,
+ base_url=source.base_url,
+ )
+ if not success:
+ raise HTTPException(status_code=500, detail="Failed to create source")
+ return {"id": source_id, "success": True}
+
+
+@router.put("/sources/{source_id}/verify", response_model=dict)
+async def verify_source(source_id: str, verification: ZeugnisSourceVerify):
+ """Verify a source's license status."""
+ pool = await get_pool()
+ if not pool:
+ raise HTTPException(status_code=503, detail="Database not available")
+
+ try:
+ async with pool.acquire() as conn:
+ await conn.execute(
+ """
+ UPDATE zeugnis_sources
+ SET license_type = $2,
+ training_allowed = $3,
+ verified_by = $4,
+ verified_at = NOW(),
+ updated_at = NOW()
+ WHERE id = $1
+ """,
+ source_id, verification.license_type.value,
+ verification.training_allowed, verification.verified_by
+ )
+ return {"success": True, "source_id": source_id}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.get("/sources/{bundesland}", response_model=dict)
+async def get_source_by_bundesland(bundesland: str):
+ """Get source details for a specific Bundesland."""
+ pool = await get_pool()
+ if not pool:
+ # Return default info
+ if bundesland not in BUNDESLAENDER:
+ raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}")
+ return {
+ "bundesland": bundesland,
+ "name": get_bundesland_name(bundesland),
+ "training_allowed": get_training_allowed(bundesland),
+ "license_type": get_license_for_bundesland(bundesland).value,
+ "document_count": 0,
+ }
+
+ try:
+ async with pool.acquire() as conn:
+ source = await conn.fetchrow(
+ "SELECT * FROM zeugnis_sources WHERE bundesland = $1",
+ bundesland
+ )
+ if source:
+ doc_count = await conn.fetchval(
+ """
+ SELECT COUNT(*) FROM zeugnis_documents d
+ JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
+ WHERE u.source_id = $1
+ """,
+ source["id"]
+ )
+ return {**dict(source), "document_count": doc_count or 0}
+
+ # Return default
+ return {
+ "bundesland": bundesland,
+ "name": get_bundesland_name(bundesland),
+ "training_allowed": get_training_allowed(bundesland),
+ "license_type": get_license_for_bundesland(bundesland).value,
+ "document_count": 0,
+ }
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+# =============================================================================
+# Seed URLs Endpoints
+# =============================================================================
+
+@router.get("/sources/{source_id}/urls", response_model=List[dict])
+async def list_seed_urls(source_id: str):
+ """Get all seed URLs for a source."""
+ pool = await get_pool()
+ if not pool:
+ return []
+
+ try:
+ async with pool.acquire() as conn:
+ rows = await conn.fetch(
+ "SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at",
+ source_id
+ )
+ return [dict(r) for r in rows]
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.post("/sources/{source_id}/urls", response_model=dict)
+async def add_seed_url(source_id: str, seed_url: SeedUrlCreate):
+ """Add a new seed URL to a source."""
+ pool = await get_pool()
+ if not pool:
+ raise HTTPException(status_code=503, detail="Database not available")
+
+ url_id = generate_id()
+ try:
+ async with pool.acquire() as conn:
+ await conn.execute(
+ """
+ INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status)
+ VALUES ($1, $2, $3, $4, 'pending')
+ """,
+ url_id, source_id, seed_url.url, seed_url.doc_type.value
+ )
+ return {"id": url_id, "success": True}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.delete("/urls/{url_id}", response_model=dict)
+async def delete_seed_url(url_id: str):
+ """Delete a seed URL."""
+ pool = await get_pool()
+ if not pool:
+ raise HTTPException(status_code=503, detail="Database not available")
+
+ try:
+ async with pool.acquire() as conn:
+ await conn.execute(
+ "DELETE FROM zeugnis_seed_urls WHERE id = $1",
+ url_id
+ )
+ return {"success": True}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+# =============================================================================
+# Initialization Endpoint
+# =============================================================================
+
+@router.post("/init", response_model=dict)
+async def initialize_sources():
+ """Initialize default sources from BUNDESLAENDER."""
+ pool = await get_pool()
+ if not pool:
+ raise HTTPException(status_code=503, detail="Database not available")
+
+ created = 0
+ try:
+ for code, info in BUNDESLAENDER.items():
+ source_id = generate_id()
+ success = await upsert_zeugnis_source(
+ id=source_id,
+ bundesland=code,
+ name=info["name"],
+ license_type=get_license_for_bundesland(code).value,
+ training_allowed=get_training_allowed(code),
+ )
+ if success:
+ created += 1
+
+ return {"success": True, "sources_created": created}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
diff --git a/scripts/doclayout_export_methods.py b/scripts/doclayout_export_methods.py
new file mode 100644
index 0000000..2a93224
--- /dev/null
+++ b/scripts/doclayout_export_methods.py
@@ -0,0 +1,311 @@
+"""
+PP-DocLayout ONNX Export Methods
+
+Download and Docker-based conversion methods for PP-DocLayout model.
+Extracted from export-doclayout-onnx.py.
+"""
+
+import hashlib
+import json
+import logging
+import shutil
+import subprocess
+import tempfile
+import urllib.request
+from pathlib import Path
+
+log = logging.getLogger("export-doclayout")
+
+# Known download sources for pre-exported ONNX models.
+DOWNLOAD_SOURCES = [
+ {
+ "name": "PaddleOCR PP-DocLayout (ppyoloe_plus_sod, HuggingFace)",
+ "url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx",
+ "filename": "model.onnx",
+ "sha256": None,
+ },
+ {
+ "name": "PaddleOCR PP-DocLayout (RapidOCR mirror)",
+ "url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx",
+ "filename": "model.onnx",
+ "sha256": None,
+ },
+]
+
+# Paddle inference model URLs (for Docker-based conversion).
+PADDLE_MODEL_URL = (
+ "https://paddleocr.bj.bcebos.com/PP-DocLayout/PP-DocLayout_plus.tar"
+)
+
+# Docker image name used for conversion.
+DOCKER_IMAGE_TAG = "breakpilot/paddle2onnx-converter:latest"
+
+
+def sha256_file(path: Path) -> str:
+ """Compute SHA-256 hex digest for a file."""
+ h = hashlib.sha256()
+ with open(path, "rb") as f:
+ for chunk in iter(lambda: f.read(1 << 20), b""):
+ h.update(chunk)
+ return h.hexdigest()
+
+
+def download_file(url: str, dest: Path, desc: str = "") -> bool:
+ """Download a file with progress reporting. Returns True on success."""
+ label = desc or url.split("/")[-1]
+ log.info("Downloading %s ...", label)
+ log.info(" URL: %s", url)
+
+ try:
+ req = urllib.request.Request(url, headers={"User-Agent": "breakpilot-export/1.0"})
+ with urllib.request.urlopen(req, timeout=120) as resp:
+ total = resp.headers.get("Content-Length")
+ total = int(total) if total else None
+ downloaded = 0
+
+ dest.parent.mkdir(parents=True, exist_ok=True)
+ with open(dest, "wb") as f:
+ while True:
+ chunk = resp.read(1 << 18) # 256 KB
+ if not chunk:
+ break
+ f.write(chunk)
+ downloaded += len(chunk)
+ if total:
+ pct = downloaded * 100 / total
+ mb = downloaded / (1 << 20)
+ total_mb = total / (1 << 20)
+ print(
+ f"\r {mb:.1f}/{total_mb:.1f} MB ({pct:.0f}%)",
+ end="",
+ flush=True,
+ )
+ if total:
+ print() # newline after progress
+
+ size_mb = dest.stat().st_size / (1 << 20)
+ log.info(" Downloaded %.1f MB -> %s", size_mb, dest)
+ return True
+
+ except Exception as exc:
+ log.warning(" Download failed: %s", exc)
+ if dest.exists():
+ dest.unlink()
+ return False
+
+
+def try_download(output_dir: Path) -> bool:
+ """Attempt to download a pre-exported ONNX model. Returns True on success."""
+ log.info("=== Method: DOWNLOAD ===")
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ model_path = output_dir / "model.onnx"
+
+ for source in DOWNLOAD_SOURCES:
+ log.info("Trying source: %s", source["name"])
+ tmp_path = output_dir / f".{source['filename']}.tmp"
+
+ if not download_file(source["url"], tmp_path, desc=source["name"]):
+ continue
+
+ # Check SHA-256 if known.
+ if source["sha256"]:
+ actual_hash = sha256_file(tmp_path)
+ if actual_hash != source["sha256"]:
+ log.warning(
+ " SHA-256 mismatch: expected %s, got %s",
+ source["sha256"],
+ actual_hash,
+ )
+ tmp_path.unlink()
+ continue
+
+ # Basic sanity: file should be > 1 MB.
+ size = tmp_path.stat().st_size
+ if size < 1 << 20:
+ log.warning(" File too small (%.1f KB) — probably not a valid model.", size / 1024)
+ tmp_path.unlink()
+ continue
+
+ # Move into place.
+ shutil.move(str(tmp_path), str(model_path))
+ log.info("Model saved to %s (%.1f MB)", model_path, model_path.stat().st_size / (1 << 20))
+ return True
+
+ log.warning("All download sources failed.")
+ return False
+
+
+DOCKERFILE_CONTENT = r"""
+FROM --platform=linux/amd64 python:3.11-slim
+
+RUN pip install --no-cache-dir \
+ paddlepaddle==3.0.0 \
+ paddle2onnx==1.3.1 \
+ onnx==1.17.0 \
+ requests
+
+WORKDIR /work
+
+# Download + extract the PP-DocLayout Paddle inference model.
+RUN python3 -c "
+import urllib.request, tarfile, os
+url = 'PADDLE_MODEL_URL_PLACEHOLDER'
+print(f'Downloading {url} ...')
+dest = '/work/pp_doclayout.tar'
+urllib.request.urlretrieve(url, dest)
+print('Extracting ...')
+with tarfile.open(dest) as t:
+ t.extractall('/work/paddle_model')
+os.remove(dest)
+# List what we extracted
+for root, dirs, files in os.walk('/work/paddle_model'):
+ for f in files:
+ fp = os.path.join(root, f)
+ sz = os.path.getsize(fp)
+ print(f' {fp} ({sz} bytes)')
+"
+
+# Convert Paddle model to ONNX.
+RUN python3 -c "
+import os, glob, subprocess
+
+# Find the inference model files
+model_dir = '/work/paddle_model'
+pdmodel_files = glob.glob(os.path.join(model_dir, '**', '*.pdmodel'), recursive=True)
+pdiparams_files = glob.glob(os.path.join(model_dir, '**', '*.pdiparams'), recursive=True)
+
+if not pdmodel_files:
+ raise FileNotFoundError('No .pdmodel file found in extracted archive')
+
+pdmodel = pdmodel_files[0]
+pdiparams = pdiparams_files[0] if pdiparams_files else None
+model_dir_actual = os.path.dirname(pdmodel)
+pdmodel_name = os.path.basename(pdmodel).replace('.pdmodel', '')
+
+print(f'Found model: {pdmodel}')
+print(f'Found params: {pdiparams}')
+print(f'Model dir: {model_dir_actual}')
+print(f'Model name prefix: {pdmodel_name}')
+
+cmd = [
+ 'paddle2onnx',
+ '--model_dir', model_dir_actual,
+ '--model_filename', os.path.basename(pdmodel),
+]
+if pdiparams:
+ cmd += ['--params_filename', os.path.basename(pdiparams)]
+cmd += [
+ '--save_file', '/work/output/model.onnx',
+ '--opset_version', '14',
+ '--enable_onnx_checker', 'True',
+]
+
+os.makedirs('/work/output', exist_ok=True)
+print(f'Running: {\" \".join(cmd)}')
+subprocess.run(cmd, check=True)
+
+out_size = os.path.getsize('/work/output/model.onnx')
+print(f'Conversion done: /work/output/model.onnx ({out_size} bytes)')
+"
+
+CMD ["cp", "-v", "/work/output/model.onnx", "/output/model.onnx"]
+""".replace(
+ "PADDLE_MODEL_URL_PLACEHOLDER", PADDLE_MODEL_URL
+)
+
+
+def try_docker(output_dir: Path) -> bool:
+ """Build a Docker image to convert the Paddle model to ONNX. Returns True on success."""
+ log.info("=== Method: DOCKER (linux/amd64) ===")
+
+ # Check Docker is available.
+ docker_bin = shutil.which("docker") or "/usr/local/bin/docker"
+ try:
+ subprocess.run(
+ [docker_bin, "version"],
+ capture_output=True,
+ check=True,
+ timeout=15,
+ )
+ except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired) as exc:
+ log.error("Docker is not available: %s", exc)
+ return False
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ with tempfile.TemporaryDirectory(prefix="doclayout-export-") as tmpdir:
+ tmpdir = Path(tmpdir)
+
+ # Write Dockerfile.
+ dockerfile_path = tmpdir / "Dockerfile"
+ dockerfile_path.write_text(DOCKERFILE_CONTENT)
+ log.info("Wrote Dockerfile to %s", dockerfile_path)
+
+ # Build image.
+ log.info("Building Docker image (this downloads ~2 GB, may take a while) ...")
+ build_cmd = [
+ docker_bin, "build",
+ "--platform", "linux/amd64",
+ "-t", DOCKER_IMAGE_TAG,
+ "-f", str(dockerfile_path),
+ str(tmpdir),
+ ]
+ log.info(" %s", " ".join(build_cmd))
+ build_result = subprocess.run(
+ build_cmd,
+ capture_output=False,
+ timeout=1200,
+ )
+ if build_result.returncode != 0:
+ log.error("Docker build failed (exit code %d).", build_result.returncode)
+ return False
+
+ # Run container.
+ log.info("Running conversion container ...")
+ run_cmd = [
+ docker_bin, "run",
+ "--rm",
+ "--platform", "linux/amd64",
+ "-v", f"{output_dir.resolve()}:/output",
+ DOCKER_IMAGE_TAG,
+ ]
+ log.info(" %s", " ".join(run_cmd))
+ run_result = subprocess.run(
+ run_cmd,
+ capture_output=False,
+ timeout=300,
+ )
+ if run_result.returncode != 0:
+ log.error("Docker run failed (exit code %d).", run_result.returncode)
+ return False
+
+ model_path = output_dir / "model.onnx"
+ if model_path.exists():
+ size_mb = model_path.stat().st_size / (1 << 20)
+ log.info("Model exported: %s (%.1f MB)", model_path, size_mb)
+ return True
+ else:
+ log.error("Expected output file not found: %s", model_path)
+ return False
+
+
+def write_metadata(output_dir: Path, method: str, class_labels: list, model_input_shape: tuple) -> None:
+ """Write a metadata JSON next to the model for provenance tracking."""
+ model_path = output_dir / "model.onnx"
+ if not model_path.exists():
+ return
+
+ meta = {
+ "model": "PP-DocLayout",
+ "format": "ONNX",
+ "export_method": method,
+ "class_labels": class_labels,
+ "input_shape": list(model_input_shape),
+ "file_size_bytes": model_path.stat().st_size,
+ "sha256": sha256_file(model_path),
+ }
+ meta_path = output_dir / "metadata.json"
+ with open(meta_path, "w") as f:
+ json.dump(meta, f, indent=2)
+ log.info("Metadata written to %s", meta_path)
diff --git a/scripts/export-doclayout-onnx.py b/scripts/export-doclayout-onnx.py
index 0d76271..08d9479 100755
--- a/scripts/export-doclayout-onnx.py
+++ b/scripts/export-doclayout-onnx.py
@@ -13,15 +13,8 @@ Usage:
"""
import argparse
-import hashlib
-import json
import logging
-import os
-import shutil
-import subprocess
import sys
-import tempfile
-import urllib.request
from pathlib import Path
logging.basicConfig(
@@ -49,92 +42,23 @@ CLASS_LABELS = [
"abstract",
]
-# Known download sources for pre-exported ONNX models.
-# Ordered by preference — first successful download wins.
-DOWNLOAD_SOURCES = [
- {
- "name": "PaddleOCR PP-DocLayout (ppyoloe_plus_sod, HuggingFace)",
- "url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx",
- "filename": "model.onnx",
- "sha256": None, # populated once a known-good hash is available
- },
- {
- "name": "PaddleOCR PP-DocLayout (RapidOCR mirror)",
- "url": "https://huggingface.co/SWHL/PP-DocLayout/resolve/main/pp_doclayout_onnx/model.onnx",
- "filename": "model.onnx",
- "sha256": None,
- },
-]
-
-# Paddle inference model URLs (for Docker-based conversion).
-PADDLE_MODEL_URL = (
- "https://paddleocr.bj.bcebos.com/PP-DocLayout/PP-DocLayout_plus.tar"
-)
-
# Expected input shape for the model (batch, channels, height, width).
MODEL_INPUT_SHAPE = (1, 3, 800, 800)
-# Docker image name used for conversion.
-DOCKER_IMAGE_TAG = "breakpilot/paddle2onnx-converter:latest"
+# Import methods from sibling module
+from doclayout_export_methods import (
+ try_download,
+ try_docker,
+ write_metadata,
+ sha256_file,
+)
+
# ---------------------------------------------------------------------------
-# Helpers
+# Verification
# ---------------------------------------------------------------------------
-def sha256_file(path: Path) -> str:
- """Compute SHA-256 hex digest for a file."""
- h = hashlib.sha256()
- with open(path, "rb") as f:
- for chunk in iter(lambda: f.read(1 << 20), b""):
- h.update(chunk)
- return h.hexdigest()
-
-
-def download_file(url: str, dest: Path, desc: str = "") -> bool:
- """Download a file with progress reporting. Returns True on success."""
- label = desc or url.split("/")[-1]
- log.info("Downloading %s ...", label)
- log.info(" URL: %s", url)
-
- try:
- req = urllib.request.Request(url, headers={"User-Agent": "breakpilot-export/1.0"})
- with urllib.request.urlopen(req, timeout=120) as resp:
- total = resp.headers.get("Content-Length")
- total = int(total) if total else None
- downloaded = 0
-
- dest.parent.mkdir(parents=True, exist_ok=True)
- with open(dest, "wb") as f:
- while True:
- chunk = resp.read(1 << 18) # 256 KB
- if not chunk:
- break
- f.write(chunk)
- downloaded += len(chunk)
- if total:
- pct = downloaded * 100 / total
- mb = downloaded / (1 << 20)
- total_mb = total / (1 << 20)
- print(
- f"\r {mb:.1f}/{total_mb:.1f} MB ({pct:.0f}%)",
- end="",
- flush=True,
- )
- if total:
- print() # newline after progress
-
- size_mb = dest.stat().st_size / (1 << 20)
- log.info(" Downloaded %.1f MB -> %s", size_mb, dest)
- return True
-
- except Exception as exc:
- log.warning(" Download failed: %s", exc)
- if dest.exists():
- dest.unlink()
- return False
-
-
def verify_onnx(model_path: Path) -> bool:
"""Load the ONNX model with onnxruntime, run a dummy inference, check outputs."""
log.info("Verifying ONNX model: %s", model_path)
@@ -169,24 +93,23 @@ def verify_onnx(model_path: Path) -> bool:
for out in outputs:
log.info(" %s: shape=%s dtype=%s", out.name, out.shape, out.type)
- # Build dummy input — use the first input's name and expected shape.
+ # Build dummy input
input_name = inputs[0].name
input_shape = inputs[0].shape
- # Replace dynamic dims (strings or None) with concrete sizes.
+ # Replace dynamic dims with concrete sizes.
concrete_shape = []
for i, dim in enumerate(input_shape):
if isinstance(dim, (int,)) and dim > 0:
concrete_shape.append(dim)
elif i == 0:
- concrete_shape.append(1) # batch
+ concrete_shape.append(1)
elif i == 1:
- concrete_shape.append(3) # channels
+ concrete_shape.append(3)
else:
- concrete_shape.append(800) # spatial
+ concrete_shape.append(800)
concrete_shape = tuple(concrete_shape)
- # Fallback if shape looks wrong — use standard MODEL_INPUT_SHAPE.
if len(concrete_shape) != 4:
concrete_shape = MODEL_INPUT_SHAPE
@@ -199,20 +122,15 @@ def verify_onnx(model_path: Path) -> bool:
arr = np.asarray(r)
log.info(" output[%d]: shape=%s dtype=%s", i, arr.shape, arr.dtype)
- # Basic sanity checks
if len(result) == 0:
log.error(" Model produced no outputs!")
return False
- # Check for at least one output with a bounding-box-like shape (N, 4) or
- # a detection-like structure. Be lenient — different ONNX exports vary.
has_plausible_output = False
for r in result:
arr = np.asarray(r)
- # Common detection output shapes: (1, N, 6), (N, 4), (N, 6), (1, N, 5+C), etc.
if arr.ndim >= 2 and any(d >= 4 for d in arr.shape):
has_plausible_output = True
- # Some models output (N,) labels or scores
if arr.ndim >= 1 and arr.size > 0:
has_plausible_output = True
@@ -229,238 +147,6 @@ def verify_onnx(model_path: Path) -> bool:
return False
-# ---------------------------------------------------------------------------
-# Method: Download
-# ---------------------------------------------------------------------------
-
-
-def try_download(output_dir: Path) -> bool:
- """Attempt to download a pre-exported ONNX model. Returns True on success."""
- log.info("=== Method: DOWNLOAD ===")
-
- output_dir.mkdir(parents=True, exist_ok=True)
- model_path = output_dir / "model.onnx"
-
- for source in DOWNLOAD_SOURCES:
- log.info("Trying source: %s", source["name"])
- tmp_path = output_dir / f".{source['filename']}.tmp"
-
- if not download_file(source["url"], tmp_path, desc=source["name"]):
- continue
-
- # Check SHA-256 if known.
- if source["sha256"]:
- actual_hash = sha256_file(tmp_path)
- if actual_hash != source["sha256"]:
- log.warning(
- " SHA-256 mismatch: expected %s, got %s",
- source["sha256"],
- actual_hash,
- )
- tmp_path.unlink()
- continue
-
- # Basic sanity: file should be > 1 MB (a real ONNX model, not an error page).
- size = tmp_path.stat().st_size
- if size < 1 << 20:
- log.warning(" File too small (%.1f KB) — probably not a valid model.", size / 1024)
- tmp_path.unlink()
- continue
-
- # Move into place.
- shutil.move(str(tmp_path), str(model_path))
- log.info("Model saved to %s (%.1f MB)", model_path, model_path.stat().st_size / (1 << 20))
- return True
-
- log.warning("All download sources failed.")
- return False
-
-
-# ---------------------------------------------------------------------------
-# Method: Docker
-# ---------------------------------------------------------------------------
-
-DOCKERFILE_CONTENT = r"""
-FROM --platform=linux/amd64 python:3.11-slim
-
-RUN pip install --no-cache-dir \
- paddlepaddle==3.0.0 \
- paddle2onnx==1.3.1 \
- onnx==1.17.0 \
- requests
-
-WORKDIR /work
-
-# Download + extract the PP-DocLayout Paddle inference model.
-RUN python3 -c "
-import urllib.request, tarfile, os
-url = 'PADDLE_MODEL_URL_PLACEHOLDER'
-print(f'Downloading {url} ...')
-dest = '/work/pp_doclayout.tar'
-urllib.request.urlretrieve(url, dest)
-print('Extracting ...')
-with tarfile.open(dest) as t:
- t.extractall('/work/paddle_model')
-os.remove(dest)
-# List what we extracted
-for root, dirs, files in os.walk('/work/paddle_model'):
- for f in files:
- fp = os.path.join(root, f)
- sz = os.path.getsize(fp)
- print(f' {fp} ({sz} bytes)')
-"
-
-# Convert Paddle model to ONNX.
-# paddle2onnx expects model_dir with model.pdmodel + model.pdiparams
-RUN python3 -c "
-import os, glob, subprocess
-
-# Find the inference model files
-model_dir = '/work/paddle_model'
-pdmodel_files = glob.glob(os.path.join(model_dir, '**', '*.pdmodel'), recursive=True)
-pdiparams_files = glob.glob(os.path.join(model_dir, '**', '*.pdiparams'), recursive=True)
-
-if not pdmodel_files:
- raise FileNotFoundError('No .pdmodel file found in extracted archive')
-
-pdmodel = pdmodel_files[0]
-pdiparams = pdiparams_files[0] if pdiparams_files else None
-model_dir_actual = os.path.dirname(pdmodel)
-pdmodel_name = os.path.basename(pdmodel).replace('.pdmodel', '')
-
-print(f'Found model: {pdmodel}')
-print(f'Found params: {pdiparams}')
-print(f'Model dir: {model_dir_actual}')
-print(f'Model name prefix: {pdmodel_name}')
-
-cmd = [
- 'paddle2onnx',
- '--model_dir', model_dir_actual,
- '--model_filename', os.path.basename(pdmodel),
-]
-if pdiparams:
- cmd += ['--params_filename', os.path.basename(pdiparams)]
-cmd += [
- '--save_file', '/work/output/model.onnx',
- '--opset_version', '14',
- '--enable_onnx_checker', 'True',
-]
-
-os.makedirs('/work/output', exist_ok=True)
-print(f'Running: {\" \".join(cmd)}')
-subprocess.run(cmd, check=True)
-
-out_size = os.path.getsize('/work/output/model.onnx')
-print(f'Conversion done: /work/output/model.onnx ({out_size} bytes)')
-"
-
-CMD ["cp", "-v", "/work/output/model.onnx", "/output/model.onnx"]
-""".replace(
- "PADDLE_MODEL_URL_PLACEHOLDER", PADDLE_MODEL_URL
-)
-
-
-def try_docker(output_dir: Path) -> bool:
- """Build a Docker image to convert the Paddle model to ONNX. Returns True on success."""
- log.info("=== Method: DOCKER (linux/amd64) ===")
-
- # Check Docker is available.
- docker_bin = shutil.which("docker") or "/usr/local/bin/docker"
- try:
- subprocess.run(
- [docker_bin, "version"],
- capture_output=True,
- check=True,
- timeout=15,
- )
- except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired) as exc:
- log.error("Docker is not available: %s", exc)
- return False
-
- output_dir.mkdir(parents=True, exist_ok=True)
-
- with tempfile.TemporaryDirectory(prefix="doclayout-export-") as tmpdir:
- tmpdir = Path(tmpdir)
-
- # Write Dockerfile.
- dockerfile_path = tmpdir / "Dockerfile"
- dockerfile_path.write_text(DOCKERFILE_CONTENT)
- log.info("Wrote Dockerfile to %s", dockerfile_path)
-
- # Build image.
- log.info("Building Docker image (this downloads ~2 GB, may take a while) ...")
- build_cmd = [
- docker_bin, "build",
- "--platform", "linux/amd64",
- "-t", DOCKER_IMAGE_TAG,
- "-f", str(dockerfile_path),
- str(tmpdir),
- ]
- log.info(" %s", " ".join(build_cmd))
- build_result = subprocess.run(
- build_cmd,
- capture_output=False, # stream output to terminal
- timeout=1200, # 20 min
- )
- if build_result.returncode != 0:
- log.error("Docker build failed (exit code %d).", build_result.returncode)
- return False
-
- # Run container — mount output_dir as /output, the CMD copies model.onnx there.
- log.info("Running conversion container ...")
- run_cmd = [
- docker_bin, "run",
- "--rm",
- "--platform", "linux/amd64",
- "-v", f"{output_dir.resolve()}:/output",
- DOCKER_IMAGE_TAG,
- ]
- log.info(" %s", " ".join(run_cmd))
- run_result = subprocess.run(
- run_cmd,
- capture_output=False,
- timeout=300,
- )
- if run_result.returncode != 0:
- log.error("Docker run failed (exit code %d).", run_result.returncode)
- return False
-
- model_path = output_dir / "model.onnx"
- if model_path.exists():
- size_mb = model_path.stat().st_size / (1 << 20)
- log.info("Model exported: %s (%.1f MB)", model_path, size_mb)
- return True
- else:
- log.error("Expected output file not found: %s", model_path)
- return False
-
-
-# ---------------------------------------------------------------------------
-# Write metadata
-# ---------------------------------------------------------------------------
-
-
-def write_metadata(output_dir: Path, method: str) -> None:
- """Write a metadata JSON next to the model for provenance tracking."""
- model_path = output_dir / "model.onnx"
- if not model_path.exists():
- return
-
- meta = {
- "model": "PP-DocLayout",
- "format": "ONNX",
- "export_method": method,
- "class_labels": CLASS_LABELS,
- "input_shape": list(MODEL_INPUT_SHAPE),
- "file_size_bytes": model_path.stat().st_size,
- "sha256": sha256_file(model_path),
- }
- meta_path = output_dir / "metadata.json"
- with open(meta_path, "w") as f:
- json.dump(meta, f, indent=2)
- log.info("Metadata written to %s", meta_path)
-
-
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
@@ -527,7 +213,7 @@ def main() -> int:
return 1
# Write metadata.
- write_metadata(output_dir, used_method)
+ write_metadata(output_dir, used_method, CLASS_LABELS, MODEL_INPUT_SHAPE)
# Verify.
if not args.skip_verify:
diff --git a/website/app/lehrer/abitur-archiv/_components/archiv-constants.ts b/website/app/lehrer/abitur-archiv/_components/archiv-constants.ts
new file mode 100644
index 0000000..672da84
--- /dev/null
+++ b/website/app/lehrer/abitur-archiv/_components/archiv-constants.ts
@@ -0,0 +1,84 @@
+/**
+ * Abitur-Archiv Constants & Mock Data
+ *
+ * Extracted from page.tsx.
+ */
+
+// API Base URL
+export const API_BASE = '/api/education/abitur-archiv'
+
+// Filter constants
+export const FAECHER = [
+ { id: 'deutsch', label: 'Deutsch' },
+ { id: 'englisch', label: 'Englisch' },
+ { id: 'mathematik', label: 'Mathematik' },
+ { id: 'biologie', label: 'Biologie' },
+ { id: 'physik', label: 'Physik' },
+ { id: 'chemie', label: 'Chemie' },
+ { id: 'geschichte', label: 'Geschichte' },
+]
+
+export const JAHRE = [2025, 2024, 2023, 2022, 2021]
+
+export const NIVEAUS = [
+ { id: 'eA', label: 'Erhoehtes Niveau (eA)' },
+ { id: 'gA', label: 'Grundlegendes Niveau (gA)' },
+]
+
+export const TYPEN = [
+ { id: 'aufgabe', label: 'Aufgabe' },
+ { id: 'erwartungshorizont', label: 'Erwartungshorizont' },
+]
+
+export interface AbiturDokument {
+ id: string
+ dateiname: string
+ fach: string
+ jahr: number
+ niveau: 'eA' | 'gA'
+ typ: 'aufgabe' | 'erwartungshorizont'
+ aufgaben_nummer: string
+ status: string
+ file_path: string
+ file_size: number
+}
+
+export function getMockDocuments(): AbiturDokument[] {
+ const docs: AbiturDokument[] = []
+ const faecher = ['deutsch', 'englisch']
+ const jahre = [2024, 2023, 2022]
+ const niveaus: Array<'eA' | 'gA'> = ['eA', 'gA']
+ const typen: Array<'aufgabe' | 'erwartungshorizont'> = ['aufgabe', 'erwartungshorizont']
+ const nummern = ['I', 'II', 'III']
+
+ let id = 1
+ for (const jahr of jahre) {
+ for (const fach of faecher) {
+ for (const niveau of niveaus) {
+ for (const nummer of nummern) {
+ for (const typ of typen) {
+ docs.push({
+ id: `doc-${id++}`,
+ dateiname: `${jahr}_${fach}_${niveau}_${nummer}${typ === 'erwartungshorizont' ? '_EWH' : ''}.pdf`,
+ fach,
+ jahr,
+ niveau,
+ typ,
+ aufgaben_nummer: nummer,
+ status: 'indexed',
+ file_path: '#',
+ file_size: 250000 + Math.random() * 500000
+ })
+ }
+ }
+ }
+ }
+ }
+ return docs
+}
+
+export function formatFileSize(bytes: number): string {
+ if (bytes < 1024) return bytes + ' B'
+ if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB'
+ return (bytes / (1024 * 1024)).toFixed(1) + ' MB'
+}
diff --git a/website/app/lehrer/abitur-archiv/page.tsx b/website/app/lehrer/abitur-archiv/page.tsx
index 351d82d..c2e6576 100644
--- a/website/app/lehrer/abitur-archiv/page.tsx
+++ b/website/app/lehrer/abitur-archiv/page.tsx
@@ -11,45 +11,10 @@ import {
FileText, Filter, ChevronLeft, ChevronRight, Eye, Download, Search,
X, Loader2, Grid, List, LayoutGrid, Plus, Archive, BookOpen
} from 'lucide-react'
-
-// API Base URL
-const API_BASE = '/api/education/abitur-archiv'
-
-// Filter constants
-const FAECHER = [
- { id: 'deutsch', label: 'Deutsch' },
- { id: 'englisch', label: 'Englisch' },
- { id: 'mathematik', label: 'Mathematik' },
- { id: 'biologie', label: 'Biologie' },
- { id: 'physik', label: 'Physik' },
- { id: 'chemie', label: 'Chemie' },
- { id: 'geschichte', label: 'Geschichte' },
-]
-
-const JAHRE = [2025, 2024, 2023, 2022, 2021]
-
-const NIVEAUS = [
- { id: 'eA', label: 'Erhoehtes Niveau (eA)' },
- { id: 'gA', label: 'Grundlegendes Niveau (gA)' },
-]
-
-const TYPEN = [
- { id: 'aufgabe', label: 'Aufgabe' },
- { id: 'erwartungshorizont', label: 'Erwartungshorizont' },
-]
-
-interface AbiturDokument {
- id: string
- dateiname: string
- fach: string
- jahr: number
- niveau: 'eA' | 'gA'
- typ: 'aufgabe' | 'erwartungshorizont'
- aufgaben_nummer: string
- status: string
- file_path: string
- file_size: number
-}
+import {
+ API_BASE, FAECHER, JAHRE, NIVEAUS, TYPEN,
+ type AbiturDokument, getMockDocuments, formatFileSize,
+} from './_components/archiv-constants'
export default function AbiturArchivPage() {
const [documents, setDocuments] = useState
([])
@@ -140,12 +105,6 @@ export default function AbiturArchivPage() {
const hasActiveFilters = filterFach || filterJahr || filterNiveau || filterTyp || searchQuery
- const formatFileSize = (bytes: number) => {
- if (bytes < 1024) return bytes + ' B'
- if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB'
- return (bytes / (1024 * 1024)).toFixed(1) + ' MB'
- }
-
return (
{/* Header */}
@@ -469,36 +428,3 @@ export default function AbiturArchivPage() {
)
}
-function getMockDocuments(): AbiturDokument[] {
- const docs: AbiturDokument[] = []
- const faecher = ['deutsch', 'englisch']
- const jahre = [2024, 2023, 2022]
- const niveaus: Array<'eA' | 'gA'> = ['eA', 'gA']
- const typen: Array<'aufgabe' | 'erwartungshorizont'> = ['aufgabe', 'erwartungshorizont']
- const nummern = ['I', 'II', 'III']
-
- let id = 1
- for (const jahr of jahre) {
- for (const fach of faecher) {
- for (const niveau of niveaus) {
- for (const nummer of nummern) {
- for (const typ of typen) {
- docs.push({
- id: `doc-${id++}`,
- dateiname: `${jahr}_${fach}_${niveau}_${nummer}${typ === 'erwartungshorizont' ? '_EWH' : ''}.pdf`,
- fach,
- jahr,
- niveau,
- typ,
- aufgaben_nummer: nummer,
- status: 'indexed',
- file_path: '#',
- file_size: 250000 + Math.random() * 500000
- })
- }
- }
- }
- }
- }
- return docs
-}
diff --git a/website/components/compliance/charts/DependencyMap.tsx b/website/components/compliance/charts/DependencyMap.tsx
index a70718a..6f1cadc 100644
--- a/website/components/compliance/charts/DependencyMap.tsx
+++ b/website/components/compliance/charts/DependencyMap.tsx
@@ -10,67 +10,9 @@
*/
import { useState, useMemo } from 'react'
-import { Language, getTerm } from '@/lib/compliance-i18n'
-
-interface Requirement {
- id: string
- article: string
- title: string
- regulation_code: string
-}
-
-interface Control {
- id: string
- control_id: string
- title: string
- domain: string
- status: string
-}
-
-interface Mapping {
- requirement_id: string
- control_id: string
- coverage_level: 'full' | 'partial' | 'planned'
-}
-
-interface DependencyMapProps {
- requirements: Requirement[]
- controls: Control[]
- mappings: Mapping[]
- lang?: Language
- onControlClick?: (control: Control) => void
- onRequirementClick?: (requirement: Requirement) => void
-}
-
-const DOMAIN_COLORS: Record
= {
- gov: '#64748b',
- priv: '#3b82f6',
- iam: '#a855f7',
- crypto: '#eab308',
- sdlc: '#22c55e',
- ops: '#f97316',
- ai: '#ec4899',
- cra: '#06b6d4',
- aud: '#6366f1',
-}
-
-const DOMAIN_LABELS: Record = {
- gov: 'Governance',
- priv: 'Datenschutz',
- iam: 'Identity & Access',
- crypto: 'Kryptografie',
- sdlc: 'Secure Dev',
- ops: 'Operations',
- ai: 'KI-spezifisch',
- cra: 'Supply Chain',
- aud: 'Audit',
-}
-
-const COVERAGE_COLORS: Record = {
- full: { bg: 'bg-green-100', border: 'border-green-500', text: 'text-green-700' },
- partial: { bg: 'bg-yellow-100', border: 'border-yellow-500', text: 'text-yellow-700' },
- planned: { bg: 'bg-slate-100', border: 'border-slate-400', text: 'text-slate-600' },
-}
+import type { Control, Requirement, DependencyMapProps } from './DependencyMapTypes'
+import { DOMAIN_COLORS, COVERAGE_COLORS } from './DependencyMapTypes'
+import { DependencyMapSankey } from './DependencyMapSankey'
export default function DependencyMap({
requirements,
@@ -115,7 +57,7 @@ export default function DependencyMap({
// Build mapping lookup
const mappingLookup = useMemo(() => {
- const lookup: Record> = {}
+ const lookup: Record> = {}
mappings.forEach((m) => {
if (!lookup[m.control_id]) lookup[m.control_id] = {}
lookup[m.control_id][m.requirement_id] = m
@@ -123,11 +65,6 @@ export default function DependencyMap({
return lookup
}, [mappings])
- // Get connected requirements for a control
- const getConnectedRequirements = (controlId: string) => {
- return Object.keys(mappingLookup[controlId] || {})
- }
-
// Get connected controls for a requirement
const getConnectedControls = (requirementId: string) => {
return Object.keys(mappingLookup)
@@ -201,81 +138,42 @@ export default function DependencyMap({
{/* Statistics Header */}
-
- {lang === 'de' ? 'Abdeckung' : 'Coverage'}
-
+
{lang === 'de' ? 'Abdeckung' : 'Coverage'}
{stats.coveragePercent}%
-
- {stats.coveredRequirements}/{stats.totalRequirements} {lang === 'de' ? 'Anforderungen' : 'Requirements'}
-
+
{stats.coveredRequirements}/{stats.totalRequirements} {lang === 'de' ? 'Anforderungen' : 'Requirements'}
-
- {lang === 'de' ? 'Vollstaendig' : 'Full'}
-
+
{lang === 'de' ? 'Vollstaendig' : 'Full'}
{stats.fullMappings}
-
{lang === 'de' ? 'Mappings' : 'Mappings'}
+
Mappings
-
- {lang === 'de' ? 'Teilweise' : 'Partial'}
-
+
{lang === 'de' ? 'Teilweise' : 'Partial'}
{stats.partialMappings}
-
{lang === 'de' ? 'Mappings' : 'Mappings'}
+
Mappings
-
- {lang === 'de' ? 'Geplant' : 'Planned'}
-
+
{lang === 'de' ? 'Geplant' : 'Planned'}
{stats.plannedMappings}
-
{lang === 'de' ? 'Mappings' : 'Mappings'}
+
Mappings
{/* Filters */}
-
setFilterRegulation(e.target.value)}
- className="px-3 py-2 border rounded-lg focus:ring-2 focus:ring-primary-500"
- >
+ setFilterRegulation(e.target.value)} className="px-3 py-2 border rounded-lg focus:ring-2 focus:ring-primary-500">
{lang === 'de' ? 'Alle Verordnungen' : 'All Regulations'}
- {regulations.map((reg) => (
- {reg}
- ))}
+ {regulations.map((reg) => ({reg} ))}
-
- setFilterDomain(e.target.value)}
- className="px-3 py-2 border rounded-lg focus:ring-2 focus:ring-primary-500"
- >
+ setFilterDomain(e.target.value)} className="px-3 py-2 border rounded-lg focus:ring-2 focus:ring-primary-500">
{lang === 'de' ? 'Alle Domains' : 'All Domains'}
- {domains.map((dom) => (
- {DOMAIN_LABELS[dom] || dom}
- ))}
+ {domains.map((dom) => ({dom} ))}
-
-
- setViewMode('matrix')}
- className={`px-3 py-1.5 text-sm rounded-md transition-colors ${
- viewMode === 'matrix' ? 'bg-white shadow text-slate-900' : 'text-slate-600'
- }`}
- >
- Matrix
-
- setViewMode('sankey')}
- className={`px-3 py-1.5 text-sm rounded-md transition-colors ${
- viewMode === 'sankey' ? 'bg-white shadow text-slate-900' : 'text-slate-600'
- }`}
- >
- {lang === 'de' ? 'Verbindungen' : 'Connections'}
-
+ setViewMode('matrix')} className={`px-3 py-1.5 text-sm rounded-md transition-colors ${viewMode === 'matrix' ? 'bg-white shadow text-slate-900' : 'text-slate-600'}`}>Matrix
+ setViewMode('sankey')} className={`px-3 py-1.5 text-sm rounded-md transition-colors ${viewMode === 'sankey' ? 'bg-white shadow text-slate-900' : 'text-slate-600'}`}>{lang === 'de' ? 'Verbindungen' : 'Connections'}
@@ -284,96 +182,38 @@ export default function DependencyMap({
{viewMode === 'matrix' ? (
- {/* Matrix Header */}
{filteredControls.map((control) => (
-
handleControlClick(control)}
- className={`
- w-20 flex-shrink-0 text-center p-2 cursor-pointer transition-colors
- ${selectedControl === control.control_id ? 'bg-primary-100' : 'hover:bg-slate-50'}
- `}
- >
-
-
- {control.control_id}
-
+
handleControlClick(control)} className={`w-20 flex-shrink-0 text-center p-2 cursor-pointer transition-colors ${selectedControl === control.control_id ? 'bg-primary-100' : 'hover:bg-slate-50'}`}>
+
+
{control.control_id}
))}
-
- {/* Matrix Body */}
{filteredRequirements.map((req) => {
const connectedControls = getConnectedControls(req.id)
- const isHighlighted = selectedRequirement === req.id ||
- (selectedControl && connectedControls.some((c) => c.controlId === selectedControl))
-
+ const isHighlighted = selectedRequirement === req.id || (selectedControl && connectedControls.some((c) => c.controlId === selectedControl))
return (
-
-
handleRequirementClick(req)}
- className={`
- w-48 flex-shrink-0 p-2 cursor-pointer transition-colors
- ${selectedRequirement === req.id ? 'bg-primary-100' : 'hover:bg-slate-50'}
- `}
- >
-
- {req.regulation_code} {req.article}
-
-
- {req.title}
-
+
+
handleRequirementClick(req)} className={`w-48 flex-shrink-0 p-2 cursor-pointer transition-colors ${selectedRequirement === req.id ? 'bg-primary-100' : 'hover:bg-slate-50'}`}>
+
{req.regulation_code} {req.article}
+
{req.title}
{filteredControls.map((control) => {
const mapping = mappingLookup[control.control_id]?.[req.id]
const isControlHighlighted = selectedControl === control.control_id
const isConnected = selectedControl && mapping
-
return (
-
+
{mapping && (
-
- {mapping.coverage_level === 'full' && (
-
-
-
- )}
- {mapping.coverage_level === 'partial' && (
-
-
-
- )}
- {mapping.coverage_level === 'planned' && (
-
-
-
- )}
+
+ {mapping.coverage_level === 'full' && (
)}
+ {mapping.coverage_level === 'partial' && (
)}
+ {mapping.coverage_level === 'planned' && (
)}
)}
@@ -386,179 +226,35 @@ export default function DependencyMap({
) : (
- /* Sankey/Connection View */
-
-
- {/* Controls Column */}
-
-
- Controls ({filteredControls.length})
-
- {filteredControls.map((control) => {
- const connectedReqs = getConnectedRequirements(control.control_id)
- const isSelected = selectedControl === control.control_id
-
- return (
-
handleControlClick(control)}
- className={`
- w-full text-left p-3 rounded-lg border transition-all
- ${isSelected ? 'border-primary-500 bg-primary-50 shadow' : 'border-slate-200 hover:border-slate-300'}
- `}
- >
-
-
-
{control.control_id}
-
{connectedReqs.length}
-
- {control.title}
-
- )
- })}
-
-
- {/* Connection Lines (simplified) */}
-
-
- {selectedControl && (
-
- {getConnectedRequirements(selectedControl).slice(0, 10).map((reqId, idx) => {
- const req = requirements.find((r) => r.id === reqId)
- const mapping = mappingLookup[selectedControl][reqId]
- if (!req) return null
-
- return (
-
- {req.regulation_code} {req.article}
-
- )
- })}
- {getConnectedRequirements(selectedControl).length > 10 && (
-
- +{getConnectedRequirements(selectedControl).length - 10} {lang === 'de' ? 'weitere' : 'more'}
-
- )}
-
- )}
- {selectedRequirement && (
-
- {getConnectedControls(selectedRequirement).slice(0, 10).map(({ controlId, coverage }) => {
- const control = controls.find((c) => c.control_id === controlId)
- if (!control) return null
-
- return (
-
- {control.control_id}
-
- )
- })}
-
- )}
- {!selectedControl && !selectedRequirement && (
-
-
- {lang === 'de'
- ? 'Waehlen Sie ein Control oder eine Anforderung aus'
- : 'Select a control or requirement'}
-
-
- )}
-
-
-
- {/* Requirements Column */}
-
-
- {lang === 'de' ? 'Anforderungen' : 'Requirements'} ({filteredRequirements.length})
-
- {filteredRequirements.slice(0, 15).map((req) => {
- const connectedCtrls = getConnectedControls(req.id)
- const isSelected = selectedRequirement === req.id
- const isHighlighted = selectedControl && connectedCtrls.some((c) => c.controlId === selectedControl)
-
- return (
-
handleRequirementClick(req)}
- className={`
- w-full text-left p-3 rounded-lg border transition-all
- ${isSelected ? 'border-primary-500 bg-primary-50 shadow' : ''}
- ${isHighlighted && !isSelected ? 'border-primary-300 bg-primary-25' : ''}
- ${!isSelected && !isHighlighted ? 'border-slate-200 hover:border-slate-300' : ''}
- `}
- >
-
- {req.regulation_code}
- {req.article}
- {connectedCtrls.length}
-
- {req.title}
-
- )
- })}
- {filteredRequirements.length > 15 && (
-
- +{filteredRequirements.length - 15} {lang === 'de' ? 'weitere' : 'more'}
-
- )}
-
-
-
+
)}
{/* Legend */}
-
-
-
-
-
+ {(['full', 'partial', 'planned'] as const).map((level) => (
+
+
+ {level === 'full' && (
)}
+ {level === 'partial' && (
)}
+ {level === 'planned' && (
)}
+
+
+ {lang === 'de' ? (level === 'full' ? 'Vollstaendig abgedeckt' : level === 'partial' ? 'Teilweise abgedeckt' : 'Geplant') : (level === 'full' ? 'Fully covered' : level === 'partial' ? 'Partially covered' : 'Planned')}
+
-
- {lang === 'de' ? 'Vollstaendig abgedeckt' : 'Fully covered'}
-
-
-
-
-
- {lang === 'de' ? 'Teilweise abgedeckt' : 'Partially covered'}
-
-
-
-
-
- {lang === 'de' ? 'Geplant' : 'Planned'}
-
-
+ ))}
diff --git a/website/components/compliance/charts/DependencyMapSankey.tsx b/website/components/compliance/charts/DependencyMapSankey.tsx
new file mode 100644
index 0000000..1e639e9
--- /dev/null
+++ b/website/components/compliance/charts/DependencyMapSankey.tsx
@@ -0,0 +1,190 @@
+'use client'
+
+/**
+ * DependencyMap Sankey/Connection View
+ *
+ * Extracted from DependencyMap to keep each file under 500 LOC.
+ */
+
+import type { Language } from '@/lib/compliance-i18n'
+import type { Requirement, Control, Mapping } from './DependencyMapTypes'
+import { DOMAIN_COLORS, COVERAGE_COLORS } from './DependencyMapTypes'
+
+interface DependencyMapSankeyProps {
+ filteredControls: Control[]
+ filteredRequirements: Requirement[]
+ requirements: Requirement[]
+ controls: Control[]
+ mappingLookup: Record
>
+ selectedControl: string | null
+ selectedRequirement: string | null
+ onControlClick: (control: Control) => void
+ onRequirementClick: (requirement: Requirement) => void
+ lang: Language
+}
+
+export function DependencyMapSankey({
+ filteredControls,
+ filteredRequirements,
+ requirements,
+ controls,
+ mappingLookup,
+ selectedControl,
+ selectedRequirement,
+ onControlClick,
+ onRequirementClick,
+ lang,
+}: DependencyMapSankeyProps) {
+ const getConnectedRequirements = (controlId: string) => {
+ return Object.keys(mappingLookup[controlId] || {})
+ }
+
+ const getConnectedControls = (requirementId: string) => {
+ return Object.keys(mappingLookup)
+ .filter((controlId) => mappingLookup[controlId][requirementId])
+ .map((controlId) => ({
+ controlId,
+ coverage: mappingLookup[controlId][requirementId].coverage_level,
+ }))
+ }
+
+ return (
+
+
+ {/* Controls Column */}
+
+
+ Controls ({filteredControls.length})
+
+ {filteredControls.map((control) => {
+ const connectedReqs = getConnectedRequirements(control.control_id)
+ const isSelected = selectedControl === control.control_id
+
+ return (
+
onControlClick(control)}
+ className={`
+ w-full text-left p-3 rounded-lg border transition-all
+ ${isSelected ? 'border-primary-500 bg-primary-50 shadow' : 'border-slate-200 hover:border-slate-300'}
+ `}
+ >
+
+
+
{control.control_id}
+
{connectedReqs.length}
+
+ {control.title}
+
+ )
+ })}
+
+
+ {/* Connection Lines (simplified) */}
+
+
+ {selectedControl && (
+
+ {getConnectedRequirements(selectedControl).slice(0, 10).map((reqId) => {
+ const req = requirements.find((r) => r.id === reqId)
+ const mapping = mappingLookup[selectedControl][reqId]
+ if (!req) return null
+
+ return (
+
+ {req.regulation_code} {req.article}
+
+ )
+ })}
+ {getConnectedRequirements(selectedControl).length > 10 && (
+
+ +{getConnectedRequirements(selectedControl).length - 10} {lang === 'de' ? 'weitere' : 'more'}
+
+ )}
+
+ )}
+ {selectedRequirement && (
+
+ {getConnectedControls(selectedRequirement).slice(0, 10).map(({ controlId, coverage }) => {
+ const control = controls.find((c) => c.control_id === controlId)
+ if (!control) return null
+
+ return (
+
+ {control.control_id}
+
+ )
+ })}
+
+ )}
+ {!selectedControl && !selectedRequirement && (
+
+
+ {lang === 'de'
+ ? 'Waehlen Sie ein Control oder eine Anforderung aus'
+ : 'Select a control or requirement'}
+
+
+ )}
+
+
+
+ {/* Requirements Column */}
+
+
+ {lang === 'de' ? 'Anforderungen' : 'Requirements'} ({filteredRequirements.length})
+
+ {filteredRequirements.slice(0, 15).map((req) => {
+ const connectedCtrls = getConnectedControls(req.id)
+ const isSelected = selectedRequirement === req.id
+ const isHighlighted = selectedControl && connectedCtrls.some((c) => c.controlId === selectedControl)
+
+ return (
+
onRequirementClick(req)}
+ className={`
+ w-full text-left p-3 rounded-lg border transition-all
+ ${isSelected ? 'border-primary-500 bg-primary-50 shadow' : ''}
+ ${isHighlighted && !isSelected ? 'border-primary-300 bg-primary-25' : ''}
+ ${!isSelected && !isHighlighted ? 'border-slate-200 hover:border-slate-300' : ''}
+ `}
+ >
+
+ {req.regulation_code}
+ {req.article}
+ {connectedCtrls.length}
+
+ {req.title}
+
+ )
+ })}
+ {filteredRequirements.length > 15 && (
+
+ +{filteredRequirements.length - 15} {lang === 'de' ? 'weitere' : 'more'}
+
+ )}
+
+
+
+ )
+}
diff --git a/website/components/compliance/charts/DependencyMapTypes.ts b/website/components/compliance/charts/DependencyMapTypes.ts
new file mode 100644
index 0000000..000e2d4
--- /dev/null
+++ b/website/components/compliance/charts/DependencyMapTypes.ts
@@ -0,0 +1,65 @@
+/**
+ * Types and constants for DependencyMap component.
+ */
+
+import type { Language } from '@/lib/compliance-i18n'
+
+export interface Requirement {
+ id: string
+ article: string
+ title: string
+ regulation_code: string
+}
+
+export interface Control {
+ id: string
+ control_id: string
+ title: string
+ domain: string
+ status: string
+}
+
+export interface Mapping {
+ requirement_id: string
+ control_id: string
+ coverage_level: 'full' | 'partial' | 'planned'
+}
+
+export interface DependencyMapProps {
+ requirements: Requirement[]
+ controls: Control[]
+ mappings: Mapping[]
+ lang?: Language
+ onControlClick?: (control: Control) => void
+ onRequirementClick?: (requirement: Requirement) => void
+}
+
+export const DOMAIN_COLORS: Record = {
+ gov: '#64748b',
+ priv: '#3b82f6',
+ iam: '#a855f7',
+ crypto: '#eab308',
+ sdlc: '#22c55e',
+ ops: '#f97316',
+ ai: '#ec4899',
+ cra: '#06b6d4',
+ aud: '#6366f1',
+}
+
+export const DOMAIN_LABELS: Record = {
+ gov: 'Governance',
+ priv: 'Datenschutz',
+ iam: 'Identity & Access',
+ crypto: 'Kryptografie',
+ sdlc: 'Secure Dev',
+ ops: 'Operations',
+ ai: 'KI-spezifisch',
+ cra: 'Supply Chain',
+ aud: 'Audit',
+}
+
+export const COVERAGE_COLORS: Record = {
+ full: { bg: 'bg-green-100', border: 'border-green-500', text: 'text-green-700' },
+ partial: { bg: 'bg-yellow-100', border: 'border-yellow-500', text: 'text-yellow-700' },
+ planned: { bg: 'bg-slate-100', border: 'border-slate-400', text: 'text-slate-600' },
+}