""" Input Validation Gate Middleware Validates incoming requests for: - Request body size limits - Content-Type validation - File upload limits - Malicious content detection Usage: from middleware import InputGateMiddleware app.add_middleware( InputGateMiddleware, max_body_size=10 * 1024 * 1024, # 10MB allowed_content_types=["application/json", "multipart/form-data"], ) """ import os from dataclasses import dataclass, field from typing import List, Optional, Set from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response @dataclass class InputGateConfig: """Configuration for input validation.""" # Maximum request body size (default: 10MB) max_body_size: int = 10 * 1024 * 1024 # Allowed content types allowed_content_types: Set[str] = field(default_factory=lambda: { "application/json", "application/x-www-form-urlencoded", "multipart/form-data", "text/plain", }) # File upload specific limits max_file_size: int = 50 * 1024 * 1024 # 50MB for file uploads allowed_file_types: Set[str] = field(default_factory=lambda: { "image/jpeg", "image/png", "image/gif", "image/webp", "application/pdf", "application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "text/csv", }) # Blocked file extensions (potential malware) blocked_extensions: Set[str] = field(default_factory=lambda: { ".exe", ".bat", ".cmd", ".com", ".msi", ".dll", ".scr", ".pif", ".vbs", ".js", ".jar", ".sh", ".ps1", ".app", }) # Paths that allow larger uploads (e.g., file upload endpoints) large_upload_paths: List[str] = field(default_factory=lambda: [ "/api/files/upload", "/api/documents/upload", "/api/attachments", ]) # Paths excluded from validation excluded_paths: List[str] = field(default_factory=lambda: [ "/health", "/metrics", ]) # Enable strict content type checking strict_content_type: bool = True class InputGateMiddleware(BaseHTTPMiddleware): """ Middleware that validates incoming request bodies and content types. Protects against: - Oversized request bodies - Invalid content types - Potentially malicious file uploads """ def __init__( self, app, config: Optional[InputGateConfig] = None, max_body_size: Optional[int] = None, allowed_content_types: Optional[Set[str]] = None, ): super().__init__(app) self.config = config or InputGateConfig() # Apply overrides if max_body_size is not None: self.config.max_body_size = max_body_size if allowed_content_types is not None: self.config.allowed_content_types = allowed_content_types # Auto-configure from environment env_max_size = os.getenv("MAX_REQUEST_BODY_SIZE") if env_max_size: try: self.config.max_body_size = int(env_max_size) except ValueError: pass def _is_excluded_path(self, path: str) -> bool: """Check if path is excluded from validation.""" return path in self.config.excluded_paths def _is_large_upload_path(self, path: str) -> bool: """Check if path allows larger uploads.""" for upload_path in self.config.large_upload_paths: if path.startswith(upload_path): return True return False def _get_max_size(self, path: str) -> int: """Get the maximum allowed body size for this path.""" if self._is_large_upload_path(path): return self.config.max_file_size return self.config.max_body_size def _validate_content_type(self, content_type: Optional[str]) -> tuple[bool, str]: """ Validate the content type. Returns: Tuple of (is_valid, error_message) """ if not content_type: # Allow requests without content type (e.g., GET requests) return True, "" # Extract base content type (remove charset, boundary, etc.) base_type = content_type.split(";")[0].strip().lower() if base_type not in self.config.allowed_content_types: return False, f"Content-Type '{base_type}' is not allowed" return True, "" def _check_blocked_extension(self, filename: str) -> bool: """Check if filename has a blocked extension.""" if not filename: return False lower_filename = filename.lower() for ext in self.config.blocked_extensions: if lower_filename.endswith(ext): return True return False async def dispatch(self, request: Request, call_next) -> Response: # Skip excluded paths if self._is_excluded_path(request.url.path): return await call_next(request) # Skip validation for GET, HEAD, OPTIONS requests if request.method in ("GET", "HEAD", "OPTIONS"): return await call_next(request) # Validate content type for requests with body content_type = request.headers.get("Content-Type") if self.config.strict_content_type: is_valid, error_msg = self._validate_content_type(content_type) if not is_valid: return JSONResponse( status_code=415, content={ "error": "unsupported_media_type", "message": error_msg, }, ) # Check Content-Length header content_length = request.headers.get("Content-Length") if content_length: try: length = int(content_length) max_size = self._get_max_size(request.url.path) if length > max_size: return JSONResponse( status_code=413, content={ "error": "payload_too_large", "message": f"Request body exceeds maximum size of {max_size} bytes", "max_size": max_size, }, ) except ValueError: return JSONResponse( status_code=400, content={ "error": "invalid_content_length", "message": "Invalid Content-Length header", }, ) # For multipart uploads, check for blocked file extensions if content_type and "multipart/form-data" in content_type: # Note: Full file validation would require reading the body # which we avoid in middleware for performance reasons. # Detailed file validation should happen in the handler. pass # Process request return await call_next(request) def validate_file_upload( filename: str, content_type: str, size: int, config: Optional[InputGateConfig] = None, ) -> tuple[bool, str]: """ Validate a file upload. Use this in upload handlers for detailed validation. Args: filename: Original filename content_type: MIME type of the file size: File size in bytes config: Optional custom configuration Returns: Tuple of (is_valid, error_message) """ cfg = config or InputGateConfig() # Check size if size > cfg.max_file_size: return False, f"File size exceeds maximum of {cfg.max_file_size} bytes" # Check extension if filename: lower_filename = filename.lower() for ext in cfg.blocked_extensions: if lower_filename.endswith(ext): return False, f"File extension '{ext}' is not allowed" # Check content type if content_type and content_type not in cfg.allowed_file_types: return False, f"File type '{content_type}' is not allowed" return True, ""