fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
52
backend/session/__init__.py
Normal file
52
backend/session/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Session Management Module for BreakPilot
|
||||
|
||||
Hybrid session storage using Valkey (Redis-fork) for fast lookups
|
||||
and PostgreSQL for persistence and DSGVO audit trail.
|
||||
|
||||
Components:
|
||||
- session_store.py: Hybrid Valkey + PostgreSQL session storage
|
||||
- session_middleware.py: FastAPI middleware for session-based auth
|
||||
- rbac_middleware.py: User type and permission checking
|
||||
"""
|
||||
|
||||
from .session_store import (
|
||||
SessionStore,
|
||||
Session,
|
||||
UserType,
|
||||
get_session_store,
|
||||
)
|
||||
from .session_middleware import (
|
||||
get_current_session,
|
||||
require_session,
|
||||
session_middleware,
|
||||
)
|
||||
from .rbac_middleware import (
|
||||
require_user_type,
|
||||
require_permission,
|
||||
require_any_permission,
|
||||
require_employee,
|
||||
require_customer,
|
||||
EMPLOYEE_PERMISSIONS,
|
||||
CUSTOMER_PERMISSIONS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Session Store
|
||||
"SessionStore",
|
||||
"Session",
|
||||
"UserType",
|
||||
"get_session_store",
|
||||
# Session Middleware
|
||||
"get_current_session",
|
||||
"require_session",
|
||||
"session_middleware",
|
||||
# RBAC Middleware
|
||||
"require_user_type",
|
||||
"require_permission",
|
||||
"require_any_permission",
|
||||
"require_employee",
|
||||
"require_customer",
|
||||
"EMPLOYEE_PERMISSIONS",
|
||||
"CUSTOMER_PERMISSIONS",
|
||||
]
|
||||
141
backend/session/cleanup_job.py
Normal file
141
backend/session/cleanup_job.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Session Cleanup Job
|
||||
|
||||
Removes expired sessions from PostgreSQL.
|
||||
Valkey handles its own expiry via TTL.
|
||||
|
||||
This job should be run periodically (e.g., via cron or APScheduler).
|
||||
|
||||
Usage:
|
||||
# Run directly
|
||||
python -m session.cleanup_job
|
||||
|
||||
# Or import and call
|
||||
from session.cleanup_job import run_cleanup
|
||||
|
||||
await run_cleanup()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_cleanup():
|
||||
"""Run session cleanup job."""
|
||||
from .session_store import get_session_store
|
||||
|
||||
logger.info("Starting session cleanup job...")
|
||||
|
||||
try:
|
||||
store = await get_session_store()
|
||||
count = await store.cleanup_expired_sessions()
|
||||
logger.info(f"Session cleanup completed: removed {count} expired sessions")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Session cleanup failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def run_cleanup_with_pg():
|
||||
"""
|
||||
Run cleanup directly with PostgreSQL connection.
|
||||
|
||||
Useful when session store is not initialized.
|
||||
"""
|
||||
database_url = os.environ.get("DATABASE_URL")
|
||||
if not database_url:
|
||||
logger.warning("DATABASE_URL not set, skipping cleanup")
|
||||
return 0
|
||||
|
||||
try:
|
||||
import asyncpg
|
||||
|
||||
conn = await asyncpg.connect(database_url)
|
||||
try:
|
||||
# Delete sessions expired more than 7 days ago
|
||||
result = await conn.execute("""
|
||||
DELETE FROM user_sessions
|
||||
WHERE expires_at < NOW() - INTERVAL '7 days'
|
||||
""")
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
logger.info(f"Session cleanup completed: removed {count} expired sessions")
|
||||
return count
|
||||
finally:
|
||||
await conn.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Session cleanup failed: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def setup_scheduler():
|
||||
"""
|
||||
Setup APScheduler for periodic cleanup.
|
||||
|
||||
Runs cleanup every 6 hours.
|
||||
"""
|
||||
try:
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
scheduler.add_job(
|
||||
run_cleanup,
|
||||
trigger=IntervalTrigger(hours=6),
|
||||
id="session_cleanup",
|
||||
name="Session Cleanup Job",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Session cleanup scheduler started (runs every 6 hours)")
|
||||
|
||||
return scheduler
|
||||
|
||||
except ImportError:
|
||||
logger.warning("APScheduler not installed, cleanup job not scheduled")
|
||||
return None
|
||||
|
||||
|
||||
def register_with_fastapi(app):
|
||||
"""
|
||||
Register cleanup job with FastAPI app lifecycle.
|
||||
|
||||
Usage:
|
||||
from session.cleanup_job import register_with_fastapi
|
||||
register_with_fastapi(app)
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
scheduler = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
nonlocal scheduler
|
||||
# Startup
|
||||
scheduler = setup_scheduler()
|
||||
# Run initial cleanup
|
||||
asyncio.create_task(run_cleanup())
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if scheduler:
|
||||
scheduler.shutdown()
|
||||
|
||||
return lifespan
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
# Run cleanup
|
||||
asyncio.run(run_cleanup_with_pg())
|
||||
389
backend/session/protected_routes.py
Normal file
389
backend/session/protected_routes.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
Protected Routes Example
|
||||
|
||||
Shows how to structure routes under /api/protected with session-based auth.
|
||||
|
||||
Route structure:
|
||||
/api/auth/ - Public (login, register, logout)
|
||||
/api/public/ - Public (health, docs)
|
||||
/api/protected/ - Authenticated (all users)
|
||||
/api/protected/employee/ - Employees only
|
||||
/api/protected/customer/ - Customers only
|
||||
/api/admin/ - Admins only
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import List, Optional
|
||||
|
||||
from .session_store import Session, UserType
|
||||
from .session_middleware import get_current_session, get_optional_session
|
||||
from .rbac_middleware import (
|
||||
require_employee,
|
||||
require_customer,
|
||||
require_permission,
|
||||
require_role,
|
||||
require_any_role,
|
||||
)
|
||||
|
||||
# =============================================
|
||||
# Router Setup
|
||||
# =============================================
|
||||
|
||||
# Protected routes - require authentication
|
||||
protected_router = APIRouter(prefix="/api/protected", tags=["Protected"])
|
||||
|
||||
# Employee-only routes
|
||||
employee_router = APIRouter(prefix="/api/protected/employee", tags=["Employee"])
|
||||
|
||||
# Customer-only routes
|
||||
customer_router = APIRouter(prefix="/api/protected/customer", tags=["Customer"])
|
||||
|
||||
# Admin routes
|
||||
admin_router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
|
||||
# =============================================
|
||||
# Protected Routes (All Authenticated Users)
|
||||
# =============================================
|
||||
|
||||
@protected_router.get("/profile")
|
||||
async def get_profile(session: Session = Depends(get_current_session)):
|
||||
"""Get current user's profile."""
|
||||
return {
|
||||
"user_id": session.user_id,
|
||||
"email": session.email,
|
||||
"user_type": session.user_type.value,
|
||||
"roles": session.roles,
|
||||
"permissions": session.permissions,
|
||||
"tenant_id": session.tenant_id,
|
||||
}
|
||||
|
||||
|
||||
@protected_router.get("/notifications")
|
||||
async def get_notifications(session: Session = Depends(get_current_session)):
|
||||
"""Get user's notifications."""
|
||||
# TODO: Implement actual notification fetching
|
||||
return {
|
||||
"notifications": [],
|
||||
"unread_count": 0,
|
||||
}
|
||||
|
||||
|
||||
@protected_router.post("/logout")
|
||||
async def logout(session: Session = Depends(get_current_session)):
|
||||
"""Logout current session."""
|
||||
from .session_store import get_session_store
|
||||
|
||||
store = await get_session_store()
|
||||
await store.revoke_session(session.session_id)
|
||||
|
||||
return {"message": "Logged out successfully"}
|
||||
|
||||
|
||||
@protected_router.post("/logout-all")
|
||||
async def logout_all(session: Session = Depends(get_current_session)):
|
||||
"""Logout from all devices."""
|
||||
from .session_store import get_session_store
|
||||
|
||||
store = await get_session_store()
|
||||
count = await store.revoke_all_user_sessions(session.user_id)
|
||||
|
||||
return {"message": f"Logged out from {count} sessions"}
|
||||
|
||||
|
||||
@protected_router.get("/sessions")
|
||||
async def get_active_sessions(session: Session = Depends(get_current_session)):
|
||||
"""Get all active sessions for current user."""
|
||||
from .session_store import get_session_store
|
||||
|
||||
store = await get_session_store()
|
||||
sessions = await store.get_active_sessions(session.user_id)
|
||||
|
||||
return {
|
||||
"sessions": [
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"ip_address": s.ip_address,
|
||||
"user_agent": s.user_agent,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
"last_activity_at": s.last_activity_at.isoformat() if s.last_activity_at else None,
|
||||
"is_current": s.session_id == session.session_id,
|
||||
}
|
||||
for s in sessions
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# Employee Routes
|
||||
# =============================================
|
||||
|
||||
@employee_router.get("/dashboard")
|
||||
async def employee_dashboard(session: Session = Depends(require_employee)):
|
||||
"""Employee dashboard with overview data."""
|
||||
return {
|
||||
"user_type": "employee",
|
||||
"email": session.email,
|
||||
"roles": session.roles,
|
||||
"widgets": [
|
||||
{"type": "today_classes", "title": "Heutige Stunden"},
|
||||
{"type": "pending_corrections", "title": "Ausstehende Korrekturen"},
|
||||
{"type": "absent_students", "title": "Abwesende Schueler"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@employee_router.get("/grades")
|
||||
async def get_grades(
|
||||
class_id: Optional[str] = None,
|
||||
session: Session = Depends(require_permission("grades:read"))
|
||||
):
|
||||
"""Get grades (employee only, requires grades:read permission)."""
|
||||
# TODO: Implement actual grade fetching
|
||||
return {
|
||||
"grades": [],
|
||||
"class_id": class_id,
|
||||
}
|
||||
|
||||
|
||||
@employee_router.post("/grades")
|
||||
async def create_grade(
|
||||
grade_data: dict,
|
||||
session: Session = Depends(require_permission("grades:write"))
|
||||
):
|
||||
"""Create a new grade (requires grades:write permission)."""
|
||||
# TODO: Implement grade creation
|
||||
return {"message": "Grade created"}
|
||||
|
||||
|
||||
@employee_router.get("/attendance")
|
||||
async def get_attendance(
|
||||
date: Optional[str] = None,
|
||||
class_id: Optional[str] = None,
|
||||
session: Session = Depends(require_permission("attendance:read"))
|
||||
):
|
||||
"""Get attendance records."""
|
||||
return {
|
||||
"attendance": [],
|
||||
"date": date,
|
||||
"class_id": class_id,
|
||||
}
|
||||
|
||||
|
||||
@employee_router.post("/attendance")
|
||||
async def mark_attendance(
|
||||
attendance_data: dict,
|
||||
session: Session = Depends(require_permission("attendance:write"))
|
||||
):
|
||||
"""Mark student attendance."""
|
||||
return {"message": "Attendance recorded"}
|
||||
|
||||
|
||||
@employee_router.get("/students")
|
||||
async def get_students(
|
||||
class_id: Optional[str] = None,
|
||||
session: Session = Depends(require_permission("students:read"))
|
||||
):
|
||||
"""Get student list."""
|
||||
return {
|
||||
"students": [],
|
||||
"class_id": class_id,
|
||||
}
|
||||
|
||||
|
||||
@employee_router.get("/corrections")
|
||||
async def get_corrections(
|
||||
session: Session = Depends(require_permission("corrections:read"))
|
||||
):
|
||||
"""Get pending corrections."""
|
||||
return {
|
||||
"corrections": [],
|
||||
"pending_count": 0,
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# Customer Routes
|
||||
# =============================================
|
||||
|
||||
@customer_router.get("/dashboard")
|
||||
async def customer_dashboard(session: Session = Depends(require_customer)):
|
||||
"""Customer dashboard."""
|
||||
return {
|
||||
"user_type": "customer",
|
||||
"email": session.email,
|
||||
"widgets": [
|
||||
{"type": "my_children", "title": "Meine Kinder"},
|
||||
{"type": "upcoming_meetings", "title": "Anstehende Termine"},
|
||||
{"type": "recent_grades", "title": "Aktuelle Noten"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@customer_router.get("/my-children")
|
||||
async def get_my_children(
|
||||
session: Session = Depends(require_permission("children:read"))
|
||||
):
|
||||
"""Get parent's children."""
|
||||
# TODO: Implement actual children fetching
|
||||
return {
|
||||
"children": [],
|
||||
}
|
||||
|
||||
|
||||
@customer_router.get("/my-grades")
|
||||
async def get_my_grades(
|
||||
session: Session = Depends(require_permission("own_grades:read"))
|
||||
):
|
||||
"""Get student's own grades."""
|
||||
return {
|
||||
"grades": [],
|
||||
"average": None,
|
||||
}
|
||||
|
||||
|
||||
@customer_router.get("/my-attendance")
|
||||
async def get_my_attendance(
|
||||
session: Session = Depends(require_permission("own_attendance:read"))
|
||||
):
|
||||
"""Get student's own attendance."""
|
||||
return {
|
||||
"attendance_records": [],
|
||||
"absence_count": 0,
|
||||
}
|
||||
|
||||
|
||||
@customer_router.get("/children/{child_id}/grades")
|
||||
async def get_child_grades(
|
||||
child_id: str,
|
||||
session: Session = Depends(require_permission("children:grades:read"))
|
||||
):
|
||||
"""Get child's grades (for parents)."""
|
||||
# TODO: Verify parent-child relationship
|
||||
return {
|
||||
"child_id": child_id,
|
||||
"grades": [],
|
||||
}
|
||||
|
||||
|
||||
@customer_router.get("/appointments")
|
||||
async def get_appointments(session: Session = Depends(require_customer)):
|
||||
"""Get upcoming appointments/meetings."""
|
||||
return {
|
||||
"appointments": [],
|
||||
}
|
||||
|
||||
|
||||
@customer_router.post("/appointments/{slot_id}/book")
|
||||
async def book_appointment(
|
||||
slot_id: str,
|
||||
session: Session = Depends(require_permission("meetings:join"))
|
||||
):
|
||||
"""Book a parent meeting slot."""
|
||||
return {
|
||||
"message": "Appointment booked",
|
||||
"slot_id": slot_id,
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# Admin Routes
|
||||
# =============================================
|
||||
|
||||
@admin_router.get("/users")
|
||||
async def list_users(
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
session: Session = Depends(require_permission("users:read"))
|
||||
):
|
||||
"""List all users (admin only)."""
|
||||
return {
|
||||
"users": [],
|
||||
"total": 0,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
|
||||
@admin_router.get("/users/{user_id}")
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
session: Session = Depends(require_permission("users:read"))
|
||||
):
|
||||
"""Get user details."""
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"email": None,
|
||||
"roles": [],
|
||||
}
|
||||
|
||||
|
||||
@admin_router.put("/users/{user_id}/roles")
|
||||
async def update_user_roles(
|
||||
user_id: str,
|
||||
roles: List[str],
|
||||
session: Session = Depends(require_permission("users:manage"))
|
||||
):
|
||||
"""Update user roles (admin only)."""
|
||||
return {
|
||||
"message": "Roles updated",
|
||||
"user_id": user_id,
|
||||
"roles": roles,
|
||||
}
|
||||
|
||||
|
||||
@admin_router.get("/audit-log")
|
||||
async def get_audit_log(
|
||||
page: int = 1,
|
||||
limit: int = 100,
|
||||
session: Session = Depends(require_permission("audit:read"))
|
||||
):
|
||||
"""Get audit log entries."""
|
||||
return {
|
||||
"entries": [],
|
||||
"total": 0,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
|
||||
@admin_router.get("/rbac/roles")
|
||||
async def list_roles(
|
||||
session: Session = Depends(require_permission("rbac:read"))
|
||||
):
|
||||
"""List all RBAC roles."""
|
||||
return {
|
||||
"roles": [],
|
||||
}
|
||||
|
||||
|
||||
@admin_router.get("/rbac/permissions")
|
||||
async def list_permissions(
|
||||
session: Session = Depends(require_permission("rbac:read"))
|
||||
):
|
||||
"""List all permissions."""
|
||||
from .rbac_middleware import EMPLOYEE_PERMISSIONS, CUSTOMER_PERMISSIONS, ADMIN_PERMISSIONS
|
||||
|
||||
return {
|
||||
"employee_permissions": EMPLOYEE_PERMISSIONS,
|
||||
"customer_permissions": CUSTOMER_PERMISSIONS,
|
||||
"admin_permissions": ADMIN_PERMISSIONS,
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# Router Registration Helper
|
||||
# =============================================
|
||||
|
||||
def register_protected_routes(app):
|
||||
"""
|
||||
Register all protected route routers with FastAPI app.
|
||||
|
||||
Usage:
|
||||
from session.protected_routes import register_protected_routes
|
||||
register_protected_routes(app)
|
||||
"""
|
||||
app.include_router(protected_router)
|
||||
app.include_router(employee_router)
|
||||
app.include_router(customer_router)
|
||||
app.include_router(admin_router)
|
||||
428
backend/session/rbac_middleware.py
Normal file
428
backend/session/rbac_middleware.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
RBAC Middleware for Session-Based Authentication
|
||||
|
||||
Provides user type checking (Employee vs. Customer) and
|
||||
permission-based access control.
|
||||
|
||||
Employee roles: Teacher-Rollen, Admin-Rollen
|
||||
Customer roles: parent, student, user
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/employee/grades")
|
||||
async def get_grades(session: Session = Depends(require_employee)):
|
||||
return {"grades": [...]}
|
||||
|
||||
@app.get("/api/protected/endpoint")
|
||||
async def protected(session: Session = Depends(require_permission("grades:read"))):
|
||||
return {"data": [...]}
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Callable
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException, Depends
|
||||
|
||||
from .session_store import Session, UserType
|
||||
from .session_middleware import get_current_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================
|
||||
# Permission Constants
|
||||
# =============================================
|
||||
|
||||
EMPLOYEE_PERMISSIONS = [
|
||||
# Grades & Attendance
|
||||
"grades:read",
|
||||
"grades:write",
|
||||
"attendance:read",
|
||||
"attendance:write",
|
||||
# Student Management
|
||||
"students:read",
|
||||
"students:write",
|
||||
# Reports & Consent
|
||||
"reports:generate",
|
||||
"consent:admin",
|
||||
# Corrections
|
||||
"corrections:read",
|
||||
"corrections:write",
|
||||
# Classes
|
||||
"classes:read",
|
||||
"classes:write",
|
||||
# Timetable
|
||||
"timetable:read",
|
||||
"timetable:write",
|
||||
# Substitutions
|
||||
"substitutions:read",
|
||||
"substitutions:write",
|
||||
# Parent Communication
|
||||
"parent_communication:read",
|
||||
"parent_communication:write",
|
||||
]
|
||||
|
||||
CUSTOMER_PERMISSIONS = [
|
||||
# Own Data Access
|
||||
"own_data:read",
|
||||
"own_grades:read",
|
||||
"own_attendance:read",
|
||||
# Consent Management
|
||||
"consent:manage",
|
||||
# Meetings & Communication
|
||||
"meetings:join",
|
||||
"messages:read",
|
||||
"messages:write",
|
||||
# Children (for parents)
|
||||
"children:read",
|
||||
"children:grades:read",
|
||||
"children:attendance:read",
|
||||
]
|
||||
|
||||
ADMIN_PERMISSIONS = [
|
||||
# User Management
|
||||
"users:read",
|
||||
"users:write",
|
||||
"users:manage",
|
||||
# RBAC Management
|
||||
"rbac:read",
|
||||
"rbac:write",
|
||||
# Audit & Logs
|
||||
"audit:read",
|
||||
# System Settings
|
||||
"settings:read",
|
||||
"settings:write",
|
||||
# DSR Management
|
||||
"dsr:read",
|
||||
"dsr:process",
|
||||
]
|
||||
|
||||
# Roles that indicate employee user type
|
||||
EMPLOYEE_ROLES = {
|
||||
"admin",
|
||||
"schul_admin",
|
||||
"schulleitung",
|
||||
"pruefungsvorsitz",
|
||||
"klassenlehrer",
|
||||
"fachlehrer",
|
||||
"sekretariat",
|
||||
"erstkorrektor",
|
||||
"zweitkorrektor",
|
||||
"drittkorrektor",
|
||||
"teacher_assistant",
|
||||
"teacher",
|
||||
"lehrer",
|
||||
"data_protection_officer",
|
||||
}
|
||||
|
||||
# Roles that indicate customer user type
|
||||
CUSTOMER_ROLES = {
|
||||
"parent",
|
||||
"student",
|
||||
"user",
|
||||
"guardian",
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# User Type Dependencies
|
||||
# =============================================
|
||||
|
||||
async def require_employee(session: Session = Depends(get_current_session)) -> Session:
|
||||
"""
|
||||
Require user to be an employee (internal staff).
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/employee/grades")
|
||||
async def employee_only(session: Session = Depends(require_employee)):
|
||||
return {"grades": [...]}
|
||||
"""
|
||||
if not session.is_employee():
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Employee access required"
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
async def require_customer(session: Session = Depends(get_current_session)) -> Session:
|
||||
"""
|
||||
Require user to be a customer (external user).
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/customer/my-grades")
|
||||
async def customer_only(session: Session = Depends(require_customer)):
|
||||
return {"my_grades": [...]}
|
||||
"""
|
||||
if not session.is_customer():
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Customer access required"
|
||||
)
|
||||
return session
|
||||
|
||||
|
||||
def require_user_type(user_type: UserType):
|
||||
"""
|
||||
Factory for user type dependency.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/endpoint")
|
||||
async def endpoint(session: Session = Depends(require_user_type(UserType.EMPLOYEE))):
|
||||
return {"data": [...]}
|
||||
"""
|
||||
async def user_type_checker(session: Session = Depends(get_current_session)) -> Session:
|
||||
if session.user_type != user_type:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"User type '{user_type.value}' required"
|
||||
)
|
||||
return session
|
||||
|
||||
return user_type_checker
|
||||
|
||||
|
||||
# =============================================
|
||||
# Permission Dependencies
|
||||
# =============================================
|
||||
|
||||
def require_permission(permission: str):
|
||||
"""
|
||||
Factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/grades")
|
||||
async def get_grades(session: Session = Depends(require_permission("grades:read"))):
|
||||
return {"grades": [...]}
|
||||
"""
|
||||
async def permission_checker(session: Session = Depends(get_current_session)) -> Session:
|
||||
if not session.has_permission(permission):
|
||||
logger.warning(
|
||||
f"Permission denied: user {session.user_id} lacks '{permission}'"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Permission '{permission}' required"
|
||||
)
|
||||
return session
|
||||
|
||||
return permission_checker
|
||||
|
||||
|
||||
def require_any_permission(permissions: List[str]):
|
||||
"""
|
||||
Require user to have at least one of the specified permissions.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/data")
|
||||
async def get_data(
|
||||
session: Session = Depends(require_any_permission(["data:read", "admin"]))
|
||||
):
|
||||
return {"data": [...]}
|
||||
"""
|
||||
async def any_permission_checker(session: Session = Depends(get_current_session)) -> Session:
|
||||
if not session.has_any_permission(permissions):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"One of permissions {permissions} required"
|
||||
)
|
||||
return session
|
||||
|
||||
return any_permission_checker
|
||||
|
||||
|
||||
def require_all_permissions(permissions: List[str]):
|
||||
"""
|
||||
Require user to have all specified permissions.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/sensitive")
|
||||
async def sensitive(
|
||||
session: Session = Depends(require_all_permissions(["grades:read", "grades:write"]))
|
||||
):
|
||||
return {"data": [...]}
|
||||
"""
|
||||
async def all_permissions_checker(session: Session = Depends(get_current_session)) -> Session:
|
||||
if not session.has_all_permissions(permissions):
|
||||
missing = [p for p in permissions if not session.has_permission(p)]
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Missing permissions: {missing}"
|
||||
)
|
||||
return session
|
||||
|
||||
return all_permissions_checker
|
||||
|
||||
|
||||
def require_role(role: str):
|
||||
"""
|
||||
Factory for role-based access control.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/admin/users")
|
||||
async def admin_users(session: Session = Depends(require_role("admin"))):
|
||||
return {"users": [...]}
|
||||
"""
|
||||
async def role_checker(session: Session = Depends(get_current_session)) -> Session:
|
||||
if not session.has_role(role):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Role '{role}' required"
|
||||
)
|
||||
return session
|
||||
|
||||
return role_checker
|
||||
|
||||
|
||||
def require_any_role(roles: List[str]):
|
||||
"""
|
||||
Require user to have at least one of the specified roles.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/management/data")
|
||||
async def management(
|
||||
session: Session = Depends(require_any_role(["admin", "schulleitung"]))
|
||||
):
|
||||
return {"data": [...]}
|
||||
"""
|
||||
async def any_role_checker(session: Session = Depends(get_current_session)) -> Session:
|
||||
if not any(session.has_role(role) for role in roles):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"One of roles {roles} required"
|
||||
)
|
||||
return session
|
||||
|
||||
return any_role_checker
|
||||
|
||||
|
||||
# =============================================
|
||||
# Tenant Isolation
|
||||
# =============================================
|
||||
|
||||
def require_same_tenant(tenant_id_param: str = "tenant_id"):
|
||||
"""
|
||||
Ensure user can only access data within their tenant.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/school/{tenant_id}/data")
|
||||
async def school_data(
|
||||
tenant_id: str,
|
||||
session: Session = Depends(require_same_tenant("tenant_id"))
|
||||
):
|
||||
return {"data": [...]}
|
||||
"""
|
||||
async def tenant_checker(
|
||||
session: Session = Depends(get_current_session),
|
||||
**kwargs
|
||||
) -> Session:
|
||||
request_tenant = kwargs.get(tenant_id_param)
|
||||
if request_tenant and session.tenant_id != request_tenant:
|
||||
# Check if user is super admin (can access all tenants)
|
||||
if not session.has_role("super_admin"):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied to this tenant"
|
||||
)
|
||||
return session
|
||||
|
||||
return tenant_checker
|
||||
|
||||
|
||||
# =============================================
|
||||
# Utility Functions
|
||||
# =============================================
|
||||
|
||||
def determine_user_type(roles: List[str]) -> UserType:
|
||||
"""
|
||||
Determine user type based on roles.
|
||||
|
||||
Employee roles take precedence over customer roles.
|
||||
"""
|
||||
role_set = set(roles)
|
||||
|
||||
# Check for employee roles
|
||||
if role_set & EMPLOYEE_ROLES:
|
||||
return UserType.EMPLOYEE
|
||||
|
||||
# Check for customer roles
|
||||
if role_set & CUSTOMER_ROLES:
|
||||
return UserType.CUSTOMER
|
||||
|
||||
# Default to customer
|
||||
return UserType.CUSTOMER
|
||||
|
||||
|
||||
def get_permissions_for_roles(roles: List[str], user_type: UserType) -> List[str]:
|
||||
"""
|
||||
Get permissions based on roles and user type.
|
||||
|
||||
This is a basic implementation - in production, you'd query
|
||||
the RBAC database for role-permission mappings.
|
||||
"""
|
||||
permissions = set()
|
||||
|
||||
# Base permissions based on user type
|
||||
if user_type == UserType.EMPLOYEE:
|
||||
permissions.update(EMPLOYEE_PERMISSIONS)
|
||||
else:
|
||||
permissions.update(CUSTOMER_PERMISSIONS)
|
||||
|
||||
# Admin permissions
|
||||
role_set = set(roles)
|
||||
admin_roles = {"admin", "schul_admin", "super_admin", "data_protection_officer"}
|
||||
if role_set & admin_roles:
|
||||
permissions.update(ADMIN_PERMISSIONS)
|
||||
|
||||
return list(permissions)
|
||||
|
||||
|
||||
def check_resource_ownership(
|
||||
session: Session,
|
||||
resource_user_id: str,
|
||||
allow_admin: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user owns a resource or is admin.
|
||||
|
||||
Usage:
|
||||
if not check_resource_ownership(session, grade.student_id):
|
||||
raise HTTPException(403, "Access denied")
|
||||
"""
|
||||
# User owns the resource
|
||||
if session.user_id == resource_user_id:
|
||||
return True
|
||||
|
||||
# Admin can access all
|
||||
if allow_admin and (session.has_role("admin") or session.has_role("super_admin")):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_parent_child_access(
|
||||
session: Session,
|
||||
student_id: str,
|
||||
parent_student_ids: List[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if parent has access to student's data.
|
||||
|
||||
Usage:
|
||||
parent_children = get_parent_children(session.user_id)
|
||||
if not check_parent_child_access(session, student_id, parent_children):
|
||||
raise HTTPException(403, "Access denied")
|
||||
"""
|
||||
# User is the student
|
||||
if session.user_id == student_id:
|
||||
return True
|
||||
|
||||
# User is parent of student
|
||||
if student_id in parent_student_ids:
|
||||
return True
|
||||
|
||||
# Employee can access (with appropriate permissions)
|
||||
if session.is_employee() and session.has_permission("students:read"):
|
||||
return True
|
||||
|
||||
return False
|
||||
240
backend/session/session_middleware.py
Normal file
240
backend/session/session_middleware.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Session Middleware for FastAPI
|
||||
|
||||
Provides session-based authentication as an alternative to JWT.
|
||||
Sessions are stored in Valkey with PostgreSQL fallback.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/profile")
|
||||
async def get_profile(session: Session = Depends(get_current_session)):
|
||||
return {"user_id": session.user_id}
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, Callable
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import Request, HTTPException, Depends
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
from .session_store import Session, SessionStore, get_session_store
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that extracts session from request and adds to request.state.
|
||||
|
||||
Session can be provided via:
|
||||
1. Authorization header: Bearer <session_id>
|
||||
2. Cookie: session_id=<session_id>
|
||||
"""
|
||||
|
||||
def __init__(self, app, session_cookie_name: str = "session_id"):
|
||||
super().__init__(app)
|
||||
self.session_cookie_name = session_cookie_name
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Extract session and add to request state."""
|
||||
session_id = self._extract_session_id(request)
|
||||
|
||||
if session_id:
|
||||
try:
|
||||
store = await get_session_store()
|
||||
session = await store.get_session(session_id)
|
||||
request.state.session = session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load session: {e}")
|
||||
request.state.session = None
|
||||
else:
|
||||
request.state.session = None
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def _extract_session_id(self, request: Request) -> Optional[str]:
|
||||
"""Extract session ID from request."""
|
||||
# Try Authorization header first
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
|
||||
# Try cookie
|
||||
return request.cookies.get(self.session_cookie_name)
|
||||
|
||||
|
||||
def session_middleware(app, session_cookie_name: str = "session_id"):
|
||||
"""Factory function to add session middleware to app."""
|
||||
return SessionMiddleware(app, session_cookie_name)
|
||||
|
||||
|
||||
async def get_current_session(request: Request) -> Session:
|
||||
"""
|
||||
FastAPI dependency to get current session.
|
||||
|
||||
Raises 401 if no valid session found.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/endpoint")
|
||||
async def protected(session: Session = Depends(get_current_session)):
|
||||
return {"user_id": session.user_id}
|
||||
"""
|
||||
# Check if middleware added session to state
|
||||
session = getattr(request.state, "session", None)
|
||||
if session:
|
||||
return session
|
||||
|
||||
# Middleware might not be installed, try manual extraction
|
||||
session_id = _extract_session_id_from_request(request)
|
||||
if not session_id:
|
||||
# Check for development mode bypass
|
||||
environment = os.environ.get("ENVIRONMENT", "development")
|
||||
if environment == "development":
|
||||
# Return demo session in development
|
||||
return _get_demo_session()
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
try:
|
||||
store = await get_session_store()
|
||||
session = await store.get_session(session_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired session")
|
||||
|
||||
return session
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Session validation failed: {e}")
|
||||
raise HTTPException(status_code=401, detail="Session validation failed")
|
||||
|
||||
|
||||
async def get_optional_session(request: Request) -> Optional[Session]:
|
||||
"""
|
||||
FastAPI dependency to get current session if present.
|
||||
|
||||
Returns None if no session (doesn't raise exception).
|
||||
Useful for endpoints that behave differently for logged in users.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/public/endpoint")
|
||||
async def public(session: Optional[Session] = Depends(get_optional_session)):
|
||||
if session:
|
||||
return {"message": f"Hello, {session.email}"}
|
||||
return {"message": "Hello, anonymous"}
|
||||
"""
|
||||
try:
|
||||
return await get_current_session(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def require_session(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require valid session for an endpoint.
|
||||
|
||||
Alternative to using Depends(get_current_session).
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected/endpoint")
|
||||
@require_session
|
||||
async def protected(request: Request):
|
||||
session = request.state.session
|
||||
return {"user_id": session.user_id}
|
||||
"""
|
||||
@wraps(func)
|
||||
async def wrapper(request: Request, *args, **kwargs):
|
||||
session = await get_current_session(request)
|
||||
request.state.session = session
|
||||
return await func(request, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def _extract_session_id_from_request(request: Request) -> Optional[str]:
|
||||
"""Extract session ID from request headers or cookies."""
|
||||
# Try Authorization header
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
|
||||
# Try X-Session-ID header
|
||||
session_header = request.headers.get("x-session-id")
|
||||
if session_header:
|
||||
return session_header
|
||||
|
||||
# Try cookie
|
||||
return request.cookies.get("session_id")
|
||||
|
||||
|
||||
def _get_demo_session() -> Session:
|
||||
"""Get demo session for development mode."""
|
||||
from .session_store import UserType
|
||||
|
||||
return Session(
|
||||
session_id="demo-session-id",
|
||||
user_id="10000000-0000-0000-0000-000000000024",
|
||||
email="demo@breakpilot.app",
|
||||
user_type=UserType.EMPLOYEE,
|
||||
roles=["admin", "schul_admin", "teacher"],
|
||||
permissions=[
|
||||
"grades:read", "grades:write",
|
||||
"attendance:read", "attendance:write",
|
||||
"students:read", "students:write",
|
||||
"reports:generate", "consent:admin",
|
||||
"own_data:read", "users:manage",
|
||||
],
|
||||
tenant_id="a0000000-0000-0000-0000-000000000001",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Development",
|
||||
)
|
||||
|
||||
|
||||
class SessionAuthBackend:
|
||||
"""
|
||||
Authentication backend for Starlette AuthenticationMiddleware.
|
||||
|
||||
Alternative way to integrate session auth.
|
||||
"""
|
||||
|
||||
async def authenticate(self, request: Request):
|
||||
"""Authenticate request using session."""
|
||||
from starlette.authentication import (
|
||||
AuthCredentials, BaseUser, UnauthenticatedUser
|
||||
)
|
||||
|
||||
session_id = _extract_session_id_from_request(request)
|
||||
if not session_id:
|
||||
return AuthCredentials([]), UnauthenticatedUser()
|
||||
|
||||
try:
|
||||
store = await get_session_store()
|
||||
session = await store.get_session(session_id)
|
||||
|
||||
if not session:
|
||||
return AuthCredentials([]), UnauthenticatedUser()
|
||||
|
||||
return AuthCredentials(session.permissions), SessionUser(session)
|
||||
except Exception:
|
||||
return AuthCredentials([]), UnauthenticatedUser()
|
||||
|
||||
|
||||
class SessionUser:
|
||||
"""User object compatible with Starlette authentication."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self.session.email
|
||||
|
||||
@property
|
||||
def identity(self) -> str:
|
||||
return self.session.user_id
|
||||
550
backend/session/session_store.py
Normal file
550
backend/session/session_store.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""
|
||||
Hybrid Session Store: Valkey + PostgreSQL
|
||||
|
||||
Architecture:
|
||||
- Valkey: Fast session cache with 24-hour TTL
|
||||
- PostgreSQL: Persistent storage and DSGVO audit trail
|
||||
- Graceful fallback: If Valkey is down, fall back to PostgreSQL
|
||||
|
||||
Session data model:
|
||||
{
|
||||
"session_id": "uuid",
|
||||
"user_id": "uuid",
|
||||
"email": "string",
|
||||
"user_type": "employee|customer",
|
||||
"roles": ["role1", "role2"],
|
||||
"permissions": ["perm1", "perm2"],
|
||||
"tenant_id": "school-uuid",
|
||||
"ip_address": "string",
|
||||
"user_agent": "string",
|
||||
"created_at": "timestamp",
|
||||
"last_activity_at": "timestamp"
|
||||
}
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserType(str, Enum):
|
||||
"""User type distinction for RBAC."""
|
||||
EMPLOYEE = "employee" # Internal staff (teachers, admins)
|
||||
CUSTOMER = "customer" # External users (parents, students)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""Session data model."""
|
||||
session_id: str
|
||||
user_id: str
|
||||
email: str
|
||||
user_type: UserType
|
||||
roles: List[str] = field(default_factory=list)
|
||||
permissions: List[str] = field(default_factory=list)
|
||||
tenant_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_activity_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"email": self.email,
|
||||
"user_type": self.user_type.value if isinstance(self.user_type, UserType) else self.user_type,
|
||||
"roles": self.roles,
|
||||
"permissions": self.permissions,
|
||||
"tenant_id": self.tenant_id,
|
||||
"ip_address": self.ip_address,
|
||||
"user_agent": self.user_agent,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_activity_at": self.last_activity_at.isoformat() if self.last_activity_at else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Session":
|
||||
"""Create Session from dictionary."""
|
||||
user_type = data.get("user_type", "customer")
|
||||
if isinstance(user_type, str):
|
||||
user_type = UserType(user_type)
|
||||
|
||||
created_at = data.get("created_at")
|
||||
if isinstance(created_at, str):
|
||||
created_at = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
|
||||
|
||||
last_activity_at = data.get("last_activity_at")
|
||||
if isinstance(last_activity_at, str):
|
||||
last_activity_at = datetime.fromisoformat(last_activity_at.replace("Z", "+00:00"))
|
||||
|
||||
return cls(
|
||||
session_id=data["session_id"],
|
||||
user_id=data["user_id"],
|
||||
email=data.get("email", ""),
|
||||
user_type=user_type,
|
||||
roles=data.get("roles", []),
|
||||
permissions=data.get("permissions", []),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
ip_address=data.get("ip_address"),
|
||||
user_agent=data.get("user_agent"),
|
||||
created_at=created_at or datetime.now(timezone.utc),
|
||||
last_activity_at=last_activity_at or datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def has_permission(self, permission: str) -> bool:
|
||||
"""Check if session has a specific permission."""
|
||||
return permission in self.permissions
|
||||
|
||||
def has_any_permission(self, permissions: List[str]) -> bool:
|
||||
"""Check if session has any of the specified permissions."""
|
||||
return any(p in self.permissions for p in permissions)
|
||||
|
||||
def has_all_permissions(self, permissions: List[str]) -> bool:
|
||||
"""Check if session has all specified permissions."""
|
||||
return all(p in self.permissions for p in permissions)
|
||||
|
||||
def has_role(self, role: str) -> bool:
|
||||
"""Check if session has a specific role."""
|
||||
return role in self.roles
|
||||
|
||||
def is_employee(self) -> bool:
|
||||
"""Check if user is an employee (internal staff)."""
|
||||
return self.user_type == UserType.EMPLOYEE
|
||||
|
||||
def is_customer(self) -> bool:
|
||||
"""Check if user is a customer (external user)."""
|
||||
return self.user_type == UserType.CUSTOMER
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""
|
||||
Hybrid session store using Valkey and PostgreSQL.
|
||||
|
||||
Valkey: Primary storage with 24h TTL for fast lookups
|
||||
PostgreSQL: Persistent backup and audit trail
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
valkey_url: Optional[str] = None,
|
||||
database_url: Optional[str] = None,
|
||||
session_ttl_hours: int = 24,
|
||||
):
|
||||
self.valkey_url = valkey_url or os.environ.get("VALKEY_URL", "redis://localhost:6379")
|
||||
self.database_url = database_url or os.environ.get("DATABASE_URL")
|
||||
self.session_ttl = timedelta(hours=session_ttl_hours)
|
||||
self.session_ttl_seconds = session_ttl_hours * 3600
|
||||
|
||||
self._valkey_client = None
|
||||
self._pg_pool = None
|
||||
self._valkey_available = True
|
||||
|
||||
async def connect(self):
|
||||
"""Initialize connections to Valkey and PostgreSQL."""
|
||||
await self._connect_valkey()
|
||||
await self._connect_postgres()
|
||||
|
||||
async def _connect_valkey(self):
|
||||
"""Connect to Valkey (Redis-compatible)."""
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
self._valkey_client = redis.from_url(
|
||||
self.valkey_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test connection
|
||||
await self._valkey_client.ping()
|
||||
self._valkey_available = True
|
||||
logger.info("Connected to Valkey session cache")
|
||||
except ImportError:
|
||||
logger.warning("redis package not installed, Valkey unavailable")
|
||||
self._valkey_available = False
|
||||
except Exception as e:
|
||||
logger.warning(f"Valkey connection failed, falling back to PostgreSQL: {e}")
|
||||
self._valkey_available = False
|
||||
|
||||
async def _connect_postgres(self):
|
||||
"""Connect to PostgreSQL."""
|
||||
if not self.database_url:
|
||||
logger.warning("DATABASE_URL not set, PostgreSQL unavailable")
|
||||
return
|
||||
|
||||
try:
|
||||
import asyncpg
|
||||
self._pg_pool = await asyncpg.create_pool(
|
||||
self.database_url,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
)
|
||||
logger.info("Connected to PostgreSQL session store")
|
||||
except ImportError:
|
||||
logger.warning("asyncpg package not installed")
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL connection failed: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""Close all connections."""
|
||||
if self._valkey_client:
|
||||
await self._valkey_client.close()
|
||||
if self._pg_pool:
|
||||
await self._pg_pool.close()
|
||||
|
||||
def _get_valkey_key(self, session_id: str) -> str:
|
||||
"""Generate Valkey key for session."""
|
||||
return f"session:{session_id}"
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash token for PostgreSQL storage."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
user_id: str,
|
||||
email: str,
|
||||
user_type: UserType,
|
||||
roles: List[str],
|
||||
permissions: List[str],
|
||||
tenant_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
) -> Session:
|
||||
"""
|
||||
Create a new session.
|
||||
|
||||
Stores in both Valkey (with TTL) and PostgreSQL (persistent).
|
||||
Returns the session with generated session_id.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
session = Session(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
user_type=user_type,
|
||||
roles=roles,
|
||||
permissions=permissions,
|
||||
tenant_id=tenant_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
# Store in Valkey (primary)
|
||||
if self._valkey_available and self._valkey_client:
|
||||
try:
|
||||
key = self._get_valkey_key(session.session_id)
|
||||
await self._valkey_client.setex(
|
||||
key,
|
||||
self.session_ttl_seconds,
|
||||
json.dumps(session.to_dict()),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store session in Valkey: {e}")
|
||||
self._valkey_available = False
|
||||
|
||||
# Store in PostgreSQL (backup + audit)
|
||||
if self._pg_pool:
|
||||
try:
|
||||
async with self._pg_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO user_sessions (
|
||||
id, user_id, token_hash, email, user_type, roles,
|
||||
permissions, tenant_id, ip_address, user_agent,
|
||||
expires_at, created_at, last_activity_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
""",
|
||||
session.session_id,
|
||||
session.user_id,
|
||||
self._hash_token(session.session_id),
|
||||
session.email,
|
||||
session.user_type.value,
|
||||
json.dumps(session.roles),
|
||||
json.dumps(session.permissions),
|
||||
session.tenant_id,
|
||||
session.ip_address,
|
||||
session.user_agent,
|
||||
datetime.now(timezone.utc) + self.session_ttl,
|
||||
session.created_at,
|
||||
session.last_activity_at,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store session in PostgreSQL: {e}")
|
||||
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[Session]:
|
||||
"""
|
||||
Get session by ID.
|
||||
|
||||
Tries Valkey first (fast), falls back to PostgreSQL.
|
||||
"""
|
||||
# Try Valkey first
|
||||
if self._valkey_available and self._valkey_client:
|
||||
try:
|
||||
key = self._get_valkey_key(session_id)
|
||||
data = await self._valkey_client.get(key)
|
||||
if data:
|
||||
session = Session.from_dict(json.loads(data))
|
||||
# Update last activity
|
||||
await self._update_last_activity(session_id)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.warning(f"Valkey lookup failed, trying PostgreSQL: {e}")
|
||||
self._valkey_available = False
|
||||
|
||||
# Fall back to PostgreSQL
|
||||
if self._pg_pool:
|
||||
try:
|
||||
async with self._pg_pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT id, user_id, email, user_type, roles, permissions,
|
||||
tenant_id, ip_address, user_agent, created_at, last_activity_at
|
||||
FROM user_sessions
|
||||
WHERE id = $1
|
||||
AND revoked_at IS NULL
|
||||
AND expires_at > NOW()
|
||||
""",
|
||||
session_id,
|
||||
)
|
||||
if row:
|
||||
session = Session(
|
||||
session_id=str(row["id"]),
|
||||
user_id=str(row["user_id"]),
|
||||
email=row["email"] or "",
|
||||
user_type=UserType(row["user_type"]) if row["user_type"] else UserType.CUSTOMER,
|
||||
roles=json.loads(row["roles"]) if row["roles"] else [],
|
||||
permissions=json.loads(row["permissions"]) if row["permissions"] else [],
|
||||
tenant_id=str(row["tenant_id"]) if row["tenant_id"] else None,
|
||||
ip_address=row["ip_address"],
|
||||
user_agent=row["user_agent"],
|
||||
created_at=row["created_at"],
|
||||
last_activity_at=row["last_activity_at"],
|
||||
)
|
||||
|
||||
# Re-cache in Valkey if it's back up
|
||||
await self._cache_in_valkey(session)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL session lookup failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _update_last_activity(self, session_id: str):
|
||||
"""Update last activity timestamp."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Update Valkey TTL
|
||||
if self._valkey_available and self._valkey_client:
|
||||
try:
|
||||
key = self._get_valkey_key(session_id)
|
||||
# Refresh TTL
|
||||
await self._valkey_client.expire(key, self.session_ttl_seconds)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update PostgreSQL
|
||||
if self._pg_pool:
|
||||
try:
|
||||
async with self._pg_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE user_sessions
|
||||
SET last_activity_at = $1, expires_at = $2
|
||||
WHERE id = $3
|
||||
""",
|
||||
now,
|
||||
now + self.session_ttl,
|
||||
session_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _cache_in_valkey(self, session: Session):
|
||||
"""Re-cache session in Valkey after PostgreSQL fallback."""
|
||||
if self._valkey_available and self._valkey_client:
|
||||
try:
|
||||
key = self._get_valkey_key(session.session_id)
|
||||
await self._valkey_client.setex(
|
||||
key,
|
||||
self.session_ttl_seconds,
|
||||
json.dumps(session.to_dict()),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def revoke_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Revoke a session (logout).
|
||||
|
||||
Removes from Valkey and marks as revoked in PostgreSQL.
|
||||
"""
|
||||
success = False
|
||||
|
||||
# Remove from Valkey
|
||||
if self._valkey_available and self._valkey_client:
|
||||
try:
|
||||
key = self._get_valkey_key(session_id)
|
||||
await self._valkey_client.delete(key)
|
||||
success = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke session in Valkey: {e}")
|
||||
|
||||
# Mark as revoked in PostgreSQL
|
||||
if self._pg_pool:
|
||||
try:
|
||||
async with self._pg_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE user_sessions
|
||||
SET revoked_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
session_id,
|
||||
)
|
||||
success = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke session in PostgreSQL: {e}")
|
||||
|
||||
return success
|
||||
|
||||
async def revoke_all_user_sessions(self, user_id: str) -> int:
|
||||
"""
|
||||
Revoke all sessions for a user (force logout from all devices).
|
||||
|
||||
Returns the number of sessions revoked.
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# Get all session IDs for user from PostgreSQL
|
||||
if self._pg_pool:
|
||||
try:
|
||||
async with self._pg_pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id FROM user_sessions
|
||||
WHERE user_id = $1
|
||||
AND revoked_at IS NULL
|
||||
AND expires_at > NOW()
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
session_ids = [str(row["id"]) for row in rows]
|
||||
|
||||
# Revoke in PostgreSQL
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE user_sessions
|
||||
SET revoked_at = NOW()
|
||||
WHERE user_id = $1
|
||||
AND revoked_at IS NULL
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
|
||||
# Remove from Valkey
|
||||
if self._valkey_available and self._valkey_client:
|
||||
for session_id in session_ids:
|
||||
try:
|
||||
key = self._get_valkey_key(session_id)
|
||||
await self._valkey_client.delete(key)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke all user sessions: {e}")
|
||||
|
||||
return count
|
||||
|
||||
async def get_active_sessions(self, user_id: str) -> List[Session]:
|
||||
"""Get all active sessions for a user."""
|
||||
sessions = []
|
||||
|
||||
if self._pg_pool:
|
||||
try:
|
||||
async with self._pg_pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, user_id, email, user_type, roles, permissions,
|
||||
tenant_id, ip_address, user_agent, created_at, last_activity_at
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1
|
||||
AND revoked_at IS NULL
|
||||
AND expires_at > NOW()
|
||||
ORDER BY last_activity_at DESC
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
sessions.append(Session(
|
||||
session_id=str(row["id"]),
|
||||
user_id=str(row["user_id"]),
|
||||
email=row["email"] or "",
|
||||
user_type=UserType(row["user_type"]) if row["user_type"] else UserType.CUSTOMER,
|
||||
roles=json.loads(row["roles"]) if row["roles"] else [],
|
||||
permissions=json.loads(row["permissions"]) if row["permissions"] else [],
|
||||
tenant_id=str(row["tenant_id"]) if row["tenant_id"] else None,
|
||||
ip_address=row["ip_address"],
|
||||
user_agent=row["user_agent"],
|
||||
created_at=row["created_at"],
|
||||
last_activity_at=row["last_activity_at"],
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get active sessions: {e}")
|
||||
|
||||
return sessions
|
||||
|
||||
async def cleanup_expired_sessions(self) -> int:
|
||||
"""
|
||||
Clean up expired sessions from PostgreSQL.
|
||||
|
||||
This is meant to be called by a background job.
|
||||
Returns the number of sessions cleaned up.
|
||||
"""
|
||||
count = 0
|
||||
|
||||
if self._pg_pool:
|
||||
try:
|
||||
async with self._pg_pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM user_sessions
|
||||
WHERE expires_at < NOW() - INTERVAL '7 days'
|
||||
"""
|
||||
)
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
except Exception as e:
|
||||
logger.error(f"Session cleanup failed: {e}")
|
||||
|
||||
return count
|
||||
|
||||
|
||||
# Global session store instance
|
||||
_session_store: Optional[SessionStore] = None
|
||||
|
||||
|
||||
async def get_session_store() -> SessionStore:
|
||||
"""Get or create the global session store instance."""
|
||||
global _session_store
|
||||
|
||||
if _session_store is None:
|
||||
ttl_hours = int(os.environ.get("SESSION_TTL_HOURS", "24"))
|
||||
_session_store = SessionStore(session_ttl_hours=ttl_hours)
|
||||
await _session_store.connect()
|
||||
|
||||
return _session_store
|
||||
Reference in New Issue
Block a user