This repository has been archived on 2026-02-15. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
breakpilot-pwa/backend/tests/test_middleware.py
Benjamin Admin 70f2b0ae64 refactor: Consolidate standalone services into admin-v2, add new SDK modules
Remove standalone services (ai-compliance-sdk root, developer-portal,
dsms-gateway, dsms-node, night-scheduler) and legacy compliance/dsgvo pages.
Add new SDK pipeline modules (academy, document-crawler, dsb-portal,
incidents, whistleblower, reporting, sso, multi-tenant, industry-templates).
Add drafting engine, legal corpus files (AT/CH/DE), pitch-deck,
blog and Förderantrag pages.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-15 09:05:18 +01:00

890 lines
29 KiB
Python

"""
Integration Tests for Middleware Components
Tests the middleware stack:
- Request-ID generation and propagation
- Security headers
- Rate limiting
- PII redaction
- Input validation
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from starlette.testclient import TestClient
from fastapi import FastAPI
import time
# ==============================================
# Request-ID Middleware Tests
# ==============================================
class TestRequestIDMiddleware:
"""Tests for RequestIDMiddleware."""
def test_generates_request_id_when_not_provided(self):
"""Should generate a UUID when no X-Request-ID header is provided."""
from middleware.request_id import RequestIDMiddleware, get_request_id
app = FastAPI()
app.add_middleware(RequestIDMiddleware)
@app.get("/test")
async def test_endpoint():
return {"request_id": get_request_id()}
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert "X-Request-ID" in response.headers
assert len(response.headers["X-Request-ID"]) == 36 # UUID format
def test_propagates_existing_request_id(self):
"""Should propagate existing X-Request-ID header."""
from middleware.request_id import RequestIDMiddleware
app = FastAPI()
app.add_middleware(RequestIDMiddleware)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
custom_id = "custom-request-id-12345"
response = client.get("/test", headers={"X-Request-ID": custom_id})
assert response.status_code == 200
assert response.headers["X-Request-ID"] == custom_id
def test_propagates_correlation_id(self):
"""Should propagate X-Correlation-ID header."""
from middleware.request_id import RequestIDMiddleware
app = FastAPI()
app.add_middleware(RequestIDMiddleware)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
custom_id = "correlation-id-12345"
response = client.get("/test", headers={"X-Correlation-ID": custom_id})
assert response.status_code == 200
assert response.headers["X-Request-ID"] == custom_id
assert response.headers["X-Correlation-ID"] == custom_id
# ==============================================
# Security Headers Middleware Tests
# ==============================================
class TestSecurityHeadersMiddleware:
"""Tests for SecurityHeadersMiddleware."""
def test_adds_security_headers(self):
"""Should add security headers to all responses."""
from middleware.security_headers import SecurityHeadersMiddleware
app = FastAPI()
app.add_middleware(SecurityHeadersMiddleware, development_mode=False)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert response.headers["X-Frame-Options"] == "DENY"
assert response.headers["X-XSS-Protection"] == "1; mode=block"
assert "Referrer-Policy" in response.headers
def test_hsts_in_production(self):
"""Should add HSTS header in production mode."""
from middleware.security_headers import SecurityHeadersMiddleware
app = FastAPI()
app.add_middleware(SecurityHeadersMiddleware, development_mode=False, hsts_enabled=True)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert "Strict-Transport-Security" in response.headers
def test_no_hsts_in_development(self):
"""Should not add HSTS header in development mode."""
from middleware.security_headers import SecurityHeadersMiddleware
app = FastAPI()
app.add_middleware(SecurityHeadersMiddleware, development_mode=True)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert "Strict-Transport-Security" not in response.headers
def test_csp_header(self):
"""Should add CSP header when enabled."""
from middleware.security_headers import SecurityHeadersMiddleware
app = FastAPI()
app.add_middleware(
SecurityHeadersMiddleware,
csp_enabled=True,
csp_policy="default-src 'self'"
)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert response.headers["Content-Security-Policy"] == "default-src 'self'"
def test_excludes_health_endpoint(self):
"""Should not add security headers to excluded paths."""
from middleware.security_headers import SecurityHeadersMiddleware, SecurityHeadersConfig
config = SecurityHeadersConfig(excluded_paths=["/health"])
app = FastAPI()
app.add_middleware(SecurityHeadersMiddleware, config=config)
@app.get("/health")
async def health():
return {"status": "healthy"}
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
# Security headers should not be present
assert "Content-Security-Policy" not in response.headers
# ==============================================
# PII Redactor Tests
# ==============================================
class TestPIIRedactor:
"""Tests for PII redaction."""
def test_redacts_email(self):
"""Should redact email addresses."""
from middleware.pii_redactor import redact_pii
text = "User test@example.com logged in"
result = redact_pii(text)
assert "test@example.com" not in result
assert "[EMAIL_REDACTED]" in result
def test_redacts_ip_v4(self):
"""Should redact IPv4 addresses."""
from middleware.pii_redactor import redact_pii
text = "Request from 192.168.1.100"
result = redact_pii(text)
assert "192.168.1.100" not in result
assert "[IP_REDACTED]" in result
def test_redacts_german_phone(self):
"""Should redact German phone numbers."""
from middleware.pii_redactor import redact_pii
text = "Call +49 30 12345678"
result = redact_pii(text)
assert "+49 30 12345678" not in result
assert "[PHONE_REDACTED]" in result
def test_redacts_multiple_pii(self):
"""Should redact multiple PII types in same text."""
from middleware.pii_redactor import redact_pii
text = "User test@example.com from 10.0.0.1"
result = redact_pii(text)
assert "test@example.com" not in result
assert "10.0.0.1" not in result
assert "[EMAIL_REDACTED]" in result
assert "[IP_REDACTED]" in result
def test_preserves_non_pii_text(self):
"""Should preserve text that is not PII."""
from middleware.pii_redactor import redact_pii
text = "User logged in successfully"
result = redact_pii(text)
assert result == text
def test_contains_pii_detection(self):
"""Should detect if text contains PII."""
from middleware.pii_redactor import PIIRedactor
redactor = PIIRedactor()
assert redactor.contains_pii("test@example.com")
assert redactor.contains_pii("192.168.1.1")
assert not redactor.contains_pii("Hello World")
def test_find_pii_locations(self):
"""Should find PII and return locations."""
from middleware.pii_redactor import PIIRedactor
redactor = PIIRedactor()
text = "Email: test@example.com, IP: 10.0.0.1"
findings = redactor.find_pii(text)
assert len(findings) == 2
assert any(f["type"] == "email" for f in findings)
assert any(f["type"] == "ip_v4" for f in findings)
# ==============================================
# Input Gate Middleware Tests
# ==============================================
class TestInputGateMiddleware:
"""Tests for InputGateMiddleware."""
def test_allows_valid_json_request(self):
"""Should allow valid JSON request within size limit."""
from middleware.input_gate import InputGateMiddleware
app = FastAPI()
app.add_middleware(InputGateMiddleware, max_body_size=1024)
@app.post("/test")
async def test_endpoint(data: dict):
return {"received": True}
client = TestClient(app)
response = client.post(
"/test",
json={"key": "value"},
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
def test_rejects_invalid_content_type(self):
"""Should reject request with invalid content type."""
from middleware.input_gate import InputGateMiddleware, InputGateConfig
config = InputGateConfig(
allowed_content_types={"application/json"},
strict_content_type=True
)
app = FastAPI()
app.add_middleware(InputGateMiddleware, config=config)
@app.post("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
response = client.post(
"/test",
content="data",
headers={"Content-Type": "text/xml"}
)
assert response.status_code == 415 # Unsupported Media Type
def test_allows_get_requests_without_body(self):
"""Should allow GET requests without validation."""
from middleware.input_gate import InputGateMiddleware
app = FastAPI()
app.add_middleware(InputGateMiddleware)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
def test_excludes_health_endpoint(self):
"""Should not validate excluded paths."""
from middleware.input_gate import InputGateMiddleware, InputGateConfig
config = InputGateConfig(excluded_paths=["/health"])
app = FastAPI()
app.add_middleware(InputGateMiddleware, config=config)
@app.get("/health")
async def health():
return {"status": "healthy"}
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
# ==============================================
# File Upload Validation Tests
# ==============================================
class TestFileUploadValidation:
"""Tests for file upload validation."""
def test_validates_file_size(self):
"""Should reject files exceeding max size."""
from middleware.input_gate import validate_file_upload, InputGateConfig
config = InputGateConfig(max_file_size=1024) # 1KB
valid, error = validate_file_upload(
filename="test.pdf",
content_type="application/pdf",
size=512, # 512 bytes
config=config
)
assert valid
valid, error = validate_file_upload(
filename="test.pdf",
content_type="application/pdf",
size=2048, # 2KB - exceeds limit
config=config
)
assert not valid
assert "size" in error.lower()
def test_rejects_blocked_extensions(self):
"""Should reject files with blocked extensions."""
from middleware.input_gate import validate_file_upload
valid, error = validate_file_upload(
filename="malware.exe",
content_type="application/octet-stream",
size=100
)
assert not valid
assert "extension" in error.lower()
valid, error = validate_file_upload(
filename="script.bat",
content_type="application/octet-stream",
size=100
)
assert not valid
def test_allows_safe_file_types(self):
"""Should allow safe file types."""
from middleware.input_gate import validate_file_upload
valid, error = validate_file_upload(
filename="document.pdf",
content_type="application/pdf",
size=1024
)
assert valid
valid, error = validate_file_upload(
filename="image.png",
content_type="image/png",
size=1024
)
assert valid
# ==============================================
# Rate Limiter Tests
# ==============================================
class TestRateLimiterMiddleware:
"""Tests for RateLimiterMiddleware."""
def test_allows_requests_under_limit(self):
"""Should allow requests under the rate limit."""
from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig
config = RateLimitConfig(
ip_limit=100,
window_size=60,
fallback_enabled=True,
skip_internal_network=True,
)
app = FastAPI()
app.add_middleware(RateLimiterMiddleware, config=config)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
# Make a few requests - should all succeed
for _ in range(5):
response = client.get("/test")
assert response.status_code == 200
def test_rate_limit_headers(self):
"""Should include rate limit headers in response."""
from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig
# Use a client IP that won't be skipped
config = RateLimitConfig(
ip_limit=100,
window_size=60,
fallback_enabled=True,
skip_internal_network=False, # Don't skip internal IPs for this test
)
app = FastAPI()
app.add_middleware(RateLimiterMiddleware, config=config)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert "X-RateLimit-Limit" in response.headers
assert "X-RateLimit-Remaining" in response.headers
assert "X-RateLimit-Reset" in response.headers
def test_skips_whitelisted_ips(self):
"""Should skip rate limiting for whitelisted IPs."""
from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig
config = RateLimitConfig(
ip_limit=1, # Very low limit
window_size=60,
ip_whitelist={"127.0.0.1", "::1", "10.0.0.1"},
fallback_enabled=True,
)
app = FastAPI()
app.add_middleware(RateLimiterMiddleware, config=config)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
client = TestClient(app)
# Multiple requests should succeed because the IP is whitelisted
# Use X-Forwarded-For header to simulate whitelisted IP
for _ in range(10):
response = client.get("/test", headers={"X-Forwarded-For": "10.0.0.1"})
assert response.status_code == 200
def test_excludes_health_endpoint(self):
"""Should not rate limit excluded paths."""
from middleware.rate_limiter import RateLimiterMiddleware, RateLimitConfig
config = RateLimitConfig(
ip_limit=1,
window_size=60,
excluded_paths=["/health"],
fallback_enabled=True,
)
app = FastAPI()
app.add_middleware(RateLimiterMiddleware, config=config)
@app.get("/health")
async def health():
return {"status": "healthy"}
client = TestClient(app)
# Multiple requests to health should succeed
for _ in range(10):
response = client.get("/health")
assert response.status_code == 200
# ==============================================
# Middleware Stack Integration Test
# ==============================================
class TestMiddlewareStackIntegration:
"""Tests for the complete middleware stack."""
def test_full_middleware_stack(self):
"""Test all middlewares work together."""
from middleware.request_id import RequestIDMiddleware
from middleware.security_headers import SecurityHeadersMiddleware
from middleware.input_gate import InputGateMiddleware
app = FastAPI()
# Add middlewares in order (last added = first executed)
app.add_middleware(InputGateMiddleware)
app.add_middleware(SecurityHeadersMiddleware, development_mode=True)
app.add_middleware(RequestIDMiddleware)
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
@app.post("/data")
async def data_endpoint(data: dict):
return {"received": data}
client = TestClient(app)
# Test GET request
response = client.get("/test")
assert response.status_code == 200
assert "X-Request-ID" in response.headers
assert "X-Content-Type-Options" in response.headers
# Test POST request with JSON
response = client.post(
"/data",
json={"key": "value"},
headers={"Content-Type": "application/json"}
)
assert response.status_code == 200
assert response.json()["received"] == {"key": "value"}
# ==============================================
# SDK Protection Middleware Tests
# ==============================================
class TestSDKProtectionMiddleware:
"""Tests for SDKProtectionMiddleware."""
def _create_app(self, **config_overrides):
"""Helper to create test app with SDK protection."""
from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig
config_kwargs = {
"fallback_enabled": True,
"watermark_secret": "test-secret",
}
config_kwargs.update(config_overrides)
config = SDKProtectionConfig(**config_kwargs)
app = FastAPI()
app.add_middleware(SDKProtectionMiddleware, config=config)
@app.get("/api/v1/tom/access-control")
async def tom_access_control():
return {"data": "access-control"}
@app.get("/api/v1/tom/encryption")
async def tom_encryption():
return {"data": "encryption"}
@app.get("/api/v1/dsfa/threshold")
async def dsfa_threshold():
return {"data": "threshold"}
@app.get("/api/v1/dsfa/necessity")
async def dsfa_necessity():
return {"data": "necessity"}
@app.get("/api/v1/vvt/processing")
async def vvt_processing():
return {"data": "processing"}
@app.get("/api/v1/vvt/purposes")
async def vvt_purposes():
return {"data": "purposes"}
@app.get("/api/v1/vvt/categories")
async def vvt_categories():
return {"data": "categories"}
@app.get("/api/v1/vvt/recipients")
async def vvt_recipients():
return {"data": "recipients"}
@app.get("/api/v1/controls/list")
async def controls_list():
return {"data": "controls"}
@app.get("/api/v1/assessment/run")
async def assessment_run():
return {"data": "assessment"}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/api/public")
async def public():
return {"data": "public"}
return app
def test_allows_normal_request(self):
"""Should allow normal requests under all limits."""
app = self._create_app()
client = TestClient(app)
response = client.get(
"/api/v1/tom/access-control",
headers={"X-API-Key": "test-user-key-123"},
)
assert response.status_code == 200
assert response.json() == {"data": "access-control"}
def test_quota_headers_present(self):
"""Should include quota headers in response."""
app = self._create_app()
client = TestClient(app)
response = client.get(
"/api/v1/tom/access-control",
headers={"X-API-Key": "test-user-key-456"},
)
assert response.status_code == 200
assert "X-SDK-Quota-Remaining-Minute" in response.headers
assert "X-SDK-Quota-Remaining-Hour" in response.headers
assert "X-SDK-Throttle-Level" in response.headers
def test_blocks_after_quota_exceeded(self):
"""Should return 429 when minute quota is exceeded."""
from middleware.sdk_protection import SDKProtectionConfig, QuotaTier
tiers = {
"free": QuotaTier("free", 3, 500, 3000, 50000), # Very low minute limit
}
app = self._create_app(tiers=tiers)
client = TestClient(app)
api_key = "quota-test-user"
headers = {"X-API-Key": api_key}
# Make requests up to the limit
for i in range(3):
response = client.get("/api/v1/tom/access-control", headers=headers)
assert response.status_code == 200, f"Request {i+1} should succeed"
# Next request should be blocked
response = client.get("/api/v1/tom/access-control", headers=headers)
assert response.status_code == 429
assert response.json()["error"] == "sdk_quota_exceeded"
def test_diversity_tracking_increments_score(self):
"""Score should increase when accessing many different categories."""
from middleware.sdk_protection import SDKProtectionConfig
app = self._create_app(diversity_threshold=3) # Low threshold for test
client = TestClient(app)
api_key = "diversity-test-user"
headers = {"X-API-Key": api_key}
# Access many different categories
endpoints = [
"/api/v1/tom/access-control",
"/api/v1/tom/encryption",
"/api/v1/dsfa/threshold",
"/api/v1/dsfa/necessity",
"/api/v1/vvt/processing",
"/api/v1/vvt/purposes",
]
for endpoint in endpoints:
response = client.get(endpoint, headers=headers)
assert response.status_code in (200, 429)
# After exceeding diversity, throttle level should increase
response = client.get("/api/v1/vvt/categories", headers=headers)
if response.status_code == 200:
level = int(response.headers.get("X-SDK-Throttle-Level", "0"))
assert level >= 0 # Score increased but may not hit threshold yet
def test_burst_detection(self):
"""Score should increase for rapid same-category requests."""
from middleware.sdk_protection import SDKProtectionConfig
app = self._create_app(burst_threshold=3) # Low threshold for test
client = TestClient(app)
api_key = "burst-test-user"
headers = {"X-API-Key": api_key}
# Burst access to same endpoint
for _ in range(5):
response = client.get("/api/v1/tom/access-control", headers=headers)
if response.status_code == 429:
break
# After burst, throttle level should have increased
response = client.get("/api/v1/tom/encryption", headers=headers)
if response.status_code == 200:
level = int(response.headers.get("X-SDK-Throttle-Level", "0"))
assert level >= 0 # Score increased
def test_sequential_enumeration_detection(self):
"""Score should increase for alphabetically sorted access patterns."""
from middleware.sdk_protection import (
SDKProtectionMiddleware,
SDKProtectionConfig,
InMemorySDKProtection,
)
config = SDKProtectionConfig(
sequential_min_entries=5,
sequential_sorted_ratio=0.6,
)
mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware)
mw.config = config
# Sorted sequence should be detected
sorted_seq = ["a_cat", "b_cat", "c_cat", "d_cat", "e_cat", "f_cat"]
assert mw._check_sequential(sorted_seq) is True
# Random sequence should not be detected
random_seq = ["d_cat", "a_cat", "f_cat", "b_cat", "e_cat", "c_cat"]
assert mw._check_sequential(random_seq) is False
# Too short sequence should not be detected
short_seq = ["a_cat", "b_cat"]
assert mw._check_sequential(short_seq) is False
def test_progressive_throttling_level_1(self):
"""Throttle level 1 should be set at score >= 30."""
from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig
config = SDKProtectionConfig()
mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware)
mw.config = config
assert mw._get_throttle_level(0) == 0
assert mw._get_throttle_level(29) == 0
assert mw._get_throttle_level(30) == 1
assert mw._get_throttle_level(50) == 1
assert mw._get_throttle_level(59) == 1
def test_progressive_throttling_level_3_blocks(self):
"""Throttle level 3 should be set at score >= 85."""
from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig
config = SDKProtectionConfig()
mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware)
mw.config = config
assert mw._get_throttle_level(60) == 2
assert mw._get_throttle_level(84) == 2
assert mw._get_throttle_level(85) == 3
assert mw._get_throttle_level(100) == 3
def test_score_decay_over_time(self):
"""Score should decay over time using decay factor."""
from middleware.sdk_protection import SDKProtectionMiddleware, SDKProtectionConfig
config = SDKProtectionConfig(
score_decay_factor=0.5, # Aggressive decay for test
score_decay_interval=60, # 1 minute intervals
)
mw = SDKProtectionMiddleware.__new__(SDKProtectionMiddleware)
mw.config = config
now = time.time()
# Score 100, last decay 2 intervals ago
score, last_decay = mw._apply_decay(100.0, now - 120, now)
# 2 intervals: 100 * 0.5 * 0.5 = 25
assert score == pytest.approx(25.0)
# No decay if within same interval
score2, _ = mw._apply_decay(100.0, now - 30, now)
assert score2 == pytest.approx(100.0)
def test_skips_non_protected_paths(self):
"""Should not apply protection to non-SDK paths."""
app = self._create_app()
client = TestClient(app)
# Health endpoint should not be protected
response = client.get("/health")
assert response.status_code == 200
assert "X-SDK-Throttle-Level" not in response.headers
# Non-SDK path should not be protected
response = client.get("/api/public")
assert response.status_code == 200
assert "X-SDK-Throttle-Level" not in response.headers
def test_watermark_header_present(self):
"""Response should include X-BP-Trace watermark header."""
app = self._create_app()
client = TestClient(app)
response = client.get(
"/api/v1/tom/access-control",
headers={"X-API-Key": "watermark-test-user"},
)
assert response.status_code == 200
assert "X-BP-Trace" in response.headers
assert len(response.headers["X-BP-Trace"]) == 32
def test_fallback_to_inmemory(self):
"""Should work with in-memory fallback when Valkey is unavailable."""
from middleware.sdk_protection import SDKProtectionConfig
# Point to non-existent Valkey
app = self._create_app(valkey_url="redis://nonexistent:9999")
client = TestClient(app)
response = client.get(
"/api/v1/tom/access-control",
headers={"X-API-Key": "fallback-test-user"},
)
assert response.status_code == 200
assert response.json() == {"data": "access-control"}
def test_no_user_passes_through(self):
"""Requests without user identification should pass through."""
app = self._create_app()
client = TestClient(app)
# No API key and no session
response = client.get("/api/v1/tom/access-control")
assert response.status_code == 200
def test_category_extraction(self):
"""Category extraction should use longest prefix match."""
from middleware.sdk_protection import _extract_category
assert _extract_category("/api/v1/tom/access-control") == "tom_access_control"
assert _extract_category("/api/v1/tom/encryption") == "tom_encryption"
assert _extract_category("/api/v1/dsfa/threshold") == "dsfa_threshold"
assert _extract_category("/api/v1/vvt/processing") == "vvt_processing"
assert _extract_category("/api/v1/controls/anything") == "controls_general"
assert _extract_category("/api/unknown/path") == "unknown"