""" Session Middleware for FastAPI Provides session-based authentication as an alternative to JWT. Sessions are stored in Valkey with PostgreSQL fallback. Usage: @app.get("/api/protected/profile") async def get_profile(session: Session = Depends(get_current_session)): return {"user_id": session.user_id} """ import os import logging from typing import Optional, Dict, Any, Callable from functools import wraps from fastapi import Request, HTTPException, Depends from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from .session_store import Session, SessionStore, get_session_store logger = logging.getLogger(__name__) class SessionMiddleware(BaseHTTPMiddleware): """ Middleware that extracts session from request and adds to request.state. Session can be provided via: 1. Authorization header: Bearer 2. Cookie: session_id= """ def __init__(self, app, session_cookie_name: str = "session_id"): super().__init__(app) self.session_cookie_name = session_cookie_name async def dispatch(self, request: Request, call_next: Callable) -> Response: """Extract session and add to request state.""" session_id = self._extract_session_id(request) if session_id: try: store = await get_session_store() session = await store.get_session(session_id) request.state.session = session except Exception as e: logger.error(f"Failed to load session: {e}") request.state.session = None else: request.state.session = None response = await call_next(request) return response def _extract_session_id(self, request: Request) -> Optional[str]: """Extract session ID from request.""" # Try Authorization header first auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): return auth_header.split(" ")[1] # Try cookie return request.cookies.get(self.session_cookie_name) def session_middleware(app, session_cookie_name: str = "session_id"): """Factory function to add session middleware to app.""" return SessionMiddleware(app, session_cookie_name) async def get_current_session(request: Request) -> Session: """ FastAPI dependency to get current session. Raises 401 if no valid session found. Usage: @app.get("/api/protected/endpoint") async def protected(session: Session = Depends(get_current_session)): return {"user_id": session.user_id} """ # Check if middleware added session to state session = getattr(request.state, "session", None) if session: return session # Middleware might not be installed, try manual extraction session_id = _extract_session_id_from_request(request) if not session_id: # Check for development mode bypass environment = os.environ.get("ENVIRONMENT", "development") if environment == "development": # Return demo session in development return _get_demo_session() raise HTTPException(status_code=401, detail="Authentication required") try: store = await get_session_store() session = await store.get_session(session_id) if not session: raise HTTPException(status_code=401, detail="Invalid or expired session") return session except HTTPException: raise except Exception as e: logger.error(f"Session validation failed: {e}") raise HTTPException(status_code=401, detail="Session validation failed") async def get_optional_session(request: Request) -> Optional[Session]: """ FastAPI dependency to get current session if present. Returns None if no session (doesn't raise exception). Useful for endpoints that behave differently for logged in users. Usage: @app.get("/api/public/endpoint") async def public(session: Optional[Session] = Depends(get_optional_session)): if session: return {"message": f"Hello, {session.email}"} return {"message": "Hello, anonymous"} """ try: return await get_current_session(request) except HTTPException: return None def require_session(func: Callable) -> Callable: """ Decorator to require valid session for an endpoint. Alternative to using Depends(get_current_session). Usage: @app.get("/api/protected/endpoint") @require_session async def protected(request: Request): session = request.state.session return {"user_id": session.user_id} """ @wraps(func) async def wrapper(request: Request, *args, **kwargs): session = await get_current_session(request) request.state.session = session return await func(request, *args, **kwargs) return wrapper def _extract_session_id_from_request(request: Request) -> Optional[str]: """Extract session ID from request headers or cookies.""" # Try Authorization header auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): return auth_header.split(" ")[1] # Try X-Session-ID header session_header = request.headers.get("x-session-id") if session_header: return session_header # Try cookie return request.cookies.get("session_id") def _get_demo_session() -> Session: """Get demo session for development mode.""" from .session_store import UserType return Session( session_id="demo-session-id", user_id="10000000-0000-0000-0000-000000000024", email="demo@breakpilot.app", user_type=UserType.EMPLOYEE, roles=["admin", "schul_admin", "teacher"], permissions=[ "grades:read", "grades:write", "attendance:read", "attendance:write", "students:read", "students:write", "reports:generate", "consent:admin", "own_data:read", "users:manage", ], tenant_id="a0000000-0000-0000-0000-000000000001", ip_address="127.0.0.1", user_agent="Development", ) class SessionAuthBackend: """ Authentication backend for Starlette AuthenticationMiddleware. Alternative way to integrate session auth. """ async def authenticate(self, request: Request): """Authenticate request using session.""" from starlette.authentication import ( AuthCredentials, BaseUser, UnauthenticatedUser ) session_id = _extract_session_id_from_request(request) if not session_id: return AuthCredentials([]), UnauthenticatedUser() try: store = await get_session_store() session = await store.get_session(session_id) if not session: return AuthCredentials([]), UnauthenticatedUser() return AuthCredentials(session.permissions), SessionUser(session) except Exception: return AuthCredentials([]), UnauthenticatedUser() class SessionUser: """User object compatible with Starlette authentication.""" def __init__(self, session: Session): self.session = session @property def is_authenticated(self) -> bool: return True @property def display_name(self) -> str: return self.session.email @property def identity(self) -> str: return self.session.user_id