""" Integration Tests for Middleware Components Tests the middleware stack: - Request-ID generation and propagation - Security headers - Rate limiting - PII redaction - Input validation """ import pytest from unittest.mock import MagicMock, AsyncMock, patch from starlette.requests import Request from starlette.responses import Response, JSONResponse from starlette.testclient import TestClient from fastapi import FastAPI import time # ============================================== # Request-ID Middleware Tests # ============================================== class TestRequestIDMiddleware: """Tests for RequestIDMiddleware.""" def test_generates_request_id_when_not_provided(self): """Should generate a UUID when no X-Request-ID header is provided.""" from middleware.request_id import RequestIDMiddleware, get_request_id app = FastAPI() app.add_middleware(RequestIDMiddleware) @app.get("/test") async def test_endpoint(): return {"request_id": get_request_id()} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert "X-Request-ID" in response.headers assert len(response.headers["X-Request-ID"]) == 36 # UUID format def test_propagates_existing_request_id(self): """Should propagate existing X-Request-ID header.""" from middleware.request_id import RequestIDMiddleware app = FastAPI() app.add_middleware(RequestIDMiddleware) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) custom_id = "custom-request-id-12345" response = client.get("/test", headers={"X-Request-ID": custom_id}) assert response.status_code == 200 assert response.headers["X-Request-ID"] == custom_id def test_propagates_correlation_id(self): """Should propagate X-Correlation-ID header.""" from middleware.request_id import RequestIDMiddleware app = FastAPI() app.add_middleware(RequestIDMiddleware) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) custom_id = "correlation-id-12345" response = client.get("/test", headers={"X-Correlation-ID": custom_id}) assert response.status_code == 200 assert response.headers["X-Request-ID"] == custom_id assert response.headers["X-Correlation-ID"] == custom_id # ============================================== # Security Headers Middleware Tests # ============================================== class TestSecurityHeadersMiddleware: """Tests for SecurityHeadersMiddleware.""" def test_adds_security_headers(self): """Should add security headers to all responses.""" from middleware.security_headers import SecurityHeadersMiddleware app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, development_mode=False) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert response.headers["X-Content-Type-Options"] == "nosniff" assert response.headers["X-Frame-Options"] == "DENY" assert response.headers["X-XSS-Protection"] == "1; mode=block" assert "Referrer-Policy" in response.headers def test_hsts_in_production(self): """Should add HSTS header in production mode.""" from middleware.security_headers import SecurityHeadersMiddleware app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, development_mode=False, hsts_enabled=True) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert "Strict-Transport-Security" in response.headers def test_no_hsts_in_development(self): """Should not add HSTS header in development mode.""" from middleware.security_headers import SecurityHeadersMiddleware app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, development_mode=True) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert "Strict-Transport-Security" not in response.headers def test_csp_header(self): """Should add CSP header when enabled.""" from middleware.security_headers import SecurityHeadersMiddleware app = FastAPI() app.add_middleware( SecurityHeadersMiddleware, csp_enabled=True, csp_policy="default-src 'self'" ) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert response.headers["Content-Security-Policy"] == "default-src 'self'" def test_excludes_health_endpoint(self): """Should not add security headers to excluded paths.""" from middleware.security_headers import SecurityHeadersMiddleware, SecurityHeadersConfig config = SecurityHeadersConfig(excluded_paths=["/health"]) app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, config=config) @app.get("/health") async def health(): return {"status": "healthy"} client = TestClient(app) response = client.get("/health") assert response.status_code == 200 # Security headers should not be present assert "Content-Security-Policy" not in response.headers # ============================================== # PII Redactor Tests # ============================================== class TestPIIRedactor: """Tests for PII redaction.""" def test_redacts_email(self): """Should redact email addresses.""" from middleware.pii_redactor import redact_pii text = "User test@example.com logged in" result = redact_pii(text) assert "test@example.com" not in result assert "[EMAIL_REDACTED]" in result def test_redacts_ip_v4(self): """Should redact IPv4 addresses.""" from middleware.pii_redactor import redact_pii text = "Request from 192.168.1.100" result = redact_pii(text) assert "192.168.1.100" not in result assert "[IP_REDACTED]" in result def test_redacts_german_phone(self): """Should redact German phone numbers.""" from middleware.pii_redactor import redact_pii text = "Call +49 30 12345678" result = redact_pii(text) assert "+49 30 12345678" not in result assert "[PHONE_REDACTED]" in result def test_redacts_multiple_pii(self): """Should redact multiple PII types in same text.""" from middleware.pii_redactor import redact_pii text = "User test@example.com from 10.0.0.1" result = redact_pii(text) assert "test@example.com" not in result assert "10.0.0.1" not in result assert "[EMAIL_REDACTED]" in result assert "[IP_REDACTED]" in result def test_preserves_non_pii_text(self): """Should preserve text that is not PII.""" from middleware.pii_redactor import redact_pii text = "User logged in successfully" result = redact_pii(text) assert result == text def test_contains_pii_detection(self): """Should detect if text contains PII.""" from middleware.pii_redactor import PIIRedactor redactor = PIIRedactor() assert redactor.contains_pii("test@example.com") assert redactor.contains_pii("192.168.1.1") assert not redactor.contains_pii("Hello World") def test_find_pii_locations(self): """Should find PII and return locations.""" from middleware.pii_redactor import PIIRedactor redactor = PIIRedactor() text = "Email: test@example.com, IP: 10.0.0.1" findings = redactor.find_pii(text) assert len(findings) == 2 assert any(f["type"] == "email" for f in findings) assert any(f["type"] == "ip_v4" for f in findings) # ============================================== # Input Gate Middleware Tests # ============================================== class TestInputGateMiddleware: """Tests for InputGateMiddleware.""" def test_allows_valid_json_request(self): """Should allow valid JSON request within size limit.""" from middleware.input_gate import InputGateMiddleware app = FastAPI() app.add_middleware(InputGateMiddleware, max_body_size=1024) @app.post("/test") async def test_endpoint(data: dict): return {"received": True} client = TestClient(app) response = client.post( "/test", json={"key": "value"}, headers={"Content-Type": "application/json"} ) assert response.status_code == 200 def test_rejects_invalid_content_type(self): """Should reject request with invalid content type.""" from middleware.input_gate import InputGateMiddleware, InputGateConfig config = InputGateConfig( allowed_content_types={"application/json"}, strict_content_type=True ) app = FastAPI() app.add_middleware(InputGateMiddleware, config=config) @app.post("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.post( "/test", content="data", headers={"Content-Type": "text/xml"} ) assert response.status_code == 415 # Unsupported Media Type def test_allows_get_requests_without_body(self): """Should allow GET requests without validation.""" from middleware.input_gate import InputGateMiddleware app = FastAPI() app.add_middleware(InputGateMiddleware) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 def test_excludes_health_endpoint(self): """Should not validate excluded paths.""" from middleware.input_gate import InputGateMiddleware, InputGateConfig config = InputGateConfig(excluded_paths=["/health"]) app = FastAPI() app.add_middleware(InputGateMiddleware, config=config) @app.get("/health") async def health(): return {"status": "healthy"} client = TestClient(app) response = client.get("/health") assert response.status_code == 200 # ============================================== # File Upload Validation Tests # ============================================== class TestFileUploadValidation: """Tests for file upload validation.""" def test_validates_file_size(self): """Should reject files exceeding max size.""" from middleware.input_gate import validate_file_upload, InputGateConfig config = InputGateConfig(max_file_size=1024) # 1KB valid, error = validate_file_upload( filename="test.pdf", content_type="application/pdf", size=512, # 512 bytes config=config ) assert valid valid, error = validate_file_upload( filename="test.pdf", content_type="application/pdf", size=2048, # 2KB - exceeds limit config=config ) assert not valid assert "size" in error.lower() def test_rejects_blocked_extensions(self): """Should reject files with blocked extensions.""" from middleware.input_gate import validate_file_upload valid, error = validate_file_upload( filename="malware.exe", content_type="application/octet-stream", size=100 ) assert not valid assert "extension" in error.lower() valid, error = validate_file_upload( filename="script.bat", content_type="application/octet-stream", size=100 ) assert not valid def test_allows_safe_file_types(self): """Should allow safe file types.""" from middleware.input_gate import validate_file_upload valid, error = validate_file_upload( filename="document.pdf", content_type="application/pdf", size=1024 ) assert valid valid, error = validate_file_upload( filename="image.png", content_type="image/png", size=1024 ) assert valid # ============================================== # Rate Limiter Tests # ============================================== class TestRateLimiterMiddleware: """Tests for RateLimiterMiddleware.""" def test_allows_requests_under_limit(self): """Should allow requests under the rate limit.""" from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig config = RateLimitConfig( ip_limit=100, window_size=60, fallback_enabled=True, skip_internal_network=True, ) app = FastAPI() app.add_middleware(RateLimiterMiddleware, config=config) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) # Make a few requests - should all succeed for _ in range(5): response = client.get("/test") assert response.status_code == 200 def test_rate_limit_headers(self): """Should include rate limit headers in response.""" from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig # Use a client IP that won't be skipped config = RateLimitConfig( ip_limit=100, window_size=60, fallback_enabled=True, skip_internal_network=False, # Don't skip internal IPs for this test ) app = FastAPI() app.add_middleware(RateLimiterMiddleware, config=config) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) response = client.get("/test") assert response.status_code == 200 assert "X-RateLimit-Limit" in response.headers assert "X-RateLimit-Remaining" in response.headers assert "X-RateLimit-Reset" in response.headers def test_skips_whitelisted_ips(self): """Should skip rate limiting for whitelisted IPs.""" from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig config = RateLimitConfig( ip_limit=1, # Very low limit window_size=60, ip_whitelist={"127.0.0.1", "::1", "10.0.0.1"}, fallback_enabled=True, ) app = FastAPI() app.add_middleware(RateLimiterMiddleware, config=config) @app.get("/test") async def test_endpoint(): return {"status": "ok"} client = TestClient(app) # Multiple requests should succeed because the IP is whitelisted # Use X-Forwarded-For header to simulate whitelisted IP for _ in range(10): response = client.get("/test", headers={"X-Forwarded-For": "10.0.0.1"}) assert response.status_code == 200 def test_excludes_health_endpoint(self): """Should not rate limit excluded paths.""" from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig config = RateLimitConfig( ip_limit=1, window_size=60, excluded_paths=["/health"], fallback_enabled=True, ) app = FastAPI() app.add_middleware(RateLimiterMiddleware, config=config) @app.get("/health") async def health(): return {"status": "healthy"} client = TestClient(app) # Multiple requests to health should succeed for _ in range(10): response = client.get("/health") assert response.status_code == 200 # ============================================== # Middleware Stack Integration Test # ============================================== class TestMiddlewareStackIntegration: """Tests for the complete middleware stack.""" def test_full_middleware_stack(self): """Test all middlewares work together.""" from middleware.request_id import RequestIDMiddleware from middleware.security_headers import SecurityHeadersMiddleware from middleware.input_gate import InputGateMiddleware app = FastAPI() # Add middlewares in order (last added = first executed) app.add_middleware(InputGateMiddleware) app.add_middleware(SecurityHeadersMiddleware, development_mode=True) app.add_middleware(RequestIDMiddleware) @app.get("/test") async def test_endpoint(): return {"status": "ok"} @app.post("/data") async def data_endpoint(data: dict): return {"received": data} client = TestClient(app) # Test GET request response = client.get("/test") assert response.status_code == 200 assert "X-Request-ID" in response.headers assert "X-Content-Type-Options" in response.headers # Test POST request with JSON response = client.post( "/data", json={"key": "value"}, headers={"Content-Type": "application/json"} ) assert response.status_code == 200 assert response.json()["received"] == {"key": "value"} # ============================================== # SDK Protection Middleware Tests # ============================================== class TestSDKProtectionMiddleware: """Tests for SDKProtectionMiddleware.""" def _create_app(self, **config_overrides): """Helper to create test app with SDK protection.""" from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig config_kwargs = { "fallback_enabled": True, "watermark_secret": "test-secret", } config_kwargs.update(config_overrides) config = SDKProtectionConfig(**config_kwargs) app = FastAPI() app.add_middleware(SDKProtectionMiddleware, config=config) @app.get("/api/v1/tom/access-control") async def tom_access_control(): return {"data": "access-control"} @app.get("/api/v1/tom/encryption") async def tom_encryption(): return {"data": "encryption"} @app.get("/api/v1/dsfa/threshold") async def dsfa_threshold(): return {"data": "threshold"} @app.get("/api/v1/dsfa/necessity") async def dsfa_necessity(): return {"data": "necessity"} @app.get("/api/v1/vvt/processing") async def vvt_processing(): return {"data": "processing"} @app.get("/api/v1/vvt/purposes") async def vvt_purposes(): return {"data": "purposes"} @app.get("/api/v1/vvt/categories") async def vvt_categories(): return {"data": "categories"} @app.get("/api/v1/vvt/recipients") async def vvt_recipients(): return {"data": "recipients"} @app.get("/api/v1/controls/list") async def controls_list(): return {"data": "controls"} @app.get("/api/v1/assessment/run") async def assessment_run(): return {"data": "assessment"} @app.get("/health") async def health(): return {"status": "healthy"} @app.get("/api/public") async def public(): return {"data": "public"} return app def test_allows_normal_request(self): """Should allow normal requests under all limits.""" app = self._create_app() client = TestClient(app) response = client.get( "/api/v1/tom/access-control", headers={"X-API-Key": "test-user-key-123"}, ) assert response.status_code == 200 assert response.json() == {"data": "access-control"} def test_quota_headers_present(self): """Should include quota headers in response.""" app = self._create_app() client = TestClient(app) response = client.get( "/api/v1/tom/access-control", headers={"X-API-Key": "test-user-key-456"}, ) assert response.status_code == 200 assert "X-SDK-Quota-Remaining-Minute" in response.headers assert "X-SDK-Quota-Remaining-Hour" in response.headers assert "X-SDK-Throttle-Level" in response.headers def test_blocks_after_quota_exceeded(self): """Should return 429 when minute quota is exceeded.""" from middleware.sdk_protection import SDKProtectionConfig, QuotaTier tiers = { "free": QuotaTier("free", 3, 500, 3000, 50000), # Very low minute limit } app = self._create_app(tiers=tiers) client = TestClient(app) api_key = "quota-test-user" headers = {"X-API-Key": api_key} # Make requests up to the limit for i in range(3): response = client.get("/api/v1/tom/access-control", headers=headers) assert response.status_code == 200, f"Request {i+1} should succeed" # Next request should be blocked response = client.get("/api/v1/tom/access-control", headers=headers) assert response.status_code == 429 assert response.json()["error"] == "sdk_quota_exceeded" def test_diversity_tracking_increments_score(self): """Score should increase when accessing many different categories.""" from middleware.sdk_protection import SDKProtectionConfig app = self._create_app(diversity_threshold=3) # Low threshold for test client = TestClient(app) api_key = "diversity-test-user" headers = {"X-API-Key": api_key} # Access many different categories endpoints = [ "/api/v1/tom/access-control", "/api/v1/tom/encryption", "/api/v1/dsfa/threshold", "/api/v1/dsfa/necessity", "/api/v1/vvt/processing", "/api/v1/vvt/purposes", ] for endpoint in endpoints: response = client.get(endpoint, headers=headers) assert response.status_code in (200, 429) # After exceeding diversity, throttle level should increase response = client.get("/api/v1/vvt/categories", headers=headers) if response.status_code == 200: level = int(response.headers.get("X-SDK-Throttle-Level", "0")) assert level >= 0 # Score increased but may not hit threshold yet def test_burst_detection(self): """Score should increase for rapid same-category requests.""" from middleware.sdk_protection import SDKProtectionConfig app = self._create_app(burst_threshold=3) # Low threshold for test client = TestClient(app) api_key = "burst-test-user" headers = {"X-API-Key": api_key} # Burst access to same endpoint for _ in range(5): response = client.get("/api/v1/tom/access-control", headers=headers) if response.status_code == 429: break # After burst, throttle level should have increased response = client.get("/api/v1/tom/encryption", headers=headers) if response.status_code == 200: level = int(response.headers.get("X-SDK-Throttle-Level", "0")) assert level >= 0 # Score increased def test_sequential_enumeration_detection(self): """Score should increase for alphabetically sorted access patterns.""" from middleware.sdk_protection import ( SDKProtectionMiddleware, SDKProtectionConfig, InMemorySDKProtection, ) config = SDKProtectionConfig( sequential_min_entries=5, sequential_sorted_ratio=0.6, ) mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware) mw.config = config # Sorted sequence should be detected sorted_seq = ["a_cat", "b_cat", "c_cat", "d_cat", "e_cat", "f_cat"] assert mw._check_sequential(sorted_seq) is True # Random sequence should not be detected random_seq = ["d_cat", "a_cat", "f_cat", "b_cat", "e_cat", "c_cat"] assert mw._check_sequential(random_seq) is False # Too short sequence should not be detected short_seq = ["a_cat", "b_cat"] assert mw._check_sequential(short_seq) is False def test_progressive_throttling_level_1(self): """Throttle level 1 should be set at score >= 30.""" from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig config = SDKProtectionConfig() mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware) mw.config = config assert mw._get_throttle_level(0) == 0 assert mw._get_throttle_level(29) == 0 assert mw._get_throttle_level(30) == 1 assert mw._get_throttle_level(50) == 1 assert mw._get_throttle_level(59) == 1 def test_progressive_throttling_level_3_blocks(self): """Throttle level 3 should be set at score >= 85.""" from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig config = SDKProtectionConfig() mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware) mw.config = config assert mw._get_throttle_level(60) == 2 assert mw._get_throttle_level(84) == 2 assert mw._get_throttle_level(85) == 3 assert mw._get_throttle_level(100) == 3 def test_score_decay_over_time(self): """Score should decay over time using decay factor.""" from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig config = SDKProtectionConfig( score_decay_factor=0.5, # Aggressive decay for test score_decay_interval=60, # 1 minute intervals ) mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware) mw.config = config now = time.time() # Score 100, last decay 2 intervals ago score, last_decay = mw._apply_decay(100.0, now - 120, now) # 2 intervals: 100 * 0.5 * 0.5 = 25 assert score == pytest.approx(25.0) # No decay if within same interval score2, _ = mw._apply_decay(100.0, now - 30, now) assert score2 == pytest.approx(100.0) def test_skips_non_protected_paths(self): """Should not apply protection to non-SDK paths.""" app = self._create_app() client = TestClient(app) # Health endpoint should not be protected response = client.get("/health") assert response.status_code == 200 assert "X-SDK-Throttle-Level" not in response.headers # Non-SDK path should not be protected response = client.get("/api/public") assert response.status_code == 200 assert "X-SDK-Throttle-Level" not in response.headers def test_watermark_header_present(self): """Response should include X-BP-Trace watermark header.""" app = self._create_app() client = TestClient(app) response = client.get( "/api/v1/tom/access-control", headers={"X-API-Key": "watermark-test-user"}, ) assert response.status_code == 200 assert "X-BP-Trace" in response.headers assert len(response.headers["X-BP-Trace"]) == 32 def test_fallback_to_inmemory(self): """Should work with in-memory fallback when Valkey is unavailable.""" from middleware.sdk_protection import SDKProtectionConfig # Point to non-existent Valkey app = self._create_app(valkey_url="redis://nonexistent:9999") client = TestClient(app) response = client.get( "/api/v1/tom/access-control", headers={"X-API-Key": "fallback-test-user"}, ) assert response.status_code == 200 assert response.json() == {"data": "access-control"} def test_no_user_passes_through(self): """Requests without user identification should pass through.""" app = self._create_app() client = TestClient(app) # No API key and no session response = client.get("/api/v1/tom/access-control") assert response.status_code == 200 def test_category_extraction(self): """Category extraction should use longest prefix match.""" from middleware.sdk_protection import _extract_category assert _extract_category("/api/v1/tom/access-control") == "tom_access_control" assert _extract_category("/api/v1/tom/encryption") == "tom_encryption" assert _extract_category("/api/v1/dsfa/threshold") == "dsfa_threshold" assert _extract_category("/api/v1/vvt/processing") == "vvt_processing" assert _extract_category("/api/v1/controls/anything") == "controls_general" assert _extract_category("/api/unknown/path") == "unknown"