package rbac import ( "context" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/google/uuid" ) // Context keys for RBAC data type contextKey string const ( ContextKeyUserID contextKey = "user_id" ContextKeyTenantID contextKey = "tenant_id" ContextKeyNamespaceID contextKey = "namespace_id" ContextKeyPermissions contextKey = "permissions" ContextKeyRoles contextKey = "roles" ContextKeyUserContext contextKey = "user_context" ) // Middleware provides RBAC middleware for Gin type Middleware struct { service *Service policyEngine *PolicyEngine } // NewMiddleware creates a new RBAC middleware func NewMiddleware(service *Service, policyEngine *PolicyEngine) *Middleware { return &Middleware{ service: service, policyEngine: policyEngine, } } // ExtractUserContext extracts user context from headers/JWT and stores in context // This middleware should run after authentication func (m *Middleware) ExtractUserContext() gin.HandlerFunc { return func(c *gin.Context) { // Extract user ID from header (set by auth middleware) userIDStr := c.GetHeader("X-User-ID") if userIDStr == "" { c.Next() return } userID, err := uuid.Parse(userIDStr) if err != nil { c.Next() return } // Extract tenant ID (from header or default) tenantIDStr := c.GetHeader("X-Tenant-ID") if tenantIDStr == "" { // Try to get from query param tenantIDStr = c.Query("tenant_id") } if tenantIDStr == "" { // Use default tenant slug tenantIDStr = c.GetHeader("X-Tenant-Slug") if tenantIDStr != "" { tenant, err := m.service.store.GetTenantBySlug(c.Request.Context(), tenantIDStr) if err == nil { tenantIDStr = tenant.ID.String() } } } var tenantID uuid.UUID if tenantIDStr != "" { tenantID, _ = uuid.Parse(tenantIDStr) } // Extract namespace ID (optional) var namespaceID *uuid.UUID namespaceIDStr := c.GetHeader("X-Namespace-ID") if namespaceIDStr == "" { namespaceIDStr = c.Query("namespace_id") } if namespaceIDStr != "" { if nsID, err := uuid.Parse(namespaceIDStr); err == nil { namespaceID = &nsID } } // Store in context c.Set(string(ContextKeyUserID), userID) c.Set(string(ContextKeyTenantID), tenantID) if namespaceID != nil { c.Set(string(ContextKeyNamespaceID), *namespaceID) } // Get effective permissions if tenantID != uuid.Nil { perms, err := m.service.GetEffectivePermissions(c.Request.Context(), userID, tenantID, namespaceID) if err == nil { c.Set(string(ContextKeyPermissions), perms.Permissions) c.Set(string(ContextKeyRoles), perms.Roles) } } c.Next() } } // RequirePermission requires the user to have a specific permission func (m *Middleware) RequirePermission(permission string) gin.HandlerFunc { return func(c *gin.Context) { userID, tenantID, namespaceID := m.extractIDs(c) if userID == uuid.Nil || tenantID == uuid.Nil { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "Authentication required", }) c.Abort() return } hasPermission, err := m.service.HasPermission(c.Request.Context(), userID, tenantID, namespaceID, permission) if err != nil || !hasPermission { c.JSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "Insufficient permissions", "required": permission, }) c.Abort() return } c.Next() } } // RequireAnyPermission requires the user to have any of the specified permissions func (m *Middleware) RequireAnyPermission(permissions ...string) gin.HandlerFunc { return func(c *gin.Context) { userID, tenantID, namespaceID := m.extractIDs(c) if userID == uuid.Nil || tenantID == uuid.Nil { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "Authentication required", }) c.Abort() return } hasPermission, err := m.service.HasAnyPermission(c.Request.Context(), userID, tenantID, namespaceID, permissions) if err != nil || !hasPermission { c.JSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "Insufficient permissions", "required": permissions, }) c.Abort() return } c.Next() } } // RequireAllPermissions requires the user to have all specified permissions func (m *Middleware) RequireAllPermissions(permissions ...string) gin.HandlerFunc { return func(c *gin.Context) { userID, tenantID, namespaceID := m.extractIDs(c) if userID == uuid.Nil || tenantID == uuid.Nil { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "Authentication required", }) c.Abort() return } hasPermission, err := m.service.HasAllPermissions(c.Request.Context(), userID, tenantID, namespaceID, permissions) if err != nil || !hasPermission { c.JSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "Insufficient permissions - all required", "required": permissions, }) c.Abort() return } c.Next() } } // RequireNamespaceAccess requires access to the specified namespace func (m *Middleware) RequireNamespaceAccess(operation string) gin.HandlerFunc { return func(c *gin.Context) { userID, tenantID, namespaceID := m.extractIDs(c) if userID == uuid.Nil || tenantID == uuid.Nil { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "Authentication required", }) c.Abort() return } // Get namespace ID from URL param if not in context if namespaceID == nil { nsIDStr := c.Param("namespace_id") if nsIDStr == "" { nsIDStr = c.Param("namespaceId") } if nsIDStr != "" { if nsID, err := uuid.Parse(nsIDStr); err == nil { namespaceID = &nsID } } } if namespaceID == nil { c.JSON(http.StatusBadRequest, gin.H{ "error": "bad_request", "message": "Namespace ID required", }) c.Abort() return } result, err := m.policyEngine.EvaluateNamespaceAccess(c.Request.Context(), &NamespaceAccessRequest{ UserID: userID, TenantID: tenantID, NamespaceID: *namespaceID, Operation: operation, }) if err != nil || !result.Allowed { reason := "access denied" if result != nil { reason = result.Reason } c.JSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "Namespace access denied", "reason": reason, "namespace": namespaceID.String(), }) c.Abort() return } // Store namespace access result in context c.Set("namespace_access", result) c.Next() } } // RequireLLMAccess validates LLM access based on policy func (m *Middleware) RequireLLMAccess() gin.HandlerFunc { return func(c *gin.Context) { userID, tenantID, namespaceID := m.extractIDs(c) if userID == uuid.Nil || tenantID == uuid.Nil { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "Authentication required", }) c.Abort() return } // Basic LLM permission check hasPermission, err := m.service.HasAnyPermission(c.Request.Context(), userID, tenantID, namespaceID, []string{ PermissionLLMAll, PermissionLLMQuery, PermissionLLMOwnQuery, }) if err != nil || !hasPermission { c.JSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "LLM access denied", }) c.Abort() return } c.Next() } } // RequireRole requires the user to have a specific role func (m *Middleware) RequireRole(role string) gin.HandlerFunc { return func(c *gin.Context) { roles, exists := c.Get(string(ContextKeyRoles)) if !exists { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "Authentication required", }) c.Abort() return } roleSlice, ok := roles.([]string) if !ok { c.JSON(http.StatusInternalServerError, gin.H{ "error": "internal_error", "message": "Invalid role data", }) c.Abort() return } hasRole := false for _, r := range roleSlice { if r == role { hasRole = true break } } if !hasRole { c.JSON(http.StatusForbidden, gin.H{ "error": "forbidden", "message": "Required role missing", "required": role, }) c.Abort() return } c.Next() } } // extractIDs extracts user, tenant, and namespace IDs from context func (m *Middleware) extractIDs(c *gin.Context) (uuid.UUID, uuid.UUID, *uuid.UUID) { var userID, tenantID uuid.UUID var namespaceID *uuid.UUID if id, exists := c.Get(string(ContextKeyUserID)); exists { userID = id.(uuid.UUID) } if id, exists := c.Get(string(ContextKeyTenantID)); exists { tenantID = id.(uuid.UUID) } if id, exists := c.Get(string(ContextKeyNamespaceID)); exists { nsID := id.(uuid.UUID) namespaceID = &nsID } return userID, tenantID, namespaceID } // GetUserID retrieves user ID from Gin context func GetUserID(c *gin.Context) uuid.UUID { if id, exists := c.Get(string(ContextKeyUserID)); exists { return id.(uuid.UUID) } return uuid.Nil } // GetTenantID retrieves tenant ID from Gin context func GetTenantID(c *gin.Context) uuid.UUID { if id, exists := c.Get(string(ContextKeyTenantID)); exists { return id.(uuid.UUID) } return uuid.Nil } // GetNamespaceID retrieves namespace ID from Gin context func GetNamespaceID(c *gin.Context) *uuid.UUID { if id, exists := c.Get(string(ContextKeyNamespaceID)); exists { nsID := id.(uuid.UUID) return &nsID } return nil } // GetPermissions retrieves permissions from Gin context func GetPermissions(c *gin.Context) []string { if perms, exists := c.Get(string(ContextKeyPermissions)); exists { return perms.([]string) } return []string{} } // GetRoles retrieves roles from Gin context func GetRoles(c *gin.Context) []string { if roles, exists := c.Get(string(ContextKeyRoles)); exists { return roles.([]string) } return []string{} } // ContextWithUserID adds user ID to context func ContextWithUserID(ctx context.Context, userID uuid.UUID) context.Context { return context.WithValue(ctx, ContextKeyUserID, userID) } // ContextWithTenantID adds tenant ID to context func ContextWithTenantID(ctx context.Context, tenantID uuid.UUID) context.Context { return context.WithValue(ctx, ContextKeyTenantID, tenantID) } // ContextWithNamespaceID adds namespace ID to context func ContextWithNamespaceID(ctx context.Context, namespaceID uuid.UUID) context.Context { return context.WithValue(ctx, ContextKeyNamespaceID, namespaceID) } // UserIDFromContext retrieves user ID from standard context func UserIDFromContext(ctx context.Context) uuid.UUID { if id, ok := ctx.Value(ContextKeyUserID).(uuid.UUID); ok { return id } return uuid.Nil } // TenantIDFromContext retrieves tenant ID from standard context func TenantIDFromContext(ctx context.Context) uuid.UUID { if id, ok := ctx.Value(ContextKeyTenantID).(uuid.UUID); ok { return id } return uuid.Nil } // NamespaceIDFromContext retrieves namespace ID from standard context func NamespaceIDFromContext(ctx context.Context) *uuid.UUID { if id, ok := ctx.Value(ContextKeyNamespaceID).(uuid.UUID); ok { return &id } return nil } // HasPermissionFromHeader checks permission from header-based context (for API keys) func (m *Middleware) HasPermissionFromHeader(c *gin.Context, permission string) bool { // Check X-API-Permissions header (set by API key auth) permsHeader := c.GetHeader("X-API-Permissions") if permsHeader != "" { perms := strings.Split(permsHeader, ",") for _, p := range perms { p = strings.TrimSpace(p) if p == permission { return true } // Wildcard check if strings.HasSuffix(p, ":*") { prefix := strings.TrimSuffix(p, "*") if strings.HasPrefix(permission, prefix) { return true } } } } return false }