""" Rate Limiter Middleware Implements distributed rate limiting using Valkey (Redis-fork). Supports IP-based, user-based, and endpoint-specific rate limits. Features: - Sliding window rate limiting - IP-based limits for unauthenticated requests - User-based limits for authenticated requests - Stricter limits for auth endpoints (anti-brute-force) - IP whitelist/blacklist support - Graceful fallback when Valkey is unavailable Usage: from middleware import RateLimiterMiddleware app.add_middleware( RateLimiterMiddleware, valkey_url="redis://localhost:6379", ip_limit=100, user_limit=500, ) """ from __future__ import annotations import asyncio import hashlib import os import time from dataclasses import dataclass, field from typing import Dict, List, Optional, Set from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response # Try to import redis (valkey-compatible) try: import redis.asyncio as redis REDIS_AVAILABLE = True except ImportError: REDIS_AVAILABLE = False redis = None @dataclass class RateLimitConfig: """Configuration for rate limiting.""" # Valkey/Redis connection valkey_url: str = "redis://localhost:6379" # Default limits (requests per minute) ip_limit: int = 100 user_limit: int = 500 # Stricter limits for auth endpoints auth_limit: int = 20 auth_endpoints: List[str] = field(default_factory=lambda: [ "/api/auth/login", "/api/auth/register", "/api/auth/password-reset", "/api/auth/forgot-password", ]) # Window size in seconds window_size: int = 60 # IP whitelist (never rate limited) ip_whitelist: Set[str] = field(default_factory=lambda: { "127.0.0.1", "::1", }) # IP blacklist (always blocked) ip_blacklist: Set[str] = field(default_factory=set) # Skip internal Docker network skip_internal_network: bool = True # Excluded paths excluded_paths: List[str] = field(default_factory=lambda: [ "/health", "/metrics", "/api/health", ]) # Fallback to in-memory when Valkey is unavailable fallback_enabled: bool = True # Key prefix for rate limit keys key_prefix: str = "ratelimit" class InMemoryRateLimiter: """Fallback in-memory rate limiter when Valkey is unavailable.""" def __init__(self): self._counts: Dict[str, List[float]] = {} self._lock = asyncio.Lock() async def check_rate_limit(self, key: str, limit: int, window: int) -> tuple[bool, int]: """ Check if rate limit is exceeded. Returns: Tuple of (is_allowed, remaining_requests) """ async with self._lock: now = time.time() window_start = now - window # Get or create entry if key not in self._counts: self._counts[key] = [] # Remove old entries self._counts[key] = [t for t in self._counts[key] if t > window_start] # Check limit current_count = len(self._counts[key]) if current_count >= limit: return False, 0 # Add new request self._counts[key].append(now) return True, limit - current_count - 1 async def cleanup(self): """Remove expired entries.""" async with self._lock: now = time.time() for key in list(self._counts.keys()): self._counts[key] = [t for t in self._counts[key] if t > now - 3600] if not self._counts[key]: del self._counts[key] class RateLimiterMiddleware(BaseHTTPMiddleware): """ Middleware that implements distributed rate limiting. Uses Valkey (Redis-fork) for distributed state, with fallback to in-memory rate limiting when Valkey is unavailable. """ def __init__( self, app, config: Optional[RateLimitConfig] = None, # Individual overrides valkey_url: Optional[str] = None, ip_limit: Optional[int] = None, user_limit: Optional[int] = None, auth_limit: Optional[int] = None, ): super().__init__(app) self.config = config or RateLimitConfig() # Apply overrides if valkey_url is not None: self.config.valkey_url = valkey_url if ip_limit is not None: self.config.ip_limit = ip_limit if user_limit is not None: self.config.user_limit = user_limit if auth_limit is not None: self.config.auth_limit = auth_limit # Auto-configure from environment self.config.valkey_url = os.getenv("VALKEY_URL", self.config.valkey_url) # Initialize Valkey client self._redis: Optional[redis.Redis] = None self._fallback = InMemoryRateLimiter() self._valkey_available = False async def _get_redis(self) -> Optional[redis.Redis]: """Get or create Redis/Valkey connection.""" if not REDIS_AVAILABLE: return None if self._redis is None: try: self._redis = redis.from_url( self.config.valkey_url, decode_responses=True, socket_timeout=1.0, socket_connect_timeout=1.0, ) await self._redis.ping() self._valkey_available = True except Exception: self._valkey_available = False self._redis = None return self._redis def _get_client_ip(self, request: Request) -> str: """Extract client IP from request.""" # Check X-Forwarded-For header xff = request.headers.get("X-Forwarded-For") if xff: return xff.split(",")[0].strip() # Check X-Real-IP header xri = request.headers.get("X-Real-IP") if xri: return xri # Fall back to direct client IP if request.client: return request.client.host return "unknown" def _get_user_id(self, request: Request) -> Optional[str]: """Extract user ID from request state (set by session middleware).""" if hasattr(request.state, "session") and request.state.session: return getattr(request.state.session, "user_id", None) return None def _is_internal_network(self, ip: str) -> bool: """Check if IP is from internal Docker network.""" return ( ip.startswith("172.") or ip.startswith("10.") or ip.startswith("192.168.") ) def _get_rate_limit(self, request: Request) -> int: """Determine the rate limit for this request.""" path = request.url.path # Auth endpoints get stricter limits for auth_path in self.config.auth_endpoints: if path.startswith(auth_path): return self.config.auth_limit # Authenticated users get higher limits if self._get_user_id(request): return self.config.user_limit # Default IP-based limit return self.config.ip_limit def _get_rate_limit_key(self, request: Request) -> str: """Generate the rate limit key for this request.""" # Use user ID if authenticated user_id = self._get_user_id(request) if user_id: identifier = f"user:{user_id}" else: ip = self._get_client_ip(request) # Hash IP for privacy ip_hash = hashlib.sha256(ip.encode()).hexdigest()[:16] identifier = f"ip:{ip_hash}" # Include path for endpoint-specific limits path = request.url.path for auth_path in self.config.auth_endpoints: if path.startswith(auth_path): return f"{self.config.key_prefix}:auth:{identifier}" return f"{self.config.key_prefix}:{identifier}" async def _check_rate_limit_valkey( self, key: str, limit: int, window: int ) -> tuple[bool, int]: """Check rate limit using Valkey.""" r = await self._get_redis() if not r: return await self._fallback.check_rate_limit(key, limit, window) try: # Use sliding window with sorted set now = time.time() window_start = now - window pipe = r.pipeline() # Remove old entries pipe.zremrangebyscore(key, "-inf", window_start) # Count current entries pipe.zcard(key) # Add new entry pipe.zadd(key, {str(now): now}) # Set expiry pipe.expire(key, window + 10) results = await pipe.execute() current_count = results[1] if current_count >= limit: return False, 0 return True, limit - current_count - 1 except Exception: # Fallback to in-memory self._valkey_available = False return await self._fallback.check_rate_limit(key, limit, window) async def dispatch(self, request: Request, call_next) -> Response: # Skip excluded paths if request.url.path in self.config.excluded_paths: return await call_next(request) # Get client IP ip = self._get_client_ip(request) # Check blacklist if ip in self.config.ip_blacklist: return JSONResponse( status_code=403, content={ "error": "ip_blocked", "message": "Your IP address has been blocked.", }, ) # Skip whitelist if ip in self.config.ip_whitelist: return await call_next(request) # Skip internal network if self.config.skip_internal_network and self._is_internal_network(ip): return await call_next(request) # Get rate limit parameters limit = self._get_rate_limit(request) key = self._get_rate_limit_key(request) window = self.config.window_size # Check rate limit allowed, remaining = await self._check_rate_limit_valkey(key, limit, window) if not allowed: return JSONResponse( status_code=429, content={ "error": "rate_limit_exceeded", "message": "Too many requests. Please try again later.", "retry_after": window, }, headers={ "Retry-After": str(window), "X-RateLimit-Limit": str(limit), "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(int(time.time()) + window), }, ) # Process request response = await call_next(request) # Add rate limit headers response.headers["X-RateLimit-Limit"] = str(limit) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Reset"] = str(int(time.time()) + window) return response