Initial commit: breakpilot-core - Shared Infrastructure

Docker Compose with 24+ services:
- PostgreSQL (PostGIS), Valkey, MinIO, Qdrant
- Vault (PKI/TLS), Nginx (Reverse Proxy)
- Backend Core API, Consent Service, Billing Service
- RAG Service, Embedding Service
- Gitea, Woodpecker CI/CD
- Night Scheduler, Health Aggregator
- Jitsi (Web/XMPP/JVB/Jicofo), Mailpit

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Boenisch
2026-02-11 23:47:13 +01:00
commit ad111d5e69
244 changed files with 84288 additions and 0 deletions

View File

@@ -0,0 +1,247 @@
package middleware
import (
"net/http"
"os"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// InputGateConfig holds configuration for input validation.
type InputGateConfig struct {
// Maximum request body size (default: 10MB)
MaxBodySize int64
// Maximum file upload size (default: 50MB)
MaxFileSize int64
// Allowed content types
AllowedContentTypes map[string]bool
// Allowed file types for uploads
AllowedFileTypes map[string]bool
// Blocked file extensions
BlockedExtensions map[string]bool
// Paths that allow larger uploads
LargeUploadPaths []string
// Paths excluded from validation
ExcludedPaths []string
// Enable strict content type checking
StrictContentType bool
}
// DefaultInputGateConfig returns sensible default configuration.
func DefaultInputGateConfig() InputGateConfig {
maxSize := int64(10 * 1024 * 1024) // 10MB
if envSize := os.Getenv("MAX_REQUEST_BODY_SIZE"); envSize != "" {
if size, err := strconv.ParseInt(envSize, 10, 64); err == nil {
maxSize = size
}
}
return InputGateConfig{
MaxBodySize: maxSize,
MaxFileSize: 50 * 1024 * 1024, // 50MB
AllowedContentTypes: map[string]bool{
"application/json": true,
"application/x-www-form-urlencoded": true,
"multipart/form-data": true,
"text/plain": true,
},
AllowedFileTypes: map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/gif": true,
"image/webp": true,
"application/pdf": true,
"text/csv": true,
"application/msword": true,
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": true,
"application/vnd.ms-excel": true,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true,
},
BlockedExtensions: map[string]bool{
".exe": true, ".bat": true, ".cmd": true, ".com": true, ".msi": true,
".dll": true, ".scr": true, ".pif": true, ".vbs": true, ".js": true,
".jar": true, ".sh": true, ".ps1": true, ".app": true,
},
LargeUploadPaths: []string{
"/api/v1/files/upload",
"/api/v1/documents/upload",
"/api/v1/attachments",
},
ExcludedPaths: []string{
"/health",
"/metrics",
"/api/v1/health",
},
StrictContentType: true,
}
}
// isExcludedPath checks if path is excluded from validation.
func (c *InputGateConfig) isExcludedPath(path string) bool {
for _, excluded := range c.ExcludedPaths {
if path == excluded {
return true
}
}
return false
}
// isLargeUploadPath checks if path allows larger uploads.
func (c *InputGateConfig) isLargeUploadPath(path string) bool {
for _, uploadPath := range c.LargeUploadPaths {
if strings.HasPrefix(path, uploadPath) {
return true
}
}
return false
}
// getMaxSize returns the maximum allowed body size for the path.
func (c *InputGateConfig) getMaxSize(path string) int64 {
if c.isLargeUploadPath(path) {
return c.MaxFileSize
}
return c.MaxBodySize
}
// validateContentType validates the content type.
func (c *InputGateConfig) validateContentType(contentType string) (bool, string) {
if contentType == "" {
return true, ""
}
// Extract base content type (remove charset, boundary, etc.)
baseType := strings.Split(contentType, ";")[0]
baseType = strings.TrimSpace(strings.ToLower(baseType))
if !c.AllowedContentTypes[baseType] {
return false, "Content-Type '" + baseType + "' is not allowed"
}
return true, ""
}
// hasBlockedExtension checks if filename has a blocked extension.
func (c *InputGateConfig) hasBlockedExtension(filename string) bool {
if filename == "" {
return false
}
lowerFilename := strings.ToLower(filename)
for ext := range c.BlockedExtensions {
if strings.HasSuffix(lowerFilename, ext) {
return true
}
}
return false
}
// InputGate returns a middleware that validates incoming request bodies.
//
// Usage:
//
// r.Use(middleware.InputGate())
//
// // Or with custom config:
// config := middleware.DefaultInputGateConfig()
// config.MaxBodySize = 5 * 1024 * 1024 // 5MB
// r.Use(middleware.InputGateWithConfig(config))
func InputGate() gin.HandlerFunc {
return InputGateWithConfig(DefaultInputGateConfig())
}
// InputGateWithConfig returns an input gate middleware with custom configuration.
func InputGateWithConfig(config InputGateConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip excluded paths
if config.isExcludedPath(c.Request.URL.Path) {
c.Next()
return
}
// Skip validation for GET, HEAD, OPTIONS requests
method := c.Request.Method
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
c.Next()
return
}
// Validate content type for requests with body
contentType := c.GetHeader("Content-Type")
if config.StrictContentType {
valid, errMsg := config.validateContentType(contentType)
if !valid {
c.AbortWithStatusJSON(http.StatusUnsupportedMediaType, gin.H{
"error": "unsupported_media_type",
"message": errMsg,
})
return
}
}
// Check Content-Length header
contentLength := c.GetHeader("Content-Length")
if contentLength != "" {
length, err := strconv.ParseInt(contentLength, 10, 64)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "invalid_content_length",
"message": "Invalid Content-Length header",
})
return
}
maxSize := config.getMaxSize(c.Request.URL.Path)
if length > maxSize {
c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, gin.H{
"error": "payload_too_large",
"message": "Request body exceeds maximum size",
"max_size": maxSize,
})
return
}
}
// Set max multipart memory for file uploads
if strings.Contains(contentType, "multipart/form-data") {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, config.MaxFileSize)
}
c.Next()
}
}
// ValidateFileUpload validates a file upload.
// Use this in upload handlers for detailed validation.
func ValidateFileUpload(filename, contentType string, size int64, config *InputGateConfig) (bool, string) {
if config == nil {
defaultConfig := DefaultInputGateConfig()
config = &defaultConfig
}
// Check size
if size > config.MaxFileSize {
return false, "File size exceeds maximum allowed"
}
// Check extension
if config.hasBlockedExtension(filename) {
return false, "File extension is not allowed"
}
// Check content type
if contentType != "" && !config.AllowedFileTypes[contentType] {
return false, "File type '" + contentType + "' is not allowed"
}
return true, ""
}

View File

@@ -0,0 +1,421 @@
package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestInputGate_AllowsGETRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for GET request, got %d", w.Code)
}
}
func TestInputGate_AllowsHEADRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.HEAD("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodHead, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for HEAD request, got %d", w.Code)
}
}
func TestInputGate_AllowsOPTIONSRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.OPTIONS("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for OPTIONS request, got %d", w.Code)
}
}
func TestInputGate_AllowsValidJSONRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`{"key": "value"}`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", "16")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for valid JSON, got %d", w.Code)
}
}
func TestInputGate_RejectsInvalidContentType(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultInputGateConfig()
config.StrictContentType = true
router.Use(InputGateWithConfig(config))
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`data`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/xml")
req.Header.Set("Content-Length", "4")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnsupportedMediaType {
t.Errorf("Expected status 415 for invalid content type, got %d", w.Code)
}
}
func TestInputGate_AllowsEmptyContentType(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`data`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
// No Content-Type header
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for empty content type, got %d", w.Code)
}
}
func TestInputGate_RejectsOversizedRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultInputGateConfig()
config.MaxBodySize = 100 // 100 bytes
router.Use(InputGateWithConfig(config))
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Create a body larger than 100 bytes
largeBody := strings.Repeat("x", 200)
body := bytes.NewBufferString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", "200")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("Expected status 413 for oversized request, got %d", w.Code)
}
}
func TestInputGate_AllowsLargeUploadPath(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultInputGateConfig()
config.MaxBodySize = 100 // 100 bytes
config.MaxFileSize = 1000 // 1000 bytes
config.LargeUploadPaths = []string{"/api/v1/files/upload"}
router.Use(InputGateWithConfig(config))
router.POST("/api/v1/files/upload", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Create a body larger than MaxBodySize but smaller than MaxFileSize
largeBody := strings.Repeat("x", 500)
body := bytes.NewBufferString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/files/upload", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", "500")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for large upload path, got %d", w.Code)
}
}
func TestInputGate_ExcludedPaths(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultInputGateConfig()
config.MaxBodySize = 10 // Very small
config.ExcludedPaths = []string{"/health"}
router.Use(InputGateWithConfig(config))
router.POST("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
})
// Send oversized body to excluded path
largeBody := strings.Repeat("x", 100)
body := bytes.NewBufferString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/health", body)
req.Header.Set("Content-Length", "100")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Should pass because path is excluded
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for excluded path, got %d", w.Code)
}
}
func TestInputGate_RejectsInvalidContentLength(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`data`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", "invalid")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid content length, got %d", w.Code)
}
}
func TestValidateFileUpload_BlockedExtension(t *testing.T) {
tests := []struct {
filename string
contentType string
blocked bool
}{
{"malware.exe", "application/octet-stream", true},
{"script.bat", "application/octet-stream", true},
{"hack.cmd", "application/octet-stream", true},
{"shell.sh", "application/octet-stream", true},
{"powershell.ps1", "application/octet-stream", true},
{"document.pdf", "application/pdf", false},
{"image.jpg", "image/jpeg", false},
{"data.csv", "text/csv", false},
}
for _, tt := range tests {
valid, errMsg := ValidateFileUpload(tt.filename, tt.contentType, 100, nil)
if tt.blocked && valid {
t.Errorf("File %s should be blocked", tt.filename)
}
if !tt.blocked && !valid {
t.Errorf("File %s should not be blocked, error: %s", tt.filename, errMsg)
}
}
}
func TestValidateFileUpload_OversizedFile(t *testing.T) {
config := DefaultInputGateConfig()
config.MaxFileSize = 1000 // 1KB
valid, errMsg := ValidateFileUpload("test.pdf", "application/pdf", 2000, &config)
if valid {
t.Error("Should reject oversized file")
}
if !strings.Contains(errMsg, "size") {
t.Errorf("Error message should mention size, got: %s", errMsg)
}
}
func TestValidateFileUpload_ValidFile(t *testing.T) {
config := DefaultInputGateConfig()
valid, errMsg := ValidateFileUpload("document.pdf", "application/pdf", 1000, &config)
if !valid {
t.Errorf("Should accept valid file, got error: %s", errMsg)
}
}
func TestValidateFileUpload_InvalidContentType(t *testing.T) {
config := DefaultInputGateConfig()
valid, errMsg := ValidateFileUpload("file.xyz", "application/x-unknown", 100, &config)
if valid {
t.Error("Should reject unknown file type")
}
if !strings.Contains(errMsg, "not allowed") {
t.Errorf("Error message should mention not allowed, got: %s", errMsg)
}
}
func TestValidateFileUpload_NilConfig(t *testing.T) {
// Should use default config when nil is passed
valid, _ := ValidateFileUpload("document.pdf", "application/pdf", 1000, nil)
if !valid {
t.Error("Should accept valid file with nil config (uses defaults)")
}
}
func TestHasBlockedExtension(t *testing.T) {
config := DefaultInputGateConfig()
tests := []struct {
filename string
blocked bool
}{
{"test.exe", true},
{"TEST.EXE", true}, // Case insensitive
{"script.BAT", true},
{"app.APP", true},
{"document.pdf", false},
{"image.png", false},
{"", false},
}
for _, tt := range tests {
result := config.hasBlockedExtension(tt.filename)
if result != tt.blocked {
t.Errorf("File %s: expected blocked=%v, got %v", tt.filename, tt.blocked, result)
}
}
}
func TestValidateContentType(t *testing.T) {
config := DefaultInputGateConfig()
tests := []struct {
contentType string
valid bool
}{
{"application/json", true},
{"application/json; charset=utf-8", true},
{"APPLICATION/JSON", true}, // Case insensitive
{"multipart/form-data; boundary=----WebKitFormBoundary", true},
{"text/plain", true},
{"application/xml", false},
{"text/html", false},
{"", true}, // Empty is allowed
}
for _, tt := range tests {
valid, _ := config.validateContentType(tt.contentType)
if valid != tt.valid {
t.Errorf("Content-Type %q: expected valid=%v, got %v", tt.contentType, tt.valid, valid)
}
}
}
func TestIsLargeUploadPath(t *testing.T) {
config := DefaultInputGateConfig()
config.LargeUploadPaths = []string{"/api/v1/files/upload", "/api/v1/documents"}
tests := []struct {
path string
isLarge bool
}{
{"/api/v1/files/upload", true},
{"/api/v1/files/upload/batch", true}, // Prefix match
{"/api/v1/documents", true},
{"/api/v1/documents/1/attachments", true},
{"/api/v1/users", false},
{"/health", false},
}
for _, tt := range tests {
result := config.isLargeUploadPath(tt.path)
if result != tt.isLarge {
t.Errorf("Path %s: expected isLarge=%v, got %v", tt.path, tt.isLarge, result)
}
}
}
func TestGetMaxSize(t *testing.T) {
config := DefaultInputGateConfig()
config.MaxBodySize = 100
config.MaxFileSize = 1000
config.LargeUploadPaths = []string{"/api/v1/files/upload"}
tests := []struct {
path string
expected int64
}{
{"/api/test", 100},
{"/api/v1/files/upload", 1000},
{"/health", 100},
}
for _, tt := range tests {
result := config.getMaxSize(tt.path)
if result != tt.expected {
t.Errorf("Path %s: expected maxSize=%d, got %d", tt.path, tt.expected, result)
}
}
}
func TestInputGate_DefaultMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`{"key": "value"}`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}

View File

@@ -0,0 +1,379 @@
package middleware
import (
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// UserClaims represents the JWT claims for a user
type UserClaims struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// CORS returns a CORS middleware
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
// Allow localhost for development
allowedOrigins := []string{
"http://localhost:3000",
"http://localhost:8000",
"http://localhost:8080",
"https://breakpilot.app",
}
allowed := false
for _, o := range allowedOrigins {
if origin == o {
allowed = true
break
}
}
if allowed {
c.Header("Access-Control-Allow-Origin", origin)
}
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// RequestLogger logs each request
func RequestLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
method := c.Request.Method
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
// Log only in development or for errors
if status >= 400 {
gin.DefaultWriter.Write([]byte(
method + " " + path + " " +
string(rune(status)) + " " +
latency.String() + "\n",
))
}
}
}
// RateLimiter implements a simple in-memory rate limiter
// Configurable via RATE_LIMIT_PER_MINUTE env var (default: 500)
func RateLimiter() gin.HandlerFunc {
type client struct {
count int
lastSeen time.Time
}
var (
mu sync.Mutex
clients = make(map[string]*client)
)
// Clean up old entries periodically
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
for ip, c := range clients {
if time.Since(c.lastSeen) > time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(c *gin.Context) {
ip := c.ClientIP()
// Skip rate limiting for Docker internal network (172.x.x.x) and localhost
// This prevents issues when multiple services share the same internal IP
if strings.HasPrefix(ip, "172.") || ip == "127.0.0.1" || ip == "::1" {
c.Next()
return
}
mu.Lock()
defer mu.Unlock()
if _, exists := clients[ip]; !exists {
clients[ip] = &client{}
}
cli := clients[ip]
// Reset count if more than a minute has passed
if time.Since(cli.lastSeen) > time.Minute {
cli.count = 0
}
cli.count++
cli.lastSeen = time.Now()
// Allow 500 requests per minute (increased for admin panels with many API calls)
if cli.count > 500 {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate_limit_exceeded",
"message": "Too many requests. Please try again later.",
})
return
}
c.Next()
}
}
// AuthMiddleware validates JWT tokens
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "missing_authorization",
"message": "Authorization header is required",
})
return
}
// Extract token from "Bearer <token>"
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_authorization",
"message": "Authorization header must be in format: Bearer <token>",
})
return
}
tokenString := parts[1]
// Parse and validate token
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(jwtSecret), nil
})
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_token",
"message": "Invalid or expired token",
})
return
}
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
// Set user info in context
c.Set("user_id", claims.UserID)
c.Set("email", claims.Email)
c.Set("role", claims.Role)
c.Next()
} else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_claims",
"message": "Invalid token claims",
})
return
}
}
}
// AdminOnly ensures only admin users can access the route
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User role not found",
})
return
}
roleStr, ok := role.(string)
if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "forbidden",
"message": "Admin access required",
})
return
}
c.Next()
}
}
// DSBOnly ensures only Data Protection Officers can access the route
// Used for critical operations like publishing legal documents (four-eyes principle)
func DSBOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User role not found",
})
return
}
roleStr, ok := role.(string)
if !ok || (roleStr != "data_protection_officer" && roleStr != "super_admin") {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "forbidden",
"message": "Only Data Protection Officers can perform this action",
})
return
}
c.Next()
}
}
// IsAdmin checks if the user has admin role
func IsAdmin(c *gin.Context) bool {
role, exists := c.Get("role")
if !exists {
return false
}
roleStr, ok := role.(string)
return ok && (roleStr == "admin" || roleStr == "super_admin" || roleStr == "data_protection_officer")
}
// IsDSB checks if the user has DSB role
func IsDSB(c *gin.Context) bool {
role, exists := c.Get("role")
if !exists {
return false
}
roleStr, ok := role.(string)
return ok && (roleStr == "data_protection_officer" || roleStr == "super_admin")
}
// GetUserID extracts the user ID from the context
func GetUserID(c *gin.Context) (uuid.UUID, error) {
userIDStr, exists := c.Get("user_id")
if !exists {
return uuid.Nil, nil
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
return uuid.Nil, err
}
return userID, nil
}
// GetClientIP returns the client's IP address
func GetClientIP(c *gin.Context) string {
// Check X-Forwarded-For header first (for proxied requests)
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// GetUserAgent returns the client's User-Agent
func GetUserAgent(c *gin.Context) string {
return c.GetHeader("User-Agent")
}
// SuspensionCheckMiddleware checks if a user is suspended and restricts access
// Suspended users can only access consent-related endpoints
func SuspensionCheckMiddleware(pool interface{ QueryRow(ctx interface{}, sql string, args ...interface{}) interface{ Scan(dest ...interface{}) error } }) gin.HandlerFunc {
return func(c *gin.Context) {
userIDStr, exists := c.Get("user_id")
if !exists {
c.Next()
return
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
c.Next()
return
}
// Check user account status
var accountStatus string
err = pool.QueryRow(c.Request.Context(), `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus)
if err != nil {
c.Next()
return
}
if accountStatus == "suspended" {
// Check if current path is allowed for suspended users
path := c.Request.URL.Path
allowedPaths := []string{
"/api/v1/consent",
"/api/v1/documents",
"/api/v1/notifications",
"/api/v1/profile",
"/api/v1/privacy/my-data",
"/api/v1/auth/logout",
}
allowed := false
for _, p := range allowedPaths {
if strings.HasPrefix(path, p) {
allowed = true
break
}
}
if !allowed {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "account_suspended",
"message": "Your account is suspended due to pending consent requirements",
"redirect": "/consent/pending",
})
return
}
// Set suspended flag in context for handlers to use
c.Set("account_suspended", true)
}
c.Next()
}
}
// IsSuspended checks if the current user's account is suspended
func IsSuspended(c *gin.Context) bool {
suspended, exists := c.Get("account_suspended")
if !exists {
return false
}
return suspended.(bool)
}

View File

@@ -0,0 +1,546 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
func init() {
gin.SetMode(gin.TestMode)
}
// Helper to create a valid JWT token for testing
func createTestToken(secret string, userID, email, role string, exp time.Time) string {
claims := UserClaims{
UserID: userID,
Email: email,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(exp),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte(secret))
return tokenString
}
// TestCORS tests the CORS middleware
func TestCORS(t *testing.T) {
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"success": true})
})
tests := []struct {
name string
origin string
method string
expectedStatus int
expectAllowedOrigin bool
}{
{"localhost:3000", "http://localhost:3000", "GET", http.StatusOK, true},
{"localhost:8000", "http://localhost:8000", "GET", http.StatusOK, true},
{"production", "https://breakpilot.app", "GET", http.StatusOK, true},
{"unknown origin", "https://unknown.com", "GET", http.StatusOK, false},
{"preflight", "http://localhost:3000", "OPTIONS", http.StatusNoContent, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest(tt.method, "/test", nil)
req.Header.Set("Origin", tt.origin)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
allowedOrigin := w.Header().Get("Access-Control-Allow-Origin")
if tt.expectAllowedOrigin && allowedOrigin != tt.origin {
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", tt.origin, allowedOrigin)
}
if !tt.expectAllowedOrigin && allowedOrigin != "" {
t.Errorf("Expected no Access-Control-Allow-Origin header, got %s", allowedOrigin)
}
})
}
}
// TestCORSHeaders tests that CORS headers are set correctly
func TestCORSHeaders(t *testing.T) {
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "http://localhost:3000")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
expectedHeaders := map[string]string{
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Origin, Content-Type, Authorization, X-Requested-With",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400",
}
for header, expected := range expectedHeaders {
actual := w.Header().Get(header)
if actual != expected {
t.Errorf("Expected %s to be %s, got %s", header, expected, actual)
}
}
}
// TestAuthMiddleware_ValidToken tests authentication with valid token
func TestAuthMiddleware_ValidToken(t *testing.T) {
secret := "test-secret-key"
userID := uuid.New().String()
email := "test@example.com"
role := "user"
router := gin.New()
router.Use(AuthMiddleware(secret))
router.GET("/protected", func(c *gin.Context) {
uid, _ := c.Get("user_id")
em, _ := c.Get("email")
r, _ := c.Get("role")
c.JSON(http.StatusOK, gin.H{
"user_id": uid,
"email": em,
"role": r,
})
})
token := createTestToken(secret, userID, email, role, time.Now().Add(time.Hour))
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestAuthMiddleware_MissingHeader tests authentication without header
func TestAuthMiddleware_MissingHeader(t *testing.T) {
router := gin.New()
router.Use(AuthMiddleware("test-secret"))
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/protected", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
// TestAuthMiddleware_InvalidFormat tests authentication with invalid header format
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
router := gin.New()
router.Use(AuthMiddleware("test-secret"))
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
tests := []struct {
name string
header string
}{
{"no Bearer prefix", "some-token"},
{"Basic auth", "Basic dXNlcjpwYXNz"},
{"empty Bearer", "Bearer "},
{"multiple spaces", "Bearer token"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", tt.header)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
})
}
}
// TestAuthMiddleware_ExpiredToken tests authentication with expired token
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
secret := "test-secret"
router := gin.New()
router.Use(AuthMiddleware(secret))
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
// Create expired token
token := createTestToken(secret, "user-123", "test@example.com", "user", time.Now().Add(-time.Hour))
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
// TestAuthMiddleware_WrongSecret tests authentication with wrong secret
func TestAuthMiddleware_WrongSecret(t *testing.T) {
router := gin.New()
router.Use(AuthMiddleware("correct-secret"))
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
// Create token with different secret
token := createTestToken("wrong-secret", "user-123", "test@example.com", "user", time.Now().Add(time.Hour))
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
// TestAdminOnly tests the AdminOnly middleware
func TestAdminOnly(t *testing.T) {
tests := []struct {
name string
role string
expectedStatus int
}{
{"admin allowed", "admin", http.StatusOK},
{"super_admin allowed", "super_admin", http.StatusOK},
{"dpo allowed", "data_protection_officer", http.StatusOK},
{"user forbidden", "user", http.StatusForbidden},
{"empty role forbidden", "", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("role", tt.role)
c.Next()
})
router.Use(AdminOnly())
router.GET("/admin", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/admin", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestAdminOnly_NoRole tests AdminOnly when role is not set
func TestAdminOnly_NoRole(t *testing.T) {
router := gin.New()
router.Use(AdminOnly())
router.GET("/admin", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/admin", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
// TestDSBOnly tests the DSBOnly middleware
func TestDSBOnly(t *testing.T) {
tests := []struct {
name string
role string
expectedStatus int
}{
{"dpo allowed", "data_protection_officer", http.StatusOK},
{"super_admin allowed", "super_admin", http.StatusOK},
{"admin forbidden", "admin", http.StatusForbidden},
{"user forbidden", "user", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("role", tt.role)
c.Next()
})
router.Use(DSBOnly())
router.GET("/dsb", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/dsb", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestIsAdmin tests the IsAdmin helper function
func TestIsAdmin(t *testing.T) {
tests := []struct {
name string
role string
expected bool
}{
{"admin", "admin", true},
{"super_admin", "super_admin", true},
{"dpo", "data_protection_officer", true},
{"user", "user", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.role != "" {
c.Set("role", tt.role)
}
result := IsAdmin(c)
if result != tt.expected {
t.Errorf("Expected IsAdmin to be %v, got %v", tt.expected, result)
}
})
}
}
// TestIsDSB tests the IsDSB helper function
func TestIsDSB(t *testing.T) {
tests := []struct {
name string
role string
expected bool
}{
{"dpo", "data_protection_officer", true},
{"super_admin", "super_admin", true},
{"admin", "admin", false},
{"user", "user", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Set("role", tt.role)
result := IsDSB(c)
if result != tt.expected {
t.Errorf("Expected IsDSB to be %v, got %v", tt.expected, result)
}
})
}
}
// TestGetUserID tests the GetUserID helper function
func TestGetUserID(t *testing.T) {
validUUID := uuid.New()
tests := []struct {
name string
userID string
setUserID bool
expectError bool
expectedID uuid.UUID
}{
{"valid UUID", validUUID.String(), true, false, validUUID},
{"invalid UUID", "not-a-uuid", true, true, uuid.Nil},
{"missing user_id", "", false, false, uuid.Nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setUserID {
c.Set("user_id", tt.userID)
}
result, err := GetUserID(c)
if tt.expectError && err == nil {
t.Error("Expected error but got none")
}
if !tt.expectError && result != tt.expectedID {
t.Errorf("Expected %v, got %v", tt.expectedID, result)
}
})
}
}
// TestGetClientIP tests the GetClientIP helper function
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
xff string
xri string
clientIP string
expectedIP string
}{
{"X-Forwarded-For", "10.0.0.1", "", "192.168.1.1", "10.0.0.1"},
{"X-Forwarded-For multiple", "10.0.0.1, 10.0.0.2", "", "192.168.1.1", "10.0.0.1"},
{"X-Real-IP", "", "10.0.0.1", "192.168.1.1", "10.0.0.1"},
{"direct", "", "", "192.168.1.1", "192.168.1.1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request, _ = http.NewRequest("GET", "/", nil)
if tt.xff != "" {
c.Request.Header.Set("X-Forwarded-For", tt.xff)
}
if tt.xri != "" {
c.Request.Header.Set("X-Real-IP", tt.xri)
}
c.Request.RemoteAddr = tt.clientIP + ":12345"
result := GetClientIP(c)
// Note: gin.ClientIP() might return different values
// depending on trusted proxies config
if result != tt.expectedIP && result != tt.clientIP {
t.Logf("Note: GetClientIP returned %s (expected %s or %s)", result, tt.expectedIP, tt.clientIP)
}
})
}
}
// TestGetUserAgent tests the GetUserAgent helper function
func TestGetUserAgent(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request, _ = http.NewRequest("GET", "/", nil)
expectedUA := "Mozilla/5.0 (Test)"
c.Request.Header.Set("User-Agent", expectedUA)
result := GetUserAgent(c)
if result != expectedUA {
t.Errorf("Expected %s, got %s", expectedUA, result)
}
}
// TestIsSuspended tests the IsSuspended helper function
func TestIsSuspended(t *testing.T) {
tests := []struct {
name string
suspended interface{}
setSuspended bool
expected bool
}{
{"suspended true", true, true, true},
{"suspended false", false, true, false},
{"not set", nil, false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setSuspended {
c.Set("account_suspended", tt.suspended)
}
result := IsSuspended(c)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
// BenchmarkCORS benchmarks the CORS middleware
func BenchmarkCORS(b *testing.B) {
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "http://localhost:3000")
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkAuthMiddleware benchmarks the auth middleware
func BenchmarkAuthMiddleware(b *testing.B) {
secret := "test-secret-key"
token := createTestToken(secret, uuid.New().String(), "test@example.com", "user", time.Now().Add(time.Hour))
router := gin.New()
router.Use(AuthMiddleware(secret))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}

View File

@@ -0,0 +1,197 @@
package middleware
import (
"regexp"
"strings"
)
// PIIPattern defines a pattern for identifying PII.
type PIIPattern struct {
Name string
Pattern *regexp.Regexp
Replacement string
}
// PIIRedactor redacts personally identifiable information from strings.
type PIIRedactor struct {
patterns []*PIIPattern
}
// Pre-compiled patterns for common PII types
var (
emailPattern = regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b`)
ipv4Pattern = regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b`)
ipv6Pattern = regexp.MustCompile(`\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b`)
phonePattern = regexp.MustCompile(`(?:\+49|0049)[\s.-]?\d{2,4}[\s.-]?\d{3,8}|\b0\d{2,4}[\s.-]?\d{3,8}\b`)
ibanPattern = regexp.MustCompile(`(?i)\b[A-Z]{2}\d{2}[\s]?(?:\d{4}[\s]?){3,5}\d{1,4}\b`)
uuidPattern = regexp.MustCompile(`(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b`)
namePattern = regexp.MustCompile(`\b(?:Herr|Frau|Hr\.|Fr\.)\s+[A-ZÄÖÜ][a-zäöüß]+(?:\s+[A-ZÄÖÜ][a-zäöüß]+)?\b`)
)
// DefaultPIIPatterns returns the default set of PII patterns.
func DefaultPIIPatterns() []*PIIPattern {
return []*PIIPattern{
{Name: "email", Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]"},
{Name: "ip_v4", Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]"},
{Name: "ip_v6", Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]"},
{Name: "phone", Pattern: phonePattern, Replacement: "[PHONE_REDACTED]"},
}
}
// AllPIIPatterns returns all available PII patterns.
func AllPIIPatterns() []*PIIPattern {
return []*PIIPattern{
{Name: "email", Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]"},
{Name: "ip_v4", Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]"},
{Name: "ip_v6", Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]"},
{Name: "phone", Pattern: phonePattern, Replacement: "[PHONE_REDACTED]"},
{Name: "iban", Pattern: ibanPattern, Replacement: "[IBAN_REDACTED]"},
{Name: "uuid", Pattern: uuidPattern, Replacement: "[UUID_REDACTED]"},
{Name: "name", Pattern: namePattern, Replacement: "[NAME_REDACTED]"},
}
}
// NewPIIRedactor creates a new PII redactor with the given patterns.
func NewPIIRedactor(patterns []*PIIPattern) *PIIRedactor {
if patterns == nil {
patterns = DefaultPIIPatterns()
}
return &PIIRedactor{patterns: patterns}
}
// NewDefaultPIIRedactor creates a PII redactor with default patterns.
func NewDefaultPIIRedactor() *PIIRedactor {
return NewPIIRedactor(DefaultPIIPatterns())
}
// Redact removes PII from the given text.
func (r *PIIRedactor) Redact(text string) string {
if text == "" {
return text
}
result := text
for _, pattern := range r.patterns {
result = pattern.Pattern.ReplaceAllString(result, pattern.Replacement)
}
return result
}
// ContainsPII checks if the text contains any PII.
func (r *PIIRedactor) ContainsPII(text string) bool {
if text == "" {
return false
}
for _, pattern := range r.patterns {
if pattern.Pattern.MatchString(text) {
return true
}
}
return false
}
// PIIFinding represents a found PII instance.
type PIIFinding struct {
Type string
Match string
Start int
End int
}
// FindPII finds all PII in the text.
func (r *PIIRedactor) FindPII(text string) []PIIFinding {
if text == "" {
return nil
}
var findings []PIIFinding
for _, pattern := range r.patterns {
matches := pattern.Pattern.FindAllStringIndex(text, -1)
for _, match := range matches {
findings = append(findings, PIIFinding{
Type: pattern.Name,
Match: text[match[0]:match[1]],
Start: match[0],
End: match[1],
})
}
}
return findings
}
// Default module-level redactor
var defaultRedactor = NewDefaultPIIRedactor()
// RedactPII is a convenience function that uses the default redactor.
func RedactPII(text string) string {
return defaultRedactor.Redact(text)
}
// ContainsPIIDefault checks if text contains PII using default patterns.
func ContainsPIIDefault(text string) bool {
return defaultRedactor.ContainsPII(text)
}
// RedactMap redacts PII from all string values in a map.
func RedactMap(data map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for key, value := range data {
switch v := value.(type) {
case string:
result[key] = RedactPII(v)
case map[string]interface{}:
result[key] = RedactMap(v)
case []interface{}:
result[key] = redactSlice(v)
default:
result[key] = v
}
}
return result
}
func redactSlice(data []interface{}) []interface{} {
result := make([]interface{}, len(data))
for i, value := range data {
switch v := value.(type) {
case string:
result[i] = RedactPII(v)
case map[string]interface{}:
result[i] = RedactMap(v)
case []interface{}:
result[i] = redactSlice(v)
default:
result[i] = v
}
}
return result
}
// SafeLogString creates a safe-to-log version of sensitive data.
// Use this for logging user-related information.
func SafeLogString(format string, args ...interface{}) string {
// Convert args to strings and redact
safeArgs := make([]interface{}, len(args))
for i, arg := range args {
switch v := arg.(type) {
case string:
safeArgs[i] = RedactPII(v)
case error:
safeArgs[i] = RedactPII(v.Error())
default:
safeArgs[i] = arg
}
}
// Note: We can't use fmt.Sprintf here due to the variadic nature
// Instead, we redact the result
result := format
for _, arg := range safeArgs {
if s, ok := arg.(string); ok {
result = strings.Replace(result, "%s", s, 1)
result = strings.Replace(result, "%v", s, 1)
}
}
return RedactPII(result)
}

View File

@@ -0,0 +1,228 @@
package middleware
import (
"testing"
)
func TestPIIRedactor_RedactsEmail(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "User test@example.com logged in"
result := redactor.Redact(text)
if result == text {
t.Error("Email should have been redacted")
}
if result != "User [EMAIL_REDACTED] logged in" {
t.Errorf("Unexpected result: %s", result)
}
}
func TestPIIRedactor_RedactsIPv4(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "Request from 192.168.1.100"
result := redactor.Redact(text)
if result == text {
t.Error("IP should have been redacted")
}
if result != "Request from [IP_REDACTED]" {
t.Errorf("Unexpected result: %s", result)
}
}
func TestPIIRedactor_RedactsGermanPhone(t *testing.T) {
redactor := NewDefaultPIIRedactor()
tests := []struct {
input string
expected string
}{
{"+49 30 12345678", "[PHONE_REDACTED]"},
{"0049 30 12345678", "[PHONE_REDACTED]"},
{"030 12345678", "[PHONE_REDACTED]"},
}
for _, tt := range tests {
result := redactor.Redact(tt.input)
if result != tt.expected {
t.Errorf("For input %q: expected %q, got %q", tt.input, tt.expected, result)
}
}
}
func TestPIIRedactor_RedactsMultiplePII(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "User test@example.com from 10.0.0.1"
result := redactor.Redact(text)
if result != "User [EMAIL_REDACTED] from [IP_REDACTED]" {
t.Errorf("Unexpected result: %s", result)
}
}
func TestPIIRedactor_PreservesNonPIIText(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "User logged in successfully"
result := redactor.Redact(text)
if result != text {
t.Errorf("Text should be unchanged: got %s", result)
}
}
func TestPIIRedactor_EmptyString(t *testing.T) {
redactor := NewDefaultPIIRedactor()
result := redactor.Redact("")
if result != "" {
t.Error("Empty string should remain empty")
}
}
func TestContainsPII(t *testing.T) {
redactor := NewDefaultPIIRedactor()
tests := []struct {
input string
expected bool
}{
{"test@example.com", true},
{"192.168.1.1", true},
{"+49 30 12345678", true},
{"Hello World", false},
{"", false},
}
for _, tt := range tests {
result := redactor.ContainsPII(tt.input)
if result != tt.expected {
t.Errorf("For input %q: expected %v, got %v", tt.input, tt.expected, result)
}
}
}
func TestFindPII(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "Email: test@example.com, IP: 10.0.0.1"
findings := redactor.FindPII(text)
if len(findings) != 2 {
t.Errorf("Expected 2 findings, got %d", len(findings))
}
hasEmail := false
hasIP := false
for _, f := range findings {
if f.Type == "email" {
hasEmail = true
}
if f.Type == "ip_v4" {
hasIP = true
}
}
if !hasEmail {
t.Error("Should have found email")
}
if !hasIP {
t.Error("Should have found IP")
}
}
func TestRedactPII_GlobalFunction(t *testing.T) {
text := "User test@example.com logged in"
result := RedactPII(text)
if result == text {
t.Error("Email should have been redacted")
}
}
func TestContainsPIIDefault(t *testing.T) {
if !ContainsPIIDefault("test@example.com") {
t.Error("Should detect email as PII")
}
if ContainsPIIDefault("Hello World") {
t.Error("Should not detect non-PII text")
}
}
func TestRedactMap(t *testing.T) {
data := map[string]interface{}{
"email": "test@example.com",
"message": "Hello World",
"nested": map[string]interface{}{
"ip": "192.168.1.1",
},
}
result := RedactMap(data)
if result["email"] != "[EMAIL_REDACTED]" {
t.Errorf("Email should be redacted: %v", result["email"])
}
if result["message"] != "Hello World" {
t.Errorf("Non-PII should be unchanged: %v", result["message"])
}
nested := result["nested"].(map[string]interface{})
if nested["ip"] != "[IP_REDACTED]" {
t.Errorf("Nested IP should be redacted: %v", nested["ip"])
}
}
func TestAllPIIPatterns(t *testing.T) {
patterns := AllPIIPatterns()
if len(patterns) == 0 {
t.Error("Should have PII patterns")
}
// Check that we have the expected patterns
expectedNames := []string{"email", "ip_v4", "ip_v6", "phone", "iban", "uuid", "name"}
nameMap := make(map[string]bool)
for _, p := range patterns {
nameMap[p.Name] = true
}
for _, name := range expectedNames {
if !nameMap[name] {
t.Errorf("Missing expected pattern: %s", name)
}
}
}
func TestDefaultPIIPatterns(t *testing.T) {
patterns := DefaultPIIPatterns()
if len(patterns) != 4 {
t.Errorf("Expected 4 default patterns, got %d", len(patterns))
}
}
func TestIBANRedaction(t *testing.T) {
redactor := NewPIIRedactor(AllPIIPatterns())
text := "IBAN: DE89 3704 0044 0532 0130 00"
result := redactor.Redact(text)
if result == text {
t.Error("IBAN should have been redacted")
}
}
func TestUUIDRedaction(t *testing.T) {
redactor := NewPIIRedactor(AllPIIPatterns())
text := "User ID: a0000000-0000-0000-0000-000000000001"
result := redactor.Redact(text)
if result == text {
t.Error("UUID should have been redacted")
}
}

View File

@@ -0,0 +1,75 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
// RequestIDHeader is the primary header for request IDs
RequestIDHeader = "X-Request-ID"
// CorrelationIDHeader is an alternative header for distributed tracing
CorrelationIDHeader = "X-Correlation-ID"
// RequestIDKey is the context key for storing the request ID
RequestIDKey = "request_id"
)
// RequestID returns a middleware that generates and propagates request IDs.
//
// For each incoming request:
// 1. Check for existing X-Request-ID or X-Correlation-ID header
// 2. If not present, generate a new UUID
// 3. Store in Gin context for use by handlers and logging
// 4. Add to response headers
//
// Usage:
//
// r.Use(middleware.RequestID())
//
// func handler(c *gin.Context) {
// requestID := middleware.GetRequestID(c)
// log.Printf("[%s] Processing request", requestID)
// }
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
// Try to get existing request ID from headers
requestID := c.GetHeader(RequestIDHeader)
if requestID == "" {
requestID = c.GetHeader(CorrelationIDHeader)
}
// Generate new ID if not provided
if requestID == "" {
requestID = uuid.New().String()
}
// Store in context for handlers and logging
c.Set(RequestIDKey, requestID)
// Add to response headers
c.Header(RequestIDHeader, requestID)
c.Header(CorrelationIDHeader, requestID)
c.Next()
}
}
// GetRequestID retrieves the request ID from the Gin context.
// Returns empty string if no request ID is set.
//
// Usage:
//
// requestID := middleware.GetRequestID(c)
func GetRequestID(c *gin.Context) string {
if id, exists := c.Get(RequestIDKey); exists {
if idStr, ok := id.(string); ok {
return idStr
}
}
return ""
}
// RequestIDFromContext is an alias for GetRequestID for API compatibility.
func RequestIDFromContext(c *gin.Context) string {
return GetRequestID(c)
}

View File

@@ -0,0 +1,152 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestRequestID_GeneratesNewID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestID())
router.GET("/test", func(c *gin.Context) {
requestID := GetRequestID(c)
if requestID == "" {
t.Error("Expected request ID to be set")
}
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Check response header
requestID := w.Header().Get(RequestIDHeader)
if requestID == "" {
t.Error("Expected X-Request-ID header in response")
}
// Check correlation ID header
correlationID := w.Header().Get(CorrelationIDHeader)
if correlationID == "" {
t.Error("Expected X-Correlation-ID header in response")
}
if requestID != correlationID {
t.Error("X-Request-ID and X-Correlation-ID should match")
}
}
func TestRequestID_PropagatesExistingID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestID())
customID := "custom-request-id-12345"
router.GET("/test", func(c *gin.Context) {
requestID := GetRequestID(c)
if requestID != customID {
t.Errorf("Expected request ID %s, got %s", customID, requestID)
}
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(RequestIDHeader, customID)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
responseID := w.Header().Get(RequestIDHeader)
if responseID != customID {
t.Errorf("Expected response header %s, got %s", customID, responseID)
}
}
func TestRequestID_PropagatesCorrelationID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestID())
correlationID := "correlation-id-67890"
router.GET("/test", func(c *gin.Context) {
requestID := GetRequestID(c)
if requestID != correlationID {
t.Errorf("Expected request ID %s, got %s", correlationID, requestID)
}
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(CorrelationIDHeader, correlationID)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Both headers should be set with the correlation ID
if w.Header().Get(RequestIDHeader) != correlationID {
t.Error("X-Request-ID should match X-Correlation-ID")
}
}
func TestGetRequestID_ReturnsEmptyWhenNotSet(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// No RequestID middleware
router.GET("/test", func(c *gin.Context) {
requestID := GetRequestID(c)
if requestID != "" {
t.Errorf("Expected empty request ID, got %s", requestID)
}
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}
func TestRequestIDFromContext_IsAliasForGetRequestID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestID())
router.GET("/test", func(c *gin.Context) {
id1 := GetRequestID(c)
id2 := RequestIDFromContext(c)
if id1 != id2 {
t.Errorf("GetRequestID and RequestIDFromContext should return same value")
}
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}

View File

@@ -0,0 +1,167 @@
package middleware
import (
"os"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// SecurityHeadersConfig holds configuration for security headers.
type SecurityHeadersConfig struct {
// X-Content-Type-Options
ContentTypeOptions string
// X-Frame-Options
FrameOptions string
// X-XSS-Protection (legacy but useful for older browsers)
XSSProtection string
// Strict-Transport-Security
HSTSEnabled bool
HSTSMaxAge int
HSTSIncludeSubdomains bool
HSTSPreload bool
// Content-Security-Policy
CSPEnabled bool
CSPPolicy string
// Referrer-Policy
ReferrerPolicy string
// Permissions-Policy
PermissionsPolicy string
// Cross-Origin headers
CrossOriginOpenerPolicy string
CrossOriginResourcePolicy string
// Development mode (relaxes some restrictions)
DevelopmentMode bool
// Excluded paths (e.g., health checks)
ExcludedPaths []string
}
// DefaultSecurityHeadersConfig returns sensible default configuration.
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
env := os.Getenv("ENVIRONMENT")
isDev := env == "" || strings.ToLower(env) == "development" || strings.ToLower(env) == "dev"
return SecurityHeadersConfig{
ContentTypeOptions: "nosniff",
FrameOptions: "DENY",
XSSProtection: "1; mode=block",
HSTSEnabled: true,
HSTSMaxAge: 31536000, // 1 year
HSTSIncludeSubdomains: true,
HSTSPreload: false,
CSPEnabled: true,
CSPPolicy: getDefaultCSP(isDev),
ReferrerPolicy: "strict-origin-when-cross-origin",
PermissionsPolicy: "geolocation=(), microphone=(), camera=()",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
DevelopmentMode: isDev,
ExcludedPaths: []string{"/health", "/metrics", "/api/v1/health"},
}
}
// getDefaultCSP returns a sensible default CSP for the environment.
func getDefaultCSP(isDevelopment bool) string {
if isDevelopment {
return "default-src 'self' localhost:* ws://localhost:*; " +
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; " +
"style-src 'self' 'unsafe-inline'; " +
"img-src 'self' data: https: blob:; " +
"font-src 'self' data:; " +
"connect-src 'self' localhost:* ws://localhost:* https:; " +
"frame-ancestors 'self'"
}
return "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'; " +
"style-src 'self' 'unsafe-inline'; " +
"img-src 'self' data: https:; " +
"font-src 'self' data:; " +
"connect-src 'self' https://breakpilot.app https://*.breakpilot.app; " +
"frame-ancestors 'none'"
}
// buildHSTSHeader builds the Strict-Transport-Security header value.
func (c *SecurityHeadersConfig) buildHSTSHeader() string {
parts := []string{"max-age=" + strconv.Itoa(c.HSTSMaxAge)}
if c.HSTSIncludeSubdomains {
parts = append(parts, "includeSubDomains")
}
if c.HSTSPreload {
parts = append(parts, "preload")
}
return strings.Join(parts, "; ")
}
// isExcludedPath checks if the path should be excluded from security headers.
func (c *SecurityHeadersConfig) isExcludedPath(path string) bool {
for _, excluded := range c.ExcludedPaths {
if path == excluded {
return true
}
}
return false
}
// SecurityHeaders returns a middleware that adds security headers to all responses.
//
// Usage:
//
// r.Use(middleware.SecurityHeaders())
//
// // Or with custom config:
// config := middleware.DefaultSecurityHeadersConfig()
// config.CSPPolicy = "default-src 'self'"
// r.Use(middleware.SecurityHeadersWithConfig(config))
func SecurityHeaders() gin.HandlerFunc {
return SecurityHeadersWithConfig(DefaultSecurityHeadersConfig())
}
// SecurityHeadersWithConfig returns a security headers middleware with custom configuration.
func SecurityHeadersWithConfig(config SecurityHeadersConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip for excluded paths
if config.isExcludedPath(c.Request.URL.Path) {
c.Next()
return
}
// Always add these headers
c.Header("X-Content-Type-Options", config.ContentTypeOptions)
c.Header("X-Frame-Options", config.FrameOptions)
c.Header("X-XSS-Protection", config.XSSProtection)
c.Header("Referrer-Policy", config.ReferrerPolicy)
// HSTS (only in production or if explicitly enabled)
if config.HSTSEnabled && !config.DevelopmentMode {
c.Header("Strict-Transport-Security", config.buildHSTSHeader())
}
// Content-Security-Policy
if config.CSPEnabled && config.CSPPolicy != "" {
c.Header("Content-Security-Policy", config.CSPPolicy)
}
// Permissions-Policy
if config.PermissionsPolicy != "" {
c.Header("Permissions-Policy", config.PermissionsPolicy)
}
// Cross-Origin headers (only in production)
if !config.DevelopmentMode {
c.Header("Cross-Origin-Opener-Policy", config.CrossOriginOpenerPolicy)
c.Header("Cross-Origin-Resource-Policy", config.CrossOriginResourcePolicy)
}
c.Next()
}
}

View File

@@ -0,0 +1,377 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestSecurityHeaders_AddsBasicHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = true // Skip HSTS and cross-origin headers
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Check basic security headers
tests := []struct {
header string
expected string
}{
{"X-Content-Type-Options", "nosniff"},
{"X-Frame-Options", "DENY"},
{"X-XSS-Protection", "1; mode=block"},
{"Referrer-Policy", "strict-origin-when-cross-origin"},
}
for _, tt := range tests {
value := w.Header().Get(tt.header)
if value != tt.expected {
t.Errorf("Header %s: expected %q, got %q", tt.header, tt.expected, value)
}
}
}
func TestSecurityHeaders_HSTSNotAddedInDevelopment(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = true
config.HSTSEnabled = true
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
hstsHeader := w.Header().Get("Strict-Transport-Security")
if hstsHeader != "" {
t.Errorf("HSTS should not be set in development mode, got: %s", hstsHeader)
}
}
func TestSecurityHeaders_HSTSAddedInProduction(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = false
config.HSTSEnabled = true
config.HSTSMaxAge = 31536000
config.HSTSIncludeSubdomains = true
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
hstsHeader := w.Header().Get("Strict-Transport-Security")
if hstsHeader == "" {
t.Error("HSTS should be set in production mode")
}
// Check that it contains max-age
if hstsHeader != "max-age=31536000; includeSubDomains" {
t.Errorf("Unexpected HSTS value: %s", hstsHeader)
}
}
func TestSecurityHeaders_HSTSWithPreload(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = false
config.HSTSEnabled = true
config.HSTSMaxAge = 31536000
config.HSTSIncludeSubdomains = true
config.HSTSPreload = true
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
hstsHeader := w.Header().Get("Strict-Transport-Security")
expected := "max-age=31536000; includeSubDomains; preload"
if hstsHeader != expected {
t.Errorf("Expected HSTS %q, got %q", expected, hstsHeader)
}
}
func TestSecurityHeaders_CSPHeader(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.CSPEnabled = true
config.CSPPolicy = "default-src 'self'"
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
cspHeader := w.Header().Get("Content-Security-Policy")
if cspHeader != "default-src 'self'" {
t.Errorf("Expected CSP %q, got %q", "default-src 'self'", cspHeader)
}
}
func TestSecurityHeaders_NoCSPWhenDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.CSPEnabled = false
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
cspHeader := w.Header().Get("Content-Security-Policy")
if cspHeader != "" {
t.Errorf("CSP should not be set when disabled, got: %s", cspHeader)
}
}
func TestSecurityHeaders_ExcludedPaths(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.ExcludedPaths = []string{"/health", "/metrics"}
router.Use(SecurityHeadersWithConfig(config))
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
})
router.GET("/api", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Test excluded path
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Header().Get("X-Content-Type-Options") != "" {
t.Error("Security headers should not be set for excluded paths")
}
// Test non-excluded path
req = httptest.NewRequest(http.MethodGet, "/api", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Header().Get("X-Content-Type-Options") != "nosniff" {
t.Error("Security headers should be set for non-excluded paths")
}
}
func TestSecurityHeaders_CrossOriginInProduction(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = false
config.CrossOriginOpenerPolicy = "same-origin"
config.CrossOriginResourcePolicy = "same-origin"
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
coopHeader := w.Header().Get("Cross-Origin-Opener-Policy")
if coopHeader != "same-origin" {
t.Errorf("Expected COOP %q, got %q", "same-origin", coopHeader)
}
corpHeader := w.Header().Get("Cross-Origin-Resource-Policy")
if corpHeader != "same-origin" {
t.Errorf("Expected CORP %q, got %q", "same-origin", corpHeader)
}
}
func TestSecurityHeaders_NoCrossOriginInDevelopment(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = true
config.CrossOriginOpenerPolicy = "same-origin"
config.CrossOriginResourcePolicy = "same-origin"
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Header().Get("Cross-Origin-Opener-Policy") != "" {
t.Error("COOP should not be set in development mode")
}
if w.Header().Get("Cross-Origin-Resource-Policy") != "" {
t.Error("CORP should not be set in development mode")
}
}
func TestSecurityHeaders_PermissionsPolicy(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.PermissionsPolicy = "geolocation=(), microphone=()"
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
ppHeader := w.Header().Get("Permissions-Policy")
if ppHeader != "geolocation=(), microphone=()" {
t.Errorf("Expected Permissions-Policy %q, got %q", "geolocation=(), microphone=()", ppHeader)
}
}
func TestSecurityHeaders_DefaultMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// Use the default middleware function
router.Use(SecurityHeaders())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Should at least have the basic headers
if w.Header().Get("X-Content-Type-Options") != "nosniff" {
t.Error("Default middleware should set X-Content-Type-Options")
}
}
func TestBuildHSTSHeader(t *testing.T) {
tests := []struct {
name string
config SecurityHeadersConfig
expected string
}{
{
name: "basic HSTS",
config: SecurityHeadersConfig{
HSTSMaxAge: 31536000,
HSTSIncludeSubdomains: false,
HSTSPreload: false,
},
expected: "max-age=31536000",
},
{
name: "HSTS with subdomains",
config: SecurityHeadersConfig{
HSTSMaxAge: 31536000,
HSTSIncludeSubdomains: true,
HSTSPreload: false,
},
expected: "max-age=31536000; includeSubDomains",
},
{
name: "HSTS with preload",
config: SecurityHeadersConfig{
HSTSMaxAge: 31536000,
HSTSIncludeSubdomains: true,
HSTSPreload: true,
},
expected: "max-age=31536000; includeSubDomains; preload",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.config.buildHSTSHeader()
if result != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, result)
}
})
}
}
func TestIsExcludedPath(t *testing.T) {
config := SecurityHeadersConfig{
ExcludedPaths: []string{"/health", "/metrics", "/api/v1/health"},
}
tests := []struct {
path string
excluded bool
}{
{"/health", true},
{"/metrics", true},
{"/api/v1/health", true},
{"/api", false},
{"/health/check", false},
{"/", false},
}
for _, tt := range tests {
result := config.isExcludedPath(tt.path)
if result != tt.excluded {
t.Errorf("Path %s: expected excluded=%v, got %v", tt.path, tt.excluded, result)
}
}
}