Files
Benjamin Boenisch ad111d5e69 Initial commit: breakpilot-core - Shared Infrastructure
Docker Compose with 24+ services:
- PostgreSQL (PostGIS), Valkey, MinIO, Qdrant
- Vault (PKI/TLS), Nginx (Reverse Proxy)
- Backend Core API, Consent Service, Billing Service
- RAG Service, Embedding Service
- Gitea, Woodpecker CI/CD
- Night Scheduler, Health Aggregator
- Jitsi (Web/XMPP/JVB/Jicofo), Mailpit

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:13 +01:00

364 lines
11 KiB
Python

"""
Rate Limiter Middleware
Implements distributed rate limiting using Valkey (Redis-fork).
Supports IP-based, user-based, and endpoint-specific rate limits.
Features:
- Sliding window rate limiting
- IP-based limits for unauthenticated requests
- User-based limits for authenticated requests
- Stricter limits for auth endpoints (anti-brute-force)
- IP whitelist/blacklist support
- Graceful fallback when Valkey is unavailable
Usage:
from middleware import RateLimiterMiddleware
app.add_middleware(
RateLimiterMiddleware,
valkey_url="redis://localhost:6379",
ip_limit=100,
user_limit=500,
)
"""
from __future__ import annotations
import asyncio
import hashlib
import os
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
# Try to import redis (valkey-compatible)
try:
import redis.asyncio as redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
redis = None
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting."""
# Valkey/Redis connection
valkey_url: str = "redis://localhost:6379"
# Default limits (requests per minute)
ip_limit: int = 100
user_limit: int = 500
# Stricter limits for auth endpoints
auth_limit: int = 20
auth_endpoints: List[str] = field(default_factory=lambda: [
"/api/auth/login",
"/api/auth/register",
"/api/auth/password-reset",
"/api/auth/forgot-password",
])
# Window size in seconds
window_size: int = 60
# IP whitelist (never rate limited)
ip_whitelist: Set[str] = field(default_factory=lambda: {
"127.0.0.1",
"::1",
})
# IP blacklist (always blocked)
ip_blacklist: Set[str] = field(default_factory=set)
# Skip internal Docker network
skip_internal_network: bool = True
# Excluded paths
excluded_paths: List[str] = field(default_factory=lambda: [
"/health",
"/metrics",
"/api/health",
])
# Fallback to in-memory when Valkey is unavailable
fallback_enabled: bool = True
# Key prefix for rate limit keys
key_prefix: str = "ratelimit"
class InMemoryRateLimiter:
"""Fallback in-memory rate limiter when Valkey is unavailable."""
def __init__(self):
self._counts: Dict[str, List[float]] = {}
self._lock = asyncio.Lock()
async def check_rate_limit(self, key: str, limit: int, window: int) -> tuple[bool, int]:
"""
Check if rate limit is exceeded.
Returns:
Tuple of (is_allowed, remaining_requests)
"""
async with self._lock:
now = time.time()
window_start = now - window
# Get or create entry
if key not in self._counts:
self._counts[key] = []
# Remove old entries
self._counts[key] = [t for t in self._counts[key] if t > window_start]
# Check limit
current_count = len(self._counts[key])
if current_count >= limit:
return False, 0
# Add new request
self._counts[key].append(now)
return True, limit - current_count - 1
async def cleanup(self):
"""Remove expired entries."""
async with self._lock:
now = time.time()
for key in list(self._counts.keys()):
self._counts[key] = [t for t in self._counts[key] if t > now - 3600]
if not self._counts[key]:
del self._counts[key]
class RateLimiterMiddleware(BaseHTTPMiddleware):
"""
Middleware that implements distributed rate limiting.
Uses Valkey (Redis-fork) for distributed state, with fallback
to in-memory rate limiting when Valkey is unavailable.
"""
def __init__(
self,
app,
config: Optional[RateLimitConfig] = None,
# Individual overrides
valkey_url: Optional[str] = None,
ip_limit: Optional[int] = None,
user_limit: Optional[int] = None,
auth_limit: Optional[int] = None,
):
super().__init__(app)
self.config = config or RateLimitConfig()
# Apply overrides
if valkey_url is not None:
self.config.valkey_url = valkey_url
if ip_limit is not None:
self.config.ip_limit = ip_limit
if user_limit is not None:
self.config.user_limit = user_limit
if auth_limit is not None:
self.config.auth_limit = auth_limit
# Auto-configure from environment
self.config.valkey_url = os.getenv("VALKEY_URL", self.config.valkey_url)
# Initialize Valkey client
self._redis: Optional[redis.Redis] = None
self._fallback = InMemoryRateLimiter()
self._valkey_available = False
async def _get_redis(self) -> Optional[redis.Redis]:
"""Get or create Redis/Valkey connection."""
if not REDIS_AVAILABLE:
return None
if self._redis is None:
try:
self._redis = redis.from_url(
self.config.valkey_url,
decode_responses=True,
socket_timeout=1.0,
socket_connect_timeout=1.0,
)
await self._redis.ping()
self._valkey_available = True
except Exception:
self._valkey_available = False
self._redis = None
return self._redis
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request."""
# Check X-Forwarded-For header
xff = request.headers.get("X-Forwarded-For")
if xff:
return xff.split(",")[0].strip()
# Check X-Real-IP header
xri = request.headers.get("X-Real-IP")
if xri:
return xri
# Fall back to direct client IP
if request.client:
return request.client.host
return "unknown"
def _get_user_id(self, request: Request) -> Optional[str]:
"""Extract user ID from request state (set by session middleware)."""
if hasattr(request.state, "session") and request.state.session:
return getattr(request.state.session, "user_id", None)
return None
def _is_internal_network(self, ip: str) -> bool:
"""Check if IP is from internal Docker network."""
return (
ip.startswith("172.") or
ip.startswith("10.") or
ip.startswith("192.168.")
)
def _get_rate_limit(self, request: Request) -> int:
"""Determine the rate limit for this request."""
path = request.url.path
# Auth endpoints get stricter limits
for auth_path in self.config.auth_endpoints:
if path.startswith(auth_path):
return self.config.auth_limit
# Authenticated users get higher limits
if self._get_user_id(request):
return self.config.user_limit
# Default IP-based limit
return self.config.ip_limit
def _get_rate_limit_key(self, request: Request) -> str:
"""Generate the rate limit key for this request."""
# Use user ID if authenticated
user_id = self._get_user_id(request)
if user_id:
identifier = f"user:{user_id}"
else:
ip = self._get_client_ip(request)
# Hash IP for privacy
ip_hash = hashlib.sha256(ip.encode()).hexdigest()[:16]
identifier = f"ip:{ip_hash}"
# Include path for endpoint-specific limits
path = request.url.path
for auth_path in self.config.auth_endpoints:
if path.startswith(auth_path):
return f"{self.config.key_prefix}:auth:{identifier}"
return f"{self.config.key_prefix}:{identifier}"
async def _check_rate_limit_valkey(
self, key: str, limit: int, window: int
) -> tuple[bool, int]:
"""Check rate limit using Valkey."""
r = await self._get_redis()
if not r:
return await self._fallback.check_rate_limit(key, limit, window)
try:
# Use sliding window with sorted set
now = time.time()
window_start = now - window
pipe = r.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, "-inf", window_start)
# Count current entries
pipe.zcard(key)
# Add new entry
pipe.zadd(key, {str(now): now})
# Set expiry
pipe.expire(key, window + 10)
results = await pipe.execute()
current_count = results[1]
if current_count >= limit:
return False, 0
return True, limit - current_count - 1
except Exception:
# Fallback to in-memory
self._valkey_available = False
return await self._fallback.check_rate_limit(key, limit, window)
async def dispatch(self, request: Request, call_next) -> Response:
# Skip excluded paths
if request.url.path in self.config.excluded_paths:
return await call_next(request)
# Get client IP
ip = self._get_client_ip(request)
# Check blacklist
if ip in self.config.ip_blacklist:
return JSONResponse(
status_code=403,
content={
"error": "ip_blocked",
"message": "Your IP address has been blocked.",
},
)
# Skip whitelist
if ip in self.config.ip_whitelist:
return await call_next(request)
# Skip internal network
if self.config.skip_internal_network and self._is_internal_network(ip):
return await call_next(request)
# Get rate limit parameters
limit = self._get_rate_limit(request)
key = self._get_rate_limit_key(request)
window = self.config.window_size
# Check rate limit
allowed, remaining = await self._check_rate_limit_valkey(key, limit, window)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": "Too many requests. Please try again later.",
"retry_after": window,
},
headers={
"Retry-After": str(window),
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + window),
},
)
# Process request
response = await call_next(request)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(int(time.time()) + window)
return response