Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
271 lines
8.7 KiB
Python
271 lines
8.7 KiB
Python
"""
|
|
Tests for Session Management
|
|
|
|
Tests cover:
|
|
- Session creation and retrieval
|
|
- State transitions
|
|
- Checkpoint management
|
|
- Heartbeat integration
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from datetime import datetime, timezone, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import sys
|
|
sys.path.insert(0, str(__file__).rsplit('/tests/', 1)[0])
|
|
|
|
from sessions.session_manager import (
|
|
AgentSession,
|
|
SessionManager,
|
|
SessionState,
|
|
SessionCheckpoint,
|
|
)
|
|
|
|
|
|
class TestAgentSession:
|
|
"""Tests for AgentSession dataclass"""
|
|
|
|
def test_create_session_defaults(self):
|
|
"""Session should have default values"""
|
|
session = AgentSession()
|
|
|
|
assert session.session_id is not None
|
|
assert session.agent_type == ""
|
|
assert session.state == SessionState.ACTIVE
|
|
assert session.checkpoints == []
|
|
assert session.context == {}
|
|
|
|
def test_create_session_with_values(self):
|
|
"""Session should accept custom values"""
|
|
session = AgentSession(
|
|
agent_type="tutor-agent",
|
|
user_id="user-123",
|
|
context={"subject": "math"}
|
|
)
|
|
|
|
assert session.agent_type == "tutor-agent"
|
|
assert session.user_id == "user-123"
|
|
assert session.context["subject"] == "math"
|
|
|
|
def test_checkpoint_creation(self):
|
|
"""Session should create checkpoints correctly"""
|
|
session = AgentSession()
|
|
checkpoint = session.checkpoint("task_received", {"task_id": "123"})
|
|
|
|
assert len(session.checkpoints) == 1
|
|
assert checkpoint.name == "task_received"
|
|
assert checkpoint.data["task_id"] == "123"
|
|
assert checkpoint.timestamp is not None
|
|
|
|
def test_heartbeat_updates_timestamp(self):
|
|
"""Heartbeat should update last_heartbeat"""
|
|
session = AgentSession()
|
|
original = session.last_heartbeat
|
|
|
|
# Small delay to ensure time difference
|
|
import time
|
|
time.sleep(0.01)
|
|
|
|
session.heartbeat()
|
|
|
|
assert session.last_heartbeat > original
|
|
|
|
def test_pause_and_resume(self):
|
|
"""Session should pause and resume correctly"""
|
|
session = AgentSession()
|
|
|
|
session.pause()
|
|
assert session.state == SessionState.PAUSED
|
|
assert len(session.checkpoints) == 1 # Pause creates checkpoint
|
|
|
|
session.resume()
|
|
assert session.state == SessionState.ACTIVE
|
|
assert len(session.checkpoints) == 2 # Resume creates checkpoint
|
|
|
|
def test_complete_session(self):
|
|
"""Session should complete with result"""
|
|
session = AgentSession()
|
|
session.complete({"output": "success"})
|
|
|
|
assert session.state == SessionState.COMPLETED
|
|
last_cp = session.get_last_checkpoint()
|
|
assert last_cp.name == "session_completed"
|
|
assert last_cp.data["result"]["output"] == "success"
|
|
|
|
def test_fail_session(self):
|
|
"""Session should fail with error"""
|
|
session = AgentSession()
|
|
session.fail("Connection timeout", {"code": 504})
|
|
|
|
assert session.state == SessionState.FAILED
|
|
last_cp = session.get_last_checkpoint()
|
|
assert last_cp.name == "session_failed"
|
|
assert last_cp.data["error"] == "Connection timeout"
|
|
assert last_cp.data["details"]["code"] == 504
|
|
|
|
def test_get_last_checkpoint_by_name(self):
|
|
"""Should filter checkpoints by name"""
|
|
session = AgentSession()
|
|
session.checkpoint("step_1", {"data": 1})
|
|
session.checkpoint("step_2", {"data": 2})
|
|
session.checkpoint("step_1", {"data": 3})
|
|
|
|
last_step_1 = session.get_last_checkpoint("step_1")
|
|
assert last_step_1.data["data"] == 3
|
|
|
|
last_step_2 = session.get_last_checkpoint("step_2")
|
|
assert last_step_2.data["data"] == 2
|
|
|
|
def test_get_duration(self):
|
|
"""Should calculate session duration"""
|
|
session = AgentSession()
|
|
duration = session.get_duration()
|
|
|
|
assert duration.total_seconds() >= 0
|
|
assert duration.total_seconds() < 1 # Should be very fast
|
|
|
|
def test_serialization(self):
|
|
"""Session should serialize and deserialize correctly"""
|
|
session = AgentSession(
|
|
agent_type="grader-agent",
|
|
user_id="user-456",
|
|
context={"exam_id": "exam-1"}
|
|
)
|
|
session.checkpoint("grading_started", {"questions": 5})
|
|
|
|
# Serialize
|
|
data = session.to_dict()
|
|
|
|
# Deserialize
|
|
restored = AgentSession.from_dict(data)
|
|
|
|
assert restored.session_id == session.session_id
|
|
assert restored.agent_type == session.agent_type
|
|
assert restored.user_id == session.user_id
|
|
assert restored.context == session.context
|
|
assert len(restored.checkpoints) == 1
|
|
|
|
|
|
class TestSessionManager:
|
|
"""Tests for SessionManager"""
|
|
|
|
@pytest.fixture
|
|
def manager(self):
|
|
"""Create a session manager without persistence"""
|
|
return SessionManager(
|
|
redis_client=None,
|
|
db_pool=None,
|
|
namespace="test"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_session(self, manager):
|
|
"""Should create new sessions"""
|
|
session = await manager.create_session(
|
|
agent_type="tutor-agent",
|
|
user_id="user-789",
|
|
context={"grade": 10}
|
|
)
|
|
|
|
assert session.agent_type == "tutor-agent"
|
|
assert session.user_id == "user-789"
|
|
assert session.context["grade"] == 10
|
|
assert len(session.checkpoints) == 1 # session_created
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_from_cache(self, manager):
|
|
"""Should retrieve session from local cache"""
|
|
created = await manager.create_session(
|
|
agent_type="grader-agent"
|
|
)
|
|
|
|
retrieved = await manager.get_session(created.session_id)
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.session_id == created.session_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_nonexistent_session(self, manager):
|
|
"""Should return None for nonexistent session"""
|
|
result = await manager.get_session("nonexistent-id")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_session(self, manager):
|
|
"""Should update session in cache"""
|
|
session = await manager.create_session(agent_type="alert-agent")
|
|
|
|
session.context["alert_count"] = 5
|
|
await manager.update_session(session)
|
|
|
|
retrieved = await manager.get_session(session.session_id)
|
|
assert retrieved.context["alert_count"] == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_session(self, manager):
|
|
"""Should delete session from cache"""
|
|
session = await manager.create_session(agent_type="test-agent")
|
|
|
|
result = await manager.delete_session(session.session_id)
|
|
assert result is True
|
|
|
|
retrieved = await manager.get_session(session.session_id)
|
|
assert retrieved is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_active_sessions(self, manager):
|
|
"""Should return active sessions filtered by type"""
|
|
await manager.create_session(agent_type="tutor-agent")
|
|
await manager.create_session(agent_type="tutor-agent")
|
|
await manager.create_session(agent_type="grader-agent")
|
|
|
|
tutor_sessions = await manager.get_active_sessions(
|
|
agent_type="tutor-agent"
|
|
)
|
|
|
|
assert len(tutor_sessions) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_stale_sessions(self, manager):
|
|
"""Should mark stale sessions as failed"""
|
|
# Create a session with old heartbeat
|
|
session = await manager.create_session(agent_type="test-agent")
|
|
session.last_heartbeat = datetime.now(timezone.utc) - timedelta(hours=50)
|
|
manager._local_cache[session.session_id] = session
|
|
|
|
count = await manager.cleanup_stale_sessions(max_age=timedelta(hours=48))
|
|
|
|
assert count == 1
|
|
assert session.state == SessionState.FAILED
|
|
|
|
|
|
class TestSessionCheckpoint:
|
|
"""Tests for SessionCheckpoint"""
|
|
|
|
def test_checkpoint_creation(self):
|
|
"""Checkpoint should store data correctly"""
|
|
checkpoint = SessionCheckpoint(
|
|
name="test_checkpoint",
|
|
timestamp=datetime.now(timezone.utc),
|
|
data={"key": "value"}
|
|
)
|
|
|
|
assert checkpoint.name == "test_checkpoint"
|
|
assert checkpoint.data["key"] == "value"
|
|
|
|
def test_checkpoint_serialization(self):
|
|
"""Checkpoint should serialize correctly"""
|
|
checkpoint = SessionCheckpoint(
|
|
name="test",
|
|
timestamp=datetime.now(timezone.utc),
|
|
data={"count": 42}
|
|
)
|
|
|
|
data = checkpoint.to_dict()
|
|
restored = SessionCheckpoint.from_dict(data)
|
|
|
|
assert restored.name == checkpoint.name
|
|
assert restored.data == checkpoint.data
|