fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
540
agent-core/sessions/session_manager.py
Normal file
540
agent-core/sessions/session_manager.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Session Management for Breakpilot Agents
|
||||
|
||||
Provides session lifecycle management with:
|
||||
- State tracking (ACTIVE, PAUSED, COMPLETED, FAILED)
|
||||
- Checkpoint-based recovery
|
||||
- Heartbeat integration
|
||||
- Hybrid Valkey + PostgreSQL persistence
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from enum import Enum
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionState(Enum):
|
||||
"""Agent session states"""
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCheckpoint:
|
||||
"""Represents a checkpoint in an agent session"""
|
||||
name: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionCheckpoint":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
data=data["data"]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSession:
|
||||
"""
|
||||
Represents an active agent session.
|
||||
|
||||
Attributes:
|
||||
session_id: Unique session identifier
|
||||
agent_type: Type of agent (tutor, grader, quality-judge, alert, orchestrator)
|
||||
user_id: Associated user ID
|
||||
state: Current session state
|
||||
created_at: Session creation timestamp
|
||||
last_heartbeat: Last heartbeat timestamp
|
||||
context: Session context data
|
||||
checkpoints: List of session checkpoints for recovery
|
||||
"""
|
||||
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
agent_type: str = ""
|
||||
user_id: str = ""
|
||||
state: SessionState = SessionState.ACTIVE
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
checkpoints: List[SessionCheckpoint] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def checkpoint(self, name: str, data: Dict[str, Any]) -> SessionCheckpoint:
|
||||
"""
|
||||
Creates a checkpoint for recovery.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name (e.g., "task_received", "processing_complete")
|
||||
data: Checkpoint data to store
|
||||
|
||||
Returns:
|
||||
The created checkpoint
|
||||
"""
|
||||
checkpoint = SessionCheckpoint(
|
||||
name=name,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
data=data
|
||||
)
|
||||
self.checkpoints.append(checkpoint)
|
||||
logger.debug(f"Session {self.session_id}: Checkpoint '{name}' created")
|
||||
return checkpoint
|
||||
|
||||
def heartbeat(self) -> None:
|
||||
"""Updates the heartbeat timestamp"""
|
||||
self.last_heartbeat = datetime.now(timezone.utc)
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pauses the session"""
|
||||
self.state = SessionState.PAUSED
|
||||
self.checkpoint("session_paused", {"previous_state": "active"})
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resumes a paused session"""
|
||||
if self.state == SessionState.PAUSED:
|
||||
self.state = SessionState.ACTIVE
|
||||
self.heartbeat()
|
||||
self.checkpoint("session_resumed", {})
|
||||
|
||||
def complete(self, result: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as completed"""
|
||||
self.state = SessionState.COMPLETED
|
||||
self.checkpoint("session_completed", {"result": result or {}})
|
||||
|
||||
def fail(self, error: str, error_details: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Marks the session as failed"""
|
||||
self.state = SessionState.FAILED
|
||||
self.checkpoint("session_failed", {
|
||||
"error": error,
|
||||
"details": error_details or {}
|
||||
})
|
||||
|
||||
def get_last_checkpoint(self, name: Optional[str] = None) -> Optional[SessionCheckpoint]:
|
||||
"""
|
||||
Gets the last checkpoint, optionally filtered by name.
|
||||
|
||||
Args:
|
||||
name: Optional checkpoint name to filter by
|
||||
|
||||
Returns:
|
||||
The last matching checkpoint or None
|
||||
"""
|
||||
if not self.checkpoints:
|
||||
return None
|
||||
|
||||
if name:
|
||||
matching = [cp for cp in self.checkpoints if cp.name == name]
|
||||
return matching[-1] if matching else None
|
||||
|
||||
return self.checkpoints[-1]
|
||||
|
||||
def get_duration(self) -> timedelta:
|
||||
"""Returns the session duration"""
|
||||
end_time = datetime.now(timezone.utc)
|
||||
if self.state in (SessionState.COMPLETED, SessionState.FAILED):
|
||||
last_cp = self.get_last_checkpoint()
|
||||
if last_cp:
|
||||
end_time = last_cp.timestamp
|
||||
return end_time - self.created_at
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes the session to a dictionary"""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": self.agent_type,
|
||||
"user_id": self.user_id,
|
||||
"state": self.state.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"last_heartbeat": self.last_heartbeat.isoformat(),
|
||||
"context": self.context,
|
||||
"checkpoints": [cp.to_dict() for cp in self.checkpoints],
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentSession":
|
||||
"""Deserializes a session from a dictionary"""
|
||||
return cls(
|
||||
session_id=data["session_id"],
|
||||
agent_type=data["agent_type"],
|
||||
user_id=data["user_id"],
|
||||
state=SessionState(data["state"]),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
last_heartbeat=datetime.fromisoformat(data["last_heartbeat"]),
|
||||
context=data.get("context", {}),
|
||||
checkpoints=[
|
||||
SessionCheckpoint.from_dict(cp)
|
||||
for cp in data.get("checkpoints", [])
|
||||
],
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
|
||||
|
||||
- Valkey: Fast access for active sessions (24h TTL)
|
||||
- PostgreSQL: Persistent storage for audit trail and recovery
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client=None,
|
||||
db_pool=None,
|
||||
namespace: str = "breakpilot",
|
||||
session_ttl_hours: int = 24
|
||||
):
|
||||
"""
|
||||
Initialize the session manager.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis/Valkey client
|
||||
db_pool: Async PostgreSQL connection pool
|
||||
namespace: Redis key namespace
|
||||
session_ttl_hours: Session TTL in Valkey
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.db_pool = db_pool
|
||||
self.namespace = namespace
|
||||
self.session_ttl = timedelta(hours=session_ttl_hours)
|
||||
self._local_cache: Dict[str, AgentSession] = {}
|
||||
|
||||
def _redis_key(self, session_id: str) -> str:
|
||||
"""Generate Redis key for session"""
|
||||
return f"{self.namespace}:session:{session_id}"
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
agent_type: str,
|
||||
user_id: str = "",
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> AgentSession:
|
||||
"""
|
||||
Creates a new agent session.
|
||||
|
||||
Args:
|
||||
agent_type: Type of agent
|
||||
user_id: Associated user ID
|
||||
context: Initial session context
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
The created AgentSession
|
||||
"""
|
||||
session = AgentSession(
|
||||
agent_type=agent_type,
|
||||
user_id=user_id,
|
||||
context=context or {},
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
session.checkpoint("session_created", {
|
||||
"agent_type": agent_type,
|
||||
"user_id": user_id
|
||||
})
|
||||
|
||||
await self._persist_session(session)
|
||||
|
||||
logger.info(
|
||||
f"Created session {session.session_id} for agent '{agent_type}'"
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""
|
||||
Retrieves a session by ID.
|
||||
|
||||
Lookup order:
|
||||
1. Local cache
|
||||
2. Valkey
|
||||
3. PostgreSQL
|
||||
|
||||
Args:
|
||||
session_id: Session ID to retrieve
|
||||
|
||||
Returns:
|
||||
AgentSession if found, None otherwise
|
||||
"""
|
||||
# Check local cache first
|
||||
if session_id in self._local_cache:
|
||||
return self._local_cache[session_id]
|
||||
|
||||
# Try Valkey
|
||||
session = await self._get_from_valkey(session_id)
|
||||
if session:
|
||||
self._local_cache[session_id] = session
|
||||
return session
|
||||
|
||||
# Fall back to PostgreSQL
|
||||
session = await self._get_from_postgres(session_id)
|
||||
if session:
|
||||
# Re-cache in Valkey for faster access
|
||||
await self._cache_in_valkey(session)
|
||||
self._local_cache[session_id] = session
|
||||
return session
|
||||
|
||||
return None
|
||||
|
||||
async def update_session(self, session: AgentSession) -> None:
|
||||
"""
|
||||
Updates a session in all storage layers.
|
||||
|
||||
Args:
|
||||
session: The session to update
|
||||
"""
|
||||
session.heartbeat()
|
||||
self._local_cache[session.session_id] = session
|
||||
await self._persist_session(session)
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Deletes a session (soft delete in PostgreSQL).
|
||||
|
||||
Args:
|
||||
session_id: Session ID to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
# Remove from local cache
|
||||
self._local_cache.pop(session_id, None)
|
||||
|
||||
# Remove from Valkey
|
||||
if self.redis:
|
||||
await self.redis.delete(self._redis_key(session_id))
|
||||
|
||||
# Soft delete in PostgreSQL
|
||||
if self.db_pool:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE agent_sessions
|
||||
SET state = 'deleted', updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
session_id
|
||||
)
|
||||
return result == "UPDATE 1"
|
||||
|
||||
return True
|
||||
|
||||
async def get_active_sessions(
|
||||
self,
|
||||
agent_type: Optional[str] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> List[AgentSession]:
|
||||
"""
|
||||
Gets all active sessions, optionally filtered.
|
||||
|
||||
Args:
|
||||
agent_type: Filter by agent type
|
||||
user_id: Filter by user ID
|
||||
|
||||
Returns:
|
||||
List of active sessions
|
||||
"""
|
||||
if not self.db_pool:
|
||||
# Fall back to local cache
|
||||
sessions = [
|
||||
s for s in self._local_cache.values()
|
||||
if s.state == SessionState.ACTIVE
|
||||
]
|
||||
if agent_type:
|
||||
sessions = [s for s in sessions if s.agent_type == agent_type]
|
||||
if user_id:
|
||||
sessions = [s for s in sessions if s.user_id == user_id]
|
||||
return sessions
|
||||
|
||||
async with self.db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat
|
||||
FROM agent_sessions
|
||||
WHERE state = 'active'
|
||||
"""
|
||||
params = []
|
||||
|
||||
if agent_type:
|
||||
query += " AND agent_type = $1"
|
||||
params.append(agent_type)
|
||||
|
||||
if user_id:
|
||||
param_num = len(params) + 1
|
||||
query += f" AND user_id = ${param_num}"
|
||||
params.append(user_id)
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
return [self._row_to_session(row) for row in rows]
|
||||
|
||||
async def cleanup_stale_sessions(
|
||||
self,
|
||||
max_age: timedelta = timedelta(hours=48)
|
||||
) -> int:
|
||||
"""
|
||||
Cleans up stale sessions (no heartbeat for max_age).
|
||||
|
||||
Args:
|
||||
max_age: Maximum age for stale sessions
|
||||
|
||||
Returns:
|
||||
Number of sessions cleaned up
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - max_age
|
||||
|
||||
if not self.db_pool:
|
||||
# Clean local cache
|
||||
stale = [
|
||||
sid for sid, s in self._local_cache.items()
|
||||
if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE
|
||||
]
|
||||
for sid in stale:
|
||||
session = self._local_cache[sid]
|
||||
session.fail("Session timeout - no heartbeat")
|
||||
return len(stale)
|
||||
|
||||
async with self.db_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE agent_sessions
|
||||
SET state = 'failed',
|
||||
updated_at = NOW(),
|
||||
context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb
|
||||
WHERE state = 'active' AND last_heartbeat < $1
|
||||
""",
|
||||
cutoff
|
||||
)
|
||||
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
logger.info(f"Cleaned up {count} stale sessions")
|
||||
return count
|
||||
|
||||
async def _persist_session(self, session: AgentSession) -> None:
|
||||
"""Persists session to both Valkey and PostgreSQL"""
|
||||
await self._cache_in_valkey(session)
|
||||
await self._save_to_postgres(session)
|
||||
|
||||
async def _cache_in_valkey(self, session: AgentSession) -> None:
|
||||
"""Caches session in Valkey"""
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.redis.setex(
|
||||
self._redis_key(session.session_id),
|
||||
int(self.session_ttl.total_seconds()),
|
||||
json.dumps(session.to_dict())
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session in Valkey: {e}")
|
||||
|
||||
async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""Retrieves session from Valkey"""
|
||||
if not self.redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self.redis.get(self._redis_key(session_id))
|
||||
if data:
|
||||
return AgentSession.from_dict(json.loads(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get session from Valkey: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _save_to_postgres(self, session: AgentSession) -> None:
|
||||
"""Saves session to PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO agent_sessions
|
||||
(id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
context = EXCLUDED.context,
|
||||
checkpoints = EXCLUDED.checkpoints,
|
||||
updated_at = NOW(),
|
||||
last_heartbeat = EXCLUDED.last_heartbeat
|
||||
""",
|
||||
session.session_id,
|
||||
session.agent_type,
|
||||
session.user_id,
|
||||
session.state.value,
|
||||
json.dumps(session.context),
|
||||
json.dumps([cp.to_dict() for cp in session.checkpoints]),
|
||||
session.created_at,
|
||||
session.last_heartbeat
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session to PostgreSQL: {e}")
|
||||
|
||||
async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]:
|
||||
"""Retrieves session from PostgreSQL"""
|
||||
if not self.db_pool:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with self.db_pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT id, agent_type, user_id, state, context, checkpoints,
|
||||
created_at, updated_at, last_heartbeat
|
||||
FROM agent_sessions
|
||||
WHERE id = $1 AND state != 'deleted'
|
||||
""",
|
||||
session_id
|
||||
)
|
||||
|
||||
if row:
|
||||
return self._row_to_session(row)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get session from PostgreSQL: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _row_to_session(self, row) -> AgentSession:
|
||||
"""Converts a database row to AgentSession"""
|
||||
checkpoints_data = row["checkpoints"]
|
||||
if isinstance(checkpoints_data, str):
|
||||
checkpoints_data = json.loads(checkpoints_data)
|
||||
|
||||
context_data = row["context"]
|
||||
if isinstance(context_data, str):
|
||||
context_data = json.loads(context_data)
|
||||
|
||||
return AgentSession(
|
||||
session_id=str(row["id"]),
|
||||
agent_type=row["agent_type"],
|
||||
user_id=row["user_id"] or "",
|
||||
state=SessionState(row["state"]),
|
||||
created_at=row["created_at"],
|
||||
last_heartbeat=row["last_heartbeat"],
|
||||
context=context_data,
|
||||
checkpoints=[
|
||||
SessionCheckpoint.from_dict(cp) for cp in checkpoints_data
|
||||
]
|
||||
)
|
||||
Reference in New Issue
Block a user