fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
22
agent-core/brain/__init__.py
Normal file
22
agent-core/brain/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Shared Brain for Breakpilot Agents
|
||||
|
||||
Provides:
|
||||
- MemoryStore: Long-term memory with TTL and access tracking
|
||||
- ConversationContext: Session context with message history
|
||||
- KnowledgeGraph: Entity relationships and semantic connections
|
||||
"""
|
||||
|
||||
from agent_core.brain.memory_store import MemoryStore, Memory
|
||||
from agent_core.brain.context_manager import ConversationContext, ContextManager
|
||||
from agent_core.brain.knowledge_graph import KnowledgeGraph, Entity, Relationship
|
||||
|
||||
__all__ = [
|
||||
"MemoryStore",
|
||||
"Memory",
|
||||
"ConversationContext",
|
||||
"ContextManager",
|
||||
"KnowledgeGraph",
|
||||
"Entity",
|
||||
"Relationship",
|
||||
]
|
||||
520
agent-core/brain/context_manager.py
Normal file
520
agent-core/brain/context_manager.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""
|
||||
Context Management for Breakpilot Agents
|
||||
|
||||
Provides conversation context with:
|
||||
- Message history with compression
|
||||
- Entity extraction and tracking
|
||||
- Intent history
|
||||
- Context summarization
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageRole(Enum):
|
||||
"""Message roles in a conversation"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""Represents a message in a conversation"""
|
||||
role: MessageRole
|
||||
content: str
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": self.role.value,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Message":
|
||||
return cls(
|
||||
role=MessageRole(data["role"]),
|
||||
content=data["content"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""
|
||||
Context for a running conversation.
|
||||
|
||||
Maintains:
|
||||
- Message history with automatic compression
|
||||
- Extracted entities
|
||||
- Intent history
|
||||
- Conversation summary
|
||||
"""
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
entities: Dict[str, Any] = field(default_factory=dict)
|
||||
intent_history: List[str] = field(default_factory=list)
|
||||
summary: Optional[str] = None
|
||||
max_messages: int = 50
|
||||
system_prompt: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""
|
||||
Adds a message to the conversation.
|
||||
|
||||
Args:
|
||||
role: Message role
|
||||
content: Message content
|
||||
metadata: Optional message metadata
|
||||
|
||||
Returns:
|
||||
The created Message
|
||||
"""
|
||||
message = Message(
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
self.messages.append(message)
|
||||
|
||||
# Compress if needed
|
||||
if len(self.messages) > self.max_messages:
|
||||
self._compress_history()
|
||||
|
||||
return message
|
||||
|
||||
def add_user_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add a user message"""
|
||||
return self.add_message(MessageRole.USER, content, metadata)
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add an assistant message"""
|
||||
return self.add_message(MessageRole.ASSISTANT, content, metadata)
|
||||
|
||||
def add_system_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add a system message"""
|
||||
return self.add_message(MessageRole.SYSTEM, content, metadata)
|
||||
|
||||
def add_intent(self, intent: str) -> None:
|
||||
"""
|
||||
Records an intent in the history.
|
||||
|
||||
Args:
|
||||
intent: The detected intent
|
||||
"""
|
||||
self.intent_history.append(intent)
|
||||
# Keep last 20 intents
|
||||
if len(self.intent_history) > 20:
|
||||
self.intent_history = self.intent_history[-20:]
|
||||
|
||||
def set_entity(self, name: str, value: Any) -> None:
|
||||
"""
|
||||
Sets an entity value.
|
||||
|
||||
Args:
|
||||
name: Entity name
|
||||
value: Entity value
|
||||
"""
|
||||
self.entities[name] = value
|
||||
|
||||
def get_entity(self, name: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Gets an entity value.
|
||||
|
||||
Args:
|
||||
name: Entity name
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Entity value or default
|
||||
"""
|
||||
return self.entities.get(name, default)
|
||||
|
||||
def get_last_message(self, role: Optional[MessageRole] = None) -> Optional[Message]:
|
||||
"""
|
||||
Gets the last message, optionally filtered by role.
|
||||
|
||||
Args:
|
||||
role: Optional role filter
|
||||
|
||||
Returns:
|
||||
The last matching message or None
|
||||
"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
if role is None:
|
||||
return self.messages[-1]
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == role:
|
||||
return msg
|
||||
|
||||
return None
|
||||
|
||||
def get_messages_for_llm(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Gets messages formatted for LLM API calls.
|
||||
|
||||
Returns:
|
||||
List of message dicts with role and content
|
||||
"""
|
||||
result = []
|
||||
|
||||
# Add system prompt first
|
||||
if self.system_prompt:
|
||||
result.append({
|
||||
"role": "system",
|
||||
"content": self.system_prompt
|
||||
})
|
||||
|
||||
# Add summary if we have one and history was compressed
|
||||
if self.summary:
|
||||
result.append({
|
||||
"role": "system",
|
||||
"content": f"Previous conversation summary: {self.summary}"
|
||||
})
|
||||
|
||||
# Add recent messages
|
||||
for msg in self.messages:
|
||||
result.append({
|
||||
"role": msg.role.value,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _compress_history(self) -> None:
|
||||
"""
|
||||
Compresses older messages to save context window space.
|
||||
|
||||
Keeps:
|
||||
- System messages
|
||||
- Last 20 messages
|
||||
- Creates summary of compressed middle messages
|
||||
"""
|
||||
# Keep system messages
|
||||
system_msgs = [m for m in self.messages if m.role == MessageRole.SYSTEM]
|
||||
|
||||
# Keep last 20 messages
|
||||
recent_msgs = self.messages[-20:]
|
||||
|
||||
# Middle messages to summarize
|
||||
middle_start = len(system_msgs)
|
||||
middle_end = len(self.messages) - 20
|
||||
middle_msgs = self.messages[middle_start:middle_end]
|
||||
|
||||
if middle_msgs:
|
||||
# Create a basic summary (can be enhanced with LLM-based summarization)
|
||||
self.summary = self._create_summary(middle_msgs)
|
||||
|
||||
# Combine
|
||||
self.messages = system_msgs + recent_msgs
|
||||
|
||||
logger.debug(
|
||||
f"Compressed conversation: {middle_end - middle_start} messages summarized"
|
||||
)
|
||||
|
||||
def _create_summary(self, messages: List[Message]) -> str:
|
||||
"""
|
||||
Creates a summary of messages.
|
||||
|
||||
This is a basic implementation - can be enhanced with LLM-based summarization.
|
||||
|
||||
Args:
|
||||
messages: Messages to summarize
|
||||
|
||||
Returns:
|
||||
Summary string
|
||||
"""
|
||||
# Count message types
|
||||
user_count = sum(1 for m in messages if m.role == MessageRole.USER)
|
||||
assistant_count = sum(1 for m in messages if m.role == MessageRole.ASSISTANT)
|
||||
|
||||
# Extract key topics (simplified - could use NLP)
|
||||
topics = set()
|
||||
for msg in messages:
|
||||
# Simple keyword extraction
|
||||
words = msg.content.lower().split()
|
||||
# Filter common words
|
||||
keywords = [w for w in words if len(w) > 5][:3]
|
||||
topics.update(keywords)
|
||||
|
||||
topics_str = ", ".join(list(topics)[:5])
|
||||
|
||||
return (
|
||||
f"Earlier conversation: {user_count} user messages, "
|
||||
f"{assistant_count} assistant responses. "
|
||||
f"Topics discussed: {topics_str}"
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clears all context"""
|
||||
self.messages.clear()
|
||||
self.entities.clear()
|
||||
self.intent_history.clear()
|
||||
self.summary = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes context to dict"""
|
||||
return {
|
||||
"messages": [m.to_dict() for m in self.messages],
|
||||
"entities": self.entities,
|
||||
"intent_history": self.intent_history,
|
||||
"summary": self.summary,
|
||||
"max_messages": self.max_messages,
|
||||
"system_prompt": self.system_prompt,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ConversationContext":
|
||||
"""Deserializes context from dict"""
|
||||
ctx = cls(
|
||||
messages=[Message.from_dict(m) for m in data.get("messages", [])],
|
||||
entities=data.get("entities", {}),
|
||||
intent_history=data.get("intent_history", []),
|
||||
summary=data.get("summary"),
|
||||
max_messages=data.get("max_messages", 50),
|
||||
system_prompt=data.get("system_prompt"),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""
|
||||
Manages conversation contexts for multiple sessions.
|
||||
|
||||
Provides:
|
||||
- Context creation and retrieval
|
||||
- Persistence to Valkey/PostgreSQL
|
||||
- Context sharing between agents
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot"
|
||||
):
|
||||
"""
|
||||
Initialize the context manager.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
namespace: Key namespace
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
self._contexts: Dict[str, ConversationContext] = {}
|
||||
self._summarize_callback: Optional[Callable[[List[Message]], Awaitable[str]]] = None
|
||||
|
||||
def _redis_key(self, session_id: str) -> str:
|
||||
"""Generate Redis key for context"""
|
||||
return f"{self.namespace}:context:{session_id}"
|
||||
|
||||
def create_context(
|
||||
self,
|
||||
session_id: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_messages: int = 50
|
||||
) -> ConversationContext:
|
||||
"""
|
||||
Creates a new conversation context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID for this context
|
||||
system_prompt: Optional system prompt
|
||||
max_messages: Maximum messages before compression
|
||||
|
||||
Returns:
|
||||
The created context
|
||||
"""
|
||||
context = ConversationContext(
|
||||
max_messages=max_messages,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
self._contexts[session_id] = context
|
||||
return context
|
||||
|
||||
async def get_context(self, session_id: str) -> Optional[ConversationContext]:
|
||||
"""
|
||||
Gets a context by session ID.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
ConversationContext or None
|
||||
"""
|
||||
# Check local cache
|
||||
if session_id in self._contexts:
|
||||
return self._contexts[session_id]
|
||||
|
||||
# Try Valkey
|
||||
context = await self._get_from_valkey(session_id)
|
||||
if context:
|
||||
self._contexts[session_id] = context
|
||||
return context
|
||||
|
||||
return None
|
||||
|
||||
async def save_context(self, session_id: str) -> None:
|
||||
"""
|
||||
Saves a context to persistent storage.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
"""
|
||||
if session_id not in self._contexts:
|
||||
return
|
||||
|
||||
context = self._contexts[session_id]
|
||||
await self._cache_in_valkey(session_id, context)
|
||||
|
||||
async def delete_context(self, session_id: str) -> bool:
|
||||
"""
|
||||
Deletes a context.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
self._contexts.pop(session_id, None)
|
||||
|
||||
if self.redis:
|
||||
await self.redis.delete(self._redis_key(session_id))
|
||||
|
||||
return True
|
||||
|
||||
def set_summarize_callback(
|
||||
self,
|
||||
callback: Callable[[List[Message]], Awaitable[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Sets a callback for LLM-based summarization.
|
||||
|
||||
Args:
|
||||
callback: Async function that takes messages and returns summary
|
||||
"""
|
||||
self._summarize_callback = callback
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
session_id: str,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Message]:
|
||||
"""
|
||||
Adds a message to a session's context.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
role: Message role
|
||||
content: Message content
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
The created message or None if context not found
|
||||
"""
|
||||
context = await self.get_context(session_id)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
message = context.add_message(role, content, metadata)
|
||||
|
||||
# Save after each message
|
||||
await self.save_context(session_id)
|
||||
|
||||
return message
|
||||
|
||||
async def get_messages_for_llm(
|
||||
self,
|
||||
session_id: str
|
||||
) -> Optional[List[Dict[str, str]]]:
|
||||
"""
|
||||
Gets formatted messages for LLM API call.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
List of message dicts or None
|
||||
"""
|
||||
context = await self.get_context(session_id)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
return context.get_messages_for_llm()
|
||||
|
||||
async def _cache_in_valkey(
|
||||
self,
|
||||
session_id: str,
|
||||
context: ConversationContext
|
||||
) -> None:
|
||||
"""Caches context in Valkey"""
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
try:
|
||||
# 24 hour TTL for contexts
|
||||
await self.redis.setex(
|
||||
self._redis_key(session_id),
|
||||
86400,
|
||||
json.dumps(context.to_dict())
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache context in Valkey: {e}")
|
||||
|
||||
async def _get_from_valkey(
|
||||
self,
|
||||
session_id: str
|
||||
) -> Optional[ConversationContext]:
|
||||
"""Retrieves context from Valkey"""
|
||||
if not self.redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self.redis.get(self._redis_key(session_id))
|
||||
if data:
|
||||
return ConversationContext.from_dict(json.loads(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get context from Valkey: {e}")
|
||||
|
||||
return None
|
||||
563
agent-core/brain/knowledge_graph.py
Normal file
563
agent-core/brain/knowledge_graph.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""
|
||||
Knowledge Graph for Breakpilot Agents
|
||||
|
||||
Provides entity and relationship management:
|
||||
- Entity storage with properties
|
||||
- Relationship definitions
|
||||
- Graph traversal
|
||||
- Optional Qdrant integration for semantic search
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EntityType(Enum):
|
||||
"""Types of entities in the knowledge graph"""
|
||||
STUDENT = "student"
|
||||
TEACHER = "teacher"
|
||||
CLASS = "class"
|
||||
SUBJECT = "subject"
|
||||
ASSIGNMENT = "assignment"
|
||||
EXAM = "exam"
|
||||
TOPIC = "topic"
|
||||
CONCEPT = "concept"
|
||||
RESOURCE = "resource"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class RelationshipType(Enum):
|
||||
"""Types of relationships between entities"""
|
||||
BELONGS_TO = "belongs_to" # Student belongs to class
|
||||
TEACHES = "teaches" # Teacher teaches subject
|
||||
ASSIGNED_TO = "assigned_to" # Assignment assigned to student
|
||||
COVERS = "covers" # Exam covers topic
|
||||
REQUIRES = "requires" # Topic requires concept
|
||||
RELATED_TO = "related_to" # General relationship
|
||||
PARENT_OF = "parent_of" # Hierarchical relationship
|
||||
CREATED_BY = "created_by" # Creator relationship
|
||||
GRADED_BY = "graded_by" # Grading relationship
|
||||
|
||||
|
||||
@dataclass
|
||||
class Entity:
|
||||
"""Represents an entity in the knowledge graph"""
|
||||
id: str
|
||||
entity_type: EntityType
|
||||
name: str
|
||||
properties: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"entity_type": self.entity_type.value,
|
||||
"name": self.name,
|
||||
"properties": self.properties,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Entity":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
entity_type=EntityType(data["entity_type"]),
|
||||
name=data["name"],
|
||||
properties=data.get("properties", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Relationship:
|
||||
"""Represents a relationship between two entities"""
|
||||
id: str
|
||||
source_id: str
|
||||
target_id: str
|
||||
relationship_type: RelationshipType
|
||||
properties: Dict[str, Any] = field(default_factory=dict)
|
||||
weight: float = 1.0
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"relationship_type": self.relationship_type.value,
|
||||
"properties": self.properties,
|
||||
"weight": self.weight,
|
||||
"created_at": self.created_at.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Relationship":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
source_id=data["source_id"],
|
||||
target_id=data["target_id"],
|
||||
relationship_type=RelationshipType(data["relationship_type"]),
|
||||
properties=data.get("properties", {}),
|
||||
weight=data.get("weight", 1.0),
|
||||
created_at=datetime.fromisoformat(data["created_at"])
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeGraph:
|
||||
"""
|
||||
Knowledge graph for managing entity relationships.
|
||||
|
||||
Provides:
|
||||
- Entity CRUD operations
|
||||
- Relationship management
|
||||
- Graph traversal (neighbors, paths)
|
||||
- Optional vector search via Qdrant
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_pool=None,
|
||||
qdrant_client=None,
|
||||
namespace: str = "breakpilot"
|
||||
):
|
||||
"""
|
||||
Initialize the knowledge graph.
|
||||
|
||||
Args:
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
qdrant_client: Optional Qdrant client for vector search
|
||||
namespace: Namespace for isolation
|
||||
"""
|
||||
self.db_pool = db_pool
|
||||
self.qdrant = qdrant_client
|
||||
self.namespace = namespace
|
||||
self._entities: Dict[str, Entity] = {}
|
||||
self._relationships: Dict[str, Relationship] = {}
|
||||
self._adjacency: Dict[str, Set[str]] = {} # entity_id -> set of relationship_ids
|
||||
|
||||
# Entity Operations
|
||||
|
||||
def add_entity(
|
||||
self,
|
||||
entity_id: str,
|
||||
entity_type: EntityType,
|
||||
name: str,
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
) -> Entity:
|
||||
"""
|
||||
Adds an entity to the graph.
|
||||
|
||||
Args:
|
||||
entity_id: Unique entity identifier
|
||||
entity_type: Type of entity
|
||||
name: Human-readable name
|
||||
properties: Entity properties
|
||||
|
||||
Returns:
|
||||
The created Entity
|
||||
"""
|
||||
entity = Entity(
|
||||
id=entity_id,
|
||||
entity_type=entity_type,
|
||||
name=name,
|
||||
properties=properties or {}
|
||||
)
|
||||
|
||||
self._entities[entity_id] = entity
|
||||
self._adjacency[entity_id] = set()
|
||||
|
||||
logger.debug(f"Added entity: {entity_type.value}/{entity_id}")
|
||||
|
||||
return entity
|
||||
|
||||
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
||||
"""Gets an entity by ID"""
|
||||
return self._entities.get(entity_id)
|
||||
|
||||
def update_entity(
|
||||
self,
|
||||
entity_id: str,
|
||||
name: Optional[str] = None,
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Entity]:
|
||||
"""
|
||||
Updates an entity.
|
||||
|
||||
Args:
|
||||
entity_id: Entity to update
|
||||
name: New name (optional)
|
||||
properties: Properties to update (merged)
|
||||
|
||||
Returns:
|
||||
Updated entity or None if not found
|
||||
"""
|
||||
entity = self._entities.get(entity_id)
|
||||
if not entity:
|
||||
return None
|
||||
|
||||
if name:
|
||||
entity.name = name
|
||||
|
||||
if properties:
|
||||
entity.properties.update(properties)
|
||||
|
||||
entity.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
return entity
|
||||
|
||||
def delete_entity(self, entity_id: str) -> bool:
|
||||
"""
|
||||
Deletes an entity and its relationships.
|
||||
|
||||
Args:
|
||||
entity_id: Entity to delete
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
if entity_id not in self._entities:
|
||||
return False
|
||||
|
||||
# Delete all relationships involving this entity
|
||||
rel_ids = list(self._adjacency.get(entity_id, set()))
|
||||
for rel_id in rel_ids:
|
||||
self._delete_relationship_internal(rel_id)
|
||||
|
||||
del self._entities[entity_id]
|
||||
del self._adjacency[entity_id]
|
||||
|
||||
return True
|
||||
|
||||
def get_entities_by_type(
|
||||
self,
|
||||
entity_type: EntityType
|
||||
) -> List[Entity]:
|
||||
"""Gets all entities of a specific type"""
|
||||
return [
|
||||
e for e in self._entities.values()
|
||||
if e.entity_type == entity_type
|
||||
]
|
||||
|
||||
def search_entities(
|
||||
self,
|
||||
query: str,
|
||||
entity_type: Optional[EntityType] = None,
|
||||
limit: int = 10
|
||||
) -> List[Entity]:
|
||||
"""
|
||||
Searches entities by name.
|
||||
|
||||
Args:
|
||||
query: Search query (case-insensitive substring)
|
||||
entity_type: Optional type filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
Matching entities
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for entity in self._entities.values():
|
||||
if entity_type and entity.entity_type != entity_type:
|
||||
continue
|
||||
|
||||
if query_lower in entity.name.lower():
|
||||
results.append(entity)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
# Relationship Operations
|
||||
|
||||
def add_relationship(
|
||||
self,
|
||||
relationship_id: str,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
relationship_type: RelationshipType,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
weight: float = 1.0
|
||||
) -> Optional[Relationship]:
|
||||
"""
|
||||
Adds a relationship between two entities.
|
||||
|
||||
Args:
|
||||
relationship_id: Unique relationship identifier
|
||||
source_id: Source entity ID
|
||||
target_id: Target entity ID
|
||||
relationship_type: Type of relationship
|
||||
properties: Relationship properties
|
||||
weight: Relationship weight/strength
|
||||
|
||||
Returns:
|
||||
The created Relationship or None if entities don't exist
|
||||
"""
|
||||
if source_id not in self._entities or target_id not in self._entities:
|
||||
logger.warning(
|
||||
f"Cannot create relationship: entity not found "
|
||||
f"(source={source_id}, target={target_id})"
|
||||
)
|
||||
return None
|
||||
|
||||
relationship = Relationship(
|
||||
id=relationship_id,
|
||||
source_id=source_id,
|
||||
target_id=target_id,
|
||||
relationship_type=relationship_type,
|
||||
properties=properties or {},
|
||||
weight=weight
|
||||
)
|
||||
|
||||
self._relationships[relationship_id] = relationship
|
||||
self._adjacency[source_id].add(relationship_id)
|
||||
self._adjacency[target_id].add(relationship_id)
|
||||
|
||||
logger.debug(
|
||||
f"Added relationship: {source_id} -[{relationship_type.value}]-> {target_id}"
|
||||
)
|
||||
|
||||
return relationship
|
||||
|
||||
def get_relationship(self, relationship_id: str) -> Optional[Relationship]:
|
||||
"""Gets a relationship by ID"""
|
||||
return self._relationships.get(relationship_id)
|
||||
|
||||
def delete_relationship(self, relationship_id: str) -> bool:
|
||||
"""Deletes a relationship"""
|
||||
return self._delete_relationship_internal(relationship_id)
|
||||
|
||||
def _delete_relationship_internal(self, relationship_id: str) -> bool:
|
||||
"""Internal relationship deletion"""
|
||||
relationship = self._relationships.get(relationship_id)
|
||||
if not relationship:
|
||||
return False
|
||||
|
||||
# Remove from adjacency lists
|
||||
if relationship.source_id in self._adjacency:
|
||||
self._adjacency[relationship.source_id].discard(relationship_id)
|
||||
if relationship.target_id in self._adjacency:
|
||||
self._adjacency[relationship.target_id].discard(relationship_id)
|
||||
|
||||
del self._relationships[relationship_id]
|
||||
return True
|
||||
|
||||
# Graph Traversal
|
||||
|
||||
def get_neighbors(
|
||||
self,
|
||||
entity_id: str,
|
||||
relationship_type: Optional[RelationshipType] = None,
|
||||
direction: str = "both" # "outgoing", "incoming", "both"
|
||||
) -> List[Tuple[Entity, Relationship]]:
|
||||
"""
|
||||
Gets neighboring entities.
|
||||
|
||||
Args:
|
||||
entity_id: Starting entity
|
||||
relationship_type: Optional filter by relationship type
|
||||
direction: Direction to traverse
|
||||
|
||||
Returns:
|
||||
List of (entity, relationship) tuples
|
||||
"""
|
||||
if entity_id not in self._entities:
|
||||
return []
|
||||
|
||||
results = []
|
||||
rel_ids = self._adjacency.get(entity_id, set())
|
||||
|
||||
for rel_id in rel_ids:
|
||||
rel = self._relationships.get(rel_id)
|
||||
if not rel:
|
||||
continue
|
||||
|
||||
# Filter by relationship type
|
||||
if relationship_type and rel.relationship_type != relationship_type:
|
||||
continue
|
||||
|
||||
# Determine neighbor based on direction
|
||||
neighbor_id = None
|
||||
if direction == "outgoing" and rel.source_id == entity_id:
|
||||
neighbor_id = rel.target_id
|
||||
elif direction == "incoming" and rel.target_id == entity_id:
|
||||
neighbor_id = rel.source_id
|
||||
elif direction == "both":
|
||||
neighbor_id = rel.target_id if rel.source_id == entity_id else rel.source_id
|
||||
|
||||
if neighbor_id:
|
||||
neighbor = self._entities.get(neighbor_id)
|
||||
if neighbor:
|
||||
results.append((neighbor, rel))
|
||||
|
||||
return results
|
||||
|
||||
def get_path(
|
||||
self,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
max_depth: int = 5
|
||||
) -> Optional[List[Tuple[Entity, Optional[Relationship]]]]:
|
||||
"""
|
||||
Finds a path between two entities using BFS.
|
||||
|
||||
Args:
|
||||
source_id: Starting entity
|
||||
target_id: Target entity
|
||||
max_depth: Maximum path length
|
||||
|
||||
Returns:
|
||||
Path as list of (entity, relationship) tuples, or None if no path
|
||||
"""
|
||||
if source_id not in self._entities or target_id not in self._entities:
|
||||
return None
|
||||
|
||||
if source_id == target_id:
|
||||
return [(self._entities[source_id], None)]
|
||||
|
||||
# BFS
|
||||
from collections import deque
|
||||
|
||||
visited = {source_id}
|
||||
# Queue items: (entity_id, path so far)
|
||||
queue = deque([(source_id, [(self._entities[source_id], None)])])
|
||||
|
||||
while queue:
|
||||
current_id, path = queue.popleft()
|
||||
|
||||
if len(path) > max_depth:
|
||||
continue
|
||||
|
||||
for neighbor, rel in self.get_neighbors(current_id):
|
||||
if neighbor.id == target_id:
|
||||
return path + [(neighbor, rel)]
|
||||
|
||||
if neighbor.id not in visited:
|
||||
visited.add(neighbor.id)
|
||||
queue.append((neighbor.id, path + [(neighbor, rel)]))
|
||||
|
||||
return None
|
||||
|
||||
def get_subgraph(
|
||||
self,
|
||||
entity_id: str,
|
||||
depth: int = 2
|
||||
) -> Tuple[List[Entity], List[Relationship]]:
|
||||
"""
|
||||
Gets a subgraph around an entity.
|
||||
|
||||
Args:
|
||||
entity_id: Center entity
|
||||
depth: How many hops to include
|
||||
|
||||
Returns:
|
||||
Tuple of (entities, relationships)
|
||||
"""
|
||||
if entity_id not in self._entities:
|
||||
return [], []
|
||||
|
||||
entities_set: Set[str] = {entity_id}
|
||||
relationships_set: Set[str] = set()
|
||||
frontier: Set[str] = {entity_id}
|
||||
|
||||
for _ in range(depth):
|
||||
next_frontier: Set[str] = set()
|
||||
|
||||
for e_id in frontier:
|
||||
for neighbor, rel in self.get_neighbors(e_id):
|
||||
if neighbor.id not in entities_set:
|
||||
entities_set.add(neighbor.id)
|
||||
next_frontier.add(neighbor.id)
|
||||
relationships_set.add(rel.id)
|
||||
|
||||
frontier = next_frontier
|
||||
|
||||
entities = [self._entities[e_id] for e_id in entities_set]
|
||||
relationships = [self._relationships[r_id] for r_id in relationships_set]
|
||||
|
||||
return entities, relationships
|
||||
|
||||
# Serialization
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes the graph to a dictionary"""
|
||||
return {
|
||||
"entities": [e.to_dict() for e in self._entities.values()],
|
||||
"relationships": [r.to_dict() for r in self._relationships.values()]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any], **kwargs) -> "KnowledgeGraph":
|
||||
"""Deserializes a graph from a dictionary"""
|
||||
graph = cls(**kwargs)
|
||||
|
||||
# Load entities first
|
||||
for e_data in data.get("entities", []):
|
||||
entity = Entity.from_dict(e_data)
|
||||
graph._entities[entity.id] = entity
|
||||
graph._adjacency[entity.id] = set()
|
||||
|
||||
# Load relationships
|
||||
for r_data in data.get("relationships", []):
|
||||
rel = Relationship.from_dict(r_data)
|
||||
graph._relationships[rel.id] = rel
|
||||
if rel.source_id in graph._adjacency:
|
||||
graph._adjacency[rel.source_id].add(rel.id)
|
||||
if rel.target_id in graph._adjacency:
|
||||
graph._adjacency[rel.target_id].add(rel.id)
|
||||
|
||||
return graph
|
||||
|
||||
def export_json(self) -> str:
|
||||
"""Exports graph to JSON string"""
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
@classmethod
|
||||
def import_json(cls, json_str: str, **kwargs) -> "KnowledgeGraph":
|
||||
"""Imports graph from JSON string"""
|
||||
return cls.from_dict(json.loads(json_str), **kwargs)
|
||||
|
||||
# Statistics
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Gets graph statistics"""
|
||||
entity_types: Dict[str, int] = {}
|
||||
for entity in self._entities.values():
|
||||
entity_types[entity.entity_type.value] = \
|
||||
entity_types.get(entity.entity_type.value, 0) + 1
|
||||
|
||||
rel_types: Dict[str, int] = {}
|
||||
for rel in self._relationships.values():
|
||||
rel_types[rel.relationship_type.value] = \
|
||||
rel_types.get(rel.relationship_type.value, 0) + 1
|
||||
|
||||
return {
|
||||
"total_entities": len(self._entities),
|
||||
"total_relationships": len(self._relationships),
|
||||
"entity_types": entity_types,
|
||||
"relationship_types": rel_types,
|
||||
"avg_connections": (
|
||||
sum(len(adj) for adj in self._adjacency.values()) /
|
||||
max(len(self._adjacency), 1)
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def entity_count(self) -> int:
|
||||
"""Returns number of entities"""
|
||||
return len(self._entities)
|
||||
|
||||
@property
|
||||
def relationship_count(self) -> int:
|
||||
"""Returns number of relationships"""
|
||||
return len(self._relationships)
|
||||
568
agent-core/brain/memory_store.py
Normal file
568
agent-core/brain/memory_store.py
Normal file
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
Memory Store for Breakpilot Agents
|
||||
|
||||
Provides long-term memory with:
|
||||
- TTL-based expiration
|
||||
- Access count tracking
|
||||
- Pattern-based search
|
||||
- Hybrid Valkey + PostgreSQL persistence
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Memory:
|
||||
"""Represents a stored memory item"""
|
||||
key: str
|
||||
value: Any
|
||||
agent_id: str
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
expires_at: Optional[datetime] = None
|
||||
access_count: int = 0
|
||||
last_accessed: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"key": self.key,
|
||||
"value": self.value,
|
||||
"agent_id": self.agent_id,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"access_count": self.access_count,
|
||||
"last_accessed": self.last_accessed.isoformat() if self.last_accessed else None,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Memory":
|
||||
return cls(
|
||||
key=data["key"],
|
||||
value=data["value"],
|
||||
agent_id=data["agent_id"],
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
access_count=data.get("access_count", 0),
|
||||
last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None,
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the memory has expired"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""
|
||||
Long-term memory store for agents.
|
||||
|
||||
Stores facts, decisions, and learning progress with:
|
||||
- TTL-based expiration
|
||||
- Access tracking for importance scoring
|
||||
- Pattern-based retrieval
|
||||
- Hybrid persistence (Valkey for fast access, PostgreSQL for durability)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot"
|
||||
):
|
||||
"""
|
||||
Initialize the memory store.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
namespace: Key namespace for isolation
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
self._local_cache: Dict[str, Memory] = {}
|
||||
|
||||
def _redis_key(self, key: str) -> str:
|
||||
"""Generate Redis key with namespace"""
|
||||
return f"{self.namespace}:memory:{key}"
|
||||
|
||||
def _hash_key(self, key: str) -> str:
|
||||
"""Generate a hash for long keys"""
|
||||
if len(key) > 200:
|
||||
return hashlib.sha256(key.encode()).hexdigest()[:32]
|
||||
return key
|
||||
|
||||
async def remember(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
agent_id: str,
|
||||
ttl_days: int = 30,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Memory:
|
||||
"""
|
||||
Stores a memory.
|
||||
|
||||
Args:
|
||||
key: Unique key for the memory
|
||||
value: Value to store (must be JSON-serializable)
|
||||
agent_id: ID of the agent storing the memory
|
||||
ttl_days: Time to live in days (0 = no expiration)
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
The created Memory object
|
||||
"""
|
||||
expires_at = None
|
||||
if ttl_days > 0:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=ttl_days)
|
||||
|
||||
memory = Memory(
|
||||
key=key,
|
||||
value=value,
|
||||
agent_id=agent_id,
|
||||
expires_at=expires_at,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Store in all layers
|
||||
await self._store_memory(memory, ttl_days)
|
||||
|
||||
logger.debug(f"Agent {agent_id} remembered '{key}' (TTL: {ttl_days} days)")
|
||||
|
||||
return memory
|
||||
|
||||
async def recall(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieves a memory value by key.
|
||||
|
||||
Args:
|
||||
key: The memory key
|
||||
|
||||
Returns:
|
||||
The stored value or None if not found/expired
|
||||
"""
|
||||
memory = await self.get_memory(key)
|
||||
if memory:
|
||||
return memory.value
|
||||
return None
|
||||
|
||||
async def get_memory(self, key: str) -> Optional[Memory]:
|
||||
"""
|
||||
Retrieves a full Memory object by key.
|
||||
|
||||
Updates access count and last_accessed timestamp.
|
||||
|
||||
Args:
|
||||
key: The memory key
|
||||
|
||||
Returns:
|
||||
Memory object or None if not found/expired
|
||||
"""
|
||||
# Check local cache
|
||||
if key in self._local_cache:
|
||||
memory = self._local_cache[key]
|
||||
if not memory.is_expired():
|
||||
await self._update_access(memory)
|
||||
return memory
|
||||
else:
|
||||
del self._local_cache[key]
|
||||
|
||||
# Try Valkey
|
||||
memory = await self._get_from_valkey(key)
|
||||
if memory and not memory.is_expired():
|
||||
self._local_cache[key] = memory
|
||||
await self._update_access(memory)
|
||||
return memory
|
||||
|
||||
# Try PostgreSQL
|
||||
memory = await self._get_from_postgres(key)
|
||||
if memory and not memory.is_expired():
|
||||
# Re-cache in Valkey
|
||||
await self._cache_in_valkey(memory)
|
||||
self._local_cache[key] = memory
|
||||
await self._update_access(memory)
|
||||
return memory
|
||||
|
||||
return None
|
||||
|
||||
async def forget(self, key: str) -> bool:
|
||||
"""
|
||||
Deletes a memory.
|
||||
|
||||
Args:
|
||||
key: The memory key to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
# Remove from local cache
|
||||
self._local_cache.pop(key, None)
|
||||
|
||||
# Remove from Valkey
|
||||
if self.redis:
|
||||
await self.redis.delete(self._redis_key(key))
|
||||
|
||||
# Mark as deleted in PostgreSQL
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM agent_memory
|
||||
WHERE namespace = $1 AND key = $2
|
||||
""",
|
||||
self.namespace,
|
||||
key
|
||||
)
|
||||
return "DELETE" in result
|
||||
|
||||
return True
|
||||
|
||||
async def search(
|
||||
self,
|
||||
pattern: str,
|
||||
agent_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
Searches for memories matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: SQL LIKE pattern (e.g., "evaluation:math:%")
|
||||
agent_id: Optional filter by agent ID
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
List of matching Memory objects
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Search PostgreSQL (primary source for patterns)
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT key, value, agent_id, created_at, expires_at,
|
||||
access_count, last_accessed, metadata
|
||||
FROM agent_memory
|
||||
WHERE namespace = $1
|
||||
AND key LIKE $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params = [self.namespace, pattern.replace("*", "%")]
|
||||
|
||||
if agent_id:
|
||||
query += " AND agent_id = $3"
|
||||
params.append(agent_id)
|
||||
|
||||
query += f" ORDER BY access_count DESC, created_at DESC LIMIT {limit}"
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
for row in rows:
|
||||
results.append(self._row_to_memory(row))
|
||||
else:
|
||||
# Fall back to local cache search
|
||||
import fnmatch
|
||||
for key, memory in self._local_cache.items():
|
||||
if fnmatch.fnmatch(key, pattern):
|
||||
if agent_id is None or memory.agent_id == agent_id:
|
||||
if not memory.is_expired():
|
||||
results.append(memory)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def get_by_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
limit: int = 100
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
Gets all memories for a specific agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of Memory objects
|
||||
"""
|
||||
return await self.search("*", agent_id=agent_id, limit=limit)
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
hours: int = 24,
|
||||
agent_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
Gets recently created memories.
|
||||
|
||||
Args:
|
||||
hours: How far back to look
|
||||
agent_id: Optional filter by agent
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of Memory objects
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT key, value, agent_id, created_at, expires_at,
|
||||
access_count, last_accessed, metadata
|
||||
FROM agent_memory
|
||||
WHERE namespace = $1
|
||||
AND created_at > $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params = [self.namespace, cutoff]
|
||||
|
||||
if agent_id:
|
||||
query += " AND agent_id = $3"
|
||||
params.append(agent_id)
|
||||
|
||||
query += f" ORDER BY created_at DESC LIMIT {limit}"
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [self._row_to_memory(row) for row in rows]
|
||||
|
||||
# Fall back to local cache
|
||||
results = []
|
||||
for memory in self._local_cache.values():
|
||||
if memory.created_at > cutoff:
|
||||
if agent_id is None or memory.agent_id == agent_id:
|
||||
if not memory.is_expired():
|
||||
results.append(memory)
|
||||
|
||||
return sorted(results, key=lambda m: m.created_at, reverse=True)[:limit]
|
||||
|
||||
async def get_most_accessed(
|
||||
self,
|
||||
limit: int = 10,
|
||||
agent_id: Optional[str] = None
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
Gets the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Number of results
|
||||
agent_id: Optional filter by agent
|
||||
|
||||
Returns:
|
||||
List of Memory objects ordered by access count
|
||||
"""
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT key, value, agent_id, created_at, expires_at,
|
||||
access_count, last_accessed, metadata
|
||||
FROM agent_memory
|
||||
WHERE namespace = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params = [self.namespace]
|
||||
|
||||
if agent_id:
|
||||
query += " AND agent_id = $2"
|
||||
params.append(agent_id)
|
||||
|
||||
query += f" ORDER BY access_count DESC LIMIT {limit}"
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [self._row_to_memory(row) for row in rows]
|
||||
|
||||
# Fall back to local cache
|
||||
valid = [m for m in self._local_cache.values() if not m.is_expired()]
|
||||
if agent_id:
|
||||
valid = [m for m in valid if m.agent_id == agent_id]
|
||||
return sorted(valid, key=lambda m: m.access_count, reverse=True)[:limit]
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Removes expired memories.
|
||||
|
||||
Returns:
|
||||
Number of memories removed
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# Clean local cache
|
||||
expired_keys = [
|
||||
key for key, memory in self._local_cache.items()
|
||||
if memory.is_expired()
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._local_cache[key]
|
||||
count += 1
|
||||
|
||||
# Clean PostgreSQL
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM agent_memory
|
||||
WHERE namespace = $1 AND expires_at < NOW()
|
||||
""",
|
||||
self.namespace
|
||||
)
|
||||
# Parse count from result
|
||||
if result:
|
||||
parts = result.split()
|
||||
if len(parts) >= 2:
|
||||
count += int(parts[1])
|
||||
|
||||
logger.info(f"Cleaned up {count} expired memories")
|
||||
return count
|
||||
|
||||
async def _store_memory(self, memory: Memory, ttl_days: int) -> None:
|
||||
"""Stores memory in all persistence layers"""
|
||||
self._local_cache[memory.key] = memory
|
||||
await self._cache_in_valkey(memory, ttl_days)
|
||||
await self._save_to_postgres(memory)
|
||||
|
||||
async def _cache_in_valkey(
|
||||
self,
|
||||
memory: Memory,
|
||||
ttl_days: Optional[int] = None
|
||||
) -> None:
|
||||
"""Caches memory in Valkey"""
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
try:
|
||||
ttl_seconds = (ttl_days or 30) * 86400
|
||||
await self.redis.setex(
|
||||
self._redis_key(memory.key),
|
||||
ttl_seconds,
|
||||
json.dumps(memory.to_dict())
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache memory in Valkey: {e}")
|
||||
|
||||
async def _get_from_valkey(self, key: str) -> Optional[Memory]:
|
||||
"""Retrieves memory from Valkey"""
|
||||
if not self.redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self.redis.get(self._redis_key(key))
|
||||
if data:
|
||||
return Memory.from_dict(json.loads(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get memory from Valkey: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _save_to_postgres(self, memory: Memory) -> None:
|
||||
"""Saves memory to PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO agent_memory
|
||||
(namespace, key, value, agent_id, access_count,
|
||||
created_at, expires_at, last_accessed, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (namespace, key) DO UPDATE SET
|
||||
value = EXCLUDED.value,
|
||||
access_count = agent_memory.access_count + 1,
|
||||
last_accessed = NOW(),
|
||||
metadata = EXCLUDED.metadata
|
||||
""",
|
||||
self.namespace,
|
||||
memory.key,
|
||||
json.dumps(memory.value),
|
||||
memory.agent_id,
|
||||
memory.access_count,
|
||||
memory.created_at,
|
||||
memory.expires_at,
|
||||
memory.last_accessed,
|
||||
json.dumps(memory.metadata)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save memory to PostgreSQL: {e}")
|
||||
|
||||
async def _get_from_postgres(self, key: str) -> Optional[Memory]:
|
||||
"""Retrieves memory from PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT key, value, agent_id, created_at, expires_at,
|
||||
access_count, last_accessed, metadata
|
||||
FROM agent_memory
|
||||
WHERE namespace = $1 AND key = $2
|
||||
""",
|
||||
self.namespace,
|
||||
key
|
||||
)
|
||||
|
||||
if row:
|
||||
return self._row_to_memory(row)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory from PostgreSQL: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _update_access(self, memory: Memory) -> None:
|
||||
"""Updates access count and timestamp"""
|
||||
memory.access_count += 1
|
||||
memory.last_accessed = datetime.now(timezone.utc)
|
||||
|
||||
# Update in PostgreSQL
|
||||
if self.db_pool:
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE agent_memory
|
||||
SET access_count = access_count + 1,
|
||||
last_accessed = NOW()
|
||||
WHERE namespace = $1 AND key = $2
|
||||
""",
|
||||
self.namespace,
|
||||
memory.key
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update access count: {e}")
|
||||
|
||||
def _row_to_memory(self, row) -> Memory:
|
||||
"""Converts a database row to Memory"""
|
||||
value = row["value"]
|
||||
if isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
|
||||
metadata = row.get("metadata", {})
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
return Memory(
|
||||
key=row["key"],
|
||||
value=value,
|
||||
agent_id=row["agent_id"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
access_count=row["access_count"] or 0,
|
||||
last_accessed=row["last_accessed"],
|
||||
metadata=metadata
|
||||
)
|
||||
Reference in New Issue
Block a user