""" Request-ID Middleware Generates and propagates unique request identifiers for distributed tracing. Supports both X-Request-ID and X-Correlation-ID headers. Usage: from middleware import RequestIDMiddleware, get_request_id app.add_middleware(RequestIDMiddleware) @app.get("/api/example") async def example(): request_id = get_request_id() logger.info(f"Processing request", extra={"request_id": request_id}) """ import uuid from contextvars import ContextVar from typing import Optional from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response # Context variable to store request ID across async calls _request_id_ctx: ContextVar[Optional[str]] = ContextVar("request_id", default=None) # Header names REQUEST_ID_HEADER = "X-Request-ID" CORRELATION_ID_HEADER = "X-Correlation-ID" def get_request_id() -> Optional[str]: """ Get the current request ID from context. Returns: The request ID string or None if not in a request context. Example: request_id = get_request_id() logger.info("Processing", extra={"request_id": request_id}) """ return _request_id_ctx.get() def set_request_id(request_id: str) -> None: """ Set the request ID in the current context. Args: request_id: The request ID to set """ _request_id_ctx.set(request_id) def generate_request_id() -> str: """ Generate a new unique request ID. Returns: A UUID4 string """ return str(uuid.uuid4()) class RequestIDMiddleware(BaseHTTPMiddleware): """ Middleware that generates and propagates request IDs. For each incoming request: 1. Check for existing X-Request-ID or X-Correlation-ID header 2. If not present, generate a new UUID 3. Store in context for use by handlers and logging 4. Add to response headers Attributes: header_name: The primary header name to use (default: X-Request-ID) generator: Function to generate new IDs (default: uuid4) """ def __init__( self, app, header_name: str = REQUEST_ID_HEADER, generator=generate_request_id, ): super().__init__(app) self.header_name = header_name self.generator = generator async def dispatch(self, request: Request, call_next) -> Response: # Try to get existing request ID from headers request_id = ( request.headers.get(REQUEST_ID_HEADER) or request.headers.get(CORRELATION_ID_HEADER) ) # Generate new ID if not provided if not request_id: request_id = self.generator() # Store in context for logging and handlers set_request_id(request_id) # Store in request state for direct access request.state.request_id = request_id # Process request response = await call_next(request) # Add request ID to response headers response.headers[REQUEST_ID_HEADER] = request_id response.headers[CORRELATION_ID_HEADER] = request_id return response class RequestIDLogFilter: """ Logging filter that adds request_id to log records. Usage: import logging handler = logging.StreamHandler() handler.addFilter(RequestIDLogFilter()) formatter = logging.Formatter( '%(asctime)s [%(request_id)s] %(levelname)s %(message)s' ) handler.setFormatter(formatter) """ def filter(self, record): record.request_id = get_request_id() or "no-request-id" return True