package services import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/breakpilot/consent-service/internal/models" ) var ( ErrInvalidClient = errors.New("invalid_client") ErrInvalidGrant = errors.New("invalid_grant") ErrInvalidScope = errors.New("invalid_scope") ErrInvalidRequest = errors.New("invalid_request") ErrUnauthorizedClient = errors.New("unauthorized_client") ErrAccessDenied = errors.New("access_denied") ErrInvalidRedirectURI = errors.New("invalid redirect_uri") ErrCodeExpired = errors.New("authorization code expired") ErrCodeUsed = errors.New("authorization code already used") ErrPKCERequired = errors.New("PKCE code_challenge required for public clients") ErrPKCEVerifyFailed = errors.New("PKCE verification failed") ) // OAuthService handles OAuth 2.0 Authorization Code Flow with PKCE type OAuthService struct { db *pgxpool.Pool jwtSecret string authCodeExpiration time.Duration accessTokenExpiration time.Duration refreshTokenExpiration time.Duration } // NewOAuthService creates a new OAuthService func NewOAuthService(db *pgxpool.Pool, jwtSecret string) *OAuthService { return &OAuthService{ db: db, jwtSecret: jwtSecret, authCodeExpiration: 10 * time.Minute, // Authorization codes expire quickly accessTokenExpiration: time.Hour, // 1 hour refreshTokenExpiration: 30 * 24 * time.Hour, // 30 days } } // ValidateClient validates an OAuth client func (s *OAuthService) ValidateClient(ctx context.Context, clientID string) (*models.OAuthClient, error) { var client models.OAuthClient var redirectURIsJSON, scopesJSON, grantTypesJSON []byte err := s.db.QueryRow(ctx, ` SELECT id, client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, is_active, created_at FROM oauth_clients WHERE client_id = $1 `, clientID).Scan( &client.ID, &client.ClientID, &client.ClientSecret, &client.Name, &client.Description, &redirectURIsJSON, &scopesJSON, &grantTypesJSON, &client.IsPublic, &client.IsActive, &client.CreatedAt, ) if err != nil { return nil, ErrInvalidClient } if !client.IsActive { return nil, ErrInvalidClient } // Parse JSON arrays json.Unmarshal(redirectURIsJSON, &client.RedirectURIs) json.Unmarshal(scopesJSON, &client.Scopes) json.Unmarshal(grantTypesJSON, &client.GrantTypes) return &client, nil } // ValidateClientSecret validates client credentials for confidential clients func (s *OAuthService) ValidateClientSecret(client *models.OAuthClient, clientSecret string) error { if client.IsPublic { // Public clients don't have a secret return nil } if client.ClientSecret != clientSecret { return ErrInvalidClient } return nil } // ValidateRedirectURI validates the redirect URI against registered URIs func (s *OAuthService) ValidateRedirectURI(client *models.OAuthClient, redirectURI string) error { for _, uri := range client.RedirectURIs { if uri == redirectURI { return nil } } return ErrInvalidRedirectURI } // ValidateScopes validates requested scopes against client's allowed scopes func (s *OAuthService) ValidateScopes(client *models.OAuthClient, requestedScopes string) ([]string, error) { if requestedScopes == "" { // Return default scopes return []string{"openid", "profile", "email"}, nil } requested := strings.Split(requestedScopes, " ") allowedMap := make(map[string]bool) for _, scope := range client.Scopes { allowedMap[scope] = true } var validScopes []string for _, scope := range requested { if allowedMap[scope] { validScopes = append(validScopes, scope) } } if len(validScopes) == 0 { return nil, ErrInvalidScope } return validScopes, nil } // GenerateAuthorizationCode generates a new authorization code func (s *OAuthService) GenerateAuthorizationCode( ctx context.Context, client *models.OAuthClient, userID uuid.UUID, redirectURI string, scopes []string, codeChallenge, codeChallengeMethod string, ) (string, error) { // For public clients, PKCE is required if client.IsPublic && codeChallenge == "" { return "", ErrPKCERequired } // Generate a secure random code codeBytes := make([]byte, 32) if _, err := rand.Read(codeBytes); err != nil { return "", fmt.Errorf("failed to generate code: %w", err) } code := base64.URLEncoding.EncodeToString(codeBytes) // Hash the code for storage codeHash := sha256.Sum256([]byte(code)) hashedCode := hex.EncodeToString(codeHash[:]) scopesJSON, _ := json.Marshal(scopes) var challengePtr, methodPtr *string if codeChallenge != "" { challengePtr = &codeChallenge if codeChallengeMethod == "" { codeChallengeMethod = "plain" } methodPtr = &codeChallengeMethod } _, err := s.db.Exec(ctx, ` INSERT INTO oauth_authorization_codes (code, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) `, hashedCode, client.ClientID, userID, redirectURI, scopesJSON, challengePtr, methodPtr, time.Now().Add(s.authCodeExpiration)) if err != nil { return "", fmt.Errorf("failed to store authorization code: %w", err) } return code, nil } // ExchangeAuthorizationCode exchanges an authorization code for tokens func (s *OAuthService) ExchangeAuthorizationCode( ctx context.Context, code string, clientID string, redirectURI string, codeVerifier string, ) (*models.OAuthTokenResponse, error) { // Hash the code to look it up codeHash := sha256.Sum256([]byte(code)) hashedCode := hex.EncodeToString(codeHash[:]) var authCode models.OAuthAuthorizationCode var scopesJSON []byte err := s.db.QueryRow(ctx, ` SELECT id, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at, used_at FROM oauth_authorization_codes WHERE code = $1 `, hashedCode).Scan( &authCode.ID, &authCode.ClientID, &authCode.UserID, &authCode.RedirectURI, &scopesJSON, &authCode.CodeChallenge, &authCode.CodeChallengeMethod, &authCode.ExpiresAt, &authCode.UsedAt, ) if err != nil { return nil, ErrInvalidGrant } // Check if code was already used if authCode.UsedAt != nil { return nil, ErrCodeUsed } // Check if code is expired if time.Now().After(authCode.ExpiresAt) { return nil, ErrCodeExpired } // Verify client_id matches if authCode.ClientID != clientID { return nil, ErrInvalidGrant } // Verify redirect_uri matches if authCode.RedirectURI != redirectURI { return nil, ErrInvalidGrant } // Verify PKCE if code_challenge was provided if authCode.CodeChallenge != nil && *authCode.CodeChallenge != "" { if codeVerifier == "" { return nil, ErrPKCEVerifyFailed } var expectedChallenge string if authCode.CodeChallengeMethod != nil && *authCode.CodeChallengeMethod == "S256" { // SHA256 hash of verifier hash := sha256.Sum256([]byte(codeVerifier)) expectedChallenge = base64.RawURLEncoding.EncodeToString(hash[:]) } else { // Plain method expectedChallenge = codeVerifier } if expectedChallenge != *authCode.CodeChallenge { return nil, ErrPKCEVerifyFailed } } // Mark code as used _, err = s.db.Exec(ctx, `UPDATE oauth_authorization_codes SET used_at = NOW() WHERE id = $1`, authCode.ID) if err != nil { return nil, fmt.Errorf("failed to mark code as used: %w", err) } // Parse scopes var scopes []string json.Unmarshal(scopesJSON, &scopes) // Generate tokens return s.generateTokens(ctx, clientID, authCode.UserID, scopes) } // RefreshAccessToken refreshes an access token using a refresh token func (s *OAuthService) RefreshAccessToken(ctx context.Context, refreshToken, clientID string, requestedScope string) (*models.OAuthTokenResponse, error) { // Hash the refresh token tokenHash := sha256.Sum256([]byte(refreshToken)) hashedToken := hex.EncodeToString(tokenHash[:]) var rt models.OAuthRefreshToken var scopesJSON []byte err := s.db.QueryRow(ctx, ` SELECT id, client_id, user_id, scopes, expires_at, revoked_at FROM oauth_refresh_tokens WHERE token_hash = $1 `, hashedToken).Scan( &rt.ID, &rt.ClientID, &rt.UserID, &scopesJSON, &rt.ExpiresAt, &rt.RevokedAt, ) if err != nil { return nil, ErrInvalidGrant } // Check if token is revoked if rt.RevokedAt != nil { return nil, ErrInvalidGrant } // Check if token is expired if time.Now().After(rt.ExpiresAt) { return nil, ErrInvalidGrant } // Verify client_id matches if rt.ClientID != clientID { return nil, ErrInvalidGrant } // Parse original scopes var originalScopes []string json.Unmarshal(scopesJSON, &originalScopes) // Determine scopes for new tokens var scopes []string if requestedScope != "" { // Validate that requested scopes are subset of original scopes originalMap := make(map[string]bool) for _, s := range originalScopes { originalMap[s] = true } for _, s := range strings.Split(requestedScope, " ") { if originalMap[s] { scopes = append(scopes, s) } } if len(scopes) == 0 { return nil, ErrInvalidScope } } else { scopes = originalScopes } // Revoke old refresh token (rotate) _, _ = s.db.Exec(ctx, `UPDATE oauth_refresh_tokens SET revoked_at = NOW() WHERE id = $1`, rt.ID) // Generate new tokens return s.generateTokens(ctx, clientID, rt.UserID, scopes) } // generateTokens generates access and refresh tokens func (s *OAuthService) generateTokens(ctx context.Context, clientID string, userID uuid.UUID, scopes []string) (*models.OAuthTokenResponse, error) { // Get user info for JWT var user models.User err := s.db.QueryRow(ctx, ` SELECT id, email, name, role, account_status FROM users WHERE id = $1 `, userID).Scan(&user.ID, &user.Email, &user.Name, &user.Role, &user.AccountStatus) if err != nil { return nil, ErrInvalidGrant } // Generate access token (JWT) accessTokenClaims := jwt.MapClaims{ "sub": userID.String(), "email": user.Email, "role": user.Role, "account_status": user.AccountStatus, "client_id": clientID, "scope": strings.Join(scopes, " "), "iat": time.Now().Unix(), "exp": time.Now().Add(s.accessTokenExpiration).Unix(), "iss": "breakpilot-consent-service", "aud": clientID, } if user.Name != nil { accessTokenClaims["name"] = *user.Name } accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessTokenClaims) accessTokenString, err := accessToken.SignedString([]byte(s.jwtSecret)) if err != nil { return nil, fmt.Errorf("failed to sign access token: %w", err) } // Hash access token for storage accessTokenHash := sha256.Sum256([]byte(accessTokenString)) hashedAccessToken := hex.EncodeToString(accessTokenHash[:]) scopesJSON, _ := json.Marshal(scopes) // Store access token var accessTokenID uuid.UUID err = s.db.QueryRow(ctx, ` INSERT INTO oauth_access_tokens (token_hash, client_id, user_id, scopes, expires_at) VALUES ($1, $2, $3, $4, $5) RETURNING id `, hashedAccessToken, clientID, userID, scopesJSON, time.Now().Add(s.accessTokenExpiration)).Scan(&accessTokenID) if err != nil { return nil, fmt.Errorf("failed to store access token: %w", err) } // Generate refresh token (opaque) refreshTokenBytes := make([]byte, 32) if _, err := rand.Read(refreshTokenBytes); err != nil { return nil, fmt.Errorf("failed to generate refresh token: %w", err) } refreshTokenString := base64.URLEncoding.EncodeToString(refreshTokenBytes) // Hash refresh token for storage refreshTokenHash := sha256.Sum256([]byte(refreshTokenString)) hashedRefreshToken := hex.EncodeToString(refreshTokenHash[:]) // Store refresh token _, err = s.db.Exec(ctx, ` INSERT INTO oauth_refresh_tokens (token_hash, access_token_id, client_id, user_id, scopes, expires_at) VALUES ($1, $2, $3, $4, $5, $6) `, hashedRefreshToken, accessTokenID, clientID, userID, scopesJSON, time.Now().Add(s.refreshTokenExpiration)) if err != nil { return nil, fmt.Errorf("failed to store refresh token: %w", err) } return &models.OAuthTokenResponse{ AccessToken: accessTokenString, TokenType: "Bearer", ExpiresIn: int(s.accessTokenExpiration.Seconds()), RefreshToken: refreshTokenString, Scope: strings.Join(scopes, " "), }, nil } // RevokeToken revokes an access or refresh token func (s *OAuthService) RevokeToken(ctx context.Context, token, tokenTypeHint string) error { tokenHash := sha256.Sum256([]byte(token)) hashedToken := hex.EncodeToString(tokenHash[:]) // Try to revoke as access token if tokenTypeHint == "" || tokenTypeHint == "access_token" { result, err := s.db.Exec(ctx, `UPDATE oauth_access_tokens SET revoked_at = NOW() WHERE token_hash = $1`, hashedToken) if err == nil && result.RowsAffected() > 0 { return nil } } // Try to revoke as refresh token if tokenTypeHint == "" || tokenTypeHint == "refresh_token" { result, err := s.db.Exec(ctx, `UPDATE oauth_refresh_tokens SET revoked_at = NOW() WHERE token_hash = $1`, hashedToken) if err == nil && result.RowsAffected() > 0 { return nil } } return nil // RFC 7009: Always return success } // ValidateAccessToken validates an OAuth access token func (s *OAuthService) ValidateAccessToken(ctx context.Context, tokenString string) (*jwt.MapClaims, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(s.jwtSecret), nil }) if err != nil { return nil, ErrInvalidToken } claims, ok := token.Claims.(jwt.MapClaims) if !ok || !token.Valid { return nil, ErrInvalidToken } // Check if token is revoked in database tokenHash := sha256.Sum256([]byte(tokenString)) hashedToken := hex.EncodeToString(tokenHash[:]) var revokedAt *time.Time err = s.db.QueryRow(ctx, `SELECT revoked_at FROM oauth_access_tokens WHERE token_hash = $1`, hashedToken).Scan(&revokedAt) if err == nil && revokedAt != nil { return nil, ErrInvalidToken } return &claims, nil } // GetClientByID retrieves an OAuth client by its client_id func (s *OAuthService) GetClientByID(ctx context.Context, clientID string) (*models.OAuthClient, error) { return s.ValidateClient(ctx, clientID) } // CreateClient creates a new OAuth client (admin only) func (s *OAuthService) CreateClient( ctx context.Context, name, description string, redirectURIs, scopes, grantTypes []string, isPublic bool, createdBy *uuid.UUID, ) (*models.OAuthClient, string, error) { // Generate client_id clientIDBytes := make([]byte, 16) rand.Read(clientIDBytes) clientID := hex.EncodeToString(clientIDBytes) // Generate client_secret for confidential clients var clientSecret string var clientSecretPtr *string if !isPublic { secretBytes := make([]byte, 32) rand.Read(secretBytes) clientSecret = base64.URLEncoding.EncodeToString(secretBytes) clientSecretPtr = &clientSecret } redirectURIsJSON, _ := json.Marshal(redirectURIs) scopesJSON, _ := json.Marshal(scopes) grantTypesJSON, _ := json.Marshal(grantTypes) var client models.OAuthClient err := s.db.QueryRow(ctx, ` INSERT INTO oauth_clients (client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, created_by) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, client_id, name, is_public, is_active, created_at `, clientID, clientSecretPtr, name, description, redirectURIsJSON, scopesJSON, grantTypesJSON, isPublic, createdBy).Scan( &client.ID, &client.ClientID, &client.Name, &client.IsPublic, &client.IsActive, &client.CreatedAt, ) if err != nil { return nil, "", fmt.Errorf("failed to create client: %w", err) } client.RedirectURIs = redirectURIs client.Scopes = scopes client.GrantTypes = grantTypes return &client, clientSecret, nil }