package handlers import ( "crypto/rand" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/breakpilot/ai-compliance-sdk/internal/rbac" "github.com/breakpilot/ai-compliance-sdk/internal/sso" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) // SSOHandlers handles SSO-related HTTP requests. type SSOHandlers struct { store *sso.Store jwtSecret string } // NewSSOHandlers creates new SSO handlers. func NewSSOHandlers(store *sso.Store, jwtSecret string) *SSOHandlers { return &SSOHandlers{store: store, jwtSecret: jwtSecret} } // ============================================================================ // SSO Configuration CRUD // ============================================================================ // CreateConfig creates a new SSO configuration for the tenant. // POST /sdk/v1/sso/configs func (h *SSOHandlers) CreateConfig(c *gin.Context) { var req sso.CreateSSOConfigRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } tenantID := rbac.GetTenantID(c) cfg, err := h.store.CreateConfig(c.Request.Context(), tenantID, &req) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusCreated, gin.H{"config": cfg}) } // ListConfigs lists all SSO configurations for the tenant. // GET /sdk/v1/sso/configs func (h *SSOHandlers) ListConfigs(c *gin.Context) { tenantID := rbac.GetTenantID(c) configs, err := h.store.ListConfigs(c.Request.Context(), tenantID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "configs": configs, "total": len(configs), }) } // GetConfig retrieves an SSO configuration by ID. // GET /sdk/v1/sso/configs/:id func (h *SSOHandlers) GetConfig(c *gin.Context) { tenantID := rbac.GetTenantID(c) configID, err := uuid.Parse(c.Param("id")) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid config ID"}) return } cfg, err := h.store.GetConfig(c.Request.Context(), tenantID, configID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } if cfg == nil { c.JSON(http.StatusNotFound, gin.H{"error": "sso configuration not found"}) return } c.JSON(http.StatusOK, gin.H{"config": cfg}) } // UpdateConfig updates an SSO configuration. // PUT /sdk/v1/sso/configs/:id func (h *SSOHandlers) UpdateConfig(c *gin.Context) { tenantID := rbac.GetTenantID(c) configID, err := uuid.Parse(c.Param("id")) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid config ID"}) return } var req sso.UpdateSSOConfigRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } cfg, err := h.store.UpdateConfig(c.Request.Context(), tenantID, configID, &req) if err != nil { if err.Error() == "sso configuration not found" { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"config": cfg}) } // DeleteConfig deletes an SSO configuration. // DELETE /sdk/v1/sso/configs/:id func (h *SSOHandlers) DeleteConfig(c *gin.Context) { tenantID := rbac.GetTenantID(c) configID, err := uuid.Parse(c.Param("id")) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid config ID"}) return } if err := h.store.DeleteConfig(c.Request.Context(), tenantID, configID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "sso configuration deleted"}) } // ============================================================================ // SSO Users // ============================================================================ // ListUsers lists all SSO-provisioned users for the tenant. // GET /sdk/v1/sso/users func (h *SSOHandlers) ListUsers(c *gin.Context) { tenantID := rbac.GetTenantID(c) users, err := h.store.ListUsers(c.Request.Context(), tenantID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "users": users, "total": len(users), }) } // ============================================================================ // OIDC Flow // ============================================================================ // InitiateOIDCLogin initiates the OIDC authorization code flow. // It looks up the enabled SSO config for the tenant, builds the authorization // URL, sets a state cookie, and redirects the user to the IdP. // GET /sdk/v1/sso/oidc/login func (h *SSOHandlers) InitiateOIDCLogin(c *gin.Context) { // Resolve tenant ID from query param tenantIDStr := c.Query("tenant_id") if tenantIDStr == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "tenant_id query parameter is required"}) return } tenantID, err := uuid.Parse(tenantIDStr) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid tenant_id"}) return } // Look up the enabled SSO config cfg, err := h.store.GetEnabledConfig(c.Request.Context(), tenantID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } if cfg == nil { c.JSON(http.StatusNotFound, gin.H{"error": "no enabled SSO configuration found for this tenant"}) return } if cfg.ProviderType != sso.ProviderTypeOIDC { c.JSON(http.StatusBadRequest, gin.H{"error": "SSO configuration is not OIDC"}) return } // Discover the authorization endpoint discoveryURL := strings.TrimSuffix(cfg.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" authEndpoint, _, _, err := discoverOIDCEndpoints(discoveryURL) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("OIDC discovery failed: %v", err)}) return } // Generate state parameter (random bytes + tenant_id for correlation) stateBytes := make([]byte, 32) if _, err := rand.Read(stateBytes); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state"}) return } state := base64.URLEncoding.EncodeToString(stateBytes) + "." + tenantID.String() // Generate nonce nonceBytes := make([]byte, 16) if _, err := rand.Read(nonceBytes); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce"}) return } nonce := base64.URLEncoding.EncodeToString(nonceBytes) // Build authorization URL scopes := cfg.OIDCScopes if len(scopes) == 0 { scopes = []string{"openid", "profile", "email"} } params := url.Values{ "client_id": {cfg.OIDCClientID}, "redirect_uri": {cfg.OIDCRedirectURI}, "response_type": {"code"}, "scope": {strings.Join(scopes, " ")}, "state": {state}, "nonce": {nonce}, } authURL := authEndpoint + "?" + params.Encode() // Set state cookie for CSRF protection (HttpOnly, 10 min expiry) c.SetCookie("sso_state", state, 600, "/", "", true, true) c.SetCookie("sso_nonce", nonce, 600, "/", "", true, true) c.Redirect(http.StatusFound, authURL) } // HandleOIDCCallback handles the OIDC authorization code callback from the IdP. // It validates the state, exchanges the code for tokens, extracts user info, // performs JIT user provisioning, and issues a JWT. // GET /sdk/v1/sso/oidc/callback func (h *SSOHandlers) HandleOIDCCallback(c *gin.Context) { // Check for errors from the IdP if errParam := c.Query("error"); errParam != "" { errDesc := c.Query("error_description") c.JSON(http.StatusBadRequest, gin.H{ "error": errParam, "description": errDesc, }) return } code := c.Query("code") stateParam := c.Query("state") if code == "" || stateParam == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "missing code or state parameter"}) return } // Validate state cookie stateCookie, err := c.Cookie("sso_state") if err != nil || stateCookie != stateParam { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid state parameter (CSRF check failed)"}) return } // Extract tenant ID from state parts := strings.SplitN(stateParam, ".", 2) if len(parts) != 2 { c.JSON(http.StatusBadRequest, gin.H{"error": "malformed state parameter"}) return } tenantID, err := uuid.Parse(parts[1]) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid tenant_id in state"}) return } // Look up the enabled SSO config cfg, err := h.store.GetEnabledConfig(c.Request.Context(), tenantID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } if cfg == nil { c.JSON(http.StatusNotFound, gin.H{"error": "no enabled SSO configuration found"}) return } // Discover OIDC endpoints discoveryURL := strings.TrimSuffix(cfg.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" _, tokenEndpoint, userInfoEndpoint, err := discoverOIDCEndpoints(discoveryURL) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("OIDC discovery failed: %v", err)}) return } // Exchange authorization code for tokens tokenResp, err := exchangeCodeForTokens(tokenEndpoint, code, cfg.OIDCClientID, cfg.OIDCClientSecret, cfg.OIDCRedirectURI) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("token exchange failed: %v", err)}) return } // Extract user claims from ID token or UserInfo endpoint claims, err := extractUserClaims(tokenResp, userInfoEndpoint) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to extract user claims: %v", err)}) return } sub := getStringClaim(claims, "sub") email := getStringClaim(claims, "email") name := getStringClaim(claims, "name") groups := getStringSliceClaim(claims, "groups") if sub == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "ID token missing 'sub' claim"}) return } if email == "" { email = sub } if name == "" { name = email } // JIT provision the user user, err := h.store.UpsertUser(c.Request.Context(), tenantID, cfg.ID, sub, email, name, groups) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("user provisioning failed: %v", err)}) return } // Determine roles from role mapping roles := resolveRoles(cfg, groups) // Generate JWT ssoClaims := sso.SSOClaims{ UserID: user.ID, TenantID: tenantID, Email: user.Email, DisplayName: user.DisplayName, Roles: roles, SSOConfigID: cfg.ID, } jwtToken, err := h.generateJWT(ssoClaims) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("JWT generation failed: %v", err)}) return } // Clear state cookies c.SetCookie("sso_state", "", -1, "/", "", true, true) c.SetCookie("sso_nonce", "", -1, "/", "", true, true) // Return JWT as JSON (the frontend can also handle redirect) c.JSON(http.StatusOK, gin.H{ "token": jwtToken, "user": user, "roles": roles, }) } // ============================================================================ // JWT Generation // ============================================================================ // generateJWT creates a signed JWT token containing the SSO claims. func (h *SSOHandlers) generateJWT(claims sso.SSOClaims) (string, error) { now := time.Now().UTC() expiry := now.Add(24 * time.Hour) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": claims.UserID.String(), "tenant_id": claims.TenantID.String(), "email": claims.Email, "display_name": claims.DisplayName, "roles": claims.Roles, "sso_config_id": claims.SSOConfigID.String(), "iss": "ai-compliance-sdk", "iat": now.Unix(), "exp": expiry.Unix(), }) tokenString, err := token.SignedString([]byte(h.jwtSecret)) if err != nil { return "", fmt.Errorf("failed to sign JWT: %w", err) } return tokenString, nil } // ============================================================================ // OIDC Discovery & Token Exchange (manual HTTP, no external OIDC library) // ============================================================================ // oidcDiscoveryResponse holds the relevant fields from the OIDC discovery document. type oidcDiscoveryResponse struct { AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` UserinfoEndpoint string `json:"userinfo_endpoint"` JwksURI string `json:"jwks_uri"` Issuer string `json:"issuer"` } // discoverOIDCEndpoints fetches the OIDC discovery document and returns // the authorization, token, and userinfo endpoints. func discoverOIDCEndpoints(discoveryURL string) (authEndpoint, tokenEndpoint, userInfoEndpoint string, err error) { client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Get(discoveryURL) if err != nil { return "", "", "", fmt.Errorf("failed to fetch discovery document: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return "", "", "", fmt.Errorf("discovery endpoint returned %d: %s", resp.StatusCode, string(body)) } var discovery oidcDiscoveryResponse if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { return "", "", "", fmt.Errorf("failed to decode discovery document: %w", err) } if discovery.AuthorizationEndpoint == "" { return "", "", "", fmt.Errorf("discovery document missing authorization_endpoint") } if discovery.TokenEndpoint == "" { return "", "", "", fmt.Errorf("discovery document missing token_endpoint") } return discovery.AuthorizationEndpoint, discovery.TokenEndpoint, discovery.UserinfoEndpoint, nil } // oidcTokenResponse holds the response from the OIDC token endpoint. type oidcTokenResponse struct { AccessToken string `json:"access_token"` IDToken string `json:"id_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` } // exchangeCodeForTokens exchanges an authorization code for tokens at the token endpoint. func exchangeCodeForTokens(tokenEndpoint, code, clientID, clientSecret, redirectURI string) (*oidcTokenResponse, error) { client := &http.Client{Timeout: 10 * time.Second} data := url.Values{ "grant_type": {"authorization_code"}, "code": {code}, "client_id": {clientID}, "redirect_uri": {redirectURI}, } req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode())) if err != nil { return nil, fmt.Errorf("failed to create token request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Use client_secret_basic if provided if clientSecret != "" { req.SetBasicAuth(clientID, clientSecret) } resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("token request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body)) } var tokenResp oidcTokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { return nil, fmt.Errorf("failed to decode token response: %w", err) } return &tokenResp, nil } // extractUserClaims extracts user claims from the ID token payload. // If the ID token is unavailable or incomplete, it falls back to the UserInfo endpoint. func extractUserClaims(tokenResp *oidcTokenResponse, userInfoEndpoint string) (map[string]interface{}, error) { claims := make(map[string]interface{}) // Try to decode ID token payload (without signature verification for claims extraction; // in production, you should verify the signature using the JWKS endpoint) if tokenResp.IDToken != "" { parts := strings.Split(tokenResp.IDToken, ".") if len(parts) == 3 { payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err == nil { if err := json.Unmarshal(payload, &claims); err == nil && claims["sub"] != nil { return claims, nil } } } } // Fallback to UserInfo endpoint if userInfoEndpoint != "" && tokenResp.AccessToken != "" { userClaims, err := fetchUserInfo(userInfoEndpoint, tokenResp.AccessToken) if err == nil && userClaims["sub"] != nil { return userClaims, nil } } if claims["sub"] != nil { return claims, nil } return nil, fmt.Errorf("could not extract user claims from ID token or UserInfo endpoint") } // fetchUserInfo calls the OIDC UserInfo endpoint with the access token. func fetchUserInfo(userInfoEndpoint, accessToken string) (map[string]interface{}, error) { client := &http.Client{Timeout: 10 * time.Second} req, err := http.NewRequest("GET", userInfoEndpoint, nil) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("userinfo endpoint returned %d", resp.StatusCode) } var claims map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&claims); err != nil { return nil, err } return claims, nil } // ============================================================================ // Claim Extraction Helpers // ============================================================================ // getStringClaim extracts a string claim from a claims map. func getStringClaim(claims map[string]interface{}, key string) string { if v, ok := claims[key]; ok { if s, ok := v.(string); ok { return s } } return "" } // getStringSliceClaim extracts a string slice claim from a claims map. func getStringSliceClaim(claims map[string]interface{}, key string) []string { v, ok := claims[key] if !ok { return nil } switch val := v.(type) { case []interface{}: result := make([]string, 0, len(val)) for _, item := range val { if s, ok := item.(string); ok { result = append(result, s) } } return result case []string: return val default: return nil } } // resolveRoles maps SSO groups to internal roles using the config's role mapping. // If no groups match, the default role is returned. func resolveRoles(cfg *sso.SSOConfig, groups []string) []string { if cfg.RoleMapping == nil || len(cfg.RoleMapping) == 0 { if cfg.DefaultRoleID != nil { return []string{cfg.DefaultRoleID.String()} } return []string{"compliance_user"} } roleSet := make(map[string]bool) for _, group := range groups { if role, ok := cfg.RoleMapping[group]; ok { roleSet[role] = true } } if len(roleSet) == 0 { if cfg.DefaultRoleID != nil { return []string{cfg.DefaultRoleID.String()} } return []string{"compliance_user"} } roles := make([]string, 0, len(roleSet)) for role := range roleSet { roles = append(roles, role) } return roles }