""" Session Models for Breakpilot Agents Data classes for agent sessions, checkpoints, and state tracking. """ from dataclasses import dataclass, field from datetime import datetime, timezone, timedelta from typing import Dict, Any, Optional, List from enum import Enum import uuid import logging logger = logging.getLogger(__name__) class SessionState(Enum): """Agent session states""" ACTIVE = "active" PAUSED = "paused" COMPLETED = "completed" FAILED = "failed" @dataclass class SessionCheckpoint: """Represents a checkpoint in an agent session""" name: str timestamp: datetime data: Dict[str, Any] def to_dict(self) -> Dict[str, Any]: return { "name": self.name, "timestamp": self.timestamp.isoformat(), "data": self.data } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint": return cls( name=data["name"], timestamp=datetime.fromisoformat(data["timestamp"]), data=data["data"] ) @dataclass class AgentSession: """ Represents an active agent session. Attributes: session_id: Unique session identifier agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator) user_id: Associated user ID state: Current session state created_at: Session creation timestamp last_heartbeat: Last heartbeat timestamp context: Session context data checkpoints: List of session checkpoints for recovery """ session_id: str = field(default_factory=lambda: str(uuid.uuid4())) agent_type: str = "" user_id: str = "" state: SessionState = SessionState.ACTIVE created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) context: Dict[str, Any] = field(default_factory=dict) checkpoints: List[SessionCheckpoint] = field(default_factory=list) metadata: Dict[str, Any] = field(default_factory=dict) def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint: """ Creates a checkpoint for recovery. Args: name: Checkpoint name (e.g., "task_received", "processing_complete") data: Checkpoint data to store Returns: The created checkpoint """ checkpoint = SessionCheckpoint( name=name, timestamp=datetime.now(timezone.utc), data=data ) self.checkpoints.append(checkpoint) logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created") return checkpoint def heartbeat(self) -> None: """Updates the heartbeat timestamp""" self.last_heartbeat = datetime.now(timezone.utc) def pause(self) -> None: """Pauses the session""" self.state = SessionState.PAUSED self.checkpoint("session_paused", {"previous_state": "active"}) def resume(self) -> None: """Resumes a paused session""" if self.state == SessionState.PAUSED: self.state = SessionState.ACTIVE self.heartbeat() self.checkpoint("session_resumed", {}) def complete(self, result: Optional[Dict[str, Any]] = None) -> None: """Marks the session as completed""" self.state = SessionState.COMPLETED self.checkpoint("session_completed", {"result": result or {}}) def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None: """Marks the session as failed""" self.state = SessionState.FAILED self.checkpoint("session_failed", { "error": error, "details": error_details or {} }) def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]: """ Gets the last checkpoint, optionally filtered by name. Args: name: Optional checkpoint name to filter by Returns: The last matching checkpoint or None """ if not self.checkpoints: return None if name: matching = [cp for cp in self.checkpoints if cp.name == name] return matching[-1] if matching else None return self.checkpoints[-1] def get_duration(self) -> timedelta: """Returns the session duration""" end_time = datetime.now(timezone.utc) if self.state in (SessionState.COMPLETED, SessionState.FAILED): last_cp = self.get_last_checkpoint() if last_cp: end_time = last_cp.timestamp return end_time - self.created_at def to_dict(self) -> Dict[str, Any]: """Serializes the session to a dictionary""" return { "session_id": self.session_id, "agent_type": self.agent_type, "user_id": self.user_id, "state": self.state.value, "created_at": self.created_at.isoformat(), "last_heartbeat": self.last_heartbeat.isoformat(), "context": self.context, "checkpoints": [cp.to_dict() for cp in self.checkpoints], "metadata": self.metadata } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AgentSession": """Deserializes a session from a dictionary""" return cls( session_id=data["session_id"], agent_type=data["agent_type"], user_id=data["user_id"], state=SessionState(data["state"]), created_at=datetime.fromisoformat(data["created_at"]), last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]), context=data.get("context", {}), checkpoints=[ SessionCheckpoint.from_dict(cp) for cp in data.get("checkpoints", []) ], metadata=data.get("metadata", {}) )