fix: Restore all files lost during destructive rebase

A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.

This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).

Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-02-09 09:51:32 +01:00
parent f7487ee240
commit bfdaf63ba9
2009 changed files with 749983 additions and 1731 deletions

View File

@@ -0,0 +1,25 @@
"""
Session Management for Breakpilot Agents
Provides:
- AgentSession: Individual agent session with context and checkpoints
- SessionManager: Create, retrieve, and manage agent sessions
- HeartbeatMonitor: Monitor agent liveness
- SessionState: Session state enumeration
"""
from agent_core.sessions.session_manager import (
AgentSession,
SessionManager,
SessionState,
)
from agent_core.sessions.heartbeat import HeartbeatMonitor
from agent_core.sessions.checkpoint import CheckpointManager
__all__ = [
"AgentSession",
"SessionManager",
"SessionState",
"HeartbeatMonitor",
"CheckpointManager",
]

View File

@@ -0,0 +1,362 @@
"""
Checkpoint Management for Breakpilot Agents
Provides checkpoint-based recovery with:
- Named checkpoints for semantic recovery points
- Automatic checkpoint compression
- Recovery from specific checkpoints
- Checkpoint analytics
"""
from typing import Dict, Any, Optional, List, Callable
from datetime import datetime, timezone
from dataclasses import dataclass, field
import json
import logging
logger = logging.getLogger(__name__)
@dataclass
class Checkpoint:
"""Represents a recovery checkpoint"""
id: str
name: str
timestamp: datetime
data: Dict[str, Any]
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"timestamp": self.timestamp.isoformat(),
"data": self.data,
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Checkpoint":
return cls(
id=data["id"],
name=data["name"],
timestamp=datetime.fromisoformat(data["timestamp"]),
data=data["data"],
metadata=data.get("metadata", {})
)
class CheckpointManager:
"""
Manages checkpoints for agent sessions.
Provides:
- Named checkpoints for semantic recovery
- Automatic compression of old checkpoints
- Recovery to specific checkpoint states
- Analytics on checkpoint patterns
"""
def __init__(
self,
session_id: str,
max_checkpoints: int = 100,
compress_after: int = 50
):
"""
Initialize the checkpoint manager.
Args:
session_id: The session ID this manager belongs to
max_checkpoints: Maximum number of checkpoints to retain
compress_after: Compress checkpoints after this count
"""
self.session_id = session_id
self.max_checkpoints = max_checkpoints
self.compress_after = compress_after
self._checkpoints: List[Checkpoint] = []
self._checkpoint_count = 0
self._on_checkpoint: Optional[Callable[[Checkpoint], None]] = None
def create(
self,
name: str,
data: Dict[str, Any],
metadata: Optional[Dict[str, Any]] = None
) -> Checkpoint:
"""
Creates a new checkpoint.
Args:
name: Semantic name for the checkpoint (e.g., "task_started")
data: Checkpoint data to store
metadata: Optional additional metadata
Returns:
The created checkpoint
"""
self._checkpoint_count += 1
checkpoint = Checkpoint(
id=f"{self.session_id}:{self._checkpoint_count}",
name=name,
timestamp=datetime.now(timezone.utc),
data=data,
metadata=metadata or {}
)
self._checkpoints.append(checkpoint)
# Compress if needed
if len(self._checkpoints) > self.compress_after:
self._compress_checkpoints()
# Trigger callback
if self._on_checkpoint:
self._on_checkpoint(checkpoint)
logger.debug(
f"Session {self.session_id}: Created checkpoint '{name}' "
f"(#{self._checkpoint_count})"
)
return checkpoint
def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
"""
Gets a checkpoint by ID.
Args:
checkpoint_id: The checkpoint ID
Returns:
The checkpoint or None if not found
"""
for cp in self._checkpoints:
if cp.id == checkpoint_id:
return cp
return None
def get_by_name(self, name: str) -> List[Checkpoint]:
"""
Gets all checkpoints with a given name.
Args:
name: The checkpoint name
Returns:
List of matching checkpoints (newest first)
"""
return [
cp for cp in reversed(self._checkpoints)
if cp.name == name
]
def get_latest(self, name: Optional[str] = None) -> Optional[Checkpoint]:
"""
Gets the latest checkpoint, optionally filtered by name.
Args:
name: Optional name filter
Returns:
The latest matching checkpoint or None
"""
if not self._checkpoints:
return None
if name:
matching = self.get_by_name(name)
return matching[0] if matching else None
return self._checkpoints[-1]
def get_all(self) -> List[Checkpoint]:
"""Returns all checkpoints"""
return list(self._checkpoints)
def get_since(self, timestamp: datetime) -> List[Checkpoint]:
"""
Gets all checkpoints since a given timestamp.
Args:
timestamp: The starting timestamp
Returns:
List of checkpoints after the timestamp
"""
return [
cp for cp in self._checkpoints
if cp.timestamp > timestamp
]
def get_between(
self,
start: datetime,
end: datetime
) -> List[Checkpoint]:
"""
Gets checkpoints between two timestamps.
Args:
start: Start timestamp
end: End timestamp
Returns:
List of checkpoints in the range
"""
return [
cp for cp in self._checkpoints
if start <= cp.timestamp <= end
]
def rollback_to(self, checkpoint_id: str) -> Optional[Dict[str, Any]]:
"""
Gets data needed to rollback to a checkpoint.
Note: This doesn't actually rollback - it returns the checkpoint
data for the caller to use for recovery.
Args:
checkpoint_id: The checkpoint to rollback to
Returns:
The checkpoint data or None if not found
"""
checkpoint = self.get(checkpoint_id)
if checkpoint:
logger.info(
f"Session {self.session_id}: Rollback to checkpoint "
f"'{checkpoint.name}' ({checkpoint_id})"
)
return checkpoint.data
return None
def clear(self) -> int:
"""
Clears all checkpoints.
Returns:
Number of checkpoints cleared
"""
count = len(self._checkpoints)
self._checkpoints.clear()
logger.info(f"Session {self.session_id}: Cleared {count} checkpoints")
return count
def _compress_checkpoints(self) -> None:
"""
Compresses old checkpoints to save memory.
Keeps:
- First checkpoint (session start)
- Last N checkpoints (recent history)
- One checkpoint per unique name (latest)
"""
if len(self._checkpoints) <= self.compress_after:
return
# Keep first checkpoint
first = self._checkpoints[0]
# Keep last 20 checkpoints
recent = self._checkpoints[-20:]
# Keep one of each unique name from the middle
middle = self._checkpoints[1:-20]
by_name: Dict[str, Checkpoint] = {}
for cp in middle:
# Keep the latest of each name
if cp.name not in by_name or cp.timestamp > by_name[cp.name].timestamp:
by_name[cp.name] = cp
# Combine and sort
compressed = [first] + list(by_name.values()) + recent
compressed.sort(key=lambda cp: cp.timestamp)
old_count = len(self._checkpoints)
self._checkpoints = compressed
logger.debug(
f"Session {self.session_id}: Compressed checkpoints "
f"from {old_count} to {len(self._checkpoints)}"
)
def get_summary(self) -> Dict[str, Any]:
"""
Gets a summary of checkpoint activity.
Returns:
Summary dict with counts and timing info
"""
if not self._checkpoints:
return {
"total_count": 0,
"unique_names": 0,
"names": {},
"first_checkpoint": None,
"last_checkpoint": None,
"duration_seconds": 0
}
name_counts: Dict[str, int] = {}
for cp in self._checkpoints:
name_counts[cp.name] = name_counts.get(cp.name, 0) + 1
first = self._checkpoints[0]
last = self._checkpoints[-1]
return {
"total_count": len(self._checkpoints),
"unique_names": len(name_counts),
"names": name_counts,
"first_checkpoint": first.to_dict(),
"last_checkpoint": last.to_dict(),
"duration_seconds": (last.timestamp - first.timestamp).total_seconds()
}
def on_checkpoint(self, callback: Callable[[Checkpoint], None]) -> None:
"""
Sets a callback to be called on each checkpoint.
Args:
callback: Function to call with each checkpoint
"""
self._on_checkpoint = callback
def export(self) -> str:
"""
Exports all checkpoints to JSON.
Returns:
JSON string of all checkpoints
"""
return json.dumps(
[cp.to_dict() for cp in self._checkpoints],
indent=2
)
def import_checkpoints(self, json_data: str) -> int:
"""
Imports checkpoints from JSON.
Args:
json_data: JSON string of checkpoints
Returns:
Number of checkpoints imported
"""
data = json.loads(json_data)
imported = [Checkpoint.from_dict(cp) for cp in data]
self._checkpoints.extend(imported)
self._checkpoint_count = max(
self._checkpoint_count,
len(self._checkpoints)
)
return len(imported)
def __len__(self) -> int:
return len(self._checkpoints)
def __iter__(self):
return iter(self._checkpoints)

View File

@@ -0,0 +1,361 @@
"""
Heartbeat Monitoring for Breakpilot Agents
Provides liveness monitoring for agents with:
- Configurable timeout thresholds
- Async background monitoring
- Callback-based timeout handling
- Integration with SessionManager
"""
import asyncio
from typing import Dict, Callable, Optional, Awaitable, Set
from datetime import datetime, timezone, timedelta
from dataclasses import dataclass, field
import logging
logger = logging.getLogger(__name__)
@dataclass
class HeartbeatEntry:
"""Represents a heartbeat entry for an agent"""
session_id: str
agent_type: str
last_beat: datetime
missed_beats: int = 0
class HeartbeatMonitor:
"""
Monitors agent heartbeats and triggers callbacks on timeout.
Usage:
monitor = HeartbeatMonitor(timeout_seconds=30)
monitor.on_timeout = handle_timeout
await monitor.start_monitoring()
"""
def __init__(
self,
timeout_seconds: int = 30,
check_interval_seconds: int = 5,
max_missed_beats: int = 3
):
"""
Initialize the heartbeat monitor.
Args:
timeout_seconds: Time without heartbeat before considered stale
check_interval_seconds: How often to check for stale sessions
max_missed_beats: Number of missed beats before triggering timeout
"""
self.sessions: Dict[str, HeartbeatEntry] = {}
self.timeout = timedelta(seconds=timeout_seconds)
self.check_interval = check_interval_seconds
self.max_missed_beats = max_missed_beats
self.on_timeout: Optional[Callable[[str, str], Awaitable[None]]] = None
self.on_warning: Optional[Callable[[str, int], Awaitable[None]]] = None
self._running = False
self._task: Optional[asyncio.Task] = None
self._paused_sessions: Set[str] = set()
async def start_monitoring(self) -> None:
"""
Starts the background heartbeat monitoring task.
This runs indefinitely until stop_monitoring() is called.
"""
if self._running:
logger.warning("Heartbeat monitor already running")
return
self._running = True
self._task = asyncio.create_task(self._monitoring_loop())
logger.info(
f"Heartbeat monitor started (timeout={self.timeout.seconds}s, "
f"interval={self.check_interval}s)"
)
async def stop_monitoring(self) -> None:
"""Stops the heartbeat monitoring task"""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
logger.info("Heartbeat monitor stopped")
async def _monitoring_loop(self) -> None:
"""Main monitoring loop"""
while self._running:
try:
await asyncio.sleep(self.check_interval)
await self._check_heartbeats()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in heartbeat monitoring: {e}")
async def _check_heartbeats(self) -> None:
"""Checks all registered sessions for stale heartbeats"""
now = datetime.now(timezone.utc)
timed_out = []
for session_id, entry in list(self.sessions.items()):
# Skip paused sessions
if session_id in self._paused_sessions:
continue
time_since_beat = now - entry.last_beat
if time_since_beat > self.timeout:
entry.missed_beats += 1
# Warn on first missed beat
if entry.missed_beats == 1 and self.on_warning:
await self.on_warning(session_id, entry.missed_beats)
logger.warning(
f"Session {session_id} missed heartbeat "
f"({entry.missed_beats}/{self.max_missed_beats})"
)
# Timeout after max missed beats
if entry.missed_beats >= self.max_missed_beats:
timed_out.append((session_id, entry.agent_type))
# Handle timeouts
for session_id, agent_type in timed_out:
logger.error(
f"Session {session_id} ({agent_type}) timed out after "
f"{self.max_missed_beats} missed heartbeats"
)
if self.on_timeout:
try:
await self.on_timeout(session_id, agent_type)
except Exception as e:
logger.error(f"Error in timeout handler: {e}")
# Remove from tracking
del self.sessions[session_id]
self._paused_sessions.discard(session_id)
def register(self, session_id: str, agent_type: str) -> None:
"""
Registers a session for heartbeat monitoring.
Args:
session_id: The session ID to monitor
agent_type: The type of agent
"""
self.sessions[session_id] = HeartbeatEntry(
session_id=session_id,
agent_type=agent_type,
last_beat=datetime.now(timezone.utc)
)
logger.debug(f"Registered session {session_id} for heartbeat monitoring")
def beat(self, session_id: str) -> bool:
"""
Records a heartbeat for a session.
Args:
session_id: The session ID
Returns:
True if the session is registered, False otherwise
"""
if session_id in self.sessions:
self.sessions[session_id].last_beat = datetime.now(timezone.utc)
self.sessions[session_id].missed_beats = 0
return True
return False
def unregister(self, session_id: str) -> bool:
"""
Unregisters a session from heartbeat monitoring.
Args:
session_id: The session ID to unregister
Returns:
True if the session was registered, False otherwise
"""
self._paused_sessions.discard(session_id)
if session_id in self.sessions:
del self.sessions[session_id]
logger.debug(f"Unregistered session {session_id} from heartbeat monitoring")
return True
return False
def pause(self, session_id: str) -> bool:
"""
Pauses heartbeat monitoring for a session.
Useful when a session is intentionally idle (e.g., waiting for user input).
Args:
session_id: The session ID to pause
Returns:
True if the session was registered, False otherwise
"""
if session_id in self.sessions:
self._paused_sessions.add(session_id)
logger.debug(f"Paused heartbeat monitoring for session {session_id}")
return True
return False
def resume(self, session_id: str) -> bool:
"""
Resumes heartbeat monitoring for a paused session.
Args:
session_id: The session ID to resume
Returns:
True if the session was paused, False otherwise
"""
if session_id in self._paused_sessions:
self._paused_sessions.discard(session_id)
# Reset the heartbeat timer
self.beat(session_id)
logger.debug(f"Resumed heartbeat monitoring for session {session_id}")
return True
return False
def get_status(self, session_id: str) -> Optional[Dict]:
"""
Gets the heartbeat status for a session.
Args:
session_id: The session ID
Returns:
Status dict or None if not registered
"""
if session_id not in self.sessions:
return None
entry = self.sessions[session_id]
now = datetime.now(timezone.utc)
return {
"session_id": session_id,
"agent_type": entry.agent_type,
"last_beat": entry.last_beat.isoformat(),
"seconds_since_beat": (now - entry.last_beat).total_seconds(),
"missed_beats": entry.missed_beats,
"is_paused": session_id in self._paused_sessions,
"is_healthy": entry.missed_beats == 0
}
def get_all_status(self) -> Dict[str, Dict]:
"""
Gets heartbeat status for all registered sessions.
Returns:
Dict mapping session_id to status dict
"""
return {
session_id: self.get_status(session_id)
for session_id in self.sessions
}
@property
def registered_count(self) -> int:
"""Returns the number of registered sessions"""
return len(self.sessions)
@property
def healthy_count(self) -> int:
"""Returns the number of healthy sessions (no missed beats)"""
return sum(
1 for entry in self.sessions.values()
if entry.missed_beats == 0
)
class HeartbeatClient:
"""
Client-side heartbeat sender for agents.
Usage:
client = HeartbeatClient(session_id, heartbeat_url)
await client.start()
# ... agent work ...
await client.stop()
"""
def __init__(
self,
session_id: str,
monitor: Optional[HeartbeatMonitor] = None,
interval_seconds: int = 10
):
"""
Initialize the heartbeat client.
Args:
session_id: The session ID to send heartbeats for
monitor: Optional local HeartbeatMonitor (for in-process agents)
interval_seconds: How often to send heartbeats
"""
self.session_id = session_id
self.monitor = monitor
self.interval = interval_seconds
self._running = False
self._task: Optional[asyncio.Task] = None
async def start(self) -> None:
"""Starts sending heartbeats"""
if self._running:
return
self._running = True
self._task = asyncio.create_task(self._heartbeat_loop())
logger.debug(f"Heartbeat client started for session {self.session_id}")
async def stop(self) -> None:
"""Stops sending heartbeats"""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
logger.debug(f"Heartbeat client stopped for session {self.session_id}")
async def _heartbeat_loop(self) -> None:
"""Main heartbeat sending loop"""
while self._running:
try:
await self._send_heartbeat()
await asyncio.sleep(self.interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error sending heartbeat: {e}")
await asyncio.sleep(self.interval)
async def _send_heartbeat(self) -> None:
"""Sends a single heartbeat"""
if self.monitor:
# Local monitor
self.monitor.beat(self.session_id)
# Future: Add HTTP-based heartbeat for distributed agents
async def __aenter__(self):
"""Context manager entry"""
await self.start()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
await self.stop()

View File

@@ -0,0 +1,540 @@
"""
Session Management for Breakpilot Agents
Provides session lifecycle management with:
- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED)
- Checkpoint-based recovery
- Heartbeat integration
- Hybrid Valkey + PostgreSQL persistence
"""
from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional, List
from enum import Enum
import uuid
import json
import logging
logger = logging.getLogger(__name__)
class SessionState(Enum):
"""Agent session states"""
ACTIVE = "active"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class SessionCheckpoint:
"""Represents a checkpoint in an agent session"""
name: str
timestamp: datetime
data: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
"timestamp": self.timestamp.isoformat(),
"data": self.data
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
return cls(
name=data["name"],
timestamp=datetime.fromisoformat(data["timestamp"]),
data=data["data"]
)
@dataclass
class AgentSession:
"""
Represents an active agent session.
Attributes:
session_id: Unique session identifier
agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
user_id: Associated user ID
state: Current session state
created_at: Session creation timestamp
last_heartbeat: Last heartbeat timestamp
context: Session context data
checkpoints: List of session checkpoints for recovery
"""
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
agent_type: str = ""
user_id: str = ""
state: SessionState = SessionState.ACTIVE
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
context: Dict[str, Any] = field(default_factory=dict)
checkpoints: List[SessionCheckpoint] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
"""
Creates a checkpoint for recovery.
Args:
name: Checkpoint name (e.g., "task_received", "processing_complete")
data: Checkpoint data to store
Returns:
The created checkpoint
"""
checkpoint = SessionCheckpoint(
name=name,
timestamp=datetime.now(timezone.utc),
data=data
)
self.checkpoints.append(checkpoint)
logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
return checkpoint
def heartbeat(self) -> None:
"""Updates the heartbeat timestamp"""
self.last_heartbeat = datetime.now(timezone.utc)
def pause(self) -> None:
"""Pauses the session"""
self.state = SessionState.PAUSED
self.checkpoint("session_paused", {"previous_state": "active"})
def resume(self) -> None:
"""Resumes a paused session"""
if self.state == SessionState.PAUSED:
self.state = SessionState.ACTIVE
self.heartbeat()
self.checkpoint("session_resumed", {})
def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
"""Marks the session as completed"""
self.state = SessionState.COMPLETED
self.checkpoint("session_completed", {"result": result or {}})
def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
"""Marks the session as failed"""
self.state = SessionState.FAILED
self.checkpoint("session_failed", {
"error": error,
"details": error_details or {}
})
def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
"""
Gets the last checkpoint, optionally filtered by name.
Args:
name: Optional checkpoint name to filter by
Returns:
The last matching checkpoint or None
"""
if not self.checkpoints:
return None
if name:
matching = [cp for cp in self.checkpoints if cp.name == name]
return matching[-1] if matching else None
return self.checkpoints[-1]
def get_duration(self) -> timedelta:
"""Returns the session duration"""
end_time = datetime.now(timezone.utc)
if self.state in (SessionState.COMPLETED, SessionState.FAILED):
last_cp = self.get_last_checkpoint()
if last_cp:
end_time = last_cp.timestamp
return end_time - self.created_at
def to_dict(self) -> Dict[str, Any]:
"""Serializes the session to a dictionary"""
return {
"session_id": self.session_id,
"agent_type": self.agent_type,
"user_id": self.user_id,
"state": self.state.value,
"created_at": self.created_at.isoformat(),
"last_heartbeat": self.last_heartbeat.isoformat(),
"context": self.context,
"checkpoints": [cp.to_dict() for cp in self.checkpoints],
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
"""Deserializes a session from a dictionary"""
return cls(
session_id=data["session_id"],
agent_type=data["agent_type"],
user_id=data["user_id"],
state=SessionState(data["state"]),
created_at=datetime.fromisoformat(data["created_at"]),
last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
context=data.get("context", {}),
checkpoints=[
SessionCheckpoint.from_dict(cp)
for cp in data.get("checkpoints", [])
],
metadata=data.get("metadata", {})
)
class SessionManager:
"""
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
- Valkey: Fast access for active sessions (24h TTL)
- PostgreSQL: Persistent storage for audit trail and recovery
"""
def __init__(
self,
redis_client=None,
db_pool=None,
namespace: str = "breakpilot",
session_ttl_hours: int = 24
):
"""
Initialize the session manager.
Args:
redis_client: Async Redis/Valkey client
db_pool: Async PostgreSQL connection pool
namespace: Redis key namespace
session_ttl_hours: Session TTL in Valkey
"""
self.redis = redis_client
self.db_pool = db_pool
self.namespace = namespace
self.session_ttl = timedelta(hours=session_ttl_hours)
self._local_cache: Dict[str, AgentSession] = {}
def _redis_key(self, session_id: str) -> str:
"""Generate Redis key for session"""
return f"{self.namespace}:session:{session_id}"
async def create_session(
self,
agent_type: str,
user_id: str = "",
context: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> AgentSession:
"""
Creates a new agent session.
Args:
agent_type: Type of agent
user_id: Associated user ID
context: Initial session context
metadata: Additional metadata
Returns:
The created AgentSession
"""
session = AgentSession(
agent_type=agent_type,
user_id=user_id,
context=context or {},
metadata=metadata or {}
)
session.checkpoint("session_created", {
"agent_type": agent_type,
"user_id": user_id
})
await self._persist_session(session)
logger.info(
f"Created session {session.session_id} for agent '{agent_type}'"
)
return session
async def get_session(self, session_id: str) -> Optional[AgentSession]:
"""
Retrieves a session by ID.
Lookup order:
1. Local cache
2. Valkey
3. PostgreSQL
Args:
session_id: Session ID to retrieve
Returns:
AgentSession if found, None otherwise
"""
# Check local cache first
if session_id in self._local_cache:
return self._local_cache[session_id]
# Try Valkey
session = await self._get_from_valkey(session_id)
if session:
self._local_cache[session_id] = session
return session
# Fall back to PostgreSQL
session = await self._get_from_postgres(session_id)
if session:
# Re-cache in Valkey for faster access
await self._cache_in_valkey(session)
self._local_cache[session_id] = session
return session
return None
async def update_session(self, session: AgentSession) -> None:
"""
Updates a session in all storage layers.
Args:
session: The session to update
"""
session.heartbeat()
self._local_cache[session.session_id] = session
await self._persist_session(session)
async def delete_session(self, session_id: str) -> bool:
"""
Deletes a session (soft delete in PostgreSQL).
Args:
session_id: Session ID to delete
Returns:
True if deleted, False if not found
"""
# Remove from local cache
self._local_cache.pop(session_id, None)
# Remove from Valkey
if self.redis:
await self.redis.delete(self._redis_key(session_id))
# Soft delete in PostgreSQL
if self.db_pool:
async with self.db_pool.acquire() as conn:
result = await conn.execute(
"""
UPDATE agent_sessions
SET state = 'deleted', updated_at = NOW()
WHERE id = $1
""",
session_id
)
return result == "UPDATE 1"
return True
async def get_active_sessions(
self,
agent_type: Optional[str] = None,
user_id: Optional[str] = None
) -> List[AgentSession]:
"""
Gets all active sessions, optionally filtered.
Args:
agent_type: Filter by agent type
user_id: Filter by user ID
Returns:
List of active sessions
"""
if not self.db_pool:
# Fall back to local cache
sessions = [
s for s in self._local_cache.values()
if s.state == SessionState.ACTIVE
]
if agent_type:
sessions = [s for s in sessions if s.agent_type == agent_type]
if user_id:
sessions = [s for s in sessions if s.user_id == user_id]
return sessions
async with self.db_pool.acquire() as conn:
query = """
SELECT id, agent_type, user_id, state, context, checkpoints,
created_at, updated_at, last_heartbeat
FROM agent_sessions
WHERE state = 'active'
"""
params = []
if agent_type:
query += " AND agent_type = $1"
params.append(agent_type)
if user_id:
param_num = len(params) + 1
query += f" AND user_id = ${param_num}"
params.append(user_id)
rows = await conn.fetch(query, *params)
return [self._row_to_session(row) for row in rows]
async def cleanup_stale_sessions(
self,
max_age: timedelta = timedelta(hours=48)
) -> int:
"""
Cleans up stale sessions (no heartbeat for max_age).
Args:
max_age: Maximum age for stale sessions
Returns:
Number of sessions cleaned up
"""
cutoff = datetime.now(timezone.utc) - max_age
if not self.db_pool:
# Clean local cache
stale = [
sid for sid, s in self._local_cache.items()
if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE
]
for sid in stale:
session = self._local_cache[sid]
session.fail("Session timeout - no heartbeat")
return len(stale)
async with self.db_pool.acquire() as conn:
result = await conn.execute(
"""
UPDATE agent_sessions
SET state = 'failed',
updated_at = NOW(),
context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb
WHERE state = 'active' AND last_heartbeat < $1
""",
cutoff
)
count = int(result.split()[-1]) if result else 0
logger.info(f"Cleaned up {count} stale sessions")
return count
async def _persist_session(self, session: AgentSession) -> None:
"""Persists session to both Valkey and PostgreSQL"""
await self._cache_in_valkey(session)
await self._save_to_postgres(session)
async def _cache_in_valkey(self, session: AgentSession) -> None:
"""Caches session in Valkey"""
if not self.redis:
return
try:
await self.redis.setex(
self._redis_key(session.session_id),
int(self.session_ttl.total_seconds()),
json.dumps(session.to_dict())
)
except Exception as e:
logger.warning(f"Failed to cache session in Valkey: {e}")
async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]:
"""Retrieves session from Valkey"""
if not self.redis:
return None
try:
data = await self.redis.get(self._redis_key(session_id))
if data:
return AgentSession.from_dict(json.loads(data))
except Exception as e:
logger.warning(f"Failed to get session from Valkey: {e}")
return None
async def _save_to_postgres(self, session: AgentSession) -> None:
"""Saves session to PostgreSQL"""
if not self.db_pool:
return
try:
async with self.db_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO agent_sessions
(id, agent_type, user_id, state, context, checkpoints,
created_at, updated_at, last_heartbeat)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8)
ON CONFLICT (id) DO UPDATE SET
state = EXCLUDED.state,
context = EXCLUDED.context,
checkpoints = EXCLUDED.checkpoints,
updated_at = NOW(),
last_heartbeat = EXCLUDED.last_heartbeat
""",
session.session_id,
session.agent_type,
session.user_id,
session.state.value,
json.dumps(session.context),
json.dumps([cp.to_dict() for cp in session.checkpoints]),
session.created_at,
session.last_heartbeat
)
except Exception as e:
logger.error(f"Failed to save session to PostgreSQL: {e}")
async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]:
"""Retrieves session from PostgreSQL"""
if not self.db_pool:
return None
try:
async with self.db_pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT id, agent_type, user_id, state, context, checkpoints,
created_at, updated_at, last_heartbeat
FROM agent_sessions
WHERE id = $1 AND state != 'deleted'
""",
session_id
)
if row:
return self._row_to_session(row)
except Exception as e:
logger.error(f"Failed to get session from PostgreSQL: {e}")
return None
def _row_to_session(self, row) -> AgentSession:
"""Converts a database row to AgentSession"""
checkpoints_data = row["checkpoints"]
if isinstance(checkpoints_data, str):
checkpoints_data = json.loads(checkpoints_data)
context_data = row["context"]
if isinstance(context_data, str):
context_data = json.loads(context_data)
return AgentSession(
session_id=str(row["id"]),
agent_type=row["agent_type"],
user_id=row["user_id"] or "",
state=SessionState(row["state"]),
created_at=row["created_at"],
last_heartbeat=row["last_heartbeat"],
context=context_data,
checkpoints=[
SessionCheckpoint.from_dict(cp) for cp in checkpoints_data
]
)