Some checks failed
ci/woodpecker/push/integration Pipeline failed
ci/woodpecker/push/main Pipeline failed
CI/CD Pipeline / Docker Build & Push (push) Has been cancelled
CI/CD Pipeline / Linting (push) Has been cancelled
CI/CD Pipeline / Go Tests (push) Has been cancelled
CI/CD Pipeline / Python Tests (push) Has been cancelled
CI/CD Pipeline / Website Tests (push) Has been cancelled
CI/CD Pipeline / Security Scan (push) Has been cancelled
CI/CD Pipeline / Integration Tests (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / CI Summary (push) Has been cancelled
Security Scanning / Python Security Scan (push) Has been cancelled
Security Scanning / Node.js Security Scan (push) Has been cancelled
Security Scanning / Secret Scanning (push) Has been cancelled
Security Scanning / Dependency Vulnerability Scan (push) Has been cancelled
Security Scanning / Go Security Scan (push) Has been cancelled
Security Scanning / Docker Image Security (push) Has been cancelled
Security Scanning / Security Summary (push) Has been cancelled
Tests / Go Tests (push) Has been cancelled
Tests / Python Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Tests / Go Lint (push) Has been cancelled
Tests / Python Lint (push) Has been cancelled
Tests / Security Scan (push) Has been cancelled
Tests / All Checks Passed (push) Has been cancelled
Implements anomaly-score-based middleware to protect SDK/Compliance endpoints from systematic data harvesting. Includes 5 detection mechanisms (diversity, burst, sequential enumeration, unusual hours, multi-tenant), multi-window quota system, progressive throttling, HMAC watermarking, and graceful Valkey fallback. - backend/middleware/sdk_protection.py: Core middleware (~750 lines) - Admin API endpoints for score management and tier configuration - 14 new tests (all passing) - MkDocs documentation with clear explanations - Screen flow and middleware dashboard updates Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
890 lines
29 KiB
Python
890 lines
29 KiB
Python
"""
|
|
Integration Tests for Middleware Components
|
|
|
|
Tests the middleware stack:
|
|
- Request-ID generation and propagation
|
|
- Security headers
|
|
- Rate limiting
|
|
- PII redaction
|
|
- Input validation
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response, JSONResponse
|
|
from starlette.testclient import TestClient
|
|
from fastapi import FastAPI
|
|
import time
|
|
|
|
|
|
# ==============================================
|
|
# Request-ID Middleware Tests
|
|
# ==============================================
|
|
|
|
|
|
class TestRequestIDMiddleware:
|
|
"""Tests for RequestIDMiddleware."""
|
|
|
|
def test_generates_request_id_when_not_provided(self):
|
|
"""Should generate a UUID when no X-Request-ID header is provided."""
|
|
from middleware.request_id import RequestIDMiddleware, get_request_id
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(RequestIDMiddleware)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"request_id": get_request_id()}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
assert "X-Request-ID" in response.headers
|
|
assert len(response.headers["X-Request-ID"]) == 36 # UUID format
|
|
|
|
def test_propagates_existing_request_id(self):
|
|
"""Should propagate existing X-Request-ID header."""
|
|
from middleware.request_id import RequestIDMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(RequestIDMiddleware)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
custom_id = "custom-request-id-12345"
|
|
response = client.get("/test", headers={"X-Request-ID": custom_id})
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["X-Request-ID"] == custom_id
|
|
|
|
def test_propagates_correlation_id(self):
|
|
"""Should propagate X-Correlation-ID header."""
|
|
from middleware.request_id import RequestIDMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(RequestIDMiddleware)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
custom_id = "correlation-id-12345"
|
|
response = client.get("/test", headers={"X-Correlation-ID": custom_id})
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["X-Request-ID"] == custom_id
|
|
assert response.headers["X-Correlation-ID"] == custom_id
|
|
|
|
|
|
# ==============================================
|
|
# Security Headers Middleware Tests
|
|
# ==============================================
|
|
|
|
|
|
class TestSecurityHeadersMiddleware:
|
|
"""Tests for SecurityHeadersMiddleware."""
|
|
|
|
def test_adds_security_headers(self):
|
|
"""Should add security headers to all responses."""
|
|
from middleware.security_headers import SecurityHeadersMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(SecurityHeadersMiddleware, development_mode=False)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
|
assert response.headers["X-Frame-Options"] == "DENY"
|
|
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
|
assert "Referrer-Policy" in response.headers
|
|
|
|
def test_hsts_in_production(self):
|
|
"""Should add HSTS header in production mode."""
|
|
from middleware.security_headers import SecurityHeadersMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(SecurityHeadersMiddleware, development_mode=False, hsts_enabled=True)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
assert "Strict-Transport-Security" in response.headers
|
|
|
|
def test_no_hsts_in_development(self):
|
|
"""Should not add HSTS header in development mode."""
|
|
from middleware.security_headers import SecurityHeadersMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(SecurityHeadersMiddleware, development_mode=True)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
assert "Strict-Transport-Security" not in response.headers
|
|
|
|
def test_csp_header(self):
|
|
"""Should add CSP header when enabled."""
|
|
from middleware.security_headers import SecurityHeadersMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(
|
|
SecurityHeadersMiddleware,
|
|
csp_enabled=True,
|
|
csp_policy="default-src 'self'"
|
|
)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["Content-Security-Policy"] == "default-src 'self'"
|
|
|
|
def test_excludes_health_endpoint(self):
|
|
"""Should not add security headers to excluded paths."""
|
|
from middleware.security_headers import SecurityHeadersMiddleware, SecurityHeadersConfig
|
|
|
|
config = SecurityHeadersConfig(excluded_paths=["/health"])
|
|
app = FastAPI()
|
|
app.add_middleware(SecurityHeadersMiddleware, config=config)
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "healthy"}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/health")
|
|
|
|
assert response.status_code == 200
|
|
# Security headers should not be present
|
|
assert "Content-Security-Policy" not in response.headers
|
|
|
|
|
|
# ==============================================
|
|
# PII Redactor Tests
|
|
# ==============================================
|
|
|
|
|
|
class TestPIIRedactor:
|
|
"""Tests for PII redaction."""
|
|
|
|
def test_redacts_email(self):
|
|
"""Should redact email addresses."""
|
|
from middleware.pii_redactor import redact_pii
|
|
|
|
text = "User test@example.com logged in"
|
|
result = redact_pii(text)
|
|
|
|
assert "test@example.com" not in result
|
|
assert "[EMAIL_REDACTED]" in result
|
|
|
|
def test_redacts_ip_v4(self):
|
|
"""Should redact IPv4 addresses."""
|
|
from middleware.pii_redactor import redact_pii
|
|
|
|
text = "Request from 192.168.1.100"
|
|
result = redact_pii(text)
|
|
|
|
assert "192.168.1.100" not in result
|
|
assert "[IP_REDACTED]" in result
|
|
|
|
def test_redacts_german_phone(self):
|
|
"""Should redact German phone numbers."""
|
|
from middleware.pii_redactor import redact_pii
|
|
|
|
text = "Call +49 30 12345678"
|
|
result = redact_pii(text)
|
|
|
|
assert "+49 30 12345678" not in result
|
|
assert "[PHONE_REDACTED]" in result
|
|
|
|
def test_redacts_multiple_pii(self):
|
|
"""Should redact multiple PII types in same text."""
|
|
from middleware.pii_redactor import redact_pii
|
|
|
|
text = "User test@example.com from 10.0.0.1"
|
|
result = redact_pii(text)
|
|
|
|
assert "test@example.com" not in result
|
|
assert "10.0.0.1" not in result
|
|
assert "[EMAIL_REDACTED]" in result
|
|
assert "[IP_REDACTED]" in result
|
|
|
|
def test_preserves_non_pii_text(self):
|
|
"""Should preserve text that is not PII."""
|
|
from middleware.pii_redactor import redact_pii
|
|
|
|
text = "User logged in successfully"
|
|
result = redact_pii(text)
|
|
|
|
assert result == text
|
|
|
|
def test_contains_pii_detection(self):
|
|
"""Should detect if text contains PII."""
|
|
from middleware.pii_redactor import PIIRedactor
|
|
|
|
redactor = PIIRedactor()
|
|
|
|
assert redactor.contains_pii("test@example.com")
|
|
assert redactor.contains_pii("192.168.1.1")
|
|
assert not redactor.contains_pii("Hello World")
|
|
|
|
def test_find_pii_locations(self):
|
|
"""Should find PII and return locations."""
|
|
from middleware.pii_redactor import PIIRedactor
|
|
|
|
redactor = PIIRedactor()
|
|
text = "Email: test@example.com, IP: 10.0.0.1"
|
|
findings = redactor.find_pii(text)
|
|
|
|
assert len(findings) == 2
|
|
assert any(f["type"] == "email" for f in findings)
|
|
assert any(f["type"] == "ip_v4" for f in findings)
|
|
|
|
|
|
# ==============================================
|
|
# Input Gate Middleware Tests
|
|
# ==============================================
|
|
|
|
|
|
class TestInputGateMiddleware:
|
|
"""Tests for InputGateMiddleware."""
|
|
|
|
def test_allows_valid_json_request(self):
|
|
"""Should allow valid JSON request within size limit."""
|
|
from middleware.input_gate import InputGateMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(InputGateMiddleware, max_body_size=1024)
|
|
|
|
@app.post("/test")
|
|
async def test_endpoint(data: dict):
|
|
return {"received": True}
|
|
|
|
client = TestClient(app)
|
|
response = client.post(
|
|
"/test",
|
|
json={"key": "value"},
|
|
headers={"Content-Type": "application/json"}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
def test_rejects_invalid_content_type(self):
|
|
"""Should reject request with invalid content type."""
|
|
from middleware.input_gate import InputGateMiddleware, InputGateConfig
|
|
|
|
config = InputGateConfig(
|
|
allowed_content_types={"application/json"},
|
|
strict_content_type=True
|
|
)
|
|
app = FastAPI()
|
|
app.add_middleware(InputGateMiddleware, config=config)
|
|
|
|
@app.post("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
response = client.post(
|
|
"/test",
|
|
content="data",
|
|
headers={"Content-Type": "text/xml"}
|
|
)
|
|
|
|
assert response.status_code == 415 # Unsupported Media Type
|
|
|
|
def test_allows_get_requests_without_body(self):
|
|
"""Should allow GET requests without validation."""
|
|
from middleware.input_gate import InputGateMiddleware
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(InputGateMiddleware)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
|
|
def test_excludes_health_endpoint(self):
|
|
"""Should not validate excluded paths."""
|
|
from middleware.input_gate import InputGateMiddleware, InputGateConfig
|
|
|
|
config = InputGateConfig(excluded_paths=["/health"])
|
|
app = FastAPI()
|
|
app.add_middleware(InputGateMiddleware, config=config)
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "healthy"}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/health")
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ==============================================
|
|
# File Upload Validation Tests
|
|
# ==============================================
|
|
|
|
|
|
class TestFileUploadValidation:
|
|
"""Tests for file upload validation."""
|
|
|
|
def test_validates_file_size(self):
|
|
"""Should reject files exceeding max size."""
|
|
from middleware.input_gate import validate_file_upload, InputGateConfig
|
|
|
|
config = InputGateConfig(max_file_size=1024) # 1KB
|
|
|
|
valid, error = validate_file_upload(
|
|
filename="test.pdf",
|
|
content_type="application/pdf",
|
|
size=512, # 512 bytes
|
|
config=config
|
|
)
|
|
assert valid
|
|
|
|
valid, error = validate_file_upload(
|
|
filename="test.pdf",
|
|
content_type="application/pdf",
|
|
size=2048, # 2KB - exceeds limit
|
|
config=config
|
|
)
|
|
assert not valid
|
|
assert "size" in error.lower()
|
|
|
|
def test_rejects_blocked_extensions(self):
|
|
"""Should reject files with blocked extensions."""
|
|
from middleware.input_gate import validate_file_upload
|
|
|
|
valid, error = validate_file_upload(
|
|
filename="malware.exe",
|
|
content_type="application/octet-stream",
|
|
size=100
|
|
)
|
|
assert not valid
|
|
assert "extension" in error.lower()
|
|
|
|
valid, error = validate_file_upload(
|
|
filename="script.bat",
|
|
content_type="application/octet-stream",
|
|
size=100
|
|
)
|
|
assert not valid
|
|
|
|
def test_allows_safe_file_types(self):
|
|
"""Should allow safe file types."""
|
|
from middleware.input_gate import validate_file_upload
|
|
|
|
valid, error = validate_file_upload(
|
|
filename="document.pdf",
|
|
content_type="application/pdf",
|
|
size=1024
|
|
)
|
|
assert valid
|
|
|
|
valid, error = validate_file_upload(
|
|
filename="image.png",
|
|
content_type="image/png",
|
|
size=1024
|
|
)
|
|
assert valid
|
|
|
|
|
|
# ==============================================
|
|
# Rate Limiter Tests
|
|
# ==============================================
|
|
|
|
|
|
class TestRateLimiterMiddleware:
|
|
"""Tests for RateLimiterMiddleware."""
|
|
|
|
def test_allows_requests_under_limit(self):
|
|
"""Should allow requests under the rate limit."""
|
|
from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig
|
|
|
|
config = RateLimitConfig(
|
|
ip_limit=100,
|
|
window_size=60,
|
|
fallback_enabled=True,
|
|
skip_internal_network=True,
|
|
)
|
|
app = FastAPI()
|
|
app.add_middleware(RateLimiterMiddleware, config=config)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
|
|
# Make a few requests - should all succeed
|
|
for _ in range(5):
|
|
response = client.get("/test")
|
|
assert response.status_code == 200
|
|
|
|
def test_rate_limit_headers(self):
|
|
"""Should include rate limit headers in response."""
|
|
from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig
|
|
|
|
# Use a client IP that won't be skipped
|
|
config = RateLimitConfig(
|
|
ip_limit=100,
|
|
window_size=60,
|
|
fallback_enabled=True,
|
|
skip_internal_network=False, # Don't skip internal IPs for this test
|
|
)
|
|
app = FastAPI()
|
|
app.add_middleware(RateLimiterMiddleware, config=config)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/test")
|
|
|
|
assert response.status_code == 200
|
|
assert "X-RateLimit-Limit" in response.headers
|
|
assert "X-RateLimit-Remaining" in response.headers
|
|
assert "X-RateLimit-Reset" in response.headers
|
|
|
|
def test_skips_whitelisted_ips(self):
|
|
"""Should skip rate limiting for whitelisted IPs."""
|
|
from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig
|
|
|
|
config = RateLimitConfig(
|
|
ip_limit=1, # Very low limit
|
|
window_size=60,
|
|
ip_whitelist={"127.0.0.1", "::1", "10.0.0.1"},
|
|
fallback_enabled=True,
|
|
)
|
|
app = FastAPI()
|
|
app.add_middleware(RateLimiterMiddleware, config=config)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
client = TestClient(app)
|
|
|
|
# Multiple requests should succeed because the IP is whitelisted
|
|
# Use X-Forwarded-For header to simulate whitelisted IP
|
|
for _ in range(10):
|
|
response = client.get("/test", headers={"X-Forwarded-For": "10.0.0.1"})
|
|
assert response.status_code == 200
|
|
|
|
def test_excludes_health_endpoint(self):
|
|
"""Should not rate limit excluded paths."""
|
|
from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig
|
|
|
|
config = RateLimitConfig(
|
|
ip_limit=1,
|
|
window_size=60,
|
|
excluded_paths=["/health"],
|
|
fallback_enabled=True,
|
|
)
|
|
app = FastAPI()
|
|
app.add_middleware(RateLimiterMiddleware, config=config)
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "healthy"}
|
|
|
|
client = TestClient(app)
|
|
|
|
# Multiple requests to health should succeed
|
|
for _ in range(10):
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ==============================================
|
|
# Middleware Stack Integration Test
|
|
# ==============================================
|
|
|
|
|
|
class TestMiddlewareStackIntegration:
|
|
"""Tests for the complete middleware stack."""
|
|
|
|
def test_full_middleware_stack(self):
|
|
"""Test all middlewares work together."""
|
|
from middleware.request_id import RequestIDMiddleware
|
|
from middleware.security_headers import SecurityHeadersMiddleware
|
|
from middleware.input_gate import InputGateMiddleware
|
|
|
|
app = FastAPI()
|
|
|
|
# Add middlewares in order (last added = first executed)
|
|
app.add_middleware(InputGateMiddleware)
|
|
app.add_middleware(SecurityHeadersMiddleware, development_mode=True)
|
|
app.add_middleware(RequestIDMiddleware)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
@app.post("/data")
|
|
async def data_endpoint(data: dict):
|
|
return {"received": data}
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test GET request
|
|
response = client.get("/test")
|
|
assert response.status_code == 200
|
|
assert "X-Request-ID" in response.headers
|
|
assert "X-Content-Type-Options" in response.headers
|
|
|
|
# Test POST request with JSON
|
|
response = client.post(
|
|
"/data",
|
|
json={"key": "value"},
|
|
headers={"Content-Type": "application/json"}
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.json()["received"] == {"key": "value"}
|
|
|
|
|
|
# ==============================================
|
|
# SDK Protection Middleware Tests
|
|
# ==============================================
|
|
|
|
|
|
class TestSDKProtectionMiddleware:
|
|
"""Tests for SDKProtectionMiddleware."""
|
|
|
|
def _create_app(self, **config_overrides):
|
|
"""Helper to create test app with SDK protection."""
|
|
from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig
|
|
|
|
config_kwargs = {
|
|
"fallback_enabled": True,
|
|
"watermark_secret": "test-secret",
|
|
}
|
|
config_kwargs.update(config_overrides)
|
|
config = SDKProtectionConfig(**config_kwargs)
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(SDKProtectionMiddleware, config=config)
|
|
|
|
@app.get("/api/v1/tom/access-control")
|
|
async def tom_access_control():
|
|
return {"data": "access-control"}
|
|
|
|
@app.get("/api/v1/tom/encryption")
|
|
async def tom_encryption():
|
|
return {"data": "encryption"}
|
|
|
|
@app.get("/api/v1/dsfa/threshold")
|
|
async def dsfa_threshold():
|
|
return {"data": "threshold"}
|
|
|
|
@app.get("/api/v1/dsfa/necessity")
|
|
async def dsfa_necessity():
|
|
return {"data": "necessity"}
|
|
|
|
@app.get("/api/v1/vvt/processing")
|
|
async def vvt_processing():
|
|
return {"data": "processing"}
|
|
|
|
@app.get("/api/v1/vvt/purposes")
|
|
async def vvt_purposes():
|
|
return {"data": "purposes"}
|
|
|
|
@app.get("/api/v1/vvt/categories")
|
|
async def vvt_categories():
|
|
return {"data": "categories"}
|
|
|
|
@app.get("/api/v1/vvt/recipients")
|
|
async def vvt_recipients():
|
|
return {"data": "recipients"}
|
|
|
|
@app.get("/api/v1/controls/list")
|
|
async def controls_list():
|
|
return {"data": "controls"}
|
|
|
|
@app.get("/api/v1/assessment/run")
|
|
async def assessment_run():
|
|
return {"data": "assessment"}
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "healthy"}
|
|
|
|
@app.get("/api/public")
|
|
async def public():
|
|
return {"data": "public"}
|
|
|
|
return app
|
|
|
|
def test_allows_normal_request(self):
|
|
"""Should allow normal requests under all limits."""
|
|
app = self._create_app()
|
|
client = TestClient(app)
|
|
|
|
response = client.get(
|
|
"/api/v1/tom/access-control",
|
|
headers={"X-API-Key": "test-user-key-123"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.json() == {"data": "access-control"}
|
|
|
|
def test_quota_headers_present(self):
|
|
"""Should include quota headers in response."""
|
|
app = self._create_app()
|
|
client = TestClient(app)
|
|
|
|
response = client.get(
|
|
"/api/v1/tom/access-control",
|
|
headers={"X-API-Key": "test-user-key-456"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert "X-SDK-Quota-Remaining-Minute" in response.headers
|
|
assert "X-SDK-Quota-Remaining-Hour" in response.headers
|
|
assert "X-SDK-Throttle-Level" in response.headers
|
|
|
|
def test_blocks_after_quota_exceeded(self):
|
|
"""Should return 429 when minute quota is exceeded."""
|
|
from middleware.sdk_protection import SDKProtectionConfig, QuotaTier
|
|
|
|
tiers = {
|
|
"free": QuotaTier("free", 3, 500, 3000, 50000), # Very low minute limit
|
|
}
|
|
app = self._create_app(tiers=tiers)
|
|
client = TestClient(app)
|
|
|
|
api_key = "quota-test-user"
|
|
headers = {"X-API-Key": api_key}
|
|
|
|
# Make requests up to the limit
|
|
for i in range(3):
|
|
response = client.get("/api/v1/tom/access-control", headers=headers)
|
|
assert response.status_code == 200, f"Request {i+1} should succeed"
|
|
|
|
# Next request should be blocked
|
|
response = client.get("/api/v1/tom/access-control", headers=headers)
|
|
assert response.status_code == 429
|
|
assert response.json()["error"] == "sdk_quota_exceeded"
|
|
|
|
def test_diversity_tracking_increments_score(self):
|
|
"""Score should increase when accessing many different categories."""
|
|
from middleware.sdk_protection import SDKProtectionConfig
|
|
|
|
app = self._create_app(diversity_threshold=3) # Low threshold for test
|
|
client = TestClient(app)
|
|
|
|
api_key = "diversity-test-user"
|
|
headers = {"X-API-Key": api_key}
|
|
|
|
# Access many different categories
|
|
endpoints = [
|
|
"/api/v1/tom/access-control",
|
|
"/api/v1/tom/encryption",
|
|
"/api/v1/dsfa/threshold",
|
|
"/api/v1/dsfa/necessity",
|
|
"/api/v1/vvt/processing",
|
|
"/api/v1/vvt/purposes",
|
|
]
|
|
|
|
for endpoint in endpoints:
|
|
response = client.get(endpoint, headers=headers)
|
|
assert response.status_code in (200, 429)
|
|
|
|
# After exceeding diversity, throttle level should increase
|
|
response = client.get("/api/v1/vvt/categories", headers=headers)
|
|
if response.status_code == 200:
|
|
level = int(response.headers.get("X-SDK-Throttle-Level", "0"))
|
|
assert level >= 0 # Score increased but may not hit threshold yet
|
|
|
|
def test_burst_detection(self):
|
|
"""Score should increase for rapid same-category requests."""
|
|
from middleware.sdk_protection import SDKProtectionConfig
|
|
|
|
app = self._create_app(burst_threshold=3) # Low threshold for test
|
|
client = TestClient(app)
|
|
|
|
api_key = "burst-test-user"
|
|
headers = {"X-API-Key": api_key}
|
|
|
|
# Burst access to same endpoint
|
|
for _ in range(5):
|
|
response = client.get("/api/v1/tom/access-control", headers=headers)
|
|
if response.status_code == 429:
|
|
break
|
|
|
|
# After burst, throttle level should have increased
|
|
response = client.get("/api/v1/tom/encryption", headers=headers)
|
|
if response.status_code == 200:
|
|
level = int(response.headers.get("X-SDK-Throttle-Level", "0"))
|
|
assert level >= 0 # Score increased
|
|
|
|
def test_sequential_enumeration_detection(self):
|
|
"""Score should increase for alphabetically sorted access patterns."""
|
|
from middleware.sdk_protection import (
|
|
SDKProtectionMiddleware,
|
|
SDKProtectionConfig,
|
|
InMemorySDKProtection,
|
|
)
|
|
|
|
config = SDKProtectionConfig(
|
|
sequential_min_entries=5,
|
|
sequential_sorted_ratio=0.6,
|
|
)
|
|
mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware)
|
|
mw.config = config
|
|
|
|
# Sorted sequence should be detected
|
|
sorted_seq = ["a_cat", "b_cat", "c_cat", "d_cat", "e_cat", "f_cat"]
|
|
assert mw._check_sequential(sorted_seq) is True
|
|
|
|
# Random sequence should not be detected
|
|
random_seq = ["d_cat", "a_cat", "f_cat", "b_cat", "e_cat", "c_cat"]
|
|
assert mw._check_sequential(random_seq) is False
|
|
|
|
# Too short sequence should not be detected
|
|
short_seq = ["a_cat", "b_cat"]
|
|
assert mw._check_sequential(short_seq) is False
|
|
|
|
def test_progressive_throttling_level_1(self):
|
|
"""Throttle level 1 should be set at score >= 30."""
|
|
from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig
|
|
|
|
config = SDKProtectionConfig()
|
|
mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware)
|
|
mw.config = config
|
|
|
|
assert mw._get_throttle_level(0) == 0
|
|
assert mw._get_throttle_level(29) == 0
|
|
assert mw._get_throttle_level(30) == 1
|
|
assert mw._get_throttle_level(50) == 1
|
|
assert mw._get_throttle_level(59) == 1
|
|
|
|
def test_progressive_throttling_level_3_blocks(self):
|
|
"""Throttle level 3 should be set at score >= 85."""
|
|
from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig
|
|
|
|
config = SDKProtectionConfig()
|
|
mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware)
|
|
mw.config = config
|
|
|
|
assert mw._get_throttle_level(60) == 2
|
|
assert mw._get_throttle_level(84) == 2
|
|
assert mw._get_throttle_level(85) == 3
|
|
assert mw._get_throttle_level(100) == 3
|
|
|
|
def test_score_decay_over_time(self):
|
|
"""Score should decay over time using decay factor."""
|
|
from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig
|
|
|
|
config = SDKProtectionConfig(
|
|
score_decay_factor=0.5, # Aggressive decay for test
|
|
score_decay_interval=60, # 1 minute intervals
|
|
)
|
|
mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware)
|
|
mw.config = config
|
|
|
|
now = time.time()
|
|
# Score 100, last decay 2 intervals ago
|
|
score, last_decay = mw._apply_decay(100.0, now - 120, now)
|
|
# 2 intervals: 100 * 0.5 * 0.5 = 25
|
|
assert score == pytest.approx(25.0)
|
|
|
|
# No decay if within same interval
|
|
score2, _ = mw._apply_decay(100.0, now - 30, now)
|
|
assert score2 == pytest.approx(100.0)
|
|
|
|
def test_skips_non_protected_paths(self):
|
|
"""Should not apply protection to non-SDK paths."""
|
|
app = self._create_app()
|
|
client = TestClient(app)
|
|
|
|
# Health endpoint should not be protected
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
assert "X-SDK-Throttle-Level" not in response.headers
|
|
|
|
# Non-SDK path should not be protected
|
|
response = client.get("/api/public")
|
|
assert response.status_code == 200
|
|
assert "X-SDK-Throttle-Level" not in response.headers
|
|
|
|
def test_watermark_header_present(self):
|
|
"""Response should include X-BP-Trace watermark header."""
|
|
app = self._create_app()
|
|
client = TestClient(app)
|
|
|
|
response = client.get(
|
|
"/api/v1/tom/access-control",
|
|
headers={"X-API-Key": "watermark-test-user"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert "X-BP-Trace" in response.headers
|
|
assert len(response.headers["X-BP-Trace"]) == 32
|
|
|
|
def test_fallback_to_inmemory(self):
|
|
"""Should work with in-memory fallback when Valkey is unavailable."""
|
|
from middleware.sdk_protection import SDKProtectionConfig
|
|
|
|
# Point to non-existent Valkey
|
|
app = self._create_app(valkey_url="redis://nonexistent:9999")
|
|
client = TestClient(app)
|
|
|
|
response = client.get(
|
|
"/api/v1/tom/access-control",
|
|
headers={"X-API-Key": "fallback-test-user"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.json() == {"data": "access-control"}
|
|
|
|
def test_no_user_passes_through(self):
|
|
"""Requests without user identification should pass through."""
|
|
app = self._create_app()
|
|
client = TestClient(app)
|
|
|
|
# No API key and no session
|
|
response = client.get("/api/v1/tom/access-control")
|
|
assert response.status_code == 200
|
|
|
|
def test_category_extraction(self):
|
|
"""Category extraction should use longest prefix match."""
|
|
from middleware.sdk_protection import _extract_category
|
|
|
|
assert _extract_category("/api/v1/tom/access-control") == "tom_access_control"
|
|
assert _extract_category("/api/v1/tom/encryption") == "tom_encryption"
|
|
assert _extract_category("/api/v1/dsfa/threshold") == "dsfa_threshold"
|
|
assert _extract_category("/api/v1/vvt/processing") == "vvt_processing"
|
|
assert _extract_category("/api/v1/controls/anything") == "controls_general"
|
|
assert _extract_category("/api/unknown/path") == "unknown"
|