Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
363 lines
9.8 KiB
Python
363 lines
9.8 KiB
Python
"""
|
|
Checkpoint Management for Breakpilot Agents
|
|
|
|
Provides checkpoint-based recovery with:
|
|
- Named checkpoints for semantic recovery points
|
|
- Automatic checkpoint compression
|
|
- Recovery from specific checkpoints
|
|
- Checkpoint analytics
|
|
"""
|
|
|
|
from typing import Dict, Any, Optional, List, Callable
|
|
from datetime import datetime, timezone
|
|
from dataclasses import dataclass, field
|
|
import json
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Checkpoint:
|
|
"""Represents a recovery checkpoint"""
|
|
id: str
|
|
name: str
|
|
timestamp: datetime
|
|
data: Dict[str, Any]
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"timestamp": self.timestamp.isoformat(),
|
|
"data": self.data,
|
|
"metadata": self.metadata
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint":
|
|
return cls(
|
|
id=data["id"],
|
|
name=data["name"],
|
|
timestamp=datetime.fromisoformat(data["timestamp"]),
|
|
data=data["data"],
|
|
metadata=data.get("metadata", {})
|
|
)
|
|
|
|
|
|
class CheckpointManager:
|
|
"""
|
|
Manages checkpoints for agent sessions.
|
|
|
|
Provides:
|
|
- Named checkpoints for semantic recovery
|
|
- Automatic compression of old checkpoints
|
|
- Recovery to specific checkpoint states
|
|
- Analytics on checkpoint patterns
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
session_id: str,
|
|
max_checkpoints: int = 100,
|
|
compress_after: int = 50
|
|
):
|
|
"""
|
|
Initialize the checkpoint manager.
|
|
|
|
Args:
|
|
session_id: The session ID this manager belongs to
|
|
max_checkpoints: Maximum number of checkpoints to retain
|
|
compress_after: Compress checkpoints after this count
|
|
"""
|
|
self.session_id = session_id
|
|
self.max_checkpoints = max_checkpoints
|
|
self.compress_after = compress_after
|
|
self._checkpoints: List[Checkpoint] = []
|
|
self._checkpoint_count = 0
|
|
self._on_checkpoint: Optional[Callable[[Checkpoint], None]] = None
|
|
|
|
def create(
|
|
self,
|
|
name: str,
|
|
data: Dict[str, Any],
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> Checkpoint:
|
|
"""
|
|
Creates a new checkpoint.
|
|
|
|
Args:
|
|
name: Semantic name for the checkpoint (e.g., "task_started")
|
|
data: Checkpoint data to store
|
|
metadata: Optional additional metadata
|
|
|
|
Returns:
|
|
The created checkpoint
|
|
"""
|
|
self._checkpoint_count += 1
|
|
|
|
checkpoint = Checkpoint(
|
|
id=f"{self.session_id}:{self._checkpoint_count}",
|
|
name=name,
|
|
timestamp=datetime.now(timezone.utc),
|
|
data=data,
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
self._checkpoints.append(checkpoint)
|
|
|
|
# Compress if needed
|
|
if len(self._checkpoints) > self.compress_after:
|
|
self._compress_checkpoints()
|
|
|
|
# Trigger callback
|
|
if self._on_checkpoint:
|
|
self._on_checkpoint(checkpoint)
|
|
|
|
logger.debug(
|
|
f"Session {self.session_id}: Created checkpoint '{name}' "
|
|
f"(#{self._checkpoint_count})"
|
|
)
|
|
|
|
return checkpoint
|
|
|
|
def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
|
|
"""
|
|
Gets a checkpoint by ID.
|
|
|
|
Args:
|
|
checkpoint_id: The checkpoint ID
|
|
|
|
Returns:
|
|
The checkpoint or None if not found
|
|
"""
|
|
for cp in self._checkpoints:
|
|
if cp.id == checkpoint_id:
|
|
return cp
|
|
return None
|
|
|
|
def get_by_name(self, name: str) -> List[Checkpoint]:
|
|
"""
|
|
Gets all checkpoints with a given name.
|
|
|
|
Args:
|
|
name: The checkpoint name
|
|
|
|
Returns:
|
|
List of matching checkpoints (newest first)
|
|
"""
|
|
return [
|
|
cp for cp in reversed(self._checkpoints)
|
|
if cp.name == name
|
|
]
|
|
|
|
def get_latest(self, name: Optional[str] = None) -> Optional[Checkpoint]:
|
|
"""
|
|
Gets the latest checkpoint, optionally filtered by name.
|
|
|
|
Args:
|
|
name: Optional name filter
|
|
|
|
Returns:
|
|
The latest matching checkpoint or None
|
|
"""
|
|
if not self._checkpoints:
|
|
return None
|
|
|
|
if name:
|
|
matching = self.get_by_name(name)
|
|
return matching[0] if matching else None
|
|
|
|
return self._checkpoints[-1]
|
|
|
|
def get_all(self) -> List[Checkpoint]:
|
|
"""Returns all checkpoints"""
|
|
return list(self._checkpoints)
|
|
|
|
def get_since(self, timestamp: datetime) -> List[Checkpoint]:
|
|
"""
|
|
Gets all checkpoints since a given timestamp.
|
|
|
|
Args:
|
|
timestamp: The starting timestamp
|
|
|
|
Returns:
|
|
List of checkpoints after the timestamp
|
|
"""
|
|
return [
|
|
cp for cp in self._checkpoints
|
|
if cp.timestamp > timestamp
|
|
]
|
|
|
|
def get_between(
|
|
self,
|
|
start: datetime,
|
|
end: datetime
|
|
) -> List[Checkpoint]:
|
|
"""
|
|
Gets checkpoints between two timestamps.
|
|
|
|
Args:
|
|
start: Start timestamp
|
|
end: End timestamp
|
|
|
|
Returns:
|
|
List of checkpoints in the range
|
|
"""
|
|
return [
|
|
cp for cp in self._checkpoints
|
|
if start <= cp.timestamp <= end
|
|
]
|
|
|
|
def rollback_to(self, checkpoint_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Gets data needed to rollback to a checkpoint.
|
|
|
|
Note: This doesn't actually rollback - it returns the checkpoint
|
|
data for the caller to use for recovery.
|
|
|
|
Args:
|
|
checkpoint_id: The checkpoint to rollback to
|
|
|
|
Returns:
|
|
The checkpoint data or None if not found
|
|
"""
|
|
checkpoint = self.get(checkpoint_id)
|
|
if checkpoint:
|
|
logger.info(
|
|
f"Session {self.session_id}: Rollback to checkpoint "
|
|
f"'{checkpoint.name}' ({checkpoint_id})"
|
|
)
|
|
return checkpoint.data
|
|
return None
|
|
|
|
def clear(self) -> int:
|
|
"""
|
|
Clears all checkpoints.
|
|
|
|
Returns:
|
|
Number of checkpoints cleared
|
|
"""
|
|
count = len(self._checkpoints)
|
|
self._checkpoints.clear()
|
|
logger.info(f"Session {self.session_id}: Cleared {count} checkpoints")
|
|
return count
|
|
|
|
def _compress_checkpoints(self) -> None:
|
|
"""
|
|
Compresses old checkpoints to save memory.
|
|
|
|
Keeps:
|
|
- First checkpoint (session start)
|
|
- Last N checkpoints (recent history)
|
|
- One checkpoint per unique name (latest)
|
|
"""
|
|
if len(self._checkpoints) <= self.compress_after:
|
|
return
|
|
|
|
# Keep first checkpoint
|
|
first = self._checkpoints[0]
|
|
|
|
# Keep last 20 checkpoints
|
|
recent = self._checkpoints[-20:]
|
|
|
|
# Keep one of each unique name from the middle
|
|
middle = self._checkpoints[1:-20]
|
|
by_name: Dict[str, Checkpoint] = {}
|
|
for cp in middle:
|
|
# Keep the latest of each name
|
|
if cp.name not in by_name or cp.timestamp > by_name[cp.name].timestamp:
|
|
by_name[cp.name] = cp
|
|
|
|
# Combine and sort
|
|
compressed = [first] + list(by_name.values()) + recent
|
|
compressed.sort(key=lambda cp: cp.timestamp)
|
|
|
|
old_count = len(self._checkpoints)
|
|
self._checkpoints = compressed
|
|
|
|
logger.debug(
|
|
f"Session {self.session_id}: Compressed checkpoints "
|
|
f"from {old_count} to {len(self._checkpoints)}"
|
|
)
|
|
|
|
def get_summary(self) -> Dict[str, Any]:
|
|
"""
|
|
Gets a summary of checkpoint activity.
|
|
|
|
Returns:
|
|
Summary dict with counts and timing info
|
|
"""
|
|
if not self._checkpoints:
|
|
return {
|
|
"total_count": 0,
|
|
"unique_names": 0,
|
|
"names": {},
|
|
"first_checkpoint": None,
|
|
"last_checkpoint": None,
|
|
"duration_seconds": 0
|
|
}
|
|
|
|
name_counts: Dict[str, int] = {}
|
|
for cp in self._checkpoints:
|
|
name_counts[cp.name] = name_counts.get(cp.name, 0) + 1
|
|
|
|
first = self._checkpoints[0]
|
|
last = self._checkpoints[-1]
|
|
|
|
return {
|
|
"total_count": len(self._checkpoints),
|
|
"unique_names": len(name_counts),
|
|
"names": name_counts,
|
|
"first_checkpoint": first.to_dict(),
|
|
"last_checkpoint": last.to_dict(),
|
|
"duration_seconds": (last.timestamp - first.timestamp).total_seconds()
|
|
}
|
|
|
|
def on_checkpoint(self, callback: Callable[[Checkpoint], None]) -> None:
|
|
"""
|
|
Sets a callback to be called on each checkpoint.
|
|
|
|
Args:
|
|
callback: Function to call with each checkpoint
|
|
"""
|
|
self._on_checkpoint = callback
|
|
|
|
def export(self) -> str:
|
|
"""
|
|
Exports all checkpoints to JSON.
|
|
|
|
Returns:
|
|
JSON string of all checkpoints
|
|
"""
|
|
return json.dumps(
|
|
[cp.to_dict() for cp in self._checkpoints],
|
|
indent=2
|
|
)
|
|
|
|
def import_checkpoints(self, json_data: str) -> int:
|
|
"""
|
|
Imports checkpoints from JSON.
|
|
|
|
Args:
|
|
json_data: JSON string of checkpoints
|
|
|
|
Returns:
|
|
Number of checkpoints imported
|
|
"""
|
|
data = json.loads(json_data)
|
|
imported = [Checkpoint.from_dict(cp) for cp in data]
|
|
self._checkpoints.extend(imported)
|
|
self._checkpoint_count = max(
|
|
self._checkpoint_count,
|
|
len(self._checkpoints)
|
|
)
|
|
return len(imported)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._checkpoints)
|
|
|
|
def __iter__(self):
|
|
return iter(self._checkpoints)
|