Files
breakpilot-lehrer/agent-core/sessions/checkpoint.py
Benjamin Boenisch 5a31f52310 Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website,
Klausur-Service, School-Service, Voice-Service, Geo-Service,
BreakPilot Drive, Agent-Core

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:26 +01:00

363 lines
9.8 KiB
Python

"""
Checkpoint Management for Breakpilot Agents
Provides checkpoint-based recovery with:
- Named checkpoints for semantic recovery points
- Automatic checkpoint compression
- Recovery from specific checkpoints
- Checkpoint analytics
"""
from typing import Dict, Any, Optional, List, Callable
from datetime import datetime, timezone
from dataclasses import dataclass, field
import json
import logging
logger = logging.getLogger(__name__)
@dataclass
class Checkpoint:
"""Represents a recovery checkpoint"""
id: str
name: str
timestamp: datetime
data: Dict[str, Any]
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"timestamp": self.timestamp.isoformat(),
"data": self.data,
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint":
return cls(
id=data["id"],
name=data["name"],
timestamp=datetime.fromisoformat(data["timestamp"]),
data=data["data"],
metadata=data.get("metadata", {})
)
class CheckpointManager:
"""
Manages checkpoints for agent sessions.
Provides:
- Named checkpoints for semantic recovery
- Automatic compression of old checkpoints
- Recovery to specific checkpoint states
- Analytics on checkpoint patterns
"""
def __init__(
self,
session_id: str,
max_checkpoints: int = 100,
compress_after: int = 50
):
"""
Initialize the checkpoint manager.
Args:
session_id: The session ID this manager belongs to
max_checkpoints: Maximum number of checkpoints to retain
compress_after: Compress checkpoints after this count
"""
self.session_id = session_id
self.max_checkpoints = max_checkpoints
self.compress_after = compress_after
self._checkpoints: List[Checkpoint] = []
self._checkpoint_count = 0
self._on_checkpoint: Optional[Callable[[Checkpoint], None]] = None
def create(
self,
name: str,
data: Dict[str, Any],
metadata: Optional[Dict[str, Any]] = None
) -> Checkpoint:
"""
Creates a new checkpoint.
Args:
name: Semantic name for the checkpoint (e.g., "task_started")
data: Checkpoint data to store
metadata: Optional additional metadata
Returns:
The created checkpoint
"""
self._checkpoint_count += 1
checkpoint = Checkpoint(
id=f"{self.session_id}:{self._checkpoint_count}",
name=name,
timestamp=datetime.now(timezone.utc),
data=data,
metadata=metadata or {}
)
self._checkpoints.append(checkpoint)
# Compress if needed
if len(self._checkpoints) > self.compress_after:
self._compress_checkpoints()
# Trigger callback
if self._on_checkpoint:
self._on_checkpoint(checkpoint)
logger.debug(
f"Session {self.session_id}: Created checkpoint '{name}' "
f"(#{self._checkpoint_count})"
)
return checkpoint
def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
"""
Gets a checkpoint by ID.
Args:
checkpoint_id: The checkpoint ID
Returns:
The checkpoint or None if not found
"""
for cp in self._checkpoints:
if cp.id == checkpoint_id:
return cp
return None
def get_by_name(self, name: str) -> List[Checkpoint]:
"""
Gets all checkpoints with a given name.
Args:
name: The checkpoint name
Returns:
List of matching checkpoints (newest first)
"""
return [
cp for cp in reversed(self._checkpoints)
if cp.name == name
]
def get_latest(self, name: Optional[str] = None) -> Optional[Checkpoint]:
"""
Gets the latest checkpoint, optionally filtered by name.
Args:
name: Optional name filter
Returns:
The latest matching checkpoint or None
"""
if not self._checkpoints:
return None
if name:
matching = self.get_by_name(name)
return matching[0] if matching else None
return self._checkpoints[-1]
def get_all(self) -> List[Checkpoint]:
"""Returns all checkpoints"""
return list(self._checkpoints)
def get_since(self, timestamp: datetime) -> List[Checkpoint]:
"""
Gets all checkpoints since a given timestamp.
Args:
timestamp: The starting timestamp
Returns:
List of checkpoints after the timestamp
"""
return [
cp for cp in self._checkpoints
if cp.timestamp > timestamp
]
def get_between(
self,
start: datetime,
end: datetime
) -> List[Checkpoint]:
"""
Gets checkpoints between two timestamps.
Args:
start: Start timestamp
end: End timestamp
Returns:
List of checkpoints in the range
"""
return [
cp for cp in self._checkpoints
if start <= cp.timestamp <= end
]
def rollback_to(self, checkpoint_id: str) -> Optional[Dict[str, Any]]:
"""
Gets data needed to rollback to a checkpoint.
Note: This doesn't actually rollback - it returns the checkpoint
data for the caller to use for recovery.
Args:
checkpoint_id: The checkpoint to rollback to
Returns:
The checkpoint data or None if not found
"""
checkpoint = self.get(checkpoint_id)
if checkpoint:
logger.info(
f"Session {self.session_id}: Rollback to checkpoint "
f"'{checkpoint.name}' ({checkpoint_id})"
)
return checkpoint.data
return None
def clear(self) -> int:
"""
Clears all checkpoints.
Returns:
Number of checkpoints cleared
"""
count = len(self._checkpoints)
self._checkpoints.clear()
logger.info(f"Session {self.session_id}: Cleared {count} checkpoints")
return count
def _compress_checkpoints(self) -> None:
"""
Compresses old checkpoints to save memory.
Keeps:
- First checkpoint (session start)
- Last N checkpoints (recent history)
- One checkpoint per unique name (latest)
"""
if len(self._checkpoints) <= self.compress_after:
return
# Keep first checkpoint
first = self._checkpoints[0]
# Keep last 20 checkpoints
recent = self._checkpoints[-20:]
# Keep one of each unique name from the middle
middle = self._checkpoints[1:-20]
by_name: Dict[str, Checkpoint] = {}
for cp in middle:
# Keep the latest of each name
if cp.name not in by_name or cp.timestamp > by_name[cp.name].timestamp:
by_name[cp.name] = cp
# Combine and sort
compressed = [first] + list(by_name.values()) + recent
compressed.sort(key=lambda cp: cp.timestamp)
old_count = len(self._checkpoints)
self._checkpoints = compressed
logger.debug(
f"Session {self.session_id}: Compressed checkpoints "
f"from {old_count} to {len(self._checkpoints)}"
)
def get_summary(self) -> Dict[str, Any]:
"""
Gets a summary of checkpoint activity.
Returns:
Summary dict with counts and timing info
"""
if not self._checkpoints:
return {
"total_count": 0,
"unique_names": 0,
"names": {},
"first_checkpoint": None,
"last_checkpoint": None,
"duration_seconds": 0
}
name_counts: Dict[str, int] = {}
for cp in self._checkpoints:
name_counts[cp.name] = name_counts.get(cp.name, 0) + 1
first = self._checkpoints[0]
last = self._checkpoints[-1]
return {
"total_count": len(self._checkpoints),
"unique_names": len(name_counts),
"names": name_counts,
"first_checkpoint": first.to_dict(),
"last_checkpoint": last.to_dict(),
"duration_seconds": (last.timestamp - first.timestamp).total_seconds()
}
def on_checkpoint(self, callback: Callable[[Checkpoint], None]) -> None:
"""
Sets a callback to be called on each checkpoint.
Args:
callback: Function to call with each checkpoint
"""
self._on_checkpoint = callback
def export(self) -> str:
"""
Exports all checkpoints to JSON.
Returns:
JSON string of all checkpoints
"""
return json.dumps(
[cp.to_dict() for cp in self._checkpoints],
indent=2
)
def import_checkpoints(self, json_data: str) -> int:
"""
Imports checkpoints from JSON.
Args:
json_data: JSON string of checkpoints
Returns:
Number of checkpoints imported
"""
data = json.loads(json_data)
imported = [Checkpoint.from_dict(cp) for cp in data]
self._checkpoints.extend(imported)
self._checkpoint_count = max(
self._checkpoint_count,
len(self._checkpoints)
)
return len(imported)
def __len__(self) -> int:
return len(self._checkpoints)
def __iter__(self):
return iter(self._checkpoints)