A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
541 lines
17 KiB
Python
541 lines
17 KiB
Python
"""
|
|
Session Management for Breakpilot Agents
|
|
|
|
Provides session lifecycle management with:
|
|
- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED)
|
|
- Checkpoint-based recovery
|
|
- Heartbeat integration
|
|
- Hybrid Valkey + PostgreSQL persistence
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Dict, Any, Optional, List
|
|
from enum import Enum
|
|
import uuid
|
|
import json
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SessionState(Enum):
|
|
"""Agent session states"""
|
|
ACTIVE = "active"
|
|
PAUSED = "paused"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
|
|
|
|
@dataclass
|
|
class SessionCheckpoint:
|
|
"""Represents a checkpoint in an agent session"""
|
|
name: str
|
|
timestamp: datetime
|
|
data: Dict[str, Any]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"name": self.name,
|
|
"timestamp": self.timestamp.isoformat(),
|
|
"data": self.data
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
|
|
return cls(
|
|
name=data["name"],
|
|
timestamp=datetime.fromisoformat(data["timestamp"]),
|
|
data=data["data"]
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class AgentSession:
|
|
"""
|
|
Represents an active agent session.
|
|
|
|
Attributes:
|
|
session_id: Unique session identifier
|
|
agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
|
|
user_id: Associated user ID
|
|
state: Current session state
|
|
created_at: Session creation timestamp
|
|
last_heartbeat: Last heartbeat timestamp
|
|
context: Session context data
|
|
checkpoints: List of session checkpoints for recovery
|
|
"""
|
|
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
agent_type: str = ""
|
|
user_id: str = ""
|
|
state: SessionState = SessionState.ACTIVE
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
context: Dict[str, Any] = field(default_factory=dict)
|
|
checkpoints: List[SessionCheckpoint] = field(default_factory=list)
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
|
|
"""
|
|
Creates a checkpoint for recovery.
|
|
|
|
Args:
|
|
name: Checkpoint name (e.g., "task_received", "processing_complete")
|
|
data: Checkpoint data to store
|
|
|
|
Returns:
|
|
The created checkpoint
|
|
"""
|
|
checkpoint = SessionCheckpoint(
|
|
name=name,
|
|
timestamp=datetime.now(timezone.utc),
|
|
data=data
|
|
)
|
|
self.checkpoints.append(checkpoint)
|
|
logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
|
|
return checkpoint
|
|
|
|
def heartbeat(self) -> None:
|
|
"""Updates the heartbeat timestamp"""
|
|
self.last_heartbeat = datetime.now(timezone.utc)
|
|
|
|
def pause(self) -> None:
|
|
"""Pauses the session"""
|
|
self.state = SessionState.PAUSED
|
|
self.checkpoint("session_paused", {"previous_state": "active"})
|
|
|
|
def resume(self) -> None:
|
|
"""Resumes a paused session"""
|
|
if self.state == SessionState.PAUSED:
|
|
self.state = SessionState.ACTIVE
|
|
self.heartbeat()
|
|
self.checkpoint("session_resumed", {})
|
|
|
|
def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Marks the session as completed"""
|
|
self.state = SessionState.COMPLETED
|
|
self.checkpoint("session_completed", {"result": result or {}})
|
|
|
|
def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Marks the session as failed"""
|
|
self.state = SessionState.FAILED
|
|
self.checkpoint("session_failed", {
|
|
"error": error,
|
|
"details": error_details or {}
|
|
})
|
|
|
|
def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
|
|
"""
|
|
Gets the last checkpoint, optionally filtered by name.
|
|
|
|
Args:
|
|
name: Optional checkpoint name to filter by
|
|
|
|
Returns:
|
|
The last matching checkpoint or None
|
|
"""
|
|
if not self.checkpoints:
|
|
return None
|
|
|
|
if name:
|
|
matching = [cp for cp in self.checkpoints if cp.name == name]
|
|
return matching[-1] if matching else None
|
|
|
|
return self.checkpoints[-1]
|
|
|
|
def get_duration(self) -> timedelta:
|
|
"""Returns the session duration"""
|
|
end_time = datetime.now(timezone.utc)
|
|
if self.state in (SessionState.COMPLETED, SessionState.FAILED):
|
|
last_cp = self.get_last_checkpoint()
|
|
if last_cp:
|
|
end_time = last_cp.timestamp
|
|
return end_time - self.created_at
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Serializes the session to a dictionary"""
|
|
return {
|
|
"session_id": self.session_id,
|
|
"agent_type": self.agent_type,
|
|
"user_id": self.user_id,
|
|
"state": self.state.value,
|
|
"created_at": self.created_at.isoformat(),
|
|
"last_heartbeat": self.last_heartbeat.isoformat(),
|
|
"context": self.context,
|
|
"checkpoints": [cp.to_dict() for cp in self.checkpoints],
|
|
"metadata": self.metadata
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
|
|
"""Deserializes a session from a dictionary"""
|
|
return cls(
|
|
session_id=data["session_id"],
|
|
agent_type=data["agent_type"],
|
|
user_id=data["user_id"],
|
|
state=SessionState(data["state"]),
|
|
created_at=datetime.fromisoformat(data["created_at"]),
|
|
last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
|
|
context=data.get("context", {}),
|
|
checkpoints=[
|
|
SessionCheckpoint.from_dict(cp)
|
|
for cp in data.get("checkpoints", [])
|
|
],
|
|
metadata=data.get("metadata", {})
|
|
)
|
|
|
|
|
|
class SessionManager:
|
|
"""
|
|
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
|
|
|
|
- Valkey: Fast access for active sessions (24h TTL)
|
|
- PostgreSQL: Persistent storage for audit trail and recovery
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_client=None,
|
|
db_pool=None,
|
|
namespace: str = "breakpilot",
|
|
session_ttl_hours: int = 24
|
|
):
|
|
"""
|
|
Initialize the session manager.
|
|
|
|
Args:
|
|
redis_client: Async Redis/Valkey client
|
|
db_pool: Async PostgreSQL connection pool
|
|
namespace: Redis key namespace
|
|
session_ttl_hours: Session TTL in Valkey
|
|
"""
|
|
self.redis = redis_client
|
|
self.db_pool = db_pool
|
|
self.namespace = namespace
|
|
self.session_ttl = timedelta(hours=session_ttl_hours)
|
|
self._local_cache: Dict[str, AgentSession] = {}
|
|
|
|
def _redis_key(self, session_id: str) -> str:
|
|
"""Generate Redis key for session"""
|
|
return f"{self.namespace}:session:{session_id}"
|
|
|
|
async def create_session(
|
|
self,
|
|
agent_type: str,
|
|
user_id: str = "",
|
|
context: Optional[Dict[str, Any]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> AgentSession:
|
|
"""
|
|
Creates a new agent session.
|
|
|
|
Args:
|
|
agent_type: Type of agent
|
|
user_id: Associated user ID
|
|
context: Initial session context
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
The created AgentSession
|
|
"""
|
|
session = AgentSession(
|
|
agent_type=agent_type,
|
|
user_id=user_id,
|
|
context=context or {},
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
session.checkpoint("session_created", {
|
|
"agent_type": agent_type,
|
|
"user_id": user_id
|
|
})
|
|
|
|
await self._persist_session(session)
|
|
|
|
logger.info(
|
|
f"Created session {session.session_id} for agent '{agent_type}'"
|
|
)
|
|
|
|
return session
|
|
|
|
async def get_session(self, session_id: str) -> Optional[AgentSession]:
|
|
"""
|
|
Retrieves a session by ID.
|
|
|
|
Lookup order:
|
|
1. Local cache
|
|
2. Valkey
|
|
3. PostgreSQL
|
|
|
|
Args:
|
|
session_id: Session ID to retrieve
|
|
|
|
Returns:
|
|
AgentSession if found, None otherwise
|
|
"""
|
|
# Check local cache first
|
|
if session_id in self._local_cache:
|
|
return self._local_cache[session_id]
|
|
|
|
# Try Valkey
|
|
session = await self._get_from_valkey(session_id)
|
|
if session:
|
|
self._local_cache[session_id] = session
|
|
return session
|
|
|
|
# Fall back to PostgreSQL
|
|
session = await self._get_from_postgres(session_id)
|
|
if session:
|
|
# Re-cache in Valkey for faster access
|
|
await self._cache_in_valkey(session)
|
|
self._local_cache[session_id] = session
|
|
return session
|
|
|
|
return None
|
|
|
|
async def update_session(self, session: AgentSession) -> None:
|
|
"""
|
|
Updates a session in all storage layers.
|
|
|
|
Args:
|
|
session: The session to update
|
|
"""
|
|
session.heartbeat()
|
|
self._local_cache[session.session_id] = session
|
|
await self._persist_session(session)
|
|
|
|
async def delete_session(self, session_id: str) -> bool:
|
|
"""
|
|
Deletes a session (soft delete in PostgreSQL).
|
|
|
|
Args:
|
|
session_id: Session ID to delete
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
# Remove from local cache
|
|
self._local_cache.pop(session_id, None)
|
|
|
|
# Remove from Valkey
|
|
if self.redis:
|
|
await self.redis.delete(self._redis_key(session_id))
|
|
|
|
# Soft delete in PostgreSQL
|
|
if self.db_pool:
|
|
async with self.db_pool.acquire() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
UPDATE agent_sessions
|
|
SET state = 'deleted', updated_at = NOW()
|
|
WHERE id = $1
|
|
""",
|
|
session_id
|
|
)
|
|
return result == "UPDATE 1"
|
|
|
|
return True
|
|
|
|
async def get_active_sessions(
|
|
self,
|
|
agent_type: Optional[str] = None,
|
|
user_id: Optional[str] = None
|
|
) -> List[AgentSession]:
|
|
"""
|
|
Gets all active sessions, optionally filtered.
|
|
|
|
Args:
|
|
agent_type: Filter by agent type
|
|
user_id: Filter by user ID
|
|
|
|
Returns:
|
|
List of active sessions
|
|
"""
|
|
if not self.db_pool:
|
|
# Fall back to local cache
|
|
sessions = [
|
|
s for s in self._local_cache.values()
|
|
if s.state == SessionState.ACTIVE
|
|
]
|
|
if agent_type:
|
|
sessions = [s for s in sessions if s.agent_type == agent_type]
|
|
if user_id:
|
|
sessions = [s for s in sessions if s.user_id == user_id]
|
|
return sessions
|
|
|
|
async with self.db_pool.acquire() as conn:
|
|
query = """
|
|
SELECT id, agent_type, user_id, state, context, checkpoints,
|
|
created_at, updated_at, last_heartbeat
|
|
FROM agent_sessions
|
|
WHERE state = 'active'
|
|
"""
|
|
params = []
|
|
|
|
if agent_type:
|
|
query += " AND agent_type = $1"
|
|
params.append(agent_type)
|
|
|
|
if user_id:
|
|
param_num = len(params) + 1
|
|
query += f" AND user_id = ${param_num}"
|
|
params.append(user_id)
|
|
|
|
rows = await conn.fetch(query, *params)
|
|
|
|
return [self._row_to_session(row) for row in rows]
|
|
|
|
async def cleanup_stale_sessions(
|
|
self,
|
|
max_age: timedelta = timedelta(hours=48)
|
|
) -> int:
|
|
"""
|
|
Cleans up stale sessions (no heartbeat for max_age).
|
|
|
|
Args:
|
|
max_age: Maximum age for stale sessions
|
|
|
|
Returns:
|
|
Number of sessions cleaned up
|
|
"""
|
|
cutoff = datetime.now(timezone.utc) - max_age
|
|
|
|
if not self.db_pool:
|
|
# Clean local cache
|
|
stale = [
|
|
sid for sid, s in self._local_cache.items()
|
|
if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE
|
|
]
|
|
for sid in stale:
|
|
session = self._local_cache[sid]
|
|
session.fail("Session timeout - no heartbeat")
|
|
return len(stale)
|
|
|
|
async with self.db_pool.acquire() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
UPDATE agent_sessions
|
|
SET state = 'failed',
|
|
updated_at = NOW(),
|
|
context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb
|
|
WHERE state = 'active' AND last_heartbeat < $1
|
|
""",
|
|
cutoff
|
|
)
|
|
|
|
count = int(result.split()[-1]) if result else 0
|
|
logger.info(f"Cleaned up {count} stale sessions")
|
|
return count
|
|
|
|
async def _persist_session(self, session: AgentSession) -> None:
|
|
"""Persists session to both Valkey and PostgreSQL"""
|
|
await self._cache_in_valkey(session)
|
|
await self._save_to_postgres(session)
|
|
|
|
async def _cache_in_valkey(self, session: AgentSession) -> None:
|
|
"""Caches session in Valkey"""
|
|
if not self.redis:
|
|
return
|
|
|
|
try:
|
|
await self.redis.setex(
|
|
self._redis_key(session.session_id),
|
|
int(self.session_ttl.total_seconds()),
|
|
json.dumps(session.to_dict())
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cache session in Valkey: {e}")
|
|
|
|
async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]:
|
|
"""Retrieves session from Valkey"""
|
|
if not self.redis:
|
|
return None
|
|
|
|
try:
|
|
data = await self.redis.get(self._redis_key(session_id))
|
|
if data:
|
|
return AgentSession.from_dict(json.loads(data))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get session from Valkey: {e}")
|
|
|
|
return None
|
|
|
|
async def _save_to_postgres(self, session: AgentSession) -> None:
|
|
"""Saves session to PostgreSQL"""
|
|
if not self.db_pool:
|
|
return
|
|
|
|
try:
|
|
async with self.db_pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO agent_sessions
|
|
(id, agent_type, user_id, state, context, checkpoints,
|
|
created_at, updated_at, last_heartbeat)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8)
|
|
ON CONFLICT (id) DO UPDATE SET
|
|
state = EXCLUDED.state,
|
|
context = EXCLUDED.context,
|
|
checkpoints = EXCLUDED.checkpoints,
|
|
updated_at = NOW(),
|
|
last_heartbeat = EXCLUDED.last_heartbeat
|
|
""",
|
|
session.session_id,
|
|
session.agent_type,
|
|
session.user_id,
|
|
session.state.value,
|
|
json.dumps(session.context),
|
|
json.dumps([cp.to_dict() for cp in session.checkpoints]),
|
|
session.created_at,
|
|
session.last_heartbeat
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save session to PostgreSQL: {e}")
|
|
|
|
async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]:
|
|
"""Retrieves session from PostgreSQL"""
|
|
if not self.db_pool:
|
|
return None
|
|
|
|
try:
|
|
async with self.db_pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT id, agent_type, user_id, state, context, checkpoints,
|
|
created_at, updated_at, last_heartbeat
|
|
FROM agent_sessions
|
|
WHERE id = $1 AND state != 'deleted'
|
|
""",
|
|
session_id
|
|
)
|
|
|
|
if row:
|
|
return self._row_to_session(row)
|
|
except Exception as e:
|
|
logger.error(f"Failed to get session from PostgreSQL: {e}")
|
|
|
|
return None
|
|
|
|
def _row_to_session(self, row) -> AgentSession:
|
|
"""Converts a database row to AgentSession"""
|
|
checkpoints_data = row["checkpoints"]
|
|
if isinstance(checkpoints_data, str):
|
|
checkpoints_data = json.loads(checkpoints_data)
|
|
|
|
context_data = row["context"]
|
|
if isinstance(context_data, str):
|
|
context_data = json.loads(context_data)
|
|
|
|
return AgentSession(
|
|
session_id=str(row["id"]),
|
|
agent_type=row["agent_type"],
|
|
user_id=row["user_id"] or "",
|
|
state=SessionState(row["state"]),
|
|
created_at=row["created_at"],
|
|
last_heartbeat=row["last_heartbeat"],
|
|
context=context_data,
|
|
checkpoints=[
|
|
SessionCheckpoint.from_dict(cp) for cp in checkpoints_data
|
|
]
|
|
)
|