""" Integration Tests for Middleware Components Tests the middleware stack: - Request-ID generation and propagation - Security headers - Rate limiting - PII redaction - Input validation """ import pytest from unittest.mock import MagicMock, AsyncMock, patch from starlette.requests import Request from starlette.responses import Response, JSONResponse from starlette.testclient import TestClient from fastapi import FastAPI import time # ============================================== # Request-ID Middleware Tests # ============================================== class TestRequestIDMiddleware: """Tests for RequestIDMiddleware.""" def test_generates_request_id_when_not_provided(self): """Should generate a UUID when no X-Request-ID header is provided.""" from middleware.request_id import RequestIDMiddleware, get_request_id app = FastAPI() app.add_middleware(RequestIDMiddleware) @app.get("/test") async def test_endpoint(): return {"request_id": get_request_id()} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert "X-Request-ID" in response.headers assert len(response.headers["X-Request-ID"]) == 36 # UUID format def test_propagates_existing_request_id(self): """Should propagate existing X-Request-ID header.""" from middleware.request_id import RequestIDMiddleware app = FastAPI() app.add_middleware(RequestIDMiddleware) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) custom_id = "custom-request-id-12345" response = client.get("/test", headers={"X-Request-ID": custom_id}) assert response.status_code == 200 assert response.headers["X-Request-ID"] == custom_id def test_propagates_correlation_id(self): """Should propagate X-Correlation-ID header.""" from middleware.request_id import RequestIDMiddleware app = FastAPI() app.add_middleware(RequestIDMiddleware) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) custom_id = "correlation-id-12345" response = client.get("/test", headers={"X-Correlation-ID": custom_id}) assert response.status_code == 200 assert response.headers["X-Request-ID"] == custom_id assert response.headers["X-Correlation-ID"] == custom_id # ============================================== # Security Headers Middleware Tests # ============================================== class TestSecurityHeadersMiddleware: """Tests for SecurityHeadersMiddleware.""" def test_adds_security_headers(self): """Should add security headers to all responses.""" from middleware.security_headers import SecurityHeadersMiddleware app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, development_mode=False) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert response.headers["X-Content-Type-Options"] == "nosniff" assert response.headers["X-Frame-Options"] == "DENY" assert response.headers["X-XSS-Protection"] == "1; mode=block" assert "Referrer-Policy" in response.headers def test_hsts_in_production(self): """Should add HSTS header in production mode.""" from middleware.security_headers import SecurityHeadersMiddleware app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, development_mode=False, hsts_enabled=True) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert "Strict-Transport-Security" in response.headers def test_no_hsts_in_development(self): """Should not add HSTS header in development mode.""" from middleware.security_headers import SecurityHeadersMiddleware app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, development_mode=True) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert "Strict-Transport-Security" not in response.headers def test_csp_header(self): """Should add CSP header when enabled.""" from middleware.security_headers import SecurityHeadersMiddleware app = FastAPI() app.add_middleware( SecurityHeadersMiddleware, csp_enabled=True, csp_policy="default-src 'self'" ) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert response.headers["Content-Security-Policy"] == "default-src 'self'" def test_excludes_health_endpoint(self): """Should not add security headers to excluded paths.""" from middleware.security_headers import SecurityHeadersMiddleware, SecurityHeadersConfig config = SecurityHeadersConfig(excluded_paths=["/health"]) app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, config=config) @app.get("/health") async def health(): return {"status": "healthy"} client = TestClient(app) response = client.get("/health") assert response.status_code == 200 # Security headers should not be present assert "Content-Security-Policy" not in response.headers # ============================================== # PII Redactor Tests # ============================================== class TestPIIRedactor: """Tests for PII redaction.""" def test_redacts_email(self): """Should redact email addresses.""" from middleware.pii_redactor import redact_pii text = "User test@example.com logged in" result = redact_pii(text) assert "test@example.com" not in result assert "[EMAIL_REDACTED]" in result def test_redacts_ip_v4(self): """Should redact IPv4 addresses.""" from middleware.pii_redactor import redact_pii text = "Request from 192.168.1.100" result = redact_pii(text) assert "192.168.1.100" not in result assert "[IP_REDACTED]" in result def test_redacts_german_phone(self): """Should redact German phone numbers.""" from middleware.pii_redactor import redact_pii text = "Call +49 30 12345678" result = redact_pii(text) assert "+49 30 12345678" not in result assert "[PHONE_REDACTED]" in result def test_redacts_multiple_pii(self): """Should redact multiple PII types in same text.""" from middleware.pii_redactor import redact_pii text = "User test@example.com from 10.0.0.1" result = redact_pii(text) assert "test@example.com" not in result assert "10.0.0.1" not in result assert "[EMAIL_REDACTED]" in result assert "[IP_REDACTED]" in result def test_preserves_non_pii_text(self): """Should preserve text that is not PII.""" from middleware.pii_redactor import redact_pii text = "User logged in successfully" result = redact_pii(text) assert result == text def test_contains_pii_detection(self): """Should detect if text contains PII.""" from middleware.pii_redactor import PIIRedactor redactor = PIIRedactor() assert redactor.contains_pii("test@example.com") assert redactor.contains_pii("192.168.1.1") assert not redactor.contains_pii("Hello World") def test_find_pii_locations(self): """Should find PII and return locations.""" from middleware.pii_redactor import PIIRedactor redactor = PIIRedactor() text = "Email: test@example.com, IP: 10.0.0.1" findings = redactor.find_pii(text) assert len(findings) == 2 assert any(f["type"] == "email" for f in findings) assert any(f["type"] == "ip_v4" for f in findings) # ============================================== # Input Gate Middleware Tests # ============================================== class TestInputGateMiddleware: """Tests for InputGateMiddleware.""" def test_allows_valid_json_request(self): """Should allow valid JSON request within size limit.""" from middleware.input_gate import InputGateMiddleware app = FastAPI() app.add_middleware(InputGateMiddleware, max_body_size=1024) @app.post("/test") async def test_endpoint(data: dict): return {"received": True} client = TestClient(app) response = client.post( "/test", json={"key": "value"}, headers={"Content-Type": "application/json"} ) assert response.status_code == 200 def test_rejects_invalid_content_type(self): """Should reject request with invalid content type.""" from middleware.input_gate import InputGateMiddleware, InputGateConfig config = InputGateConfig( allowed_content_types={"application/json"}, strict_content_type=True ) app = FastAPI() app.add_middleware(InputGateMiddleware, config=config) @app.post("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.post( "/test", content="data", headers={"Content-Type": "text/xml"} ) assert response.status_code == 415 # Unsupported Media Type def test_allows_get_requests_without_body(self): """Should allow GET requests without validation.""" from middleware.input_gate import InputGateMiddleware app = FastAPI() app.add_middleware(InputGateMiddleware) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 def test_excludes_health_endpoint(self): """Should not validate excluded paths.""" from middleware.input_gate import InputGateMiddleware, InputGateConfig config = InputGateConfig(excluded_paths=["/health"]) app = FastAPI() app.add_middleware(InputGateMiddleware, config=config) @app.get("/health") async def health(): return {"status": "healthy"} client = TestClient(app) response = client.get("/health") assert response.status_code == 200 # ============================================== # File Upload Validation Tests # ============================================== class TestFileUploadValidation: """Tests for file upload validation.""" def test_validates_file_size(self): """Should reject files exceeding max size.""" from middleware.input_gate import validate_file_upload, InputGateConfig config = InputGateConfig(max_file_size=1024) # 1KB valid, error = validate_file_upload( filename="test.pdf", content_type="application/pdf", size=512, # 512 bytes config=config ) assert valid valid, error = validate_file_upload( filename="test.pdf", content_type="application/pdf", size=2048, # 2KB - exceeds limit config=config ) assert not valid assert "size" in error.lower() def test_rejects_blocked_extensions(self): """Should reject files with blocked extensions.""" from middleware.input_gate import validate_file_upload valid, error = validate_file_upload( filename="malware.exe", content_type="application/octet-stream", size=100 ) assert not valid assert "extension" in error.lower() valid, error = validate_file_upload( filename="script.bat", content_type="application/octet-stream", size=100 ) assert not valid def test_allows_safe_file_types(self): """Should allow safe file types.""" from middleware.input_gate import validate_file_upload valid, error = validate_file_upload( filename="document.pdf", content_type="application/pdf", size=1024 ) assert valid valid, error = validate_file_upload( filename="image.png", content_type="image/png", size=1024 ) assert valid # ============================================== # Rate Limiter Tests # ============================================== class TestRateLimiterMiddleware: """Tests for RateLimiterMiddleware.""" def test_allows_requests_under_limit(self): """Should allow requests under the rate limit.""" from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig config = RateLimitConfig( ip_limit=100, window_size=60, fallback_enabled=True, skip_internal_network=True, ) app = FastAPI() app.add_middleware(RateLimiterMiddleware, config=config) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) # Make a few requests - should all succeed for _ in range(5): response = client.get("/test") assert response.status_code == 200 def test_rate_limit_headers(self): """Should include rate limit headers in response.""" from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig # Use a client IP that won't be skipped config = RateLimitConfig( ip_limit=100, window_size=60, fallback_enabled=True, skip_internal_network=False, # Don't skip internal IPs for this test ) app = FastAPI() app.add_middleware(RateLimiterMiddleware, config=config) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert "X-RateLimit-Limit" in response.headers assert "X-RateLimit-Remaining" in response.headers assert "X-RateLimit-Reset" in response.headers def test_skips_whitelisted_ips(self): """Should skip rate limiting for whitelisted IPs.""" from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig config = RateLimitConfig( ip_limit=1, # Very low limit window_size=60, ip_whitelist={"127.0.0.1", "::1", "10.0.0.1"}, fallback_enabled=True, ) app = FastAPI() app.add_middleware(RateLimiterMiddleware, config=config) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) # Multiple requests should succeed because the IP is whitelisted # Use X-Forwarded-For header to simulate whitelisted IP for _ in range(10): response = client.get("/test", headers={"X-Forwarded-For": "10.0.0.1"}) assert response.status_code == 200 def test_excludes_health_endpoint(self): """Should not rate limit excluded paths.""" from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig config = RateLimitConfig( ip_limit=1, window_size=60, excluded_paths=["/health"], fallback_enabled=True, ) app = FastAPI() app.add_middleware(RateLimiterMiddleware, config=config) @app.get("/health") async def health(): return {"status": "healthy"} client = TestClient(app) # Multiple requests to health should succeed for _ in range(10): response = client.get("/health") assert response.status_code == 200 # ============================================== # Middleware Stack Integration Test # ============================================== class TestMiddlewareStackIntegration: """Tests for the complete middleware stack.""" def test_full_middleware_stack(self): """Test all middlewares work together.""" from middleware.request_id import RequestIDMiddleware from middleware.security_headers import SecurityHeadersMiddleware from middleware.input_gate import InputGateMiddleware app = FastAPI() # Add middlewares in order (last added = first executed) app.add_middleware(InputGateMiddleware) app.add_middleware(SecurityHeadersMiddleware, development_mode=True) app.add_middleware(RequestIDMiddleware) @app.get("/test") async def test_endpoint(): return {"status": "ok"} @app.post("/data") async def data_endpoint(data: dict): return {"received": data} client = TestClient(app) # Test GET request response = client.get("/test") assert response.status_code == 200 assert "X-Request-ID" in response.headers assert "X-Content-Type-Options" in response.headers # Test POST request with JSON response = client.post( "/data", json={"key": "value"}, headers={"Content-Type": "application/json"} ) assert response.status_code == 200 assert response.json()["received"] == {"key": "value"}