""" SDK Protection Middleware Schuetzt SDK-Endpunkte vor systematischer Enumeration durch Wettbewerber oder automatisierte Crawler. Erkennt Anomalien wie sequentielle Abfragen, hohe Kategorie-Diversitaet und Burst-Zugriffe auf gleiche Kategorien. Features: - Multi-Window Quota (Minute, Stunde, Tag, Monat) - Anomaly-Score mit progressivem Throttling - Diversity-Tracking (max verschiedene Kategorien/Stunde) - Burst-Detection (gleiche Kategorie in kurzem Zeitraum) - Sequential-Enumeration-Detection (alphabetisch sortierte Zugriffe) - Unusual-Hours Detection (0-5 Uhr UTC) - Multi-Tenant Detection (>3 Tenants/Stunde) - HMAC-basiertes Watermarking (X-BP-Trace) - Graceful Fallback auf In-Memory wenn Valkey nicht verfuegbar Usage: from middleware import SDKProtectionMiddleware app.add_middleware(SDKProtectionMiddleware) """ from __future__ import annotations import asyncio import hashlib import hmac import logging import os import random import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set, Tuple from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response try: import redis.asyncio as redis REDIS_AVAILABLE = True except ImportError: REDIS_AVAILABLE = False redis = None logger = logging.getLogger("sdk_protection") # ============================================== # Configuration # ============================================== @dataclass class QuotaTier: """Rate limits for a specific tier.""" name: str per_minute: int per_hour: int per_day: int per_month: int DEFAULT_TIERS: Dict[str, QuotaTier] = { "free": QuotaTier("free", 30, 500, 3_000, 50_000), "standard": QuotaTier("standard", 60, 1_500, 10_000, 200_000), "enterprise": QuotaTier("enterprise", 120, 5_000, 50_000, 1_000_000), } @dataclass class SDKProtectionConfig: """Configuration for SDK protection middleware.""" # Valkey connection valkey_url: str = "redis://localhost:6379" # Quota tiers tiers: Dict[str, QuotaTier] = field(default_factory=lambda: dict(DEFAULT_TIERS)) default_tier: str = "free" # Diversity threshold: max unique categories per hour diversity_threshold: int = 40 # Burst threshold: max requests to same category in 2 minutes burst_threshold: int = 15 burst_window: int = 120 # seconds # Sequential detection: min sorted entries in 5 minutes sequential_window: int = 300 # seconds sequential_min_entries: int = 10 sequential_sorted_ratio: float = 0.7 # 70% sorted = suspicious # Throttle levels based on anomaly score throttle_level_1_score: int = 30 # delay 1-3s throttle_level_2_score: int = 60 # delay 5-10s, reduced detail throttle_level_3_score: int = 85 # block (429) # Score decay: multiply by factor every interval score_decay_factor: float = 0.95 score_decay_interval: int = 300 # 5 minutes # Score increments for anomalies score_diversity_increment: int = 15 score_burst_increment: int = 20 score_sequential_increment: int = 25 score_unusual_hours_increment: int = 10 score_multi_tenant_increment: int = 15 # Multi-tenant threshold multi_tenant_threshold: int = 3 # Unusual hours (UTC) unusual_hours_start: int = 0 unusual_hours_end: int = 5 # Watermark secret for HMAC watermark_secret: str = "" # Protected path prefixes protected_paths: List[str] = field(default_factory=lambda: [ "/api/sdk/", "/api/compliance/", "/api/v1/tom/", "/api/v1/dsfa/", "/api/v1/vvt/", "/api/v1/controls/", "/api/v1/assessment/", "/api/v1/eh/", "/api/v1/namespace/", ]) # Excluded paths (never protected) excluded_paths: List[str] = field(default_factory=lambda: [ "/health", "/metrics", "/api/health", ]) # Fallback to in-memory when Valkey unavailable fallback_enabled: bool = True # Key prefix key_prefix: str = "sdk_protect" # ============================================== # Category Map # ============================================== CATEGORY_MAP: Dict[str, str] = { # TOM categories "/api/v1/tom/access-control": "tom_access_control", "/api/v1/tom/encryption": "tom_encryption", "/api/v1/tom/pseudonymization": "tom_pseudonymization", "/api/v1/tom/integrity": "tom_integrity", "/api/v1/tom/availability": "tom_availability", "/api/v1/tom/resilience": "tom_resilience", "/api/v1/tom/recoverability": "tom_recoverability", "/api/v1/tom/testing": "tom_testing", "/api/v1/tom/data-separation": "tom_data_separation", "/api/v1/tom/input-control": "tom_input_control", "/api/v1/tom/transport-control": "tom_transport_control", "/api/v1/tom/output-control": "tom_output_control", "/api/v1/tom/order-control": "tom_order_control", # DSFA categories "/api/v1/dsfa/threshold": "dsfa_threshold", "/api/v1/dsfa/necessity": "dsfa_necessity", "/api/v1/dsfa/risks": "dsfa_risks", "/api/v1/dsfa/measures": "dsfa_measures", "/api/v1/dsfa/residual": "dsfa_residual", "/api/v1/dsfa/consultation": "dsfa_consultation", # VVT categories "/api/v1/vvt/processing": "vvt_processing", "/api/v1/vvt/purposes": "vvt_purposes", "/api/v1/vvt/categories": "vvt_categories", "/api/v1/vvt/recipients": "vvt_recipients", "/api/v1/vvt/transfers": "vvt_transfers", "/api/v1/vvt/retention": "vvt_retention", "/api/v1/vvt/security": "vvt_security", # Controls "/api/v1/controls/": "controls_general", # Assessment "/api/v1/assessment/": "assessment_general", # SDK "/api/sdk/": "sdk_general", # Compliance "/api/compliance/": "compliance_general", # EH "/api/v1/eh/": "eh_general", # Namespace "/api/v1/namespace/": "namespace_general", } def _extract_category(path: str) -> str: """Extract category from request path using longest-prefix match.""" best_match = "" best_category = "unknown" for prefix, category in CATEGORY_MAP.items(): if path.startswith(prefix) and len(prefix) > len(best_match): best_match = prefix best_category = category return best_category # ============================================== # In-Memory Fallback # ============================================== class InMemorySDKProtection: """Fallback in-memory tracking when Valkey is unavailable.""" def __init__(self): self._quotas: Dict[str, List[float]] = {} self._diversity: Dict[str, Set[str]] = {} self._bursts: Dict[str, List[float]] = {} self._sequences: Dict[str, List[str]] = {} self._scores: Dict[str, Dict[str, Any]] = {} self._tenants: Dict[str, Set[str]] = {} self._lock = asyncio.Lock() async def check_quota(self, user: str, tier: QuotaTier) -> Tuple[bool, Dict[str, int]]: """Check multi-window quotas. Returns (allowed, remaining_per_window).""" async with self._lock: now = time.time() remaining = {} windows = { "minute": (60, tier.per_minute), "hour": (3600, tier.per_hour), "day": (86400, tier.per_day), "month": (2592000, tier.per_month), } for window_name, (window_size, limit) in windows.items(): key = f"{user}:{window_name}" if key not in self._quotas: self._quotas[key] = [] self._quotas[key] = [t for t in self._quotas[key] if t > now - window_size] count = len(self._quotas[key]) remaining[window_name] = max(0, limit - count) if count >= limit: return False, remaining # Record the request in all windows for window_name in windows: key = f"{user}:{window_name}" self._quotas[key].append(now) return True, remaining async def track_diversity(self, user: str, category: str, hour_key: str) -> int: """Track category diversity. Returns current unique count.""" async with self._lock: key = f"{user}:{hour_key}" if key not in self._diversity: self._diversity[key] = set() self._diversity[key].add(category) return len(self._diversity[key]) async def track_burst(self, user: str, category: str, window: int) -> int: """Track burst access to same category. Returns count in window.""" async with self._lock: now = time.time() key = f"{user}:{category}" if key not in self._bursts: self._bursts[key] = [] self._bursts[key] = [t for t in self._bursts[key] if t > now - window] self._bursts[key].append(now) return len(self._bursts[key]) async def track_sequence(self, user: str, category: str, max_len: int = 10) -> List[str]: """Track category sequence. Returns last N categories.""" async with self._lock: if user not in self._sequences: self._sequences[user] = [] self._sequences[user].append(category) if len(self._sequences[user]) > max_len: self._sequences[user] = self._sequences[user][-max_len:] return list(self._sequences[user]) async def get_score(self, user: str) -> Tuple[float, float]: """Get anomaly score and last decay timestamp.""" async with self._lock: if user not in self._scores: self._scores[user] = {"score": 0.0, "last_decay": time.time()} return self._scores[user]["score"], self._scores[user]["last_decay"] async def set_score(self, user: str, score: float, last_decay: float): """Set anomaly score.""" async with self._lock: self._scores[user] = {"score": score, "last_decay": last_decay} async def track_tenant(self, user: str, tenant_id: str, hour_key: str) -> int: """Track tenant access. Returns unique tenant count.""" async with self._lock: key = f"{user}:{hour_key}" if key not in self._tenants: self._tenants[key] = set() self._tenants[key].add(tenant_id) return len(self._tenants[key]) # ============================================== # SDK Protection Middleware # ============================================== class SDKProtectionMiddleware(BaseHTTPMiddleware): """ Middleware to protect SDK endpoints from systematic enumeration. Tracks anomaly scores and applies progressive throttling based on detected patterns like sequential access, burst requests, and high category diversity. """ def __init__( self, app, config: Optional[SDKProtectionConfig] = None, ): super().__init__(app) self.config = config or SDKProtectionConfig() # Auto-configure from environment self.config.valkey_url = os.getenv("VALKEY_URL", self.config.valkey_url) self.config.watermark_secret = os.getenv( "SDK_WATERMARK_SECRET", self.config.watermark_secret or "bp-sdk-watermark-default" ) self._redis: Optional[redis.Redis] = None self._fallback = InMemorySDKProtection() self._valkey_available = False async def _get_redis(self) -> Optional[redis.Redis]: """Get or create Redis/Valkey connection.""" if not REDIS_AVAILABLE: return None if self._redis is None: try: self._redis = redis.from_url( self.config.valkey_url, decode_responses=True, socket_timeout=1.0, socket_connect_timeout=1.0, ) await self._redis.ping() self._valkey_available = True except Exception: self._valkey_available = False self._redis = None return self._redis def _is_protected_path(self, path: str) -> bool: """Check if the path is protected by SDK middleware.""" if path in self.config.excluded_paths: return False return any(path.startswith(p) for p in self.config.protected_paths) def _get_user_id(self, request: Request) -> Optional[str]: """Extract user ID from request.""" if hasattr(request.state, "session") and request.state.session: return getattr(request.state.session, "user_id", None) api_key = request.headers.get("X-API-Key") if api_key: return f"apikey:{hashlib.sha256(api_key.encode()).hexdigest()[:16]}" return None def _get_user_tier(self, request: Request) -> str: """Determine user's tier from request.""" if hasattr(request.state, "tier"): return request.state.tier tier_header = request.headers.get("X-SDK-Tier") if tier_header and tier_header in self.config.tiers: return tier_header return self.config.default_tier def _get_tenant_id(self, request: Request) -> Optional[str]: """Extract tenant ID from request.""" return request.headers.get("X-Tenant-ID") def _generate_watermark(self, user_id: str, timestamp: float) -> str: """Generate HMAC-based watermark for response tracing.""" message = f"{user_id}:{timestamp:.0f}" return hmac.new( self.config.watermark_secret.encode(), message.encode(), hashlib.sha256, ).hexdigest()[:32] def _get_hour_key(self) -> str: """Get current hour key for grouping.""" return str(int(time.time() / 3600)) # ------------------------------------------ # Valkey-based tracking # ------------------------------------------ async def _check_quota_valkey( self, user: str, tier: QuotaTier ) -> Tuple[bool, Dict[str, int]]: """Check multi-window quotas using Valkey sorted sets.""" r = await self._get_redis() if not r: return await self._fallback.check_quota(user, tier) try: now = time.time() prefix = self.config.key_prefix windows = { "minute": (60, tier.per_minute), "hour": (3600, tier.per_hour), "day": (86400, tier.per_day), "month": (2592000, tier.per_month), } pipe = r.pipeline() for window_name, (window_size, _) in windows.items(): key = f"{prefix}:quota:{user}:{window_name}" pipe.zremrangebyscore(key, "-inf", now - window_size) pipe.zcard(key) results = await pipe.execute() remaining = {} allowed = True for i, (window_name, (_, limit)) in enumerate(windows.items()): count = results[i * 2 + 1] # zcard result remaining[window_name] = max(0, limit - count) if count >= limit: allowed = False if allowed: pipe = r.pipeline() member = f"{now}:{random.randint(0, 999999)}" for window_name, (window_size, _) in windows.items(): key = f"{prefix}:quota:{user}:{window_name}" pipe.zadd(key, {member: now}) pipe.expire(key, window_size + 10) await pipe.execute() return allowed, remaining except Exception: self._valkey_available = False return await self._fallback.check_quota(user, tier) async def _track_diversity_valkey( self, user: str, category: str ) -> int: """Track category diversity using Valkey set.""" r = await self._get_redis() hour_key = self._get_hour_key() if not r: return await self._fallback.track_diversity(user, category, hour_key) try: key = f"{self.config.key_prefix}:diversity:{user}:{hour_key}" pipe = r.pipeline() pipe.sadd(key, category) pipe.scard(key) pipe.expire(key, 3660) results = await pipe.execute() return results[1] except Exception: return await self._fallback.track_diversity(user, category, hour_key) async def _track_burst_valkey( self, user: str, category: str ) -> int: """Track burst access using Valkey sorted set.""" r = await self._get_redis() if not r: return await self._fallback.track_burst( user, category, self.config.burst_window ) try: now = time.time() key = f"{self.config.key_prefix}:burst:{user}:{category}" pipe = r.pipeline() pipe.zremrangebyscore(key, "-inf", now - self.config.burst_window) pipe.zadd(key, {str(now): now}) pipe.zcard(key) pipe.expire(key, self.config.burst_window + 10) results = await pipe.execute() return results[2] except Exception: return await self._fallback.track_burst( user, category, self.config.burst_window ) async def _track_sequence_valkey( self, user: str, category: str ) -> List[str]: """Track category sequence using Valkey list.""" r = await self._get_redis() if not r: return await self._fallback.track_sequence(user, category) try: key = f"{self.config.key_prefix}:seq:{user}" pipe = r.pipeline() pipe.rpush(key, category) pipe.ltrim(key, -10, -1) # Keep last 10 pipe.lrange(key, 0, -1) pipe.expire(key, self.config.sequential_window + 10) results = await pipe.execute() return results[2] except Exception: return await self._fallback.track_sequence(user, category) async def _get_score_valkey(self, user: str) -> Tuple[float, float]: """Get anomaly score from Valkey hash.""" r = await self._get_redis() if not r: return await self._fallback.get_score(user) try: key = f"{self.config.key_prefix}:score:{user}" data = await r.hgetall(key) if not data: return 0.0, time.time() return float(data.get("score", 0)), float(data.get("last_decay", time.time())) except Exception: return await self._fallback.get_score(user) async def _set_score_valkey(self, user: str, score: float, last_decay: float): """Set anomaly score in Valkey hash.""" r = await self._get_redis() if not r: await self._fallback.set_score(user, score, last_decay) return try: key = f"{self.config.key_prefix}:score:{user}" pipe = r.pipeline() pipe.hset(key, mapping={"score": str(score), "last_decay": str(last_decay)}) pipe.expire(key, 86400) await pipe.execute() except Exception: await self._fallback.set_score(user, score, last_decay) async def _track_tenant_valkey( self, user: str, tenant_id: str ) -> int: """Track tenant access using Valkey set.""" r = await self._get_redis() hour_key = self._get_hour_key() if not r: return await self._fallback.track_tenant(user, tenant_id, hour_key) try: key = f"{self.config.key_prefix}:tenants:{user}:{hour_key}" pipe = r.pipeline() pipe.sadd(key, tenant_id) pipe.scard(key) pipe.expire(key, 3660) results = await pipe.execute() return results[1] except Exception: return await self._fallback.track_tenant(user, tenant_id, hour_key) # ------------------------------------------ # Anomaly Detection # ------------------------------------------ def _check_sequential(self, sequence: List[str]) -> bool: """Check if sequence of categories is suspiciously sorted.""" if len(sequence) < self.config.sequential_min_entries: return False sorted_seq = sorted(sequence) matches = sum(1 for a, b in zip(sequence, sorted_seq) if a == b) ratio = matches / len(sequence) return ratio >= self.config.sequential_sorted_ratio def _is_unusual_hour(self) -> bool: """Check if current UTC hour is unusual.""" import datetime hour = datetime.datetime.utcnow().hour return self.config.unusual_hours_start <= hour < self.config.unusual_hours_end def _apply_decay( self, score: float, last_decay: float, now: float ) -> Tuple[float, float]: """Apply time-based score decay.""" elapsed = now - last_decay intervals = int(elapsed / self.config.score_decay_interval) if intervals > 0: for _ in range(intervals): score *= self.config.score_decay_factor last_decay = now return max(0.0, score), last_decay def _get_throttle_level(self, score: float) -> int: """Determine throttle level from anomaly score.""" if score >= self.config.throttle_level_3_score: return 3 if score >= self.config.throttle_level_2_score: return 2 if score >= self.config.throttle_level_1_score: return 1 return 0 # ------------------------------------------ # Main dispatch # ------------------------------------------ async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path # Skip non-protected paths if not self._is_protected_path(path): return await call_next(request) # Extract user user_id = self._get_user_id(request) if not user_id: # No user identified - let other middleware handle auth return await call_next(request) # Get tier and category tier_name = self._get_user_tier(request) tier = self.config.tiers.get(tier_name, self.config.tiers[self.config.default_tier]) category = _extract_category(path) # --- Multi-Window Quota Check --- allowed, remaining = await self._check_quota_valkey(user_id, tier) if not allowed: logger.warning( "SDK quota exceeded for user=%s tier=%s path=%s", user_id[:16], tier_name, path, ) return JSONResponse( status_code=429, content={ "error": "sdk_quota_exceeded", "message": "SDK request quota exceeded. Please try again later.", "tier": tier_name, "limits": { "per_minute": tier.per_minute, "per_hour": tier.per_hour, "per_day": tier.per_day, "per_month": tier.per_month, }, }, headers={ "Retry-After": "60", "X-SDK-Quota-Remaining-Minute": str(remaining.get("minute", 0)), "X-SDK-Quota-Remaining-Hour": str(remaining.get("hour", 0)), }, ) # --- Load Anomaly Score with Decay --- now = time.time() score, last_decay = await self._get_score_valkey(user_id) score, last_decay = self._apply_decay(score, last_decay, now) triggered_rules: List[str] = [] # --- Diversity Tracking --- diversity_count = await self._track_diversity_valkey(user_id, category) if diversity_count > self.config.diversity_threshold: score += self.config.score_diversity_increment triggered_rules.append("high_diversity") # --- Burst Detection --- burst_count = await self._track_burst_valkey(user_id, category) if burst_count > self.config.burst_threshold: score += self.config.score_burst_increment triggered_rules.append("burst_detected") # --- Sequential Enumeration Detection --- sequence = await self._track_sequence_valkey(user_id, category) if self._check_sequential(sequence): score += self.config.score_sequential_increment triggered_rules.append("sequential_enumeration") # --- Unusual Hours --- if self._is_unusual_hour(): score += self.config.score_unusual_hours_increment triggered_rules.append("unusual_hours") # --- Multi-Tenant Detection --- tenant_id = self._get_tenant_id(request) if tenant_id: tenant_count = await self._track_tenant_valkey(user_id, tenant_id) if tenant_count > self.config.multi_tenant_threshold: score += self.config.score_multi_tenant_increment triggered_rules.append("multi_tenant") # Persist updated score await self._set_score_valkey(user_id, score, last_decay) # --- Determine Throttle Level --- throttle_level = self._get_throttle_level(score) if triggered_rules: logger.info( "SDK anomaly: user=%s score=%.1f level=%d rules=%s", user_id[:16], score, throttle_level, ",".join(triggered_rules), ) # Level 3: Block if throttle_level >= 3: logger.warning( "SDK protection blocking user=%s score=%.1f", user_id[:16], score, ) return JSONResponse( status_code=429, content={ "error": "sdk_protection_triggered", "message": "Anomalous access pattern detected. Request blocked.", }, headers={ "Retry-After": "300", "X-SDK-Throttle-Level": str(throttle_level), }, ) # Level 2: Heavy delay + reduced detail if throttle_level == 2: delay = random.uniform(5.0, 10.0) await asyncio.sleep(delay) # Level 1: Light delay elif throttle_level == 1: delay = random.uniform(1.0, 3.0) await asyncio.sleep(delay) # --- Forward Request --- response = await call_next(request) # --- Response Headers --- response.headers["X-SDK-Quota-Remaining-Minute"] = str(remaining.get("minute", 0)) response.headers["X-SDK-Quota-Remaining-Hour"] = str(remaining.get("hour", 0)) response.headers["X-SDK-Throttle-Level"] = str(throttle_level) if throttle_level >= 2: response.headers["X-SDK-Detail-Reduced"] = "true" # Watermark watermark = self._generate_watermark(user_id, now) response.headers["X-BP-Trace"] = watermark return response