""" Tests for Session Management Tests cover: - Session creation and retrieval - State transitions - Checkpoint management - Heartbeat integration """ import pytest import asyncio from datetime import datetime, timezone, timedelta from unittest.mock import AsyncMock, MagicMock, patch import sys sys.path.insert(0, str(__file__).rsplit('/tests/', 1)[0]) from sessions.session_manager import ( AgentSession, SessionManager, SessionState, SessionCheckpoint, ) class TestAgentSession: """Tests for AgentSession dataclass""" def test_create_session_defaults(self): """Session should have default values""" session = AgentSession() assert session.session_id is not None assert session.agent_type == "" assert session.state == SessionState.ACTIVE assert session.checkpoints == [] assert session.context == {} def test_create_session_with_values(self): """Session should accept custom values""" session = AgentSession( agent_type="tutor-agent", user_id="user-123", context={"subject": "math"} ) assert session.agent_type == "tutor-agent" assert session.user_id == "user-123" assert session.context["subject"] == "math" def test_checkpoint_creation(self): """Session should create checkpoints correctly""" session = AgentSession() checkpoint = session.checkpoint("task_received", {"task_id": "123"}) assert len(session.checkpoints) == 1 assert checkpoint.name == "task_received" assert checkpoint.data["task_id"] == "123" assert checkpoint.timestamp is not None def test_heartbeat_updates_timestamp(self): """Heartbeat should update last_heartbeat""" session = AgentSession() original = session.last_heartbeat # Small delay to ensure time difference import time time.sleep(0.01) session.heartbeat() assert session.last_heartbeat > original def test_pause_and_resume(self): """Session should pause and resume correctly""" session = AgentSession() session.pause() assert session.state == SessionState.PAUSED assert len(session.checkpoints) == 1 # Pause creates checkpoint session.resume() assert session.state == SessionState.ACTIVE assert len(session.checkpoints) == 2 # Resume creates checkpoint def test_complete_session(self): """Session should complete with result""" session = AgentSession() session.complete({"output": "success"}) assert session.state == SessionState.COMPLETED last_cp = session.get_last_checkpoint() assert last_cp.name == "session_completed" assert last_cp.data["result"]["output"] == "success" def test_fail_session(self): """Session should fail with error""" session = AgentSession() session.fail("Connection timeout", {"code": 504}) assert session.state == SessionState.FAILED last_cp = session.get_last_checkpoint() assert last_cp.name == "session_failed" assert last_cp.data["error"] == "Connection timeout" assert last_cp.data["details"]["code"] == 504 def test_get_last_checkpoint_by_name(self): """Should filter checkpoints by name""" session = AgentSession() session.checkpoint("step_1", {"data": 1}) session.checkpoint("step_2", {"data": 2}) session.checkpoint("step_1", {"data": 3}) last_step_1 = session.get_last_checkpoint("step_1") assert last_step_1.data["data"] == 3 last_step_2 = session.get_last_checkpoint("step_2") assert last_step_2.data["data"] == 2 def test_get_duration(self): """Should calculate session duration""" session = AgentSession() duration = session.get_duration() assert duration.total_seconds() >= 0 assert duration.total_seconds() < 1 # Should be very fast def test_serialization(self): """Session should serialize and deserialize correctly""" session = AgentSession( agent_type="grader-agent", user_id="user-456", context={"exam_id": "exam-1"} ) session.checkpoint("grading_started", {"questions": 5}) # Serialize data = session.to_dict() # Deserialize restored = AgentSession.from_dict(data) assert restored.session_id == session.session_id assert restored.agent_type == session.agent_type assert restored.user_id == session.user_id assert restored.context == session.context assert len(restored.checkpoints) == 1 class TestSessionManager: """Tests for SessionManager""" @pytest.fixture def manager(self): """Create a session manager without persistence""" return SessionManager( redis_client=None, db_pool=None, namespace="test" ) @pytest.mark.asyncio async def test_create_session(self, manager): """Should create new sessions""" session = await manager.create_session( agent_type="tutor-agent", user_id="user-789", context={"grade": 10} ) assert session.agent_type == "tutor-agent" assert session.user_id == "user-789" assert session.context["grade"] == 10 assert len(session.checkpoints) == 1 # session_created @pytest.mark.asyncio async def test_get_session_from_cache(self, manager): """Should retrieve session from local cache""" created = await manager.create_session( agent_type="grader-agent" ) retrieved = await manager.get_session(created.session_id) assert retrieved is not None assert retrieved.session_id == created.session_id @pytest.mark.asyncio async def test_get_nonexistent_session(self, manager): """Should return None for nonexistent session""" result = await manager.get_session("nonexistent-id") assert result is None @pytest.mark.asyncio async def test_update_session(self, manager): """Should update session in cache""" session = await manager.create_session(agent_type="alert-agent") session.context["alert_count"] = 5 await manager.update_session(session) retrieved = await manager.get_session(session.session_id) assert retrieved.context["alert_count"] == 5 @pytest.mark.asyncio async def test_delete_session(self, manager): """Should delete session from cache""" session = await manager.create_session(agent_type="test-agent") result = await manager.delete_session(session.session_id) assert result is True retrieved = await manager.get_session(session.session_id) assert retrieved is None @pytest.mark.asyncio async def test_get_active_sessions(self, manager): """Should return active sessions filtered by type""" await manager.create_session(agent_type="tutor-agent") await manager.create_session(agent_type="tutor-agent") await manager.create_session(agent_type="grader-agent") tutor_sessions = await manager.get_active_sessions( agent_type="tutor-agent" ) assert len(tutor_sessions) == 2 @pytest.mark.asyncio async def test_cleanup_stale_sessions(self, manager): """Should mark stale sessions as failed""" # Create a session with old heartbeat session = await manager.create_session(agent_type="test-agent") session.last_heartbeat = datetime.now(timezone.utc) - timedelta(hours=50) manager._local_cache[session.session_id] = session count = await manager.cleanup_stale_sessions(max_age=timedelta(hours=48)) assert count == 1 assert session.state == SessionState.FAILED class TestSessionCheckpoint: """Tests for SessionCheckpoint""" def test_checkpoint_creation(self): """Checkpoint should store data correctly""" checkpoint = SessionCheckpoint( name="test_checkpoint", timestamp=datetime.now(timezone.utc), data={"key": "value"} ) assert checkpoint.name == "test_checkpoint" assert checkpoint.data["key"] == "value" def test_checkpoint_serialization(self): """Checkpoint should serialize correctly""" checkpoint = SessionCheckpoint( name="test", timestamp=datetime.now(timezone.utc), data={"count": 42} ) data = checkpoint.to_dict() restored = SessionCheckpoint.from_dict(data) assert restored.name == checkpoint.name assert restored.data == checkpoint.data