""" Checkpoint Management for Breakpilot Agents Provides checkpoint-based recovery with: - Named checkpoints for semantic recovery points - Automatic checkpoint compression - Recovery from specific checkpoints - Checkpoint analytics """ from typing import Dict, Any, Optional, List, Callable from datetime import datetime, timezone from dataclasses import dataclass, field import json import logging logger = logging.getLogger(__name__) @dataclass class Checkpoint: """Represents a recovery checkpoint""" id: str name: str timestamp: datetime data: Dict[str, Any] metadata: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return { "id": self.id, "name": self.name, "timestamp": self.timestamp.isoformat(), "data": self.data, "metadata": self.metadata } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint": return cls( id=data["id"], name=data["name"], timestamp=datetime.fromisoformat(data["timestamp"]), data=data["data"], metadata=data.get("metadata", {}) ) class CheckpointManager: """ Manages checkpoints for agent sessions. Provides: - Named checkpoints for semantic recovery - Automatic compression of old checkpoints - Recovery to specific checkpoint states - Analytics on checkpoint patterns """ def __init__( self, session_id: str, max_checkpoints: int = 100, compress_after: int = 50 ): """ Initialize the checkpoint manager. Args: session_id: The session ID this manager belongs to max_checkpoints: Maximum number of checkpoints to retain compress_after: Compress checkpoints after this count """ self.session_id = session_id self.max_checkpoints = max_checkpoints self.compress_after = compress_after self._checkpoints: List[Checkpoint] = [] self._checkpoint_count = 0 self._on_checkpoint: Optional[Callable[[Checkpoint], None]] = None def create( self, name: str, data: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None ) -> Checkpoint: """ Creates a new checkpoint. Args: name: Semantic name for the checkpoint (e.g., "task_started") data: Checkpoint data to store metadata: Optional additional metadata Returns: The created checkpoint """ self._checkpoint_count += 1 checkpoint = Checkpoint( id=f"{self.session_id}:{self._checkpoint_count}", name=name, timestamp=datetime.now(timezone.utc), data=data, metadata=metadata or {} ) self._checkpoints.append(checkpoint) # Compress if needed if len(self._checkpoints) > self.compress_after: self._compress_checkpoints() # Trigger callback if self._on_checkpoint: self._on_checkpoint(checkpoint) logger.debug( f"Session {self.session_id}: Created checkpoint '{name}' " f"(#{self._checkpoint_count})" ) return checkpoint def get(self, checkpoint_id: str) -> Optional[Checkpoint]: """ Gets a checkpoint by ID. Args: checkpoint_id: The checkpoint ID Returns: The checkpoint or None if not found """ for cp in self._checkpoints: if cp.id == checkpoint_id: return cp return None def get_by_name(self, name: str) -> List[Checkpoint]: """ Gets all checkpoints with a given name. Args: name: The checkpoint name Returns: List of matching checkpoints (newest first) """ return [ cp for cp in reversed(self._checkpoints) if cp.name == name ] def get_latest(self, name: Optional[str] = None) -> Optional[Checkpoint]: """ Gets the latest checkpoint, optionally filtered by name. Args: name: Optional name filter Returns: The latest matching checkpoint or None """ if not self._checkpoints: return None if name: matching = self.get_by_name(name) return matching[0] if matching else None return self._checkpoints[-1] def get_all(self) -> List[Checkpoint]: """Returns all checkpoints""" return list(self._checkpoints) def get_since(self, timestamp: datetime) -> List[Checkpoint]: """ Gets all checkpoints since a given timestamp. Args: timestamp: The starting timestamp Returns: List of checkpoints after the timestamp """ return [ cp for cp in self._checkpoints if cp.timestamp > timestamp ] def get_between( self, start: datetime, end: datetime ) -> List[Checkpoint]: """ Gets checkpoints between two timestamps. Args: start: Start timestamp end: End timestamp Returns: List of checkpoints in the range """ return [ cp for cp in self._checkpoints if start <= cp.timestamp <= end ] def rollback_to(self, checkpoint_id: str) -> Optional[Dict[str, Any]]: """ Gets data needed to rollback to a checkpoint. Note: This doesn't actually rollback - it returns the checkpoint data for the caller to use for recovery. Args: checkpoint_id: The checkpoint to rollback to Returns: The checkpoint data or None if not found """ checkpoint = self.get(checkpoint_id) if checkpoint: logger.info( f"Session {self.session_id}: Rollback to checkpoint " f"'{checkpoint.name}' ({checkpoint_id})" ) return checkpoint.data return None def clear(self) -> int: """ Clears all checkpoints. Returns: Number of checkpoints cleared """ count = len(self._checkpoints) self._checkpoints.clear() logger.info(f"Session {self.session_id}: Cleared {count} checkpoints") return count def _compress_checkpoints(self) -> None: """ Compresses old checkpoints to save memory. Keeps: - First checkpoint (session start) - Last N checkpoints (recent history) - One checkpoint per unique name (latest) """ if len(self._checkpoints) <= self.compress_after: return # Keep first checkpoint first = self._checkpoints[0] # Keep last 20 checkpoints recent = self._checkpoints[-20:] # Keep one of each unique name from the middle middle = self._checkpoints[1:-20] by_name: Dict[str, Checkpoint] = {} for cp in middle: # Keep the latest of each name if cp.name not in by_name or cp.timestamp > by_name[cp.name].timestamp: by_name[cp.name] = cp # Combine and sort compressed = [first] + list(by_name.values()) + recent compressed.sort(key=lambda cp: cp.timestamp) old_count = len(self._checkpoints) self._checkpoints = compressed logger.debug( f"Session {self.session_id}: Compressed checkpoints " f"from {old_count} to {len(self._checkpoints)}" ) def get_summary(self) -> Dict[str, Any]: """ Gets a summary of checkpoint activity. Returns: Summary dict with counts and timing info """ if not self._checkpoints: return { "total_count": 0, "unique_names": 0, "names": {}, "first_checkpoint": None, "last_checkpoint": None, "duration_seconds": 0 } name_counts: Dict[str, int] = {} for cp in self._checkpoints: name_counts[cp.name] = name_counts.get(cp.name, 0) + 1 first = self._checkpoints[0] last = self._checkpoints[-1] return { "total_count": len(self._checkpoints), "unique_names": len(name_counts), "names": name_counts, "first_checkpoint": first.to_dict(), "last_checkpoint": last.to_dict(), "duration_seconds": (last.timestamp - first.timestamp).total_seconds() } def on_checkpoint(self, callback: Callable[[Checkpoint], None]) -> None: """ Sets a callback to be called on each checkpoint. Args: callback: Function to call with each checkpoint """ self._on_checkpoint = callback def export(self) -> str: """ Exports all checkpoints to JSON. Returns: JSON string of all checkpoints """ return json.dumps( [cp.to_dict() for cp in self._checkpoints], indent=2 ) def import_checkpoints(self, json_data: str) -> int: """ Imports checkpoints from JSON. Args: json_data: JSON string of checkpoints Returns: Number of checkpoints imported """ data = json.loads(json_data) imported = [Checkpoint.from_dict(cp) for cp in data] self._checkpoints.extend(imported) self._checkpoint_count = max( self._checkpoint_count, len(self._checkpoints) ) return len(imported) def __len__(self) -> int: return len(self._checkpoints) def __iter__(self): return iter(self._checkpoints)