Services: Admin-Compliance, Backend-Compliance, AI-Compliance-SDK, Consent-SDK, Developer-Portal, PCA-Platform, DSMS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
364 lines
11 KiB
Python
364 lines
11 KiB
Python
"""
|
|
Rate Limiter Middleware
|
|
|
|
Implements distributed rate limiting using Valkey (Redis-fork).
|
|
Supports IP-based, user-based, and endpoint-specific rate limits.
|
|
|
|
Features:
|
|
- Sliding window rate limiting
|
|
- IP-based limits for unauthenticated requests
|
|
- User-based limits for authenticated requests
|
|
- Stricter limits for auth endpoints (anti-brute-force)
|
|
- IP whitelist/blacklist support
|
|
- Graceful fallback when Valkey is unavailable
|
|
|
|
Usage:
|
|
from middleware import RateLimiterMiddleware
|
|
|
|
app.add_middleware(
|
|
RateLimiterMiddleware,
|
|
valkey_url="redis://localhost:6379",
|
|
ip_limit=100,
|
|
user_limit=500,
|
|
)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Set
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, Response
|
|
|
|
# Try to import redis (valkey-compatible)
|
|
try:
|
|
import redis.asyncio as redis
|
|
REDIS_AVAILABLE = True
|
|
except ImportError:
|
|
REDIS_AVAILABLE = False
|
|
redis = None
|
|
|
|
|
|
@dataclass
|
|
class RateLimitConfig:
|
|
"""Configuration for rate limiting."""
|
|
|
|
# Valkey/Redis connection
|
|
valkey_url: str = "redis://localhost:6379"
|
|
|
|
# Default limits (requests per minute)
|
|
ip_limit: int = 100
|
|
user_limit: int = 500
|
|
|
|
# Stricter limits for auth endpoints
|
|
auth_limit: int = 20
|
|
auth_endpoints: List[str] = field(default_factory=lambda: [
|
|
"/api/auth/login",
|
|
"/api/auth/register",
|
|
"/api/auth/password-reset",
|
|
"/api/auth/forgot-password",
|
|
])
|
|
|
|
# Window size in seconds
|
|
window_size: int = 60
|
|
|
|
# IP whitelist (never rate limited)
|
|
ip_whitelist: Set[str] = field(default_factory=lambda: {
|
|
"127.0.0.1",
|
|
"::1",
|
|
})
|
|
|
|
# IP blacklist (always blocked)
|
|
ip_blacklist: Set[str] = field(default_factory=set)
|
|
|
|
# Skip internal Docker network
|
|
skip_internal_network: bool = True
|
|
|
|
# Excluded paths
|
|
excluded_paths: List[str] = field(default_factory=lambda: [
|
|
"/health",
|
|
"/metrics",
|
|
"/api/health",
|
|
])
|
|
|
|
# Fallback to in-memory when Valkey is unavailable
|
|
fallback_enabled: bool = True
|
|
|
|
# Key prefix for rate limit keys
|
|
key_prefix: str = "ratelimit"
|
|
|
|
|
|
class InMemoryRateLimiter:
|
|
"""Fallback in-memory rate limiter when Valkey is unavailable."""
|
|
|
|
def __init__(self):
|
|
self._counts: Dict[str, List[float]] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def check_rate_limit(self, key: str, limit: int, window: int) -> tuple[bool, int]:
|
|
"""
|
|
Check if rate limit is exceeded.
|
|
|
|
Returns:
|
|
Tuple of (is_allowed, remaining_requests)
|
|
"""
|
|
async with self._lock:
|
|
now = time.time()
|
|
window_start = now - window
|
|
|
|
# Get or create entry
|
|
if key not in self._counts:
|
|
self._counts[key] = []
|
|
|
|
# Remove old entries
|
|
self._counts[key] = [t for t in self._counts[key] if t > window_start]
|
|
|
|
# Check limit
|
|
current_count = len(self._counts[key])
|
|
if current_count >= limit:
|
|
return False, 0
|
|
|
|
# Add new request
|
|
self._counts[key].append(now)
|
|
return True, limit - current_count - 1
|
|
|
|
async def cleanup(self):
|
|
"""Remove expired entries."""
|
|
async with self._lock:
|
|
now = time.time()
|
|
for key in list(self._counts.keys()):
|
|
self._counts[key] = [t for t in self._counts[key] if t > now - 3600]
|
|
if not self._counts[key]:
|
|
del self._counts[key]
|
|
|
|
|
|
class RateLimiterMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware that implements distributed rate limiting.
|
|
|
|
Uses Valkey (Redis-fork) for distributed state, with fallback
|
|
to in-memory rate limiting when Valkey is unavailable.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
app,
|
|
config: Optional[RateLimitConfig] = None,
|
|
# Individual overrides
|
|
valkey_url: Optional[str] = None,
|
|
ip_limit: Optional[int] = None,
|
|
user_limit: Optional[int] = None,
|
|
auth_limit: Optional[int] = None,
|
|
):
|
|
super().__init__(app)
|
|
|
|
self.config = config or RateLimitConfig()
|
|
|
|
# Apply overrides
|
|
if valkey_url is not None:
|
|
self.config.valkey_url = valkey_url
|
|
if ip_limit is not None:
|
|
self.config.ip_limit = ip_limit
|
|
if user_limit is not None:
|
|
self.config.user_limit = user_limit
|
|
if auth_limit is not None:
|
|
self.config.auth_limit = auth_limit
|
|
|
|
# Auto-configure from environment
|
|
self.config.valkey_url = os.getenv("VALKEY_URL", self.config.valkey_url)
|
|
|
|
# Initialize Valkey client
|
|
self._redis: Optional[redis.Redis] = None
|
|
self._fallback = InMemoryRateLimiter()
|
|
self._valkey_available = False
|
|
|
|
async def _get_redis(self) -> Optional[redis.Redis]:
|
|
"""Get or create Redis/Valkey connection."""
|
|
if not REDIS_AVAILABLE:
|
|
return None
|
|
|
|
if self._redis is None:
|
|
try:
|
|
self._redis = redis.from_url(
|
|
self.config.valkey_url,
|
|
decode_responses=True,
|
|
socket_timeout=1.0,
|
|
socket_connect_timeout=1.0,
|
|
)
|
|
await self._redis.ping()
|
|
self._valkey_available = True
|
|
except Exception:
|
|
self._valkey_available = False
|
|
self._redis = None
|
|
|
|
return self._redis
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Extract client IP from request."""
|
|
# Check X-Forwarded-For header
|
|
xff = request.headers.get("X-Forwarded-For")
|
|
if xff:
|
|
return xff.split(",")[0].strip()
|
|
|
|
# Check X-Real-IP header
|
|
xri = request.headers.get("X-Real-IP")
|
|
if xri:
|
|
return xri
|
|
|
|
# Fall back to direct client IP
|
|
if request.client:
|
|
return request.client.host
|
|
return "unknown"
|
|
|
|
def _get_user_id(self, request: Request) -> Optional[str]:
|
|
"""Extract user ID from request state (set by session middleware)."""
|
|
if hasattr(request.state, "session") and request.state.session:
|
|
return getattr(request.state.session, "user_id", None)
|
|
return None
|
|
|
|
def _is_internal_network(self, ip: str) -> bool:
|
|
"""Check if IP is from internal Docker network."""
|
|
return (
|
|
ip.startswith("172.") or
|
|
ip.startswith("10.") or
|
|
ip.startswith("192.168.")
|
|
)
|
|
|
|
def _get_rate_limit(self, request: Request) -> int:
|
|
"""Determine the rate limit for this request."""
|
|
path = request.url.path
|
|
|
|
# Auth endpoints get stricter limits
|
|
for auth_path in self.config.auth_endpoints:
|
|
if path.startswith(auth_path):
|
|
return self.config.auth_limit
|
|
|
|
# Authenticated users get higher limits
|
|
if self._get_user_id(request):
|
|
return self.config.user_limit
|
|
|
|
# Default IP-based limit
|
|
return self.config.ip_limit
|
|
|
|
def _get_rate_limit_key(self, request: Request) -> str:
|
|
"""Generate the rate limit key for this request."""
|
|
# Use user ID if authenticated
|
|
user_id = self._get_user_id(request)
|
|
if user_id:
|
|
identifier = f"user:{user_id}"
|
|
else:
|
|
ip = self._get_client_ip(request)
|
|
# Hash IP for privacy
|
|
ip_hash = hashlib.sha256(ip.encode()).hexdigest()[:16]
|
|
identifier = f"ip:{ip_hash}"
|
|
|
|
# Include path for endpoint-specific limits
|
|
path = request.url.path
|
|
for auth_path in self.config.auth_endpoints:
|
|
if path.startswith(auth_path):
|
|
return f"{self.config.key_prefix}:auth:{identifier}"
|
|
|
|
return f"{self.config.key_prefix}:{identifier}"
|
|
|
|
async def _check_rate_limit_valkey(
|
|
self, key: str, limit: int, window: int
|
|
) -> tuple[bool, int]:
|
|
"""Check rate limit using Valkey."""
|
|
r = await self._get_redis()
|
|
if not r:
|
|
return await self._fallback.check_rate_limit(key, limit, window)
|
|
|
|
try:
|
|
# Use sliding window with sorted set
|
|
now = time.time()
|
|
window_start = now - window
|
|
|
|
pipe = r.pipeline()
|
|
# Remove old entries
|
|
pipe.zremrangebyscore(key, "-inf", window_start)
|
|
# Count current entries
|
|
pipe.zcard(key)
|
|
# Add new entry
|
|
pipe.zadd(key, {str(now): now})
|
|
# Set expiry
|
|
pipe.expire(key, window + 10)
|
|
|
|
results = await pipe.execute()
|
|
current_count = results[1]
|
|
|
|
if current_count >= limit:
|
|
return False, 0
|
|
|
|
return True, limit - current_count - 1
|
|
|
|
except Exception:
|
|
# Fallback to in-memory
|
|
self._valkey_available = False
|
|
return await self._fallback.check_rate_limit(key, limit, window)
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
# Skip excluded paths
|
|
if request.url.path in self.config.excluded_paths:
|
|
return await call_next(request)
|
|
|
|
# Get client IP
|
|
ip = self._get_client_ip(request)
|
|
|
|
# Check blacklist
|
|
if ip in self.config.ip_blacklist:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={
|
|
"error": "ip_blocked",
|
|
"message": "Your IP address has been blocked.",
|
|
},
|
|
)
|
|
|
|
# Skip whitelist
|
|
if ip in self.config.ip_whitelist:
|
|
return await call_next(request)
|
|
|
|
# Skip internal network
|
|
if self.config.skip_internal_network and self._is_internal_network(ip):
|
|
return await call_next(request)
|
|
|
|
# Get rate limit parameters
|
|
limit = self._get_rate_limit(request)
|
|
key = self._get_rate_limit_key(request)
|
|
window = self.config.window_size
|
|
|
|
# Check rate limit
|
|
allowed, remaining = await self._check_rate_limit_valkey(key, limit, window)
|
|
|
|
if not allowed:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "rate_limit_exceeded",
|
|
"message": "Too many requests. Please try again later.",
|
|
"retry_after": window,
|
|
},
|
|
headers={
|
|
"Retry-After": str(window),
|
|
"X-RateLimit-Limit": str(limit),
|
|
"X-RateLimit-Remaining": "0",
|
|
"X-RateLimit-Reset": str(int(time.time()) + window),
|
|
},
|
|
)
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Add rate limit headers
|
|
response.headers["X-RateLimit-Limit"] = str(limit)
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
response.headers["X-RateLimit-Reset"] = str(int(time.time()) + window)
|
|
|
|
return response
|