""" Session Management for Breakpilot Agents Provides session lifecycle management with: - State tracking (ACTIVE, PAUSED, COMPLETED, FAILED) - Checkpoint-based recovery - Heartbeat integration - Hybrid Valkey + PostgreSQL persistence """ from dataclasses import dataclass, field from datetime import datetime, timezone, timedelta from typing import Dict, Any, Optional, List from enum import Enum import uuid import json import logging logger = logging.getLogger(__name__) class SessionState(Enum): """Agent session states""" ACTIVE = "active" PAUSED = "paused" COMPLETED = "completed" FAILED = "failed" @dataclass class SessionCheckpoint: """Represents a checkpoint in an agent session""" name: str timestamp: datetime data: Dict[str, Any] def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "timestamp": self.timestamp.isoformat(), "data": self.data } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint": return cls( name=data["name"], timestamp=datetime.fromisoformat(data["timestamp"]), data=data["data"] ) @dataclass class AgentSession: """ Represents an active agent session. Attributes: session_id: Unique session identifier agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator) user_id: Associated user ID state: Current session state created_at: Session creation timestamp last_heartbeat: Last heartbeat timestamp context: Session context data checkpoints: List of session checkpoints for recovery """ session_id: str = field(default_factory=lambda: str(uuid.uuid4())) agent_type: str = "" user_id: str = "" state: SessionState = SessionState.ACTIVE created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) context: Dict[str, Any] = field(default_factory=dict) checkpoints: List[SessionCheckpoint] = field(default_factory=list) metadata: Dict[str, Any] = field(default_factory=dict) def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint: """ Creates a checkpoint for recovery. Args: name: Checkpoint name (e.g., "task_received", "processing_complete") data: Checkpoint data to store Returns: The created checkpoint """ checkpoint = SessionCheckpoint( name=name, timestamp=datetime.now(timezone.utc), data=data ) self.checkpoints.append(checkpoint) logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created") return checkpoint def heartbeat(self) -> None: """Updates the heartbeat timestamp""" self.last_heartbeat = datetime.now(timezone.utc) def pause(self) -> None: """Pauses the session""" self.state = SessionState.PAUSED self.checkpoint("session_paused", {"previous_state": "active"}) def resume(self) -> None: """Resumes a paused session""" if self.state == SessionState.PAUSED: self.state = SessionState.ACTIVE self.heartbeat() self.checkpoint("session_resumed", {}) def complete(self, result: Optional[Dict[str, Any]] = None) -> None: """Marks the session as completed""" self.state = SessionState.COMPLETED self.checkpoint("session_completed", {"result": result or {}}) def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None: """Marks the session as failed""" self.state = SessionState.FAILED self.checkpoint("session_failed", { "error": error, "details": error_details or {} }) def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]: """ Gets the last checkpoint, optionally filtered by name. Args: name: Optional checkpoint name to filter by Returns: The last matching checkpoint or None """ if not self.checkpoints: return None if name: matching = [cp for cp in self.checkpoints if cp.name == name] return matching[-1] if matching else None return self.checkpoints[-1] def get_duration(self) -> timedelta: """Returns the session duration""" end_time = datetime.now(timezone.utc) if self.state in (SessionState.COMPLETED, SessionState.FAILED): last_cp = self.get_last_checkpoint() if last_cp: end_time = last_cp.timestamp return end_time - self.created_at def to_dict(self) -> Dict[str, Any]: """Serializes the session to a dictionary""" return { "session_id": self.session_id, "agent_type": self.agent_type, "user_id": self.user_id, "state": self.state.value, "created_at": self.created_at.isoformat(), "last_heartbeat": self.last_heartbeat.isoformat(), "context": self.context, "checkpoints": [cp.to_dict() for cp in self.checkpoints], "metadata": self.metadata } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AgentSession": """Deserializes a session from a dictionary""" return cls( session_id=data["session_id"], agent_type=data["agent_type"], user_id=data["user_id"], state=SessionState(data["state"]), created_at=datetime.fromisoformat(data["created_at"]), last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]), context=data.get("context", {}), checkpoints=[ SessionCheckpoint.from_dict(cp) for cp in data.get("checkpoints", []) ], metadata=data.get("metadata", {}) ) class SessionManager: """ Manages agent sessions with hybrid Valkey + PostgreSQL persistence. - Valkey: Fast access for active sessions (24h TTL) - PostgreSQL: Persistent storage for audit trail and recovery """ def __init__( self, redis_client=None, db_pool=None, namespace: str = "breakpilot", session_ttl_hours: int = 24 ): """ Initialize the session manager. Args: redis_client: Async Redis/Valkey client db_pool: Async PostgreSQL connection pool namespace: Redis key namespace session_ttl_hours: Session TTL in Valkey """ self.redis = redis_client self.db_pool = db_pool self.namespace = namespace self.session_ttl = timedelta(hours=session_ttl_hours) self._local_cache: Dict[str, AgentSession] = {} def _redis_key(self, session_id: str) -> str: """Generate Redis key for session""" return f"{self.namespace}:session:{session_id}" async def create_session( self, agent_type: str, user_id: str = "", context: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None ) -> AgentSession: """ Creates a new agent session. Args: agent_type: Type of agent user_id: Associated user ID context: Initial session context metadata: Additional metadata Returns: The created AgentSession """ session = AgentSession( agent_type=agent_type, user_id=user_id, context=context or {}, metadata=metadata or {} ) session.checkpoint("session_created", { "agent_type": agent_type, "user_id": user_id }) self._local_cache[session.session_id] = session await self._persist_session(session) logger.info( f"Created session {session.session_id} for agent '{agent_type}'" ) return session async def get_session(self, session_id: str) -> Optional[AgentSession]: """ Retrieves a session by ID. Lookup order: 1. Local cache 2. Valkey 3. PostgreSQL Args: session_id: Session ID to retrieve Returns: AgentSession if found, None otherwise """ # Check local cache first if session_id in self._local_cache: return self._local_cache[session_id] # Try Valkey session = await self._get_from_valkey(session_id) if session: self._local_cache[session_id] = session return session # Fall back to PostgreSQL session = await self._get_from_postgres(session_id) if session: # Re-cache in Valkey for faster access await self._cache_in_valkey(session) self._local_cache[session_id] = session return session return None async def update_session(self, session: AgentSession) -> None: """ Updates a session in all storage layers. Args: session: The session to update """ session.heartbeat() self._local_cache[session.session_id] = session self._local_cache[session.session_id] = session await self._persist_session(session) async def delete_session(self, session_id: str) -> bool: """ Deletes a session (soft delete in PostgreSQL). Args: session_id: Session ID to delete Returns: True if deleted, False if not found """ # Remove from local cache self._local_cache.pop(session_id, None) # Remove from Valkey if self.redis: await self.redis.delete(self._redis_key(session_id)) # Soft delete in PostgreSQL if self.db_pool: async with self.db_pool.acquire() as conn: result = await conn.execute( """ UPDATE agent_sessions SET state = 'deleted', updated_at = NOW() WHERE id = $1 """, session_id ) return result == "UPDATE 1" return True async def get_active_sessions( self, agent_type: Optional[str] = None, user_id: Optional[str] = None ) -> List[AgentSession]: """ Gets all active sessions, optionally filtered. Args: agent_type: Filter by agent type user_id: Filter by user ID Returns: List of active sessions """ if not self.db_pool: # Fall back to local cache sessions = [ s for s in self._local_cache.values() if s.state == SessionState.ACTIVE ] if agent_type: sessions = [s for s in sessions if s.agent_type == agent_type] if user_id: sessions = [s for s in sessions if s.user_id == user_id] return sessions async with self.db_pool.acquire() as conn: query = """ SELECT id, agent_type, user_id, state, context, checkpoints, created_at, updated_at, last_heartbeat FROM agent_sessions WHERE state = 'active' """ params = [] if agent_type: query += " AND agent_type = $1" params.append(agent_type) if user_id: param_num = len(params) + 1 query += f" AND user_id = ${param_num}" params.append(user_id) rows = await conn.fetch(query, *params) return [self._row_to_session(row) for row in rows] async def cleanup_stale_sessions( self, max_age: timedelta = timedelta(hours=48) ) -> int: """ Cleans up stale sessions (no heartbeat for max_age). Args: max_age: Maximum age for stale sessions Returns: Number of sessions cleaned up """ cutoff = datetime.now(timezone.utc) - max_age if not self.db_pool: # Clean local cache stale = [ sid for sid, s in self._local_cache.items() if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE ] for sid in stale: session = self._local_cache[sid] session.fail("Session timeout - no heartbeat") return len(stale) async with self.db_pool.acquire() as conn: result = await conn.execute( """ UPDATE agent_sessions SET state = 'failed', updated_at = NOW(), context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb WHERE state = 'active' AND last_heartbeat < $1 """, cutoff ) count = int(result.split()[-1]) if result else 0 logger.info(f"Cleaned up {count} stale sessions") return count async def _persist_session(self, session: AgentSession) -> None: """Persists session to both Valkey and PostgreSQL""" await self._cache_in_valkey(session) await self._save_to_postgres(session) async def _cache_in_valkey(self, session: AgentSession) -> None: """Caches session in Valkey""" if not self.redis: return try: await self.redis.setex( self._redis_key(session.session_id), int(self.session_ttl.total_seconds()), json.dumps(session.to_dict()) ) except Exception as e: logger.warning(f"Failed to cache session in Valkey: {e}") async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]: """Retrieves session from Valkey""" if not self.redis: return None try: data = await self.redis.get(self._redis_key(session_id)) if data: return AgentSession.from_dict(json.loads(data)) except Exception as e: logger.warning(f"Failed to get session from Valkey: {e}") return None async def _save_to_postgres(self, session: AgentSession) -> None: """Saves session to PostgreSQL""" if not self.db_pool: return try: async with self.db_pool.acquire() as conn: await conn.execute( """ INSERT INTO agent_sessions (id, agent_type, user_id, state, context, checkpoints, created_at, updated_at, last_heartbeat) VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8) ON CONFLICT (id) DO UPDATE SET state = EXCLUDED.state, context = EXCLUDED.context, checkpoints = EXCLUDED.checkpoints, updated_at = NOW(), last_heartbeat = EXCLUDED.last_heartbeat """, session.session_id, session.agent_type, session.user_id, session.state.value, json.dumps(session.context), json.dumps([cp.to_dict() for cp in session.checkpoints]), session.created_at, session.last_heartbeat ) except Exception as e: logger.error(f"Failed to save session to PostgreSQL: {e}") async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]: """Retrieves session from PostgreSQL""" if not self.db_pool: return None try: async with self.db_pool.acquire() as conn: row = await conn.fetchrow( """ SELECT id, agent_type, user_id, state, context, checkpoints, created_at, updated_at, last_heartbeat FROM agent_sessions WHERE id = $1 AND state != 'deleted' """, session_id ) if row: return self._row_to_session(row) except Exception as e: logger.error(f"Failed to get session from PostgreSQL: {e}") return None def _row_to_session(self, row) -> AgentSession: """Converts a database row to AgentSession""" checkpoints_data = row["checkpoints"] if isinstance(checkpoints_data, str): checkpoints_data = json.loads(checkpoints_data) context_data = row["context"] if isinstance(context_data, str): context_data = json.loads(context_data) return AgentSession( session_id=str(row["id"]), agent_type=row["agent_type"], user_id=row["user_id"] or "", state=SessionState(row["state"]), created_at=row["created_at"], last_heartbeat=row["last_heartbeat"], context=context_data, checkpoints=[ SessionCheckpoint.from_dict(cp) for cp in checkpoints_data ] )