feat: Add Academy, Whistleblower, Incidents, Vendor, DSB, SSO, Reporting, Multi-Tenant and Industry backends
Go handlers, models, stores and migrations for all SDK modules. Updates developer portal navigation and BYOEH page. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
631
ai-compliance-sdk/internal/api/handlers/sso_handlers.go
Normal file
631
ai-compliance-sdk/internal/api/handlers/sso_handlers.go
Normal file
@@ -0,0 +1,631 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
|
||||
"github.com/breakpilot/ai-compliance-sdk/internal/sso"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SSOHandlers handles SSO-related HTTP requests.
|
||||
type SSOHandlers struct {
|
||||
store *sso.Store
|
||||
jwtSecret string
|
||||
}
|
||||
|
||||
// NewSSOHandlers creates new SSO handlers.
|
||||
func NewSSOHandlers(store *sso.Store, jwtSecret string) *SSOHandlers {
|
||||
return &SSOHandlers{store: store, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSO Configuration CRUD
|
||||
// ============================================================================
|
||||
|
||||
// CreateConfig creates a new SSO configuration for the tenant.
|
||||
// POST /sdk/v1/sso/configs
|
||||
func (h *SSOHandlers) CreateConfig(c *gin.Context) {
|
||||
var req sso.CreateSSOConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
tenantID := rbac.GetTenantID(c)
|
||||
|
||||
cfg, err := h.store.CreateConfig(c.Request.Context(), tenantID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{"config": cfg})
|
||||
}
|
||||
|
||||
// ListConfigs lists all SSO configurations for the tenant.
|
||||
// GET /sdk/v1/sso/configs
|
||||
func (h *SSOHandlers) ListConfigs(c *gin.Context) {
|
||||
tenantID := rbac.GetTenantID(c)
|
||||
|
||||
configs, err := h.store.ListConfigs(c.Request.Context(), tenantID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"configs": configs,
|
||||
"total": len(configs),
|
||||
})
|
||||
}
|
||||
|
||||
// GetConfig retrieves an SSO configuration by ID.
|
||||
// GET /sdk/v1/sso/configs/:id
|
||||
func (h *SSOHandlers) GetConfig(c *gin.Context) {
|
||||
tenantID := rbac.GetTenantID(c)
|
||||
|
||||
configID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid config ID"})
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := h.store.GetConfig(c.Request.Context(), tenantID, configID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if cfg == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "sso configuration not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"config": cfg})
|
||||
}
|
||||
|
||||
// UpdateConfig updates an SSO configuration.
|
||||
// PUT /sdk/v1/sso/configs/:id
|
||||
func (h *SSOHandlers) UpdateConfig(c *gin.Context) {
|
||||
tenantID := rbac.GetTenantID(c)
|
||||
|
||||
configID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid config ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req sso.UpdateSSOConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := h.store.UpdateConfig(c.Request.Context(), tenantID, configID, &req)
|
||||
if err != nil {
|
||||
if err.Error() == "sso configuration not found" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"config": cfg})
|
||||
}
|
||||
|
||||
// DeleteConfig deletes an SSO configuration.
|
||||
// DELETE /sdk/v1/sso/configs/:id
|
||||
func (h *SSOHandlers) DeleteConfig(c *gin.Context) {
|
||||
tenantID := rbac.GetTenantID(c)
|
||||
|
||||
configID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid config ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.DeleteConfig(c.Request.Context(), tenantID, configID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "sso configuration deleted"})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSO Users
|
||||
// ============================================================================
|
||||
|
||||
// ListUsers lists all SSO-provisioned users for the tenant.
|
||||
// GET /sdk/v1/sso/users
|
||||
func (h *SSOHandlers) ListUsers(c *gin.Context) {
|
||||
tenantID := rbac.GetTenantID(c)
|
||||
|
||||
users, err := h.store.ListUsers(c.Request.Context(), tenantID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"users": users,
|
||||
"total": len(users),
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OIDC Flow
|
||||
// ============================================================================
|
||||
|
||||
// InitiateOIDCLogin initiates the OIDC authorization code flow.
|
||||
// It looks up the enabled SSO config for the tenant, builds the authorization
|
||||
// URL, sets a state cookie, and redirects the user to the IdP.
|
||||
// GET /sdk/v1/sso/oidc/login
|
||||
func (h *SSOHandlers) InitiateOIDCLogin(c *gin.Context) {
|
||||
// Resolve tenant ID from query param
|
||||
tenantIDStr := c.Query("tenant_id")
|
||||
if tenantIDStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "tenant_id query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
tenantID, err := uuid.Parse(tenantIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid tenant_id"})
|
||||
return
|
||||
}
|
||||
|
||||
// Look up the enabled SSO config
|
||||
cfg, err := h.store.GetEnabledConfig(c.Request.Context(), tenantID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if cfg == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no enabled SSO configuration found for this tenant"})
|
||||
return
|
||||
}
|
||||
if cfg.ProviderType != sso.ProviderTypeOIDC {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "SSO configuration is not OIDC"})
|
||||
return
|
||||
}
|
||||
|
||||
// Discover the authorization endpoint
|
||||
discoveryURL := strings.TrimSuffix(cfg.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
|
||||
authEndpoint, _, _, err := discoverOIDCEndpoints(discoveryURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("OIDC discovery failed: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate state parameter (random bytes + tenant_id for correlation)
|
||||
stateBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state"})
|
||||
return
|
||||
}
|
||||
state := base64.URLEncoding.EncodeToString(stateBytes) + "." + tenantID.String()
|
||||
|
||||
// Generate nonce
|
||||
nonceBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(nonceBytes); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce"})
|
||||
return
|
||||
}
|
||||
nonce := base64.URLEncoding.EncodeToString(nonceBytes)
|
||||
|
||||
// Build authorization URL
|
||||
scopes := cfg.OIDCScopes
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
|
||||
params := url.Values{
|
||||
"client_id": {cfg.OIDCClientID},
|
||||
"redirect_uri": {cfg.OIDCRedirectURI},
|
||||
"response_type": {"code"},
|
||||
"scope": {strings.Join(scopes, " ")},
|
||||
"state": {state},
|
||||
"nonce": {nonce},
|
||||
}
|
||||
|
||||
authURL := authEndpoint + "?" + params.Encode()
|
||||
|
||||
// Set state cookie for CSRF protection (HttpOnly, 10 min expiry)
|
||||
c.SetCookie("sso_state", state, 600, "/", "", true, true)
|
||||
c.SetCookie("sso_nonce", nonce, 600, "/", "", true, true)
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
// HandleOIDCCallback handles the OIDC authorization code callback from the IdP.
|
||||
// It validates the state, exchanges the code for tokens, extracts user info,
|
||||
// performs JIT user provisioning, and issues a JWT.
|
||||
// GET /sdk/v1/sso/oidc/callback
|
||||
func (h *SSOHandlers) HandleOIDCCallback(c *gin.Context) {
|
||||
// Check for errors from the IdP
|
||||
if errParam := c.Query("error"); errParam != "" {
|
||||
errDesc := c.Query("error_description")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": errParam,
|
||||
"description": errDesc,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
stateParam := c.Query("state")
|
||||
if code == "" || stateParam == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing code or state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state cookie
|
||||
stateCookie, err := c.Cookie("sso_state")
|
||||
if err != nil || stateCookie != stateParam {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid state parameter (CSRF check failed)"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract tenant ID from state
|
||||
parts := strings.SplitN(stateParam, ".", 2)
|
||||
if len(parts) != 2 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "malformed state parameter"})
|
||||
return
|
||||
}
|
||||
tenantID, err := uuid.Parse(parts[1])
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid tenant_id in state"})
|
||||
return
|
||||
}
|
||||
|
||||
// Look up the enabled SSO config
|
||||
cfg, err := h.store.GetEnabledConfig(c.Request.Context(), tenantID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if cfg == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no enabled SSO configuration found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Discover OIDC endpoints
|
||||
discoveryURL := strings.TrimSuffix(cfg.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
|
||||
_, tokenEndpoint, userInfoEndpoint, err := discoverOIDCEndpoints(discoveryURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("OIDC discovery failed: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange authorization code for tokens
|
||||
tokenResp, err := exchangeCodeForTokens(tokenEndpoint, code, cfg.OIDCClientID, cfg.OIDCClientSecret, cfg.OIDCRedirectURI)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("token exchange failed: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract user claims from ID token or UserInfo endpoint
|
||||
claims, err := extractUserClaims(tokenResp, userInfoEndpoint)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to extract user claims: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
sub := getStringClaim(claims, "sub")
|
||||
email := getStringClaim(claims, "email")
|
||||
name := getStringClaim(claims, "name")
|
||||
groups := getStringSliceClaim(claims, "groups")
|
||||
|
||||
if sub == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "ID token missing 'sub' claim"})
|
||||
return
|
||||
}
|
||||
if email == "" {
|
||||
email = sub
|
||||
}
|
||||
if name == "" {
|
||||
name = email
|
||||
}
|
||||
|
||||
// JIT provision the user
|
||||
user, err := h.store.UpsertUser(c.Request.Context(), tenantID, cfg.ID, sub, email, name, groups)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("user provisioning failed: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Determine roles from role mapping
|
||||
roles := resolveRoles(cfg, groups)
|
||||
|
||||
// Generate JWT
|
||||
ssoClaims := sso.SSOClaims{
|
||||
UserID: user.ID,
|
||||
TenantID: tenantID,
|
||||
Email: user.Email,
|
||||
DisplayName: user.DisplayName,
|
||||
Roles: roles,
|
||||
SSOConfigID: cfg.ID,
|
||||
}
|
||||
|
||||
jwtToken, err := h.generateJWT(ssoClaims)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("JWT generation failed: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Clear state cookies
|
||||
c.SetCookie("sso_state", "", -1, "/", "", true, true)
|
||||
c.SetCookie("sso_nonce", "", -1, "/", "", true, true)
|
||||
|
||||
// Return JWT as JSON (the frontend can also handle redirect)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": jwtToken,
|
||||
"user": user,
|
||||
"roles": roles,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// JWT Generation
|
||||
// ============================================================================
|
||||
|
||||
// generateJWT creates a signed JWT token containing the SSO claims.
|
||||
func (h *SSOHandlers) generateJWT(claims sso.SSOClaims) (string, error) {
|
||||
now := time.Now().UTC()
|
||||
expiry := now.Add(24 * time.Hour)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": claims.UserID.String(),
|
||||
"tenant_id": claims.TenantID.String(),
|
||||
"email": claims.Email,
|
||||
"display_name": claims.DisplayName,
|
||||
"roles": claims.Roles,
|
||||
"sso_config_id": claims.SSOConfigID.String(),
|
||||
"iss": "ai-compliance-sdk",
|
||||
"iat": now.Unix(),
|
||||
"exp": expiry.Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte(h.jwtSecret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign JWT: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OIDC Discovery & Token Exchange (manual HTTP, no external OIDC library)
|
||||
// ============================================================================
|
||||
|
||||
// oidcDiscoveryResponse holds the relevant fields from the OIDC discovery document.
|
||||
type oidcDiscoveryResponse struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint"`
|
||||
JwksURI string `json:"jwks_uri"`
|
||||
Issuer string `json:"issuer"`
|
||||
}
|
||||
|
||||
// discoverOIDCEndpoints fetches the OIDC discovery document and returns
|
||||
// the authorization, token, and userinfo endpoints.
|
||||
func discoverOIDCEndpoints(discoveryURL string) (authEndpoint, tokenEndpoint, userInfoEndpoint string, err error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
resp, err := client.Get(discoveryURL)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to fetch discovery document: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", "", "", fmt.Errorf("discovery endpoint returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var discovery oidcDiscoveryResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to decode discovery document: %w", err)
|
||||
}
|
||||
|
||||
if discovery.AuthorizationEndpoint == "" {
|
||||
return "", "", "", fmt.Errorf("discovery document missing authorization_endpoint")
|
||||
}
|
||||
if discovery.TokenEndpoint == "" {
|
||||
return "", "", "", fmt.Errorf("discovery document missing token_endpoint")
|
||||
}
|
||||
|
||||
return discovery.AuthorizationEndpoint, discovery.TokenEndpoint, discovery.UserinfoEndpoint, nil
|
||||
}
|
||||
|
||||
// oidcTokenResponse holds the response from the OIDC token endpoint.
|
||||
type oidcTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// exchangeCodeForTokens exchanges an authorization code for tokens at the token endpoint.
|
||||
func exchangeCodeForTokens(tokenEndpoint, code, clientID, clientSecret, redirectURI string) (*oidcTokenResponse, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
"client_id": {clientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// Use client_secret_basic if provided
|
||||
if clientSecret != "" {
|
||||
req.SetBasicAuth(clientID, clientSecret)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp oidcTokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// extractUserClaims extracts user claims from the ID token payload.
|
||||
// If the ID token is unavailable or incomplete, it falls back to the UserInfo endpoint.
|
||||
func extractUserClaims(tokenResp *oidcTokenResponse, userInfoEndpoint string) (map[string]interface{}, error) {
|
||||
claims := make(map[string]interface{})
|
||||
|
||||
// Try to decode ID token payload (without signature verification for claims extraction;
|
||||
// in production, you should verify the signature using the JWKS endpoint)
|
||||
if tokenResp.IDToken != "" {
|
||||
parts := strings.Split(tokenResp.IDToken, ".")
|
||||
if len(parts) == 3 {
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err == nil {
|
||||
if err := json.Unmarshal(payload, &claims); err == nil && claims["sub"] != nil {
|
||||
return claims, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to UserInfo endpoint
|
||||
if userInfoEndpoint != "" && tokenResp.AccessToken != "" {
|
||||
userClaims, err := fetchUserInfo(userInfoEndpoint, tokenResp.AccessToken)
|
||||
if err == nil && userClaims["sub"] != nil {
|
||||
return userClaims, nil
|
||||
}
|
||||
}
|
||||
|
||||
if claims["sub"] != nil {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("could not extract user claims from ID token or UserInfo endpoint")
|
||||
}
|
||||
|
||||
// fetchUserInfo calls the OIDC UserInfo endpoint with the access token.
|
||||
func fetchUserInfo(userInfoEndpoint, accessToken string) (map[string]interface{}, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, err := http.NewRequest("GET", userInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("userinfo endpoint returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Claim Extraction Helpers
|
||||
// ============================================================================
|
||||
|
||||
// getStringClaim extracts a string claim from a claims map.
|
||||
func getStringClaim(claims map[string]interface{}, key string) string {
|
||||
if v, ok := claims[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getStringSliceClaim extracts a string slice claim from a claims map.
|
||||
func getStringSliceClaim(claims map[string]interface{}, key string) []string {
|
||||
v, ok := claims[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
result := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
case []string:
|
||||
return val
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRoles maps SSO groups to internal roles using the config's role mapping.
|
||||
// If no groups match, the default role is returned.
|
||||
func resolveRoles(cfg *sso.SSOConfig, groups []string) []string {
|
||||
if cfg.RoleMapping == nil || len(cfg.RoleMapping) == 0 {
|
||||
if cfg.DefaultRoleID != nil {
|
||||
return []string{cfg.DefaultRoleID.String()}
|
||||
}
|
||||
return []string{"compliance_user"}
|
||||
}
|
||||
|
||||
roleSet := make(map[string]bool)
|
||||
for _, group := range groups {
|
||||
if role, ok := cfg.RoleMapping[group]; ok {
|
||||
roleSet[role] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(roleSet) == 0 {
|
||||
if cfg.DefaultRoleID != nil {
|
||||
return []string{cfg.DefaultRoleID.String()}
|
||||
}
|
||||
return []string{"compliance_user"}
|
||||
}
|
||||
|
||||
roles := make([]string, 0, len(roleSet))
|
||||
for role := range roleSet {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
return roles
|
||||
}
|
||||
Reference in New Issue
Block a user