""" Hybrid Session Store: Valkey + PostgreSQL Architecture: - Valkey: Fast session cache with 24-hour TTL - PostgreSQL: Persistent storage and DSGVO audit trail - Graceful fallback: If Valkey is down, fall back to PostgreSQL Session data model: { "session_id": "uuid", "user_id": "uuid", "email": "string", "user_type": "employee|customer", "roles": ["role1", "role2"], "permissions": ["perm1", "perm2"], "tenant_id": "school-uuid", "ip_address": "string", "user_agent": "string", "created_at": "timestamp", "last_activity_at": "timestamp" } """ import os import json import hashlib import logging from datetime import datetime, timezone, timedelta from typing import Optional, Dict, Any, List from dataclasses import dataclass, asdict, field from enum import Enum import asyncio logger = logging.getLogger(__name__) class UserType(str, Enum): """User type distinction for RBAC.""" EMPLOYEE = "employee" # Internal staff (teachers, admins) CUSTOMER = "customer" # External users (parents, students) @dataclass class Session: """Session data model.""" session_id: str user_id: str email: str user_type: UserType roles: List[str] = field(default_factory=list) permissions: List[str] = field(default_factory=list) tenant_id: Optional[str] = None ip_address: Optional[str] = None user_agent: Optional[str] = None created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) last_activity_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for JSON serialization.""" return { "session_id": self.session_id, "user_id": self.user_id, "email": self.email, "user_type": self.user_type.value if isinstance(self.user_type, UserType) else self.user_type, "roles": self.roles, "permissions": self.permissions, "tenant_id": self.tenant_id, "ip_address": self.ip_address, "user_agent": self.user_agent, "created_at": self.created_at.isoformat() if self.created_at else None, "last_activity_at": self.last_activity_at.isoformat() if self.last_activity_at else None, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "Session": """Create Session from dictionary.""" user_type = data.get("user_type", "customer") if isinstance(user_type, str): user_type = UserType(user_type) created_at = data.get("created_at") if isinstance(created_at, str): created_at = datetime.fromisoformat(created_at.replace("Z", "+00:00")) last_activity_at = data.get("last_activity_at") if isinstance(last_activity_at, str): last_activity_at = datetime.fromisoformat(last_activity_at.replace("Z", "+00:00")) return cls( session_id=data["session_id"], user_id=data["user_id"], email=data.get("email", ""), user_type=user_type, roles=data.get("roles", []), permissions=data.get("permissions", []), tenant_id=data.get("tenant_id"), ip_address=data.get("ip_address"), user_agent=data.get("user_agent"), created_at=created_at or datetime.now(timezone.utc), last_activity_at=last_activity_at or datetime.now(timezone.utc), ) def has_permission(self, permission: str) -> bool: """Check if session has a specific permission.""" return permission in self.permissions def has_any_permission(self, permissions: List[str]) -> bool: """Check if session has any of the specified permissions.""" return any(p in self.permissions for p in permissions) def has_all_permissions(self, permissions: List[str]) -> bool: """Check if session has all specified permissions.""" return all(p in self.permissions for p in permissions) def has_role(self, role: str) -> bool: """Check if session has a specific role.""" return role in self.roles def is_employee(self) -> bool: """Check if user is an employee (internal staff).""" return self.user_type == UserType.EMPLOYEE def is_customer(self) -> bool: """Check if user is a customer (external user).""" return self.user_type == UserType.CUSTOMER class SessionStore: """ Hybrid session store using Valkey and PostgreSQL. Valkey: Primary storage with 24h TTL for fast lookups PostgreSQL: Persistent backup and audit trail """ def __init__( self, valkey_url: Optional[str] = None, database_url: Optional[str] = None, session_ttl_hours: int = 24, ): self.valkey_url = valkey_url or os.environ.get("VALKEY_URL", "redis://localhost:6379") self.database_url = database_url or os.environ.get("DATABASE_URL") self.session_ttl = timedelta(hours=session_ttl_hours) self.session_ttl_seconds = session_ttl_hours * 3600 self._valkey_client = None self._pg_pool = None self._valkey_available = True async def connect(self): """Initialize connections to Valkey and PostgreSQL.""" await self._connect_valkey() await self._connect_postgres() async def _connect_valkey(self): """Connect to Valkey (Redis-compatible).""" try: import redis.asyncio as redis self._valkey_client = redis.from_url( self.valkey_url, encoding="utf-8", decode_responses=True, ) # Test connection await self._valkey_client.ping() self._valkey_available = True logger.info("Connected to Valkey session cache") except ImportError: logger.warning("redis package not installed, Valkey unavailable") self._valkey_available = False except Exception as e: logger.warning(f"Valkey connection failed, falling back to PostgreSQL: {e}") self._valkey_available = False async def _connect_postgres(self): """Connect to PostgreSQL.""" if not self.database_url: logger.warning("DATABASE_URL not set, PostgreSQL unavailable") return try: import asyncpg self._pg_pool = await asyncpg.create_pool( self.database_url, min_size=2, max_size=10, ) logger.info("Connected to PostgreSQL session store") except ImportError: logger.warning("asyncpg package not installed") except Exception as e: logger.error(f"PostgreSQL connection failed: {e}") async def close(self): """Close all connections.""" if self._valkey_client: await self._valkey_client.close() if self._pg_pool: await self._pg_pool.close() def _get_valkey_key(self, session_id: str) -> str: """Generate Valkey key for session.""" return f"session:{session_id}" def _hash_token(self, token: str) -> str: """Hash token for PostgreSQL storage.""" return hashlib.sha256(token.encode()).hexdigest() async def create_session( self, user_id: str, email: str, user_type: UserType, roles: List[str], permissions: List[str], tenant_id: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> Session: """ Create a new session. Stores in both Valkey (with TTL) and PostgreSQL (persistent). Returns the session with generated session_id. """ import uuid session = Session( session_id=str(uuid.uuid4()), user_id=user_id, email=email, user_type=user_type, roles=roles, permissions=permissions, tenant_id=tenant_id, ip_address=ip_address, user_agent=user_agent, ) # Store in Valkey (primary) if self._valkey_available and self._valkey_client: try: key = self._get_valkey_key(session.session_id) await self._valkey_client.setex( key, self.session_ttl_seconds, json.dumps(session.to_dict()), ) except Exception as e: logger.error(f"Failed to store session in Valkey: {e}") self._valkey_available = False # Store in PostgreSQL (backup + audit) if self._pg_pool: try: async with self._pg_pool.acquire() as conn: await conn.execute( """ INSERT INTO user_sessions ( id, user_id, token_hash, email, user_type, roles, permissions, tenant_id, ip_address, user_agent, expires_at, created_at, last_activity_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) """, session.session_id, session.user_id, self._hash_token(session.session_id), session.email, session.user_type.value, json.dumps(session.roles), json.dumps(session.permissions), session.tenant_id, session.ip_address, session.user_agent, datetime.now(timezone.utc) + self.session_ttl, session.created_at, session.last_activity_at, ) except Exception as e: logger.error(f"Failed to store session in PostgreSQL: {e}") return session async def get_session(self, session_id: str) -> Optional[Session]: """ Get session by ID. Tries Valkey first (fast), falls back to PostgreSQL. """ # Try Valkey first if self._valkey_available and self._valkey_client: try: key = self._get_valkey_key(session_id) data = await self._valkey_client.get(key) if data: session = Session.from_dict(json.loads(data)) # Update last activity await self._update_last_activity(session_id) return session except Exception as e: logger.warning(f"Valkey lookup failed, trying PostgreSQL: {e}") self._valkey_available = False # Fall back to PostgreSQL if self._pg_pool: try: async with self._pg_pool.acquire() as conn: row = await conn.fetchrow( """ SELECT id, user_id, email, user_type, roles, permissions, tenant_id, ip_address, user_agent, created_at, last_activity_at FROM user_sessions WHERE id = $1 AND revoked_at IS NULL AND expires_at > NOW() """, session_id, ) if row: session = Session( session_id=str(row["id"]), user_id=str(row["user_id"]), email=row["email"] or "", user_type=UserType(row["user_type"]) if row["user_type"] else UserType.CUSTOMER, roles=json.loads(row["roles"]) if row["roles"] else [], permissions=json.loads(row["permissions"]) if row["permissions"] else [], tenant_id=str(row["tenant_id"]) if row["tenant_id"] else None, ip_address=row["ip_address"], user_agent=row["user_agent"], created_at=row["created_at"], last_activity_at=row["last_activity_at"], ) # Re-cache in Valkey if it's back up await self._cache_in_valkey(session) return session except Exception as e: logger.error(f"PostgreSQL session lookup failed: {e}") return None async def _update_last_activity(self, session_id: str): """Update last activity timestamp.""" now = datetime.now(timezone.utc) # Update Valkey TTL if self._valkey_available and self._valkey_client: try: key = self._get_valkey_key(session_id) # Refresh TTL await self._valkey_client.expire(key, self.session_ttl_seconds) except Exception: pass # Update PostgreSQL if self._pg_pool: try: async with self._pg_pool.acquire() as conn: await conn.execute( """ UPDATE user_sessions SET last_activity_at = $1, expires_at = $2 WHERE id = $3 """, now, now + self.session_ttl, session_id, ) except Exception: pass async def _cache_in_valkey(self, session: Session): """Re-cache session in Valkey after PostgreSQL fallback.""" if self._valkey_available and self._valkey_client: try: key = self._get_valkey_key(session.session_id) await self._valkey_client.setex( key, self.session_ttl_seconds, json.dumps(session.to_dict()), ) except Exception: pass async def revoke_session(self, session_id: str) -> bool: """ Revoke a session (logout). Removes from Valkey and marks as revoked in PostgreSQL. """ success = False # Remove from Valkey if self._valkey_available and self._valkey_client: try: key = self._get_valkey_key(session_id) await self._valkey_client.delete(key) success = True except Exception as e: logger.error(f"Failed to revoke session in Valkey: {e}") # Mark as revoked in PostgreSQL if self._pg_pool: try: async with self._pg_pool.acquire() as conn: await conn.execute( """ UPDATE user_sessions SET revoked_at = NOW() WHERE id = $1 """, session_id, ) success = True except Exception as e: logger.error(f"Failed to revoke session in PostgreSQL: {e}") return success async def revoke_all_user_sessions(self, user_id: str) -> int: """ Revoke all sessions for a user (force logout from all devices). Returns the number of sessions revoked. """ count = 0 # Get all session IDs for user from PostgreSQL if self._pg_pool: try: async with self._pg_pool.acquire() as conn: rows = await conn.fetch( """ SELECT id FROM user_sessions WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() """, user_id, ) session_ids = [str(row["id"]) for row in rows] # Revoke in PostgreSQL result = await conn.execute( """ UPDATE user_sessions SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL """, user_id, ) count = int(result.split()[-1]) if result else 0 # Remove from Valkey if self._valkey_available and self._valkey_client: for session_id in session_ids: try: key = self._get_valkey_key(session_id) await self._valkey_client.delete(key) except Exception: pass except Exception as e: logger.error(f"Failed to revoke all user sessions: {e}") return count async def get_active_sessions(self, user_id: str) -> List[Session]: """Get all active sessions for a user.""" sessions = [] if self._pg_pool: try: async with self._pg_pool.acquire() as conn: rows = await conn.fetch( """ SELECT id, user_id, email, user_type, roles, permissions, tenant_id, ip_address, user_agent, created_at, last_activity_at FROM user_sessions WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() ORDER BY last_activity_at DESC """, user_id, ) for row in rows: sessions.append(Session( session_id=str(row["id"]), user_id=str(row["user_id"]), email=row["email"] or "", user_type=UserType(row["user_type"]) if row["user_type"] else UserType.CUSTOMER, roles=json.loads(row["roles"]) if row["roles"] else [], permissions=json.loads(row["permissions"]) if row["permissions"] else [], tenant_id=str(row["tenant_id"]) if row["tenant_id"] else None, ip_address=row["ip_address"], user_agent=row["user_agent"], created_at=row["created_at"], last_activity_at=row["last_activity_at"], )) except Exception as e: logger.error(f"Failed to get active sessions: {e}") return sessions async def cleanup_expired_sessions(self) -> int: """ Clean up expired sessions from PostgreSQL. This is meant to be called by a background job. Returns the number of sessions cleaned up. """ count = 0 if self._pg_pool: try: async with self._pg_pool.acquire() as conn: result = await conn.execute( """ DELETE FROM user_sessions WHERE expires_at < NOW() - INTERVAL '7 days' """ ) count = int(result.split()[-1]) if result else 0 logger.info(f"Cleaned up {count} expired sessions") except Exception as e: logger.error(f"Session cleanup failed: {e}") return count # Global session store instance _session_store: Optional[SessionStore] = None async def get_session_store() -> SessionStore: """Get or create the global session store instance.""" global _session_store if _session_store is None: ttl_hours = int(os.environ.get("SESSION_TTL_HOURS", "24")) _session_store = SessionStore(session_ttl_hours=ttl_hours) await _session_store.connect() return _session_store