fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
25
agent-core/sessions/__init__.py
Normal file
25
agent-core/sessions/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Session Management for Breakpilot Agents
|
||||
|
||||
Provides:
|
||||
- AgentSession: Individual agent session with context and checkpoints
|
||||
- SessionManager: Create, retrieve, and manage agent sessions
|
||||
- HeartbeatMonitor: Monitor agent liveness
|
||||
- SessionState: Session state enumeration
|
||||
"""
|
||||
|
||||
from agent_core.sessions.session_manager import (
|
||||
AgentSession,
|
||||
SessionManager,
|
||||
SessionState,
|
||||
)
|
||||
from agent_core.sessions.heartbeat import HeartbeatMonitor
|
||||
from agent_core.sessions.checkpoint import CheckpointManager
|
||||
|
||||
__all__ = [
|
||||
"AgentSession",
|
||||
"SessionManager",
|
||||
"SessionState",
|
||||
"HeartbeatMonitor",
|
||||
"CheckpointManager",
|
||||
]
|
||||
362
agent-core/sessions/checkpoint.py
Normal file
362
agent-core/sessions/checkpoint.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Checkpoint Management for Breakpilot Agents
|
||||
|
||||
Provides checkpoint-based recovery with:
|
||||
- Named checkpoints for semantic recovery points
|
||||
- Automatic checkpoint compression
|
||||
- Recovery from specific checkpoints
|
||||
- Checkpoint analytics
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Callable
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
"""Represents a recovery checkpoint"""
|
||||
id: str
|
||||
name: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
data=data["data"],
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Manages checkpoints for agent sessions.
|
||||
|
||||
Provides:
|
||||
- Named checkpoints for semantic recovery
|
||||
- Automatic compression of old checkpoints
|
||||
- Recovery to specific checkpoint states
|
||||
- Analytics on checkpoint patterns
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
max_checkpoints: int = 100,
|
||||
compress_after: int = 50
|
||||
):
|
||||
"""
|
||||
Initialize the checkpoint manager.
|
||||
|
||||
Args:
|
||||
session_id: The session ID this manager belongs to
|
||||
max_checkpoints: Maximum number of checkpoints to retain
|
||||
compress_after: Compress checkpoints after this count
|
||||
"""
|
||||
self.session_id = session_id
|
||||
self.max_checkpoints = max_checkpoints
|
||||
self.compress_after = compress_after
|
||||
self._checkpoints: List[Checkpoint] = []
|
||||
self._checkpoint_count = 0
|
||||
self._on_checkpoint: Optional[Callable[[Checkpoint], None]] = None
|
||||
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
data: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Checkpoint:
|
||||
"""
|
||||
Creates a new checkpoint.
|
||||
|
||||
Args:
|
||||
name: Semantic name for the checkpoint (e.g., "task_started")
|
||||
data: Checkpoint data to store
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
self._checkpoint_count += 1
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=f"{self.session_id}:{self._checkpoint_count}",
|
||||
name=name,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data=data,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
self._checkpoints.append(checkpoint)
|
||||
|
||||
# Compress if needed
|
||||
if len(self._checkpoints) > self.compress_after:
|
||||
self._compress_checkpoints()
|
||||
|
||||
# Trigger callback
|
||||
if self._on_checkpoint:
|
||||
self._on_checkpoint(checkpoint)
|
||||
|
||||
logger.debug(
|
||||
f"Session {self.session_id}: Created checkpoint '{name}' "
|
||||
f"(#{self._checkpoint_count})"
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
|
||||
"""
|
||||
Gets a checkpoint by ID.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The checkpoint ID
|
||||
|
||||
Returns:
|
||||
The checkpoint or None if not found
|
||||
"""
|
||||
for cp in self._checkpoints:
|
||||
if cp.id == checkpoint_id:
|
||||
return cp
|
||||
return None
|
||||
|
||||
def get_by_name(self, name: str) -> List[Checkpoint]:
|
||||
"""
|
||||
Gets all checkpoints with a given name.
|
||||
|
||||
Args:
|
||||
name: The checkpoint name
|
||||
|
||||
Returns:
|
||||
List of matching checkpoints (newest first)
|
||||
"""
|
||||
return [
|
||||
cp for cp in reversed(self._checkpoints)
|
||||
if cp.name == name
|
||||
]
|
||||
|
||||
def get_latest(self, name: Optional[str] = None) -> Optional[Checkpoint]:
|
||||
"""
|
||||
Gets the latest checkpoint, optionally filtered by name.
|
||||
|
||||
Args:
|
||||
name: Optional name filter
|
||||
|
||||
Returns:
|
||||
The latest matching checkpoint or None
|
||||
"""
|
||||
if not self._checkpoints:
|
||||
return None
|
||||
|
||||
if name:
|
||||
matching = self.get_by_name(name)
|
||||
return matching[0] if matching else None
|
||||
|
||||
return self._checkpoints[-1]
|
||||
|
||||
def get_all(self) -> List[Checkpoint]:
|
||||
"""Returns all checkpoints"""
|
||||
return list(self._checkpoints)
|
||||
|
||||
def get_since(self, timestamp: datetime) -> List[Checkpoint]:
|
||||
"""
|
||||
Gets all checkpoints since a given timestamp.
|
||||
|
||||
Args:
|
||||
timestamp: The starting timestamp
|
||||
|
||||
Returns:
|
||||
List of checkpoints after the timestamp
|
||||
"""
|
||||
return [
|
||||
cp for cp in self._checkpoints
|
||||
if cp.timestamp > timestamp
|
||||
]
|
||||
|
||||
def get_between(
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
) -> List[Checkpoint]:
|
||||
"""
|
||||
Gets checkpoints between two timestamps.
|
||||
|
||||
Args:
|
||||
start: Start timestamp
|
||||
end: End timestamp
|
||||
|
||||
Returns:
|
||||
List of checkpoints in the range
|
||||
"""
|
||||
return [
|
||||
cp for cp in self._checkpoints
|
||||
if start <= cp.timestamp <= end
|
||||
]
|
||||
|
||||
def rollback_to(self, checkpoint_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Gets data needed to rollback to a checkpoint.
|
||||
|
||||
Note: This doesn't actually rollback - it returns the checkpoint
|
||||
data for the caller to use for recovery.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The checkpoint to rollback to
|
||||
|
||||
Returns:
|
||||
The checkpoint data or None if not found
|
||||
"""
|
||||
checkpoint = self.get(checkpoint_id)
|
||||
if checkpoint:
|
||||
logger.info(
|
||||
f"Session {self.session_id}: Rollback to checkpoint "
|
||||
f"'{checkpoint.name}' ({checkpoint_id})"
|
||||
)
|
||||
return checkpoint.data
|
||||
return None
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clears all checkpoints.
|
||||
|
||||
Returns:
|
||||
Number of checkpoints cleared
|
||||
"""
|
||||
count = len(self._checkpoints)
|
||||
self._checkpoints.clear()
|
||||
logger.info(f"Session {self.session_id}: Cleared {count} checkpoints")
|
||||
return count
|
||||
|
||||
def _compress_checkpoints(self) -> None:
|
||||
"""
|
||||
Compresses old checkpoints to save memory.
|
||||
|
||||
Keeps:
|
||||
- First checkpoint (session start)
|
||||
- Last N checkpoints (recent history)
|
||||
- One checkpoint per unique name (latest)
|
||||
"""
|
||||
if len(self._checkpoints) <= self.compress_after:
|
||||
return
|
||||
|
||||
# Keep first checkpoint
|
||||
first = self._checkpoints[0]
|
||||
|
||||
# Keep last 20 checkpoints
|
||||
recent = self._checkpoints[-20:]
|
||||
|
||||
# Keep one of each unique name from the middle
|
||||
middle = self._checkpoints[1:-20]
|
||||
by_name: Dict[str, Checkpoint] = {}
|
||||
for cp in middle:
|
||||
# Keep the latest of each name
|
||||
if cp.name not in by_name or cp.timestamp > by_name[cp.name].timestamp:
|
||||
by_name[cp.name] = cp
|
||||
|
||||
# Combine and sort
|
||||
compressed = [first] + list(by_name.values()) + recent
|
||||
compressed.sort(key=lambda cp: cp.timestamp)
|
||||
|
||||
old_count = len(self._checkpoints)
|
||||
self._checkpoints = compressed
|
||||
|
||||
logger.debug(
|
||||
f"Session {self.session_id}: Compressed checkpoints "
|
||||
f"from {old_count} to {len(self._checkpoints)}"
|
||||
)
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets a summary of checkpoint activity.
|
||||
|
||||
Returns:
|
||||
Summary dict with counts and timing info
|
||||
"""
|
||||
if not self._checkpoints:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"unique_names": 0,
|
||||
"names": {},
|
||||
"first_checkpoint": None,
|
||||
"last_checkpoint": None,
|
||||
"duration_seconds": 0
|
||||
}
|
||||
|
||||
name_counts: Dict[str, int] = {}
|
||||
for cp in self._checkpoints:
|
||||
name_counts[cp.name] = name_counts.get(cp.name, 0) + 1
|
||||
|
||||
first = self._checkpoints[0]
|
||||
last = self._checkpoints[-1]
|
||||
|
||||
return {
|
||||
"total_count": len(self._checkpoints),
|
||||
"unique_names": len(name_counts),
|
||||
"names": name_counts,
|
||||
"first_checkpoint": first.to_dict(),
|
||||
"last_checkpoint": last.to_dict(),
|
||||
"duration_seconds": (last.timestamp - first.timestamp).total_seconds()
|
||||
}
|
||||
|
||||
def on_checkpoint(self, callback: Callable[[Checkpoint], None]) -> None:
|
||||
"""
|
||||
Sets a callback to be called on each checkpoint.
|
||||
|
||||
Args:
|
||||
callback: Function to call with each checkpoint
|
||||
"""
|
||||
self._on_checkpoint = callback
|
||||
|
||||
def export(self) -> str:
|
||||
"""
|
||||
Exports all checkpoints to JSON.
|
||||
|
||||
Returns:
|
||||
JSON string of all checkpoints
|
||||
"""
|
||||
return json.dumps(
|
||||
[cp.to_dict() for cp in self._checkpoints],
|
||||
indent=2
|
||||
)
|
||||
|
||||
def import_checkpoints(self, json_data: str) -> int:
|
||||
"""
|
||||
Imports checkpoints from JSON.
|
||||
|
||||
Args:
|
||||
json_data: JSON string of checkpoints
|
||||
|
||||
Returns:
|
||||
Number of checkpoints imported
|
||||
"""
|
||||
data = json.loads(json_data)
|
||||
imported = [Checkpoint.from_dict(cp) for cp in data]
|
||||
self._checkpoints.extend(imported)
|
||||
self._checkpoint_count = max(
|
||||
self._checkpoint_count,
|
||||
len(self._checkpoints)
|
||||
)
|
||||
return len(imported)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._checkpoints)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._checkpoints)
|
||||
361
agent-core/sessions/heartbeat.py
Normal file
361
agent-core/sessions/heartbeat.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Heartbeat Monitoring for Breakpilot Agents
|
||||
|
||||
Provides liveness monitoring for agents with:
|
||||
- Configurable timeout thresholds
|
||||
- Async background monitoring
|
||||
- Callback-based timeout handling
|
||||
- Integration with SessionManager
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Callable, Optional, Awaitable, Set
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeartbeatEntry:
|
||||
"""Represents a heartbeat entry for an agent"""
|
||||
session_id: str
|
||||
agent_type: str
|
||||
last_beat: datetime
|
||||
missed_beats: int = 0
|
||||
|
||||
|
||||
class HeartbeatMonitor:
|
||||
"""
|
||||
Monitors agent heartbeats and triggers callbacks on timeout.
|
||||
|
||||
Usage:
|
||||
monitor = HeartbeatMonitor(timeout_seconds=30)
|
||||
monitor.on_timeout = handle_timeout
|
||||
await monitor.start_monitoring()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout_seconds: int = 30,
|
||||
check_interval_seconds: int = 5,
|
||||
max_missed_beats: int = 3
|
||||
):
|
||||
"""
|
||||
Initialize the heartbeat monitor.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Time without heartbeat before considered stale
|
||||
check_interval_seconds: How often to check for stale sessions
|
||||
max_missed_beats: Number of missed beats before triggering timeout
|
||||
"""
|
||||
self.sessions: Dict[str, HeartbeatEntry] = {}
|
||||
self.timeout = timedelta(seconds=timeout_seconds)
|
||||
self.check_interval = check_interval_seconds
|
||||
self.max_missed_beats = max_missed_beats
|
||||
self.on_timeout: Optional[Callable[[str, str], Awaitable[None]]] = None
|
||||
self.on_warning: Optional[Callable[[str, int], Awaitable[None]]] = None
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._paused_sessions: Set[str] = set()
|
||||
|
||||
async def start_monitoring(self) -> None:
|
||||
"""
|
||||
Starts the background heartbeat monitoring task.
|
||||
|
||||
This runs indefinitely until stop_monitoring() is called.
|
||||
"""
|
||||
if self._running:
|
||||
logger.warning("Heartbeat monitor already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._monitoring_loop())
|
||||
logger.info(
|
||||
f"Heartbeat monitor started (timeout={self.timeout.seconds}s, "
|
||||
f"interval={self.check_interval}s)"
|
||||
)
|
||||
|
||||
async def stop_monitoring(self) -> None:
|
||||
"""Stops the heartbeat monitoring task"""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
logger.info("Heartbeat monitor stopped")
|
||||
|
||||
async def _monitoring_loop(self) -> None:
|
||||
"""Main monitoring loop"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.check_interval)
|
||||
await self._check_heartbeats()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in heartbeat monitoring: {e}")
|
||||
|
||||
async def _check_heartbeats(self) -> None:
|
||||
"""Checks all registered sessions for stale heartbeats"""
|
||||
now = datetime.now(timezone.utc)
|
||||
timed_out = []
|
||||
|
||||
for session_id, entry in list(self.sessions.items()):
|
||||
# Skip paused sessions
|
||||
if session_id in self._paused_sessions:
|
||||
continue
|
||||
|
||||
time_since_beat = now - entry.last_beat
|
||||
|
||||
if time_since_beat > self.timeout:
|
||||
entry.missed_beats += 1
|
||||
|
||||
# Warn on first missed beat
|
||||
if entry.missed_beats == 1 and self.on_warning:
|
||||
await self.on_warning(session_id, entry.missed_beats)
|
||||
logger.warning(
|
||||
f"Session {session_id} missed heartbeat "
|
||||
f"({entry.missed_beats}/{self.max_missed_beats})"
|
||||
)
|
||||
|
||||
# Timeout after max missed beats
|
||||
if entry.missed_beats >= self.max_missed_beats:
|
||||
timed_out.append((session_id, entry.agent_type))
|
||||
|
||||
# Handle timeouts
|
||||
for session_id, agent_type in timed_out:
|
||||
logger.error(
|
||||
f"Session {session_id} ({agent_type}) timed out after "
|
||||
f"{self.max_missed_beats} missed heartbeats"
|
||||
)
|
||||
|
||||
if self.on_timeout:
|
||||
try:
|
||||
await self.on_timeout(session_id, agent_type)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in timeout handler: {e}")
|
||||
|
||||
# Remove from tracking
|
||||
del self.sessions[session_id]
|
||||
self._paused_sessions.discard(session_id)
|
||||
|
||||
def register(self, session_id: str, agent_type: str) -> None:
|
||||
"""
|
||||
Registers a session for heartbeat monitoring.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to monitor
|
||||
agent_type: The type of agent
|
||||
"""
|
||||
self.sessions[session_id] = HeartbeatEntry(
|
||||
session_id=session_id,
|
||||
agent_type=agent_type,
|
||||
last_beat=datetime.now(timezone.utc)
|
||||
)
|
||||
logger.debug(f"Registered session {session_id} for heartbeat monitoring")
|
||||
|
||||
def beat(self, session_id: str) -> bool:
|
||||
"""
|
||||
Records a heartbeat for a session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
True if the session is registered, False otherwise
|
||||
"""
|
||||
if session_id in self.sessions:
|
||||
self.sessions[session_id].last_beat = datetime.now(timezone.utc)
|
||||
self.sessions[session_id].missed_beats = 0
|
||||
return True
|
||||
return False
|
||||
|
||||
def unregister(self, session_id: str) -> bool:
|
||||
"""
|
||||
Unregisters a session from heartbeat monitoring.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to unregister
|
||||
|
||||
Returns:
|
||||
True if the session was registered, False otherwise
|
||||
"""
|
||||
self._paused_sessions.discard(session_id)
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
logger.debug(f"Unregistered session {session_id} from heartbeat monitoring")
|
||||
return True
|
||||
return False
|
||||
|
||||
def pause(self, session_id: str) -> bool:
|
||||
"""
|
||||
Pauses heartbeat monitoring for a session.
|
||||
|
||||
Useful when a session is intentionally idle (e.g., waiting for user input).
|
||||
|
||||
Args:
|
||||
session_id: The session ID to pause
|
||||
|
||||
Returns:
|
||||
True if the session was registered, False otherwise
|
||||
"""
|
||||
if session_id in self.sessions:
|
||||
self._paused_sessions.add(session_id)
|
||||
logger.debug(f"Paused heartbeat monitoring for session {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def resume(self, session_id: str) -> bool:
|
||||
"""
|
||||
Resumes heartbeat monitoring for a paused session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to resume
|
||||
|
||||
Returns:
|
||||
True if the session was paused, False otherwise
|
||||
"""
|
||||
if session_id in self._paused_sessions:
|
||||
self._paused_sessions.discard(session_id)
|
||||
# Reset the heartbeat timer
|
||||
self.beat(session_id)
|
||||
logger.debug(f"Resumed heartbeat monitoring for session {session_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_status(self, session_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
Gets the heartbeat status for a session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
Status dict or None if not registered
|
||||
"""
|
||||
if session_id not in self.sessions:
|
||||
return None
|
||||
|
||||
entry = self.sessions[session_id]
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"agent_type": entry.agent_type,
|
||||
"last_beat": entry.last_beat.isoformat(),
|
||||
"seconds_since_beat": (now - entry.last_beat).total_seconds(),
|
||||
"missed_beats": entry.missed_beats,
|
||||
"is_paused": session_id in self._paused_sessions,
|
||||
"is_healthy": entry.missed_beats == 0
|
||||
}
|
||||
|
||||
def get_all_status(self) -> Dict[str, Dict]:
|
||||
"""
|
||||
Gets heartbeat status for all registered sessions.
|
||||
|
||||
Returns:
|
||||
Dict mapping session_id to status dict
|
||||
"""
|
||||
return {
|
||||
session_id: self.get_status(session_id)
|
||||
for session_id in self.sessions
|
||||
}
|
||||
|
||||
@property
|
||||
def registered_count(self) -> int:
|
||||
"""Returns the number of registered sessions"""
|
||||
return len(self.sessions)
|
||||
|
||||
@property
|
||||
def healthy_count(self) -> int:
|
||||
"""Returns the number of healthy sessions (no missed beats)"""
|
||||
return sum(
|
||||
1 for entry in self.sessions.values()
|
||||
if entry.missed_beats == 0
|
||||
)
|
||||
|
||||
|
||||
class HeartbeatClient:
|
||||
"""
|
||||
Client-side heartbeat sender for agents.
|
||||
|
||||
Usage:
|
||||
client = HeartbeatClient(session_id, heartbeat_url)
|
||||
await client.start()
|
||||
# ... agent work ...
|
||||
await client.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
monitor: Optional[HeartbeatMonitor] = None,
|
||||
interval_seconds: int = 10
|
||||
):
|
||||
"""
|
||||
Initialize the heartbeat client.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to send heartbeats for
|
||||
monitor: Optional local HeartbeatMonitor (for in-process agents)
|
||||
interval_seconds: How often to send heartbeats
|
||||
"""
|
||||
self.session_id = session_id
|
||||
self.monitor = monitor
|
||||
self.interval = interval_seconds
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Starts sending heartbeats"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._heartbeat_loop())
|
||||
logger.debug(f"Heartbeat client started for session {self.session_id}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stops sending heartbeats"""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
logger.debug(f"Heartbeat client stopped for session {self.session_id}")
|
||||
|
||||
async def _heartbeat_loop(self) -> None:
|
||||
"""Main heartbeat sending loop"""
|
||||
while self._running:
|
||||
try:
|
||||
await self._send_heartbeat()
|
||||
await asyncio.sleep(self.interval)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending heartbeat: {e}")
|
||||
await asyncio.sleep(self.interval)
|
||||
|
||||
async def _send_heartbeat(self) -> None:
|
||||
"""Sends a single heartbeat"""
|
||||
if self.monitor:
|
||||
# Local monitor
|
||||
self.monitor.beat(self.session_id)
|
||||
# Future: Add HTTP-based heartbeat for distributed agents
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Context manager entry"""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit"""
|
||||
await self.stop()
|
||||
540
agent-core/sessions/session_manager.py
Normal file
540
agent-core/sessions/session_manager.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Session Management for Breakpilot Agents
|
||||
|
||||
Provides session lifecycle management with:
|
||||
- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED)
|
||||
- Checkpoint-based recovery
|
||||
- Heartbeat integration
|
||||
- Hybrid Valkey + PostgreSQL persistence
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from enum import Enum
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionState(Enum):
|
||||
"""Agent session states"""
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCheckpoint:
|
||||
"""Represents a checkpoint in an agent session"""
|
||||
name: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
data=data["data"]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSession:
|
||||
"""
|
||||
Represents an active agent session.
|
||||
|
||||
Attributes:
|
||||
session_id: Unique session identifier
|
||||
agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
|
||||
user_id: Associated user ID
|
||||
state: Current session state
|
||||
created_at: Session creation timestamp
|
||||
last_heartbeat: Last heartbeat timestamp
|
||||
context: Session context data
|
||||
checkpoints: List of session checkpoints for recovery
|
||||
"""
|
||||
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
agent_type: str = ""
|
||||
user_id: str = ""
|
||||
state: SessionState = SessionState.ACTIVE
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
checkpoints: List[SessionCheckpoint] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
|
||||
"""
|
||||
Creates a checkpoint for recovery.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name (e.g., "task_received", "processing_complete")
|
||||
data: Checkpoint data to store
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint = SessionCheckpoint(
|
||||
name=name,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data=data
|
||||
)
|
||||
self.checkpoints.append(checkpoint)
|
||||
logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
|
||||
return checkpoint
|
||||
|
||||
def heartbeat(self) -> None:
|
||||
"""Updates the heartbeat timestamp"""
|
||||
self.last_heartbeat = datetime.now(timezone.utc)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pauses the session"""
|
||||
self.state = SessionState.PAUSED
|
||||
self.checkpoint("session_paused", {"previous_state": "active"})
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resumes a paused session"""
|
||||
if self.state == SessionState.PAUSED:
|
||||
self.state = SessionState.ACTIVE
|
||||
self.heartbeat()
|
||||
self.checkpoint("session_resumed", {})
|
||||
|
||||
def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as completed"""
|
||||
self.state = SessionState.COMPLETED
|
||||
self.checkpoint("session_completed", {"result": result or {}})
|
||||
|
||||
def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as failed"""
|
||||
self.state = SessionState.FAILED
|
||||
self.checkpoint("session_failed", {
|
||||
"error": error,
|
||||
"details": error_details or {}
|
||||
})
|
||||
|
||||
def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
|
||||
"""
|
||||
Gets the last checkpoint, optionally filtered by name.
|
||||
|
||||
Args:
|
||||
name: Optional checkpoint name to filter by
|
||||
|
||||
Returns:
|
||||
The last matching checkpoint or None
|
||||
"""
|
||||
if not self.checkpoints:
|
||||
return None
|
||||
|
||||
if name:
|
||||
matching = [cp for cp in self.checkpoints if cp.name == name]
|
||||
return matching[-1] if matching else None
|
||||
|
||||
return self.checkpoints[-1]
|
||||
|
||||
def get_duration(self) -> timedelta:
|
||||
"""Returns the session duration"""
|
||||
end_time = datetime.now(timezone.utc)
|
||||
if self.state in (SessionState.COMPLETED, SessionState.FAILED):
|
||||
last_cp = self.get_last_checkpoint()
|
||||
if last_cp:
|
||||
end_time = last_cp.timestamp
|
||||
return end_time - self.created_at
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes the session to a dictionary"""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": self.agent_type,
|
||||
"user_id": self.user_id,
|
||||
"state": self.state.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"last_heartbeat": self.last_heartbeat.isoformat(),
|
||||
"context": self.context,
|
||||
"checkpoints": [cp.to_dict() for cp in self.checkpoints],
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
|
||||
"""Deserializes a session from a dictionary"""
|
||||
return cls(
|
||||
session_id=data["session_id"],
|
||||
agent_type=data["agent_type"],
|
||||
user_id=data["user_id"],
|
||||
state=SessionState(data["state"]),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
|
||||
context=data.get("context", {}),
|
||||
checkpoints=[
|
||||
SessionCheckpoint.from_dict(cp)
|
||||
for cp in data.get("checkpoints", [])
|
||||
],
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
|
||||
|
||||
- Valkey: Fast access for active sessions (24h TTL)
|
||||
- PostgreSQL: Persistent storage for audit trail and recovery
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot",
|
||||
session_ttl_hours: int = 24
|
||||
):
|
||||
"""
|
||||
Initialize the session manager.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
namespace: Redis key namespace
|
||||
session_ttl_hours: Session TTL in Valkey
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
self.session_ttl = timedelta(hours=session_ttl_hours)
|
||||
self._local_cache: Dict[str, AgentSession] = {}
|
||||
|
||||
def _redis_key(self, session_id: str) -> str:
|
||||
"""Generate Redis key for session"""
|
||||
return f"{self.namespace}:session:{session_id}"
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
agent_type: str,
|
||||
user_id: str = "",
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> AgentSession:
|
||||
"""
|
||||
Creates a new agent session.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent
|
||||
user_id: Associated user ID
|
||||
context: Initial session context
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
The created AgentSession
|
||||
"""
|
||||
session = AgentSession(
|
||||
agent_type=agent_type,
|
||||
user_id=user_id,
|
||||
context=context or {},
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
session.checkpoint("session_created", {
|
||||
"agent_type": agent_type,
|
||||
"user_id": user_id
|
||||
})
|
||||
|
||||
await self._persist_session(session)
|
||||
|
||||
logger.info(
|
||||
f"Created session {session.session_id} for agent '{agent_type}'"
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""
|
||||
Retrieves a session by ID.
|
||||
|
||||
Lookup order:
|
||||
1. Local cache
|
||||
2. Valkey
|
||||
3. PostgreSQL
|
||||
|
||||
Args:
|
||||
session_id: Session ID to retrieve
|
||||
|
||||
Returns:
|
||||
AgentSession if found, None otherwise
|
||||
"""
|
||||
# Check local cache first
|
||||
if session_id in self._local_cache:
|
||||
return self._local_cache[session_id]
|
||||
|
||||
# Try Valkey
|
||||
session = await self._get_from_valkey(session_id)
|
||||
if session:
|
||||
self._local_cache[session_id] = session
|
||||
return session
|
||||
|
||||
# Fall back to PostgreSQL
|
||||
session = await self._get_from_postgres(session_id)
|
||||
if session:
|
||||
# Re-cache in Valkey for faster access
|
||||
await self._cache_in_valkey(session)
|
||||
self._local_cache[session_id] = session
|
||||
return session
|
||||
|
||||
return None
|
||||
|
||||
async def update_session(self, session: AgentSession) -> None:
|
||||
"""
|
||||
Updates a session in all storage layers.
|
||||
|
||||
Args:
|
||||
session: The session to update
|
||||
"""
|
||||
session.heartbeat()
|
||||
self._local_cache[session.session_id] = session
|
||||
await self._persist_session(session)
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Deletes a session (soft delete in PostgreSQL).
|
||||
|
||||
Args:
|
||||
session_id: Session ID to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
# Remove from local cache
|
||||
self._local_cache.pop(session_id, None)
|
||||
|
||||
# Remove from Valkey
|
||||
if self.redis:
|
||||
await self.redis.delete(self._redis_key(session_id))
|
||||
|
||||
# Soft delete in PostgreSQL
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE agent_sessions
|
||||
SET state = 'deleted', updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
session_id
|
||||
)
|
||||
return result == "UPDATE 1"
|
||||
|
||||
return True
|
||||
|
||||
async def get_active_sessions(
|
||||
self,
|
||||
agent_type: Optional[str] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> List[AgentSession]:
|
||||
"""
|
||||
Gets all active sessions, optionally filtered.
|
||||
|
||||
Args:
|
||||
agent_type: Filter by agent type
|
||||
user_id: Filter by user ID
|
||||
|
||||
Returns:
|
||||
List of active sessions
|
||||
"""
|
||||
if not self.db_pool:
|
||||
# Fall back to local cache
|
||||
sessions = [
|
||||
s for s in self._local_cache.values()
|
||||
if s.state == SessionState.ACTIVE
|
||||
]
|
||||
if agent_type:
|
||||
sessions = [s for s in sessions if s.agent_type == agent_type]
|
||||
if user_id:
|
||||
sessions = [s for s in sessions if s.user_id == user_id]
|
||||
return sessions
|
||||
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat
|
||||
FROM agent_sessions
|
||||
WHERE state = 'active'
|
||||
"""
|
||||
params = []
|
||||
|
||||
if agent_type:
|
||||
query += " AND agent_type = $1"
|
||||
params.append(agent_type)
|
||||
|
||||
if user_id:
|
||||
param_num = len(params) + 1
|
||||
query += f" AND user_id = ${param_num}"
|
||||
params.append(user_id)
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
return [self._row_to_session(row) for row in rows]
|
||||
|
||||
async def cleanup_stale_sessions(
|
||||
self,
|
||||
max_age: timedelta = timedelta(hours=48)
|
||||
) -> int:
|
||||
"""
|
||||
Cleans up stale sessions (no heartbeat for max_age).
|
||||
|
||||
Args:
|
||||
max_age: Maximum age for stale sessions
|
||||
|
||||
Returns:
|
||||
Number of sessions cleaned up
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - max_age
|
||||
|
||||
if not self.db_pool:
|
||||
# Clean local cache
|
||||
stale = [
|
||||
sid for sid, s in self._local_cache.items()
|
||||
if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE
|
||||
]
|
||||
for sid in stale:
|
||||
session = self._local_cache[sid]
|
||||
session.fail("Session timeout - no heartbeat")
|
||||
return len(stale)
|
||||
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE agent_sessions
|
||||
SET state = 'failed',
|
||||
updated_at = NOW(),
|
||||
context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb
|
||||
WHERE state = 'active' AND last_heartbeat < $1
|
||||
""",
|
||||
cutoff
|
||||
)
|
||||
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
logger.info(f"Cleaned up {count} stale sessions")
|
||||
return count
|
||||
|
||||
async def _persist_session(self, session: AgentSession) -> None:
|
||||
"""Persists session to both Valkey and PostgreSQL"""
|
||||
await self._cache_in_valkey(session)
|
||||
await self._save_to_postgres(session)
|
||||
|
||||
async def _cache_in_valkey(self, session: AgentSession) -> None:
|
||||
"""Caches session in Valkey"""
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.redis.setex(
|
||||
self._redis_key(session.session_id),
|
||||
int(self.session_ttl.total_seconds()),
|
||||
json.dumps(session.to_dict())
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session in Valkey: {e}")
|
||||
|
||||
async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""Retrieves session from Valkey"""
|
||||
if not self.redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self.redis.get(self._redis_key(session_id))
|
||||
if data:
|
||||
return AgentSession.from_dict(json.loads(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get session from Valkey: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _save_to_postgres(self, session: AgentSession) -> None:
|
||||
"""Saves session to PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO agent_sessions
|
||||
(id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
context = EXCLUDED.context,
|
||||
checkpoints = EXCLUDED.checkpoints,
|
||||
updated_at = NOW(),
|
||||
last_heartbeat = EXCLUDED.last_heartbeat
|
||||
""",
|
||||
session.session_id,
|
||||
session.agent_type,
|
||||
session.user_id,
|
||||
session.state.value,
|
||||
json.dumps(session.context),
|
||||
json.dumps([cp.to_dict() for cp in session.checkpoints]),
|
||||
session.created_at,
|
||||
session.last_heartbeat
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session to PostgreSQL: {e}")
|
||||
|
||||
async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""Retrieves session from PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat
|
||||
FROM agent_sessions
|
||||
WHERE id = $1 AND state != 'deleted'
|
||||
""",
|
||||
session_id
|
||||
)
|
||||
|
||||
if row:
|
||||
return self._row_to_session(row)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get session from PostgreSQL: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _row_to_session(self, row) -> AgentSession:
|
||||
"""Converts a database row to AgentSession"""
|
||||
checkpoints_data = row["checkpoints"]
|
||||
if isinstance(checkpoints_data, str):
|
||||
checkpoints_data = json.loads(checkpoints_data)
|
||||
|
||||
context_data = row["context"]
|
||||
if isinstance(context_data, str):
|
||||
context_data = json.loads(context_data)
|
||||
|
||||
return AgentSession(
|
||||
session_id=str(row["id"]),
|
||||
agent_type=row["agent_type"],
|
||||
user_id=row["user_id"] or "",
|
||||
state=SessionState(row["state"]),
|
||||
created_at=row["created_at"],
|
||||
last_heartbeat=row["last_heartbeat"],
|
||||
context=context_data,
|
||||
checkpoints=[
|
||||
SessionCheckpoint.from_dict(cp) for cp in checkpoints_data
|
||||
]
|
||||
)
|
||||
Reference in New Issue
Block a user