klausur-service (11 files): - cv_gutter_repair, ocr_pipeline_regression, upload_api - ocr_pipeline_sessions, smart_spell, nru_worksheet_generator - ocr_pipeline_overlays, mail/aggregator, zeugnis_api - cv_syllable_detect, self_rag backend-lehrer (17 files): - classroom_engine/suggestions, generators/quiz_generator - worksheets_api, llm_gateway/comparison, state_engine_api - classroom/models (→ 4 submodules), services/file_processor - alerts_agent/api/wizard+digests+routes, content_generators/pdf - classroom/routes/sessions, llm_gateway/inference - classroom_engine/analytics, auth/keycloak_auth - alerts_agent/processing/rule_engine, ai_processor/print_versions agent-core (5 files): - brain/memory_store, brain/knowledge_graph, brain/context_manager - orchestrator/supervisor, sessions/session_manager admin-lehrer (5 components): - GridOverlay, StepGridReview, DevOpsPipelineSidebar - DataFlowDiagram, sbom/wizard/page website (2 files): - DependencyMap, lehrer/abitur-archiv Other: nibis_ingestion, grid_detection_service, export-doclayout-onnx Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
378 lines
12 KiB
Python
378 lines
12 KiB
Python
"""
|
|
Session Management for Breakpilot Agents
|
|
|
|
Provides session lifecycle management with:
|
|
- Hybrid Valkey + PostgreSQL persistence
|
|
- Session CRUD operations
|
|
- Stale session cleanup
|
|
"""
|
|
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Dict, Any, Optional, List
|
|
import json
|
|
import logging
|
|
|
|
from agent_core.sessions.session_models import (
|
|
SessionState,
|
|
SessionCheckpoint,
|
|
AgentSession,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SessionManager:
|
|
"""
|
|
Manages agent sessions with hybrid Valkey + PostgreSQL persistence.
|
|
|
|
- Valkey: Fast access for active sessions (24h TTL)
|
|
- PostgreSQL: Persistent storage for audit trail and recovery
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_client=None,
|
|
db_pool=None,
|
|
namespace: str = "breakpilot",
|
|
session_ttl_hours: int = 24
|
|
):
|
|
"""
|
|
Initialize the session manager.
|
|
|
|
Args:
|
|
redis_client: Async Redis/Valkey client
|
|
db_pool: Async PostgreSQL connection pool
|
|
namespace: Redis key namespace
|
|
session_ttl_hours: Session TTL in Valkey
|
|
"""
|
|
self.redis = redis_client
|
|
self.db_pool = db_pool
|
|
self.namespace = namespace
|
|
self.session_ttl = timedelta(hours=session_ttl_hours)
|
|
self._local_cache: Dict[str, AgentSession] = {}
|
|
|
|
def _redis_key(self, session_id: str) -> str:
|
|
"""Generate Redis key for session"""
|
|
return f"{self.namespace}:session:{session_id}"
|
|
|
|
async def create_session(
|
|
self,
|
|
agent_type: str,
|
|
user_id: str = "",
|
|
context: Optional[Dict[str, Any]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> AgentSession:
|
|
"""
|
|
Creates a new agent session.
|
|
|
|
Args:
|
|
agent_type: Type of agent
|
|
user_id: Associated user ID
|
|
context: Initial session context
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
The created AgentSession
|
|
"""
|
|
session = AgentSession(
|
|
agent_type=agent_type,
|
|
user_id=user_id,
|
|
context=context or {},
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
session.checkpoint("session_created", {
|
|
"agent_type": agent_type,
|
|
"user_id": user_id
|
|
})
|
|
|
|
self._local_cache[session.session_id] = session
|
|
await self._persist_session(session)
|
|
|
|
logger.info(
|
|
f"Created session {session.session_id} for agent '{agent_type}'"
|
|
)
|
|
|
|
return session
|
|
|
|
async def get_session(self, session_id: str) -> Optional[AgentSession]:
|
|
"""
|
|
Retrieves a session by ID.
|
|
|
|
Lookup order:
|
|
1. Local cache
|
|
2. Valkey
|
|
3. PostgreSQL
|
|
|
|
Args:
|
|
session_id: Session ID to retrieve
|
|
|
|
Returns:
|
|
AgentSession if found, None otherwise
|
|
"""
|
|
# Check local cache first
|
|
if session_id in self._local_cache:
|
|
return self._local_cache[session_id]
|
|
|
|
# Try Valkey
|
|
session = await self._get_from_valkey(session_id)
|
|
if session:
|
|
self._local_cache[session_id] = session
|
|
return session
|
|
|
|
# Fall back to PostgreSQL
|
|
session = await self._get_from_postgres(session_id)
|
|
if session:
|
|
# Re-cache in Valkey for faster access
|
|
await self._cache_in_valkey(session)
|
|
self._local_cache[session_id] = session
|
|
return session
|
|
|
|
return None
|
|
|
|
async def update_session(self, session: AgentSession) -> None:
|
|
"""
|
|
Updates a session in all storage layers.
|
|
|
|
Args:
|
|
session: The session to update
|
|
"""
|
|
session.heartbeat()
|
|
self._local_cache[session.session_id] = session
|
|
await self._persist_session(session)
|
|
|
|
async def delete_session(self, session_id: str) -> bool:
|
|
"""
|
|
Deletes a session (soft delete in PostgreSQL).
|
|
|
|
Args:
|
|
session_id: Session ID to delete
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
# Remove from local cache
|
|
self._local_cache.pop(session_id, None)
|
|
|
|
# Remove from Valkey
|
|
if self.redis:
|
|
await self.redis.delete(self._redis_key(session_id))
|
|
|
|
# Soft delete in PostgreSQL
|
|
if self.db_pool:
|
|
async with self.db_pool.acquire() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
UPDATE agent_sessions
|
|
SET state = 'deleted', updated_at = NOW()
|
|
WHERE id = $1
|
|
""",
|
|
session_id
|
|
)
|
|
return result == "UPDATE 1"
|
|
|
|
return True
|
|
|
|
async def get_active_sessions(
|
|
self,
|
|
agent_type: Optional[str] = None,
|
|
user_id: Optional[str] = None
|
|
) -> List[AgentSession]:
|
|
"""
|
|
Gets all active sessions, optionally filtered.
|
|
|
|
Args:
|
|
agent_type: Filter by agent type
|
|
user_id: Filter by user ID
|
|
|
|
Returns:
|
|
List of active sessions
|
|
"""
|
|
if not self.db_pool:
|
|
# Fall back to local cache
|
|
sessions = [
|
|
s for s in self._local_cache.values()
|
|
if s.state == SessionState.ACTIVE
|
|
]
|
|
if agent_type:
|
|
sessions = [s for s in sessions if s.agent_type == agent_type]
|
|
if user_id:
|
|
sessions = [s for s in sessions if s.user_id == user_id]
|
|
return sessions
|
|
|
|
async with self.db_pool.acquire() as conn:
|
|
query = """
|
|
SELECT id, agent_type, user_id, state, context, checkpoints,
|
|
created_at, updated_at, last_heartbeat
|
|
FROM agent_sessions
|
|
WHERE state = 'active'
|
|
"""
|
|
params = []
|
|
|
|
if agent_type:
|
|
query += " AND agent_type = $1"
|
|
params.append(agent_type)
|
|
|
|
if user_id:
|
|
param_num = len(params) + 1
|
|
query += f" AND user_id = ${param_num}"
|
|
params.append(user_id)
|
|
|
|
rows = await conn.fetch(query, *params)
|
|
|
|
return [self._row_to_session(row) for row in rows]
|
|
|
|
async def cleanup_stale_sessions(
|
|
self,
|
|
max_age: timedelta = timedelta(hours=48)
|
|
) -> int:
|
|
"""
|
|
Cleans up stale sessions (no heartbeat for max_age).
|
|
|
|
Args:
|
|
max_age: Maximum age for stale sessions
|
|
|
|
Returns:
|
|
Number of sessions cleaned up
|
|
"""
|
|
cutoff = datetime.now(timezone.utc) - max_age
|
|
|
|
if not self.db_pool:
|
|
# Clean local cache
|
|
stale = [
|
|
sid for sid, s in self._local_cache.items()
|
|
if s.last_heartbeat < cutoff and s.state == SessionState.ACTIVE
|
|
]
|
|
for sid in stale:
|
|
session = self._local_cache[sid]
|
|
session.fail("Session timeout - no heartbeat")
|
|
return len(stale)
|
|
|
|
async with self.db_pool.acquire() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
UPDATE agent_sessions
|
|
SET state = 'failed',
|
|
updated_at = NOW(),
|
|
context = context || '{"failure_reason": "heartbeat_timeout"}'::jsonb
|
|
WHERE state = 'active' AND last_heartbeat < $1
|
|
""",
|
|
cutoff
|
|
)
|
|
|
|
count = int(result.split()[-1]) if result else 0
|
|
logger.info(f"Cleaned up {count} stale sessions")
|
|
return count
|
|
|
|
async def _persist_session(self, session: AgentSession) -> None:
|
|
"""Persists session to both Valkey and PostgreSQL"""
|
|
await self._cache_in_valkey(session)
|
|
await self._save_to_postgres(session)
|
|
|
|
async def _cache_in_valkey(self, session: AgentSession) -> None:
|
|
"""Caches session in Valkey"""
|
|
if not self.redis:
|
|
return
|
|
|
|
try:
|
|
await self.redis.setex(
|
|
self._redis_key(session.session_id),
|
|
int(self.session_ttl.total_seconds()),
|
|
json.dumps(session.to_dict())
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cache session in Valkey: {e}")
|
|
|
|
async def _get_from_valkey(self, session_id: str) -> Optional[AgentSession]:
|
|
"""Retrieves session from Valkey"""
|
|
if not self.redis:
|
|
return None
|
|
|
|
try:
|
|
data = await self.redis.get(self._redis_key(session_id))
|
|
if data:
|
|
return AgentSession.from_dict(json.loads(data))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get session from Valkey: {e}")
|
|
|
|
return None
|
|
|
|
async def _save_to_postgres(self, session: AgentSession) -> None:
|
|
"""Saves session to PostgreSQL"""
|
|
if not self.db_pool:
|
|
return
|
|
|
|
try:
|
|
async with self.db_pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO agent_sessions
|
|
(id, agent_type, user_id, state, context, checkpoints,
|
|
created_at, updated_at, last_heartbeat)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), $8)
|
|
ON CONFLICT (id) DO UPDATE SET
|
|
state = EXCLUDED.state,
|
|
context = EXCLUDED.context,
|
|
checkpoints = EXCLUDED.checkpoints,
|
|
updated_at = NOW(),
|
|
last_heartbeat = EXCLUDED.last_heartbeat
|
|
""",
|
|
session.session_id,
|
|
session.agent_type,
|
|
session.user_id,
|
|
session.state.value,
|
|
json.dumps(session.context),
|
|
json.dumps([cp.to_dict() for cp in session.checkpoints]),
|
|
session.created_at,
|
|
session.last_heartbeat
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save session to PostgreSQL: {e}")
|
|
|
|
async def _get_from_postgres(self, session_id: str) -> Optional[AgentSession]:
|
|
"""Retrieves session from PostgreSQL"""
|
|
if not self.db_pool:
|
|
return None
|
|
|
|
try:
|
|
async with self.db_pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT id, agent_type, user_id, state, context, checkpoints,
|
|
created_at, updated_at, last_heartbeat
|
|
FROM agent_sessions
|
|
WHERE id = $1 AND state != 'deleted'
|
|
""",
|
|
session_id
|
|
)
|
|
|
|
if row:
|
|
return self._row_to_session(row)
|
|
except Exception as e:
|
|
logger.error(f"Failed to get session from PostgreSQL: {e}")
|
|
|
|
return None
|
|
|
|
def _row_to_session(self, row) -> AgentSession:
|
|
"""Converts a database row to AgentSession"""
|
|
checkpoints_data = row["checkpoints"]
|
|
if isinstance(checkpoints_data, str):
|
|
checkpoints_data = json.loads(checkpoints_data)
|
|
|
|
context_data = row["context"]
|
|
if isinstance(context_data, str):
|
|
context_data = json.loads(context_data)
|
|
|
|
return AgentSession(
|
|
session_id=str(row["id"]),
|
|
agent_type=row["agent_type"],
|
|
user_id=row["user_id"] or "",
|
|
state=SessionState(row["state"]),
|
|
created_at=row["created_at"],
|
|
last_heartbeat=row["last_heartbeat"],
|
|
context=context_data,
|
|
checkpoints=[
|
|
SessionCheckpoint.from_dict(cp) for cp in checkpoints_data
|
|
]
|
|
)
|