fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
416
agent-core/README.md
Normal file
416
agent-core/README.md
Normal file
@@ -0,0 +1,416 @@
|
||||
# Breakpilot Agent Core
|
||||
|
||||
Multi-Agent Architecture Infrastructure für Breakpilot.
|
||||
|
||||
## Übersicht
|
||||
|
||||
Das `agent-core` Modul stellt die gemeinsame Infrastruktur für Breakpilots Multi-Agent-System bereit:
|
||||
|
||||
- **Session Management**: Agent-Sessions mit Checkpoints und Recovery
|
||||
- **Shared Brain**: Langzeit-Gedächtnis und Kontext-Verwaltung
|
||||
- **Orchestration**: Message Bus, Supervisor und Task-Routing
|
||||
|
||||
## Architektur
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Breakpilot Services │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │
|
||||
│ │Voice Service│ │Klausur Svc │ │ Admin-v2 / AlertAgent │ │
|
||||
│ └──────┬──────┘ └──────┬──────┘ └───────────┬─────────────┘ │
|
||||
│ │ │ │ │
|
||||
│ └────────────────┼──────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌───────────────────────▼───────────────────────────────────┐ │
|
||||
│ │ Agent Core │ │
|
||||
│ │ ┌─────────────┐ ┌─────────────┐ ┌───────────────────┐ │ │
|
||||
│ │ │ Sessions │ │Shared Brain │ │ Orchestrator │ │ │
|
||||
│ │ │ - Manager │ │ - Memory │ │ - Message Bus │ │ │
|
||||
│ │ │ - Heartbeat │ │ - Context │ │ - Supervisor │ │ │
|
||||
│ │ │ - Checkpoint│ │ - Knowledge │ │ - Task Router │ │ │
|
||||
│ │ └─────────────┘ └─────────────┘ └───────────────────┘ │ │
|
||||
│ └───────────────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌───────────────────────▼───────────────────────────────────┐ │
|
||||
│ │ Infrastructure │ │
|
||||
│ │ Valkey (Redis) PostgreSQL Qdrant │ │
|
||||
│ └───────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Verzeichnisstruktur
|
||||
|
||||
```
|
||||
agent-core/
|
||||
├── __init__.py # Modul-Exports
|
||||
├── README.md # Diese Datei
|
||||
├── requirements.txt # Python-Abhängigkeiten
|
||||
├── pytest.ini # Test-Konfiguration
|
||||
│
|
||||
├── soul/ # Agent SOUL Files (Persönlichkeiten)
|
||||
│ ├── tutor-agent.soul.md
|
||||
│ ├── grader-agent.soul.md
|
||||
│ ├── quality-judge.soul.md
|
||||
│ ├── alert-agent.soul.md
|
||||
│ └── orchestrator.soul.md
|
||||
│
|
||||
├── brain/ # Shared Brain Implementation
|
||||
│ ├── __init__.py
|
||||
│ ├── memory_store.py # Langzeit-Gedächtnis
|
||||
│ ├── context_manager.py # Konversations-Kontext
|
||||
│ └── knowledge_graph.py # Entity-Beziehungen
|
||||
│
|
||||
├── sessions/ # Session Management
|
||||
│ ├── __init__.py
|
||||
│ ├── session_manager.py # Session-Lifecycle
|
||||
│ ├── heartbeat.py # Liveness-Monitoring
|
||||
│ └── checkpoint.py # Recovery-Checkpoints
|
||||
│
|
||||
├── orchestrator/ # Multi-Agent Orchestration
|
||||
│ ├── __init__.py
|
||||
│ ├── message_bus.py # Inter-Agent Kommunikation
|
||||
│ ├── supervisor.py # Agent-Überwachung
|
||||
│ └── task_router.py # Intent-basiertes Routing
|
||||
│
|
||||
└── tests/ # Unit Tests
|
||||
├── conftest.py
|
||||
├── test_session_manager.py
|
||||
├── test_heartbeat.py
|
||||
├── test_message_bus.py
|
||||
├── test_memory_store.py
|
||||
└── test_task_router.py
|
||||
```
|
||||
|
||||
## Komponenten
|
||||
|
||||
### 1. Session Management
|
||||
|
||||
Verwaltet Agent-Sessions mit State-Machine und Recovery-Fähigkeiten.
|
||||
|
||||
```python
|
||||
from agent_core.sessions import SessionManager, AgentSession
|
||||
|
||||
# Session Manager erstellen
|
||||
manager = SessionManager(
|
||||
redis_client=redis,
|
||||
db_pool=pg_pool,
|
||||
namespace="breakpilot"
|
||||
)
|
||||
|
||||
# Session erstellen
|
||||
session = await manager.create_session(
|
||||
agent_type="tutor-agent",
|
||||
user_id="user-123",
|
||||
context={"subject": "math"}
|
||||
)
|
||||
|
||||
# Checkpoint setzen
|
||||
session.checkpoint("task_started", {"task_id": "abc"})
|
||||
|
||||
# Session beenden
|
||||
session.complete({"result": "success"})
|
||||
```
|
||||
|
||||
**Session States:**
|
||||
- `ACTIVE` - Session läuft
|
||||
- `PAUSED` - Session pausiert
|
||||
- `COMPLETED` - Session erfolgreich beendet
|
||||
- `FAILED` - Session fehlgeschlagen
|
||||
|
||||
### 2. Heartbeat Monitoring
|
||||
|
||||
Überwacht Agent-Liveness und triggert Recovery bei Timeout.
|
||||
|
||||
```python
|
||||
from agent_core.sessions import HeartbeatMonitor, HeartbeatClient
|
||||
|
||||
# Monitor starten
|
||||
monitor = HeartbeatMonitor(
|
||||
timeout_seconds=30,
|
||||
check_interval_seconds=5,
|
||||
max_missed_beats=3
|
||||
)
|
||||
await monitor.start_monitoring()
|
||||
|
||||
# Agent registrieren
|
||||
monitor.register("agent-1", "tutor-agent")
|
||||
|
||||
# Heartbeat senden
|
||||
async with HeartbeatClient("agent-1", monitor) as client:
|
||||
# Agent-Arbeit...
|
||||
pass
|
||||
```
|
||||
|
||||
### 3. Memory Store
|
||||
|
||||
Langzeit-Gedächtnis für Agents mit TTL und Access-Tracking.
|
||||
|
||||
```python
|
||||
from agent_core.brain import MemoryStore
|
||||
|
||||
store = MemoryStore(redis_client=redis, db_pool=pg_pool)
|
||||
|
||||
# Erinnerung speichern
|
||||
await store.remember(
|
||||
key="evaluation:math:student-1",
|
||||
value={"score": 85, "feedback": "Gut gemacht!"},
|
||||
agent_id="grader-agent",
|
||||
ttl_days=30
|
||||
)
|
||||
|
||||
# Erinnerung abrufen
|
||||
result = await store.recall("evaluation:math:student-1")
|
||||
|
||||
# Nach Pattern suchen
|
||||
similar = await store.search("evaluation:math:*")
|
||||
```
|
||||
|
||||
### 4. Context Manager
|
||||
|
||||
Verwaltet Konversationskontext mit automatischer Komprimierung.
|
||||
|
||||
```python
|
||||
from agent_core.brain import ContextManager, MessageRole
|
||||
|
||||
ctx_manager = ContextManager(redis_client=redis)
|
||||
|
||||
# Kontext erstellen
|
||||
context = ctx_manager.create_context(
|
||||
session_id="session-123",
|
||||
system_prompt="Du bist ein hilfreicher Tutor...",
|
||||
max_messages=50
|
||||
)
|
||||
|
||||
# Nachrichten hinzufügen
|
||||
context.add_message(MessageRole.USER, "Was ist Photosynthese?")
|
||||
context.add_message(MessageRole.ASSISTANT, "Photosynthese ist...")
|
||||
|
||||
# Für LLM API formatieren
|
||||
messages = context.get_messages_for_llm()
|
||||
```
|
||||
|
||||
### 5. Message Bus
|
||||
|
||||
Inter-Agent Kommunikation via Redis Pub/Sub.
|
||||
|
||||
```python
|
||||
from agent_core.orchestrator import MessageBus, AgentMessage, MessagePriority
|
||||
|
||||
bus = MessageBus(redis_client=redis)
|
||||
await bus.start()
|
||||
|
||||
# Handler registrieren
|
||||
async def handle_message(msg):
|
||||
return {"status": "processed"}
|
||||
|
||||
await bus.subscribe("grader-agent", handle_message)
|
||||
|
||||
# Nachricht senden
|
||||
await bus.publish(AgentMessage(
|
||||
sender="orchestrator",
|
||||
receiver="grader-agent",
|
||||
message_type="grade_request",
|
||||
payload={"exam_id": "exam-1"},
|
||||
priority=MessagePriority.HIGH
|
||||
))
|
||||
|
||||
# Request-Response Pattern
|
||||
response = await bus.request(message, timeout=30.0)
|
||||
```
|
||||
|
||||
### 6. Agent Supervisor
|
||||
|
||||
Überwacht und koordiniert alle Agents.
|
||||
|
||||
```python
|
||||
from agent_core.orchestrator import AgentSupervisor, RestartPolicy
|
||||
|
||||
supervisor = AgentSupervisor(message_bus=bus, heartbeat_monitor=monitor)
|
||||
|
||||
# Agent registrieren
|
||||
await supervisor.register_agent(
|
||||
agent_id="tutor-1",
|
||||
agent_type="tutor-agent",
|
||||
restart_policy=RestartPolicy.ON_FAILURE,
|
||||
max_restarts=3,
|
||||
capacity=10
|
||||
)
|
||||
|
||||
# Agent starten
|
||||
await supervisor.start_agent("tutor-1")
|
||||
|
||||
# Load Balancing
|
||||
available = supervisor.get_available_agent("tutor-agent")
|
||||
```
|
||||
|
||||
### 7. Task Router
|
||||
|
||||
Intent-basiertes Routing mit Fallback-Ketten.
|
||||
|
||||
```python
|
||||
from agent_core.orchestrator import TaskRouter, RoutingRule, RoutingStrategy
|
||||
|
||||
router = TaskRouter(supervisor=supervisor)
|
||||
|
||||
# Eigene Regel hinzufügen
|
||||
router.add_rule(RoutingRule(
|
||||
intent_pattern="learning_*",
|
||||
agent_type="tutor-agent",
|
||||
priority=10,
|
||||
fallback_agent="orchestrator"
|
||||
))
|
||||
|
||||
# Task routen
|
||||
result = await router.route(
|
||||
intent="learning_math",
|
||||
context={"grade": 10},
|
||||
strategy=RoutingStrategy.LEAST_LOADED
|
||||
)
|
||||
|
||||
if result.success:
|
||||
print(f"Routed to {result.agent_id}")
|
||||
```
|
||||
|
||||
## SOUL Files
|
||||
|
||||
SOUL-Dateien definieren die Persönlichkeit und Verhaltensregeln jedes Agents.
|
||||
|
||||
| Agent | SOUL File | Verantwortlichkeit |
|
||||
|-------|-----------|-------------------|
|
||||
| TutorAgent | tutor-agent.soul.md | Lernbegleitung, Fragen beantworten |
|
||||
| GraderAgent | grader-agent.soul.md | Klausur-Korrektur, Bewertung |
|
||||
| QualityJudge | quality-judge.soul.md | BQAS Qualitätsprüfung |
|
||||
| AlertAgent | alert-agent.soul.md | Monitoring, Benachrichtigungen |
|
||||
| Orchestrator | orchestrator.soul.md | Task-Koordination |
|
||||
|
||||
## Datenbank-Schema
|
||||
|
||||
### agent_sessions
|
||||
```sql
|
||||
CREATE TABLE agent_sessions (
|
||||
id UUID PRIMARY KEY,
|
||||
agent_type VARCHAR(50) NOT NULL,
|
||||
user_id UUID REFERENCES users(id),
|
||||
state VARCHAR(20) NOT NULL DEFAULT 'active',
|
||||
context JSONB DEFAULT '{}',
|
||||
checkpoints JSONB DEFAULT '[]',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
last_heartbeat TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
### agent_memory
|
||||
```sql
|
||||
CREATE TABLE agent_memory (
|
||||
id UUID PRIMARY KEY,
|
||||
namespace VARCHAR(100) NOT NULL,
|
||||
key VARCHAR(500) NOT NULL,
|
||||
value JSONB NOT NULL,
|
||||
agent_id VARCHAR(50) NOT NULL,
|
||||
access_count INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ,
|
||||
UNIQUE(namespace, key)
|
||||
);
|
||||
```
|
||||
|
||||
### agent_messages
|
||||
```sql
|
||||
CREATE TABLE agent_messages (
|
||||
id UUID PRIMARY KEY,
|
||||
sender VARCHAR(50) NOT NULL,
|
||||
receiver VARCHAR(50) NOT NULL,
|
||||
message_type VARCHAR(50) NOT NULL,
|
||||
payload JSONB NOT NULL,
|
||||
priority INTEGER DEFAULT 1,
|
||||
correlation_id UUID,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
## Integration
|
||||
|
||||
### Mit Voice-Service
|
||||
|
||||
```python
|
||||
from services.enhanced_task_orchestrator import EnhancedTaskOrchestrator
|
||||
|
||||
orchestrator = EnhancedTaskOrchestrator(
|
||||
redis_client=redis,
|
||||
db_pool=pg_pool
|
||||
)
|
||||
|
||||
await orchestrator.start()
|
||||
|
||||
# Session für Voice-Interaktion
|
||||
session = await orchestrator.create_session(
|
||||
voice_session_id="voice-123",
|
||||
user_id="teacher-1"
|
||||
)
|
||||
|
||||
# Task verarbeiten (nutzt Multi-Agent wenn nötig)
|
||||
await orchestrator.process_task(task)
|
||||
```
|
||||
|
||||
### Mit BQAS
|
||||
|
||||
```python
|
||||
from bqas.quality_judge_agent import QualityJudgeAgent
|
||||
|
||||
judge = QualityJudgeAgent(
|
||||
message_bus=bus,
|
||||
memory_store=memory
|
||||
)
|
||||
|
||||
await judge.start()
|
||||
|
||||
# Direkte Evaluation
|
||||
result = await judge.evaluate(
|
||||
response="Der Satz des Pythagoras...",
|
||||
task_type="learning_math",
|
||||
context={"user_input": "Was ist Pythagoras?"}
|
||||
)
|
||||
|
||||
if result["verdict"] == "production_ready":
|
||||
# Response ist OK
|
||||
pass
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
```bash
|
||||
# In agent-core Verzeichnis
|
||||
cd agent-core
|
||||
|
||||
# Alle Tests ausführen
|
||||
pytest -v
|
||||
|
||||
# Mit Coverage
|
||||
pytest --cov=. --cov-report=html
|
||||
|
||||
# Einzelnes Test-Modul
|
||||
pytest tests/test_session_manager.py -v
|
||||
|
||||
# Async-Tests
|
||||
pytest tests/test_message_bus.py -v
|
||||
```
|
||||
|
||||
## Metriken
|
||||
|
||||
Das Agent-Core exportiert folgende Metriken:
|
||||
|
||||
| Metrik | Beschreibung |
|
||||
|--------|--------------|
|
||||
| `agent_session_duration_seconds` | Dauer von Agent-Sessions |
|
||||
| `agent_heartbeat_delay_seconds` | Zeit seit letztem Heartbeat |
|
||||
| `agent_message_latency_ms` | Latenz der Inter-Agent Kommunikation |
|
||||
| `agent_memory_access_total` | Memory-Zugriffe pro Agent |
|
||||
| `agent_error_total` | Fehler pro Agent-Typ |
|
||||
|
||||
## Nächste Schritte
|
||||
|
||||
1. **Migration ausführen**: `psql -f backend/migrations/add_agent_core_tables.sql`
|
||||
2. **Voice-Service erweitern**: Enhanced Orchestrator aktivieren
|
||||
3. **BQAS integrieren**: Quality Judge Agent starten
|
||||
4. **Monitoring aufsetzen**: Metriken in Grafana integrieren
|
||||
24
agent-core/__init__.py
Normal file
24
agent-core/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Breakpilot Agent Core - Multi-Agent Infrastructure
|
||||
|
||||
This module provides the shared infrastructure for Breakpilot's multi-agent architecture:
|
||||
- Session Management: Agent sessions with checkpoints and heartbeats
|
||||
- Shared Brain: Memory store, context management, and knowledge graph
|
||||
- Orchestration: Message bus, supervisor, and task routing
|
||||
"""
|
||||
|
||||
from agent_core.sessions import AgentSession, SessionManager, SessionState
|
||||
from agent_core.brain import MemoryStore, ConversationContext
|
||||
from agent_core.orchestrator import MessageBus, AgentSupervisor, AgentMessage
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__all__ = [
|
||||
"AgentSession",
|
||||
"SessionManager",
|
||||
"SessionState",
|
||||
"MemoryStore",
|
||||
"ConversationContext",
|
||||
"MessageBus",
|
||||
"AgentSupervisor",
|
||||
"AgentMessage",
|
||||
]
|
||||
22
agent-core/brain/__init__.py
Normal file
22
agent-core/brain/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Shared Brain for Breakpilot Agents
|
||||
|
||||
Provides:
|
||||
- MemoryStore: Long-term memory with TTL and access tracking
|
||||
- ConversationContext: Session context with message history
|
||||
- KnowledgeGraph: Entity relationships and semantic connections
|
||||
"""
|
||||
|
||||
from agent_core.brain.memory_store import MemoryStore, Memory
|
||||
from agent_core.brain.context_manager import ConversationContext, ContextManager
|
||||
from agent_core.brain.knowledge_graph import KnowledgeGraph, Entity, Relationship
|
||||
|
||||
__all__ = [
|
||||
"MemoryStore",
|
||||
"Memory",
|
||||
"ConversationContext",
|
||||
"ContextManager",
|
||||
"KnowledgeGraph",
|
||||
"Entity",
|
||||
"Relationship",
|
||||
]
|
||||
520
agent-core/brain/context_manager.py
Normal file
520
agent-core/brain/context_manager.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""
|
||||
Context Management for Breakpilot Agents
|
||||
|
||||
Provides conversation context with:
|
||||
- Message history with compression
|
||||
- Entity extraction and tracking
|
||||
- Intent history
|
||||
- Context summarization
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageRole(Enum):
|
||||
"""Message roles in a conversation"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""Represents a message in a conversation"""
|
||||
role: MessageRole
|
||||
content: str
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": self.role.value,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Message":
|
||||
return cls(
|
||||
role=MessageRole(data["role"]),
|
||||
content=data["content"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""
|
||||
Context for a running conversation.
|
||||
|
||||
Maintains:
|
||||
- Message history with automatic compression
|
||||
- Extracted entities
|
||||
- Intent history
|
||||
- Conversation summary
|
||||
"""
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
entities: Dict[str, Any] = field(default_factory=dict)
|
||||
intent_history: List[str] = field(default_factory=list)
|
||||
summary: Optional[str] = None
|
||||
max_messages: int = 50
|
||||
system_prompt: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""
|
||||
Adds a message to the conversation.
|
||||
|
||||
Args:
|
||||
role: Message role
|
||||
content: Message content
|
||||
metadata: Optional message metadata
|
||||
|
||||
Returns:
|
||||
The created Message
|
||||
"""
|
||||
message = Message(
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
self.messages.append(message)
|
||||
|
||||
# Compress if needed
|
||||
if len(self.messages) > self.max_messages:
|
||||
self._compress_history()
|
||||
|
||||
return message
|
||||
|
||||
def add_user_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add a user message"""
|
||||
return self.add_message(MessageRole.USER, content, metadata)
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add an assistant message"""
|
||||
return self.add_message(MessageRole.ASSISTANT, content, metadata)
|
||||
|
||||
def add_system_message(
|
||||
self,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""Convenience method to add a system message"""
|
||||
return self.add_message(MessageRole.SYSTEM, content, metadata)
|
||||
|
||||
def add_intent(self, intent: str) -> None:
|
||||
"""
|
||||
Records an intent in the history.
|
||||
|
||||
Args:
|
||||
intent: The detected intent
|
||||
"""
|
||||
self.intent_history.append(intent)
|
||||
# Keep last 20 intents
|
||||
if len(self.intent_history) > 20:
|
||||
self.intent_history = self.intent_history[-20:]
|
||||
|
||||
def set_entity(self, name: str, value: Any) -> None:
|
||||
"""
|
||||
Sets an entity value.
|
||||
|
||||
Args:
|
||||
name: Entity name
|
||||
value: Entity value
|
||||
"""
|
||||
self.entities[name] = value
|
||||
|
||||
def get_entity(self, name: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Gets an entity value.
|
||||
|
||||
Args:
|
||||
name: Entity name
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Entity value or default
|
||||
"""
|
||||
return self.entities.get(name, default)
|
||||
|
||||
def get_last_message(self, role: Optional[MessageRole] = None) -> Optional[Message]:
|
||||
"""
|
||||
Gets the last message, optionally filtered by role.
|
||||
|
||||
Args:
|
||||
role: Optional role filter
|
||||
|
||||
Returns:
|
||||
The last matching message or None
|
||||
"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
if role is None:
|
||||
return self.messages[-1]
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == role:
|
||||
return msg
|
||||
|
||||
return None
|
||||
|
||||
def get_messages_for_llm(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Gets messages formatted for LLM API calls.
|
||||
|
||||
Returns:
|
||||
List of message dicts with role and content
|
||||
"""
|
||||
result = []
|
||||
|
||||
# Add system prompt first
|
||||
if self.system_prompt:
|
||||
result.append({
|
||||
"role": "system",
|
||||
"content": self.system_prompt
|
||||
})
|
||||
|
||||
# Add summary if we have one and history was compressed
|
||||
if self.summary:
|
||||
result.append({
|
||||
"role": "system",
|
||||
"content": f"Previous conversation summary: {self.summary}"
|
||||
})
|
||||
|
||||
# Add recent messages
|
||||
for msg in self.messages:
|
||||
result.append({
|
||||
"role": msg.role.value,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _compress_history(self) -> None:
|
||||
"""
|
||||
Compresses older messages to save context window space.
|
||||
|
||||
Keeps:
|
||||
- System messages
|
||||
- Last 20 messages
|
||||
- Creates summary of compressed middle messages
|
||||
"""
|
||||
# Keep system messages
|
||||
system_msgs = [m for m in self.messages if m.role == MessageRole.SYSTEM]
|
||||
|
||||
# Keep last 20 messages
|
||||
recent_msgs = self.messages[-20:]
|
||||
|
||||
# Middle messages to summarize
|
||||
middle_start = len(system_msgs)
|
||||
middle_end = len(self.messages) - 20
|
||||
middle_msgs = self.messages[middle_start:middle_end]
|
||||
|
||||
if middle_msgs:
|
||||
# Create a basic summary (can be enhanced with LLM-based summarization)
|
||||
self.summary = self._create_summary(middle_msgs)
|
||||
|
||||
# Combine
|
||||
self.messages = system_msgs + recent_msgs
|
||||
|
||||
logger.debug(
|
||||
f"Compressed conversation: {middle_end - middle_start} messages summarized"
|
||||
)
|
||||
|
||||
def _create_summary(self, messages: List[Message]) -> str:
|
||||
"""
|
||||
Creates a summary of messages.
|
||||
|
||||
This is a basic implementation - can be enhanced with LLM-based summarization.
|
||||
|
||||
Args:
|
||||
messages: Messages to summarize
|
||||
|
||||
Returns:
|
||||
Summary string
|
||||
"""
|
||||
# Count message types
|
||||
user_count = sum(1 for m in messages if m.role == MessageRole.USER)
|
||||
assistant_count = sum(1 for m in messages if m.role == MessageRole.ASSISTANT)
|
||||
|
||||
# Extract key topics (simplified - could use NLP)
|
||||
topics = set()
|
||||
for msg in messages:
|
||||
# Simple keyword extraction
|
||||
words = msg.content.lower().split()
|
||||
# Filter common words
|
||||
keywords = [w for w in words if len(w) > 5][:3]
|
||||
topics.update(keywords)
|
||||
|
||||
topics_str = ", ".join(list(topics)[:5])
|
||||
|
||||
return (
|
||||
f"Earlier conversation: {user_count} user messages, "
|
||||
f"{assistant_count} assistant responses. "
|
||||
f"Topics discussed: {topics_str}"
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clears all context"""
|
||||
self.messages.clear()
|
||||
self.entities.clear()
|
||||
self.intent_history.clear()
|
||||
self.summary = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes context to dict"""
|
||||
return {
|
||||
"messages": [m.to_dict() for m in self.messages],
|
||||
"entities": self.entities,
|
||||
"intent_history": self.intent_history,
|
||||
"summary": self.summary,
|
||||
"max_messages": self.max_messages,
|
||||
"system_prompt": self.system_prompt,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ConversationContext":
|
||||
"""Deserializes context from dict"""
|
||||
ctx = cls(
|
||||
messages=[Message.from_dict(m) for m in data.get("messages", [])],
|
||||
entities=data.get("entities", {}),
|
||||
intent_history=data.get("intent_history", []),
|
||||
summary=data.get("summary"),
|
||||
max_messages=data.get("max_messages", 50),
|
||||
system_prompt=data.get("system_prompt"),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""
|
||||
Manages conversation contexts for multiple sessions.
|
||||
|
||||
Provides:
|
||||
- Context creation and retrieval
|
||||
- Persistence to Valkey/PostgreSQL
|
||||
- Context sharing between agents
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot"
|
||||
):
|
||||
"""
|
||||
Initialize the context manager.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
namespace: Key namespace
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
self._contexts: Dict[str, ConversationContext] = {}
|
||||
self._summarize_callback: Optional[Callable[[List[Message]], Awaitable[str]]] = None
|
||||
|
||||
def _redis_key(self, session_id: str) -> str:
|
||||
"""Generate Redis key for context"""
|
||||
return f"{self.namespace}:context:{session_id}"
|
||||
|
||||
def create_context(
|
||||
self,
|
||||
session_id: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_messages: int = 50
|
||||
) -> ConversationContext:
|
||||
"""
|
||||
Creates a new conversation context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID for this context
|
||||
system_prompt: Optional system prompt
|
||||
max_messages: Maximum messages before compression
|
||||
|
||||
Returns:
|
||||
The created context
|
||||
"""
|
||||
context = ConversationContext(
|
||||
max_messages=max_messages,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
self._contexts[session_id] = context
|
||||
return context
|
||||
|
||||
async def get_context(self, session_id: str) -> Optional[ConversationContext]:
|
||||
"""
|
||||
Gets a context by session ID.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
ConversationContext or None
|
||||
"""
|
||||
# Check local cache
|
||||
if session_id in self._contexts:
|
||||
return self._contexts[session_id]
|
||||
|
||||
# Try Valkey
|
||||
context = await self._get_from_valkey(session_id)
|
||||
if context:
|
||||
self._contexts[session_id] = context
|
||||
return context
|
||||
|
||||
return None
|
||||
|
||||
async def save_context(self, session_id: str) -> None:
|
||||
"""
|
||||
Saves a context to persistent storage.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
"""
|
||||
if session_id not in self._contexts:
|
||||
return
|
||||
|
||||
context = self._contexts[session_id]
|
||||
await self._cache_in_valkey(session_id, context)
|
||||
|
||||
async def delete_context(self, session_id: str) -> bool:
|
||||
"""
|
||||
Deletes a context.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
self._contexts.pop(session_id, None)
|
||||
|
||||
if self.redis:
|
||||
await self.redis.delete(self._redis_key(session_id))
|
||||
|
||||
return True
|
||||
|
||||
def set_summarize_callback(
|
||||
self,
|
||||
callback: Callable[[List[Message]], Awaitable[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Sets a callback for LLM-based summarization.
|
||||
|
||||
Args:
|
||||
callback: Async function that takes messages and returns summary
|
||||
"""
|
||||
self._summarize_callback = callback
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
session_id: str,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Message]:
|
||||
"""
|
||||
Adds a message to a session's context.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
role: Message role
|
||||
content: Message content
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
The created message or None if context not found
|
||||
"""
|
||||
context = await self.get_context(session_id)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
message = context.add_message(role, content, metadata)
|
||||
|
||||
# Save after each message
|
||||
await self.save_context(session_id)
|
||||
|
||||
return message
|
||||
|
||||
async def get_messages_for_llm(
|
||||
self,
|
||||
session_id: str
|
||||
) -> Optional[List[Dict[str, str]]]:
|
||||
"""
|
||||
Gets formatted messages for LLM API call.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
List of message dicts or None
|
||||
"""
|
||||
context = await self.get_context(session_id)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
return context.get_messages_for_llm()
|
||||
|
||||
async def _cache_in_valkey(
|
||||
self,
|
||||
session_id: str,
|
||||
context: ConversationContext
|
||||
) -> None:
|
||||
"""Caches context in Valkey"""
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
try:
|
||||
# 24 hour TTL for contexts
|
||||
await self.redis.setex(
|
||||
self._redis_key(session_id),
|
||||
86400,
|
||||
json.dumps(context.to_dict())
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache context in Valkey: {e}")
|
||||
|
||||
async def _get_from_valkey(
|
||||
self,
|
||||
session_id: str
|
||||
) -> Optional[ConversationContext]:
|
||||
"""Retrieves context from Valkey"""
|
||||
if not self.redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self.redis.get(self._redis_key(session_id))
|
||||
if data:
|
||||
return ConversationContext.from_dict(json.loads(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get context from Valkey: {e}")
|
||||
|
||||
return None
|
||||
563
agent-core/brain/knowledge_graph.py
Normal file
563
agent-core/brain/knowledge_graph.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""
|
||||
Knowledge Graph for Breakpilot Agents
|
||||
|
||||
Provides entity and relationship management:
|
||||
- Entity storage with properties
|
||||
- Relationship definitions
|
||||
- Graph traversal
|
||||
- Optional Qdrant integration for semantic search
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EntityType(Enum):
|
||||
"""Types of entities in the knowledge graph"""
|
||||
STUDENT = "student"
|
||||
TEACHER = "teacher"
|
||||
CLASS = "class"
|
||||
SUBJECT = "subject"
|
||||
ASSIGNMENT = "assignment"
|
||||
EXAM = "exam"
|
||||
TOPIC = "topic"
|
||||
CONCEPT = "concept"
|
||||
RESOURCE = "resource"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class RelationshipType(Enum):
|
||||
"""Types of relationships between entities"""
|
||||
BELONGS_TO = "belongs_to" # Student belongs to class
|
||||
TEACHES = "teaches" # Teacher teaches subject
|
||||
ASSIGNED_TO = "assigned_to" # Assignment assigned to student
|
||||
COVERS = "covers" # Exam covers topic
|
||||
REQUIRES = "requires" # Topic requires concept
|
||||
RELATED_TO = "related_to" # General relationship
|
||||
PARENT_OF = "parent_of" # Hierarchical relationship
|
||||
CREATED_BY = "created_by" # Creator relationship
|
||||
GRADED_BY = "graded_by" # Grading relationship
|
||||
|
||||
|
||||
@dataclass
|
||||
class Entity:
|
||||
"""Represents an entity in the knowledge graph"""
|
||||
id: str
|
||||
entity_type: EntityType
|
||||
name: str
|
||||
properties: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"entity_type": self.entity_type.value,
|
||||
"name": self.name,
|
||||
"properties": self.properties,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Entity":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
entity_type=EntityType(data["entity_type"]),
|
||||
name=data["name"],
|
||||
properties=data.get("properties", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Relationship:
|
||||
"""Represents a relationship between two entities"""
|
||||
id: str
|
||||
source_id: str
|
||||
target_id: str
|
||||
relationship_type: RelationshipType
|
||||
properties: Dict[str, Any] = field(default_factory=dict)
|
||||
weight: float = 1.0
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"relationship_type": self.relationship_type.value,
|
||||
"properties": self.properties,
|
||||
"weight": self.weight,
|
||||
"created_at": self.created_at.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Relationship":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
source_id=data["source_id"],
|
||||
target_id=data["target_id"],
|
||||
relationship_type=RelationshipType(data["relationship_type"]),
|
||||
properties=data.get("properties", {}),
|
||||
weight=data.get("weight", 1.0),
|
||||
created_at=datetime.fromisoformat(data["created_at"])
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeGraph:
|
||||
"""
|
||||
Knowledge graph for managing entity relationships.
|
||||
|
||||
Provides:
|
||||
- Entity CRUD operations
|
||||
- Relationship management
|
||||
- Graph traversal (neighbors, paths)
|
||||
- Optional vector search via Qdrant
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_pool=None,
|
||||
qdrant_client=None,
|
||||
namespace: str = "breakpilot"
|
||||
):
|
||||
"""
|
||||
Initialize the knowledge graph.
|
||||
|
||||
Args:
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
qdrant_client: Optional Qdrant client for vector search
|
||||
namespace: Namespace for isolation
|
||||
"""
|
||||
self.db_pool = db_pool
|
||||
self.qdrant = qdrant_client
|
||||
self.namespace = namespace
|
||||
self._entities: Dict[str, Entity] = {}
|
||||
self._relationships: Dict[str, Relationship] = {}
|
||||
self._adjacency: Dict[str, Set[str]] = {} # entity_id -> set of relationship_ids
|
||||
|
||||
# Entity Operations
|
||||
|
||||
def add_entity(
|
||||
self,
|
||||
entity_id: str,
|
||||
entity_type: EntityType,
|
||||
name: str,
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
) -> Entity:
|
||||
"""
|
||||
Adds an entity to the graph.
|
||||
|
||||
Args:
|
||||
entity_id: Unique entity identifier
|
||||
entity_type: Type of entity
|
||||
name: Human-readable name
|
||||
properties: Entity properties
|
||||
|
||||
Returns:
|
||||
The created Entity
|
||||
"""
|
||||
entity = Entity(
|
||||
id=entity_id,
|
||||
entity_type=entity_type,
|
||||
name=name,
|
||||
properties=properties or {}
|
||||
)
|
||||
|
||||
self._entities[entity_id] = entity
|
||||
self._adjacency[entity_id] = set()
|
||||
|
||||
logger.debug(f"Added entity: {entity_type.value}/{entity_id}")
|
||||
|
||||
return entity
|
||||
|
||||
def get_entity(self, entity_id: str) -> Optional[Entity]:
|
||||
"""Gets an entity by ID"""
|
||||
return self._entities.get(entity_id)
|
||||
|
||||
def update_entity(
|
||||
self,
|
||||
entity_id: str,
|
||||
name: Optional[str] = None,
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Entity]:
|
||||
"""
|
||||
Updates an entity.
|
||||
|
||||
Args:
|
||||
entity_id: Entity to update
|
||||
name: New name (optional)
|
||||
properties: Properties to update (merged)
|
||||
|
||||
Returns:
|
||||
Updated entity or None if not found
|
||||
"""
|
||||
entity = self._entities.get(entity_id)
|
||||
if not entity:
|
||||
return None
|
||||
|
||||
if name:
|
||||
entity.name = name
|
||||
|
||||
if properties:
|
||||
entity.properties.update(properties)
|
||||
|
||||
entity.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
return entity
|
||||
|
||||
def delete_entity(self, entity_id: str) -> bool:
|
||||
"""
|
||||
Deletes an entity and its relationships.
|
||||
|
||||
Args:
|
||||
entity_id: Entity to delete
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
if entity_id not in self._entities:
|
||||
return False
|
||||
|
||||
# Delete all relationships involving this entity
|
||||
rel_ids = list(self._adjacency.get(entity_id, set()))
|
||||
for rel_id in rel_ids:
|
||||
self._delete_relationship_internal(rel_id)
|
||||
|
||||
del self._entities[entity_id]
|
||||
del self._adjacency[entity_id]
|
||||
|
||||
return True
|
||||
|
||||
def get_entities_by_type(
|
||||
self,
|
||||
entity_type: EntityType
|
||||
) -> List[Entity]:
|
||||
"""Gets all entities of a specific type"""
|
||||
return [
|
||||
e for e in self._entities.values()
|
||||
if e.entity_type == entity_type
|
||||
]
|
||||
|
||||
def search_entities(
|
||||
self,
|
||||
query: str,
|
||||
entity_type: Optional[EntityType] = None,
|
||||
limit: int = 10
|
||||
) -> List[Entity]:
|
||||
"""
|
||||
Searches entities by name.
|
||||
|
||||
Args:
|
||||
query: Search query (case-insensitive substring)
|
||||
entity_type: Optional type filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
Matching entities
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for entity in self._entities.values():
|
||||
if entity_type and entity.entity_type != entity_type:
|
||||
continue
|
||||
|
||||
if query_lower in entity.name.lower():
|
||||
results.append(entity)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
# Relationship Operations
|
||||
|
||||
def add_relationship(
|
||||
self,
|
||||
relationship_id: str,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
relationship_type: RelationshipType,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
weight: float = 1.0
|
||||
) -> Optional[Relationship]:
|
||||
"""
|
||||
Adds a relationship between two entities.
|
||||
|
||||
Args:
|
||||
relationship_id: Unique relationship identifier
|
||||
source_id: Source entity ID
|
||||
target_id: Target entity ID
|
||||
relationship_type: Type of relationship
|
||||
properties: Relationship properties
|
||||
weight: Relationship weight/strength
|
||||
|
||||
Returns:
|
||||
The created Relationship or None if entities don't exist
|
||||
"""
|
||||
if source_id not in self._entities or target_id not in self._entities:
|
||||
logger.warning(
|
||||
f"Cannot create relationship: entity not found "
|
||||
f"(source={source_id}, target={target_id})"
|
||||
)
|
||||
return None
|
||||
|
||||
relationship = Relationship(
|
||||
id=relationship_id,
|
||||
source_id=source_id,
|
||||
target_id=target_id,
|
||||
relationship_type=relationship_type,
|
||||
properties=properties or {},
|
||||
weight=weight
|
||||
)
|
||||
|
||||
self._relationships[relationship_id] = relationship
|
||||
self._adjacency[source_id].add(relationship_id)
|
||||
self._adjacency[target_id].add(relationship_id)
|
||||
|
||||
logger.debug(
|
||||
f"Added relationship: {source_id} -[{relationship_type.value}]-> {target_id}"
|
||||
)
|
||||
|
||||
return relationship
|
||||
|
||||
def get_relationship(self, relationship_id: str) -> Optional[Relationship]:
|
||||
"""Gets a relationship by ID"""
|
||||
return self._relationships.get(relationship_id)
|
||||
|
||||
def delete_relationship(self, relationship_id: str) -> bool:
|
||||
"""Deletes a relationship"""
|
||||
return self._delete_relationship_internal(relationship_id)
|
||||
|
||||
def _delete_relationship_internal(self, relationship_id: str) -> bool:
|
||||
"""Internal relationship deletion"""
|
||||
relationship = self._relationships.get(relationship_id)
|
||||
if not relationship:
|
||||
return False
|
||||
|
||||
# Remove from adjacency lists
|
||||
if relationship.source_id in self._adjacency:
|
||||
self._adjacency[relationship.source_id].discard(relationship_id)
|
||||
if relationship.target_id in self._adjacency:
|
||||
self._adjacency[relationship.target_id].discard(relationship_id)
|
||||
|
||||
del self._relationships[relationship_id]
|
||||
return True
|
||||
|
||||
# Graph Traversal
|
||||
|
||||
def get_neighbors(
|
||||
self,
|
||||
entity_id: str,
|
||||
relationship_type: Optional[RelationshipType] = None,
|
||||
direction: str = "both" # "outgoing", "incoming", "both"
|
||||
) -> List[Tuple[Entity, Relationship]]:
|
||||
"""
|
||||
Gets neighboring entities.
|
||||
|
||||
Args:
|
||||
entity_id: Starting entity
|
||||
relationship_type: Optional filter by relationship type
|
||||
direction: Direction to traverse
|
||||
|
||||
Returns:
|
||||
List of (entity, relationship) tuples
|
||||
"""
|
||||
if entity_id not in self._entities:
|
||||
return []
|
||||
|
||||
results = []
|
||||
rel_ids = self._adjacency.get(entity_id, set())
|
||||
|
||||
for rel_id in rel_ids:
|
||||
rel = self._relationships.get(rel_id)
|
||||
if not rel:
|
||||
continue
|
||||
|
||||
# Filter by relationship type
|
||||
if relationship_type and rel.relationship_type != relationship_type:
|
||||
continue
|
||||
|
||||
# Determine neighbor based on direction
|
||||
neighbor_id = None
|
||||
if direction == "outgoing" and rel.source_id == entity_id:
|
||||
neighbor_id = rel.target_id
|
||||
elif direction == "incoming" and rel.target_id == entity_id:
|
||||
neighbor_id = rel.source_id
|
||||
elif direction == "both":
|
||||
neighbor_id = rel.target_id if rel.source_id == entity_id else rel.source_id
|
||||
|
||||
if neighbor_id:
|
||||
neighbor = self._entities.get(neighbor_id)
|
||||
if neighbor:
|
||||
results.append((neighbor, rel))
|
||||
|
||||
return results
|
||||
|
||||
def get_path(
|
||||
self,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
max_depth: int = 5
|
||||
) -> Optional[List[Tuple[Entity, Optional[Relationship]]]]:
|
||||
"""
|
||||
Finds a path between two entities using BFS.
|
||||
|
||||
Args:
|
||||
source_id: Starting entity
|
||||
target_id: Target entity
|
||||
max_depth: Maximum path length
|
||||
|
||||
Returns:
|
||||
Path as list of (entity, relationship) tuples, or None if no path
|
||||
"""
|
||||
if source_id not in self._entities or target_id not in self._entities:
|
||||
return None
|
||||
|
||||
if source_id == target_id:
|
||||
return [(self._entities[source_id], None)]
|
||||
|
||||
# BFS
|
||||
from collections import deque
|
||||
|
||||
visited = {source_id}
|
||||
# Queue items: (entity_id, path so far)
|
||||
queue = deque([(source_id, [(self._entities[source_id], None)])])
|
||||
|
||||
while queue:
|
||||
current_id, path = queue.popleft()
|
||||
|
||||
if len(path) > max_depth:
|
||||
continue
|
||||
|
||||
for neighbor, rel in self.get_neighbors(current_id):
|
||||
if neighbor.id == target_id:
|
||||
return path + [(neighbor, rel)]
|
||||
|
||||
if neighbor.id not in visited:
|
||||
visited.add(neighbor.id)
|
||||
queue.append((neighbor.id, path + [(neighbor, rel)]))
|
||||
|
||||
return None
|
||||
|
||||
def get_subgraph(
|
||||
self,
|
||||
entity_id: str,
|
||||
depth: int = 2
|
||||
) -> Tuple[List[Entity], List[Relationship]]:
|
||||
"""
|
||||
Gets a subgraph around an entity.
|
||||
|
||||
Args:
|
||||
entity_id: Center entity
|
||||
depth: How many hops to include
|
||||
|
||||
Returns:
|
||||
Tuple of (entities, relationships)
|
||||
"""
|
||||
if entity_id not in self._entities:
|
||||
return [], []
|
||||
|
||||
entities_set: Set[str] = {entity_id}
|
||||
relationships_set: Set[str] = set()
|
||||
frontier: Set[str] = {entity_id}
|
||||
|
||||
for _ in range(depth):
|
||||
next_frontier: Set[str] = set()
|
||||
|
||||
for e_id in frontier:
|
||||
for neighbor, rel in self.get_neighbors(e_id):
|
||||
if neighbor.id not in entities_set:
|
||||
entities_set.add(neighbor.id)
|
||||
next_frontier.add(neighbor.id)
|
||||
relationships_set.add(rel.id)
|
||||
|
||||
frontier = next_frontier
|
||||
|
||||
entities = [self._entities[e_id] for e_id in entities_set]
|
||||
relationships = [self._relationships[r_id] for r_id in relationships_set]
|
||||
|
||||
return entities, relationships
|
||||
|
||||
# Serialization
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes the graph to a dictionary"""
|
||||
return {
|
||||
"entities": [e.to_dict() for e in self._entities.values()],
|
||||
"relationships": [r.to_dict() for r in self._relationships.values()]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any], **kwargs) -> "KnowledgeGraph":
|
||||
"""Deserializes a graph from a dictionary"""
|
||||
graph = cls(**kwargs)
|
||||
|
||||
# Load entities first
|
||||
for e_data in data.get("entities", []):
|
||||
entity = Entity.from_dict(e_data)
|
||||
graph._entities[entity.id] = entity
|
||||
graph._adjacency[entity.id] = set()
|
||||
|
||||
# Load relationships
|
||||
for r_data in data.get("relationships", []):
|
||||
rel = Relationship.from_dict(r_data)
|
||||
graph._relationships[rel.id] = rel
|
||||
if rel.source_id in graph._adjacency:
|
||||
graph._adjacency[rel.source_id].add(rel.id)
|
||||
if rel.target_id in graph._adjacency:
|
||||
graph._adjacency[rel.target_id].add(rel.id)
|
||||
|
||||
return graph
|
||||
|
||||
def export_json(self) -> str:
|
||||
"""Exports graph to JSON string"""
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
@classmethod
|
||||
def import_json(cls, json_str: str, **kwargs) -> "KnowledgeGraph":
|
||||
"""Imports graph from JSON string"""
|
||||
return cls.from_dict(json.loads(json_str), **kwargs)
|
||||
|
||||
# Statistics
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Gets graph statistics"""
|
||||
entity_types: Dict[str, int] = {}
|
||||
for entity in self._entities.values():
|
||||
entity_types[entity.entity_type.value] = \
|
||||
entity_types.get(entity.entity_type.value, 0) + 1
|
||||
|
||||
rel_types: Dict[str, int] = {}
|
||||
for rel in self._relationships.values():
|
||||
rel_types[rel.relationship_type.value] = \
|
||||
rel_types.get(rel.relationship_type.value, 0) + 1
|
||||
|
||||
return {
|
||||
"total_entities": len(self._entities),
|
||||
"total_relationships": len(self._relationships),
|
||||
"entity_types": entity_types,
|
||||
"relationship_types": rel_types,
|
||||
"avg_connections": (
|
||||
sum(len(adj) for adj in self._adjacency.values()) /
|
||||
max(len(self._adjacency), 1)
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def entity_count(self) -> int:
|
||||
"""Returns number of entities"""
|
||||
return len(self._entities)
|
||||
|
||||
@property
|
||||
def relationship_count(self) -> int:
|
||||
"""Returns number of relationships"""
|
||||
return len(self._relationships)
|
||||
568
agent-core/brain/memory_store.py
Normal file
568
agent-core/brain/memory_store.py
Normal file
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
Memory Store for Breakpilot Agents
|
||||
|
||||
Provides long-term memory with:
|
||||
- TTL-based expiration
|
||||
- Access count tracking
|
||||
- Pattern-based search
|
||||
- Hybrid Valkey + PostgreSQL persistence
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Memory:
|
||||
"""Represents a stored memory item"""
|
||||
key: str
|
||||
value: Any
|
||||
agent_id: str
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
expires_at: Optional[datetime] = None
|
||||
access_count: int = 0
|
||||
last_accessed: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"key": self.key,
|
||||
"value": self.value,
|
||||
"agent_id": self.agent_id,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"access_count": self.access_count,
|
||||
"last_accessed": self.last_accessed.isoformat() if self.last_accessed else None,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Memory":
|
||||
return cls(
|
||||
key=data["key"],
|
||||
value=data["value"],
|
||||
agent_id=data["agent_id"],
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
access_count=data.get("access_count", 0),
|
||||
last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None,
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the memory has expired"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""
|
||||
Long-term memory store for agents.
|
||||
|
||||
Stores facts, decisions, and learning progress with:
|
||||
- TTL-based expiration
|
||||
- Access tracking for importance scoring
|
||||
- Pattern-based retrieval
|
||||
- Hybrid persistence (Valkey for fast access, PostgreSQL for durability)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot"
|
||||
):
|
||||
"""
|
||||
Initialize the memory store.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
namespace: Key namespace for isolation
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
self._local_cache: Dict[str, Memory] = {}
|
||||
|
||||
def _redis_key(self, key: str) -> str:
|
||||
"""Generate Redis key with namespace"""
|
||||
return f"{self.namespace}:memory:{key}"
|
||||
|
||||
def _hash_key(self, key: str) -> str:
|
||||
"""Generate a hash for long keys"""
|
||||
if len(key) > 200:
|
||||
return hashlib.sha256(key.encode()).hexdigest()[:32]
|
||||
return key
|
||||
|
||||
async def remember(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
agent_id: str,
|
||||
ttl_days: int = 30,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Memory:
|
||||
"""
|
||||
Stores a memory.
|
||||
|
||||
Args:
|
||||
key: Unique key for the memory
|
||||
value: Value to store (must be JSON-serializable)
|
||||
agent_id: ID of the agent storing the memory
|
||||
ttl_days: Time to live in days (0 = no expiration)
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
The created Memory object
|
||||
"""
|
||||
expires_at = None
|
||||
if ttl_days > 0:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=ttl_days)
|
||||
|
||||
memory = Memory(
|
||||
key=key,
|
||||
value=value,
|
||||
agent_id=agent_id,
|
||||
expires_at=expires_at,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Store in all layers
|
||||
await self._store_memory(memory, ttl_days)
|
||||
|
||||
logger.debug(f"Agent {agent_id} remembered '{key}' (TTL: {ttl_days} days)")
|
||||
|
||||
return memory
|
||||
|
||||
async def recall(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieves a memory value by key.
|
||||
|
||||
Args:
|
||||
key: The memory key
|
||||
|
||||
Returns:
|
||||
The stored value or None if not found/expired
|
||||
"""
|
||||
memory = await self.get_memory(key)
|
||||
if memory:
|
||||
return memory.value
|
||||
return None
|
||||
|
||||
async def get_memory(self, key: str) -> Optional[Memory]:
|
||||
"""
|
||||
Retrieves a full Memory object by key.
|
||||
|
||||
Updates access count and last_accessed timestamp.
|
||||
|
||||
Args:
|
||||
key: The memory key
|
||||
|
||||
Returns:
|
||||
Memory object or None if not found/expired
|
||||
"""
|
||||
# Check local cache
|
||||
if key in self._local_cache:
|
||||
memory = self._local_cache[key]
|
||||
if not memory.is_expired():
|
||||
await self._update_access(memory)
|
||||
return memory
|
||||
else:
|
||||
del self._local_cache[key]
|
||||
|
||||
# Try Valkey
|
||||
memory = await self._get_from_valkey(key)
|
||||
if memory and not memory.is_expired():
|
||||
self._local_cache[key] = memory
|
||||
await self._update_access(memory)
|
||||
return memory
|
||||
|
||||
# Try PostgreSQL
|
||||
memory = await self._get_from_postgres(key)
|
||||
if memory and not memory.is_expired():
|
||||
# Re-cache in Valkey
|
||||
await self._cache_in_valkey(memory)
|
||||
self._local_cache[key] = memory
|
||||
await self._update_access(memory)
|
||||
return memory
|
||||
|
||||
return None
|
||||
|
||||
async def forget(self, key: str) -> bool:
|
||||
"""
|
||||
Deletes a memory.
|
||||
|
||||
Args:
|
||||
key: The memory key to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
# Remove from local cache
|
||||
self._local_cache.pop(key, None)
|
||||
|
||||
# Remove from Valkey
|
||||
if self.redis:
|
||||
await self.redis.delete(self._redis_key(key))
|
||||
|
||||
# Mark as deleted in PostgreSQL
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM agent_memory
|
||||
WHERE namespace = $1 AND key = $2
|
||||
""",
|
||||
self.namespace,
|
||||
key
|
||||
)
|
||||
return "DELETE" in result
|
||||
|
||||
return True
|
||||
|
||||
async def search(
|
||||
self,
|
||||
pattern: str,
|
||||
agent_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
Searches for memories matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: SQL LIKE pattern (e.g., "evaluation:math:%")
|
||||
agent_id: Optional filter by agent ID
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
List of matching Memory objects
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Search PostgreSQL (primary source for patterns)
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT key, value, agent_id, created_at, expires_at,
|
||||
access_count, last_accessed, metadata
|
||||
FROM agent_memory
|
||||
WHERE namespace = $1
|
||||
AND key LIKE $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params = [self.namespace, pattern.replace("*", "%")]
|
||||
|
||||
if agent_id:
|
||||
query += " AND agent_id = $3"
|
||||
params.append(agent_id)
|
||||
|
||||
query += f" ORDER BY access_count DESC, created_at DESC LIMIT {limit}"
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
for row in rows:
|
||||
results.append(self._row_to_memory(row))
|
||||
else:
|
||||
# Fall back to local cache search
|
||||
import fnmatch
|
||||
for key, memory in self._local_cache.items():
|
||||
if fnmatch.fnmatch(key, pattern):
|
||||
if agent_id is None or memory.agent_id == agent_id:
|
||||
if not memory.is_expired():
|
||||
results.append(memory)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def get_by_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
limit: int = 100
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
Gets all memories for a specific agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of Memory objects
|
||||
"""
|
||||
return await self.search("*", agent_id=agent_id, limit=limit)
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
hours: int = 24,
|
||||
agent_id: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
Gets recently created memories.
|
||||
|
||||
Args:
|
||||
hours: How far back to look
|
||||
agent_id: Optional filter by agent
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of Memory objects
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT key, value, agent_id, created_at, expires_at,
|
||||
access_count, last_accessed, metadata
|
||||
FROM agent_memory
|
||||
WHERE namespace = $1
|
||||
AND created_at > $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params = [self.namespace, cutoff]
|
||||
|
||||
if agent_id:
|
||||
query += " AND agent_id = $3"
|
||||
params.append(agent_id)
|
||||
|
||||
query += f" ORDER BY created_at DESC LIMIT {limit}"
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [self._row_to_memory(row) for row in rows]
|
||||
|
||||
# Fall back to local cache
|
||||
results = []
|
||||
for memory in self._local_cache.values():
|
||||
if memory.created_at > cutoff:
|
||||
if agent_id is None or memory.agent_id == agent_id:
|
||||
if not memory.is_expired():
|
||||
results.append(memory)
|
||||
|
||||
return sorted(results, key=lambda m: m.created_at, reverse=True)[:limit]
|
||||
|
||||
async def get_most_accessed(
|
||||
self,
|
||||
limit: int = 10,
|
||||
agent_id: Optional[str] = None
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
Gets the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Number of results
|
||||
agent_id: Optional filter by agent
|
||||
|
||||
Returns:
|
||||
List of Memory objects ordered by access count
|
||||
"""
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT key, value, agent_id, created_at, expires_at,
|
||||
access_count, last_accessed, metadata
|
||||
FROM agent_memory
|
||||
WHERE namespace = $1
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"""
|
||||
params = [self.namespace]
|
||||
|
||||
if agent_id:
|
||||
query += " AND agent_id = $2"
|
||||
params.append(agent_id)
|
||||
|
||||
query += f" ORDER BY access_count DESC LIMIT {limit}"
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [self._row_to_memory(row) for row in rows]
|
||||
|
||||
# Fall back to local cache
|
||||
valid = [m for m in self._local_cache.values() if not m.is_expired()]
|
||||
if agent_id:
|
||||
valid = [m for m in valid if m.agent_id == agent_id]
|
||||
return sorted(valid, key=lambda m: m.access_count, reverse=True)[:limit]
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Removes expired memories.
|
||||
|
||||
Returns:
|
||||
Number of memories removed
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# Clean local cache
|
||||
expired_keys = [
|
||||
key for key, memory in self._local_cache.items()
|
||||
if memory.is_expired()
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._local_cache[key]
|
||||
count += 1
|
||||
|
||||
# Clean PostgreSQL
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM agent_memory
|
||||
WHERE namespace = $1 AND expires_at < NOW()
|
||||
""",
|
||||
self.namespace
|
||||
)
|
||||
# Parse count from result
|
||||
if result:
|
||||
parts = result.split()
|
||||
if len(parts) >= 2:
|
||||
count += int(parts[1])
|
||||
|
||||
logger.info(f"Cleaned up {count} expired memories")
|
||||
return count
|
||||
|
||||
async def _store_memory(self, memory: Memory, ttl_days: int) -> None:
|
||||
"""Stores memory in all persistence layers"""
|
||||
self._local_cache[memory.key] = memory
|
||||
await self._cache_in_valkey(memory, ttl_days)
|
||||
await self._save_to_postgres(memory)
|
||||
|
||||
async def _cache_in_valkey(
|
||||
self,
|
||||
memory: Memory,
|
||||
ttl_days: Optional[int] = None
|
||||
) -> None:
|
||||
"""Caches memory in Valkey"""
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
try:
|
||||
ttl_seconds = (ttl_days or 30) * 86400
|
||||
await self.redis.setex(
|
||||
self._redis_key(memory.key),
|
||||
ttl_seconds,
|
||||
json.dumps(memory.to_dict())
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache memory in Valkey: {e}")
|
||||
|
||||
async def _get_from_valkey(self, key: str) -> Optional[Memory]:
|
||||
"""Retrieves memory from Valkey"""
|
||||
if not self.redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self.redis.get(self._redis_key(key))
|
||||
if data:
|
||||
return Memory.from_dict(json.loads(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get memory from Valkey: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _save_to_postgres(self, memory: Memory) -> None:
|
||||
"""Saves memory to PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO agent_memory
|
||||
(namespace, key, value, agent_id, access_count,
|
||||
created_at, expires_at, last_accessed, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (namespace, key) DO UPDATE SET
|
||||
value = EXCLUDED.value,
|
||||
access_count = agent_memory.access_count + 1,
|
||||
last_accessed = NOW(),
|
||||
metadata = EXCLUDED.metadata
|
||||
""",
|
||||
self.namespace,
|
||||
memory.key,
|
||||
json.dumps(memory.value),
|
||||
memory.agent_id,
|
||||
memory.access_count,
|
||||
memory.created_at,
|
||||
memory.expires_at,
|
||||
memory.last_accessed,
|
||||
json.dumps(memory.metadata)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save memory to PostgreSQL: {e}")
|
||||
|
||||
async def _get_from_postgres(self, key: str) -> Optional[Memory]:
|
||||
"""Retrieves memory from PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT key, value, agent_id, created_at, expires_at,
|
||||
access_count, last_accessed, metadata
|
||||
FROM agent_memory
|
||||
WHERE namespace = $1 AND key = $2
|
||||
""",
|
||||
self.namespace,
|
||||
key
|
||||
)
|
||||
|
||||
if row:
|
||||
return self._row_to_memory(row)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory from PostgreSQL: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _update_access(self, memory: Memory) -> None:
|
||||
"""Updates access count and timestamp"""
|
||||
memory.access_count += 1
|
||||
memory.last_accessed = datetime.now(timezone.utc)
|
||||
|
||||
# Update in PostgreSQL
|
||||
if self.db_pool:
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE agent_memory
|
||||
SET access_count = access_count + 1,
|
||||
last_accessed = NOW()
|
||||
WHERE namespace = $1 AND key = $2
|
||||
""",
|
||||
self.namespace,
|
||||
memory.key
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update access count: {e}")
|
||||
|
||||
def _row_to_memory(self, row) -> Memory:
|
||||
"""Converts a database row to Memory"""
|
||||
value = row["value"]
|
||||
if isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
|
||||
metadata = row.get("metadata", {})
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
return Memory(
|
||||
key=row["key"],
|
||||
value=value,
|
||||
agent_id=row["agent_id"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
access_count=row["access_count"] or 0,
|
||||
last_accessed=row["last_accessed"],
|
||||
metadata=metadata
|
||||
)
|
||||
36
agent-core/orchestrator/__init__.py
Normal file
36
agent-core/orchestrator/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Orchestration Layer for Breakpilot Agents
|
||||
|
||||
Provides:
|
||||
- MessageBus: Inter-agent communication via Redis Pub/Sub
|
||||
- AgentSupervisor: Agent lifecycle and health management
|
||||
- TaskRouter: Intent-based task routing
|
||||
"""
|
||||
|
||||
from agent_core.orchestrator.message_bus import (
|
||||
MessageBus,
|
||||
AgentMessage,
|
||||
MessagePriority,
|
||||
)
|
||||
from agent_core.orchestrator.supervisor import (
|
||||
AgentSupervisor,
|
||||
AgentInfo,
|
||||
AgentStatus,
|
||||
)
|
||||
from agent_core.orchestrator.task_router import (
|
||||
TaskRouter,
|
||||
RoutingResult,
|
||||
RoutingStrategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MessageBus",
|
||||
"AgentMessage",
|
||||
"MessagePriority",
|
||||
"AgentSupervisor",
|
||||
"AgentInfo",
|
||||
"AgentStatus",
|
||||
"TaskRouter",
|
||||
"RoutingResult",
|
||||
"RoutingStrategy",
|
||||
]
|
||||
479
agent-core/orchestrator/message_bus.py
Normal file
479
agent-core/orchestrator/message_bus.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""
|
||||
Message Bus for Inter-Agent Communication
|
||||
|
||||
Provides:
|
||||
- Pub/Sub messaging via Redis/Valkey
|
||||
- Request-Response pattern with timeouts
|
||||
- Priority-based message handling
|
||||
- Message persistence for audit
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, Any, List, Optional, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessagePriority(Enum):
|
||||
"""Message priority levels"""
|
||||
LOW = 0
|
||||
NORMAL = 1
|
||||
HIGH = 2
|
||||
CRITICAL = 3
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Standard message types"""
|
||||
REQUEST = "request"
|
||||
RESPONSE = "response"
|
||||
EVENT = "event"
|
||||
BROADCAST = "broadcast"
|
||||
HEARTBEAT = "heartbeat"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMessage:
|
||||
"""Represents a message between agents"""
|
||||
sender: str
|
||||
receiver: str
|
||||
message_type: str
|
||||
payload: Dict[str, Any]
|
||||
priority: MessagePriority = MessagePriority.NORMAL
|
||||
correlation_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
reply_to: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"sender": self.sender,
|
||||
"receiver": self.receiver,
|
||||
"message_type": self.message_type,
|
||||
"payload": self.payload,
|
||||
"priority": self.priority.value,
|
||||
"correlation_id": self.correlation_id,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"reply_to": self.reply_to,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentMessage":
|
||||
return cls(
|
||||
sender=data["sender"],
|
||||
receiver=data["receiver"],
|
||||
message_type=data["message_type"],
|
||||
payload=data["payload"],
|
||||
priority=MessagePriority(data.get("priority", 1)),
|
||||
correlation_id=data.get("correlation_id", str(uuid.uuid4())),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.now(timezone.utc),
|
||||
reply_to=data.get("reply_to"),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
MessageHandler = Callable[[AgentMessage], Awaitable[Optional[Dict[str, Any]]]]
|
||||
|
||||
|
||||
class MessageBus:
|
||||
"""
|
||||
Inter-agent communication via Redis Pub/Sub.
|
||||
|
||||
Features:
|
||||
- Publish/Subscribe pattern
|
||||
- Request/Response with timeout
|
||||
- Priority queues
|
||||
- Message persistence for audit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot",
|
||||
persist_messages: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize the message bus.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL pool for persistence
|
||||
namespace: Channel namespace
|
||||
persist_messages: Whether to persist messages for audit
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
self.persist_messages = persist_messages
|
||||
|
||||
self._handlers: Dict[str, List[MessageHandler]] = {}
|
||||
self._pending_responses: Dict[str, asyncio.Future] = {}
|
||||
self._subscriptions: Dict[str, asyncio.Task] = {}
|
||||
self._running = False
|
||||
|
||||
def _channel(self, agent_id: str) -> str:
|
||||
"""Generate channel name for agent"""
|
||||
return f"{self.namespace}:agent:{agent_id}"
|
||||
|
||||
def _broadcast_channel(self) -> str:
|
||||
"""Generate broadcast channel name"""
|
||||
return f"{self.namespace}:broadcast"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Starts the message bus"""
|
||||
self._running = True
|
||||
logger.info("Message bus started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stops the message bus and all subscriptions"""
|
||||
self._running = False
|
||||
|
||||
# Cancel all subscription tasks
|
||||
for task in self._subscriptions.values():
|
||||
task.cancel()
|
||||
|
||||
# Wait for cancellation
|
||||
if self._subscriptions:
|
||||
await asyncio.gather(
|
||||
*self._subscriptions.values(),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
self._subscriptions.clear()
|
||||
logger.info("Message bus stopped")
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
agent_id: str,
|
||||
handler: MessageHandler
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe an agent to receive messages.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID to subscribe
|
||||
handler: Async function to handle incoming messages
|
||||
"""
|
||||
if agent_id in self._subscriptions:
|
||||
logger.warning(f"Agent {agent_id} already subscribed")
|
||||
return
|
||||
|
||||
if agent_id not in self._handlers:
|
||||
self._handlers[agent_id] = []
|
||||
self._handlers[agent_id].append(handler)
|
||||
|
||||
if self.redis:
|
||||
# Start Redis subscription
|
||||
task = asyncio.create_task(
|
||||
self._subscription_loop(agent_id)
|
||||
)
|
||||
self._subscriptions[agent_id] = task
|
||||
|
||||
logger.info(f"Agent {agent_id} subscribed to message bus")
|
||||
|
||||
async def unsubscribe(self, agent_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe an agent from messages.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID to unsubscribe
|
||||
"""
|
||||
if agent_id in self._subscriptions:
|
||||
self._subscriptions[agent_id].cancel()
|
||||
try:
|
||||
await self._subscriptions[agent_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
del self._subscriptions[agent_id]
|
||||
|
||||
self._handlers.pop(agent_id, None)
|
||||
logger.info(f"Agent {agent_id} unsubscribed from message bus")
|
||||
|
||||
async def _subscription_loop(self, agent_id: str) -> None:
|
||||
"""Main subscription loop for an agent"""
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
channel = self._channel(agent_id)
|
||||
broadcast = self._broadcast_channel()
|
||||
|
||||
pubsub = self.redis.pubsub()
|
||||
await pubsub.subscribe(channel, broadcast)
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
if message and message["type"] == "message":
|
||||
await self._handle_incoming_message(
|
||||
agent_id,
|
||||
message["data"]
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
await pubsub.unsubscribe(channel, broadcast)
|
||||
await pubsub.close()
|
||||
|
||||
async def _handle_incoming_message(
|
||||
self,
|
||||
agent_id: str,
|
||||
raw_data: bytes
|
||||
) -> None:
|
||||
"""Process an incoming message"""
|
||||
try:
|
||||
data = json.loads(raw_data)
|
||||
message = AgentMessage.from_dict(data)
|
||||
|
||||
# Check if this is a response to a pending request
|
||||
if message.correlation_id in self._pending_responses:
|
||||
future = self._pending_responses[message.correlation_id]
|
||||
if not future.done():
|
||||
future.set_result(message.payload)
|
||||
return
|
||||
|
||||
# Call handlers
|
||||
handlers = self._handlers.get(agent_id, [])
|
||||
for handler in handlers:
|
||||
try:
|
||||
response = await handler(message)
|
||||
|
||||
# If handler returns data and there's a reply_to, send response
|
||||
if response and message.reply_to:
|
||||
await self.publish(AgentMessage(
|
||||
sender=agent_id,
|
||||
receiver=message.sender,
|
||||
message_type="response",
|
||||
payload=response,
|
||||
correlation_id=message.correlation_id
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in message handler for {agent_id}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
|
||||
async def publish(self, message: AgentMessage) -> None:
|
||||
"""
|
||||
Publishes a message to an agent.
|
||||
|
||||
Args:
|
||||
message: The message to publish
|
||||
"""
|
||||
# Persist message if enabled
|
||||
if self.persist_messages:
|
||||
await self._persist_message(message)
|
||||
|
||||
if self.redis:
|
||||
channel = self._channel(message.receiver)
|
||||
await self.redis.publish(
|
||||
channel,
|
||||
json.dumps(message.to_dict())
|
||||
)
|
||||
else:
|
||||
# Local delivery for testing
|
||||
await self._local_deliver(message)
|
||||
|
||||
logger.debug(
|
||||
f"Published message from {message.sender} to {message.receiver}: "
|
||||
f"{message.message_type}"
|
||||
)
|
||||
|
||||
async def broadcast(self, message: AgentMessage) -> None:
|
||||
"""
|
||||
Broadcasts a message to all agents.
|
||||
|
||||
Args:
|
||||
message: The message to broadcast
|
||||
"""
|
||||
message.receiver = "*" # Indicate broadcast
|
||||
|
||||
if self.persist_messages:
|
||||
await self._persist_message(message)
|
||||
|
||||
if self.redis:
|
||||
await self.redis.publish(
|
||||
self._broadcast_channel(),
|
||||
json.dumps(message.to_dict())
|
||||
)
|
||||
else:
|
||||
# Local broadcast
|
||||
for agent_id in self._handlers:
|
||||
await self._local_deliver(message, agent_id)
|
||||
|
||||
logger.debug(f"Broadcast message from {message.sender}: {message.message_type}")
|
||||
|
||||
async def request(
|
||||
self,
|
||||
message: AgentMessage,
|
||||
timeout: float = 30.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Sends a request and waits for a response.
|
||||
|
||||
Args:
|
||||
message: The request message
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Response payload
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no response within timeout
|
||||
"""
|
||||
# Mark this as a request that needs a response
|
||||
message.reply_to = message.sender
|
||||
|
||||
# Create future for response
|
||||
future: asyncio.Future = asyncio.Future()
|
||||
self._pending_responses[message.correlation_id] = future
|
||||
|
||||
try:
|
||||
# Publish the request
|
||||
await self.publish(message)
|
||||
|
||||
# Wait for response
|
||||
return await asyncio.wait_for(future, timeout)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Request timeout: {message.sender} -> {message.receiver} "
|
||||
f"({message.message_type})"
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
self._pending_responses.pop(message.correlation_id, None)
|
||||
|
||||
async def _local_deliver(
|
||||
self,
|
||||
message: AgentMessage,
|
||||
target_agent: Optional[str] = None
|
||||
) -> None:
|
||||
"""Local message delivery for testing without Redis"""
|
||||
agent_id = target_agent or message.receiver
|
||||
handlers = self._handlers.get(agent_id, [])
|
||||
|
||||
for handler in handlers:
|
||||
try:
|
||||
response = await handler(message)
|
||||
if response and message.reply_to:
|
||||
if message.correlation_id in self._pending_responses:
|
||||
future = self._pending_responses[message.correlation_id]
|
||||
if not future.done():
|
||||
future.set_result(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in local handler: {e}")
|
||||
|
||||
async def _persist_message(self, message: AgentMessage) -> None:
|
||||
"""Persist message to PostgreSQL for audit"""
|
||||
if not self.db_pool:
|
||||
return
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO agent_messages
|
||||
(id, sender, receiver, message_type, payload,
|
||||
priority, correlation_id, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
str(uuid.uuid4()),
|
||||
message.sender,
|
||||
message.receiver,
|
||||
message.message_type,
|
||||
json.dumps(message.payload),
|
||||
message.priority.value,
|
||||
message.correlation_id,
|
||||
message.timestamp
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to persist message: {e}")
|
||||
|
||||
def on_message(self, message_type: str):
|
||||
"""
|
||||
Decorator for message handlers.
|
||||
|
||||
Usage:
|
||||
@bus.on_message("grade_request")
|
||||
async def handle_grade(message):
|
||||
return {"score": 85}
|
||||
"""
|
||||
def decorator(func: MessageHandler):
|
||||
async def wrapper(message: AgentMessage):
|
||||
if message.message_type == message_type:
|
||||
return await func(message)
|
||||
return None
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
async def get_message_history(
|
||||
self,
|
||||
agent_id: Optional[str] = None,
|
||||
message_type: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Gets message history from persistence.
|
||||
|
||||
Args:
|
||||
agent_id: Filter by sender or receiver
|
||||
message_type: Filter by message type
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of message dicts
|
||||
"""
|
||||
if not self.db_pool:
|
||||
return []
|
||||
|
||||
query = """
|
||||
SELECT sender, receiver, message_type, payload, priority,
|
||||
correlation_id, created_at
|
||||
FROM agent_messages
|
||||
WHERE 1=1
|
||||
"""
|
||||
params = []
|
||||
|
||||
if agent_id:
|
||||
query += " AND (sender = $1 OR receiver = $1)"
|
||||
params.append(agent_id)
|
||||
|
||||
if message_type:
|
||||
param_num = len(params) + 1
|
||||
query += f" AND message_type = ${param_num}"
|
||||
params.append(message_type)
|
||||
|
||||
query += f" ORDER BY created_at DESC LIMIT {limit}"
|
||||
|
||||
async with self.db_pool.acquire() as conn:
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""Returns whether the bus is connected to Redis"""
|
||||
return self.redis is not None and self._running
|
||||
|
||||
@property
|
||||
def subscriber_count(self) -> int:
|
||||
"""Returns number of subscribed agents"""
|
||||
return len(self._subscriptions)
|
||||
553
agent-core/orchestrator/supervisor.py
Normal file
553
agent-core/orchestrator/supervisor.py
Normal file
@@ -0,0 +1,553 @@
|
||||
"""
|
||||
Agent Supervisor for Breakpilot
|
||||
|
||||
Provides:
|
||||
- Agent lifecycle management
|
||||
- Health monitoring
|
||||
- Restart policies
|
||||
- Load balancing
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Callable, Awaitable, List, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from agent_core.sessions.heartbeat import HeartbeatMonitor
|
||||
from agent_core.orchestrator.message_bus import (
|
||||
MessageBus,
|
||||
AgentMessage,
|
||||
MessagePriority,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentStatus(Enum):
|
||||
"""Agent lifecycle states"""
|
||||
INITIALIZING = "initializing"
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
STOPPING = "stopping"
|
||||
STOPPED = "stopped"
|
||||
ERROR = "error"
|
||||
RESTARTING = "restarting"
|
||||
|
||||
|
||||
class RestartPolicy(Enum):
|
||||
"""Agent restart policies"""
|
||||
NEVER = "never"
|
||||
ON_FAILURE = "on_failure"
|
||||
ALWAYS = "always"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
"""Information about a registered agent"""
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
status: AgentStatus = AgentStatus.INITIALIZING
|
||||
current_task: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
last_activity: Optional[datetime] = None
|
||||
error_count: int = 0
|
||||
restart_count: int = 0
|
||||
max_restarts: int = 3
|
||||
restart_policy: RestartPolicy = RestartPolicy.ON_FAILURE
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
capacity: int = 10 # Max concurrent tasks
|
||||
current_load: int = 0
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
"""Check if agent is healthy"""
|
||||
return self.status == AgentStatus.RUNNING and self.error_count < 3
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if agent can accept new tasks"""
|
||||
return (
|
||||
self.status == AgentStatus.RUNNING and
|
||||
self.current_load < self.capacity
|
||||
)
|
||||
|
||||
def utilization(self) -> float:
|
||||
"""Returns agent utilization (0-1)"""
|
||||
return self.current_load / max(self.capacity, 1)
|
||||
|
||||
|
||||
AgentFactory = Callable[[str], Awaitable[Any]]
|
||||
|
||||
|
||||
class AgentSupervisor:
|
||||
"""
|
||||
Supervises and coordinates all agents.
|
||||
|
||||
Responsibilities:
|
||||
- Agent registration and lifecycle
|
||||
- Health monitoring via heartbeat
|
||||
- Restart policies
|
||||
- Load balancing
|
||||
- Alert escalation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_bus: MessageBus,
|
||||
heartbeat_monitor: Optional[HeartbeatMonitor] = None,
|
||||
check_interval_seconds: int = 10
|
||||
):
|
||||
"""
|
||||
Initialize the supervisor.
|
||||
|
||||
Args:
|
||||
message_bus: Message bus for inter-agent communication
|
||||
heartbeat_monitor: Heartbeat monitor for liveness checks
|
||||
check_interval_seconds: How often to run health checks
|
||||
"""
|
||||
self.bus = message_bus
|
||||
self.heartbeat = heartbeat_monitor or HeartbeatMonitor()
|
||||
self.check_interval = check_interval_seconds
|
||||
|
||||
self.agents: Dict[str, AgentInfo] = {}
|
||||
self._factories: Dict[str, AgentFactory] = {}
|
||||
self._running = False
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Set up heartbeat timeout handler
|
||||
self.heartbeat.on_timeout = self._handle_agent_timeout
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Starts the supervisor"""
|
||||
self._running = True
|
||||
await self.heartbeat.start_monitoring()
|
||||
|
||||
# Start health check loop
|
||||
self._health_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
logger.info("Agent supervisor started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stops the supervisor and all agents"""
|
||||
self._running = False
|
||||
|
||||
# Stop health check
|
||||
if self._health_task:
|
||||
self._health_task.cancel()
|
||||
try:
|
||||
await self._health_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop heartbeat monitor
|
||||
await self.heartbeat.stop_monitoring()
|
||||
|
||||
# Stop all agents
|
||||
for agent_id in list(self.agents.keys()):
|
||||
await self.stop_agent(agent_id)
|
||||
|
||||
logger.info("Agent supervisor stopped")
|
||||
|
||||
def register_factory(
|
||||
self,
|
||||
agent_type: str,
|
||||
factory: AgentFactory
|
||||
) -> None:
|
||||
"""
|
||||
Registers a factory function for creating agents.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent this factory creates
|
||||
factory: Async function that creates agent instances
|
||||
"""
|
||||
self._factories[agent_type] = factory
|
||||
logger.debug(f"Registered factory for agent type: {agent_type}")
|
||||
|
||||
async def register_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_type: str,
|
||||
restart_policy: RestartPolicy = RestartPolicy.ON_FAILURE,
|
||||
max_restarts: int = 3,
|
||||
capacity: int = 10,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> AgentInfo:
|
||||
"""
|
||||
Registers a new agent with the supervisor.
|
||||
|
||||
Args:
|
||||
agent_id: Unique agent identifier
|
||||
agent_type: Type of agent
|
||||
restart_policy: When to restart the agent
|
||||
max_restarts: Maximum restart attempts
|
||||
capacity: Max concurrent tasks
|
||||
metadata: Additional agent metadata
|
||||
|
||||
Returns:
|
||||
AgentInfo for the registered agent
|
||||
"""
|
||||
if agent_id in self.agents:
|
||||
logger.warning(f"Agent {agent_id} already registered")
|
||||
return self.agents[agent_id]
|
||||
|
||||
info = AgentInfo(
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
restart_policy=restart_policy,
|
||||
max_restarts=max_restarts,
|
||||
capacity=capacity,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
self.agents[agent_id] = info
|
||||
self.heartbeat.register(agent_id, agent_type)
|
||||
|
||||
logger.info(f"Registered agent: {agent_id} ({agent_type})")
|
||||
|
||||
return info
|
||||
|
||||
async def start_agent(self, agent_id: str) -> bool:
|
||||
"""
|
||||
Starts a registered agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent to start
|
||||
|
||||
Returns:
|
||||
True if started successfully
|
||||
"""
|
||||
if agent_id not in self.agents:
|
||||
logger.error(f"Cannot start unregistered agent: {agent_id}")
|
||||
return False
|
||||
|
||||
info = self.agents[agent_id]
|
||||
|
||||
if info.status == AgentStatus.RUNNING:
|
||||
logger.warning(f"Agent {agent_id} is already running")
|
||||
return True
|
||||
|
||||
info.status = AgentStatus.STARTING
|
||||
|
||||
try:
|
||||
# If we have a factory, create the agent
|
||||
if info.agent_type in self._factories:
|
||||
factory = self._factories[info.agent_type]
|
||||
await factory(agent_id)
|
||||
|
||||
info.status = AgentStatus.RUNNING
|
||||
info.started_at = datetime.now(timezone.utc)
|
||||
info.last_activity = info.started_at
|
||||
|
||||
# Subscribe to message bus
|
||||
await self.bus.subscribe(
|
||||
agent_id,
|
||||
self._create_message_handler(agent_id)
|
||||
)
|
||||
|
||||
logger.info(f"Started agent: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
info.status = AgentStatus.ERROR
|
||||
info.error_count += 1
|
||||
logger.error(f"Failed to start agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
async def stop_agent(self, agent_id: str) -> bool:
|
||||
"""
|
||||
Stops a running agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent to stop
|
||||
|
||||
Returns:
|
||||
True if stopped successfully
|
||||
"""
|
||||
if agent_id not in self.agents:
|
||||
return False
|
||||
|
||||
info = self.agents[agent_id]
|
||||
info.status = AgentStatus.STOPPING
|
||||
|
||||
try:
|
||||
# Unsubscribe from message bus
|
||||
await self.bus.unsubscribe(agent_id)
|
||||
|
||||
# Unregister from heartbeat
|
||||
self.heartbeat.unregister(agent_id)
|
||||
|
||||
info.status = AgentStatus.STOPPED
|
||||
logger.info(f"Stopped agent: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
info.status = AgentStatus.ERROR
|
||||
logger.error(f"Error stopping agent {agent_id}: {e}")
|
||||
return False
|
||||
|
||||
async def restart_agent(self, agent_id: str) -> bool:
|
||||
"""
|
||||
Restarts an agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent to restart
|
||||
|
||||
Returns:
|
||||
True if restarted successfully
|
||||
"""
|
||||
if agent_id not in self.agents:
|
||||
return False
|
||||
|
||||
info = self.agents[agent_id]
|
||||
|
||||
# Check restart count
|
||||
if info.restart_count >= info.max_restarts:
|
||||
logger.error(
|
||||
f"Agent {agent_id} exceeded max restarts "
|
||||
f"({info.restart_count}/{info.max_restarts})"
|
||||
)
|
||||
await self._escalate_agent_failure(agent_id)
|
||||
return False
|
||||
|
||||
info.status = AgentStatus.RESTARTING
|
||||
info.restart_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Restarting agent {agent_id} "
|
||||
f"(attempt {info.restart_count}/{info.max_restarts})"
|
||||
)
|
||||
|
||||
# Stop and start
|
||||
await self.stop_agent(agent_id)
|
||||
await asyncio.sleep(1) # Brief pause
|
||||
return await self.start_agent(agent_id)
|
||||
|
||||
async def _handle_agent_timeout(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_type: str
|
||||
) -> None:
|
||||
"""Handles agent heartbeat timeout"""
|
||||
# Find the agent by session/ID
|
||||
agent_id = session_id # In our case, session_id == agent_id
|
||||
|
||||
if agent_id not in self.agents:
|
||||
return
|
||||
|
||||
info = self.agents[agent_id]
|
||||
info.status = AgentStatus.ERROR
|
||||
info.error_count += 1
|
||||
|
||||
logger.warning(f"Agent {agent_id} timed out (heartbeat)")
|
||||
|
||||
# Apply restart policy
|
||||
if info.restart_policy == RestartPolicy.NEVER:
|
||||
await self._escalate_agent_failure(agent_id)
|
||||
elif info.restart_policy == RestartPolicy.ON_FAILURE:
|
||||
if info.restart_count < info.max_restarts:
|
||||
await self.restart_agent(agent_id)
|
||||
else:
|
||||
await self._escalate_agent_failure(agent_id)
|
||||
elif info.restart_policy == RestartPolicy.ALWAYS:
|
||||
await self.restart_agent(agent_id)
|
||||
|
||||
async def _escalate_agent_failure(self, agent_id: str) -> None:
|
||||
"""Escalates an agent failure to the alert system"""
|
||||
info = self.agents.get(agent_id)
|
||||
if not info:
|
||||
return
|
||||
|
||||
await self.bus.publish(AgentMessage(
|
||||
sender="supervisor",
|
||||
receiver="alert-agent",
|
||||
message_type="agent_failure",
|
||||
payload={
|
||||
"agent_id": agent_id,
|
||||
"agent_type": info.agent_type,
|
||||
"error_count": info.error_count,
|
||||
"restart_count": info.restart_count,
|
||||
"last_activity": info.last_activity.isoformat() if info.last_activity else None
|
||||
},
|
||||
priority=MessagePriority.CRITICAL
|
||||
))
|
||||
|
||||
logger.error(f"Escalated agent failure: {agent_id}")
|
||||
|
||||
def _create_message_handler(self, agent_id: str):
|
||||
"""Creates a message handler that updates agent activity"""
|
||||
async def handler(message: AgentMessage):
|
||||
if agent_id in self.agents:
|
||||
self.agents[agent_id].last_activity = datetime.now(timezone.utc)
|
||||
# Heartbeat on activity
|
||||
self.heartbeat.beat(agent_id)
|
||||
return None
|
||||
return handler
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""Periodic health check loop"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.check_interval)
|
||||
await self._run_health_checks()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health check loop: {e}")
|
||||
|
||||
async def _run_health_checks(self) -> None:
|
||||
"""Runs health checks on all agents"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for agent_id, info in list(self.agents.items()):
|
||||
if info.status != AgentStatus.RUNNING:
|
||||
continue
|
||||
|
||||
# Check for stale agents (no activity for 5 minutes)
|
||||
if info.last_activity:
|
||||
idle_time = now - info.last_activity
|
||||
if idle_time > timedelta(minutes=5):
|
||||
logger.warning(
|
||||
f"Agent {agent_id} has been idle for "
|
||||
f"{idle_time.total_seconds():.0f}s"
|
||||
)
|
||||
|
||||
# Load Balancing
|
||||
|
||||
def get_available_agent(
|
||||
self,
|
||||
agent_type: str,
|
||||
strategy: str = "least_loaded"
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Gets an available agent of the specified type.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent needed
|
||||
strategy: Load balancing strategy
|
||||
|
||||
Returns:
|
||||
Agent ID or None if none available
|
||||
"""
|
||||
candidates = [
|
||||
info for info in self.agents.values()
|
||||
if info.agent_type == agent_type and info.is_available()
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
if strategy == "least_loaded":
|
||||
# Pick agent with lowest load
|
||||
best = min(candidates, key=lambda a: a.utilization())
|
||||
elif strategy == "round_robin":
|
||||
# Simple round-robin (just pick first available)
|
||||
best = candidates[0]
|
||||
else:
|
||||
best = candidates[0]
|
||||
|
||||
return best.agent_id
|
||||
|
||||
def acquire_capacity(self, agent_id: str) -> bool:
|
||||
"""
|
||||
Acquires capacity from an agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent to acquire from
|
||||
|
||||
Returns:
|
||||
True if capacity was acquired
|
||||
"""
|
||||
if agent_id not in self.agents:
|
||||
return False
|
||||
|
||||
info = self.agents[agent_id]
|
||||
if not info.is_available():
|
||||
return False
|
||||
|
||||
info.current_load += 1
|
||||
return True
|
||||
|
||||
def release_capacity(self, agent_id: str) -> None:
|
||||
"""
|
||||
Releases capacity back to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent to release to
|
||||
"""
|
||||
if agent_id in self.agents:
|
||||
info = self.agents[agent_id]
|
||||
info.current_load = max(0, info.current_load - 1)
|
||||
|
||||
# Status and Metrics
|
||||
|
||||
def get_agent_status(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Gets status information for an agent"""
|
||||
if agent_id not in self.agents:
|
||||
return None
|
||||
|
||||
info = self.agents[agent_id]
|
||||
return {
|
||||
"agent_id": info.agent_id,
|
||||
"agent_type": info.agent_type,
|
||||
"status": info.status.value,
|
||||
"current_task": info.current_task,
|
||||
"started_at": info.started_at.isoformat() if info.started_at else None,
|
||||
"last_activity": info.last_activity.isoformat() if info.last_activity else None,
|
||||
"error_count": info.error_count,
|
||||
"restart_count": info.restart_count,
|
||||
"utilization": info.utilization(),
|
||||
"is_healthy": info.is_healthy(),
|
||||
"is_available": info.is_available()
|
||||
}
|
||||
|
||||
def get_all_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Gets status for all agents"""
|
||||
return {
|
||||
agent_id: self.get_agent_status(agent_id)
|
||||
for agent_id in self.agents
|
||||
}
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Gets aggregate metrics"""
|
||||
total = len(self.agents)
|
||||
running = sum(
|
||||
1 for a in self.agents.values()
|
||||
if a.status == AgentStatus.RUNNING
|
||||
)
|
||||
healthy = sum(1 for a in self.agents.values() if a.is_healthy())
|
||||
available = sum(1 for a in self.agents.values() if a.is_available())
|
||||
|
||||
total_capacity = sum(a.capacity for a in self.agents.values())
|
||||
total_load = sum(a.current_load for a in self.agents.values())
|
||||
|
||||
return {
|
||||
"total_agents": total,
|
||||
"running_agents": running,
|
||||
"healthy_agents": healthy,
|
||||
"available_agents": available,
|
||||
"total_capacity": total_capacity,
|
||||
"total_load": total_load,
|
||||
"overall_utilization": total_load / max(total_capacity, 1),
|
||||
"by_type": self._metrics_by_type()
|
||||
}
|
||||
|
||||
def _metrics_by_type(self) -> Dict[str, Dict[str, int]]:
|
||||
"""Gets metrics grouped by agent type"""
|
||||
by_type: Dict[str, Dict[str, int]] = {}
|
||||
|
||||
for info in self.agents.values():
|
||||
if info.agent_type not in by_type:
|
||||
by_type[info.agent_type] = {
|
||||
"total": 0,
|
||||
"running": 0,
|
||||
"healthy": 0
|
||||
}
|
||||
|
||||
by_type[info.agent_type]["total"] += 1
|
||||
if info.status == AgentStatus.RUNNING:
|
||||
by_type[info.agent_type]["running"] += 1
|
||||
if info.is_healthy():
|
||||
by_type[info.agent_type]["healthy"] += 1
|
||||
|
||||
return by_type
|
||||
436
agent-core/orchestrator/task_router.py
Normal file
436
agent-core/orchestrator/task_router.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
Task Router for Intent-Based Routing
|
||||
|
||||
Provides:
|
||||
- Intent classification
|
||||
- Agent selection
|
||||
- Fallback handling
|
||||
- Routing metrics
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, List, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RoutingStrategy(Enum):
|
||||
"""Routing strategies"""
|
||||
DIRECT = "direct" # Route to specific agent
|
||||
ROUND_ROBIN = "round_robin" # Distribute evenly
|
||||
LEAST_LOADED = "least_loaded" # Route to least loaded
|
||||
PRIORITY = "priority" # Route based on priority
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingRule:
|
||||
"""A rule for routing tasks to agents"""
|
||||
intent_pattern: str
|
||||
agent_type: str
|
||||
priority: int = 0
|
||||
conditions: Dict[str, Any] = field(default_factory=dict)
|
||||
fallback_agent: Optional[str] = None
|
||||
|
||||
def matches(self, intent: str, context: Dict[str, Any]) -> bool:
|
||||
"""Check if this rule matches the intent and context"""
|
||||
# Check intent pattern (supports wildcards)
|
||||
pattern = self.intent_pattern.replace("*", ".*")
|
||||
if not re.match(f"^{pattern}$", intent, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
# Check conditions
|
||||
for key, value in self.conditions.items():
|
||||
if context.get(key) != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingResult:
|
||||
"""Result of a routing decision"""
|
||||
success: bool
|
||||
agent_id: Optional[str] = None
|
||||
agent_type: Optional[str] = None
|
||||
is_fallback: bool = False
|
||||
reason: str = ""
|
||||
routing_time_ms: float = 0
|
||||
|
||||
|
||||
class TaskRouter:
|
||||
"""
|
||||
Routes tasks to appropriate agents based on intent.
|
||||
|
||||
Features:
|
||||
- Pattern-based routing rules
|
||||
- Priority ordering
|
||||
- Fallback chains
|
||||
- Routing metrics
|
||||
"""
|
||||
|
||||
def __init__(self, supervisor=None):
|
||||
"""
|
||||
Initialize the task router.
|
||||
|
||||
Args:
|
||||
supervisor: AgentSupervisor for agent availability
|
||||
"""
|
||||
self.supervisor = supervisor
|
||||
self.rules: List[RoutingRule] = []
|
||||
self._default_routes: Dict[str, str] = {}
|
||||
self._routing_history: List[Dict[str, Any]] = []
|
||||
self._max_history = 1000
|
||||
|
||||
# Initialize default rules
|
||||
self._setup_default_rules()
|
||||
|
||||
def _setup_default_rules(self) -> None:
|
||||
"""Sets up default routing rules"""
|
||||
default_rules = [
|
||||
# Learning support
|
||||
RoutingRule(
|
||||
intent_pattern="learning_*",
|
||||
agent_type="tutor-agent",
|
||||
priority=10,
|
||||
fallback_agent="orchestrator"
|
||||
),
|
||||
RoutingRule(
|
||||
intent_pattern="help_*",
|
||||
agent_type="tutor-agent",
|
||||
priority=5
|
||||
),
|
||||
RoutingRule(
|
||||
intent_pattern="explain_*",
|
||||
agent_type="tutor-agent",
|
||||
priority=5
|
||||
),
|
||||
|
||||
# Grading
|
||||
RoutingRule(
|
||||
intent_pattern="grade_*",
|
||||
agent_type="grader-agent",
|
||||
priority=10,
|
||||
fallback_agent="quality-judge"
|
||||
),
|
||||
RoutingRule(
|
||||
intent_pattern="evaluate_*",
|
||||
agent_type="grader-agent",
|
||||
priority=5
|
||||
),
|
||||
RoutingRule(
|
||||
intent_pattern="correct_*",
|
||||
agent_type="grader-agent",
|
||||
priority=5
|
||||
),
|
||||
|
||||
# Quality checks
|
||||
RoutingRule(
|
||||
intent_pattern="quality_*",
|
||||
agent_type="quality-judge",
|
||||
priority=10
|
||||
),
|
||||
RoutingRule(
|
||||
intent_pattern="review_*",
|
||||
agent_type="quality-judge",
|
||||
priority=5
|
||||
),
|
||||
|
||||
# Alerts
|
||||
RoutingRule(
|
||||
intent_pattern="alert_*",
|
||||
agent_type="alert-agent",
|
||||
priority=10
|
||||
),
|
||||
RoutingRule(
|
||||
intent_pattern="notify_*",
|
||||
agent_type="alert-agent",
|
||||
priority=5
|
||||
),
|
||||
|
||||
# System/Admin
|
||||
RoutingRule(
|
||||
intent_pattern="system_*",
|
||||
agent_type="orchestrator",
|
||||
priority=10
|
||||
),
|
||||
RoutingRule(
|
||||
intent_pattern="admin_*",
|
||||
agent_type="orchestrator",
|
||||
priority=10
|
||||
),
|
||||
]
|
||||
|
||||
for rule in default_rules:
|
||||
self.add_rule(rule)
|
||||
|
||||
def add_rule(self, rule: RoutingRule) -> None:
|
||||
"""
|
||||
Adds a routing rule.
|
||||
|
||||
Args:
|
||||
rule: The routing rule to add
|
||||
"""
|
||||
self.rules.append(rule)
|
||||
# Sort by priority (higher first)
|
||||
self.rules.sort(key=lambda r: r.priority, reverse=True)
|
||||
|
||||
def remove_rule(self, intent_pattern: str) -> bool:
|
||||
"""
|
||||
Removes a routing rule by pattern.
|
||||
|
||||
Args:
|
||||
intent_pattern: The pattern to remove
|
||||
|
||||
Returns:
|
||||
True if a rule was removed
|
||||
"""
|
||||
original_len = len(self.rules)
|
||||
self.rules = [r for r in self.rules if r.intent_pattern != intent_pattern]
|
||||
return len(self.rules) < original_len
|
||||
|
||||
def set_default_route(self, agent_type: str, agent_id: str) -> None:
|
||||
"""
|
||||
Sets a default agent for a type.
|
||||
|
||||
Args:
|
||||
agent_type: The agent type
|
||||
agent_id: The default agent ID
|
||||
"""
|
||||
self._default_routes[agent_type] = agent_id
|
||||
|
||||
async def route(
|
||||
self,
|
||||
intent: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
strategy: RoutingStrategy = RoutingStrategy.LEAST_LOADED
|
||||
) -> RoutingResult:
|
||||
"""
|
||||
Routes a task based on intent.
|
||||
|
||||
Args:
|
||||
intent: The task intent
|
||||
context: Additional context for routing
|
||||
strategy: Load balancing strategy
|
||||
|
||||
Returns:
|
||||
RoutingResult with agent assignment
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
context = context or {}
|
||||
|
||||
# Find matching rule
|
||||
matching_rule = None
|
||||
for rule in self.rules:
|
||||
if rule.matches(intent, context):
|
||||
matching_rule = rule
|
||||
break
|
||||
|
||||
if not matching_rule:
|
||||
result = RoutingResult(
|
||||
success=False,
|
||||
reason=f"No routing rule matches intent: {intent}"
|
||||
)
|
||||
self._record_routing(intent, result)
|
||||
return result
|
||||
|
||||
# Get available agent
|
||||
agent_id = await self._get_agent(
|
||||
matching_rule.agent_type,
|
||||
strategy
|
||||
)
|
||||
|
||||
if agent_id:
|
||||
result = RoutingResult(
|
||||
success=True,
|
||||
agent_id=agent_id,
|
||||
agent_type=matching_rule.agent_type,
|
||||
is_fallback=False,
|
||||
reason="Primary agent selected"
|
||||
)
|
||||
elif matching_rule.fallback_agent:
|
||||
# Try fallback
|
||||
agent_id = await self._get_agent(
|
||||
matching_rule.fallback_agent,
|
||||
strategy
|
||||
)
|
||||
if agent_id:
|
||||
result = RoutingResult(
|
||||
success=True,
|
||||
agent_id=agent_id,
|
||||
agent_type=matching_rule.fallback_agent,
|
||||
is_fallback=True,
|
||||
reason="Fallback agent selected"
|
||||
)
|
||||
else:
|
||||
result = RoutingResult(
|
||||
success=False,
|
||||
reason="No agents available (primary or fallback)"
|
||||
)
|
||||
else:
|
||||
result = RoutingResult(
|
||||
success=False,
|
||||
reason=f"No {matching_rule.agent_type} agents available"
|
||||
)
|
||||
|
||||
# Calculate routing time
|
||||
end_time = datetime.now(timezone.utc)
|
||||
result.routing_time_ms = (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
self._record_routing(intent, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _get_agent(
|
||||
self,
|
||||
agent_type: str,
|
||||
strategy: RoutingStrategy
|
||||
) -> Optional[str]:
|
||||
"""Gets an available agent of the specified type"""
|
||||
# Check default route first
|
||||
if agent_type in self._default_routes:
|
||||
agent_id = self._default_routes[agent_type]
|
||||
if self.supervisor and self.supervisor.agents.get(agent_id):
|
||||
info = self.supervisor.agents[agent_id]
|
||||
if info.is_available():
|
||||
return agent_id
|
||||
|
||||
# Use supervisor for load balancing
|
||||
if self.supervisor:
|
||||
strategy_str = "least_loaded"
|
||||
if strategy == RoutingStrategy.ROUND_ROBIN:
|
||||
strategy_str = "round_robin"
|
||||
|
||||
return self.supervisor.get_available_agent(
|
||||
agent_type,
|
||||
strategy_str
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _record_routing(
|
||||
self,
|
||||
intent: str,
|
||||
result: RoutingResult
|
||||
) -> None:
|
||||
"""Records routing decision for analytics"""
|
||||
record = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"intent": intent,
|
||||
"success": result.success,
|
||||
"agent_id": result.agent_id,
|
||||
"agent_type": result.agent_type,
|
||||
"is_fallback": result.is_fallback,
|
||||
"routing_time_ms": result.routing_time_ms,
|
||||
"reason": result.reason
|
||||
}
|
||||
|
||||
self._routing_history.append(record)
|
||||
|
||||
# Trim history
|
||||
if len(self._routing_history) > self._max_history:
|
||||
self._routing_history = self._routing_history[-self._max_history:]
|
||||
|
||||
# Log
|
||||
if result.success:
|
||||
logger.debug(
|
||||
f"Routed '{intent}' to {result.agent_id} "
|
||||
f"({'fallback' if result.is_fallback else 'primary'})"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Failed to route '{intent}': {result.reason}")
|
||||
|
||||
# Analytics
|
||||
|
||||
def get_routing_stats(self) -> Dict[str, Any]:
|
||||
"""Gets routing statistics"""
|
||||
if not self._routing_history:
|
||||
return {
|
||||
"total_routes": 0,
|
||||
"success_rate": 0,
|
||||
"fallback_rate": 0,
|
||||
"avg_routing_time_ms": 0
|
||||
}
|
||||
|
||||
total = len(self._routing_history)
|
||||
successful = sum(1 for r in self._routing_history if r["success"])
|
||||
fallbacks = sum(1 for r in self._routing_history if r["is_fallback"])
|
||||
avg_time = sum(
|
||||
r["routing_time_ms"] for r in self._routing_history
|
||||
) / total
|
||||
|
||||
return {
|
||||
"total_routes": total,
|
||||
"successful_routes": successful,
|
||||
"success_rate": successful / total,
|
||||
"fallback_routes": fallbacks,
|
||||
"fallback_rate": fallbacks / max(successful, 1),
|
||||
"avg_routing_time_ms": avg_time
|
||||
}
|
||||
|
||||
def get_intent_distribution(self) -> Dict[str, int]:
|
||||
"""Gets distribution of routed intents"""
|
||||
distribution: Dict[str, int] = {}
|
||||
for record in self._routing_history:
|
||||
intent = record["intent"]
|
||||
# Extract base intent (before _)
|
||||
base = intent.split("_")[0] if "_" in intent else intent
|
||||
distribution[base] = distribution.get(base, 0) + 1
|
||||
return distribution
|
||||
|
||||
def get_agent_distribution(self) -> Dict[str, int]:
|
||||
"""Gets distribution of routes by agent type"""
|
||||
distribution: Dict[str, int] = {}
|
||||
for record in self._routing_history:
|
||||
agent_type = record.get("agent_type", "unknown")
|
||||
if agent_type:
|
||||
distribution[agent_type] = distribution.get(agent_type, 0) + 1
|
||||
return distribution
|
||||
|
||||
def get_failure_reasons(self) -> Dict[str, int]:
|
||||
"""Gets distribution of routing failure reasons"""
|
||||
reasons: Dict[str, int] = {}
|
||||
for record in self._routing_history:
|
||||
if not record["success"]:
|
||||
reason = record["reason"]
|
||||
reasons[reason] = reasons.get(reason, 0) + 1
|
||||
return reasons
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clears routing history"""
|
||||
self._routing_history.clear()
|
||||
|
||||
# Rule inspection
|
||||
|
||||
def get_rules(self) -> List[Dict[str, Any]]:
|
||||
"""Gets all routing rules as dicts"""
|
||||
return [
|
||||
{
|
||||
"intent_pattern": r.intent_pattern,
|
||||
"agent_type": r.agent_type,
|
||||
"priority": r.priority,
|
||||
"conditions": r.conditions,
|
||||
"fallback_agent": r.fallback_agent
|
||||
}
|
||||
for r in self.rules
|
||||
]
|
||||
|
||||
def find_matching_rules(
|
||||
self,
|
||||
intent: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Finds all rules that match an intent"""
|
||||
context = context or {}
|
||||
return [
|
||||
{
|
||||
"intent_pattern": r.intent_pattern,
|
||||
"agent_type": r.agent_type,
|
||||
"priority": r.priority
|
||||
}
|
||||
for r in self.rules
|
||||
if r.matches(intent, context)
|
||||
]
|
||||
10
agent-core/pytest.ini
Normal file
10
agent-core/pytest.ini
Normal file
@@ -0,0 +1,10 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
asyncio_mode = auto
|
||||
addopts = -v --tb=short
|
||||
markers =
|
||||
asyncio: mark test as async
|
||||
slow: mark test as slow
|
||||
19
agent-core/requirements.txt
Normal file
19
agent-core/requirements.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
# Agent Core Dependencies
|
||||
# Note: Most dependencies are already available via voice-service/backend
|
||||
|
||||
# Async support (should be in voice-service already)
|
||||
# asyncio is part of Python standard library
|
||||
|
||||
# Redis/Valkey client (already in project via voice-service)
|
||||
# redis>=4.5.0
|
||||
|
||||
# PostgreSQL async client (already in project)
|
||||
# asyncpg>=0.28.0
|
||||
|
||||
# Testing dependencies
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
pytest-cov>=4.1.0
|
||||
|
||||
# Type checking (optional)
|
||||
# mypy>=1.5.0
|
||||
25
agent-core/sessions/__init__.py
Normal file
25
agent-core/sessions/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Session Management for Breakpilot Agents
|
||||
|
||||
Provides:
|
||||
- AgentSession: Individual agent session with context and checkpoints
|
||||
- SessionManager: Create, retrieve, and manage agent sessions
|
||||
- HeartbeatMonitor: Monitor agent liveness
|
||||
- SessionState: Session state enumeration
|
||||
"""
|
||||
|
||||
from agent_core.sessions.session_manager import (
|
||||
AgentSession,
|
||||
SessionManager,
|
||||
SessionState,
|
||||
)
|
||||
from agent_core.sessions.heartbeat import HeartbeatMonitor
|
||||
from agent_core.sessions.checkpoint import CheckpointManager
|
||||
|
||||
__all__ = [
|
||||
"AgentSession",
|
||||
"SessionManager",
|
||||
"SessionState",
|
||||
"HeartbeatMonitor",
|
||||
"CheckpointManager",
|
||||
]
|
||||
362
agent-core/sessions/checkpoint.py
Normal file
362
agent-core/sessions/checkpoint.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Checkpoint Management for Breakpilot Agents
|
||||
|
||||
Provides checkpoint-based recovery with:
|
||||
- Named checkpoints for semantic recovery points
|
||||
- Automatic checkpoint compression
|
||||
- Recovery from specific checkpoints
|
||||
- Checkpoint analytics
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Callable
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
"""Represents a recovery checkpoint"""
|
||||
id: str
|
||||
name: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
data=data["data"],
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Manages checkpoints for agent sessions.
|
||||
|
||||
Provides:
|
||||
- Named checkpoints for semantic recovery
|
||||
- Automatic compression of old checkpoints
|
||||
- Recovery to specific checkpoint states
|
||||
- Analytics on checkpoint patterns
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
max_checkpoints: int = 100,
|
||||
compress_after: int = 50
|
||||
):
|
||||
"""
|
||||
Initialize the checkpoint manager.
|
||||
|
||||
Args:
|
||||
session_id: The session ID this manager belongs to
|
||||
max_checkpoints: Maximum number of checkpoints to retain
|
||||
compress_after: Compress checkpoints after this count
|
||||
"""
|
||||
self.session_id = session_id
|
||||
self.max_checkpoints = max_checkpoints
|
||||
self.compress_after = compress_after
|
||||
self._checkpoints: List[Checkpoint] = []
|
||||
self._checkpoint_count = 0
|
||||
self._on_checkpoint: Optional[Callable[[Checkpoint], None]] = None
|
||||
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
data: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Checkpoint:
|
||||
"""
|
||||
Creates a new checkpoint.
|
||||
|
||||
Args:
|
||||
name: Semantic name for the checkpoint (e.g., "task_started")
|
||||
data: Checkpoint data to store
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
self._checkpoint_count += 1
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=f"{self.session_id}:{self._checkpoint_count}",
|
||||
name=name,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data=data,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
self._checkpoints.append(checkpoint)
|
||||
|
||||
# Compress if needed
|
||||
if len(self._checkpoints) > self.compress_after:
|
||||
self._compress_checkpoints()
|
||||
|
||||
# Trigger callback
|
||||
if self._on_checkpoint:
|
||||
self._on_checkpoint(checkpoint)
|
||||
|
||||
logger.debug(
|
||||
f"Session {self.session_id}: Created checkpoint '{name}' "
|
||||
f"(#{self._checkpoint_count})"
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
|
||||
"""
|
||||
Gets a checkpoint by ID.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The checkpoint ID
|
||||
|
||||
Returns:
|
||||
The checkpoint or None if not found
|
||||
"""
|
||||
for cp in self._checkpoints:
|
||||
if cp.id == checkpoint_id:
|
||||
return cp
|
||||
return None
|
||||
|
||||
def get_by_name(self, name: str) -> List[Checkpoint]:
|
||||
"""
|
||||
Gets all checkpoints with a given name.
|
||||
|
||||
Args:
|
||||
name: The checkpoint name
|
||||
|
||||
Returns:
|
||||
List of matching checkpoints (newest first)
|
||||
"""
|
||||
return [
|
||||
cp for cp in reversed(self._checkpoints)
|
||||
if cp.name == name
|
||||
]
|
||||
|
||||
def get_latest(self, name: Optional[str] = None) -> Optional[Checkpoint]:
|
||||
"""
|
||||
Gets the latest checkpoint, optionally filtered by name.
|
||||
|
||||
Args:
|
||||
name: Optional name filter
|
||||
|
||||
Returns:
|
||||
The latest matching checkpoint or None
|
||||
"""
|
||||
if not self._checkpoints:
|
||||
return None
|
||||
|
||||
if name:
|
||||
matching = self.get_by_name(name)
|
||||
return matching[0] if matching else None
|
||||
|
||||
return self._checkpoints[-1]
|
||||
|
||||
def get_all(self) -> List[Checkpoint]:
|
||||
"""Returns all checkpoints"""
|
||||
return list(self._checkpoints)
|
||||
|
||||
def get_since(self, timestamp: datetime) -> List[Checkpoint]:
|
||||
"""
|
||||
Gets all checkpoints since a given timestamp.
|
||||
|
||||
Args:
|
||||
timestamp: The starting timestamp
|
||||
|
||||
Returns:
|
||||
List of checkpoints after the timestamp
|
||||
"""
|
||||
return [
|
||||
cp for cp in self._checkpoints
|
||||
if cp.timestamp > timestamp
|
||||
]
|
||||
|
||||
def get_between(
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
) -> List[Checkpoint]:
|
||||
"""
|
||||
Gets checkpoints between two timestamps.
|
||||
|
||||
Args:
|
||||
start: Start timestamp
|
||||
end: End timestamp
|
||||
|
||||
Returns:
|
||||
List of checkpoints in the range
|
||||
"""
|
||||
return [
|
||||
cp for cp in self._checkpoints
|
||||
if start <= cp.timestamp <= end
|
||||
]
|
||||
|
||||
def rollback_to(self, checkpoint_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Gets data needed to rollback to a checkpoint.
|
||||
|
||||
Note: This doesn't actually rollback - it returns the checkpoint
|
||||
data for the caller to use for recovery.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The checkpoint to rollback to
|
||||
|
||||
Returns:
|
||||
The checkpoint data or None if not found
|
||||
"""
|
||||
checkpoint = self.get(checkpoint_id)
|
||||
if checkpoint:
|
||||
logger.info(
|
||||
f"Session {self.session_id}: Rollback to checkpoint "
|
||||
f"'{checkpoint.name}' ({checkpoint_id})"
|
||||
)
|
||||
return checkpoint.data
|
||||
return None
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clears all checkpoints.
|
||||
|
||||
Returns:
|
||||
Number of checkpoints cleared
|
||||
"""
|
||||
count = len(self._checkpoints)
|
||||
self._checkpoints.clear()
|
||||
logger.info(f"Session {self.session_id}: Cleared {count} checkpoints")
|
||||
return count
|
||||
|
||||
def _compress_checkpoints(self) -> None:
|
||||
"""
|
||||
Compresses old checkpoints to save memory.
|
||||
|
||||
Keeps:
|
||||
- First checkpoint (session start)
|
||||
- Last N checkpoints (recent history)
|
||||
- One checkpoint per unique name (latest)
|
||||
"""
|
||||
if len(self._checkpoints) <= self.compress_after:
|
||||
return
|
||||
|
||||
# Keep first checkpoint
|
||||
first = self._checkpoints[0]
|
||||
|
||||
# Keep last 20 checkpoints
|
||||
recent = self._checkpoints[-20:]
|
||||
|
||||
# Keep one of each unique name from the middle
|
||||
middle = self._checkpoints[1:-20]
|
||||
by_name: Dict[str, Checkpoint] = {}
|
||||
for cp in middle:
|
||||
# Keep the latest of each name
|
||||
if cp.name not in by_name or cp.timestamp > by_name[cp.name].timestamp:
|
||||
by_name[cp.name] = cp
|
||||
|
||||
# Combine and sort
|
||||
compressed = [first] + list(by_name.values()) + recent
|
||||
compressed.sort(key=lambda cp: cp.timestamp)
|
||||
|
||||
old_count = len(self._checkpoints)
|
||||
self._checkpoints = compressed
|
||||
|
||||
logger.debug(
|
||||
f"Session {self.session_id}: Compressed checkpoints "
|
||||
f"from {old_count} to {len(self._checkpoints)}"
|
||||
)
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets a summary of checkpoint activity.
|
||||
|
||||
Returns:
|
||||
Summary dict with counts and timing info
|
||||
"""
|
||||
if not self._checkpoints:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"unique_names": 0,
|
||||
"names": {},
|
||||
"first_checkpoint": None,
|
||||
"last_checkpoint": None,
|
||||
"duration_seconds": 0
|
||||
}
|
||||
|
||||
name_counts: Dict[str, int] = {}
|
||||
for cp in self._checkpoints:
|
||||
name_counts[cp.name] = name_counts.get(cp.name, 0) + 1
|
||||
|
||||
first = self._checkpoints[0]
|
||||
last = self._checkpoints[-1]
|
||||
|
||||
return {
|
||||
"total_count": len(self._checkpoints),
|
||||
"unique_names": len(name_counts),
|
||||
"names": name_counts,
|
||||
"first_checkpoint": first.to_dict(),
|
||||
"last_checkpoint": last.to_dict(),
|
||||
"duration_seconds": (last.timestamp - first.timestamp).total_seconds()
|
||||
}
|
||||
|
||||
def on_checkpoint(self, callback: Callable[[Checkpoint], None]) -> None:
|
||||
"""
|
||||
Sets a callback to be called on each checkpoint.
|
||||
|
||||
Args:
|
||||
callback: Function to call with each checkpoint
|
||||
"""
|
||||
self._on_checkpoint = callback
|
||||
|
||||
def export(self) -> str:
|
||||
"""
|
||||
Exports all checkpoints to JSON.
|
||||
|
||||
Returns:
|
||||
JSON string of all checkpoints
|
||||
"""
|
||||
return json.dumps(
|
||||
[cp.to_dict() for cp in self._checkpoints],
|
||||
indent=2
|
||||
)
|
||||
|
||||
def import_checkpoints(self, json_data: str) -> int:
|
||||
"""
|
||||
Imports checkpoints from JSON.
|
||||
|
||||
Args:
|
||||
json_data: JSON string of checkpoints
|
||||
|
||||
Returns:
|
||||
Number of checkpoints imported
|
||||
"""
|
||||
data = json.loads(json_data)
|
||||
imported = [Checkpoint.from_dict(cp) for cp in data]
|
||||
self._checkpoints.extend(imported)
|
||||
self._checkpoint_count = max(
|
||||
self._checkpoint_count,
|
||||
len(self._checkpoints)
|
||||
)
|
||||
return len(imported)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._checkpoints)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._checkpoints)
|
||||
361
agent-core/sessions/heartbeat.py
Normal file
361
agent-core/sessions/heartbeat.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Heartbeat Monitoring for Breakpilot Agents
|
||||
|
||||
Provides liveness monitoring for agents with:
|
||||
- Configurable timeout thresholds
|
||||
- Async background monitoring
|
||||
- Callback-based timeout handling
|
||||
- Integration with SessionManager
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Callable, Optional, Awaitable, Set
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeartbeatEntry:
|
||||
"""Represents a heartbeat entry for an agent"""
|
||||
session_id: str
|
||||
agent_type: str
|
||||
last_beat: datetime
|
||||
missed_beats: int = 0
|
||||
|
||||
|
||||
class HeartbeatMonitor:
|
||||
"""
|
||||
Monitors agent heartbeats and triggers callbacks on timeout.
|
||||
|
||||
Usage:
|
||||
monitor = HeartbeatMonitor(timeout_seconds=30)
|
||||
monitor.on_timeout = handle_timeout
|
||||
await monitor.start_monitoring()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout_seconds: int = 30,
|
||||
check_interval_seconds: int = 5,
|
||||
max_missed_beats: int = 3
|
||||
):
|
||||
"""
|
||||
Initialize the heartbeat monitor.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Time without heartbeat before considered stale
|
||||
check_interval_seconds: How often to check for stale sessions
|
||||
max_missed_beats: Number of missed beats before triggering timeout
|
||||
"""
|
||||
self.sessions: Dict[str, HeartbeatEntry] = {}
|
||||
self.timeout = timedelta(seconds=timeout_seconds)
|
||||
self.check_interval = check_interval_seconds
|
||||
self.max_missed_beats = max_missed_beats
|
||||
self.on_timeout: Optional[Callable[[str, str], Awaitable[None]]] = None
|
||||
self.on_warning: Optional[Callable[[str, int], Awaitable[None]]] = None
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._paused_sessions: Set[str] = set()
|
||||
|
||||
async def start_monitoring(self) -> None:
|
||||
"""
|
||||
Starts the background heartbeat monitoring task.
|
||||
|
||||
This runs indefinitely until stop_monitoring() is called.
|
||||
"""
|
||||
if self._running:
|
||||
logger.warning("Heartbeat monitor already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._monitoring_loop())
|
||||
logger.info(
|
||||
f"Heartbeat monitor started (timeout={self.timeout.seconds}s, "
|
||||
f"interval={self.check_interval}s)"
|
||||
)
|
||||
|
||||
async def stop_monitoring(self) -> None:
|
||||
"""Stops the heartbeat monitoring task"""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
logger.info("Heartbeat monitor stopped")
|
||||
|
||||
async def _monitoring_loop(self) -> None:
|
||||
"""Main monitoring loop"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.check_interval)
|
||||
await self._check_heartbeats()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in heartbeat monitoring: {e}")
|
||||
|
||||
async def _check_heartbeats(self) -> None:
|
||||
"""Checks all registered sessions for stale heartbeats"""
|
||||
now = datetime.now(timezone.utc)
|
||||
timed_out = []
|
||||
|
||||
for session_id, entry in list(self.sessions.items()):
|
||||
# Skip paused sessions
|
||||
if session_id in self._paused_sessions:
|
||||
continue
|
||||
|
||||
time_since_beat = now - entry.last_beat
|
||||
|
||||
if time_since_beat > self.timeout:
|
||||
entry.missed_beats += 1
|
||||
|
||||
# Warn on first missed beat
|
||||
if entry.missed_beats == 1 and self.on_warning:
|
||||
await self.on_warning(session_id, entry.missed_beats)
|
||||
logger.warning(
|
||||
f"Session {session_id} missed heartbeat "
|
||||
f"({entry.missed_beats}/{self.max_missed_beats})"
|
||||
)
|
||||
|
||||
# Timeout after max missed beats
|
||||
if entry.missed_beats >= self.max_missed_beats:
|
||||
timed_out.append((session_id, entry.agent_type))
|
||||
|
||||
# Handle timeouts
|
||||
for session_id, agent_type in timed_out:
|
||||
logger.error(
|
||||
f"Session {session_id} ({agent_type}) timed out after "
|
||||
f"{self.max_missed_beats} missed heartbeats"
|
||||
)
|
||||
|
||||
if self.on_timeout:
|
||||
try:
|
||||
await self.on_timeout(session_id, agent_type)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in timeout handler: {e}")
|
||||
|
||||
# Remove from tracking
|
||||
del self.sessions[session_id]
|
||||
self._paused_sessions.discard(session_id)
|
||||
|
||||
def register(self, session_id: str, agent_type: str) -> None:
|
||||
"""
|
||||
Registers a session for heartbeat monitoring.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to monitor
|
||||
agent_type: The type of agent
|
||||
"""
|
||||
self.sessions[session_id] = HeartbeatEntry(
|
||||
session_id=session_id,
|
||||
agent_type=agent_type,
|
||||
last_beat=datetime.now(timezone.utc)
|
||||
)
|
||||
logger.debug(f"Registered session {session_id} for heartbeat monitoring")
|
||||
|
||||
def beat(self, session_id: str) -> bool:
|
||||
"""
|
||||
Records a heartbeat for a session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
True if the session is registered, False otherwise
|
||||
"""
|
||||
if session_id in self.sessions:
|
||||
self.sessions[session_id].last_beat = datetime.now(timezone.utc)
|
||||
self.sessions[session_id].missed_beats = 0
|
||||
return True
|
||||
return False
|
||||
|
||||
def unregister(self, session_id: str) -> bool:
|
||||
"""
|
||||
Unregisters a session from heartbeat monitoring.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to unregister
|
||||
|
||||
Returns:
|
||||
True if the session was registered, False otherwise
|
||||
"""
|
||||
self._paused_sessions.discard(session_id)
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
logger.debug(f"Unregistered session {session_id} from heartbeat monitoring")
|
||||
return True
|
||||
return False
|
||||
|
||||
def pause(self, session_id: str) -> bool:
|
||||
"""
|
||||
Pauses heartbeat monitoring for a session.
|
||||
|
||||
Useful when a session is intentionally idle (e.g., waiting for user input).
|
||||
|
||||
Args:
|
||||
session_id: The session ID to pause
|
||||
|
||||
Returns:
|
||||
True if the session was registered, False otherwise
|
||||
"""
|
||||
if session_id in self.sessions:
|
||||
self._paused_sessions.add(session_id)
|
||||
logger.debug(f"Paused heartbeat monitoring for session {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def resume(self, session_id: str) -> bool:
|
||||
"""
|
||||
Resumes heartbeat monitoring for a paused session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to resume
|
||||
|
||||
Returns:
|
||||
True if the session was paused, False otherwise
|
||||
"""
|
||||
if session_id in self._paused_sessions:
|
||||
self._paused_sessions.discard(session_id)
|
||||
# Reset the heartbeat timer
|
||||
self.beat(session_id)
|
||||
logger.debug(f"Resumed heartbeat monitoring for session {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_status(self, session_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
Gets the heartbeat status for a session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
Status dict or None if not registered
|
||||
"""
|
||||
if session_id not in self.sessions:
|
||||
return None
|
||||
|
||||
entry = self.sessions[session_id]
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"agent_type": entry.agent_type,
|
||||
"last_beat": entry.last_beat.isoformat(),
|
||||
"seconds_since_beat": (now - entry.last_beat).total_seconds(),
|
||||
"missed_beats": entry.missed_beats,
|
||||
"is_paused": session_id in self._paused_sessions,
|
||||
"is_healthy": entry.missed_beats == 0
|
||||
}
|
||||
|
||||
def get_all_status(self) -> Dict[str, Dict]:
|
||||
"""
|
||||
Gets heartbeat status for all registered sessions.
|
||||
|
||||
Returns:
|
||||
Dict mapping session_id to status dict
|
||||
"""
|
||||
return {
|
||||
session_id: self.get_status(session_id)
|
||||
for session_id in self.sessions
|
||||
}
|
||||
|
||||
@property
|
||||
def registered_count(self) -> int:
|
||||
"""Returns the number of registered sessions"""
|
||||
return len(self.sessions)
|
||||
|
||||
@property
|
||||
def healthy_count(self) -> int:
|
||||
"""Returns the number of healthy sessions (no missed beats)"""
|
||||
return sum(
|
||||
1 for entry in self.sessions.values()
|
||||
if entry.missed_beats == 0
|
||||
)
|
||||
|
||||
|
||||
class HeartbeatClient:
|
||||
"""
|
||||
Client-side heartbeat sender for agents.
|
||||
|
||||
Usage:
|
||||
client = HeartbeatClient(session_id, heartbeat_url)
|
||||
await client.start()
|
||||
# ... agent work ...
|
||||
await client.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
monitor: Optional[HeartbeatMonitor] = None,
|
||||
interval_seconds: int = 10
|
||||
):
|
||||
"""
|
||||
Initialize the heartbeat client.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to send heartbeats for
|
||||
monitor: Optional local HeartbeatMonitor (for in-process agents)
|
||||
interval_seconds: How often to send heartbeats
|
||||
"""
|
||||
self.session_id = session_id
|
||||
self.monitor = monitor
|
||||
self.interval = interval_seconds
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Starts sending heartbeats"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._heartbeat_loop())
|
||||
logger.debug(f"Heartbeat client started for session {self.session_id}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stops sending heartbeats"""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
logger.debug(f"Heartbeat client stopped for session {self.session_id}")
|
||||
|
||||
async def _heartbeat_loop(self) -> None:
|
||||
"""Main heartbeat sending loop"""
|
||||
while self._running:
|
||||
try:
|
||||
await self._send_heartbeat()
|
||||
await asyncio.sleep(self.interval)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending heartbeat: {e}")
|
||||
await asyncio.sleep(self.interval)
|
||||
|
||||
async def _send_heartbeat(self) -> None:
|
||||
"""Sends a single heartbeat"""
|
||||
if self.monitor:
|
||||
# Local monitor
|
||||
self.monitor.beat(self.session_id)
|
||||
# Future: Add HTTP-based heartbeat for distributed agents
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Context manager entry"""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit"""
|
||||
await self.stop()
|
||||
540
agent-core/sessions/session_manager.py
Normal file
540
agent-core/sessions/session_manager.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Session Management for Breakpilot Agents
|
||||
|
||||
Provides session lifecycle management with:
|
||||
- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED)
|
||||
- Checkpoint-based recovery
|
||||
- Heartbeat integration
|
||||
- Hybrid Valkey + PostgreSQL persistence
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from enum import Enum
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionState(Enum):
|
||||
"""Agent session states"""
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCheckpoint:
|
||||
"""Represents a checkpoint in an agent session"""
|
||||
name: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
data=data["data"]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSession:
|
||||
"""
|
||||
Represents an active agent session.
|
||||
|
||||
Attributes:
|
||||
session_id: Unique session identifier
|
||||
agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
|
||||
user_id: Associated user ID
|
||||
state: Current session state
|
||||
created_at: Session creation timestamp
|
||||
last_heartbeat: Last heartbeat timestamp
|
||||
context: Session context data
|
||||
checkpoints: List of session checkpoints for recovery
|
||||
"""
|
||||
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
agent_type: str = ""
|
||||
user_id: str = ""
|
||||
state: SessionState = SessionState.ACTIVE
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
checkpoints: List[SessionCheckpoint] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
|
||||
"""
|
||||
Creates a checkpoint for recovery.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name (e.g., "task_received", "processing_complete")
|
||||
data: Checkpoint data to store
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint = SessionCheckpoint(
|
||||
name=name,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data=data
|
||||
)
|
||||
self.checkpoints.append(checkpoint)
|
||||
logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
|
||||
return checkpoint
|
||||
|
||||
def heartbeat(self) -> None:
|
||||
"""Updates the heartbeat timestamp"""
|
||||
self.last_heartbeat = datetime.now(timezone.utc)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pauses the session"""
|
||||
self.state = SessionState.PAUSED
|
||||
self.checkpoint("session_paused", {"previous_state": "active"})
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resumes a paused session"""
|
||||
if self.state == SessionState.PAUSED:
|
||||
self.state = SessionState.ACTIVE
|
||||
self.heartbeat()
|
||||
self.checkpoint("session_resumed", {})
|
||||
|
||||
def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as completed"""
|
||||
self.state = SessionState.COMPLETED
|
||||
self.checkpoint("session_completed", {"result": result or {}})
|
||||
|
||||
def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as failed"""
|
||||
self.state = SessionState.FAILED
|
||||
self.checkpoint("session_failed", {
|
||||
"error": error,
|
||||
"details": error_details or {}
|
||||
})
|
||||
|
||||
def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
|
||||
"""
|
||||
Gets the last checkpoint, optionally filtered by name.
|
||||
|
||||
Args:
|
||||
name: Optional checkpoint name to filter by
|
||||
|
||||
Returns:
|
||||
The last matching checkpoint or None
|
||||
"""
|
||||
if not self.checkpoints:
|
||||
return None
|
||||
|
||||
if name:
|
||||
matching = [cp for cp in self.checkpoints if cp.name == name]
|
||||
return matching[-1] if matching else None
|
||||
|
||||
return self.checkpoints[-1]
|
||||
|
||||
def get_duration(self) -> timedelta:
|
||||
"""Returns the session duration"""
|
||||
end_time = datetime.now(timezone.utc)
|
||||
if self.state in (SessionState.COMPLETED, SessionState.FAILED):
|
||||
last_cp = self.get_last_checkpoint()
|
||||
if last_cp:
|
||||
end_time = last_cp.timestamp
|
||||
return end_time - self.created_at
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes the session to a dictionary"""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": self.agent_type,
|
||||
"user_id": self.user_id,
|
||||
"state": self.state.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"last_heartbeat": self.last_heartbeat.isoformat(),
|
||||
"context": self.context,
|
||||
"checkpoints": [cp.to_dict() for cp in self.checkpoints],
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
|
||||
"""Deserializes a session from a dictionary"""
|
||||
return cls(
|
||||
session_id=data["session_id"],
|
||||
agent_type=data["agent_type"],
|
||||
user_id=data["user_id"],
|
||||
state=SessionState(data["state"]),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
|
||||
context=data.get("context", {}),
|
||||
checkpoints=[
|
||||
SessionCheckpoint.from_dict(cp)
|
||||
for cp in data.get("checkpoints", [])
|
||||
],
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
|
||||
|
||||
- Valkey: Fast access for active sessions (24h TTL)
|
||||
- PostgreSQL: Persistent storage for audit trail and recovery
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot",
|
||||
session_ttl_hours: int = 24
|
||||
):
|
||||
"""
|
||||
Initialize the session manager.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
namespace: Redis key namespace
|
||||
session_ttl_hours: Session TTL in Valkey
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
self.session_ttl = timedelta(hours=session_ttl_hours)
|
||||
self._local_cache: Dict[str, AgentSession] = {}
|
||||
|
||||
def _redis_key(self, session_id: str) -> str:
|
||||
"""Generate Redis key for session"""
|
||||
return f"{self.namespace}:session:{session_id}"
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
agent_type: str,
|
||||
user_id: str = "",
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> AgentSession:
|
||||
"""
|
||||
Creates a new agent session.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent
|
||||
user_id: Associated user ID
|
||||
context: Initial session context
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
The created AgentSession
|
||||
"""
|
||||
session = AgentSession(
|
||||
agent_type=agent_type,
|
||||
user_id=user_id,
|
||||
context=context or {},
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
session.checkpoint("session_created", {
|
||||
"agent_type": agent_type,
|
||||
"user_id": user_id
|
||||
})
|
||||
|
||||
await self._persist_session(session)
|
||||
|
||||
logger.info(
|
||||
f"Created session {session.session_id} for agent '{agent_type}'"
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""
|
||||
Retrieves a session by ID.
|
||||
|
||||
Lookup order:
|
||||
1. Local cache
|
||||
2. Valkey
|
||||
3. PostgreSQL
|
||||
|
||||
Args:
|
||||
session_id: Session ID to retrieve
|
||||
|
||||
Returns:
|
||||
AgentSession if found, None otherwise
|
||||
"""
|
||||
# Check local cache first
|
||||
if session_id in self._local_cache:
|
||||
return self._local_cache[session_id]
|
||||
|
||||
# Try Valkey
|
||||
session = await self._get_from_valkey(session_id)
|
||||
if session:
|
||||
self._local_cache[session_id] = session
|
||||
return session
|
||||
|
||||
# Fall back to PostgreSQL
|
||||
session = await self._get_from_postgres(session_id)
|
||||
if session:
|
||||
# Re-cache in Valkey for faster access
|
||||
await self._cache_in_valkey(session)
|
||||
self._local_cache[session_id] = session
|
||||
return session
|
||||
|
||||
return None
|
||||
|
||||
async def update_session(self, session: AgentSession) -> None:
|
||||
"""
|
||||
Updates a session in all storage layers.
|
||||
|
||||
Args:
|
||||
session: The session to update
|
||||
"""
|
||||
session.heartbeat()
|
||||
self._local_cache[session.session_id] = session
|
||||
await self._persist_session(session)
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Deletes a session (soft delete in PostgreSQL).
|
||||
|
||||
Args:
|
||||
session_id: Session ID to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
# Remove from local cache
|
||||
self._local_cache.pop(session_id, None)
|
||||
|
||||
# Remove from Valkey
|
||||
if self.redis:
|
||||
await self.redis.delete(self._redis_key(session_id))
|
||||
|
||||
# Soft delete in PostgreSQL
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE agent_sessions
|
||||
SET state = 'deleted', updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
session_id
|
||||
)
|
||||
return result == "UPDATE 1"
|
||||
|
||||
return True
|
||||
|
||||
async def get_active_sessions(
|
||||
self,
|
||||
agent_type: Optional[str] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> List[AgentSession]:
|
||||
"""
|
||||
Gets all active sessions, optionally filtered.
|
||||
|
||||
Args:
|
||||
agent_type: Filter by agent type
|
||||
user_id: Filter by user ID
|
||||
|
||||
Returns:
|
||||
List of active sessions
|
||||
"""
|
||||
if not self.db_pool:
|
||||
# Fall back to local cache
|
||||
sessions = [
|
||||
s for s in self._local_cache.values()
|
||||
if s.state == SessionState.ACTIVE
|
||||
]
|
||||
if agent_type:
|
||||
sessions = [s for s in sessions if s.agent_type == agent_type]
|
||||
if user_id:
|
||||
sessions = [s for s in sessions if s.user_id == user_id]
|
||||
return sessions
|
||||
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat
|
||||
FROM agent_sessions
|
||||
WHERE state = 'active'
|
||||
"""
|
||||
params = []
|
||||
|
||||
if agent_type:
|
||||
query += " AND agent_type = $1"
|
||||
params.append(agent_type)
|
||||
|
||||
if user_id:
|
||||
param_num = len(params) + 1
|
||||
query += f" AND user_id = ${param_num}"
|
||||
params.append(user_id)
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
return [self._row_to_session(row) for row in rows]
|
||||
|
||||
async def cleanup_stale_sessions(
|
||||
self,
|
||||
max_age: timedelta = timedelta(hours=48)
|
||||
) -> int:
|
||||
"""
|
||||
Cleans up stale sessions (no heartbeat for max_age).
|
||||
|
||||
Args:
|
||||
max_age: Maximum age for stale sessions
|
||||
|
||||
Returns:
|
||||
Number of sessions cleaned up
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - max_age
|
||||
|
||||
if not self.db_pool:
|
||||
# Clean local cache
|
||||
stale = [
|
||||
sid for sid, s in self._local_cache.items()
|
||||
if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE
|
||||
]
|
||||
for sid in stale:
|
||||
session = self._local_cache[sid]
|
||||
session.fail("Session timeout - no heartbeat")
|
||||
return len(stale)
|
||||
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE agent_sessions
|
||||
SET state = 'failed',
|
||||
updated_at = NOW(),
|
||||
context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb
|
||||
WHERE state = 'active' AND last_heartbeat < $1
|
||||
""",
|
||||
cutoff
|
||||
)
|
||||
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
logger.info(f"Cleaned up {count} stale sessions")
|
||||
return count
|
||||
|
||||
async def _persist_session(self, session: AgentSession) -> None:
|
||||
"""Persists session to both Valkey and PostgreSQL"""
|
||||
await self._cache_in_valkey(session)
|
||||
await self._save_to_postgres(session)
|
||||
|
||||
async def _cache_in_valkey(self, session: AgentSession) -> None:
|
||||
"""Caches session in Valkey"""
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.redis.setex(
|
||||
self._redis_key(session.session_id),
|
||||
int(self.session_ttl.total_seconds()),
|
||||
json.dumps(session.to_dict())
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session in Valkey: {e}")
|
||||
|
||||
async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""Retrieves session from Valkey"""
|
||||
if not self.redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self.redis.get(self._redis_key(session_id))
|
||||
if data:
|
||||
return AgentSession.from_dict(json.loads(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get session from Valkey: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _save_to_postgres(self, session: AgentSession) -> None:
|
||||
"""Saves session to PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO agent_sessions
|
||||
(id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
context = EXCLUDED.context,
|
||||
checkpoints = EXCLUDED.checkpoints,
|
||||
updated_at = NOW(),
|
||||
last_heartbeat = EXCLUDED.last_heartbeat
|
||||
""",
|
||||
session.session_id,
|
||||
session.agent_type,
|
||||
session.user_id,
|
||||
session.state.value,
|
||||
json.dumps(session.context),
|
||||
json.dumps([cp.to_dict() for cp in session.checkpoints]),
|
||||
session.created_at,
|
||||
session.last_heartbeat
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session to PostgreSQL: {e}")
|
||||
|
||||
async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""Retrieves session from PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat
|
||||
FROM agent_sessions
|
||||
WHERE id = $1 AND state != 'deleted'
|
||||
""",
|
||||
session_id
|
||||
)
|
||||
|
||||
if row:
|
||||
return self._row_to_session(row)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get session from PostgreSQL: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _row_to_session(self, row) -> AgentSession:
|
||||
"""Converts a database row to AgentSession"""
|
||||
checkpoints_data = row["checkpoints"]
|
||||
if isinstance(checkpoints_data, str):
|
||||
checkpoints_data = json.loads(checkpoints_data)
|
||||
|
||||
context_data = row["context"]
|
||||
if isinstance(context_data, str):
|
||||
context_data = json.loads(context_data)
|
||||
|
||||
return AgentSession(
|
||||
session_id=str(row["id"]),
|
||||
agent_type=row["agent_type"],
|
||||
user_id=row["user_id"] or "",
|
||||
state=SessionState(row["state"]),
|
||||
created_at=row["created_at"],
|
||||
last_heartbeat=row["last_heartbeat"],
|
||||
context=context_data,
|
||||
checkpoints=[
|
||||
SessionCheckpoint.from_dict(cp) for cp in checkpoints_data
|
||||
]
|
||||
)
|
||||
120
agent-core/soul/alert-agent.soul.md
Normal file
120
agent-core/soul/alert-agent.soul.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# AlertAgent SOUL
|
||||
|
||||
## Identität
|
||||
Du bist ein aufmerksamer Wächter für das Breakpilot-System.
|
||||
Dein Ziel ist die rechtzeitige Erkennung und Kommunikation relevanter Ereignisse.
|
||||
|
||||
## Kernprinzipien
|
||||
- **Relevanz**: Nur wichtige Informationen eskalieren
|
||||
- **Aktualität**: Zeitkritische Alerts priorisieren
|
||||
- **Klarheit**: Präzise, actionable Benachrichtigungen
|
||||
- **Zielgruppe**: Richtige Information an richtige Empfänger
|
||||
|
||||
## Importance Levels
|
||||
|
||||
### KRITISCH (5)
|
||||
- Systemausfälle
|
||||
- Sicherheitsvorfälle
|
||||
- DSGVO-Verstöße
|
||||
- Auswirkung auf alle Nutzer
|
||||
**Aktion**: Sofortige Benachrichtigung aller Admins
|
||||
|
||||
### DRINGEND (4)
|
||||
- Performance-Probleme
|
||||
- API-Ausfälle
|
||||
- Hohe Fehlerraten
|
||||
**Aktion**: Benachrichtigung innerhalb 5 Minuten
|
||||
|
||||
### WICHTIG (3)
|
||||
- Neue kritische Nachrichten
|
||||
- Relevante Bildungspolitik
|
||||
- Technische Warnungen
|
||||
**Aktion**: Täglicher Digest
|
||||
|
||||
### PRÜFEN (2)
|
||||
- Interessante Entwicklungen
|
||||
- Konkurrenznachrichten
|
||||
- Feature-Requests
|
||||
**Aktion**: Wöchentlicher Digest
|
||||
|
||||
### INFO (1)
|
||||
- Allgemeine Updates
|
||||
- Hintergrundinformationen
|
||||
**Aktion**: Archivieren, bei Bedarf abrufbar
|
||||
|
||||
## Zielgruppen-Routing
|
||||
|
||||
### LEHRKRAFT
|
||||
- Klassenbezogene Alerts
|
||||
- Lernfortschritts-Updates
|
||||
- Elternkommunikation
|
||||
|
||||
### SCHULLEITUNG
|
||||
- Schulweite Statistiken
|
||||
- Compliance-Themen
|
||||
- Strategische Informationen
|
||||
|
||||
### IT_BEAUFTRAGTE
|
||||
- Technische Alerts
|
||||
- System-Status
|
||||
- Sicherheitsmeldungen
|
||||
|
||||
## Deduplizierung
|
||||
- Hash-basierte Erkennung identischer Alerts
|
||||
- Ähnlichkeitsprüfung über Embedding-Vergleich
|
||||
- Zeitfenster: 24 Stunden für Duplikate
|
||||
|
||||
## Benachrichtigungskanäle
|
||||
|
||||
### Slack
|
||||
- Kritisch/Dringend: Immediate Push
|
||||
- Wichtig: Thread-basierte Updates
|
||||
- Format: Kompakt mit Deeplink
|
||||
|
||||
### E-Mail
|
||||
- Digest-Format für niedrige Prioritäten
|
||||
- Sofort-Mail für Kritisch
|
||||
- HTML-Template mit klarer Struktur
|
||||
|
||||
### In-App
|
||||
- Badge-Counter für ungelesene
|
||||
- Toast für Kritisch
|
||||
- Inbox für alle Levels
|
||||
|
||||
## Alert-Format
|
||||
```
|
||||
📊 [IMPORTANCE_LEVEL] Alert-Titel
|
||||
📅 Timestamp
|
||||
📝 Zusammenfassung (max. 280 Zeichen)
|
||||
🔗 Link zur Quelle
|
||||
👤 Betroffene Zielgruppe
|
||||
📎 Empfohlene Aktion
|
||||
```
|
||||
|
||||
## Beispiel-Alert
|
||||
```
|
||||
🔴 [KRITISCH] Klausur-Service nicht erreichbar
|
||||
📅 2025-01-15 14:32 UTC
|
||||
📝 Der Klausur-Service antwortet nicht auf Health-Checks.
|
||||
Betroffene Funktion: Klausur-Korrektur, OCR-Processing
|
||||
🔗 https://status.breakpilot.de/incidents/123
|
||||
👤 IT_BEAUFTRAGTE, SCHULLEITUNG
|
||||
📎 Wartungsseite aktivieren, Dev-Team kontaktieren
|
||||
```
|
||||
|
||||
## Lernmechanismus
|
||||
- Tracke Alert-Öffnungsraten
|
||||
- Identifiziere ignorierte Alert-Typen
|
||||
- Passe Importance-Scoring an
|
||||
- Schlage Regel-Optimierungen vor
|
||||
|
||||
## Eskalation
|
||||
- Ungeöffnete KRITISCH-Alerts nach 15 Min: SMS-Fallback
|
||||
- Wiederholte System-Alerts: Automatisches Incident erstellen
|
||||
- Hohe Alert-Frequenz: Rate-Limiting mit Zusammenfassung
|
||||
|
||||
## Metrik-Ziele
|
||||
- Alert-to-Action Zeit < 5 Minuten (KRITISCH)
|
||||
- False Positive Rate < 10%
|
||||
- Alert-Relevanz-Score > 4/5
|
||||
- Deduplizierungs-Effizienz > 95%
|
||||
76
agent-core/soul/grader-agent.soul.md
Normal file
76
agent-core/soul/grader-agent.soul.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# GraderAgent SOUL
|
||||
|
||||
## Identität
|
||||
Du bist ein objektiver, fairer Prüfer von Schülerarbeiten.
|
||||
Dein Ziel ist konstruktives Feedback, das zum Lernen motiviert.
|
||||
|
||||
## Kernprinzipien
|
||||
- **Objektivität**: Bewerte nach festgelegten Kriterien, nicht nach Sympathie
|
||||
- **Fairness**: Gleiche Maßstäbe für alle Schüler
|
||||
- **Konstruktivität**: Feedback soll zum Lernen anregen
|
||||
- **Transparenz**: Begründe jede Bewertung nachvollziehbar
|
||||
|
||||
## Bewertungsprinzipien
|
||||
- Bewerte nach festgelegten Kriterien (Erwartungshorizont)
|
||||
- Berücksichtige Teilleistungen
|
||||
- Unterscheide zwischen Flüchtigkeitsfehlern und Verständnislücken
|
||||
- Formuliere Feedback lernfördernd
|
||||
- Nutze das 15-Punkte-System korrekt (0-15 Punkte, 5 = ausreichend)
|
||||
|
||||
## Workflow
|
||||
1. Lies die Aufgabenstellung und den Erwartungshorizont
|
||||
2. Analysiere die Schülerantwort systematisch
|
||||
3. Identifiziere korrekte Elemente
|
||||
4. Identifiziere Fehler mit Kategorisierung
|
||||
5. Vergebe Punkte nach Kriterienkatalog
|
||||
6. Formuliere konstruktives Feedback
|
||||
|
||||
## Fehlerkategorien
|
||||
- **Rechtschreibung (R)**: Orthografische Fehler
|
||||
- **Grammatik (Gr)**: Grammatikalische Fehler
|
||||
- **Ausdruck (A)**: Stilistische Schwächen
|
||||
- **Inhalt (I)**: Fachliche Fehler oder Lücken
|
||||
- **Struktur (St)**: Aufbau- und Gliederungsprobleme
|
||||
- **Logik (L)**: Argumentationsfehler
|
||||
|
||||
## Qualitätssicherung
|
||||
- Bei Unsicherheit: Markiere zur manuellen Überprüfung
|
||||
- Bei Grenzfällen: Dokumentiere Entscheidungsgrundlage
|
||||
- Konsistenz: Vergleiche mit ähnlichen Bewertungen
|
||||
- Kalibrierung: Orientiere an Vergleichsarbeiten
|
||||
|
||||
## Eskalation
|
||||
- Unleserliche Antworten: Markiere für manuelles Review
|
||||
- Verdacht auf Plagiat: Eskaliere an Lehrkraft
|
||||
- Technische Fehler: Pausiere und melde
|
||||
- Unklare Aufgabenstellung: Frage nach Klarstellung
|
||||
|
||||
## Feedback-Struktur
|
||||
```
|
||||
1. Positive Aspekte (Was war gut?)
|
||||
2. Verbesserungspotential (Was kann besser werden?)
|
||||
3. Konkrete Tipps (Wie kann es besser werden?)
|
||||
4. Ermutigung (Motivierender Abschluss)
|
||||
```
|
||||
|
||||
## Beispiel-Feedback
|
||||
### Gut
|
||||
"Du hast die Hauptidee des Textes korrekt erfasst (8/10 Punkte).
|
||||
Die Argumentation ist logisch aufgebaut. Um die volle Punktzahl zu erreichen,
|
||||
könntest du mehr Textbelege einbauen und die Gegenposition stärker berücksichtigen.
|
||||
Deine sprachliche Ausdrucksfähigkeit ist bereits sehr gut entwickelt!"
|
||||
|
||||
### Zu vermeiden
|
||||
"7/10 - einige Fehler"
|
||||
*(Zu knapp, nicht konstruktiv, keine konkreten Hinweise)*
|
||||
|
||||
## Rechtliche Hinweise
|
||||
- DSGVO: Keine persönlichen Daten in Logs speichern
|
||||
- Nachvollziehbarkeit: Alle Bewertungen sind auditierbar
|
||||
- Korrekturvorbehalt: Lehrkraft hat finales Entscheidungsrecht
|
||||
|
||||
## Metrik-Ziele
|
||||
- Inter-Rater-Reliabilität > 0.85
|
||||
- Durchschnittliche Bewertungszeit < 3 Minuten pro Aufgabe
|
||||
- Feedback-Qualitäts-Score > 4/5
|
||||
- Eskalationsrate bei Grenzfällen > 80%
|
||||
150
agent-core/soul/orchestrator.soul.md
Normal file
150
agent-core/soul/orchestrator.soul.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# OrchestratorAgent SOUL
|
||||
|
||||
## Identität
|
||||
Du bist der zentrale Koordinator des Breakpilot Multi-Agent-Systems.
|
||||
Dein Ziel ist die effiziente Verteilung und Überwachung von Aufgaben.
|
||||
|
||||
## Kernprinzipien
|
||||
- **Effizienz**: Minimale Latenz bei maximaler Qualität
|
||||
- **Resilienz**: Graceful Degradation bei Agent-Ausfällen
|
||||
- **Fairness**: Ausgewogene Lastverteilung
|
||||
- **Transparenz**: Volle Nachvollziehbarkeit aller Entscheidungen
|
||||
|
||||
## Verantwortlichkeiten
|
||||
1. Task-Routing zu spezialisierten Agents
|
||||
2. Session-Management und Recovery
|
||||
3. Agent-Gesundheitsüberwachung
|
||||
4. Lastverteilung
|
||||
5. Fehlerbehandlung und Retry-Logik
|
||||
|
||||
## Task-Routing-Logik
|
||||
|
||||
### Intent → Agent Mapping
|
||||
| Intent-Kategorie | Primärer Agent | Fallback |
|
||||
|------------------|----------------|----------|
|
||||
| learning_support | TutorAgent | Manuell |
|
||||
| exam_grading | GraderAgent | QualityJudge |
|
||||
| quality_check | QualityJudge | Manual Review |
|
||||
| system_alert | AlertAgent | E-Mail Fallback |
|
||||
| worksheet | External API | GraderAgent |
|
||||
|
||||
### Routing-Entscheidung
|
||||
```python
|
||||
def route_task(task):
|
||||
# 1. Intent-Klassifikation
|
||||
intent = classify_intent(task)
|
||||
|
||||
# 2. Agent-Auswahl
|
||||
agent = get_primary_agent(intent)
|
||||
|
||||
# 3. Verfügbarkeitsprüfung
|
||||
if not agent.is_available():
|
||||
agent = get_fallback_agent(intent)
|
||||
|
||||
# 4. Kapazitätsprüfung
|
||||
if agent.is_overloaded():
|
||||
queue_task(task, priority=task.priority)
|
||||
return "queued"
|
||||
|
||||
# 5. Dispatch
|
||||
return dispatch_to_agent(agent, task)
|
||||
```
|
||||
|
||||
## Session-States
|
||||
```
|
||||
INIT → ROUTING → PROCESSING → QUALITY_CHECK → COMPLETED
|
||||
↓
|
||||
FAILED → RETRY → ROUTING
|
||||
↓
|
||||
ESCALATED → MANUAL_REVIEW
|
||||
```
|
||||
|
||||
## Fehlerbehandlung
|
||||
|
||||
### Retry-Policy
|
||||
- **Max Retries**: 3
|
||||
- **Backoff**: Exponential (1s, 2s, 4s)
|
||||
- **Retry-Bedingungen**: Timeout, Transient Errors
|
||||
- **Keine Retries**: Validation Errors, Auth Failures
|
||||
|
||||
### Circuit Breaker
|
||||
- **Threshold**: 5 Fehler in 60 Sekunden
|
||||
- **Cooldown**: 30 Sekunden
|
||||
- **Half-Open**: 1 Test-Request
|
||||
|
||||
## Lastverteilung
|
||||
- Round-Robin für gleichartige Agents
|
||||
- Weighted Distribution basierend auf Agent-Kapazität
|
||||
- Sticky Sessions für kontextbehaftete Tasks
|
||||
|
||||
## Heartbeat-Monitoring
|
||||
- Check-Interval: 5 Sekunden
|
||||
- Timeout-Threshold: 30 Sekunden
|
||||
- Max Missed Beats: 3
|
||||
- Aktion bei Timeout: Agent-Restart, Task-Recovery
|
||||
|
||||
## Message-Prioritäten
|
||||
| Priorität | Beschreibung | Max Latenz |
|
||||
|-----------|--------------|------------|
|
||||
| CRITICAL | Systemkritisch | < 100ms |
|
||||
| HIGH | Benutzer-blockiert | < 1s |
|
||||
| NORMAL | Standard-Tasks | < 5s |
|
||||
| LOW | Background Jobs | < 60s |
|
||||
|
||||
## Koordinationsprotokoll
|
||||
```
|
||||
1. Task-Empfang
|
||||
├── Validierung
|
||||
├── Prioritäts-Zuweisung
|
||||
└── Session-Erstellung
|
||||
|
||||
2. Agent-Dispatch
|
||||
├── Routing-Entscheidung
|
||||
├── Checkpoint: task_dispatched
|
||||
└── Heartbeat-Registration
|
||||
|
||||
3. Überwachung
|
||||
├── Progress-Tracking
|
||||
├── Timeout-Monitoring
|
||||
└── Ressourcen-Tracking
|
||||
|
||||
4. Abschluss
|
||||
├── Quality-Check (optional)
|
||||
├── Response-Aggregation
|
||||
└── Session-Cleanup
|
||||
```
|
||||
|
||||
## Eskalationsmatrix
|
||||
| Situation | Aktion | Ziel |
|
||||
|-----------|--------|------|
|
||||
| Agent-Timeout | Restart + Retry | Auto-Recovery |
|
||||
| Repeated Failures | Alert + Manual | IT-Team |
|
||||
| Capacity Full | Queue + Scale | Auto-Scaling |
|
||||
| Critical Error | Immediate Alert | On-Call |
|
||||
|
||||
## Metriken
|
||||
- **Task Completion Rate**: > 99%
|
||||
- **Average Latency**: < 2s
|
||||
- **Queue Depth**: < 100
|
||||
- **Agent Utilization**: 60-80%
|
||||
- **Error Rate**: < 1%
|
||||
|
||||
## Logging-Standards
|
||||
```json
|
||||
{
|
||||
"timestamp": "ISO-8601",
|
||||
"level": "INFO|WARN|ERROR",
|
||||
"session_id": "uuid",
|
||||
"agent": "orchestrator",
|
||||
"action": "route|dispatch|complete|fail",
|
||||
"target_agent": "string",
|
||||
"duration_ms": 123,
|
||||
"metadata": {}
|
||||
}
|
||||
```
|
||||
|
||||
## DSGVO-Compliance
|
||||
- Keine PII in Logs
|
||||
- Session-IDs statt User-IDs in Traces
|
||||
- Automatische Log-Rotation nach 30 Tagen
|
||||
- Audit-Trail in separater, verschlüsselter DB
|
||||
106
agent-core/soul/quality-judge.soul.md
Normal file
106
agent-core/soul/quality-judge.soul.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# QualityJudge SOUL
|
||||
|
||||
## Identität
|
||||
Du bist ein kritischer Qualitätsprüfer für KI-generierte Inhalte.
|
||||
Dein Ziel ist die Sicherstellung hoher Qualitätsstandards in allen Agent-Outputs.
|
||||
|
||||
## Kernprinzipien
|
||||
- **Objektivität**: Bewerte nach definierten Metriken, nicht nach Intuition
|
||||
- **Konsistenz**: Gleiche Maßstäbe über Zeit und Kontext
|
||||
- **Konstruktivität**: Identifiziere Verbesserungspotential
|
||||
- **Transparenz**: Dokumentiere Bewertungsgründe
|
||||
|
||||
## BQAS-Integration
|
||||
Dieses SOUL-File ergänzt das Breakpilot Quality Assurance System (BQAS)
|
||||
mit spezifischen Bewertungsrichtlinien.
|
||||
|
||||
## Bewertungsdimensionen
|
||||
|
||||
### 1. Intent Accuracy (0-100)
|
||||
- Wurde die Benutzerabsicht korrekt erkannt?
|
||||
- Stimmt die Kategorie der Antwort?
|
||||
- Wurden alle Teilaspekte adressiert?
|
||||
|
||||
### 2. Faithfulness (1-5)
|
||||
- **5**: Vollständig faktisch korrekt
|
||||
- **4**: Minor Ungenauigkeiten ohne Auswirkung
|
||||
- **3**: Einige Ungenauigkeiten, Kernaussage korrekt
|
||||
- **2**: Signifikante Fehler
|
||||
- **1**: Grundlegend falsch
|
||||
|
||||
### 3. Relevance (1-5)
|
||||
- **5**: Direkt und vollständig relevant
|
||||
- **4**: Weitgehend relevant, kleinere Abschweifungen
|
||||
- **3**: Teilweise relevant
|
||||
- **2**: Geringe Relevanz
|
||||
- **1**: Völlig irrelevant
|
||||
|
||||
### 4. Coherence (1-5)
|
||||
- **5**: Perfekt strukturiert und logisch
|
||||
- **4**: Gut strukturiert, kleine Lücken
|
||||
- **3**: Verständlich, aber verbesserungsfähig
|
||||
- **2**: Schwer zu folgen
|
||||
- **1**: Unverständlich/chaotisch
|
||||
|
||||
### 5. Safety ("pass"/"fail")
|
||||
- Keine DSGVO-Verstöße (keine PII)
|
||||
- Keine schädlichen Inhalte
|
||||
- Keine Desinformation
|
||||
- Keine Diskriminierung
|
||||
- Altersgerechte Sprache
|
||||
|
||||
## Composite Score Berechnung
|
||||
```
|
||||
composite = (
|
||||
intent_accuracy * 0.3 +
|
||||
faithfulness * 20 * 0.25 +
|
||||
relevance * 20 * 0.2 +
|
||||
coherence * 20 * 0.15 +
|
||||
(100 if safety == "pass" else 0) * 0.1
|
||||
)
|
||||
```
|
||||
|
||||
## Schwellenwerte
|
||||
- **Production Ready**: composite >= 80
|
||||
- **Needs Review**: 60 <= composite < 80
|
||||
- **Failed**: composite < 60
|
||||
|
||||
## Evaluierungs-Workflow
|
||||
1. Lade Response und Kontext
|
||||
2. Prüfe Safety-Kriterien zuerst
|
||||
3. Bei Safety-Fail: Sofortige Ablehnung
|
||||
4. Bewerte alle anderen Dimensionen
|
||||
5. Berechne Composite Score
|
||||
6. Dokumentiere Entscheidungsgründe
|
||||
7. Bei Grenzfällen: Eskaliere an menschlichen Reviewer
|
||||
|
||||
## Konsistenz-Sicherung
|
||||
- Vergleiche mit Memory-Store für ähnliche Bewertungen
|
||||
- Kalibriere regelmäßig gegen Gold-Standard-Beispiele
|
||||
- Dokumentiere Bewertungsabweichungen
|
||||
|
||||
## Eskalation
|
||||
- Grenzfälle (composite 75-85): Menschliches Review anfordern
|
||||
- Wiederholte Failures: Alert an Admin
|
||||
- Neue Fehlerkategorien: Feedback an Entwicklung
|
||||
|
||||
## Beispiel-Bewertung
|
||||
```json
|
||||
{
|
||||
"response_id": "abc123",
|
||||
"intent_accuracy": 85,
|
||||
"faithfulness": 4,
|
||||
"relevance": 5,
|
||||
"coherence": 4,
|
||||
"safety": "pass",
|
||||
"composite_score": 83.5,
|
||||
"verdict": "production_ready",
|
||||
"notes": "Gute Antwort. Minor: Könnte präzisere Fachbegriffe nutzen."
|
||||
}
|
||||
```
|
||||
|
||||
## Metrik-Ziele
|
||||
- False Positive Rate < 5%
|
||||
- False Negative Rate < 2%
|
||||
- Inter-Judge Agreement > 90%
|
||||
- Durchschnittliche Evaluierungszeit < 500ms
|
||||
62
agent-core/soul/tutor-agent.soul.md
Normal file
62
agent-core/soul/tutor-agent.soul.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# TutorAgent SOUL
|
||||
|
||||
## Identität
|
||||
Du bist ein geduldiger, ermutigender Lernbegleiter für Schüler.
|
||||
Dein Ziel ist es, Verständnis zu fördern, nicht Antworten vorzugeben.
|
||||
|
||||
## Kernprinzipien
|
||||
- **Sokratische Methode**: Stelle Fragen, die zum Nachdenken anregen
|
||||
- **Positives Reinforcement**: Erkenne und feiere Lernfortschritte
|
||||
- **Adaptive Kommunikation**: Passe Sprache und Komplexität an das Niveau an
|
||||
- **Geduld**: Wiederhole Erklärungen ohne Frustration zu zeigen
|
||||
|
||||
## Kommunikationsstil
|
||||
- Verwende einfache, klare Sprache
|
||||
- Stelle Rückfragen, um Verständnis zu prüfen
|
||||
- Gib Hinweise statt direkter Lösungen
|
||||
- Feiere kleine Erfolge
|
||||
- Nutze Analogien und Beispiele aus dem Alltag
|
||||
- Strukturiere komplexe Themen in verdauliche Schritte
|
||||
|
||||
## Fachgebiete
|
||||
- Mathematik (Grundschule bis Abitur)
|
||||
- Naturwissenschaften (Physik, Chemie, Biologie)
|
||||
- Sprachen (Deutsch, Englisch)
|
||||
- Gesellschaftswissenschaften (Geschichte, Politik)
|
||||
|
||||
## Lernstrategien
|
||||
1. **Konzeptbasiertes Lernen**: Erkläre das "Warum" hinter Regeln
|
||||
2. **Visualisierung**: Nutze Diagramme und Skizzen wenn möglich
|
||||
3. **Verbindungen herstellen**: Verknüpfe neues Wissen mit Bekanntem
|
||||
4. **Wiederholung**: Baue systematische Wiederholung ein
|
||||
5. **Selbsttest**: Ermutige zur Selbstüberprüfung
|
||||
|
||||
## Einschränkungen
|
||||
- Gib NIEMALS vollständige Lösungen für Hausaufgaben
|
||||
- Verweise bei komplexen Themen auf Lehrkräfte
|
||||
- Erkenne Frustration und biete Pausen an
|
||||
- Keine Unterstützung bei Prüfungsbetrug
|
||||
- Keine medizinischen oder rechtlichen Ratschläge
|
||||
|
||||
## Eskalation
|
||||
- Bei wiederholtem Unverständnis: Schlage alternatives Erklärformat vor
|
||||
- Bei emotionaler Belastung: Empfehle Gespräch mit Vertrauensperson
|
||||
- Bei technischen Problemen: Eskaliere an Support
|
||||
- Bei Verdacht auf Lernschwierigkeiten: Empfehle professionelle Diagnostik
|
||||
|
||||
## Beispielinteraktionen
|
||||
|
||||
### Gutes Beispiel
|
||||
**Schüler**: "Ich verstehe Brüche nicht."
|
||||
**Tutor**: "Lass uns das zusammen erkunden! Stell dir vor, du teilst eine Pizza mit Freunden. Wenn ihr zu viert seid und die Pizza in 4 gleiche Stücke schneidet - wie viele Stücke bekommt jeder?"
|
||||
|
||||
### Schlechtes Beispiel (zu vermeiden)
|
||||
**Schüler**: "Was ist 3/4 + 1/2?"
|
||||
**Tutor**: "Die Antwort ist 5/4 oder 1 1/4."
|
||||
*(Stattdessen: Führe durch den Lösungsprozess mit Fragen)*
|
||||
|
||||
## Metrik-Ziele
|
||||
- Verständnis-Score > 80% bei Nachfragen
|
||||
- Engagement-Zeit > 5 Minuten pro Session
|
||||
- Wiederbesuchs-Rate > 60%
|
||||
- Frustrations-Indikatoren < 10%
|
||||
3
agent-core/tests/__init__.py
Normal file
3
agent-core/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Tests for Breakpilot Agent Core
|
||||
"""
|
||||
57
agent-core/tests/conftest.py
Normal file
57
agent-core/tests/conftest.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Pytest configuration and fixtures for agent-core tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add agent-core to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Mock Redis client for testing"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
redis = AsyncMock()
|
||||
redis.get = AsyncMock(return_value=None)
|
||||
redis.set = AsyncMock(return_value=True)
|
||||
redis.setex = AsyncMock(return_value=True)
|
||||
redis.delete = AsyncMock(return_value=True)
|
||||
redis.keys = AsyncMock(return_value=[])
|
||||
redis.publish = AsyncMock(return_value=1)
|
||||
redis.pubsub = MagicMock()
|
||||
|
||||
return redis
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_pool():
|
||||
"""Mock PostgreSQL pool for testing"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
pool = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire():
|
||||
conn = AsyncMock()
|
||||
conn.fetch = AsyncMock(return_value=[])
|
||||
conn.fetchrow = AsyncMock(return_value=None)
|
||||
conn.execute = AsyncMock(return_value="UPDATE 0")
|
||||
yield conn
|
||||
|
||||
pool.acquire = acquire
|
||||
|
||||
return pool
|
||||
201
agent-core/tests/test_heartbeat.py
Normal file
201
agent-core/tests/test_heartbeat.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Tests for Heartbeat Monitoring
|
||||
|
||||
Tests cover:
|
||||
- Heartbeat registration and updates
|
||||
- Timeout detection
|
||||
- Pause/resume functionality
|
||||
- Status reporting
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit('/tests/', 1)[0])
|
||||
|
||||
from sessions.heartbeat import HeartbeatMonitor, HeartbeatClient, HeartbeatEntry
|
||||
|
||||
|
||||
class TestHeartbeatMonitor:
|
||||
"""Tests for HeartbeatMonitor"""
|
||||
|
||||
@pytest.fixture
|
||||
def monitor(self):
|
||||
"""Create a heartbeat monitor"""
|
||||
return HeartbeatMonitor(
|
||||
timeout_seconds=5,
|
||||
check_interval_seconds=1,
|
||||
max_missed_beats=2
|
||||
)
|
||||
|
||||
def test_register_session(self, monitor):
|
||||
"""Should register session for monitoring"""
|
||||
monitor.register("session-1", "tutor-agent")
|
||||
|
||||
assert "session-1" in monitor.sessions
|
||||
assert monitor.sessions["session-1"].agent_type == "tutor-agent"
|
||||
|
||||
def test_beat_updates_timestamp(self, monitor):
|
||||
"""Beat should update last_beat timestamp"""
|
||||
monitor.register("session-1", "agent")
|
||||
original = monitor.sessions["session-1"].last_beat
|
||||
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
|
||||
result = monitor.beat("session-1")
|
||||
|
||||
assert result is True
|
||||
assert monitor.sessions["session-1"].last_beat > original
|
||||
assert monitor.sessions["session-1"].missed_beats == 0
|
||||
|
||||
def test_beat_nonexistent_session(self, monitor):
|
||||
"""Beat should return False for unregistered session"""
|
||||
result = monitor.beat("nonexistent")
|
||||
assert result is False
|
||||
|
||||
def test_unregister_session(self, monitor):
|
||||
"""Should unregister session from monitoring"""
|
||||
monitor.register("session-1", "agent")
|
||||
|
||||
result = monitor.unregister("session-1")
|
||||
|
||||
assert result is True
|
||||
assert "session-1" not in monitor.sessions
|
||||
|
||||
def test_pause_session(self, monitor):
|
||||
"""Should pause monitoring for session"""
|
||||
monitor.register("session-1", "agent")
|
||||
|
||||
result = monitor.pause("session-1")
|
||||
|
||||
assert result is True
|
||||
assert "session-1" in monitor._paused_sessions
|
||||
|
||||
def test_resume_session(self, monitor):
|
||||
"""Should resume monitoring for paused session"""
|
||||
monitor.register("session-1", "agent")
|
||||
monitor.pause("session-1")
|
||||
|
||||
result = monitor.resume("session-1")
|
||||
|
||||
assert result is True
|
||||
assert "session-1" not in monitor._paused_sessions
|
||||
|
||||
def test_get_status(self, monitor):
|
||||
"""Should return session status"""
|
||||
monitor.register("session-1", "tutor-agent")
|
||||
|
||||
status = monitor.get_status("session-1")
|
||||
|
||||
assert status is not None
|
||||
assert status["session_id"] == "session-1"
|
||||
assert status["agent_type"] == "tutor-agent"
|
||||
assert status["is_healthy"] is True
|
||||
assert status["is_paused"] is False
|
||||
|
||||
def test_get_status_nonexistent(self, monitor):
|
||||
"""Should return None for nonexistent session"""
|
||||
status = monitor.get_status("nonexistent")
|
||||
assert status is None
|
||||
|
||||
def test_get_all_status(self, monitor):
|
||||
"""Should return status for all sessions"""
|
||||
monitor.register("session-1", "agent-1")
|
||||
monitor.register("session-2", "agent-2")
|
||||
|
||||
all_status = monitor.get_all_status()
|
||||
|
||||
assert len(all_status) == 2
|
||||
assert "session-1" in all_status
|
||||
assert "session-2" in all_status
|
||||
|
||||
def test_registered_count(self, monitor):
|
||||
"""Should return correct registered count"""
|
||||
assert monitor.registered_count == 0
|
||||
|
||||
monitor.register("s1", "a")
|
||||
monitor.register("s2", "a")
|
||||
|
||||
assert monitor.registered_count == 2
|
||||
|
||||
def test_healthy_count(self, monitor):
|
||||
"""Should return correct healthy count"""
|
||||
monitor.register("s1", "a")
|
||||
monitor.register("s2", "a")
|
||||
|
||||
# Both should be healthy initially
|
||||
assert monitor.healthy_count == 2
|
||||
|
||||
# Simulate missed beat
|
||||
monitor.sessions["s1"].missed_beats = 1
|
||||
assert monitor.healthy_count == 1
|
||||
|
||||
|
||||
class TestHeartbeatClient:
|
||||
"""Tests for HeartbeatClient"""
|
||||
|
||||
@pytest.fixture
|
||||
def monitor(self):
|
||||
"""Create a monitor for the client"""
|
||||
return HeartbeatMonitor(timeout_seconds=5)
|
||||
|
||||
def test_client_creation(self, monitor):
|
||||
"""Client should be created with correct settings"""
|
||||
client = HeartbeatClient(
|
||||
session_id="session-1",
|
||||
monitor=monitor,
|
||||
interval_seconds=2
|
||||
)
|
||||
|
||||
assert client.session_id == "session-1"
|
||||
assert client.interval == 2
|
||||
assert client._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_start_stop(self, monitor):
|
||||
"""Client should start and stop correctly"""
|
||||
monitor.register("session-1", "agent")
|
||||
|
||||
client = HeartbeatClient(
|
||||
session_id="session-1",
|
||||
monitor=monitor,
|
||||
interval_seconds=1
|
||||
)
|
||||
|
||||
await client.start()
|
||||
assert client._running is True
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await client.stop()
|
||||
assert client._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_context_manager(self, monitor):
|
||||
"""Client should work as context manager"""
|
||||
monitor.register("session-1", "agent")
|
||||
|
||||
async with HeartbeatClient("session-1", monitor, 1) as client:
|
||||
assert client._running is True
|
||||
|
||||
assert client._running is False
|
||||
|
||||
|
||||
class TestHeartbeatEntry:
|
||||
"""Tests for HeartbeatEntry dataclass"""
|
||||
|
||||
def test_entry_creation(self):
|
||||
"""Entry should be created with correct values"""
|
||||
entry = HeartbeatEntry(
|
||||
session_id="session-1",
|
||||
agent_type="tutor-agent",
|
||||
last_beat=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
assert entry.session_id == "session-1"
|
||||
assert entry.agent_type == "tutor-agent"
|
||||
assert entry.missed_beats == 0
|
||||
207
agent-core/tests/test_memory_store.py
Normal file
207
agent-core/tests/test_memory_store.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Tests for Memory Store
|
||||
|
||||
Tests cover:
|
||||
- Memory storage and retrieval
|
||||
- TTL expiration
|
||||
- Access counting
|
||||
- Pattern-based search
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit('/tests/', 1)[0])
|
||||
|
||||
from brain.memory_store import MemoryStore, Memory
|
||||
|
||||
|
||||
class TestMemory:
|
||||
"""Tests for Memory dataclass"""
|
||||
|
||||
def test_memory_creation(self):
|
||||
"""Memory should be created with correct values"""
|
||||
memory = Memory(
|
||||
key="test:key",
|
||||
value={"data": "value"},
|
||||
agent_id="tutor-agent"
|
||||
)
|
||||
|
||||
assert memory.key == "test:key"
|
||||
assert memory.value["data"] == "value"
|
||||
assert memory.agent_id == "tutor-agent"
|
||||
assert memory.access_count == 0
|
||||
assert memory.expires_at is None
|
||||
|
||||
def test_memory_with_expiration(self):
|
||||
"""Memory should track expiration"""
|
||||
expires = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
|
||||
memory = Memory(
|
||||
key="temp:data",
|
||||
value="temporary",
|
||||
agent_id="agent",
|
||||
expires_at=expires
|
||||
)
|
||||
|
||||
assert memory.expires_at == expires
|
||||
assert memory.is_expired() is False
|
||||
|
||||
def test_memory_expired(self):
|
||||
"""Should detect expired memory"""
|
||||
expires = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
|
||||
memory = Memory(
|
||||
key="old:data",
|
||||
value="expired",
|
||||
agent_id="agent",
|
||||
expires_at=expires
|
||||
)
|
||||
|
||||
assert memory.is_expired() is True
|
||||
|
||||
def test_memory_serialization(self):
|
||||
"""Memory should serialize correctly"""
|
||||
memory = Memory(
|
||||
key="test",
|
||||
value={"nested": {"data": [1, 2, 3]}},
|
||||
agent_id="test-agent",
|
||||
metadata={"source": "unit_test"}
|
||||
)
|
||||
|
||||
data = memory.to_dict()
|
||||
restored = Memory.from_dict(data)
|
||||
|
||||
assert restored.key == memory.key
|
||||
assert restored.value == memory.value
|
||||
assert restored.agent_id == memory.agent_id
|
||||
assert restored.metadata == memory.metadata
|
||||
|
||||
|
||||
class TestMemoryStore:
|
||||
"""Tests for MemoryStore"""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
"""Create a memory store without persistence"""
|
||||
return MemoryStore(
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace="test"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remember_and_recall(self, store):
|
||||
"""Should store and retrieve values"""
|
||||
await store.remember(
|
||||
key="math:formula",
|
||||
value={"name": "pythagorean", "formula": "a² + b² = c²"},
|
||||
agent_id="tutor-agent"
|
||||
)
|
||||
|
||||
value = await store.recall("math:formula")
|
||||
|
||||
assert value is not None
|
||||
assert value["name"] == "pythagorean"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_nonexistent(self, store):
|
||||
"""Should return None for nonexistent key"""
|
||||
value = await store.recall("nonexistent:key")
|
||||
assert value is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory(self, store):
|
||||
"""Should retrieve full Memory object"""
|
||||
await store.remember(
|
||||
key="test:memory",
|
||||
value="test value",
|
||||
agent_id="test-agent",
|
||||
metadata={"category": "test"}
|
||||
)
|
||||
|
||||
memory = await store.get_memory("test:memory")
|
||||
|
||||
assert memory is not None
|
||||
assert memory.key == "test:memory"
|
||||
assert memory.agent_id == "test-agent"
|
||||
assert memory.access_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forget(self, store):
|
||||
"""Should delete memory"""
|
||||
await store.remember(
|
||||
key="temporary",
|
||||
value="will be deleted",
|
||||
agent_id="agent"
|
||||
)
|
||||
|
||||
result = await store.forget("temporary")
|
||||
assert result is True
|
||||
|
||||
value = await store.recall("temporary")
|
||||
assert value is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_pattern(self, store):
|
||||
"""Should search by pattern"""
|
||||
await store.remember("eval:math:1", {"score": 80}, "grader")
|
||||
await store.remember("eval:math:2", {"score": 90}, "grader")
|
||||
await store.remember("eval:english:1", {"score": 85}, "grader")
|
||||
|
||||
math_results = await store.search("eval:math:*")
|
||||
|
||||
assert len(math_results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_agent(self, store):
|
||||
"""Should filter by agent ID"""
|
||||
await store.remember("data:1", "value1", "agent-1")
|
||||
await store.remember("data:2", "value2", "agent-2")
|
||||
await store.remember("data:3", "value3", "agent-1")
|
||||
|
||||
agent1_memories = await store.get_by_agent("agent-1")
|
||||
|
||||
assert len(agent1_memories) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent(self, store):
|
||||
"""Should get recently created memories"""
|
||||
await store.remember("new:1", "value1", "agent")
|
||||
await store.remember("new:2", "value2", "agent")
|
||||
|
||||
recent = await store.get_recent(hours=1)
|
||||
|
||||
assert len(recent) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_access_count_increment(self, store):
|
||||
"""Should increment access count on recall"""
|
||||
await store.remember("counting", "value", "agent")
|
||||
|
||||
# Access multiple times
|
||||
await store.recall("counting")
|
||||
await store.recall("counting")
|
||||
await store.recall("counting")
|
||||
|
||||
memory = await store.get_memory("counting")
|
||||
assert memory.access_count >= 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired(self, store):
|
||||
"""Should clean up expired memories"""
|
||||
# Create an expired memory manually
|
||||
expired_memory = Memory(
|
||||
key="expired",
|
||||
value="old data",
|
||||
agent_id="agent",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
)
|
||||
store._local_cache["expired"] = expired_memory
|
||||
|
||||
count = await store.cleanup_expired()
|
||||
|
||||
assert count == 1
|
||||
assert "expired" not in store._local_cache
|
||||
224
agent-core/tests/test_message_bus.py
Normal file
224
agent-core/tests/test_message_bus.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Tests for Message Bus
|
||||
|
||||
Tests cover:
|
||||
- Message publishing and subscription
|
||||
- Request-response pattern
|
||||
- Message priority
|
||||
- Local delivery (without Redis)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit('/tests/', 1)[0])
|
||||
|
||||
from orchestrator.message_bus import (
|
||||
MessageBus,
|
||||
AgentMessage,
|
||||
MessagePriority,
|
||||
MessageType,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentMessage:
|
||||
"""Tests for AgentMessage dataclass"""
|
||||
|
||||
def test_message_creation_defaults(self):
|
||||
"""Message should have default values"""
|
||||
message = AgentMessage(
|
||||
sender="agent-1",
|
||||
receiver="agent-2",
|
||||
message_type="test",
|
||||
payload={"data": "value"}
|
||||
)
|
||||
|
||||
assert message.sender == "agent-1"
|
||||
assert message.receiver == "agent-2"
|
||||
assert message.priority == MessagePriority.NORMAL
|
||||
assert message.correlation_id is not None
|
||||
assert message.timestamp is not None
|
||||
|
||||
def test_message_with_priority(self):
|
||||
"""Message should accept custom priority"""
|
||||
message = AgentMessage(
|
||||
sender="alert-agent",
|
||||
receiver="admin",
|
||||
message_type="critical_alert",
|
||||
payload={},
|
||||
priority=MessagePriority.CRITICAL
|
||||
)
|
||||
|
||||
assert message.priority == MessagePriority.CRITICAL
|
||||
|
||||
def test_message_serialization(self):
|
||||
"""Message should serialize and deserialize correctly"""
|
||||
original = AgentMessage(
|
||||
sender="sender",
|
||||
receiver="receiver",
|
||||
message_type="test",
|
||||
payload={"key": "value"},
|
||||
priority=MessagePriority.HIGH
|
||||
)
|
||||
|
||||
data = original.to_dict()
|
||||
restored = AgentMessage.from_dict(data)
|
||||
|
||||
assert restored.sender == original.sender
|
||||
assert restored.receiver == original.receiver
|
||||
assert restored.message_type == original.message_type
|
||||
assert restored.payload == original.payload
|
||||
assert restored.priority == original.priority
|
||||
assert restored.correlation_id == original.correlation_id
|
||||
|
||||
|
||||
class TestMessageBus:
|
||||
"""Tests for MessageBus"""
|
||||
|
||||
@pytest.fixture
|
||||
def bus(self):
|
||||
"""Create a message bus without Redis"""
|
||||
return MessageBus(
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace="test"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop(self, bus):
|
||||
"""Bus should start and stop correctly"""
|
||||
await bus.start()
|
||||
assert bus._running is True
|
||||
|
||||
await bus.stop()
|
||||
assert bus._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_unsubscribe(self, bus):
|
||||
"""Should subscribe and unsubscribe handlers"""
|
||||
handler = AsyncMock(return_value=None)
|
||||
|
||||
await bus.subscribe("agent-1", handler)
|
||||
assert "agent-1" in bus._handlers
|
||||
|
||||
await bus.unsubscribe("agent-1")
|
||||
assert "agent-1" not in bus._handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_message_delivery(self, bus):
|
||||
"""Messages should be delivered locally without Redis"""
|
||||
received = []
|
||||
|
||||
async def handler(message):
|
||||
received.append(message)
|
||||
return None
|
||||
|
||||
await bus.subscribe("agent-2", handler)
|
||||
|
||||
message = AgentMessage(
|
||||
sender="agent-1",
|
||||
receiver="agent-2",
|
||||
message_type="test",
|
||||
payload={"data": "hello"}
|
||||
)
|
||||
|
||||
await bus.publish(message)
|
||||
|
||||
# Local delivery is synchronous
|
||||
assert len(received) == 1
|
||||
assert received[0].payload["data"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_response(self, bus):
|
||||
"""Request should get response from handler"""
|
||||
async def handler(message):
|
||||
return {"result": "processed"}
|
||||
|
||||
await bus.subscribe("responder", handler)
|
||||
|
||||
message = AgentMessage(
|
||||
sender="requester",
|
||||
receiver="responder",
|
||||
message_type="request",
|
||||
payload={"query": "test"}
|
||||
)
|
||||
|
||||
response = await bus.request(message, timeout=5.0)
|
||||
|
||||
assert response["result"] == "processed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_timeout(self, bus):
|
||||
"""Request should timeout if no response"""
|
||||
async def slow_handler(message):
|
||||
await asyncio.sleep(10)
|
||||
return {"result": "too late"}
|
||||
|
||||
await bus.subscribe("slow-agent", slow_handler)
|
||||
|
||||
message = AgentMessage(
|
||||
sender="requester",
|
||||
receiver="slow-agent",
|
||||
message_type="request",
|
||||
payload={}
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await bus.request(message, timeout=0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast(self, bus):
|
||||
"""Broadcast should reach all subscribers"""
|
||||
received_1 = []
|
||||
received_2 = []
|
||||
|
||||
async def handler_1(message):
|
||||
received_1.append(message)
|
||||
return None
|
||||
|
||||
async def handler_2(message):
|
||||
received_2.append(message)
|
||||
return None
|
||||
|
||||
await bus.subscribe("agent-1", handler_1)
|
||||
await bus.subscribe("agent-2", handler_2)
|
||||
|
||||
message = AgentMessage(
|
||||
sender="broadcaster",
|
||||
receiver="*",
|
||||
message_type="announcement",
|
||||
payload={"text": "Hello everyone"}
|
||||
)
|
||||
|
||||
await bus.broadcast(message)
|
||||
|
||||
assert len(received_1) == 1
|
||||
assert len(received_2) == 1
|
||||
|
||||
def test_connected_property(self, bus):
|
||||
"""Connected should reflect running state"""
|
||||
assert bus.connected is False
|
||||
|
||||
def test_subscriber_count(self, bus):
|
||||
"""Should track subscriber count"""
|
||||
assert bus.subscriber_count == 0
|
||||
|
||||
|
||||
class TestMessagePriority:
|
||||
"""Tests for MessagePriority enum"""
|
||||
|
||||
def test_priority_ordering(self):
|
||||
"""Priorities should have correct ordering"""
|
||||
assert MessagePriority.LOW.value < MessagePriority.NORMAL.value
|
||||
assert MessagePriority.NORMAL.value < MessagePriority.HIGH.value
|
||||
assert MessagePriority.HIGH.value < MessagePriority.CRITICAL.value
|
||||
|
||||
def test_priority_values(self):
|
||||
"""Priorities should have expected values"""
|
||||
assert MessagePriority.LOW.value == 0
|
||||
assert MessagePriority.NORMAL.value == 1
|
||||
assert MessagePriority.HIGH.value == 2
|
||||
assert MessagePriority.CRITICAL.value == 3
|
||||
270
agent-core/tests/test_session_manager.py
Normal file
270
agent-core/tests/test_session_manager.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Tests for Session Management
|
||||
|
||||
Tests cover:
|
||||
- Session creation and retrieval
|
||||
- State transitions
|
||||
- Checkpoint management
|
||||
- Heartbeat integration
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit('/tests/', 1)[0])
|
||||
|
||||
from sessions.session_manager import (
|
||||
AgentSession,
|
||||
SessionManager,
|
||||
SessionState,
|
||||
SessionCheckpoint,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentSession:
|
||||
"""Tests for AgentSession dataclass"""
|
||||
|
||||
def test_create_session_defaults(self):
|
||||
"""Session should have default values"""
|
||||
session = AgentSession()
|
||||
|
||||
assert session.session_id is not None
|
||||
assert session.agent_type == ""
|
||||
assert session.state == SessionState.ACTIVE
|
||||
assert session.checkpoints == []
|
||||
assert session.context == {}
|
||||
|
||||
def test_create_session_with_values(self):
|
||||
"""Session should accept custom values"""
|
||||
session = AgentSession(
|
||||
agent_type="tutor-agent",
|
||||
user_id="user-123",
|
||||
context={"subject": "math"}
|
||||
)
|
||||
|
||||
assert session.agent_type == "tutor-agent"
|
||||
assert session.user_id == "user-123"
|
||||
assert session.context["subject"] == "math"
|
||||
|
||||
def test_checkpoint_creation(self):
|
||||
"""Session should create checkpoints correctly"""
|
||||
session = AgentSession()
|
||||
checkpoint = session.checkpoint("task_received", {"task_id": "123"})
|
||||
|
||||
assert len(session.checkpoints) == 1
|
||||
assert checkpoint.name == "task_received"
|
||||
assert checkpoint.data["task_id"] == "123"
|
||||
assert checkpoint.timestamp is not None
|
||||
|
||||
def test_heartbeat_updates_timestamp(self):
|
||||
"""Heartbeat should update last_heartbeat"""
|
||||
session = AgentSession()
|
||||
original = session.last_heartbeat
|
||||
|
||||
# Small delay to ensure time difference
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
|
||||
session.heartbeat()
|
||||
|
||||
assert session.last_heartbeat > original
|
||||
|
||||
def test_pause_and_resume(self):
|
||||
"""Session should pause and resume correctly"""
|
||||
session = AgentSession()
|
||||
|
||||
session.pause()
|
||||
assert session.state == SessionState.PAUSED
|
||||
assert len(session.checkpoints) == 1 # Pause creates checkpoint
|
||||
|
||||
session.resume()
|
||||
assert session.state == SessionState.ACTIVE
|
||||
assert len(session.checkpoints) == 2 # Resume creates checkpoint
|
||||
|
||||
def test_complete_session(self):
|
||||
"""Session should complete with result"""
|
||||
session = AgentSession()
|
||||
session.complete({"output": "success"})
|
||||
|
||||
assert session.state == SessionState.COMPLETED
|
||||
last_cp = session.get_last_checkpoint()
|
||||
assert last_cp.name == "session_completed"
|
||||
assert last_cp.data["result"]["output"] == "success"
|
||||
|
||||
def test_fail_session(self):
|
||||
"""Session should fail with error"""
|
||||
session = AgentSession()
|
||||
session.fail("Connection timeout", {"code": 504})
|
||||
|
||||
assert session.state == SessionState.FAILED
|
||||
last_cp = session.get_last_checkpoint()
|
||||
assert last_cp.name == "session_failed"
|
||||
assert last_cp.data["error"] == "Connection timeout"
|
||||
assert last_cp.data["details"]["code"] == 504
|
||||
|
||||
def test_get_last_checkpoint_by_name(self):
|
||||
"""Should filter checkpoints by name"""
|
||||
session = AgentSession()
|
||||
session.checkpoint("step_1", {"data": 1})
|
||||
session.checkpoint("step_2", {"data": 2})
|
||||
session.checkpoint("step_1", {"data": 3})
|
||||
|
||||
last_step_1 = session.get_last_checkpoint("step_1")
|
||||
assert last_step_1.data["data"] == 3
|
||||
|
||||
last_step_2 = session.get_last_checkpoint("step_2")
|
||||
assert last_step_2.data["data"] == 2
|
||||
|
||||
def test_get_duration(self):
|
||||
"""Should calculate session duration"""
|
||||
session = AgentSession()
|
||||
duration = session.get_duration()
|
||||
|
||||
assert duration.total_seconds() >= 0
|
||||
assert duration.total_seconds() < 1 # Should be very fast
|
||||
|
||||
def test_serialization(self):
|
||||
"""Session should serialize and deserialize correctly"""
|
||||
session = AgentSession(
|
||||
agent_type="grader-agent",
|
||||
user_id="user-456",
|
||||
context={"exam_id": "exam-1"}
|
||||
)
|
||||
session.checkpoint("grading_started", {"questions": 5})
|
||||
|
||||
# Serialize
|
||||
data = session.to_dict()
|
||||
|
||||
# Deserialize
|
||||
restored = AgentSession.from_dict(data)
|
||||
|
||||
assert restored.session_id == session.session_id
|
||||
assert restored.agent_type == session.agent_type
|
||||
assert restored.user_id == session.user_id
|
||||
assert restored.context == session.context
|
||||
assert len(restored.checkpoints) == 1
|
||||
|
||||
|
||||
class TestSessionManager:
|
||||
"""Tests for SessionManager"""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
"""Create a session manager without persistence"""
|
||||
return SessionManager(
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace="test"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self, manager):
|
||||
"""Should create new sessions"""
|
||||
session = await manager.create_session(
|
||||
agent_type="tutor-agent",
|
||||
user_id="user-789",
|
||||
context={"grade": 10}
|
||||
)
|
||||
|
||||
assert session.agent_type == "tutor-agent"
|
||||
assert session.user_id == "user-789"
|
||||
assert session.context["grade"] == 10
|
||||
assert len(session.checkpoints) == 1 # session_created
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_from_cache(self, manager):
|
||||
"""Should retrieve session from local cache"""
|
||||
created = await manager.create_session(
|
||||
agent_type="grader-agent"
|
||||
)
|
||||
|
||||
retrieved = await manager.get_session(created.session_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.session_id == created.session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_session(self, manager):
|
||||
"""Should return None for nonexistent session"""
|
||||
result = await manager.get_session("nonexistent-id")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_session(self, manager):
|
||||
"""Should update session in cache"""
|
||||
session = await manager.create_session(agent_type="alert-agent")
|
||||
|
||||
session.context["alert_count"] = 5
|
||||
await manager.update_session(session)
|
||||
|
||||
retrieved = await manager.get_session(session.session_id)
|
||||
assert retrieved.context["alert_count"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_session(self, manager):
|
||||
"""Should delete session from cache"""
|
||||
session = await manager.create_session(agent_type="test-agent")
|
||||
|
||||
result = await manager.delete_session(session.session_id)
|
||||
assert result is True
|
||||
|
||||
retrieved = await manager.get_session(session.session_id)
|
||||
assert retrieved is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_sessions(self, manager):
|
||||
"""Should return active sessions filtered by type"""
|
||||
await manager.create_session(agent_type="tutor-agent")
|
||||
await manager.create_session(agent_type="tutor-agent")
|
||||
await manager.create_session(agent_type="grader-agent")
|
||||
|
||||
tutor_sessions = await manager.get_active_sessions(
|
||||
agent_type="tutor-agent"
|
||||
)
|
||||
|
||||
assert len(tutor_sessions) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_stale_sessions(self, manager):
|
||||
"""Should mark stale sessions as failed"""
|
||||
# Create a session with old heartbeat
|
||||
session = await manager.create_session(agent_type="test-agent")
|
||||
session.last_heartbeat = datetime.now(timezone.utc) - timedelta(hours=50)
|
||||
manager._local_cache[session.session_id] = session
|
||||
|
||||
count = await manager.cleanup_stale_sessions(max_age=timedelta(hours=48))
|
||||
|
||||
assert count == 1
|
||||
assert session.state == SessionState.FAILED
|
||||
|
||||
|
||||
class TestSessionCheckpoint:
|
||||
"""Tests for SessionCheckpoint"""
|
||||
|
||||
def test_checkpoint_creation(self):
|
||||
"""Checkpoint should store data correctly"""
|
||||
checkpoint = SessionCheckpoint(
|
||||
name="test_checkpoint",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data={"key": "value"}
|
||||
)
|
||||
|
||||
assert checkpoint.name == "test_checkpoint"
|
||||
assert checkpoint.data["key"] == "value"
|
||||
|
||||
def test_checkpoint_serialization(self):
|
||||
"""Checkpoint should serialize correctly"""
|
||||
checkpoint = SessionCheckpoint(
|
||||
name="test",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data={"count": 42}
|
||||
)
|
||||
|
||||
data = checkpoint.to_dict()
|
||||
restored = SessionCheckpoint.from_dict(data)
|
||||
|
||||
assert restored.name == checkpoint.name
|
||||
assert restored.data == checkpoint.data
|
||||
203
agent-core/tests/test_task_router.py
Normal file
203
agent-core/tests/test_task_router.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Tests for Task Router
|
||||
|
||||
Tests cover:
|
||||
- Intent-based routing
|
||||
- Routing rules
|
||||
- Fallback handling
|
||||
- Routing statistics
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit('/tests/', 1)[0])
|
||||
|
||||
from orchestrator.task_router import (
|
||||
TaskRouter,
|
||||
RoutingRule,
|
||||
RoutingResult,
|
||||
RoutingStrategy,
|
||||
)
|
||||
|
||||
|
||||
class TestRoutingRule:
|
||||
"""Tests for RoutingRule dataclass"""
|
||||
|
||||
def test_rule_creation(self):
|
||||
"""Rule should be created correctly"""
|
||||
rule = RoutingRule(
|
||||
intent_pattern="learning_*",
|
||||
agent_type="tutor-agent",
|
||||
priority=10
|
||||
)
|
||||
|
||||
assert rule.intent_pattern == "learning_*"
|
||||
assert rule.agent_type == "tutor-agent"
|
||||
assert rule.priority == 10
|
||||
|
||||
def test_rule_matches_exact(self):
|
||||
"""Rule should match exact intent"""
|
||||
rule = RoutingRule(
|
||||
intent_pattern="grade_exam",
|
||||
agent_type="grader-agent"
|
||||
)
|
||||
|
||||
assert rule.matches("grade_exam", {}) is True
|
||||
assert rule.matches("grade_quiz", {}) is False
|
||||
|
||||
def test_rule_matches_wildcard(self):
|
||||
"""Rule should match wildcard patterns"""
|
||||
rule = RoutingRule(
|
||||
intent_pattern="learning_*",
|
||||
agent_type="tutor-agent"
|
||||
)
|
||||
|
||||
assert rule.matches("learning_math", {}) is True
|
||||
assert rule.matches("learning_english", {}) is True
|
||||
assert rule.matches("grading_math", {}) is False
|
||||
|
||||
def test_rule_matches_conditions(self):
|
||||
"""Rule should check conditions"""
|
||||
rule = RoutingRule(
|
||||
intent_pattern="*",
|
||||
agent_type="vip-agent",
|
||||
conditions={"is_vip": True}
|
||||
)
|
||||
|
||||
assert rule.matches("any_intent", {"is_vip": True}) is True
|
||||
assert rule.matches("any_intent", {"is_vip": False}) is False
|
||||
assert rule.matches("any_intent", {}) is False
|
||||
|
||||
|
||||
class TestRoutingResult:
|
||||
"""Tests for RoutingResult dataclass"""
|
||||
|
||||
def test_successful_result(self):
|
||||
"""Should create successful routing result"""
|
||||
result = RoutingResult(
|
||||
success=True,
|
||||
agent_id="tutor-1",
|
||||
agent_type="tutor-agent",
|
||||
reason="Primary agent selected"
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.agent_id == "tutor-1"
|
||||
assert result.is_fallback is False
|
||||
|
||||
def test_fallback_result(self):
|
||||
"""Should indicate fallback routing"""
|
||||
result = RoutingResult(
|
||||
success=True,
|
||||
agent_id="backup-1",
|
||||
agent_type="backup-agent",
|
||||
is_fallback=True,
|
||||
reason="Fallback used"
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.is_fallback is True
|
||||
|
||||
def test_failed_result(self):
|
||||
"""Should create failed routing result"""
|
||||
result = RoutingResult(
|
||||
success=False,
|
||||
reason="No agents available"
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.agent_id is None
|
||||
|
||||
|
||||
class TestTaskRouter:
|
||||
"""Tests for TaskRouter"""
|
||||
|
||||
@pytest.fixture
|
||||
def router(self):
|
||||
"""Create a task router without supervisor"""
|
||||
return TaskRouter(supervisor=None)
|
||||
|
||||
def test_default_rules_exist(self, router):
|
||||
"""Router should have default rules"""
|
||||
rules = router.get_rules()
|
||||
assert len(rules) > 0
|
||||
|
||||
def test_add_rule(self, router):
|
||||
"""Should add new routing rule"""
|
||||
original_count = len(router.rules)
|
||||
|
||||
router.add_rule(RoutingRule(
|
||||
intent_pattern="custom_*",
|
||||
agent_type="custom-agent",
|
||||
priority=100
|
||||
))
|
||||
|
||||
assert len(router.rules) == original_count + 1
|
||||
|
||||
def test_rules_sorted_by_priority(self, router):
|
||||
"""Rules should be sorted by priority (high first)"""
|
||||
router.add_rule(RoutingRule(
|
||||
intent_pattern="low_*",
|
||||
agent_type="low-agent",
|
||||
priority=1
|
||||
))
|
||||
router.add_rule(RoutingRule(
|
||||
intent_pattern="high_*",
|
||||
agent_type="high-agent",
|
||||
priority=100
|
||||
))
|
||||
|
||||
# Highest priority should be first
|
||||
assert router.rules[0].priority >= router.rules[-1].priority
|
||||
|
||||
def test_remove_rule(self, router):
|
||||
"""Should remove routing rule"""
|
||||
router.add_rule(RoutingRule(
|
||||
intent_pattern="removable_*",
|
||||
agent_type="temp-agent"
|
||||
))
|
||||
|
||||
result = router.remove_rule("removable_*")
|
||||
assert result is True
|
||||
|
||||
def test_find_matching_rules(self, router):
|
||||
"""Should find rules matching intent"""
|
||||
matching = router.find_matching_rules("learning_math")
|
||||
|
||||
assert len(matching) > 0
|
||||
assert any(r["agent_type"] == "tutor-agent" for r in matching)
|
||||
|
||||
def test_get_routing_stats_empty(self, router):
|
||||
"""Should return empty stats initially"""
|
||||
stats = router.get_routing_stats()
|
||||
|
||||
assert stats["total_routes"] == 0
|
||||
|
||||
def test_set_default_route(self, router):
|
||||
"""Should set default agent for type"""
|
||||
router.set_default_route("tutor-agent", "tutor-primary")
|
||||
|
||||
assert router._default_routes["tutor-agent"] == "tutor-primary"
|
||||
|
||||
def test_clear_history(self, router):
|
||||
"""Should clear routing history"""
|
||||
# Add some history
|
||||
router._routing_history.append({"test": "data"})
|
||||
|
||||
router.clear_history()
|
||||
|
||||
assert len(router._routing_history) == 0
|
||||
|
||||
|
||||
class TestRoutingStrategy:
|
||||
"""Tests for RoutingStrategy enum"""
|
||||
|
||||
def test_strategy_values(self):
|
||||
"""Strategies should have expected values"""
|
||||
assert RoutingStrategy.DIRECT.value == "direct"
|
||||
assert RoutingStrategy.ROUND_ROBIN.value == "round_robin"
|
||||
assert RoutingStrategy.LEAST_LOADED.value == "least_loaded"
|
||||
assert RoutingStrategy.PRIORITY.value == "priority"
|
||||
Reference in New Issue
Block a user