""" Session Management for Breakpilot Agents Provides session lifecycle management with: - Hybrid Valkey + PostgreSQL persistence - Session CRUD operations - Stale session cleanup """ from datetime import datetime, timezone, timedelta from typing import Dict, Any, Optional, List import json import logging from agent_core.sessions.session_models import ( SessionState, SessionCheckpoint, AgentSession, ) logger = logging.getLogger(__name__) class SessionManager: """ Manages agent sessions with hybrid Valkey + PostgreSQL persistence. - Valkey: Fast access for active sessions (24h TTL) - PostgreSQL: Persistent storage for audit trail and recovery """ def __init__( self, redis_client=None, db_pool=None, namespace: str = "breakpilot", session_ttl_hours: int = 24 ): """ Initialize the session manager. Args: redis_client: Async Redis/Valkey client db_pool: Async PostgreSQL connection pool namespace: Redis key namespace session_ttl_hours: Session TTL in Valkey """ self.redis = redis_client self.db_pool = db_pool self.namespace = namespace self.session_ttl = timedelta(hours=session_ttl_hours) self._local_cache: Dict[str, AgentSession] = {} def _redis_key(self, session_id: str) -> str: """Generate Redis key for session""" return f"{self.namespace}:session:{session_id}" async def create_session( self, agent_type: str, user_id: str = "", context: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None ) -> AgentSession: """ Creates a new agent session. Args: agent_type: Type of agent user_id: Associated user ID context: Initial session context metadata: Additional metadata Returns: The created AgentSession """ session = AgentSession( agent_type=agent_type, user_id=user_id, context=context or {}, metadata=metadata or {} ) session.checkpoint("session_created", { "agent_type": agent_type, "user_id": user_id }) self._local_cache[session.session_id] = session await self._persist_session(session) logger.info( f"Created session {session.session_id} for agent '{agent_type}'" ) return session async def get_session(self, session_id: str) -> Optional[AgentSession]: """ Retrieves a session by ID. Lookup order: 1. Local cache 2. Valkey 3. PostgreSQL Args: session_id: Session ID to retrieve Returns: AgentSession if found, None otherwise """ # Check local cache first if session_id in self._local_cache: return self._local_cache[session_id] # Try Valkey session = await self._get_from_valkey(session_id) if session: self._local_cache[session_id] = session return session # Fall back to PostgreSQL session = await self._get_from_postgres(session_id) if session: # Re-cache in Valkey for faster access await self._cache_in_valkey(session) self._local_cache[session_id] = session return session return None async def update_session(self, session: AgentSession) -> None: """ Updates a session in all storage layers. Args: session: The session to update """ session.heartbeat() self._local_cache[session.session_id] = session await self._persist_session(session) async def delete_session(self, session_id: str) -> bool: """ Deletes a session (soft delete in PostgreSQL). Args: session_id: Session ID to delete Returns: True if deleted, False if not found """ # Remove from local cache self._local_cache.pop(session_id, None) # Remove from Valkey if self.redis: await self.redis.delete(self._redis_key(session_id)) # Soft delete in PostgreSQL if self.db_pool: async with self.db_pool.acquire() as conn: result = await conn.execute( """ UPDATE agent_sessions SET state = 'deleted', updated_at = NOW() WHERE id = $1 """, session_id ) return result == "UPDATE 1" return True async def get_active_sessions( self, agent_type: Optional[str] = None, user_id: Optional[str] = None ) -> List[AgentSession]: """ Gets all active sessions, optionally filtered. Args: agent_type: Filter by agent type user_id: Filter by user ID Returns: List of active sessions """ if not self.db_pool: # Fall back to local cache sessions = [ s for s in self._local_cache.values() if s.state == SessionState.ACTIVE ] if agent_type: sessions = [s for s in sessions if s.agent_type == agent_type] if user_id: sessions = [s for s in sessions if s.user_id == user_id] return sessions async with self.db_pool.acquire() as conn: query = """ SELECT id, agent_type, user_id, state, context, checkpoints, created_at, updated_at, last_heartbeat FROM agent_sessions WHERE state = 'active' """ params = [] if agent_type: query += " AND agent_type = $1" params.append(agent_type) if user_id: param_num = len(params) + 1 query += f" AND user_id = ${param_num}" params.append(user_id) rows = await conn.fetch(query, *params) return [self._row_to_session(row) for row in rows] async def cleanup_stale_sessions( self, max_age: timedelta = timedelta(hours=48) ) -> int: """ Cleans up stale sessions (no heartbeat for max_age). Args: max_age: Maximum age for stale sessions Returns: Number of sessions cleaned up """ cutoff = datetime.now(timezone.utc) - max_age if not self.db_pool: # Clean local cache stale = [ sid for sid, s in self._local_cache.items() if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE ] for sid in stale: session = self._local_cache[sid] session.fail("Session timeout - no heartbeat") return len(stale) async with self.db_pool.acquire() as conn: result = await conn.execute( """ UPDATE agent_sessions SET state = 'failed', updated_at = NOW(), context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb WHERE state = 'active' AND last_heartbeat < $1 """, cutoff ) count = int(result.split()[-1]) if result else 0 logger.info(f"Cleaned up {count} stale sessions") return count async def _persist_session(self, session: AgentSession) -> None: """Persists session to both Valkey and PostgreSQL""" await self._cache_in_valkey(session) await self._save_to_postgres(session) async def _cache_in_valkey(self, session: AgentSession) -> None: """Caches session in Valkey""" if not self.redis: return try: await self.redis.setex( self._redis_key(session.session_id), int(self.session_ttl.total_seconds()), json.dumps(session.to_dict()) ) except Exception as e: logger.warning(f"Failed to cache session in Valkey: {e}") async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]: """Retrieves session from Valkey""" if not self.redis: return None try: data = await self.redis.get(self._redis_key(session_id)) if data: return AgentSession.from_dict(json.loads(data)) except Exception as e: logger.warning(f"Failed to get session from Valkey: {e}") return None async def _save_to_postgres(self, session: AgentSession) -> None: """Saves session to PostgreSQL""" if not self.db_pool: return try: async with self.db_pool.acquire() as conn: await conn.execute( """ INSERT INTO agent_sessions (id, agent_type, user_id, state, context, checkpoints, created_at, updated_at, last_heartbeat) VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8) ON CONFLICT (id) DO UPDATE SET state = EXCLUDED.state, context = EXCLUDED.context, checkpoints = EXCLUDED.checkpoints, updated_at = NOW(), last_heartbeat = EXCLUDED.last_heartbeat """, session.session_id, session.agent_type, session.user_id, session.state.value, json.dumps(session.context), json.dumps([cp.to_dict() for cp in session.checkpoints]), session.created_at, session.last_heartbeat ) except Exception as e: logger.error(f"Failed to save session to PostgreSQL: {e}") async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]: """Retrieves session from PostgreSQL""" if not self.db_pool: return None try: async with self.db_pool.acquire() as conn: row = await conn.fetchrow( """ SELECT id, agent_type, user_id, state, context, checkpoints, created_at, updated_at, last_heartbeat FROM agent_sessions WHERE id = $1 AND state != 'deleted' """, session_id ) if row: return self._row_to_session(row) except Exception as e: logger.error(f"Failed to get session from PostgreSQL: {e}") return None def _row_to_session(self, row) -> AgentSession: """Converts a database row to AgentSession""" checkpoints_data = row["checkpoints"] if isinstance(checkpoints_data, str): checkpoints_data = json.loads(checkpoints_data) context_data = row["context"] if isinstance(context_data, str): context_data = json.loads(context_data) return AgentSession( session_id=str(row["id"]), agent_type=row["agent_type"], user_id=row["user_id"] or "", state=SessionState(row["state"]), created_at=row["created_at"], last_heartbeat=row["last_heartbeat"], context=context_data, checkpoints=[ SessionCheckpoint.from_dict(cp) for cp in checkpoints_data ] )