Initial commit: breakpilot-core - Shared Infrastructure
Docker Compose with 24+ services: - PostgreSQL (PostGIS), Valkey, MinIO, Qdrant - Vault (PKI/TLS), Nginx (Reverse Proxy) - Backend Core API, Consent Service, Billing Service - RAG Service, Embedding Service - Gitea, Woodpecker CI/CD - Night Scheduler, Health Aggregator - Jitsi (Web/XMPP/JVB/Jicofo), Mailpit Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
247
consent-service/internal/middleware/input_gate.go
Normal file
247
consent-service/internal/middleware/input_gate.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// InputGateConfig holds configuration for input validation.
|
||||
type InputGateConfig struct {
|
||||
// Maximum request body size (default: 10MB)
|
||||
MaxBodySize int64
|
||||
|
||||
// Maximum file upload size (default: 50MB)
|
||||
MaxFileSize int64
|
||||
|
||||
// Allowed content types
|
||||
AllowedContentTypes map[string]bool
|
||||
|
||||
// Allowed file types for uploads
|
||||
AllowedFileTypes map[string]bool
|
||||
|
||||
// Blocked file extensions
|
||||
BlockedExtensions map[string]bool
|
||||
|
||||
// Paths that allow larger uploads
|
||||
LargeUploadPaths []string
|
||||
|
||||
// Paths excluded from validation
|
||||
ExcludedPaths []string
|
||||
|
||||
// Enable strict content type checking
|
||||
StrictContentType bool
|
||||
}
|
||||
|
||||
// DefaultInputGateConfig returns sensible default configuration.
|
||||
func DefaultInputGateConfig() InputGateConfig {
|
||||
maxSize := int64(10 * 1024 * 1024) // 10MB
|
||||
if envSize := os.Getenv("MAX_REQUEST_BODY_SIZE"); envSize != "" {
|
||||
if size, err := strconv.ParseInt(envSize, 10, 64); err == nil {
|
||||
maxSize = size
|
||||
}
|
||||
}
|
||||
|
||||
return InputGateConfig{
|
||||
MaxBodySize: maxSize,
|
||||
MaxFileSize: 50 * 1024 * 1024, // 50MB
|
||||
AllowedContentTypes: map[string]bool{
|
||||
"application/json": true,
|
||||
"application/x-www-form-urlencoded": true,
|
||||
"multipart/form-data": true,
|
||||
"text/plain": true,
|
||||
},
|
||||
AllowedFileTypes: map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"application/pdf": true,
|
||||
"text/csv": true,
|
||||
"application/msword": true,
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": true,
|
||||
"application/vnd.ms-excel": true,
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true,
|
||||
},
|
||||
BlockedExtensions: map[string]bool{
|
||||
".exe": true, ".bat": true, ".cmd": true, ".com": true, ".msi": true,
|
||||
".dll": true, ".scr": true, ".pif": true, ".vbs": true, ".js": true,
|
||||
".jar": true, ".sh": true, ".ps1": true, ".app": true,
|
||||
},
|
||||
LargeUploadPaths: []string{
|
||||
"/api/v1/files/upload",
|
||||
"/api/v1/documents/upload",
|
||||
"/api/v1/attachments",
|
||||
},
|
||||
ExcludedPaths: []string{
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/health",
|
||||
},
|
||||
StrictContentType: true,
|
||||
}
|
||||
}
|
||||
|
||||
// isExcludedPath checks if path is excluded from validation.
|
||||
func (c *InputGateConfig) isExcludedPath(path string) bool {
|
||||
for _, excluded := range c.ExcludedPaths {
|
||||
if path == excluded {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isLargeUploadPath checks if path allows larger uploads.
|
||||
func (c *InputGateConfig) isLargeUploadPath(path string) bool {
|
||||
for _, uploadPath := range c.LargeUploadPaths {
|
||||
if strings.HasPrefix(path, uploadPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getMaxSize returns the maximum allowed body size for the path.
|
||||
func (c *InputGateConfig) getMaxSize(path string) int64 {
|
||||
if c.isLargeUploadPath(path) {
|
||||
return c.MaxFileSize
|
||||
}
|
||||
return c.MaxBodySize
|
||||
}
|
||||
|
||||
// validateContentType validates the content type.
|
||||
func (c *InputGateConfig) validateContentType(contentType string) (bool, string) {
|
||||
if contentType == "" {
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// Extract base content type (remove charset, boundary, etc.)
|
||||
baseType := strings.Split(contentType, ";")[0]
|
||||
baseType = strings.TrimSpace(strings.ToLower(baseType))
|
||||
|
||||
if !c.AllowedContentTypes[baseType] {
|
||||
return false, "Content-Type '" + baseType + "' is not allowed"
|
||||
}
|
||||
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// hasBlockedExtension checks if filename has a blocked extension.
|
||||
func (c *InputGateConfig) hasBlockedExtension(filename string) bool {
|
||||
if filename == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
lowerFilename := strings.ToLower(filename)
|
||||
for ext := range c.BlockedExtensions {
|
||||
if strings.HasSuffix(lowerFilename, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// InputGate returns a middleware that validates incoming request bodies.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// r.Use(middleware.InputGate())
|
||||
//
|
||||
// // Or with custom config:
|
||||
// config := middleware.DefaultInputGateConfig()
|
||||
// config.MaxBodySize = 5 * 1024 * 1024 // 5MB
|
||||
// r.Use(middleware.InputGateWithConfig(config))
|
||||
func InputGate() gin.HandlerFunc {
|
||||
return InputGateWithConfig(DefaultInputGateConfig())
|
||||
}
|
||||
|
||||
// InputGateWithConfig returns an input gate middleware with custom configuration.
|
||||
func InputGateWithConfig(config InputGateConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip excluded paths
|
||||
if config.isExcludedPath(c.Request.URL.Path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Skip validation for GET, HEAD, OPTIONS requests
|
||||
method := c.Request.Method
|
||||
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate content type for requests with body
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if config.StrictContentType {
|
||||
valid, errMsg := config.validateContentType(contentType)
|
||||
if !valid {
|
||||
c.AbortWithStatusJSON(http.StatusUnsupportedMediaType, gin.H{
|
||||
"error": "unsupported_media_type",
|
||||
"message": errMsg,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check Content-Length header
|
||||
contentLength := c.GetHeader("Content-Length")
|
||||
if contentLength != "" {
|
||||
length, err := strconv.ParseInt(contentLength, 10, 64)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_content_length",
|
||||
"message": "Invalid Content-Length header",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
maxSize := config.getMaxSize(c.Request.URL.Path)
|
||||
if length > maxSize {
|
||||
c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "payload_too_large",
|
||||
"message": "Request body exceeds maximum size",
|
||||
"max_size": maxSize,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set max multipart memory for file uploads
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, config.MaxFileSize)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateFileUpload validates a file upload.
|
||||
// Use this in upload handlers for detailed validation.
|
||||
func ValidateFileUpload(filename, contentType string, size int64, config *InputGateConfig) (bool, string) {
|
||||
if config == nil {
|
||||
defaultConfig := DefaultInputGateConfig()
|
||||
config = &defaultConfig
|
||||
}
|
||||
|
||||
// Check size
|
||||
if size > config.MaxFileSize {
|
||||
return false, "File size exceeds maximum allowed"
|
||||
}
|
||||
|
||||
// Check extension
|
||||
if config.hasBlockedExtension(filename) {
|
||||
return false, "File extension is not allowed"
|
||||
}
|
||||
|
||||
// Check content type
|
||||
if contentType != "" && !config.AllowedFileTypes[contentType] {
|
||||
return false, "File type '" + contentType + "' is not allowed"
|
||||
}
|
||||
|
||||
return true, ""
|
||||
}
|
||||
421
consent-service/internal/middleware/input_gate_test.go
Normal file
421
consent-service/internal/middleware/input_gate_test.go
Normal file
@@ -0,0 +1,421 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestInputGate_AllowsGETRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for GET request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsHEADRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.HEAD("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodHead, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for HEAD request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsOPTIONSRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.OPTIONS("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for OPTIONS request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsValidJSONRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`{"key": "value"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Length", "16")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for valid JSON, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_RejectsInvalidContentType(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultInputGateConfig()
|
||||
config.StrictContentType = true
|
||||
router.Use(InputGateWithConfig(config))
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`data`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
req.Header.Set("Content-Length", "4")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnsupportedMediaType {
|
||||
t.Errorf("Expected status 415 for invalid content type, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsEmptyContentType(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`data`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
// No Content-Type header
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for empty content type, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_RejectsOversizedRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxBodySize = 100 // 100 bytes
|
||||
router.Use(InputGateWithConfig(config))
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Create a body larger than 100 bytes
|
||||
largeBody := strings.Repeat("x", 200)
|
||||
body := bytes.NewBufferString(largeBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Length", "200")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("Expected status 413 for oversized request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsLargeUploadPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxBodySize = 100 // 100 bytes
|
||||
config.MaxFileSize = 1000 // 1000 bytes
|
||||
config.LargeUploadPaths = []string{"/api/v1/files/upload"}
|
||||
router.Use(InputGateWithConfig(config))
|
||||
|
||||
router.POST("/api/v1/files/upload", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Create a body larger than MaxBodySize but smaller than MaxFileSize
|
||||
largeBody := strings.Repeat("x", 500)
|
||||
body := bytes.NewBufferString(largeBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/files/upload", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Length", "500")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for large upload path, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_ExcludedPaths(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxBodySize = 10 // Very small
|
||||
config.ExcludedPaths = []string{"/health"}
|
||||
router.Use(InputGateWithConfig(config))
|
||||
|
||||
router.POST("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
|
||||
})
|
||||
|
||||
// Send oversized body to excluded path
|
||||
largeBody := strings.Repeat("x", 100)
|
||||
body := bytes.NewBufferString(largeBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/health", body)
|
||||
req.Header.Set("Content-Length", "100")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Should pass because path is excluded
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for excluded path, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_RejectsInvalidContentLength(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`data`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Length", "invalid")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for invalid content length, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_BlockedExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
contentType string
|
||||
blocked bool
|
||||
}{
|
||||
{"malware.exe", "application/octet-stream", true},
|
||||
{"script.bat", "application/octet-stream", true},
|
||||
{"hack.cmd", "application/octet-stream", true},
|
||||
{"shell.sh", "application/octet-stream", true},
|
||||
{"powershell.ps1", "application/octet-stream", true},
|
||||
{"document.pdf", "application/pdf", false},
|
||||
{"image.jpg", "image/jpeg", false},
|
||||
{"data.csv", "text/csv", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
valid, errMsg := ValidateFileUpload(tt.filename, tt.contentType, 100, nil)
|
||||
if tt.blocked && valid {
|
||||
t.Errorf("File %s should be blocked", tt.filename)
|
||||
}
|
||||
if !tt.blocked && !valid {
|
||||
t.Errorf("File %s should not be blocked, error: %s", tt.filename, errMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_OversizedFile(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxFileSize = 1000 // 1KB
|
||||
|
||||
valid, errMsg := ValidateFileUpload("test.pdf", "application/pdf", 2000, &config)
|
||||
|
||||
if valid {
|
||||
t.Error("Should reject oversized file")
|
||||
}
|
||||
if !strings.Contains(errMsg, "size") {
|
||||
t.Errorf("Error message should mention size, got: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_ValidFile(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
|
||||
valid, errMsg := ValidateFileUpload("document.pdf", "application/pdf", 1000, &config)
|
||||
|
||||
if !valid {
|
||||
t.Errorf("Should accept valid file, got error: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_InvalidContentType(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
|
||||
valid, errMsg := ValidateFileUpload("file.xyz", "application/x-unknown", 100, &config)
|
||||
|
||||
if valid {
|
||||
t.Error("Should reject unknown file type")
|
||||
}
|
||||
if !strings.Contains(errMsg, "not allowed") {
|
||||
t.Errorf("Error message should mention not allowed, got: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_NilConfig(t *testing.T) {
|
||||
// Should use default config when nil is passed
|
||||
valid, _ := ValidateFileUpload("document.pdf", "application/pdf", 1000, nil)
|
||||
|
||||
if !valid {
|
||||
t.Error("Should accept valid file with nil config (uses defaults)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasBlockedExtension(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
|
||||
tests := []struct {
|
||||
filename string
|
||||
blocked bool
|
||||
}{
|
||||
{"test.exe", true},
|
||||
{"TEST.EXE", true}, // Case insensitive
|
||||
{"script.BAT", true},
|
||||
{"app.APP", true},
|
||||
{"document.pdf", false},
|
||||
{"image.png", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := config.hasBlockedExtension(tt.filename)
|
||||
if result != tt.blocked {
|
||||
t.Errorf("File %s: expected blocked=%v, got %v", tt.filename, tt.blocked, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateContentType(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
|
||||
tests := []struct {
|
||||
contentType string
|
||||
valid bool
|
||||
}{
|
||||
{"application/json", true},
|
||||
{"application/json; charset=utf-8", true},
|
||||
{"APPLICATION/JSON", true}, // Case insensitive
|
||||
{"multipart/form-data; boundary=----WebKitFormBoundary", true},
|
||||
{"text/plain", true},
|
||||
{"application/xml", false},
|
||||
{"text/html", false},
|
||||
{"", true}, // Empty is allowed
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
valid, _ := config.validateContentType(tt.contentType)
|
||||
if valid != tt.valid {
|
||||
t.Errorf("Content-Type %q: expected valid=%v, got %v", tt.contentType, tt.valid, valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLargeUploadPath(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
config.LargeUploadPaths = []string{"/api/v1/files/upload", "/api/v1/documents"}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
isLarge bool
|
||||
}{
|
||||
{"/api/v1/files/upload", true},
|
||||
{"/api/v1/files/upload/batch", true}, // Prefix match
|
||||
{"/api/v1/documents", true},
|
||||
{"/api/v1/documents/1/attachments", true},
|
||||
{"/api/v1/users", false},
|
||||
{"/health", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := config.isLargeUploadPath(tt.path)
|
||||
if result != tt.isLarge {
|
||||
t.Errorf("Path %s: expected isLarge=%v, got %v", tt.path, tt.isLarge, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxSize(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxBodySize = 100
|
||||
config.MaxFileSize = 1000
|
||||
config.LargeUploadPaths = []string{"/api/v1/files/upload"}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expected int64
|
||||
}{
|
||||
{"/api/test", 100},
|
||||
{"/api/v1/files/upload", 1000},
|
||||
{"/health", 100},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := config.getMaxSize(tt.path)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Path %s: expected maxSize=%d, got %d", tt.path, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_DefaultMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`{"key": "value"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
379
consent-service/internal/middleware/middleware.go
Normal file
379
consent-service/internal/middleware/middleware.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserClaims represents the JWT claims for a user
|
||||
type UserClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CORS returns a CORS middleware
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// Allow localhost for development
|
||||
allowedOrigins := []string{
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8080",
|
||||
"https://breakpilot.app",
|
||||
}
|
||||
|
||||
allowed := false
|
||||
for _, o := range allowedOrigins {
|
||||
if origin == o {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allowed {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequestLogger logs each request
|
||||
func RequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
method := c.Request.Method
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
|
||||
// Log only in development or for errors
|
||||
if status >= 400 {
|
||||
gin.DefaultWriter.Write([]byte(
|
||||
method + " " + path + " " +
|
||||
string(rune(status)) + " " +
|
||||
latency.String() + "\n",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimiter implements a simple in-memory rate limiter
|
||||
// Configurable via RATE_LIMIT_PER_MINUTE env var (default: 500)
|
||||
func RateLimiter() gin.HandlerFunc {
|
||||
type client struct {
|
||||
count int
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
clients = make(map[string]*client)
|
||||
)
|
||||
|
||||
// Clean up old entries periodically
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
mu.Lock()
|
||||
for ip, c := range clients {
|
||||
if time.Since(c.lastSeen) > time.Minute {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
|
||||
// Skip rate limiting for Docker internal network (172.x.x.x) and localhost
|
||||
// This prevents issues when multiple services share the same internal IP
|
||||
if strings.HasPrefix(ip, "172.") || ip == "127.0.0.1" || ip == "::1" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := clients[ip]; !exists {
|
||||
clients[ip] = &client{}
|
||||
}
|
||||
|
||||
cli := clients[ip]
|
||||
|
||||
// Reset count if more than a minute has passed
|
||||
if time.Since(cli.lastSeen) > time.Minute {
|
||||
cli.count = 0
|
||||
}
|
||||
|
||||
cli.count++
|
||||
cli.lastSeen = time.Now()
|
||||
|
||||
// Allow 500 requests per minute (increased for admin panels with many API calls)
|
||||
if cli.count > 500 {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": "Too many requests. Please try again later.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthMiddleware validates JWT tokens
|
||||
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "missing_authorization",
|
||||
"message": "Authorization header is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract token from "Bearer <token>"
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_authorization",
|
||||
"message": "Authorization header must be in format: Bearer <token>",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// Parse and validate token
|
||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_token",
|
||||
"message": "Invalid or expired token",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||
// Set user info in context
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("role", claims.Role)
|
||||
c.Next()
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_claims",
|
||||
"message": "Invalid token claims",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AdminOnly ensures only admin users can access the route
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User role not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
roleStr, ok := role.(string)
|
||||
if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "forbidden",
|
||||
"message": "Admin access required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// DSBOnly ensures only Data Protection Officers can access the route
|
||||
// Used for critical operations like publishing legal documents (four-eyes principle)
|
||||
func DSBOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User role not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
roleStr, ok := role.(string)
|
||||
if !ok || (roleStr != "data_protection_officer" && roleStr != "super_admin") {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "forbidden",
|
||||
"message": "Only Data Protection Officers can perform this action",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user has admin role
|
||||
func IsAdmin(c *gin.Context) bool {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
roleStr, ok := role.(string)
|
||||
return ok && (roleStr == "admin" || roleStr == "super_admin" || roleStr == "data_protection_officer")
|
||||
}
|
||||
|
||||
// IsDSB checks if the user has DSB role
|
||||
func IsDSB(c *gin.Context) bool {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
roleStr, ok := role.(string)
|
||||
return ok && (roleStr == "data_protection_officer" || roleStr == "super_admin")
|
||||
}
|
||||
|
||||
// GetUserID extracts the user ID from the context
|
||||
func GetUserID(c *gin.Context) (uuid.UUID, error) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// GetClientIP returns the client's IP address
|
||||
func GetClientIP(c *gin.Context) string {
|
||||
// Check X-Forwarded-For header first (for proxied requests)
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// GetUserAgent returns the client's User-Agent
|
||||
func GetUserAgent(c *gin.Context) string {
|
||||
return c.GetHeader("User-Agent")
|
||||
}
|
||||
|
||||
// SuspensionCheckMiddleware checks if a user is suspended and restricts access
|
||||
// Suspended users can only access consent-related endpoints
|
||||
func SuspensionCheckMiddleware(pool interface{ QueryRow(ctx interface{}, sql string, args ...interface{}) interface{ Scan(dest ...interface{}) error } }) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Check user account status
|
||||
var accountStatus string
|
||||
err = pool.QueryRow(c.Request.Context(), `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if accountStatus == "suspended" {
|
||||
// Check if current path is allowed for suspended users
|
||||
path := c.Request.URL.Path
|
||||
allowedPaths := []string{
|
||||
"/api/v1/consent",
|
||||
"/api/v1/documents",
|
||||
"/api/v1/notifications",
|
||||
"/api/v1/profile",
|
||||
"/api/v1/privacy/my-data",
|
||||
"/api/v1/auth/logout",
|
||||
}
|
||||
|
||||
allowed := false
|
||||
for _, p := range allowedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "account_suspended",
|
||||
"message": "Your account is suspended due to pending consent requirements",
|
||||
"redirect": "/consent/pending",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set suspended flag in context for handlers to use
|
||||
c.Set("account_suspended", true)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsSuspended checks if the current user's account is suspended
|
||||
func IsSuspended(c *gin.Context) bool {
|
||||
suspended, exists := c.Get("account_suspended")
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return suspended.(bool)
|
||||
}
|
||||
546
consent-service/internal/middleware/middleware_test.go
Normal file
546
consent-service/internal/middleware/middleware_test.go
Normal file
@@ -0,0 +1,546 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// Helper to create a valid JWT token for testing
|
||||
func createTestToken(secret string, userID, email, role string, exp time.Time) string {
|
||||
claims := UserClaims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Role: role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(exp),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secret))
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestCORS tests the CORS middleware
|
||||
func TestCORS(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
method string
|
||||
expectedStatus int
|
||||
expectAllowedOrigin bool
|
||||
}{
|
||||
{"localhost:3000", "http://localhost:3000", "GET", http.StatusOK, true},
|
||||
{"localhost:8000", "http://localhost:8000", "GET", http.StatusOK, true},
|
||||
{"production", "https://breakpilot.app", "GET", http.StatusOK, true},
|
||||
{"unknown origin", "https://unknown.com", "GET", http.StatusOK, false},
|
||||
{"preflight", "http://localhost:3000", "OPTIONS", http.StatusNoContent, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(tt.method, "/test", nil)
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
allowedOrigin := w.Header().Get("Access-Control-Allow-Origin")
|
||||
if tt.expectAllowedOrigin && allowedOrigin != tt.origin {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", tt.origin, allowedOrigin)
|
||||
}
|
||||
if !tt.expectAllowedOrigin && allowedOrigin != "" {
|
||||
t.Errorf("Expected no Access-Control-Allow-Origin header, got %s", allowedOrigin)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORSHeaders tests that CORS headers are set correctly
|
||||
func TestCORSHeaders(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", "http://localhost:3000")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
expectedHeaders := map[string]string{
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Origin, Content-Type, Authorization, X-Requested-With",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
}
|
||||
|
||||
for header, expected := range expectedHeaders {
|
||||
actual := w.Header().Get(header)
|
||||
if actual != expected {
|
||||
t.Errorf("Expected %s to be %s, got %s", header, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_ValidToken tests authentication with valid token
|
||||
func TestAuthMiddleware_ValidToken(t *testing.T) {
|
||||
secret := "test-secret-key"
|
||||
userID := uuid.New().String()
|
||||
email := "test@example.com"
|
||||
role := "user"
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware(secret))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
uid, _ := c.Get("user_id")
|
||||
em, _ := c.Get("email")
|
||||
r, _ := c.Get("role")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": uid,
|
||||
"email": em,
|
||||
"role": r,
|
||||
})
|
||||
})
|
||||
|
||||
token := createTestToken(secret, userID, email, role, time.Now().Add(time.Hour))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_MissingHeader tests authentication without header
|
||||
func TestAuthMiddleware_MissingHeader(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware("test-secret"))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_InvalidFormat tests authentication with invalid header format
|
||||
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware("test-secret"))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
}{
|
||||
{"no Bearer prefix", "some-token"},
|
||||
{"Basic auth", "Basic dXNlcjpwYXNz"},
|
||||
{"empty Bearer", "Bearer "},
|
||||
{"multiple spaces", "Bearer token"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("Authorization", tt.header)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_ExpiredToken tests authentication with expired token
|
||||
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
|
||||
secret := "test-secret"
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware(secret))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
// Create expired token
|
||||
token := createTestToken(secret, "user-123", "test@example.com", "user", time.Now().Add(-time.Hour))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_WrongSecret tests authentication with wrong secret
|
||||
func TestAuthMiddleware_WrongSecret(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware("correct-secret"))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
// Create token with different secret
|
||||
token := createTestToken("wrong-secret", "user-123", "test@example.com", "user", time.Now().Add(time.Hour))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminOnly tests the AdminOnly middleware
|
||||
func TestAdminOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
role string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"admin allowed", "admin", http.StatusOK},
|
||||
{"super_admin allowed", "super_admin", http.StatusOK},
|
||||
{"dpo allowed", "data_protection_officer", http.StatusOK},
|
||||
{"user forbidden", "user", http.StatusForbidden},
|
||||
{"empty role forbidden", "", http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("role", tt.role)
|
||||
c.Next()
|
||||
})
|
||||
router.Use(AdminOnly())
|
||||
router.GET("/admin", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/admin", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminOnly_NoRole tests AdminOnly when role is not set
|
||||
func TestAdminOnly_NoRole(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AdminOnly())
|
||||
router.GET("/admin", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/admin", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSBOnly tests the DSBOnly middleware
|
||||
func TestDSBOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
role string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"dpo allowed", "data_protection_officer", http.StatusOK},
|
||||
{"super_admin allowed", "super_admin", http.StatusOK},
|
||||
{"admin forbidden", "admin", http.StatusForbidden},
|
||||
{"user forbidden", "user", http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("role", tt.role)
|
||||
c.Next()
|
||||
})
|
||||
router.Use(DSBOnly())
|
||||
router.GET("/dsb", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/dsb", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAdmin tests the IsAdmin helper function
|
||||
func TestIsAdmin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
role string
|
||||
expected bool
|
||||
}{
|
||||
{"admin", "admin", true},
|
||||
{"super_admin", "super_admin", true},
|
||||
{"dpo", "data_protection_officer", true},
|
||||
{"user", "user", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.role != "" {
|
||||
c.Set("role", tt.role)
|
||||
}
|
||||
|
||||
result := IsAdmin(c)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected IsAdmin to be %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsDSB tests the IsDSB helper function
|
||||
func TestIsDSB(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
role string
|
||||
expected bool
|
||||
}{
|
||||
{"dpo", "data_protection_officer", true},
|
||||
{"super_admin", "super_admin", true},
|
||||
{"admin", "admin", false},
|
||||
{"user", "user", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Set("role", tt.role)
|
||||
|
||||
result := IsDSB(c)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected IsDSB to be %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserID tests the GetUserID helper function
|
||||
func TestGetUserID(t *testing.T) {
|
||||
validUUID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
setUserID bool
|
||||
expectError bool
|
||||
expectedID uuid.UUID
|
||||
}{
|
||||
{"valid UUID", validUUID.String(), true, false, validUUID},
|
||||
{"invalid UUID", "not-a-uuid", true, true, uuid.Nil},
|
||||
{"missing user_id", "", false, false, uuid.Nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setUserID {
|
||||
c.Set("user_id", tt.userID)
|
||||
}
|
||||
|
||||
result, err := GetUserID(c)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
|
||||
if !tt.expectError && result != tt.expectedID {
|
||||
t.Errorf("Expected %v, got %v", tt.expectedID, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetClientIP tests the GetClientIP helper function
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
xff string
|
||||
xri string
|
||||
clientIP string
|
||||
expectedIP string
|
||||
}{
|
||||
{"X-Forwarded-For", "10.0.0.1", "", "192.168.1.1", "10.0.0.1"},
|
||||
{"X-Forwarded-For multiple", "10.0.0.1, 10.0.0.2", "", "192.168.1.1", "10.0.0.1"},
|
||||
{"X-Real-IP", "", "10.0.0.1", "192.168.1.1", "10.0.0.1"},
|
||||
{"direct", "", "", "192.168.1.1", "192.168.1.1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
if tt.xff != "" {
|
||||
c.Request.Header.Set("X-Forwarded-For", tt.xff)
|
||||
}
|
||||
if tt.xri != "" {
|
||||
c.Request.Header.Set("X-Real-IP", tt.xri)
|
||||
}
|
||||
c.Request.RemoteAddr = tt.clientIP + ":12345"
|
||||
|
||||
result := GetClientIP(c)
|
||||
|
||||
// Note: gin.ClientIP() might return different values
|
||||
// depending on trusted proxies config
|
||||
if result != tt.expectedIP && result != tt.clientIP {
|
||||
t.Logf("Note: GetClientIP returned %s (expected %s or %s)", result, tt.expectedIP, tt.clientIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserAgent tests the GetUserAgent helper function
|
||||
func TestGetUserAgent(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
|
||||
expectedUA := "Mozilla/5.0 (Test)"
|
||||
c.Request.Header.Set("User-Agent", expectedUA)
|
||||
|
||||
result := GetUserAgent(c)
|
||||
if result != expectedUA {
|
||||
t.Errorf("Expected %s, got %s", expectedUA, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsSuspended tests the IsSuspended helper function
|
||||
func TestIsSuspended(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
suspended interface{}
|
||||
setSuspended bool
|
||||
expected bool
|
||||
}{
|
||||
{"suspended true", true, true, true},
|
||||
{"suspended false", false, true, false},
|
||||
{"not set", nil, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setSuspended {
|
||||
c.Set("account_suspended", tt.suspended)
|
||||
}
|
||||
|
||||
result := IsSuspended(c)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCORS benchmarks the CORS middleware
|
||||
func BenchmarkCORS(b *testing.B) {
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", "http://localhost:3000")
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAuthMiddleware benchmarks the auth middleware
|
||||
func BenchmarkAuthMiddleware(b *testing.B) {
|
||||
secret := "test-secret-key"
|
||||
token := createTestToken(secret, uuid.New().String(), "test@example.com", "user", time.Now().Add(time.Hour))
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware(secret))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
197
consent-service/internal/middleware/pii_redactor.go
Normal file
197
consent-service/internal/middleware/pii_redactor.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PIIPattern defines a pattern for identifying PII.
|
||||
type PIIPattern struct {
|
||||
Name string
|
||||
Pattern *regexp.Regexp
|
||||
Replacement string
|
||||
}
|
||||
|
||||
// PIIRedactor redacts personally identifiable information from strings.
|
||||
type PIIRedactor struct {
|
||||
patterns []*PIIPattern
|
||||
}
|
||||
|
||||
// Pre-compiled patterns for common PII types
|
||||
var (
|
||||
emailPattern = regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b`)
|
||||
ipv4Pattern = regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b`)
|
||||
ipv6Pattern = regexp.MustCompile(`\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b`)
|
||||
phonePattern = regexp.MustCompile(`(?:\+49|0049)[\s.-]?\d{2,4}[\s.-]?\d{3,8}|\b0\d{2,4}[\s.-]?\d{3,8}\b`)
|
||||
ibanPattern = regexp.MustCompile(`(?i)\b[A-Z]{2}\d{2}[\s]?(?:\d{4}[\s]?){3,5}\d{1,4}\b`)
|
||||
uuidPattern = regexp.MustCompile(`(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b`)
|
||||
namePattern = regexp.MustCompile(`\b(?:Herr|Frau|Hr\.|Fr\.)\s+[A-ZÄÖÜ][a-zäöüß]+(?:\s+[A-ZÄÖÜ][a-zäöüß]+)?\b`)
|
||||
)
|
||||
|
||||
// DefaultPIIPatterns returns the default set of PII patterns.
|
||||
func DefaultPIIPatterns() []*PIIPattern {
|
||||
return []*PIIPattern{
|
||||
{Name: "email", Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]"},
|
||||
{Name: "ip_v4", Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]"},
|
||||
{Name: "ip_v6", Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]"},
|
||||
{Name: "phone", Pattern: phonePattern, Replacement: "[PHONE_REDACTED]"},
|
||||
}
|
||||
}
|
||||
|
||||
// AllPIIPatterns returns all available PII patterns.
|
||||
func AllPIIPatterns() []*PIIPattern {
|
||||
return []*PIIPattern{
|
||||
{Name: "email", Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]"},
|
||||
{Name: "ip_v4", Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]"},
|
||||
{Name: "ip_v6", Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]"},
|
||||
{Name: "phone", Pattern: phonePattern, Replacement: "[PHONE_REDACTED]"},
|
||||
{Name: "iban", Pattern: ibanPattern, Replacement: "[IBAN_REDACTED]"},
|
||||
{Name: "uuid", Pattern: uuidPattern, Replacement: "[UUID_REDACTED]"},
|
||||
{Name: "name", Pattern: namePattern, Replacement: "[NAME_REDACTED]"},
|
||||
}
|
||||
}
|
||||
|
||||
// NewPIIRedactor creates a new PII redactor with the given patterns.
|
||||
func NewPIIRedactor(patterns []*PIIPattern) *PIIRedactor {
|
||||
if patterns == nil {
|
||||
patterns = DefaultPIIPatterns()
|
||||
}
|
||||
return &PIIRedactor{patterns: patterns}
|
||||
}
|
||||
|
||||
// NewDefaultPIIRedactor creates a PII redactor with default patterns.
|
||||
func NewDefaultPIIRedactor() *PIIRedactor {
|
||||
return NewPIIRedactor(DefaultPIIPatterns())
|
||||
}
|
||||
|
||||
// Redact removes PII from the given text.
|
||||
func (r *PIIRedactor) Redact(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
result := text
|
||||
for _, pattern := range r.patterns {
|
||||
result = pattern.Pattern.ReplaceAllString(result, pattern.Replacement)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ContainsPII checks if the text contains any PII.
|
||||
func (r *PIIRedactor) ContainsPII(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, pattern := range r.patterns {
|
||||
if pattern.Pattern.MatchString(text) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PIIFinding represents a found PII instance.
|
||||
type PIIFinding struct {
|
||||
Type string
|
||||
Match string
|
||||
Start int
|
||||
End int
|
||||
}
|
||||
|
||||
// FindPII finds all PII in the text.
|
||||
func (r *PIIRedactor) FindPII(text string) []PIIFinding {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var findings []PIIFinding
|
||||
for _, pattern := range r.patterns {
|
||||
matches := pattern.Pattern.FindAllStringIndex(text, -1)
|
||||
for _, match := range matches {
|
||||
findings = append(findings, PIIFinding{
|
||||
Type: pattern.Name,
|
||||
Match: text[match[0]:match[1]],
|
||||
Start: match[0],
|
||||
End: match[1],
|
||||
})
|
||||
}
|
||||
}
|
||||
return findings
|
||||
}
|
||||
|
||||
// Default module-level redactor
|
||||
var defaultRedactor = NewDefaultPIIRedactor()
|
||||
|
||||
// RedactPII is a convenience function that uses the default redactor.
|
||||
func RedactPII(text string) string {
|
||||
return defaultRedactor.Redact(text)
|
||||
}
|
||||
|
||||
// ContainsPIIDefault checks if text contains PII using default patterns.
|
||||
func ContainsPIIDefault(text string) bool {
|
||||
return defaultRedactor.ContainsPII(text)
|
||||
}
|
||||
|
||||
// RedactMap redacts PII from all string values in a map.
|
||||
func RedactMap(data map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
for key, value := range data {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
result[key] = RedactPII(v)
|
||||
case map[string]interface{}:
|
||||
result[key] = RedactMap(v)
|
||||
case []interface{}:
|
||||
result[key] = redactSlice(v)
|
||||
default:
|
||||
result[key] = v
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func redactSlice(data []interface{}) []interface{} {
|
||||
result := make([]interface{}, len(data))
|
||||
for i, value := range data {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
result[i] = RedactPII(v)
|
||||
case map[string]interface{}:
|
||||
result[i] = RedactMap(v)
|
||||
case []interface{}:
|
||||
result[i] = redactSlice(v)
|
||||
default:
|
||||
result[i] = v
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SafeLogString creates a safe-to-log version of sensitive data.
|
||||
// Use this for logging user-related information.
|
||||
func SafeLogString(format string, args ...interface{}) string {
|
||||
// Convert args to strings and redact
|
||||
safeArgs := make([]interface{}, len(args))
|
||||
for i, arg := range args {
|
||||
switch v := arg.(type) {
|
||||
case string:
|
||||
safeArgs[i] = RedactPII(v)
|
||||
case error:
|
||||
safeArgs[i] = RedactPII(v.Error())
|
||||
default:
|
||||
safeArgs[i] = arg
|
||||
}
|
||||
}
|
||||
|
||||
// Note: We can't use fmt.Sprintf here due to the variadic nature
|
||||
// Instead, we redact the result
|
||||
result := format
|
||||
for _, arg := range safeArgs {
|
||||
if s, ok := arg.(string); ok {
|
||||
result = strings.Replace(result, "%s", s, 1)
|
||||
result = strings.Replace(result, "%v", s, 1)
|
||||
}
|
||||
}
|
||||
return RedactPII(result)
|
||||
}
|
||||
228
consent-service/internal/middleware/pii_redactor_test.go
Normal file
228
consent-service/internal/middleware/pii_redactor_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPIIRedactor_RedactsEmail(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "User test@example.com logged in"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("Email should have been redacted")
|
||||
}
|
||||
if result != "User [EMAIL_REDACTED] logged in" {
|
||||
t.Errorf("Unexpected result: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_RedactsIPv4(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "Request from 192.168.1.100"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("IP should have been redacted")
|
||||
}
|
||||
if result != "Request from [IP_REDACTED]" {
|
||||
t.Errorf("Unexpected result: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_RedactsGermanPhone(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"+49 30 12345678", "[PHONE_REDACTED]"},
|
||||
{"0049 30 12345678", "[PHONE_REDACTED]"},
|
||||
{"030 12345678", "[PHONE_REDACTED]"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := redactor.Redact(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("For input %q: expected %q, got %q", tt.input, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_RedactsMultiplePII(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "User test@example.com from 10.0.0.1"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result != "User [EMAIL_REDACTED] from [IP_REDACTED]" {
|
||||
t.Errorf("Unexpected result: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_PreservesNonPIIText(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "User logged in successfully"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result != text {
|
||||
t.Errorf("Text should be unchanged: got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_EmptyString(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
result := redactor.Redact("")
|
||||
if result != "" {
|
||||
t.Error("Empty string should remain empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsPII(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"test@example.com", true},
|
||||
{"192.168.1.1", true},
|
||||
{"+49 30 12345678", true},
|
||||
{"Hello World", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := redactor.ContainsPII(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("For input %q: expected %v, got %v", tt.input, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPII(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "Email: test@example.com, IP: 10.0.0.1"
|
||||
findings := redactor.FindPII(text)
|
||||
|
||||
if len(findings) != 2 {
|
||||
t.Errorf("Expected 2 findings, got %d", len(findings))
|
||||
}
|
||||
|
||||
hasEmail := false
|
||||
hasIP := false
|
||||
for _, f := range findings {
|
||||
if f.Type == "email" {
|
||||
hasEmail = true
|
||||
}
|
||||
if f.Type == "ip_v4" {
|
||||
hasIP = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasEmail {
|
||||
t.Error("Should have found email")
|
||||
}
|
||||
if !hasIP {
|
||||
t.Error("Should have found IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactPII_GlobalFunction(t *testing.T) {
|
||||
text := "User test@example.com logged in"
|
||||
result := RedactPII(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("Email should have been redacted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsPIIDefault(t *testing.T) {
|
||||
if !ContainsPIIDefault("test@example.com") {
|
||||
t.Error("Should detect email as PII")
|
||||
}
|
||||
if ContainsPIIDefault("Hello World") {
|
||||
t.Error("Should not detect non-PII text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactMap(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"message": "Hello World",
|
||||
"nested": map[string]interface{}{
|
||||
"ip": "192.168.1.1",
|
||||
},
|
||||
}
|
||||
|
||||
result := RedactMap(data)
|
||||
|
||||
if result["email"] != "[EMAIL_REDACTED]" {
|
||||
t.Errorf("Email should be redacted: %v", result["email"])
|
||||
}
|
||||
if result["message"] != "Hello World" {
|
||||
t.Errorf("Non-PII should be unchanged: %v", result["message"])
|
||||
}
|
||||
|
||||
nested := result["nested"].(map[string]interface{})
|
||||
if nested["ip"] != "[IP_REDACTED]" {
|
||||
t.Errorf("Nested IP should be redacted: %v", nested["ip"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllPIIPatterns(t *testing.T) {
|
||||
patterns := AllPIIPatterns()
|
||||
|
||||
if len(patterns) == 0 {
|
||||
t.Error("Should have PII patterns")
|
||||
}
|
||||
|
||||
// Check that we have the expected patterns
|
||||
expectedNames := []string{"email", "ip_v4", "ip_v6", "phone", "iban", "uuid", "name"}
|
||||
nameMap := make(map[string]bool)
|
||||
for _, p := range patterns {
|
||||
nameMap[p.Name] = true
|
||||
}
|
||||
|
||||
for _, name := range expectedNames {
|
||||
if !nameMap[name] {
|
||||
t.Errorf("Missing expected pattern: %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPIIPatterns(t *testing.T) {
|
||||
patterns := DefaultPIIPatterns()
|
||||
|
||||
if len(patterns) != 4 {
|
||||
t.Errorf("Expected 4 default patterns, got %d", len(patterns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIBANRedaction(t *testing.T) {
|
||||
redactor := NewPIIRedactor(AllPIIPatterns())
|
||||
|
||||
text := "IBAN: DE89 3704 0044 0532 0130 00"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("IBAN should have been redacted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDRedaction(t *testing.T) {
|
||||
redactor := NewPIIRedactor(AllPIIPatterns())
|
||||
|
||||
text := "User ID: a0000000-0000-0000-0000-000000000001"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("UUID should have been redacted")
|
||||
}
|
||||
}
|
||||
75
consent-service/internal/middleware/request_id.go
Normal file
75
consent-service/internal/middleware/request_id.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// RequestIDHeader is the primary header for request IDs
|
||||
RequestIDHeader = "X-Request-ID"
|
||||
// CorrelationIDHeader is an alternative header for distributed tracing
|
||||
CorrelationIDHeader = "X-Correlation-ID"
|
||||
// RequestIDKey is the context key for storing the request ID
|
||||
RequestIDKey = "request_id"
|
||||
)
|
||||
|
||||
// RequestID returns a middleware that generates and propagates request IDs.
|
||||
//
|
||||
// For each incoming request:
|
||||
// 1. Check for existing X-Request-ID or X-Correlation-ID header
|
||||
// 2. If not present, generate a new UUID
|
||||
// 3. Store in Gin context for use by handlers and logging
|
||||
// 4. Add to response headers
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// r.Use(middleware.RequestID())
|
||||
//
|
||||
// func handler(c *gin.Context) {
|
||||
// requestID := middleware.GetRequestID(c)
|
||||
// log.Printf("[%s] Processing request", requestID)
|
||||
// }
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Try to get existing request ID from headers
|
||||
requestID := c.GetHeader(RequestIDHeader)
|
||||
if requestID == "" {
|
||||
requestID = c.GetHeader(CorrelationIDHeader)
|
||||
}
|
||||
|
||||
// Generate new ID if not provided
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// Store in context for handlers and logging
|
||||
c.Set(RequestIDKey, requestID)
|
||||
|
||||
// Add to response headers
|
||||
c.Header(RequestIDHeader, requestID)
|
||||
c.Header(CorrelationIDHeader, requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestID retrieves the request ID from the Gin context.
|
||||
// Returns empty string if no request ID is set.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// requestID := middleware.GetRequestID(c)
|
||||
func GetRequestID(c *gin.Context) string {
|
||||
if id, exists := c.Get(RequestIDKey); exists {
|
||||
if idStr, ok := id.(string); ok {
|
||||
return idStr
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// RequestIDFromContext is an alias for GetRequestID for API compatibility.
|
||||
func RequestIDFromContext(c *gin.Context) string {
|
||||
return GetRequestID(c)
|
||||
}
|
||||
152
consent-service/internal/middleware/request_id_test.go
Normal file
152
consent-service/internal/middleware/request_id_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestRequestID_GeneratesNewID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
requestID := GetRequestID(c)
|
||||
if requestID == "" {
|
||||
t.Error("Expected request ID to be set")
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Check response header
|
||||
requestID := w.Header().Get(RequestIDHeader)
|
||||
if requestID == "" {
|
||||
t.Error("Expected X-Request-ID header in response")
|
||||
}
|
||||
|
||||
// Check correlation ID header
|
||||
correlationID := w.Header().Get(CorrelationIDHeader)
|
||||
if correlationID == "" {
|
||||
t.Error("Expected X-Correlation-ID header in response")
|
||||
}
|
||||
|
||||
if requestID != correlationID {
|
||||
t.Error("X-Request-ID and X-Correlation-ID should match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID_PropagatesExistingID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
|
||||
customID := "custom-request-id-12345"
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
requestID := GetRequestID(c)
|
||||
if requestID != customID {
|
||||
t.Errorf("Expected request ID %s, got %s", customID, requestID)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(RequestIDHeader, customID)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
responseID := w.Header().Get(RequestIDHeader)
|
||||
if responseID != customID {
|
||||
t.Errorf("Expected response header %s, got %s", customID, responseID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID_PropagatesCorrelationID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
|
||||
correlationID := "correlation-id-67890"
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
requestID := GetRequestID(c)
|
||||
if requestID != correlationID {
|
||||
t.Errorf("Expected request ID %s, got %s", correlationID, requestID)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(CorrelationIDHeader, correlationID)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Both headers should be set with the correlation ID
|
||||
if w.Header().Get(RequestIDHeader) != correlationID {
|
||||
t.Error("X-Request-ID should match X-Correlation-ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequestID_ReturnsEmptyWhenNotSet(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// No RequestID middleware
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
requestID := GetRequestID(c)
|
||||
if requestID != "" {
|
||||
t.Errorf("Expected empty request ID, got %s", requestID)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestIDFromContext_IsAliasForGetRequestID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
id1 := GetRequestID(c)
|
||||
id2 := RequestIDFromContext(c)
|
||||
if id1 != id2 {
|
||||
t.Errorf("GetRequestID and RequestIDFromContext should return same value")
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
167
consent-service/internal/middleware/security_headers.go
Normal file
167
consent-service/internal/middleware/security_headers.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SecurityHeadersConfig holds configuration for security headers.
|
||||
type SecurityHeadersConfig struct {
|
||||
// X-Content-Type-Options
|
||||
ContentTypeOptions string
|
||||
|
||||
// X-Frame-Options
|
||||
FrameOptions string
|
||||
|
||||
// X-XSS-Protection (legacy but useful for older browsers)
|
||||
XSSProtection string
|
||||
|
||||
// Strict-Transport-Security
|
||||
HSTSEnabled bool
|
||||
HSTSMaxAge int
|
||||
HSTSIncludeSubdomains bool
|
||||
HSTSPreload bool
|
||||
|
||||
// Content-Security-Policy
|
||||
CSPEnabled bool
|
||||
CSPPolicy string
|
||||
|
||||
// Referrer-Policy
|
||||
ReferrerPolicy string
|
||||
|
||||
// Permissions-Policy
|
||||
PermissionsPolicy string
|
||||
|
||||
// Cross-Origin headers
|
||||
CrossOriginOpenerPolicy string
|
||||
CrossOriginResourcePolicy string
|
||||
|
||||
// Development mode (relaxes some restrictions)
|
||||
DevelopmentMode bool
|
||||
|
||||
// Excluded paths (e.g., health checks)
|
||||
ExcludedPaths []string
|
||||
}
|
||||
|
||||
// DefaultSecurityHeadersConfig returns sensible default configuration.
|
||||
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
|
||||
env := os.Getenv("ENVIRONMENT")
|
||||
isDev := env == "" || strings.ToLower(env) == "development" || strings.ToLower(env) == "dev"
|
||||
|
||||
return SecurityHeadersConfig{
|
||||
ContentTypeOptions: "nosniff",
|
||||
FrameOptions: "DENY",
|
||||
XSSProtection: "1; mode=block",
|
||||
HSTSEnabled: true,
|
||||
HSTSMaxAge: 31536000, // 1 year
|
||||
HSTSIncludeSubdomains: true,
|
||||
HSTSPreload: false,
|
||||
CSPEnabled: true,
|
||||
CSPPolicy: getDefaultCSP(isDev),
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
PermissionsPolicy: "geolocation=(), microphone=(), camera=()",
|
||||
CrossOriginOpenerPolicy: "same-origin",
|
||||
CrossOriginResourcePolicy: "same-origin",
|
||||
DevelopmentMode: isDev,
|
||||
ExcludedPaths: []string{"/health", "/metrics", "/api/v1/health"},
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultCSP returns a sensible default CSP for the environment.
|
||||
func getDefaultCSP(isDevelopment bool) string {
|
||||
if isDevelopment {
|
||||
return "default-src 'self' localhost:* ws://localhost:*; " +
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
"img-src 'self' data: https: blob:; " +
|
||||
"font-src 'self' data:; " +
|
||||
"connect-src 'self' localhost:* ws://localhost:* https:; " +
|
||||
"frame-ancestors 'self'"
|
||||
}
|
||||
return "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline'; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
"img-src 'self' data: https:; " +
|
||||
"font-src 'self' data:; " +
|
||||
"connect-src 'self' https://breakpilot.app https://*.breakpilot.app; " +
|
||||
"frame-ancestors 'none'"
|
||||
}
|
||||
|
||||
// buildHSTSHeader builds the Strict-Transport-Security header value.
|
||||
func (c *SecurityHeadersConfig) buildHSTSHeader() string {
|
||||
parts := []string{"max-age=" + strconv.Itoa(c.HSTSMaxAge)}
|
||||
if c.HSTSIncludeSubdomains {
|
||||
parts = append(parts, "includeSubDomains")
|
||||
}
|
||||
if c.HSTSPreload {
|
||||
parts = append(parts, "preload")
|
||||
}
|
||||
return strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// isExcludedPath checks if the path should be excluded from security headers.
|
||||
func (c *SecurityHeadersConfig) isExcludedPath(path string) bool {
|
||||
for _, excluded := range c.ExcludedPaths {
|
||||
if path == excluded {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SecurityHeaders returns a middleware that adds security headers to all responses.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// r.Use(middleware.SecurityHeaders())
|
||||
//
|
||||
// // Or with custom config:
|
||||
// config := middleware.DefaultSecurityHeadersConfig()
|
||||
// config.CSPPolicy = "default-src 'self'"
|
||||
// r.Use(middleware.SecurityHeadersWithConfig(config))
|
||||
func SecurityHeaders() gin.HandlerFunc {
|
||||
return SecurityHeadersWithConfig(DefaultSecurityHeadersConfig())
|
||||
}
|
||||
|
||||
// SecurityHeadersWithConfig returns a security headers middleware with custom configuration.
|
||||
func SecurityHeadersWithConfig(config SecurityHeadersConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip for excluded paths
|
||||
if config.isExcludedPath(c.Request.URL.Path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Always add these headers
|
||||
c.Header("X-Content-Type-Options", config.ContentTypeOptions)
|
||||
c.Header("X-Frame-Options", config.FrameOptions)
|
||||
c.Header("X-XSS-Protection", config.XSSProtection)
|
||||
c.Header("Referrer-Policy", config.ReferrerPolicy)
|
||||
|
||||
// HSTS (only in production or if explicitly enabled)
|
||||
if config.HSTSEnabled && !config.DevelopmentMode {
|
||||
c.Header("Strict-Transport-Security", config.buildHSTSHeader())
|
||||
}
|
||||
|
||||
// Content-Security-Policy
|
||||
if config.CSPEnabled && config.CSPPolicy != "" {
|
||||
c.Header("Content-Security-Policy", config.CSPPolicy)
|
||||
}
|
||||
|
||||
// Permissions-Policy
|
||||
if config.PermissionsPolicy != "" {
|
||||
c.Header("Permissions-Policy", config.PermissionsPolicy)
|
||||
}
|
||||
|
||||
// Cross-Origin headers (only in production)
|
||||
if !config.DevelopmentMode {
|
||||
c.Header("Cross-Origin-Opener-Policy", config.CrossOriginOpenerPolicy)
|
||||
c.Header("Cross-Origin-Resource-Policy", config.CrossOriginResourcePolicy)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
377
consent-service/internal/middleware/security_headers_test.go
Normal file
377
consent-service/internal/middleware/security_headers_test.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSecurityHeaders_AddsBasicHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = true // Skip HSTS and cross-origin headers
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Check basic security headers
|
||||
tests := []struct {
|
||||
header string
|
||||
expected string
|
||||
}{
|
||||
{"X-Content-Type-Options", "nosniff"},
|
||||
{"X-Frame-Options", "DENY"},
|
||||
{"X-XSS-Protection", "1; mode=block"},
|
||||
{"Referrer-Policy", "strict-origin-when-cross-origin"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
value := w.Header().Get(tt.header)
|
||||
if value != tt.expected {
|
||||
t.Errorf("Header %s: expected %q, got %q", tt.header, tt.expected, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_HSTSNotAddedInDevelopment(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = true
|
||||
config.HSTSEnabled = true
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
hstsHeader := w.Header().Get("Strict-Transport-Security")
|
||||
if hstsHeader != "" {
|
||||
t.Errorf("HSTS should not be set in development mode, got: %s", hstsHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_HSTSAddedInProduction(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = false
|
||||
config.HSTSEnabled = true
|
||||
config.HSTSMaxAge = 31536000
|
||||
config.HSTSIncludeSubdomains = true
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
hstsHeader := w.Header().Get("Strict-Transport-Security")
|
||||
if hstsHeader == "" {
|
||||
t.Error("HSTS should be set in production mode")
|
||||
}
|
||||
|
||||
// Check that it contains max-age
|
||||
if hstsHeader != "max-age=31536000; includeSubDomains" {
|
||||
t.Errorf("Unexpected HSTS value: %s", hstsHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_HSTSWithPreload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = false
|
||||
config.HSTSEnabled = true
|
||||
config.HSTSMaxAge = 31536000
|
||||
config.HSTSIncludeSubdomains = true
|
||||
config.HSTSPreload = true
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
hstsHeader := w.Header().Get("Strict-Transport-Security")
|
||||
expected := "max-age=31536000; includeSubDomains; preload"
|
||||
if hstsHeader != expected {
|
||||
t.Errorf("Expected HSTS %q, got %q", expected, hstsHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_CSPHeader(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.CSPEnabled = true
|
||||
config.CSPPolicy = "default-src 'self'"
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
cspHeader := w.Header().Get("Content-Security-Policy")
|
||||
if cspHeader != "default-src 'self'" {
|
||||
t.Errorf("Expected CSP %q, got %q", "default-src 'self'", cspHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_NoCSPWhenDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.CSPEnabled = false
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
cspHeader := w.Header().Get("Content-Security-Policy")
|
||||
if cspHeader != "" {
|
||||
t.Errorf("CSP should not be set when disabled, got: %s", cspHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_ExcludedPaths(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.ExcludedPaths = []string{"/health", "/metrics"}
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
|
||||
})
|
||||
|
||||
router.GET("/api", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Test excluded path
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("X-Content-Type-Options") != "" {
|
||||
t.Error("Security headers should not be set for excluded paths")
|
||||
}
|
||||
|
||||
// Test non-excluded path
|
||||
req = httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("X-Content-Type-Options") != "nosniff" {
|
||||
t.Error("Security headers should be set for non-excluded paths")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_CrossOriginInProduction(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = false
|
||||
config.CrossOriginOpenerPolicy = "same-origin"
|
||||
config.CrossOriginResourcePolicy = "same-origin"
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
coopHeader := w.Header().Get("Cross-Origin-Opener-Policy")
|
||||
if coopHeader != "same-origin" {
|
||||
t.Errorf("Expected COOP %q, got %q", "same-origin", coopHeader)
|
||||
}
|
||||
|
||||
corpHeader := w.Header().Get("Cross-Origin-Resource-Policy")
|
||||
if corpHeader != "same-origin" {
|
||||
t.Errorf("Expected CORP %q, got %q", "same-origin", corpHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_NoCrossOriginInDevelopment(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = true
|
||||
config.CrossOriginOpenerPolicy = "same-origin"
|
||||
config.CrossOriginResourcePolicy = "same-origin"
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Cross-Origin-Opener-Policy") != "" {
|
||||
t.Error("COOP should not be set in development mode")
|
||||
}
|
||||
|
||||
if w.Header().Get("Cross-Origin-Resource-Policy") != "" {
|
||||
t.Error("CORP should not be set in development mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_PermissionsPolicy(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.PermissionsPolicy = "geolocation=(), microphone=()"
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
ppHeader := w.Header().Get("Permissions-Policy")
|
||||
if ppHeader != "geolocation=(), microphone=()" {
|
||||
t.Errorf("Expected Permissions-Policy %q, got %q", "geolocation=(), microphone=()", ppHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_DefaultMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// Use the default middleware function
|
||||
router.Use(SecurityHeaders())
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Should at least have the basic headers
|
||||
if w.Header().Get("X-Content-Type-Options") != "nosniff" {
|
||||
t.Error("Default middleware should set X-Content-Type-Options")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildHSTSHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config SecurityHeadersConfig
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic HSTS",
|
||||
config: SecurityHeadersConfig{
|
||||
HSTSMaxAge: 31536000,
|
||||
HSTSIncludeSubdomains: false,
|
||||
HSTSPreload: false,
|
||||
},
|
||||
expected: "max-age=31536000",
|
||||
},
|
||||
{
|
||||
name: "HSTS with subdomains",
|
||||
config: SecurityHeadersConfig{
|
||||
HSTSMaxAge: 31536000,
|
||||
HSTSIncludeSubdomains: true,
|
||||
HSTSPreload: false,
|
||||
},
|
||||
expected: "max-age=31536000; includeSubDomains",
|
||||
},
|
||||
{
|
||||
name: "HSTS with preload",
|
||||
config: SecurityHeadersConfig{
|
||||
HSTSMaxAge: 31536000,
|
||||
HSTSIncludeSubdomains: true,
|
||||
HSTSPreload: true,
|
||||
},
|
||||
expected: "max-age=31536000; includeSubDomains; preload",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.config.buildHSTSHeader()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsExcludedPath(t *testing.T) {
|
||||
config := SecurityHeadersConfig{
|
||||
ExcludedPaths: []string{"/health", "/metrics", "/api/v1/health"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
excluded bool
|
||||
}{
|
||||
{"/health", true},
|
||||
{"/metrics", true},
|
||||
{"/api/v1/health", true},
|
||||
{"/api", false},
|
||||
{"/health/check", false},
|
||||
{"/", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := config.isExcludedPath(tt.path)
|
||||
if result != tt.excluded {
|
||||
t.Errorf("Path %s: expected excluded=%v, got %v", tt.path, tt.excluded, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user