Initial commit: breakpilot-compliance - Compliance SDK Platform

Services: Admin-Compliance, Backend-Compliance,
AI-Compliance-SDK, Consent-SDK, Developer-Portal,
PCA-Platform, DSMS

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Boenisch
2026-02-11 23:47:28 +01:00
commit 4435e7ea0a
734 changed files with 251369 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
"""
BreakPilot Middleware Stack
This module provides middleware components for the FastAPI backend:
- Request-ID: Adds unique request identifiers for tracing
- Security Headers: Adds security headers to all responses
- Rate Limiter: Protects against abuse (Valkey-based)
- PII Redactor: Redacts sensitive data from logs
- Input Gate: Validates request body size and content types
"""
from .request_id import RequestIDMiddleware, get_request_id
from .security_headers import SecurityHeadersMiddleware
from .rate_limiter import RateLimiterMiddleware
from .pii_redactor import PIIRedactor, redact_pii
from .input_gate import InputGateMiddleware
__all__ = [
"RequestIDMiddleware",
"get_request_id",
"SecurityHeadersMiddleware",
"RateLimiterMiddleware",
"PIIRedactor",
"redact_pii",
"InputGateMiddleware",
]

View File

@@ -0,0 +1,260 @@
"""
Input Validation Gate Middleware
Validates incoming requests for:
- Request body size limits
- Content-Type validation
- File upload limits
- Malicious content detection
Usage:
from middleware import InputGateMiddleware
app.add_middleware(
InputGateMiddleware,
max_body_size=10 * 1024 * 1024, # 10MB
allowed_content_types=["application/json", "multipart/form-data"],
)
"""
import os
from dataclasses import dataclass, field
from typing import List, Optional, Set
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
@dataclass
class InputGateConfig:
"""Configuration for input validation."""
# Maximum request body size (default: 10MB)
max_body_size: int = 10 * 1024 * 1024
# Allowed content types
allowed_content_types: Set[str] = field(default_factory=lambda: {
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
})
# File upload specific limits
max_file_size: int = 50 * 1024 * 1024 # 50MB for file uploads
allowed_file_types: Set[str] = field(default_factory=lambda: {
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"text/csv",
})
# Blocked file extensions (potential malware)
blocked_extensions: Set[str] = field(default_factory=lambda: {
".exe", ".bat", ".cmd", ".com", ".msi",
".dll", ".scr", ".pif", ".vbs", ".js",
".jar", ".sh", ".ps1", ".app",
})
# Paths that allow larger uploads (e.g., file upload endpoints)
large_upload_paths: List[str] = field(default_factory=lambda: [
"/api/files/upload",
"/api/documents/upload",
"/api/attachments",
])
# Paths excluded from validation
excluded_paths: List[str] = field(default_factory=lambda: [
"/health",
"/metrics",
])
# Enable strict content type checking
strict_content_type: bool = True
class InputGateMiddleware(BaseHTTPMiddleware):
"""
Middleware that validates incoming request bodies and content types.
Protects against:
- Oversized request bodies
- Invalid content types
- Potentially malicious file uploads
"""
def __init__(
self,
app,
config: Optional[InputGateConfig] = None,
max_body_size: Optional[int] = None,
allowed_content_types: Optional[Set[str]] = None,
):
super().__init__(app)
self.config = config or InputGateConfig()
# Apply overrides
if max_body_size is not None:
self.config.max_body_size = max_body_size
if allowed_content_types is not None:
self.config.allowed_content_types = allowed_content_types
# Auto-configure from environment
env_max_size = os.getenv("MAX_REQUEST_BODY_SIZE")
if env_max_size:
try:
self.config.max_body_size = int(env_max_size)
except ValueError:
pass
def _is_excluded_path(self, path: str) -> bool:
"""Check if path is excluded from validation."""
return path in self.config.excluded_paths
def _is_large_upload_path(self, path: str) -> bool:
"""Check if path allows larger uploads."""
for upload_path in self.config.large_upload_paths:
if path.startswith(upload_path):
return True
return False
def _get_max_size(self, path: str) -> int:
"""Get the maximum allowed body size for this path."""
if self._is_large_upload_path(path):
return self.config.max_file_size
return self.config.max_body_size
def _validate_content_type(self, content_type: Optional[str]) -> tuple[bool, str]:
"""
Validate the content type.
Returns:
Tuple of (is_valid, error_message)
"""
if not content_type:
# Allow requests without content type (e.g., GET requests)
return True, ""
# Extract base content type (remove charset, boundary, etc.)
base_type = content_type.split(";")[0].strip().lower()
if base_type not in self.config.allowed_content_types:
return False, f"Content-Type '{base_type}' is not allowed"
return True, ""
def _check_blocked_extension(self, filename: str) -> bool:
"""Check if filename has a blocked extension."""
if not filename:
return False
lower_filename = filename.lower()
for ext in self.config.blocked_extensions:
if lower_filename.endswith(ext):
return True
return False
async def dispatch(self, request: Request, call_next) -> Response:
# Skip excluded paths
if self._is_excluded_path(request.url.path):
return await call_next(request)
# Skip validation for GET, HEAD, OPTIONS requests
if request.method in ("GET", "HEAD", "OPTIONS"):
return await call_next(request)
# Validate content type for requests with body
content_type = request.headers.get("Content-Type")
if self.config.strict_content_type:
is_valid, error_msg = self._validate_content_type(content_type)
if not is_valid:
return JSONResponse(
status_code=415,
content={
"error": "unsupported_media_type",
"message": error_msg,
},
)
# Check Content-Length header
content_length = request.headers.get("Content-Length")
if content_length:
try:
length = int(content_length)
max_size = self._get_max_size(request.url.path)
if length > max_size:
return JSONResponse(
status_code=413,
content={
"error": "payload_too_large",
"message": f"Request body exceeds maximum size of {max_size} bytes",
"max_size": max_size,
},
)
except ValueError:
return JSONResponse(
status_code=400,
content={
"error": "invalid_content_length",
"message": "Invalid Content-Length header",
},
)
# For multipart uploads, check for blocked file extensions
if content_type and "multipart/form-data" in content_type:
# Note: Full file validation would require reading the body
# which we avoid in middleware for performance reasons.
# Detailed file validation should happen in the handler.
pass
# Process request
return await call_next(request)
def validate_file_upload(
filename: str,
content_type: str,
size: int,
config: Optional[InputGateConfig] = None,
) -> tuple[bool, str]:
"""
Validate a file upload.
Use this in upload handlers for detailed validation.
Args:
filename: Original filename
content_type: MIME type of the file
size: File size in bytes
config: Optional custom configuration
Returns:
Tuple of (is_valid, error_message)
"""
cfg = config or InputGateConfig()
# Check size
if size > cfg.max_file_size:
return False, f"File size exceeds maximum of {cfg.max_file_size} bytes"
# Check extension
if filename:
lower_filename = filename.lower()
for ext in cfg.blocked_extensions:
if lower_filename.endswith(ext):
return False, f"File extension '{ext}' is not allowed"
# Check content type
if content_type and content_type not in cfg.allowed_file_types:
return False, f"File type '{content_type}' is not allowed"
return True, ""

View File

@@ -0,0 +1,316 @@
"""
PII Redactor
Redacts Personally Identifiable Information (PII) from logs and responses.
Essential for DSGVO/GDPR compliance in BreakPilot.
Redacted data types:
- Email addresses
- IP addresses
- German phone numbers
- Names (when identified)
- Student IDs
- Credit card numbers
- IBAN numbers
Usage:
from middleware import PIIRedactor, redact_pii
# Use in logging
logger.info(redact_pii(f"User {email} logged in from {ip}"))
# Configure redactor
redactor = PIIRedactor(patterns=["email", "ip", "phone"])
safe_message = redactor.redact(sensitive_message)
"""
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Pattern, Set
@dataclass
class PIIPattern:
"""Definition of a PII pattern."""
name: str
pattern: Pattern
replacement: str
# Pre-compiled regex patterns for common PII
PII_PATTERNS: Dict[str, PIIPattern] = {
"email": PIIPattern(
name="email",
pattern=re.compile(
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
re.IGNORECASE
),
replacement="[EMAIL_REDACTED]",
),
"ip_v4": PIIPattern(
name="ip_v4",
pattern=re.compile(
r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
),
replacement="[IP_REDACTED]",
),
"ip_v6": PIIPattern(
name="ip_v6",
pattern=re.compile(
r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b'
),
replacement="[IP_REDACTED]",
),
"phone_de": PIIPattern(
name="phone_de",
pattern=re.compile(
r'(?<!\w)(?:\+49|0049|0)[\s.-]?(?:\d{2,4})[\s.-]?(?:\d{3,4})[\s.-]?(?:\d{3,4})(?!\d)'
),
replacement="[PHONE_REDACTED]",
),
"phone_intl": PIIPattern(
name="phone_intl",
pattern=re.compile(
r'(?<!\w)\+?(?:\d[\s.-]?){10,15}(?!\d)'
),
replacement="[PHONE_REDACTED]",
),
"credit_card": PIIPattern(
name="credit_card",
pattern=re.compile(
r'\b(?:\d{4}[\s.-]?){3}\d{4}\b'
),
replacement="[CC_REDACTED]",
),
"iban": PIIPattern(
name="iban",
pattern=re.compile(
r'\b[A-Z]{2}\d{2}[\s]?(?:\d{4}[\s]?){3,5}\d{1,4}\b',
re.IGNORECASE
),
replacement="[IBAN_REDACTED]",
),
"student_id": PIIPattern(
name="student_id",
pattern=re.compile(
r'\b(?:student|schueler|schüler)[-_]?(?:id|nr)?[:\s]?\d{4,10}\b',
re.IGNORECASE
),
replacement="[STUDENT_ID_REDACTED]",
),
"uuid": PIIPattern(
name="uuid",
pattern=re.compile(
r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b',
re.IGNORECASE
),
replacement="[UUID_REDACTED]",
),
# German names are harder to detect, but we can catch common patterns
"name_prefix": PIIPattern(
name="name_prefix",
pattern=re.compile(
r'\b(?:Herr|Frau|Hr\.|Fr\.)\s+[A-ZÄÖÜ][a-zäöüß]+(?:\s+[A-ZÄÖÜ][a-zäöüß]+)?\b'
),
replacement="[NAME_REDACTED]",
),
}
# Default patterns to enable
DEFAULT_PATTERNS = ["email", "ip_v4", "ip_v6", "phone_de"]
class PIIRedactor:
"""
Redacts PII from strings.
Attributes:
patterns: List of pattern names to use (e.g., ["email", "ip_v4"])
custom_patterns: Additional custom patterns
"""
def __init__(
self,
patterns: Optional[List[str]] = None,
custom_patterns: Optional[List[PIIPattern]] = None,
preserve_format: bool = False,
):
"""
Initialize the PII redactor.
Args:
patterns: List of pattern names to enable (default: email, ip_v4, ip_v6, phone_de)
custom_patterns: Additional custom PIIPattern objects
preserve_format: If True, preserve the length of redacted content
"""
self.patterns = patterns or DEFAULT_PATTERNS
self.custom_patterns = custom_patterns or []
self.preserve_format = preserve_format
# Build active patterns list
self._active_patterns: List[PIIPattern] = []
for pattern_name in self.patterns:
if pattern_name in PII_PATTERNS:
self._active_patterns.append(PII_PATTERNS[pattern_name])
# Add custom patterns
self._active_patterns.extend(self.custom_patterns)
def redact(self, text: str) -> str:
"""
Redact PII from the given text.
Args:
text: The text to redact PII from
Returns:
Text with PII replaced by redaction markers
"""
if not text:
return text
result = text
for pattern in self._active_patterns:
if self.preserve_format:
# Replace with same-length placeholder
def replace_preserve(match):
length = len(match.group())
return "*" * length
result = pattern.pattern.sub(replace_preserve, result)
else:
result = pattern.pattern.sub(pattern.replacement, result)
return result
def contains_pii(self, text: str) -> bool:
"""
Check if text contains any PII.
Args:
text: The text to check
Returns:
True if PII is detected
"""
if not text:
return False
for pattern in self._active_patterns:
if pattern.pattern.search(text):
return True
return False
def find_pii(self, text: str) -> List[Dict[str, str]]:
"""
Find all PII in text with their types.
Args:
text: The text to search
Returns:
List of dicts with 'type' and 'match' keys
"""
if not text:
return []
findings = []
for pattern in self._active_patterns:
for match in pattern.pattern.finditer(text):
findings.append({
"type": pattern.name,
"match": match.group(),
"start": match.start(),
"end": match.end(),
})
return findings
# Module-level default redactor instance
_default_redactor: Optional[PIIRedactor] = None
def get_default_redactor() -> PIIRedactor:
"""Get or create the default redactor instance."""
global _default_redactor
if _default_redactor is None:
_default_redactor = PIIRedactor()
return _default_redactor
def redact_pii(text: str) -> str:
"""
Convenience function to redact PII using the default redactor.
Args:
text: Text to redact
Returns:
Redacted text
Example:
logger.info(redact_pii(f"User {email} logged in"))
"""
return get_default_redactor().redact(text)
class PIIRedactingLogFilter:
"""
Logging filter that automatically redacts PII from log messages.
Usage:
import logging
handler = logging.StreamHandler()
handler.addFilter(PIIRedactingLogFilter())
logger = logging.getLogger()
logger.addHandler(handler)
"""
def __init__(self, redactor: Optional[PIIRedactor] = None):
self.redactor = redactor or get_default_redactor()
def filter(self, record):
# Redact the message
if record.msg:
record.msg = self.redactor.redact(str(record.msg))
# Redact args if present
if record.args:
if isinstance(record.args, dict):
record.args = {
k: self.redactor.redact(str(v)) if isinstance(v, str) else v
for k, v in record.args.items()
}
elif isinstance(record.args, tuple):
record.args = tuple(
self.redactor.redact(str(v)) if isinstance(v, str) else v
for v in record.args
)
return True
def create_safe_dict(data: dict, redactor: Optional[PIIRedactor] = None) -> dict:
"""
Create a copy of a dictionary with PII redacted.
Args:
data: Dictionary to redact
redactor: Optional custom redactor
Returns:
New dictionary with redacted values
"""
r = redactor or get_default_redactor()
def redact_value(value):
if isinstance(value, str):
return r.redact(value)
elif isinstance(value, dict):
return create_safe_dict(value, r)
elif isinstance(value, list):
return [redact_value(v) for v in value]
return value
return {k: redact_value(v) for k, v in data.items()}

View File

@@ -0,0 +1,363 @@
"""
Rate Limiter Middleware
Implements distributed rate limiting using Valkey (Redis-fork).
Supports IP-based, user-based, and endpoint-specific rate limits.
Features:
- Sliding window rate limiting
- IP-based limits for unauthenticated requests
- User-based limits for authenticated requests
- Stricter limits for auth endpoints (anti-brute-force)
- IP whitelist/blacklist support
- Graceful fallback when Valkey is unavailable
Usage:
from middleware import RateLimiterMiddleware
app.add_middleware(
RateLimiterMiddleware,
valkey_url="redis://localhost:6379",
ip_limit=100,
user_limit=500,
)
"""
from __future__ import annotations
import asyncio
import hashlib
import os
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
# Try to import redis (valkey-compatible)
try:
import redis.asyncio as redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
redis = None
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting."""
# Valkey/Redis connection
valkey_url: str = "redis://localhost:6379"
# Default limits (requests per minute)
ip_limit: int = 100
user_limit: int = 500
# Stricter limits for auth endpoints
auth_limit: int = 20
auth_endpoints: List[str] = field(default_factory=lambda: [
"/api/auth/login",
"/api/auth/register",
"/api/auth/password-reset",
"/api/auth/forgot-password",
])
# Window size in seconds
window_size: int = 60
# IP whitelist (never rate limited)
ip_whitelist: Set[str] = field(default_factory=lambda: {
"127.0.0.1",
"::1",
})
# IP blacklist (always blocked)
ip_blacklist: Set[str] = field(default_factory=set)
# Skip internal Docker network
skip_internal_network: bool = True
# Excluded paths
excluded_paths: List[str] = field(default_factory=lambda: [
"/health",
"/metrics",
"/api/health",
])
# Fallback to in-memory when Valkey is unavailable
fallback_enabled: bool = True
# Key prefix for rate limit keys
key_prefix: str = "ratelimit"
class InMemoryRateLimiter:
"""Fallback in-memory rate limiter when Valkey is unavailable."""
def __init__(self):
self._counts: Dict[str, List[float]] = {}
self._lock = asyncio.Lock()
async def check_rate_limit(self, key: str, limit: int, window: int) -> tuple[bool, int]:
"""
Check if rate limit is exceeded.
Returns:
Tuple of (is_allowed, remaining_requests)
"""
async with self._lock:
now = time.time()
window_start = now - window
# Get or create entry
if key not in self._counts:
self._counts[key] = []
# Remove old entries
self._counts[key] = [t for t in self._counts[key] if t > window_start]
# Check limit
current_count = len(self._counts[key])
if current_count >= limit:
return False, 0
# Add new request
self._counts[key].append(now)
return True, limit - current_count - 1
async def cleanup(self):
"""Remove expired entries."""
async with self._lock:
now = time.time()
for key in list(self._counts.keys()):
self._counts[key] = [t for t in self._counts[key] if t > now - 3600]
if not self._counts[key]:
del self._counts[key]
class RateLimiterMiddleware(BaseHTTPMiddleware):
"""
Middleware that implements distributed rate limiting.
Uses Valkey (Redis-fork) for distributed state, with fallback
to in-memory rate limiting when Valkey is unavailable.
"""
def __init__(
self,
app,
config: Optional[RateLimitConfig] = None,
# Individual overrides
valkey_url: Optional[str] = None,
ip_limit: Optional[int] = None,
user_limit: Optional[int] = None,
auth_limit: Optional[int] = None,
):
super().__init__(app)
self.config = config or RateLimitConfig()
# Apply overrides
if valkey_url is not None:
self.config.valkey_url = valkey_url
if ip_limit is not None:
self.config.ip_limit = ip_limit
if user_limit is not None:
self.config.user_limit = user_limit
if auth_limit is not None:
self.config.auth_limit = auth_limit
# Auto-configure from environment
self.config.valkey_url = os.getenv("VALKEY_URL", self.config.valkey_url)
# Initialize Valkey client
self._redis: Optional[redis.Redis] = None
self._fallback = InMemoryRateLimiter()
self._valkey_available = False
async def _get_redis(self) -> Optional[redis.Redis]:
"""Get or create Redis/Valkey connection."""
if not REDIS_AVAILABLE:
return None
if self._redis is None:
try:
self._redis = redis.from_url(
self.config.valkey_url,
decode_responses=True,
socket_timeout=1.0,
socket_connect_timeout=1.0,
)
await self._redis.ping()
self._valkey_available = True
except Exception:
self._valkey_available = False
self._redis = None
return self._redis
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request."""
# Check X-Forwarded-For header
xff = request.headers.get("X-Forwarded-For")
if xff:
return xff.split(",")[0].strip()
# Check X-Real-IP header
xri = request.headers.get("X-Real-IP")
if xri:
return xri
# Fall back to direct client IP
if request.client:
return request.client.host
return "unknown"
def _get_user_id(self, request: Request) -> Optional[str]:
"""Extract user ID from request state (set by session middleware)."""
if hasattr(request.state, "session") and request.state.session:
return getattr(request.state.session, "user_id", None)
return None
def _is_internal_network(self, ip: str) -> bool:
"""Check if IP is from internal Docker network."""
return (
ip.startswith("172.") or
ip.startswith("10.") or
ip.startswith("192.168.")
)
def _get_rate_limit(self, request: Request) -> int:
"""Determine the rate limit for this request."""
path = request.url.path
# Auth endpoints get stricter limits
for auth_path in self.config.auth_endpoints:
if path.startswith(auth_path):
return self.config.auth_limit
# Authenticated users get higher limits
if self._get_user_id(request):
return self.config.user_limit
# Default IP-based limit
return self.config.ip_limit
def _get_rate_limit_key(self, request: Request) -> str:
"""Generate the rate limit key for this request."""
# Use user ID if authenticated
user_id = self._get_user_id(request)
if user_id:
identifier = f"user:{user_id}"
else:
ip = self._get_client_ip(request)
# Hash IP for privacy
ip_hash = hashlib.sha256(ip.encode()).hexdigest()[:16]
identifier = f"ip:{ip_hash}"
# Include path for endpoint-specific limits
path = request.url.path
for auth_path in self.config.auth_endpoints:
if path.startswith(auth_path):
return f"{self.config.key_prefix}:auth:{identifier}"
return f"{self.config.key_prefix}:{identifier}"
async def _check_rate_limit_valkey(
self, key: str, limit: int, window: int
) -> tuple[bool, int]:
"""Check rate limit using Valkey."""
r = await self._get_redis()
if not r:
return await self._fallback.check_rate_limit(key, limit, window)
try:
# Use sliding window with sorted set
now = time.time()
window_start = now - window
pipe = r.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, "-inf", window_start)
# Count current entries
pipe.zcard(key)
# Add new entry
pipe.zadd(key, {str(now): now})
# Set expiry
pipe.expire(key, window + 10)
results = await pipe.execute()
current_count = results[1]
if current_count >= limit:
return False, 0
return True, limit - current_count - 1
except Exception:
# Fallback to in-memory
self._valkey_available = False
return await self._fallback.check_rate_limit(key, limit, window)
async def dispatch(self, request: Request, call_next) -> Response:
# Skip excluded paths
if request.url.path in self.config.excluded_paths:
return await call_next(request)
# Get client IP
ip = self._get_client_ip(request)
# Check blacklist
if ip in self.config.ip_blacklist:
return JSONResponse(
status_code=403,
content={
"error": "ip_blocked",
"message": "Your IP address has been blocked.",
},
)
# Skip whitelist
if ip in self.config.ip_whitelist:
return await call_next(request)
# Skip internal network
if self.config.skip_internal_network and self._is_internal_network(ip):
return await call_next(request)
# Get rate limit parameters
limit = self._get_rate_limit(request)
key = self._get_rate_limit_key(request)
window = self.config.window_size
# Check rate limit
allowed, remaining = await self._check_rate_limit_valkey(key, limit, window)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": "Too many requests. Please try again later.",
"retry_after": window,
},
headers={
"Retry-After": str(window),
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + window),
},
)
# Process request
response = await call_next(request)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(int(time.time()) + window)
return response

View File

@@ -0,0 +1,138 @@
"""
Request-ID Middleware
Generates and propagates unique request identifiers for distributed tracing.
Supports both X-Request-ID and X-Correlation-ID headers.
Usage:
from middleware import RequestIDMiddleware, get_request_id
app.add_middleware(RequestIDMiddleware)
@app.get("/api/example")
async def example():
request_id = get_request_id()
logger.info(f"Processing request", extra={"request_id": request_id})
"""
import uuid
from contextvars import ContextVar
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
# Context variable to store request ID across async calls
_request_id_ctx: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
# Header names
REQUEST_ID_HEADER = "X-Request-ID"
CORRELATION_ID_HEADER = "X-Correlation-ID"
def get_request_id() -> Optional[str]:
"""
Get the current request ID from context.
Returns:
The request ID string or None if not in a request context.
Example:
request_id = get_request_id()
logger.info("Processing", extra={"request_id": request_id})
"""
return _request_id_ctx.get()
def set_request_id(request_id: str) -> None:
"""
Set the request ID in the current context.
Args:
request_id: The request ID to set
"""
_request_id_ctx.set(request_id)
def generate_request_id() -> str:
"""
Generate a new unique request ID.
Returns:
A UUID4 string
"""
return str(uuid.uuid4())
class RequestIDMiddleware(BaseHTTPMiddleware):
"""
Middleware that generates and propagates request IDs.
For each incoming request:
1. Check for existing X-Request-ID or X-Correlation-ID header
2. If not present, generate a new UUID
3. Store in context for use by handlers and logging
4. Add to response headers
Attributes:
header_name: The primary header name to use (default: X-Request-ID)
generator: Function to generate new IDs (default: uuid4)
"""
def __init__(
self,
app,
header_name: str = REQUEST_ID_HEADER,
generator=generate_request_id,
):
super().__init__(app)
self.header_name = header_name
self.generator = generator
async def dispatch(self, request: Request, call_next) -> Response:
# Try to get existing request ID from headers
request_id = (
request.headers.get(REQUEST_ID_HEADER)
or request.headers.get(CORRELATION_ID_HEADER)
)
# Generate new ID if not provided
if not request_id:
request_id = self.generator()
# Store in context for logging and handlers
set_request_id(request_id)
# Store in request state for direct access
request.state.request_id = request_id
# Process request
response = await call_next(request)
# Add request ID to response headers
response.headers[REQUEST_ID_HEADER] = request_id
response.headers[CORRELATION_ID_HEADER] = request_id
return response
class RequestIDLogFilter:
"""
Logging filter that adds request_id to log records.
Usage:
import logging
handler = logging.StreamHandler()
handler.addFilter(RequestIDLogFilter())
formatter = logging.Formatter(
'%(asctime)s [%(request_id)s] %(levelname)s %(message)s'
)
handler.setFormatter(formatter)
"""
def filter(self, record):
record.request_id = get_request_id() or "no-request-id"
return True

View File

@@ -0,0 +1,202 @@
"""
Security Headers Middleware
Adds security headers to all HTTP responses to protect against common attacks.
Headers added:
- X-Content-Type-Options: nosniff
- X-Frame-Options: DENY
- X-XSS-Protection: 1; mode=block
- Strict-Transport-Security (HSTS)
- Content-Security-Policy
- Referrer-Policy
- Permissions-Policy
Usage:
from middleware import SecurityHeadersMiddleware
app.add_middleware(SecurityHeadersMiddleware)
# Or with custom configuration:
app.add_middleware(
SecurityHeadersMiddleware,
hsts_enabled=True,
csp_policy="default-src 'self'",
)
"""
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
@dataclass
class SecurityHeadersConfig:
"""Configuration for security headers."""
# X-Content-Type-Options
content_type_options: str = "nosniff"
# X-Frame-Options
frame_options: str = "DENY"
# X-XSS-Protection (legacy, but still useful for older browsers)
xss_protection: str = "1; mode=block"
# Strict-Transport-Security
hsts_enabled: bool = True
hsts_max_age: int = 31536000 # 1 year
hsts_include_subdomains: bool = True
hsts_preload: bool = False
# Content-Security-Policy
csp_enabled: bool = True
csp_policy: str = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'"
# Referrer-Policy
referrer_policy: str = "strict-origin-when-cross-origin"
# Permissions-Policy (formerly Feature-Policy)
permissions_policy: str = "geolocation=(), microphone=(), camera=()"
# Cross-Origin headers
cross_origin_opener_policy: str = "same-origin"
cross_origin_embedder_policy: str = "require-corp"
cross_origin_resource_policy: str = "same-origin"
# Development mode (relaxes some restrictions)
development_mode: bool = False
# Excluded paths (e.g., for health checks)
excluded_paths: List[str] = field(default_factory=lambda: ["/health", "/metrics"])
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Middleware that adds security headers to all responses.
Attributes:
config: SecurityHeadersConfig instance
"""
def __init__(
self,
app,
config: Optional[SecurityHeadersConfig] = None,
# Individual overrides for convenience
hsts_enabled: Optional[bool] = None,
csp_policy: Optional[str] = None,
csp_enabled: Optional[bool] = None,
development_mode: Optional[bool] = None,
):
super().__init__(app)
# Use provided config or create default
self.config = config or SecurityHeadersConfig()
# Apply individual overrides
if hsts_enabled is not None:
self.config.hsts_enabled = hsts_enabled
if csp_policy is not None:
self.config.csp_policy = csp_policy
if csp_enabled is not None:
self.config.csp_enabled = csp_enabled
if development_mode is not None:
self.config.development_mode = development_mode
# Auto-detect development mode from environment
if development_mode is None:
env = os.getenv("ENVIRONMENT", "development")
self.config.development_mode = env.lower() in ("development", "dev", "local")
def _build_hsts_header(self) -> str:
"""Build the Strict-Transport-Security header value."""
parts = [f"max-age={self.config.hsts_max_age}"]
if self.config.hsts_include_subdomains:
parts.append("includeSubDomains")
if self.config.hsts_preload:
parts.append("preload")
return "; ".join(parts)
def _get_headers(self) -> Dict[str, str]:
"""Build the security headers dictionary."""
headers = {}
# Always add these headers
headers["X-Content-Type-Options"] = self.config.content_type_options
headers["X-Frame-Options"] = self.config.frame_options
headers["X-XSS-Protection"] = self.config.xss_protection
headers["Referrer-Policy"] = self.config.referrer_policy
# HSTS (only in production or if explicitly enabled)
if self.config.hsts_enabled and not self.config.development_mode:
headers["Strict-Transport-Security"] = self._build_hsts_header()
# Content-Security-Policy
if self.config.csp_enabled:
headers["Content-Security-Policy"] = self.config.csp_policy
# Permissions-Policy
if self.config.permissions_policy:
headers["Permissions-Policy"] = self.config.permissions_policy
# Cross-Origin headers (relaxed in development)
if not self.config.development_mode:
headers["Cross-Origin-Opener-Policy"] = self.config.cross_origin_opener_policy
# Note: COEP can break loading of external resources, be careful
# headers["Cross-Origin-Embedder-Policy"] = self.config.cross_origin_embedder_policy
headers["Cross-Origin-Resource-Policy"] = self.config.cross_origin_resource_policy
return headers
async def dispatch(self, request: Request, call_next) -> Response:
# Skip security headers for excluded paths
if request.url.path in self.config.excluded_paths:
return await call_next(request)
# Process request
response = await call_next(request)
# Add security headers
for header_name, header_value in self._get_headers().items():
response.headers[header_name] = header_value
return response
def get_default_csp_for_environment(environment: str) -> str:
"""
Get a sensible default CSP for the given environment.
Args:
environment: "development", "staging", or "production"
Returns:
CSP policy string
"""
if environment.lower() in ("development", "dev", "local"):
# Relaxed CSP for development
return (
"default-src 'self' localhost:* ws://localhost:*; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https: blob:; "
"font-src 'self' data:; "
"connect-src 'self' localhost:* ws://localhost:* https:; "
"frame-ancestors 'self'"
)
else:
# Strict CSP for production
return (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self' data:; "
"connect-src 'self' https://breakpilot.app https://*.breakpilot.app; "
"frame-ancestors 'none'"
)