package middleware import ( "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) func init() { gin.SetMode(gin.TestMode) } // Helper to create a valid JWT token for testing func createTestToken(secret string, userID, email, role string, exp time.Time) string { claims := UserClaims{ UserID: userID, Email: email, Role: role, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(exp), IssuedAt: jwt.NewNumericDate(time.Now()), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte(secret)) return tokenString } // TestCORS tests the CORS middleware func TestCORS(t *testing.T) { router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"success": true}) }) tests := []struct { name string origin string method string expectedStatus int expectAllowedOrigin bool }{ {"localhost:3000", "http://localhost:3000", "GET", http.StatusOK, true}, {"localhost:8000", "http://localhost:8000", "GET", http.StatusOK, true}, {"production", "https://breakpilot.app", "GET", http.StatusOK, true}, {"unknown origin", "https://unknown.com", "GET", http.StatusOK, false}, {"preflight", "http://localhost:3000", "OPTIONS", http.StatusNoContent, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest(tt.method, "/test", nil) req.Header.Set("Origin", tt.origin) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) } allowedOrigin := w.Header().Get("Access-Control-Allow-Origin") if tt.expectAllowedOrigin && allowedOrigin != tt.origin { t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", tt.origin, allowedOrigin) } if !tt.expectAllowedOrigin && allowedOrigin != "" { t.Errorf("Expected no Access-Control-Allow-Origin header, got %s", allowedOrigin) } }) } } // TestCORSHeaders tests that CORS headers are set correctly func TestCORSHeaders(t *testing.T) { router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "http://localhost:3000") w := httptest.NewRecorder() router.ServeHTTP(w, req) expectedHeaders := map[string]string{ "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", "Access-Control-Allow-Headers": "Origin, Content-Type, Authorization, X-Requested-With", "Access-Control-Allow-Credentials": "true", "Access-Control-Max-Age": "86400", } for header, expected := range expectedHeaders { actual := w.Header().Get(header) if actual != expected { t.Errorf("Expected %s to be %s, got %s", header, expected, actual) } } } // TestAuthMiddleware_ValidToken tests authentication with valid token func TestAuthMiddleware_ValidToken(t *testing.T) { secret := "test-secret-key" userID := uuid.New().String() email := "test@example.com" role := "user" router := gin.New() router.Use(AuthMiddleware(secret)) router.GET("/protected", func(c *gin.Context) { uid, _ := c.Get("user_id") em, _ := c.Get("email") r, _ := c.Get("role") c.JSON(http.StatusOK, gin.H{ "user_id": uid, "email": em, "role": r, }) }) token := createTestToken(secret, userID, email, role, time.Now().Add(time.Hour)) req, _ := http.NewRequest("GET", "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) } } // TestAuthMiddleware_MissingHeader tests authentication without header func TestAuthMiddleware_MissingHeader(t *testing.T) { router := gin.New() router.Use(AuthMiddleware("test-secret")) router.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) req, _ := http.NewRequest("GET", "/protected", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } // TestAuthMiddleware_InvalidFormat tests authentication with invalid header format func TestAuthMiddleware_InvalidFormat(t *testing.T) { router := gin.New() router.Use(AuthMiddleware("test-secret")) router.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) tests := []struct { name string header string }{ {"no Bearer prefix", "some-token"}, {"Basic auth", "Basic dXNlcjpwYXNz"}, {"empty Bearer", "Bearer "}, {"multiple spaces", "Bearer token"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest("GET", "/protected", nil) req.Header.Set("Authorization", tt.header) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } }) } } // TestAuthMiddleware_ExpiredToken tests authentication with expired token func TestAuthMiddleware_ExpiredToken(t *testing.T) { secret := "test-secret" router := gin.New() router.Use(AuthMiddleware(secret)) router.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) // Create expired token token := createTestToken(secret, "user-123", "test@example.com", "user", time.Now().Add(-time.Hour)) req, _ := http.NewRequest("GET", "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } // TestAuthMiddleware_WrongSecret tests authentication with wrong secret func TestAuthMiddleware_WrongSecret(t *testing.T) { router := gin.New() router.Use(AuthMiddleware("correct-secret")) router.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) // Create token with different secret token := createTestToken("wrong-secret", "user-123", "test@example.com", "user", time.Now().Add(time.Hour)) req, _ := http.NewRequest("GET", "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } // TestAdminOnly tests the AdminOnly middleware func TestAdminOnly(t *testing.T) { tests := []struct { name string role string expectedStatus int }{ {"admin allowed", "admin", http.StatusOK}, {"super_admin allowed", "super_admin", http.StatusOK}, {"dpo allowed", "data_protection_officer", http.StatusOK}, {"user forbidden", "user", http.StatusForbidden}, {"empty role forbidden", "", http.StatusForbidden}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { router := gin.New() router.Use(func(c *gin.Context) { c.Set("role", tt.role) c.Next() }) router.Use(AdminOnly()) router.GET("/admin", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) req, _ := http.NewRequest("GET", "/admin", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) } }) } } // TestAdminOnly_NoRole tests AdminOnly when role is not set func TestAdminOnly_NoRole(t *testing.T) { router := gin.New() router.Use(AdminOnly()) router.GET("/admin", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) req, _ := http.NewRequest("GET", "/admin", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } // TestDSBOnly tests the DSBOnly middleware func TestDSBOnly(t *testing.T) { tests := []struct { name string role string expectedStatus int }{ {"dpo allowed", "data_protection_officer", http.StatusOK}, {"super_admin allowed", "super_admin", http.StatusOK}, {"admin forbidden", "admin", http.StatusForbidden}, {"user forbidden", "user", http.StatusForbidden}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { router := gin.New() router.Use(func(c *gin.Context) { c.Set("role", tt.role) c.Next() }) router.Use(DSBOnly()) router.GET("/dsb", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) req, _ := http.NewRequest("GET", "/dsb", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != tt.expectedStatus { t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) } }) } } // TestIsAdmin tests the IsAdmin helper function func TestIsAdmin(t *testing.T) { tests := []struct { name string role string expected bool }{ {"admin", "admin", true}, {"super_admin", "super_admin", true}, {"dpo", "data_protection_officer", true}, {"user", "user", false}, {"empty", "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) if tt.role != "" { c.Set("role", tt.role) } result := IsAdmin(c) if result != tt.expected { t.Errorf("Expected IsAdmin to be %v, got %v", tt.expected, result) } }) } } // TestIsDSB tests the IsDSB helper function func TestIsDSB(t *testing.T) { tests := []struct { name string role string expected bool }{ {"dpo", "data_protection_officer", true}, {"super_admin", "super_admin", true}, {"admin", "admin", false}, {"user", "user", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Set("role", tt.role) result := IsDSB(c) if result != tt.expected { t.Errorf("Expected IsDSB to be %v, got %v", tt.expected, result) } }) } } // TestGetUserID tests the GetUserID helper function func TestGetUserID(t *testing.T) { validUUID := uuid.New() tests := []struct { name string userID string setUserID bool expectError bool expectedID uuid.UUID }{ {"valid UUID", validUUID.String(), true, false, validUUID}, {"invalid UUID", "not-a-uuid", true, true, uuid.Nil}, {"missing user_id", "", false, false, uuid.Nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) if tt.setUserID { c.Set("user_id", tt.userID) } result, err := GetUserID(c) if tt.expectError && err == nil { t.Error("Expected error but got none") } if !tt.expectError && result != tt.expectedID { t.Errorf("Expected %v, got %v", tt.expectedID, result) } }) } } // TestGetClientIP tests the GetClientIP helper function func TestGetClientIP(t *testing.T) { tests := []struct { name string xff string xri string clientIP string expectedIP string }{ {"X-Forwarded-For", "10.0.0.1", "", "192.168.1.1", "10.0.0.1"}, {"X-Forwarded-For multiple", "10.0.0.1, 10.0.0.2", "", "192.168.1.1", "10.0.0.1"}, {"X-Real-IP", "", "10.0.0.1", "192.168.1.1", "10.0.0.1"}, {"direct", "", "", "192.168.1.1", "192.168.1.1"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request, _ = http.NewRequest("GET", "/", nil) if tt.xff != "" { c.Request.Header.Set("X-Forwarded-For", tt.xff) } if tt.xri != "" { c.Request.Header.Set("X-Real-IP", tt.xri) } c.Request.RemoteAddr = tt.clientIP + ":12345" result := GetClientIP(c) // Note: gin.ClientIP() might return different values // depending on trusted proxies config if result != tt.expectedIP && result != tt.clientIP { t.Logf("Note: GetClientIP returned %s (expected %s or %s)", result, tt.expectedIP, tt.clientIP) } }) } } // TestGetUserAgent tests the GetUserAgent helper function func TestGetUserAgent(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request, _ = http.NewRequest("GET", "/", nil) expectedUA := "Mozilla/5.0 (Test)" c.Request.Header.Set("User-Agent", expectedUA) result := GetUserAgent(c) if result != expectedUA { t.Errorf("Expected %s, got %s", expectedUA, result) } } // TestIsSuspended tests the IsSuspended helper function func TestIsSuspended(t *testing.T) { tests := []struct { name string suspended interface{} setSuspended bool expected bool }{ {"suspended true", true, true, true}, {"suspended false", false, true, false}, {"not set", nil, false, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) if tt.setSuspended { c.Set("account_suspended", tt.suspended) } result := IsSuspended(c) if result != tt.expected { t.Errorf("Expected %v, got %v", tt.expected, result) } }) } } // BenchmarkCORS benchmarks the CORS middleware func BenchmarkCORS(b *testing.B) { router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "http://localhost:3000") for i := 0; i < b.N; i++ { w := httptest.NewRecorder() router.ServeHTTP(w, req) } } // BenchmarkAuthMiddleware benchmarks the auth middleware func BenchmarkAuthMiddleware(b *testing.B) { secret := "test-secret-key" token := createTestToken(secret, uuid.New().String(), "test@example.com", "user", time.Now().Add(time.Hour)) router := gin.New() router.Use(AuthMiddleware(secret)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) for i := 0; i < b.N; i++ { w := httptest.NewRecorder() router.ServeHTTP(w, req) } }