package middleware import ( "net/http" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) // UserClaims represents the JWT claims for a user type UserClaims struct { UserID string `json:"user_id"` Email string `json:"email"` Role string `json:"role"` jwt.RegisteredClaims } // CORS returns a CORS middleware func CORS() gin.HandlerFunc { return func(c *gin.Context) { origin := c.Request.Header.Get("Origin") // Allow localhost for development allowedOrigins := []string{ "http://localhost:3000", "http://localhost:8000", "http://localhost:8080", "https://breakpilot.app", } allowed := false for _, o := range allowedOrigins { if origin == o { allowed = true break } } if allowed { c.Header("Access-Control-Allow-Origin", origin) } c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With") c.Header("Access-Control-Allow-Credentials", "true") c.Header("Access-Control-Max-Age", "86400") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } // RequestLogger logs each request func RequestLogger() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path method := c.Request.Method c.Next() latency := time.Since(start) status := c.Writer.Status() // Log only in development or for errors if status >= 400 { gin.DefaultWriter.Write([]byte( method + " " + path + " " + string(rune(status)) + " " + latency.String() + "\n", )) } } } // RateLimiter implements a simple in-memory rate limiter // Configurable via RATE_LIMIT_PER_MINUTE env var (default: 500) func RateLimiter() gin.HandlerFunc { type client struct { count int lastSeen time.Time } var ( mu sync.Mutex clients = make(map[string]*client) ) // Clean up old entries periodically go func() { for { time.Sleep(time.Minute) mu.Lock() for ip, c := range clients { if time.Since(c.lastSeen) > time.Minute { delete(clients, ip) } } mu.Unlock() } }() return func(c *gin.Context) { ip := c.ClientIP() // Skip rate limiting for Docker internal network (172.x.x.x) and localhost // This prevents issues when multiple services share the same internal IP if strings.HasPrefix(ip, "172.") || ip == "127.0.0.1" || ip == "::1" { c.Next() return } mu.Lock() defer mu.Unlock() if _, exists := clients[ip]; !exists { clients[ip] = &client{} } cli := clients[ip] // Reset count if more than a minute has passed if time.Since(cli.lastSeen) > time.Minute { cli.count = 0 } cli.count++ cli.lastSeen = time.Now() // Allow 500 requests per minute (increased for admin panels with many API calls) if cli.count > 500 { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "error": "rate_limit_exceeded", "message": "Too many requests. Please try again later.", }) return } c.Next() } } // AuthMiddleware validates JWT tokens func AuthMiddleware(jwtSecret string) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "missing_authorization", "message": "Authorization header is required", }) return } // Extract token from "Bearer " parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "invalid_authorization", "message": "Authorization header must be in format: Bearer ", }) return } tokenString := parts[1] // Parse and validate token token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { return []byte(jwtSecret), nil }) if err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "invalid_token", "message": "Invalid or expired token", }) return } if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { // Set user info in context c.Set("user_id", claims.UserID) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Next() } else { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "invalid_claims", "message": "Invalid token claims", }) return } } } // AdminOnly ensures only admin users can access the route func AdminOnly() gin.HandlerFunc { return func(c *gin.Context) { role, exists := c.Get("role") if !exists { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "User role not found", }) return } roleStr, ok := role.(string) if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "Admin access required", }) return } c.Next() } } // DSBOnly ensures only Data Protection Officers can access the route // Used for critical operations like publishing legal documents (four-eyes principle) func DSBOnly() gin.HandlerFunc { return func(c *gin.Context) { role, exists := c.Get("role") if !exists { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "User role not found", }) return } roleStr, ok := role.(string) if !ok || (roleStr != "data_protection_officer" && roleStr != "super_admin") { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "Only Data Protection Officers can perform this action", }) return } c.Next() } } // IsAdmin checks if the user has admin role func IsAdmin(c *gin.Context) bool { role, exists := c.Get("role") if !exists { return false } roleStr, ok := role.(string) return ok && (roleStr == "admin" || roleStr == "super_admin" || roleStr == "data_protection_officer") } // IsDSB checks if the user has DSB role func IsDSB(c *gin.Context) bool { role, exists := c.Get("role") if !exists { return false } roleStr, ok := role.(string) return ok && (roleStr == "data_protection_officer" || roleStr == "super_admin") } // GetUserID extracts the user ID from the context func GetUserID(c *gin.Context) (uuid.UUID, error) { userIDStr, exists := c.Get("user_id") if !exists { return uuid.Nil, nil } userID, err := uuid.Parse(userIDStr.(string)) if err != nil { return uuid.Nil, err } return userID, nil } // GetClientIP returns the client's IP address func GetClientIP(c *gin.Context) string { // Check X-Forwarded-For header first (for proxied requests) if xff := c.GetHeader("X-Forwarded-For"); xff != "" { ips := strings.Split(xff, ",") return strings.TrimSpace(ips[0]) } // Check X-Real-IP header if xri := c.GetHeader("X-Real-IP"); xri != "" { return xri } return c.ClientIP() } // GetUserAgent returns the client's User-Agent func GetUserAgent(c *gin.Context) string { return c.GetHeader("User-Agent") } // SuspensionCheckMiddleware checks if a user is suspended and restricts access // Suspended users can only access consent-related endpoints func SuspensionCheckMiddleware(pool interface{ QueryRow(ctx interface{}, sql string, args ...interface{}) interface{ Scan(dest ...interface{}) error } }) gin.HandlerFunc { return func(c *gin.Context) { userIDStr, exists := c.Get("user_id") if !exists { c.Next() return } userID, err := uuid.Parse(userIDStr.(string)) if err != nil { c.Next() return } // Check user account status var accountStatus string err = pool.QueryRow(c.Request.Context(), `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus) if err != nil { c.Next() return } if accountStatus == "suspended" { // Check if current path is allowed for suspended users path := c.Request.URL.Path allowedPaths := []string{ "/api/v1/consent", "/api/v1/documents", "/api/v1/notifications", "/api/v1/profile", "/api/v1/privacy/my-data", "/api/v1/auth/logout", } allowed := false for _, p := range allowedPaths { if strings.HasPrefix(path, p) { allowed = true break } } if !allowed { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": "account_suspended", "message": "Your account is suspended due to pending consent requirements", "redirect": "/consent/pending", }) return } // Set suspended flag in context for handlers to use c.Set("account_suspended", true) } c.Next() } } // IsSuspended checks if the current user's account is suspended func IsSuspended(c *gin.Context) bool { suspended, exists := c.Get("account_suspended") if !exists { return false } return suspended.(bool) }