package middleware import ( "net/http" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" ) // UserClaims represents the JWT claims structure type UserClaims struct { UserID string `json:"user_id"` Email string `json:"email"` Role string `json:"role"` jwt.RegisteredClaims } // CORS middleware func CORS() gin.HandlerFunc { return func(c *gin.Context) { origin := c.Request.Header.Get("Origin") if origin == "" { origin = "*" } c.Writer.Header().Set("Access-Control-Allow-Origin", origin) c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(204) return } c.Next() } } // RequestLogger logs HTTP requests func RequestLogger() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path c.Next() latency := time.Since(start) status := c.Writer.Status() if status >= 400 { gin.DefaultWriter.Write([]byte( c.Request.Method + " " + path + " " + http.StatusText(status) + " " + latency.String() + "\n", )) } } } // Rate limiter storage var ( rateLimitMu sync.Mutex rateLimits = make(map[string][]time.Time) ) // RateLimiter implements per-IP rate limiting func RateLimiter() gin.HandlerFunc { return func(c *gin.Context) { ip := c.ClientIP() // Skip rate limiting for internal Docker IPs if strings.HasPrefix(ip, "172.") || strings.HasPrefix(ip, "10.") || ip == "127.0.0.1" { c.Next() return } rateLimitMu.Lock() defer rateLimitMu.Unlock() now := time.Now() windowStart := now.Add(-time.Minute) // Clean old entries var recent []time.Time for _, t := range rateLimits[ip] { if t.After(windowStart) { recent = append(recent, t) } } // Check limit (500 requests per minute) if len(recent) >= 500 { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded", }) return } rateLimits[ip] = append(recent, now) c.Next() } } // AuthMiddleware validates JWT tokens func AuthMiddleware(jwtSecret string) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Authorization header required", }) return } tokenString := strings.TrimPrefix(authHeader, "Bearer ") if tokenString == authHeader { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Invalid authorization format", }) return } token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte(jwtSecret), nil }) if err != nil || !token.Valid { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Invalid or expired token", }) return } claims, ok := token.Claims.(*UserClaims) if !ok { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Invalid token claims", }) return } // Set user info in context c.Set("user_id", claims.UserID) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Next() } } // AdminOnly restricts access to admin users func AdminOnly() gin.HandlerFunc { return func(c *gin.Context) { role := c.GetString("role") if role != "admin" && role != "super_admin" { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": "Admin access required", }) return } c.Next() } }