package middleware import ( "net/http" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) // UserClaims represents the JWT claims for a user type UserClaims struct { UserID string `json:"user_id"` Email string `json:"email"` Role string `json:"role"` jwt.RegisteredClaims } // CORS returns a CORS middleware func CORS() gin.HandlerFunc { return func(c *gin.Context) { origin := c.Request.Header.Get("Origin") // Allow localhost for development allowedOrigins := []string{ "http://localhost:3000", "http://localhost:8000", "http://localhost:8080", "http://localhost:8083", "https://breakpilot.app", } allowed := false for _, o := range allowedOrigins { if origin == o { allowed = true break } } if allowed { c.Header("Access-Control-Allow-Origin", origin) } c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With, X-Internal-API-Key") c.Header("Access-Control-Allow-Credentials", "true") c.Header("Access-Control-Max-Age", "86400") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } // RequestLogger logs each request func RequestLogger() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path method := c.Request.Method c.Next() latency := time.Since(start) status := c.Writer.Status() // Log only in development or for errors if status >= 400 { gin.DefaultWriter.Write([]byte( method + " " + path + " " + string(rune(status)) + " " + latency.String() + "\n", )) } } } // RateLimiter implements a simple in-memory rate limiter func RateLimiter() gin.HandlerFunc { type client struct { count int lastSeen time.Time } var ( mu sync.Mutex clients = make(map[string]*client) ) // Clean up old entries periodically go func() { for { time.Sleep(time.Minute) mu.Lock() for ip, c := range clients { if time.Since(c.lastSeen) > time.Minute { delete(clients, ip) } } mu.Unlock() } }() return func(c *gin.Context) { ip := c.ClientIP() mu.Lock() defer mu.Unlock() if _, exists := clients[ip]; !exists { clients[ip] = &client{} } cli := clients[ip] // Reset count if more than a minute has passed if time.Since(cli.lastSeen) > time.Minute { cli.count = 0 } cli.count++ cli.lastSeen = time.Now() // Allow 100 requests per minute if cli.count > 100 { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "error": "rate_limit_exceeded", "message": "Too many requests. Please try again later.", }) return } c.Next() } } // AuthMiddleware validates JWT tokens func AuthMiddleware(jwtSecret string) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "missing_authorization", "message": "Authorization header is required", }) return } // Extract token from "Bearer " parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "invalid_authorization", "message": "Authorization header must be in format: Bearer ", }) return } tokenString := parts[1] // Parse and validate token token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte(jwtSecret), nil }) if err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "invalid_token", "message": "Invalid or expired token", }) return } if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { // Set user info in context c.Set("user_id", claims.UserID) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Next() } else { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "invalid_claims", "message": "Invalid token claims", }) return } } } // InternalAPIKeyMiddleware validates internal API key for service-to-service communication func InternalAPIKeyMiddleware(apiKey string) gin.HandlerFunc { return func(c *gin.Context) { if apiKey == "" { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ "error": "config_error", "message": "Internal API key not configured", }) return } providedKey := c.GetHeader("X-Internal-API-Key") if providedKey == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "missing_api_key", "message": "X-Internal-API-Key header is required", }) return } if providedKey != apiKey { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "invalid_api_key", "message": "Invalid API key", }) return } c.Next() } } // AdminOnly ensures only admin users can access the route func AdminOnly() gin.HandlerFunc { return func(c *gin.Context) { role, exists := c.Get("role") if !exists { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "User role not found", }) return } roleStr, ok := role.(string) if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "Admin access required", }) return } c.Next() } } // GetUserID extracts the user ID from the context func GetUserID(c *gin.Context) (uuid.UUID, error) { userIDStr, exists := c.Get("user_id") if !exists { return uuid.Nil, nil } userID, err := uuid.Parse(userIDStr.(string)) if err != nil { return uuid.Nil, err } return userID, nil } // GetClientIP returns the client's IP address func GetClientIP(c *gin.Context) string { // Check X-Forwarded-For header first (for proxied requests) if xff := c.GetHeader("X-Forwarded-For"); xff != "" { ips := strings.Split(xff, ",") return strings.TrimSpace(ips[0]) } // Check X-Real-IP header if xri := c.GetHeader("X-Real-IP"); xri != "" { return xri } return c.ClientIP() } // GetUserAgent returns the client's User-Agent func GetUserAgent(c *gin.Context) string { return c.GetHeader("User-Agent") }