Initial commit: breakpilot-core - Shared Infrastructure
Docker Compose with 24+ services: - PostgreSQL (PostGIS), Valkey, MinIO, Qdrant - Vault (PKI/TLS), Nginx (Reverse Proxy) - Backend Core API, Consent Service, Billing Service - RAG Service, Embedding Service - Gitea, Woodpecker CI/CD - Night Scheduler, Health Aggregator - Jitsi (Web/XMPP/JVB/Jicofo), Mailpit Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
363
backend-core/middleware/rate_limiter.py
Normal file
363
backend-core/middleware/rate_limiter.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Rate Limiter Middleware
|
||||
|
||||
Implements distributed rate limiting using Valkey (Redis-fork).
|
||||
Supports IP-based, user-based, and endpoint-specific rate limits.
|
||||
|
||||
Features:
|
||||
- Sliding window rate limiting
|
||||
- IP-based limits for unauthenticated requests
|
||||
- User-based limits for authenticated requests
|
||||
- Stricter limits for auth endpoints (anti-brute-force)
|
||||
- IP whitelist/blacklist support
|
||||
- Graceful fallback when Valkey is unavailable
|
||||
|
||||
Usage:
|
||||
from middleware import RateLimiterMiddleware
|
||||
|
||||
app.add_middleware(
|
||||
RateLimiterMiddleware,
|
||||
valkey_url="redis://localhost:6379",
|
||||
ip_limit=100,
|
||||
user_limit=500,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
# Try to import redis (valkey-compatible)
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
redis = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for rate limiting."""
|
||||
|
||||
# Valkey/Redis connection
|
||||
valkey_url: str = "redis://localhost:6379"
|
||||
|
||||
# Default limits (requests per minute)
|
||||
ip_limit: int = 100
|
||||
user_limit: int = 500
|
||||
|
||||
# Stricter limits for auth endpoints
|
||||
auth_limit: int = 20
|
||||
auth_endpoints: List[str] = field(default_factory=lambda: [
|
||||
"/api/auth/login",
|
||||
"/api/auth/register",
|
||||
"/api/auth/password-reset",
|
||||
"/api/auth/forgot-password",
|
||||
])
|
||||
|
||||
# Window size in seconds
|
||||
window_size: int = 60
|
||||
|
||||
# IP whitelist (never rate limited)
|
||||
ip_whitelist: Set[str] = field(default_factory=lambda: {
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
})
|
||||
|
||||
# IP blacklist (always blocked)
|
||||
ip_blacklist: Set[str] = field(default_factory=set)
|
||||
|
||||
# Skip internal Docker network
|
||||
skip_internal_network: bool = True
|
||||
|
||||
# Excluded paths
|
||||
excluded_paths: List[str] = field(default_factory=lambda: [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/health",
|
||||
])
|
||||
|
||||
# Fallback to in-memory when Valkey is unavailable
|
||||
fallback_enabled: bool = True
|
||||
|
||||
# Key prefix for rate limit keys
|
||||
key_prefix: str = "ratelimit"
|
||||
|
||||
|
||||
class InMemoryRateLimiter:
|
||||
"""Fallback in-memory rate limiter when Valkey is unavailable."""
|
||||
|
||||
def __init__(self):
|
||||
self._counts: Dict[str, List[float]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def check_rate_limit(self, key: str, limit: int, window: int) -> tuple[bool, int]:
|
||||
"""
|
||||
Check if rate limit is exceeded.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, remaining_requests)
|
||||
"""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - window
|
||||
|
||||
# Get or create entry
|
||||
if key not in self._counts:
|
||||
self._counts[key] = []
|
||||
|
||||
# Remove old entries
|
||||
self._counts[key] = [t for t in self._counts[key] if t > window_start]
|
||||
|
||||
# Check limit
|
||||
current_count = len(self._counts[key])
|
||||
if current_count >= limit:
|
||||
return False, 0
|
||||
|
||||
# Add new request
|
||||
self._counts[key].append(now)
|
||||
return True, limit - current_count - 1
|
||||
|
||||
async def cleanup(self):
|
||||
"""Remove expired entries."""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
for key in list(self._counts.keys()):
|
||||
self._counts[key] = [t for t in self._counts[key] if t > now - 3600]
|
||||
if not self._counts[key]:
|
||||
del self._counts[key]
|
||||
|
||||
|
||||
class RateLimiterMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that implements distributed rate limiting.
|
||||
|
||||
Uses Valkey (Redis-fork) for distributed state, with fallback
|
||||
to in-memory rate limiting when Valkey is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
config: Optional[RateLimitConfig] = None,
|
||||
# Individual overrides
|
||||
valkey_url: Optional[str] = None,
|
||||
ip_limit: Optional[int] = None,
|
||||
user_limit: Optional[int] = None,
|
||||
auth_limit: Optional[int] = None,
|
||||
):
|
||||
super().__init__(app)
|
||||
|
||||
self.config = config or RateLimitConfig()
|
||||
|
||||
# Apply overrides
|
||||
if valkey_url is not None:
|
||||
self.config.valkey_url = valkey_url
|
||||
if ip_limit is not None:
|
||||
self.config.ip_limit = ip_limit
|
||||
if user_limit is not None:
|
||||
self.config.user_limit = user_limit
|
||||
if auth_limit is not None:
|
||||
self.config.auth_limit = auth_limit
|
||||
|
||||
# Auto-configure from environment
|
||||
self.config.valkey_url = os.getenv("VALKEY_URL", self.config.valkey_url)
|
||||
|
||||
# Initialize Valkey client
|
||||
self._redis: Optional[redis.Redis] = None
|
||||
self._fallback = InMemoryRateLimiter()
|
||||
self._valkey_available = False
|
||||
|
||||
async def _get_redis(self) -> Optional[redis.Redis]:
|
||||
"""Get or create Redis/Valkey connection."""
|
||||
if not REDIS_AVAILABLE:
|
||||
return None
|
||||
|
||||
if self._redis is None:
|
||||
try:
|
||||
self._redis = redis.from_url(
|
||||
self.config.valkey_url,
|
||||
decode_responses=True,
|
||||
socket_timeout=1.0,
|
||||
socket_connect_timeout=1.0,
|
||||
)
|
||||
await self._redis.ping()
|
||||
self._valkey_available = True
|
||||
except Exception:
|
||||
self._valkey_available = False
|
||||
self._redis = None
|
||||
|
||||
return self._redis
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP from request."""
|
||||
# Check X-Forwarded-For header
|
||||
xff = request.headers.get("X-Forwarded-For")
|
||||
if xff:
|
||||
return xff.split(",")[0].strip()
|
||||
|
||||
# Check X-Real-IP header
|
||||
xri = request.headers.get("X-Real-IP")
|
||||
if xri:
|
||||
return xri
|
||||
|
||||
# Fall back to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
def _get_user_id(self, request: Request) -> Optional[str]:
|
||||
"""Extract user ID from request state (set by session middleware)."""
|
||||
if hasattr(request.state, "session") and request.state.session:
|
||||
return getattr(request.state.session, "user_id", None)
|
||||
return None
|
||||
|
||||
def _is_internal_network(self, ip: str) -> bool:
|
||||
"""Check if IP is from internal Docker network."""
|
||||
return (
|
||||
ip.startswith("172.") or
|
||||
ip.startswith("10.") or
|
||||
ip.startswith("192.168.")
|
||||
)
|
||||
|
||||
def _get_rate_limit(self, request: Request) -> int:
|
||||
"""Determine the rate limit for this request."""
|
||||
path = request.url.path
|
||||
|
||||
# Auth endpoints get stricter limits
|
||||
for auth_path in self.config.auth_endpoints:
|
||||
if path.startswith(auth_path):
|
||||
return self.config.auth_limit
|
||||
|
||||
# Authenticated users get higher limits
|
||||
if self._get_user_id(request):
|
||||
return self.config.user_limit
|
||||
|
||||
# Default IP-based limit
|
||||
return self.config.ip_limit
|
||||
|
||||
def _get_rate_limit_key(self, request: Request) -> str:
|
||||
"""Generate the rate limit key for this request."""
|
||||
# Use user ID if authenticated
|
||||
user_id = self._get_user_id(request)
|
||||
if user_id:
|
||||
identifier = f"user:{user_id}"
|
||||
else:
|
||||
ip = self._get_client_ip(request)
|
||||
# Hash IP for privacy
|
||||
ip_hash = hashlib.sha256(ip.encode()).hexdigest()[:16]
|
||||
identifier = f"ip:{ip_hash}"
|
||||
|
||||
# Include path for endpoint-specific limits
|
||||
path = request.url.path
|
||||
for auth_path in self.config.auth_endpoints:
|
||||
if path.startswith(auth_path):
|
||||
return f"{self.config.key_prefix}:auth:{identifier}"
|
||||
|
||||
return f"{self.config.key_prefix}:{identifier}"
|
||||
|
||||
async def _check_rate_limit_valkey(
|
||||
self, key: str, limit: int, window: int
|
||||
) -> tuple[bool, int]:
|
||||
"""Check rate limit using Valkey."""
|
||||
r = await self._get_redis()
|
||||
if not r:
|
||||
return await self._fallback.check_rate_limit(key, limit, window)
|
||||
|
||||
try:
|
||||
# Use sliding window with sorted set
|
||||
now = time.time()
|
||||
window_start = now - window
|
||||
|
||||
pipe = r.pipeline()
|
||||
# Remove old entries
|
||||
pipe.zremrangebyscore(key, "-inf", window_start)
|
||||
# Count current entries
|
||||
pipe.zcard(key)
|
||||
# Add new entry
|
||||
pipe.zadd(key, {str(now): now})
|
||||
# Set expiry
|
||||
pipe.expire(key, window + 10)
|
||||
|
||||
results = await pipe.execute()
|
||||
current_count = results[1]
|
||||
|
||||
if current_count >= limit:
|
||||
return False, 0
|
||||
|
||||
return True, limit - current_count - 1
|
||||
|
||||
except Exception:
|
||||
# Fallback to in-memory
|
||||
self._valkey_available = False
|
||||
return await self._fallback.check_rate_limit(key, limit, window)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
# Skip excluded paths
|
||||
if request.url.path in self.config.excluded_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Get client IP
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
# Check blacklist
|
||||
if ip in self.config.ip_blacklist:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "ip_blocked",
|
||||
"message": "Your IP address has been blocked.",
|
||||
},
|
||||
)
|
||||
|
||||
# Skip whitelist
|
||||
if ip in self.config.ip_whitelist:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip internal network
|
||||
if self.config.skip_internal_network and self._is_internal_network(ip):
|
||||
return await call_next(request)
|
||||
|
||||
# Get rate limit parameters
|
||||
limit = self._get_rate_limit(request)
|
||||
key = self._get_rate_limit_key(request)
|
||||
window = self.config.window_size
|
||||
|
||||
# Check rate limit
|
||||
allowed, remaining = await self._check_rate_limit_valkey(key, limit, window)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": "Too many requests. Please try again later.",
|
||||
"retry_after": window,
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(window),
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
"X-RateLimit-Reset": str(int(time.time()) + window),
|
||||
},
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
response.headers["X-RateLimit-Limit"] = str(limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Reset"] = str(int(time.time()) + window)
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user