[split-required] Split final 43 files (500-668 LOC) to complete refactoring
klausur-service (11 files): - cv_gutter_repair, ocr_pipeline_regression, upload_api - ocr_pipeline_sessions, smart_spell, nru_worksheet_generator - ocr_pipeline_overlays, mail/aggregator, zeugnis_api - cv_syllable_detect, self_rag backend-lehrer (17 files): - classroom_engine/suggestions, generators/quiz_generator - worksheets_api, llm_gateway/comparison, state_engine_api - classroom/models (→ 4 submodules), services/file_processor - alerts_agent/api/wizard+digests+routes, content_generators/pdf - classroom/routes/sessions, llm_gateway/inference - classroom_engine/analytics, auth/keycloak_auth - alerts_agent/processing/rule_engine, ai_processor/print_versions agent-core (5 files): - brain/memory_store, brain/knowledge_graph, brain/context_manager - orchestrator/supervisor, sessions/session_manager admin-lehrer (5 components): - GridOverlay, StepGridReview, DevOpsPipelineSidebar - DataFlowDiagram, sbom/wizard/page website (2 files): - DependencyMap, lehrer/abitur-archiv Other: nibis_ingestion, grid_detection_service, export-doclayout-onnx Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -7,15 +7,31 @@ Provides:
|
||||
- KnowledgeGraph: Entity relationships and semantic connections
|
||||
"""
|
||||
|
||||
from agent_core.brain.memory_store import MemoryStore, Memory
|
||||
from agent_core.brain.context_manager import ConversationContext, ContextManager
|
||||
from agent_core.brain.knowledge_graph import KnowledgeGraph, Entity, Relationship
|
||||
from agent_core.brain.memory_models import Memory
|
||||
from agent_core.brain.memory_store import MemoryStore
|
||||
from agent_core.brain.context_models import (
|
||||
MessageRole,
|
||||
Message,
|
||||
ConversationContext,
|
||||
)
|
||||
from agent_core.brain.context_manager import ContextManager
|
||||
from agent_core.brain.knowledge_models import (
|
||||
EntityType,
|
||||
RelationshipType,
|
||||
Entity,
|
||||
Relationship,
|
||||
)
|
||||
from agent_core.brain.knowledge_graph import KnowledgeGraph
|
||||
|
||||
__all__ = [
|
||||
"MemoryStore",
|
||||
"Memory",
|
||||
"MessageRole",
|
||||
"Message",
|
||||
"ConversationContext",
|
||||
"ContextManager",
|
||||
"EntityType",
|
||||
"RelationshipType",
|
||||
"KnowledgeGraph",
|
||||
"Entity",
|
||||
"Relationship",
|
||||
|
||||
@@ -1,317 +1,22 @@
|
||||
"""
|
||||
Context Management for Breakpilot Agents
|
||||
|
||||
Provides conversation context with:
|
||||
- Message history with compression
|
||||
- Entity extraction and tracking
|
||||
- Intent history
|
||||
- Context summarization
|
||||
Manages conversation contexts for multiple sessions with persistence.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
|
||||
from agent_core.brain.context_models import (
|
||||
MessageRole,
|
||||
Message,
|
||||
ConversationContext,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageRole(Enum):
|
||||
"""Message roles in a conversation"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""Represents a message in a conversation"""
|
||||
role: MessageRole
|
||||
content: str
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": self.role.value,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Message":
|
||||
return cls(
|
||||
role=MessageRole(data["role"]),
|
||||
content=data["content"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""
|
||||
Context for a running conversation.
|
||||
|
||||
Maintains:
|
||||
- Message history with automatic compression
|
||||
- Extracted entities
|
||||
- Intent history
|
||||
- Conversation summary
|
||||
"""
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
entities: Dict[str, Any] = field(default_factory=dict)
|
||||
intent_history: List[str] = field(default_factory=list)
|
||||
summary: Optional[str] = None
|
||||
max_messages: int = 50
|
||||
system_prompt: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""
|
||||
Adds a message to the conversation.
|
||||
|
||||
Args:
|
||||
role: Message role
|
||||
content: Message content
|
||||
metadata: Optional message metadata
|
||||
|
||||
Returns:
|
||||
The created Message
|
||||
"""
|
||||
message = Message(
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
self.messages.append(message)
|
||||
|
||||
# Compress if needed
|
||||
if len(self.messages) > self.max_messages:
|
||||
self._compress_history()
|
||||
|
||||
return message
|
||||
|
||||
def add_user_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add a user message"""
|
||||
return self.add_message(MessageRole.USER, content, metadata)
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add an assistant message"""
|
||||
return self.add_message(MessageRole.ASSISTANT, content, metadata)
|
||||
|
||||
def add_system_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add a system message"""
|
||||
return self.add_message(MessageRole.SYSTEM, content, metadata)
|
||||
|
||||
def add_intent(self, intent: str) -> None:
|
||||
"""
|
||||
Records an intent in the history.
|
||||
|
||||
Args:
|
||||
intent: The detected intent
|
||||
"""
|
||||
self.intent_history.append(intent)
|
||||
# Keep last 20 intents
|
||||
if len(self.intent_history) > 20:
|
||||
self.intent_history = self.intent_history[-20:]
|
||||
|
||||
def set_entity(self, name: str, value: Any) -> None:
|
||||
"""
|
||||
Sets an entity value.
|
||||
|
||||
Args:
|
||||
name: Entity name
|
||||
value: Entity value
|
||||
"""
|
||||
self.entities[name] = value
|
||||
|
||||
def get_entity(self, name: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Gets an entity value.
|
||||
|
||||
Args:
|
||||
name: Entity name
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Entity value or default
|
||||
"""
|
||||
return self.entities.get(name, default)
|
||||
|
||||
def get_last_message(self, role: Optional[MessageRole] = None) -> Optional[Message]:
|
||||
"""
|
||||
Gets the last message, optionally filtered by role.
|
||||
|
||||
Args:
|
||||
role: Optional role filter
|
||||
|
||||
Returns:
|
||||
The last matching message or None
|
||||
"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
if role is None:
|
||||
return self.messages[-1]
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == role:
|
||||
return msg
|
||||
|
||||
return None
|
||||
|
||||
def get_messages_for_llm(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Gets messages formatted for LLM API calls.
|
||||
|
||||
Returns:
|
||||
List of message dicts with role and content
|
||||
"""
|
||||
result = []
|
||||
|
||||
# Add system prompt first
|
||||
if self.system_prompt:
|
||||
result.append({
|
||||
"role": "system",
|
||||
"content": self.system_prompt
|
||||
})
|
||||
|
||||
# Add summary if we have one and history was compressed
|
||||
if self.summary:
|
||||
result.append({
|
||||
"role": "system",
|
||||
"content": f"Previous conversation summary: {self.summary}"
|
||||
})
|
||||
|
||||
# Add recent messages
|
||||
for msg in self.messages:
|
||||
result.append({
|
||||
"role": msg.role.value,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _compress_history(self) -> None:
|
||||
"""
|
||||
Compresses older messages to save context window space.
|
||||
|
||||
Keeps:
|
||||
- System messages
|
||||
- Last 20 messages
|
||||
- Creates summary of compressed middle messages
|
||||
"""
|
||||
# Keep system messages
|
||||
system_msgs = [m for m in self.messages if m.role == MessageRole.SYSTEM]
|
||||
|
||||
# Keep last 20 messages
|
||||
recent_msgs = self.messages[-20:]
|
||||
|
||||
# Middle messages to summarize
|
||||
middle_start = len(system_msgs)
|
||||
middle_end = len(self.messages) - 20
|
||||
middle_msgs = self.messages[middle_start:middle_end]
|
||||
|
||||
if middle_msgs:
|
||||
# Create a basic summary (can be enhanced with LLM-based summarization)
|
||||
self.summary = self._create_summary(middle_msgs)
|
||||
|
||||
# Combine
|
||||
self.messages = system_msgs + recent_msgs
|
||||
|
||||
logger.debug(
|
||||
f"Compressed conversation: {middle_end - middle_start} messages summarized"
|
||||
)
|
||||
|
||||
def _create_summary(self, messages: List[Message]) -> str:
|
||||
"""
|
||||
Creates a summary of messages.
|
||||
|
||||
This is a basic implementation - can be enhanced with LLM-based summarization.
|
||||
|
||||
Args:
|
||||
messages: Messages to summarize
|
||||
|
||||
Returns:
|
||||
Summary string
|
||||
"""
|
||||
# Count message types
|
||||
user_count = sum(1 for m in messages if m.role == MessageRole.USER)
|
||||
assistant_count = sum(1 for m in messages if m.role == MessageRole.ASSISTANT)
|
||||
|
||||
# Extract key topics (simplified - could use NLP)
|
||||
topics = set()
|
||||
for msg in messages:
|
||||
# Simple keyword extraction
|
||||
words = msg.content.lower().split()
|
||||
# Filter common words
|
||||
keywords = [w for w in words if len(w) > 5][:3]
|
||||
topics.update(keywords)
|
||||
|
||||
topics_str = ", ".join(list(topics)[:5])
|
||||
|
||||
return (
|
||||
f"Earlier conversation: {user_count} user messages, "
|
||||
f"{assistant_count} assistant responses. "
|
||||
f"Topics discussed: {topics_str}"
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clears all context"""
|
||||
self.messages.clear()
|
||||
self.entities.clear()
|
||||
self.intent_history.clear()
|
||||
self.summary = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes context to dict"""
|
||||
return {
|
||||
"messages": [m.to_dict() for m in self.messages],
|
||||
"entities": self.entities,
|
||||
"intent_history": self.intent_history,
|
||||
"summary": self.summary,
|
||||
"max_messages": self.max_messages,
|
||||
"system_prompt": self.system_prompt,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ConversationContext":
|
||||
"""Deserializes context from dict"""
|
||||
ctx = cls(
|
||||
messages=[Message.from_dict(m) for m in data.get("messages", [])],
|
||||
entities=data.get("entities", {}),
|
||||
intent_history=data.get("intent_history", []),
|
||||
summary=data.get("summary"),
|
||||
max_messages=data.get("max_messages", 50),
|
||||
system_prompt=data.get("system_prompt"),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""
|
||||
Manages conversation contexts for multiple sessions.
|
||||
|
||||
307
agent-core/brain/context_models.py
Normal file
307
agent-core/brain/context_models.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Context Models for Breakpilot Agents
|
||||
|
||||
Data classes for conversation messages and context management.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageRole(Enum):
|
||||
"""Message roles in a conversation"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""Represents a message in a conversation"""
|
||||
role: MessageRole
|
||||
content: str
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": self.role.value,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Message":
|
||||
return cls(
|
||||
role=MessageRole(data["role"]),
|
||||
content=data["content"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""
|
||||
Context for a running conversation.
|
||||
|
||||
Maintains:
|
||||
- Message history with automatic compression
|
||||
- Extracted entities
|
||||
- Intent history
|
||||
- Conversation summary
|
||||
"""
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
entities: Dict[str, Any] = field(default_factory=dict)
|
||||
intent_history: List[str] = field(default_factory=list)
|
||||
summary: Optional[str] = None
|
||||
max_messages: int = 50
|
||||
system_prompt: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""
|
||||
Adds a message to the conversation.
|
||||
|
||||
Args:
|
||||
role: Message role
|
||||
content: Message content
|
||||
metadata: Optional message metadata
|
||||
|
||||
Returns:
|
||||
The created Message
|
||||
"""
|
||||
message = Message(
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
self.messages.append(message)
|
||||
|
||||
# Compress if needed
|
||||
if len(self.messages) > self.max_messages:
|
||||
self._compress_history()
|
||||
|
||||
return message
|
||||
|
||||
def add_user_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add a user message"""
|
||||
return self.add_message(MessageRole.USER, content, metadata)
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add an assistant message"""
|
||||
return self.add_message(MessageRole.ASSISTANT, content, metadata)
|
||||
|
||||
def add_system_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add a system message"""
|
||||
return self.add_message(MessageRole.SYSTEM, content, metadata)
|
||||
|
||||
def add_intent(self, intent: str) -> None:
|
||||
"""
|
||||
Records an intent in the history.
|
||||
|
||||
Args:
|
||||
intent: The detected intent
|
||||
"""
|
||||
self.intent_history.append(intent)
|
||||
# Keep last 20 intents
|
||||
if len(self.intent_history) > 20:
|
||||
self.intent_history = self.intent_history[-20:]
|
||||
|
||||
def set_entity(self, name: str, value: Any) -> None:
|
||||
"""
|
||||
Sets an entity value.
|
||||
|
||||
Args:
|
||||
name: Entity name
|
||||
value: Entity value
|
||||
"""
|
||||
self.entities[name] = value
|
||||
|
||||
def get_entity(self, name: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Gets an entity value.
|
||||
|
||||
Args:
|
||||
name: Entity name
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Entity value or default
|
||||
"""
|
||||
return self.entities.get(name, default)
|
||||
|
||||
def get_last_message(self, role: Optional[MessageRole] = None) -> Optional[Message]:
|
||||
"""
|
||||
Gets the last message, optionally filtered by role.
|
||||
|
||||
Args:
|
||||
role: Optional role filter
|
||||
|
||||
Returns:
|
||||
The last matching message or None
|
||||
"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
if role is None:
|
||||
return self.messages[-1]
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == role:
|
||||
return msg
|
||||
|
||||
return None
|
||||
|
||||
def get_messages_for_llm(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Gets messages formatted for LLM API calls.
|
||||
|
||||
Returns:
|
||||
List of message dicts with role and content
|
||||
"""
|
||||
result = []
|
||||
|
||||
# Add system prompt first
|
||||
if self.system_prompt:
|
||||
result.append({
|
||||
"role": "system",
|
||||
"content": self.system_prompt
|
||||
})
|
||||
|
||||
# Add summary if we have one and history was compressed
|
||||
if self.summary:
|
||||
result.append({
|
||||
"role": "system",
|
||||
"content": f"Previous conversation summary: {self.summary}"
|
||||
})
|
||||
|
||||
# Add recent messages
|
||||
for msg in self.messages:
|
||||
result.append({
|
||||
"role": msg.role.value,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _compress_history(self) -> None:
|
||||
"""
|
||||
Compresses older messages to save context window space.
|
||||
|
||||
Keeps:
|
||||
- System messages
|
||||
- Last 20 messages
|
||||
- Creates summary of compressed middle messages
|
||||
"""
|
||||
# Keep system messages
|
||||
system_msgs = [m for m in self.messages if m.role == MessageRole.SYSTEM]
|
||||
|
||||
# Keep last 20 messages
|
||||
recent_msgs = self.messages[-20:]
|
||||
|
||||
# Middle messages to summarize
|
||||
middle_start = len(system_msgs)
|
||||
middle_end = len(self.messages) - 20
|
||||
middle_msgs = self.messages[middle_start:middle_end]
|
||||
|
||||
if middle_msgs:
|
||||
# Create a basic summary (can be enhanced with LLM-based summarization)
|
||||
self.summary = self._create_summary(middle_msgs)
|
||||
|
||||
# Combine
|
||||
self.messages = system_msgs + recent_msgs
|
||||
|
||||
logger.debug(
|
||||
f"Compressed conversation: {middle_end - middle_start} messages summarized"
|
||||
)
|
||||
|
||||
def _create_summary(self, messages: List[Message]) -> str:
|
||||
"""
|
||||
Creates a summary of messages.
|
||||
|
||||
This is a basic implementation - can be enhanced with LLM-based summarization.
|
||||
|
||||
Args:
|
||||
messages: Messages to summarize
|
||||
|
||||
Returns:
|
||||
Summary string
|
||||
"""
|
||||
# Count message types
|
||||
user_count = sum(1 for m in messages if m.role == MessageRole.USER)
|
||||
assistant_count = sum(1 for m in messages if m.role == MessageRole.ASSISTANT)
|
||||
|
||||
# Extract key topics (simplified - could use NLP)
|
||||
topics = set()
|
||||
for msg in messages:
|
||||
# Simple keyword extraction
|
||||
words = msg.content.lower().split()
|
||||
# Filter common words
|
||||
keywords = [w for w in words if len(w) > 5][:3]
|
||||
topics.update(keywords)
|
||||
|
||||
topics_str = ", ".join(list(topics)[:5])
|
||||
|
||||
return (
|
||||
f"Earlier conversation: {user_count} user messages, "
|
||||
f"{assistant_count} assistant responses. "
|
||||
f"Topics discussed: {topics_str}"
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clears all context"""
|
||||
self.messages.clear()
|
||||
self.entities.clear()
|
||||
self.intent_history.clear()
|
||||
self.summary = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes context to dict"""
|
||||
return {
|
||||
"messages": [m.to_dict() for m in self.messages],
|
||||
"entities": self.entities,
|
||||
"intent_history": self.intent_history,
|
||||
"summary": self.summary,
|
||||
"max_messages": self.max_messages,
|
||||
"system_prompt": self.system_prompt,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ConversationContext":
|
||||
"""Deserializes context from dict"""
|
||||
ctx = cls(
|
||||
messages=[Message.from_dict(m) for m in data.get("messages", [])],
|
||||
entities=data.get("entities", {}),
|
||||
intent_history=data.get("intent_history", []),
|
||||
summary=data.get("summary"),
|
||||
max_messages=data.get("max_messages", 50),
|
||||
system_prompt=data.get("system_prompt"),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
return ctx
|
||||
@@ -9,109 +9,20 @@ Provides entity and relationship management:
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
|
||||
from agent_core.brain.knowledge_models import (
|
||||
EntityType,
|
||||
RelationshipType,
|
||||
Entity,
|
||||
Relationship,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EntityType(Enum):
|
||||
"""Types of entities in the knowledge graph"""
|
||||
STUDENT = "student"
|
||||
TEACHER = "teacher"
|
||||
CLASS = "class"
|
||||
SUBJECT = "subject"
|
||||
ASSIGNMENT = "assignment"
|
||||
EXAM = "exam"
|
||||
TOPIC = "topic"
|
||||
CONCEPT = "concept"
|
||||
RESOURCE = "resource"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class RelationshipType(Enum):
|
||||
"""Types of relationships between entities"""
|
||||
BELONGS_TO = "belongs_to" # Student belongs to class
|
||||
TEACHES = "teaches" # Teacher teaches subject
|
||||
ASSIGNED_TO = "assigned_to" # Assignment assigned to student
|
||||
COVERS = "covers" # Exam covers topic
|
||||
REQUIRES = "requires" # Topic requires concept
|
||||
RELATED_TO = "related_to" # General relationship
|
||||
PARENT_OF = "parent_of" # Hierarchical relationship
|
||||
CREATED_BY = "created_by" # Creator relationship
|
||||
GRADED_BY = "graded_by" # Grading relationship
|
||||
|
||||
|
||||
@dataclass
|
||||
class Entity:
|
||||
"""Represents an entity in the knowledge graph"""
|
||||
id: str
|
||||
entity_type: EntityType
|
||||
name: str
|
||||
properties: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"entity_type": self.entity_type.value,
|
||||
"name": self.name,
|
||||
"properties": self.properties,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Entity":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
entity_type=EntityType(data["entity_type"]),
|
||||
name=data["name"],
|
||||
properties=data.get("properties", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Relationship:
|
||||
"""Represents a relationship between two entities"""
|
||||
id: str
|
||||
source_id: str
|
||||
target_id: str
|
||||
relationship_type: RelationshipType
|
||||
properties: Dict[str, Any] = field(default_factory=dict)
|
||||
weight: float = 1.0
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"relationship_type": self.relationship_type.value,
|
||||
"properties": self.properties,
|
||||
"weight": self.weight,
|
||||
"created_at": self.created_at.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Relationship":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
source_id=data["source_id"],
|
||||
target_id=data["target_id"],
|
||||
relationship_type=RelationshipType(data["relationship_type"]),
|
||||
properties=data.get("properties", {}),
|
||||
weight=data.get("weight", 1.0),
|
||||
created_at=datetime.fromisoformat(data["created_at"])
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeGraph:
|
||||
"""
|
||||
Knowledge graph for managing entity relationships.
|
||||
|
||||
104
agent-core/brain/knowledge_models.py
Normal file
104
agent-core/brain/knowledge_models.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Knowledge Graph Models for Breakpilot Agents
|
||||
|
||||
Entity and relationship data classes, plus type enumerations.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EntityType(Enum):
|
||||
"""Types of entities in the knowledge graph"""
|
||||
STUDENT = "student"
|
||||
TEACHER = "teacher"
|
||||
CLASS = "class"
|
||||
SUBJECT = "subject"
|
||||
ASSIGNMENT = "assignment"
|
||||
EXAM = "exam"
|
||||
TOPIC = "topic"
|
||||
CONCEPT = "concept"
|
||||
RESOURCE = "resource"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class RelationshipType(Enum):
|
||||
"""Types of relationships between entities"""
|
||||
BELONGS_TO = "belongs_to" # Student belongs to class
|
||||
TEACHES = "teaches" # Teacher teaches subject
|
||||
ASSIGNED_TO = "assigned_to" # Assignment assigned to student
|
||||
COVERS = "covers" # Exam covers topic
|
||||
REQUIRES = "requires" # Topic requires concept
|
||||
RELATED_TO = "related_to" # General relationship
|
||||
PARENT_OF = "parent_of" # Hierarchical relationship
|
||||
CREATED_BY = "created_by" # Creator relationship
|
||||
GRADED_BY = "graded_by" # Grading relationship
|
||||
|
||||
|
||||
@dataclass
|
||||
class Entity:
|
||||
"""Represents an entity in the knowledge graph"""
|
||||
id: str
|
||||
entity_type: EntityType
|
||||
name: str
|
||||
properties: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"entity_type": self.entity_type.value,
|
||||
"name": self.name,
|
||||
"properties": self.properties,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Entity":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
entity_type=EntityType(data["entity_type"]),
|
||||
name=data["name"],
|
||||
properties=data.get("properties", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Relationship:
|
||||
"""Represents a relationship between two entities"""
|
||||
id: str
|
||||
source_id: str
|
||||
target_id: str
|
||||
relationship_type: RelationshipType
|
||||
properties: Dict[str, Any] = field(default_factory=dict)
|
||||
weight: float = 1.0
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"relationship_type": self.relationship_type.value,
|
||||
"properties": self.properties,
|
||||
"weight": self.weight,
|
||||
"created_at": self.created_at.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Relationship":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
source_id=data["source_id"],
|
||||
target_id=data["target_id"],
|
||||
relationship_type=RelationshipType(data["relationship_type"]),
|
||||
properties=data.get("properties", {}),
|
||||
weight=data.get("weight", 1.0),
|
||||
created_at=datetime.fromisoformat(data["created_at"])
|
||||
)
|
||||
53
agent-core/brain/memory_models.py
Normal file
53
agent-core/brain/memory_models.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Memory Models for Breakpilot Agents
|
||||
|
||||
Data classes for memory items used by MemoryStore.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class Memory:
|
||||
"""Represents a stored memory item"""
|
||||
key: str
|
||||
value: Any
|
||||
agent_id: str
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
expires_at: Optional[datetime] = None
|
||||
access_count: int = 0
|
||||
last_accessed: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"key": self.key,
|
||||
"value": self.value,
|
||||
"agent_id": self.agent_id,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"access_count": self.access_count,
|
||||
"last_accessed": self.last_accessed.isoformat() if self.last_accessed else None,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Memory":
|
||||
return cls(
|
||||
key=data["key"],
|
||||
value=data["value"],
|
||||
agent_id=data["agent_id"],
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
access_count=data.get("access_count", 0),
|
||||
last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None,
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the memory has expired"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
@@ -1,92 +1,24 @@
|
||||
"""
|
||||
Memory Store for Breakpilot Agents
|
||||
|
||||
Provides long-term memory with:
|
||||
- TTL-based expiration
|
||||
- Access count tracking
|
||||
- Pattern-based search
|
||||
- Hybrid Valkey + PostgreSQL persistence
|
||||
Hybrid Valkey + PostgreSQL persistence with TTL, access tracking, and pattern search.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
from agent_core.brain.memory_models import Memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Memory:
|
||||
"""Represents a stored memory item"""
|
||||
key: str
|
||||
value: Any
|
||||
agent_id: str
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
expires_at: Optional[datetime] = None
|
||||
access_count: int = 0
|
||||
last_accessed: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"key": self.key,
|
||||
"value": self.value,
|
||||
"agent_id": self.agent_id,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"access_count": self.access_count,
|
||||
"last_accessed": self.last_accessed.isoformat() if self.last_accessed else None,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Memory":
|
||||
return cls(
|
||||
key=data["key"],
|
||||
value=data["value"],
|
||||
agent_id=data["agent_id"],
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
access_count=data.get("access_count", 0),
|
||||
last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None,
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the memory has expired"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""
|
||||
Long-term memory store for agents.
|
||||
"""Long-term memory store with TTL, access tracking, and hybrid persistence."""
|
||||
|
||||
Stores facts, decisions, and learning progress with:
|
||||
- TTL-based expiration
|
||||
- Access tracking for importance scoring
|
||||
- Pattern-based retrieval
|
||||
- Hybrid persistence (Valkey for fast access, PostgreSQL for durability)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot"
|
||||
):
|
||||
"""
|
||||
Initialize the memory store.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
namespace: Key namespace for isolation
|
||||
"""
|
||||
def __init__(self, redis_client=None, db_pool=None, namespace: str = "breakpilot"):
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
@@ -103,26 +35,10 @@ class MemoryStore:
|
||||
return key
|
||||
|
||||
async def remember(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
agent_id: str,
|
||||
ttl_days: int = 30,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, key: str, value: Any, agent_id: str,
|
||||
ttl_days: int = 30, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Memory:
|
||||
"""
|
||||
Stores a memory.
|
||||
|
||||
Args:
|
||||
key: Unique key for the memory
|
||||
value: Value to store (must be JSON-serializable)
|
||||
agent_id: ID of the agent storing the memory
|
||||
ttl_days: Time to live in days (0 = no expiration)
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
The created Memory object
|
||||
"""
|
||||
"""Stores a memory with optional TTL and metadata."""
|
||||
expires_at = None
|
||||
if ttl_days > 0:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=ttl_days)
|
||||
@@ -143,32 +59,14 @@ class MemoryStore:
|
||||
return memory
|
||||
|
||||
async def recall(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieves a memory value by key.
|
||||
|
||||
Args:
|
||||
key: The memory key
|
||||
|
||||
Returns:
|
||||
The stored value or None if not found/expired
|
||||
"""
|
||||
"""Retrieves a memory value by key, or None if not found/expired."""
|
||||
memory = await self.get_memory(key)
|
||||
if memory:
|
||||
return memory.value
|
||||
return None
|
||||
|
||||
async def get_memory(self, key: str) -> Optional[Memory]:
|
||||
"""
|
||||
Retrieves a full Memory object by key.
|
||||
|
||||
Updates access count and last_accessed timestamp.
|
||||
|
||||
Args:
|
||||
key: The memory key
|
||||
|
||||
Returns:
|
||||
Memory object or None if not found/expired
|
||||
"""
|
||||
"""Retrieves a full Memory object by key, updating access count."""
|
||||
# Check local cache
|
||||
if key in self._local_cache:
|
||||
memory = self._local_cache[key]
|
||||
|
||||
@@ -12,11 +12,13 @@ from agent_core.orchestrator.message_bus import (
|
||||
AgentMessage,
|
||||
MessagePriority,
|
||||
)
|
||||
from agent_core.orchestrator.supervisor import (
|
||||
AgentSupervisor,
|
||||
AgentInfo,
|
||||
from agent_core.orchestrator.supervisor_models import (
|
||||
AgentStatus,
|
||||
RestartPolicy,
|
||||
AgentInfo,
|
||||
AgentFactory,
|
||||
)
|
||||
from agent_core.orchestrator.supervisor import AgentSupervisor
|
||||
from agent_core.orchestrator.task_router import (
|
||||
TaskRouter,
|
||||
RoutingResult,
|
||||
@@ -30,6 +32,8 @@ __all__ = [
|
||||
"AgentSupervisor",
|
||||
"AgentInfo",
|
||||
"AgentStatus",
|
||||
"RestartPolicy",
|
||||
"AgentFactory",
|
||||
"TaskRouter",
|
||||
"RoutingResult",
|
||||
"RoutingStrategy",
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
"""
|
||||
Agent Supervisor for Breakpilot
|
||||
|
||||
Provides:
|
||||
- Agent lifecycle management
|
||||
- Health monitoring
|
||||
- Restart policies
|
||||
- Load balancing
|
||||
Agent lifecycle management, health monitoring, restart policies, load balancing.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Callable, Awaitable, List, Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, List, Any
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@@ -21,91 +15,24 @@ from agent_core.orchestrator.message_bus import (
|
||||
AgentMessage,
|
||||
MessagePriority,
|
||||
)
|
||||
from agent_core.orchestrator.supervisor_models import (
|
||||
AgentStatus,
|
||||
RestartPolicy,
|
||||
AgentInfo,
|
||||
AgentFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentStatus(Enum):
|
||||
"""Agent lifecycle states"""
|
||||
INITIALIZING = "initializing"
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
STOPPING = "stopping"
|
||||
STOPPED = "stopped"
|
||||
ERROR = "error"
|
||||
RESTARTING = "restarting"
|
||||
|
||||
|
||||
class RestartPolicy(Enum):
|
||||
"""Agent restart policies"""
|
||||
NEVER = "never"
|
||||
ON_FAILURE = "on_failure"
|
||||
ALWAYS = "always"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
"""Information about a registered agent"""
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
status: AgentStatus = AgentStatus.INITIALIZING
|
||||
current_task: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
last_activity: Optional[datetime] = None
|
||||
error_count: int = 0
|
||||
restart_count: int = 0
|
||||
max_restarts: int = 3
|
||||
restart_policy: RestartPolicy = RestartPolicy.ON_FAILURE
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
capacity: int = 10 # Max concurrent tasks
|
||||
current_load: int = 0
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
"""Check if agent is healthy"""
|
||||
return self.status == AgentStatus.RUNNING and self.error_count < 3
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if agent can accept new tasks"""
|
||||
return (
|
||||
self.status == AgentStatus.RUNNING and
|
||||
self.current_load < self.capacity
|
||||
)
|
||||
|
||||
def utilization(self) -> float:
|
||||
"""Returns agent utilization (0-1)"""
|
||||
return self.current_load / max(self.capacity, 1)
|
||||
|
||||
|
||||
AgentFactory = Callable[[str], Awaitable[Any]]
|
||||
|
||||
|
||||
class AgentSupervisor:
|
||||
"""
|
||||
Supervises and coordinates all agents.
|
||||
|
||||
Responsibilities:
|
||||
- Agent registration and lifecycle
|
||||
- Health monitoring via heartbeat
|
||||
- Restart policies
|
||||
- Load balancing
|
||||
- Alert escalation
|
||||
"""
|
||||
"""Supervises agents: lifecycle, health monitoring, restart policies, load balancing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_bus: MessageBus,
|
||||
self, message_bus: MessageBus,
|
||||
heartbeat_monitor: Optional[HeartbeatMonitor] = None,
|
||||
check_interval_seconds: int = 10
|
||||
):
|
||||
"""
|
||||
Initialize the supervisor.
|
||||
|
||||
Args:
|
||||
message_bus: Message bus for inter-agent communication
|
||||
heartbeat_monitor: Heartbeat monitor for liveness checks
|
||||
check_interval_seconds: How often to run health checks
|
||||
"""
|
||||
self.bus = message_bus
|
||||
self.heartbeat = heartbeat_monitor or HeartbeatMonitor()
|
||||
self.check_interval = check_interval_seconds
|
||||
|
||||
65
agent-core/orchestrator/supervisor_models.py
Normal file
65
agent-core/orchestrator/supervisor_models.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Supervisor Models for Breakpilot Agents
|
||||
|
||||
Data classes and enumerations for agent lifecycle management.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AgentStatus(Enum):
|
||||
"""Agent lifecycle states"""
|
||||
INITIALIZING = "initializing"
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
STOPPING = "stopping"
|
||||
STOPPED = "stopped"
|
||||
ERROR = "error"
|
||||
RESTARTING = "restarting"
|
||||
|
||||
|
||||
class RestartPolicy(Enum):
|
||||
"""Agent restart policies"""
|
||||
NEVER = "never"
|
||||
ON_FAILURE = "on_failure"
|
||||
ALWAYS = "always"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
"""Information about a registered agent"""
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
status: AgentStatus = AgentStatus.INITIALIZING
|
||||
current_task: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
last_activity: Optional[datetime] = None
|
||||
error_count: int = 0
|
||||
restart_count: int = 0
|
||||
max_restarts: int = 3
|
||||
restart_policy: RestartPolicy = RestartPolicy.ON_FAILURE
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
capacity: int = 10 # Max concurrent tasks
|
||||
current_load: int = 0
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
"""Check if agent is healthy"""
|
||||
return self.status == AgentStatus.RUNNING and self.error_count < 3
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if agent can accept new tasks"""
|
||||
return (
|
||||
self.status == AgentStatus.RUNNING and
|
||||
self.current_load < self.capacity
|
||||
)
|
||||
|
||||
def utilization(self) -> float:
|
||||
"""Returns agent utilization (0-1)"""
|
||||
return self.current_load / max(self.capacity, 1)
|
||||
|
||||
|
||||
AgentFactory = Callable[[str], Awaitable[Any]]
|
||||
@@ -8,11 +8,12 @@ Provides:
|
||||
- SessionState: Session state enumeration
|
||||
"""
|
||||
|
||||
from agent_core.sessions.session_manager import (
|
||||
from agent_core.sessions.session_models import (
|
||||
AgentSession,
|
||||
SessionManager,
|
||||
SessionState,
|
||||
SessionCheckpoint,
|
||||
)
|
||||
from agent_core.sessions.session_manager import SessionManager
|
||||
from agent_core.sessions.heartbeat import HeartbeatMonitor
|
||||
from agent_core.sessions.checkpoint import CheckpointManager
|
||||
|
||||
@@ -20,6 +21,7 @@ __all__ = [
|
||||
"AgentSession",
|
||||
"SessionManager",
|
||||
"SessionState",
|
||||
"SessionCheckpoint",
|
||||
"HeartbeatMonitor",
|
||||
"CheckpointManager",
|
||||
]
|
||||
|
||||
@@ -2,189 +2,25 @@
|
||||
Session Management for Breakpilot Agents
|
||||
|
||||
Provides session lifecycle management with:
|
||||
- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED)
|
||||
- Checkpoint-based recovery
|
||||
- Heartbeat integration
|
||||
- Hybrid Valkey + PostgreSQL persistence
|
||||
- Session CRUD operations
|
||||
- Stale session cleanup
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from enum import Enum
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
|
||||
from agent_core.sessions.session_models import (
|
||||
SessionState,
|
||||
SessionCheckpoint,
|
||||
AgentSession,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionState(Enum):
|
||||
"""Agent session states"""
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCheckpoint:
|
||||
"""Represents a checkpoint in an agent session"""
|
||||
name: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
data=data["data"]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSession:
|
||||
"""
|
||||
Represents an active agent session.
|
||||
|
||||
Attributes:
|
||||
session_id: Unique session identifier
|
||||
agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
|
||||
user_id: Associated user ID
|
||||
state: Current session state
|
||||
created_at: Session creation timestamp
|
||||
last_heartbeat: Last heartbeat timestamp
|
||||
context: Session context data
|
||||
checkpoints: List of session checkpoints for recovery
|
||||
"""
|
||||
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
agent_type: str = ""
|
||||
user_id: str = ""
|
||||
state: SessionState = SessionState.ACTIVE
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
checkpoints: List[SessionCheckpoint] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
|
||||
"""
|
||||
Creates a checkpoint for recovery.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name (e.g., "task_received", "processing_complete")
|
||||
data: Checkpoint data to store
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint = SessionCheckpoint(
|
||||
name=name,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data=data
|
||||
)
|
||||
self.checkpoints.append(checkpoint)
|
||||
logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
|
||||
return checkpoint
|
||||
|
||||
def heartbeat(self) -> None:
|
||||
"""Updates the heartbeat timestamp"""
|
||||
self.last_heartbeat = datetime.now(timezone.utc)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pauses the session"""
|
||||
self.state = SessionState.PAUSED
|
||||
self.checkpoint("session_paused", {"previous_state": "active"})
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resumes a paused session"""
|
||||
if self.state == SessionState.PAUSED:
|
||||
self.state = SessionState.ACTIVE
|
||||
self.heartbeat()
|
||||
self.checkpoint("session_resumed", {})
|
||||
|
||||
def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as completed"""
|
||||
self.state = SessionState.COMPLETED
|
||||
self.checkpoint("session_completed", {"result": result or {}})
|
||||
|
||||
def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as failed"""
|
||||
self.state = SessionState.FAILED
|
||||
self.checkpoint("session_failed", {
|
||||
"error": error,
|
||||
"details": error_details or {}
|
||||
})
|
||||
|
||||
def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
|
||||
"""
|
||||
Gets the last checkpoint, optionally filtered by name.
|
||||
|
||||
Args:
|
||||
name: Optional checkpoint name to filter by
|
||||
|
||||
Returns:
|
||||
The last matching checkpoint or None
|
||||
"""
|
||||
if not self.checkpoints:
|
||||
return None
|
||||
|
||||
if name:
|
||||
matching = [cp for cp in self.checkpoints if cp.name == name]
|
||||
return matching[-1] if matching else None
|
||||
|
||||
return self.checkpoints[-1]
|
||||
|
||||
def get_duration(self) -> timedelta:
|
||||
"""Returns the session duration"""
|
||||
end_time = datetime.now(timezone.utc)
|
||||
if self.state in (SessionState.COMPLETED, SessionState.FAILED):
|
||||
last_cp = self.get_last_checkpoint()
|
||||
if last_cp:
|
||||
end_time = last_cp.timestamp
|
||||
return end_time - self.created_at
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes the session to a dictionary"""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": self.agent_type,
|
||||
"user_id": self.user_id,
|
||||
"state": self.state.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"last_heartbeat": self.last_heartbeat.isoformat(),
|
||||
"context": self.context,
|
||||
"checkpoints": [cp.to_dict() for cp in self.checkpoints],
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
|
||||
"""Deserializes a session from a dictionary"""
|
||||
return cls(
|
||||
session_id=data["session_id"],
|
||||
agent_type=data["agent_type"],
|
||||
user_id=data["user_id"],
|
||||
state=SessionState(data["state"]),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
|
||||
context=data.get("context", {}),
|
||||
checkpoints=[
|
||||
SessionCheckpoint.from_dict(cp)
|
||||
for cp in data.get("checkpoints", [])
|
||||
],
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
|
||||
@@ -303,7 +139,6 @@ class SessionManager:
|
||||
"""
|
||||
session.heartbeat()
|
||||
self._local_cache[session.session_id] = session
|
||||
self._local_cache[session.session_id] = session
|
||||
await self._persist_session(session)
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
|
||||
180
agent-core/sessions/session_models.py
Normal file
180
agent-core/sessions/session_models.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Session Models for Breakpilot Agents
|
||||
|
||||
Data classes for agent sessions, checkpoints, and state tracking.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from enum import Enum
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionState(Enum):
|
||||
"""Agent session states"""
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCheckpoint:
|
||||
"""Represents a checkpoint in an agent session"""
|
||||
name: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
data=data["data"]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSession:
|
||||
"""
|
||||
Represents an active agent session.
|
||||
|
||||
Attributes:
|
||||
session_id: Unique session identifier
|
||||
agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
|
||||
user_id: Associated user ID
|
||||
state: Current session state
|
||||
created_at: Session creation timestamp
|
||||
last_heartbeat: Last heartbeat timestamp
|
||||
context: Session context data
|
||||
checkpoints: List of session checkpoints for recovery
|
||||
"""
|
||||
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
agent_type: str = ""
|
||||
user_id: str = ""
|
||||
state: SessionState = SessionState.ACTIVE
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
checkpoints: List[SessionCheckpoint] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
|
||||
"""
|
||||
Creates a checkpoint for recovery.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name (e.g., "task_received", "processing_complete")
|
||||
data: Checkpoint data to store
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint = SessionCheckpoint(
|
||||
name=name,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data=data
|
||||
)
|
||||
self.checkpoints.append(checkpoint)
|
||||
logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
|
||||
return checkpoint
|
||||
|
||||
def heartbeat(self) -> None:
|
||||
"""Updates the heartbeat timestamp"""
|
||||
self.last_heartbeat = datetime.now(timezone.utc)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pauses the session"""
|
||||
self.state = SessionState.PAUSED
|
||||
self.checkpoint("session_paused", {"previous_state": "active"})
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resumes a paused session"""
|
||||
if self.state == SessionState.PAUSED:
|
||||
self.state = SessionState.ACTIVE
|
||||
self.heartbeat()
|
||||
self.checkpoint("session_resumed", {})
|
||||
|
||||
def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as completed"""
|
||||
self.state = SessionState.COMPLETED
|
||||
self.checkpoint("session_completed", {"result": result or {}})
|
||||
|
||||
def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as failed"""
|
||||
self.state = SessionState.FAILED
|
||||
self.checkpoint("session_failed", {
|
||||
"error": error,
|
||||
"details": error_details or {}
|
||||
})
|
||||
|
||||
def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
|
||||
"""
|
||||
Gets the last checkpoint, optionally filtered by name.
|
||||
|
||||
Args:
|
||||
name: Optional checkpoint name to filter by
|
||||
|
||||
Returns:
|
||||
The last matching checkpoint or None
|
||||
"""
|
||||
if not self.checkpoints:
|
||||
return None
|
||||
|
||||
if name:
|
||||
matching = [cp for cp in self.checkpoints if cp.name == name]
|
||||
return matching[-1] if matching else None
|
||||
|
||||
return self.checkpoints[-1]
|
||||
|
||||
def get_duration(self) -> timedelta:
|
||||
"""Returns the session duration"""
|
||||
end_time = datetime.now(timezone.utc)
|
||||
if self.state in (SessionState.COMPLETED, SessionState.FAILED):
|
||||
last_cp = self.get_last_checkpoint()
|
||||
if last_cp:
|
||||
end_time = last_cp.timestamp
|
||||
return end_time - self.created_at
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes the session to a dictionary"""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": self.agent_type,
|
||||
"user_id": self.user_id,
|
||||
"state": self.state.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"last_heartbeat": self.last_heartbeat.isoformat(),
|
||||
"context": self.context,
|
||||
"checkpoints": [cp.to_dict() for cp in self.checkpoints],
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
|
||||
"""Deserializes a session from a dictionary"""
|
||||
return cls(
|
||||
session_id=data["session_id"],
|
||||
agent_type=data["agent_type"],
|
||||
user_id=data["user_id"],
|
||||
state=SessionState(data["state"]),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
|
||||
context=data.get("context", {}),
|
||||
checkpoints=[
|
||||
SessionCheckpoint.from_dict(cp)
|
||||
for cp in data.get("checkpoints", [])
|
||||
],
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
Reference in New Issue
Block a user