This repository has been archived on 2026-02-15. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
breakpilot-pwa/agent-core/sessions/session_manager.py
Benjamin Admin bfdaf63ba9 fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.

This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).

Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 09:51:32 +01:00

541 lines
17 KiB
Python

"""
Session Management for Breakpilot Agents
Provides session lifecycle management with:
- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED)
- Checkpoint-based recovery
- Heartbeat integration
- Hybrid Valkey + PostgreSQL persistence
"""
from dataclasses import dataclass, field
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional, List
from enum import Enum
import uuid
import json
import logging
logger = logging.getLogger(__name__)
class SessionState(Enum):
"""Agent session states"""
ACTIVE = "active"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class SessionCheckpoint:
"""Represents a checkpoint in an agent session"""
name: str
timestamp: datetime
data: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
"timestamp": self.timestamp.isoformat(),
"data": self.data
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
return cls(
name=data["name"],
timestamp=datetime.fromisoformat(data["timestamp"]),
data=data["data"]
)
@dataclass
class AgentSession:
"""
Represents an active agent session.
Attributes:
session_id: Unique session identifier
agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
user_id: Associated user ID
state: Current session state
created_at: Session creation timestamp
last_heartbeat: Last heartbeat timestamp
context: Session context data
checkpoints: List of session checkpoints for recovery
"""
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
agent_type: str = ""
user_id: str = ""
state: SessionState = SessionState.ACTIVE
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
context: Dict[str, Any] = field(default_factory=dict)
checkpoints: List[SessionCheckpoint] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
"""
Creates a checkpoint for recovery.
Args:
name: Checkpoint name (e.g., "task_received", "processing_complete")
data: Checkpoint data to store
Returns:
The created checkpoint
"""
checkpoint = SessionCheckpoint(
name=name,
timestamp=datetime.now(timezone.utc),
data=data
)
self.checkpoints.append(checkpoint)
logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
return checkpoint
def heartbeat(self) -> None:
"""Updates the heartbeat timestamp"""
self.last_heartbeat = datetime.now(timezone.utc)
def pause(self) -> None:
"""Pauses the session"""
self.state = SessionState.PAUSED
self.checkpoint("session_paused", {"previous_state": "active"})
def resume(self) -> None:
"""Resumes a paused session"""
if self.state == SessionState.PAUSED:
self.state = SessionState.ACTIVE
self.heartbeat()
self.checkpoint("session_resumed", {})
def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
"""Marks the session as completed"""
self.state = SessionState.COMPLETED
self.checkpoint("session_completed", {"result": result or {}})
def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
"""Marks the session as failed"""
self.state = SessionState.FAILED
self.checkpoint("session_failed", {
"error": error,
"details": error_details or {}
})
def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
"""
Gets the last checkpoint, optionally filtered by name.
Args:
name: Optional checkpoint name to filter by
Returns:
The last matching checkpoint or None
"""
if not self.checkpoints:
return None
if name:
matching = [cp for cp in self.checkpoints if cp.name == name]
return matching[-1] if matching else None
return self.checkpoints[-1]
def get_duration(self) -> timedelta:
"""Returns the session duration"""
end_time = datetime.now(timezone.utc)
if self.state in (SessionState.COMPLETED, SessionState.FAILED):
last_cp = self.get_last_checkpoint()
if last_cp:
end_time = last_cp.timestamp
return end_time - self.created_at
def to_dict(self) -> Dict[str, Any]:
"""Serializes the session to a dictionary"""
return {
"session_id": self.session_id,
"agent_type": self.agent_type,
"user_id": self.user_id,
"state": self.state.value,
"created_at": self.created_at.isoformat(),
"last_heartbeat": self.last_heartbeat.isoformat(),
"context": self.context,
"checkpoints": [cp.to_dict() for cp in self.checkpoints],
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
"""Deserializes a session from a dictionary"""
return cls(
session_id=data["session_id"],
agent_type=data["agent_type"],
user_id=data["user_id"],
state=SessionState(data["state"]),
created_at=datetime.fromisoformat(data["created_at"]),
last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
context=data.get("context", {}),
checkpoints=[
SessionCheckpoint.from_dict(cp)
for cp in data.get("checkpoints", [])
],
metadata=data.get("metadata", {})
)
class SessionManager:
"""
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
- Valkey: Fast access for active sessions (24h TTL)
- PostgreSQL: Persistent storage for audit trail and recovery
"""
def __init__(
self,
redis_client=None,
db_pool=None,
namespace: str = "breakpilot",
session_ttl_hours: int = 24
):
"""
Initialize the session manager.
Args:
redis_client: Async Redis/Valkey client
db_pool: Async PostgreSQL connection pool
namespace: Redis key namespace
session_ttl_hours: Session TTL in Valkey
"""
self.redis = redis_client
self.db_pool = db_pool
self.namespace = namespace
self.session_ttl = timedelta(hours=session_ttl_hours)
self._local_cache: Dict[str, AgentSession] = {}
def _redis_key(self, session_id: str) -> str:
"""Generate Redis key for session"""
return f"{self.namespace}:session:{session_id}"
async def create_session(
self,
agent_type: str,
user_id: str = "",
context: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> AgentSession:
"""
Creates a new agent session.
Args:
agent_type: Type of agent
user_id: Associated user ID
context: Initial session context
metadata: Additional metadata
Returns:
The created AgentSession
"""
session = AgentSession(
agent_type=agent_type,
user_id=user_id,
context=context or {},
metadata=metadata or {}
)
session.checkpoint("session_created", {
"agent_type": agent_type,
"user_id": user_id
})
await self._persist_session(session)
logger.info(
f"Created session {session.session_id} for agent '{agent_type}'"
)
return session
async def get_session(self, session_id: str) -> Optional[AgentSession]:
"""
Retrieves a session by ID.
Lookup order:
1. Local cache
2. Valkey
3. PostgreSQL
Args:
session_id: Session ID to retrieve
Returns:
AgentSession if found, None otherwise
"""
# Check local cache first
if session_id in self._local_cache:
return self._local_cache[session_id]
# Try Valkey
session = await self._get_from_valkey(session_id)
if session:
self._local_cache[session_id] = session
return session
# Fall back to PostgreSQL
session = await self._get_from_postgres(session_id)
if session:
# Re-cache in Valkey for faster access
await self._cache_in_valkey(session)
self._local_cache[session_id] = session
return session
return None
async def update_session(self, session: AgentSession) -> None:
"""
Updates a session in all storage layers.
Args:
session: The session to update
"""
session.heartbeat()
self._local_cache[session.session_id] = session
await self._persist_session(session)
async def delete_session(self, session_id: str) -> bool:
"""
Deletes a session (soft delete in PostgreSQL).
Args:
session_id: Session ID to delete
Returns:
True if deleted, False if not found
"""
# Remove from local cache
self._local_cache.pop(session_id, None)
# Remove from Valkey
if self.redis:
await self.redis.delete(self._redis_key(session_id))
# Soft delete in PostgreSQL
if self.db_pool:
async with self.db_pool.acquire() as conn:
result = await conn.execute(
"""
UPDATE agent_sessions
SET state = 'deleted', updated_at = NOW()
WHERE id = $1
""",
session_id
)
return result == "UPDATE 1"
return True
async def get_active_sessions(
self,
agent_type: Optional[str] = None,
user_id: Optional[str] = None
) -> List[AgentSession]:
"""
Gets all active sessions, optionally filtered.
Args:
agent_type: Filter by agent type
user_id: Filter by user ID
Returns:
List of active sessions
"""
if not self.db_pool:
# Fall back to local cache
sessions = [
s for s in self._local_cache.values()
if s.state == SessionState.ACTIVE
]
if agent_type:
sessions = [s for s in sessions if s.agent_type == agent_type]
if user_id:
sessions = [s for s in sessions if s.user_id == user_id]
return sessions
async with self.db_pool.acquire() as conn:
query = """
SELECT id, agent_type, user_id, state, context, checkpoints,
created_at, updated_at, last_heartbeat
FROM agent_sessions
WHERE state = 'active'
"""
params = []
if agent_type:
query += " AND agent_type = $1"
params.append(agent_type)
if user_id:
param_num = len(params) + 1
query += f" AND user_id = ${param_num}"
params.append(user_id)
rows = await conn.fetch(query, *params)
return [self._row_to_session(row) for row in rows]
async def cleanup_stale_sessions(
self,
max_age: timedelta = timedelta(hours=48)
) -> int:
"""
Cleans up stale sessions (no heartbeat for max_age).
Args:
max_age: Maximum age for stale sessions
Returns:
Number of sessions cleaned up
"""
cutoff = datetime.now(timezone.utc) - max_age
if not self.db_pool:
# Clean local cache
stale = [
sid for sid, s in self._local_cache.items()
if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE
]
for sid in stale:
session = self._local_cache[sid]
session.fail("Session timeout - no heartbeat")
return len(stale)
async with self.db_pool.acquire() as conn:
result = await conn.execute(
"""
UPDATE agent_sessions
SET state = 'failed',
updated_at = NOW(),
context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb
WHERE state = 'active' AND last_heartbeat < $1
""",
cutoff
)
count = int(result.split()[-1]) if result else 0
logger.info(f"Cleaned up {count} stale sessions")
return count
async def _persist_session(self, session: AgentSession) -> None:
"""Persists session to both Valkey and PostgreSQL"""
await self._cache_in_valkey(session)
await self._save_to_postgres(session)
async def _cache_in_valkey(self, session: AgentSession) -> None:
"""Caches session in Valkey"""
if not self.redis:
return
try:
await self.redis.setex(
self._redis_key(session.session_id),
int(self.session_ttl.total_seconds()),
json.dumps(session.to_dict())
)
except Exception as e:
logger.warning(f"Failed to cache session in Valkey: {e}")
async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]:
"""Retrieves session from Valkey"""
if not self.redis:
return None
try:
data = await self.redis.get(self._redis_key(session_id))
if data:
return AgentSession.from_dict(json.loads(data))
except Exception as e:
logger.warning(f"Failed to get session from Valkey: {e}")
return None
async def _save_to_postgres(self, session: AgentSession) -> None:
"""Saves session to PostgreSQL"""
if not self.db_pool:
return
try:
async with self.db_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO agent_sessions
(id, agent_type, user_id, state, context, checkpoints,
created_at, updated_at, last_heartbeat)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8)
ON CONFLICT (id) DO UPDATE SET
state = EXCLUDED.state,
context = EXCLUDED.context,
checkpoints = EXCLUDED.checkpoints,
updated_at = NOW(),
last_heartbeat = EXCLUDED.last_heartbeat
""",
session.session_id,
session.agent_type,
session.user_id,
session.state.value,
json.dumps(session.context),
json.dumps([cp.to_dict() for cp in session.checkpoints]),
session.created_at,
session.last_heartbeat
)
except Exception as e:
logger.error(f"Failed to save session to PostgreSQL: {e}")
async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]:
"""Retrieves session from PostgreSQL"""
if not self.db_pool:
return None
try:
async with self.db_pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT id, agent_type, user_id, state, context, checkpoints,
created_at, updated_at, last_heartbeat
FROM agent_sessions
WHERE id = $1 AND state != 'deleted'
""",
session_id
)
if row:
return self._row_to_session(row)
except Exception as e:
logger.error(f"Failed to get session from PostgreSQL: {e}")
return None
def _row_to_session(self, row) -> AgentSession:
"""Converts a database row to AgentSession"""
checkpoints_data = row["checkpoints"]
if isinstance(checkpoints_data, str):
checkpoints_data = json.loads(checkpoints_data)
context_data = row["context"]
if isinstance(context_data, str):
context_data = json.loads(context_data)
return AgentSession(
session_id=str(row["id"]),
agent_type=row["agent_type"],
user_id=row["user_id"] or "",
state=SessionState(row["state"]),
created_at=row["created_at"],
last_heartbeat=row["last_heartbeat"],
context=context_data,
checkpoints=[
SessionCheckpoint.from_dict(cp) for cp in checkpoints_data
]
)