package services import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "strings" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/breakpilot/consent-service/internal/models" ) var ( ErrInvalidClient = errors.New("invalid_client") ErrInvalidGrant = errors.New("invalid_grant") ErrInvalidScope = errors.New("invalid_scope") ErrInvalidRequest = errors.New("invalid_request") ErrUnauthorizedClient = errors.New("unauthorized_client") ErrAccessDenied = errors.New("access_denied") ErrInvalidRedirectURI = errors.New("invalid redirect_uri") ErrCodeExpired = errors.New("authorization code expired") ErrCodeUsed = errors.New("authorization code already used") ErrPKCERequired = errors.New("PKCE code_challenge required for public clients") ErrPKCEVerifyFailed = errors.New("PKCE verification failed") ) // OAuthService handles OAuth 2.0 Authorization Code Flow with PKCE type OAuthService struct { db *pgxpool.Pool jwtSecret string authCodeExpiration time.Duration accessTokenExpiration time.Duration refreshTokenExpiration time.Duration } // NewOAuthService creates a new OAuthService func NewOAuthService(db *pgxpool.Pool, jwtSecret string) *OAuthService { return &OAuthService{ db: db, jwtSecret: jwtSecret, authCodeExpiration: 10 * time.Minute, // Authorization codes expire quickly accessTokenExpiration: time.Hour, // 1 hour refreshTokenExpiration: 30 * 24 * time.Hour, // 30 days } } // ======================================== // Client Validation // ======================================== // ValidateClient validates an OAuth client func (s *OAuthService) ValidateClient(ctx context.Context, clientID string) (*models.OAuthClient, error) { var client models.OAuthClient var redirectURIsJSON, scopesJSON, grantTypesJSON []byte err := s.db.QueryRow(ctx, ` SELECT id, client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, is_active, created_at FROM oauth_clients WHERE client_id = $1 `, clientID).Scan( &client.ID, &client.ClientID, &client.ClientSecret, &client.Name, &client.Description, &redirectURIsJSON, &scopesJSON, &grantTypesJSON, &client.IsPublic, &client.IsActive, &client.CreatedAt, ) if err != nil { return nil, ErrInvalidClient } if !client.IsActive { return nil, ErrInvalidClient } // Parse JSON arrays json.Unmarshal(redirectURIsJSON, &client.RedirectURIs) json.Unmarshal(scopesJSON, &client.Scopes) json.Unmarshal(grantTypesJSON, &client.GrantTypes) return &client, nil } // ValidateClientSecret validates client credentials for confidential clients func (s *OAuthService) ValidateClientSecret(client *models.OAuthClient, clientSecret string) error { if client.IsPublic { // Public clients don't have a secret return nil } if client.ClientSecret != clientSecret { return ErrInvalidClient } return nil } // ValidateRedirectURI validates the redirect URI against registered URIs func (s *OAuthService) ValidateRedirectURI(client *models.OAuthClient, redirectURI string) error { for _, uri := range client.RedirectURIs { if uri == redirectURI { return nil } } return ErrInvalidRedirectURI } // ValidateScopes validates requested scopes against client's allowed scopes func (s *OAuthService) ValidateScopes(client *models.OAuthClient, requestedScopes string) ([]string, error) { if requestedScopes == "" { // Return default scopes return []string{"openid", "profile", "email"}, nil } requested := strings.Split(requestedScopes, " ") allowedMap := make(map[string]bool) for _, scope := range client.Scopes { allowedMap[scope] = true } var validScopes []string for _, scope := range requested { if allowedMap[scope] { validScopes = append(validScopes, scope) } } if len(validScopes) == 0 { return nil, ErrInvalidScope } return validScopes, nil } // ======================================== // Authorization Code // ======================================== // GenerateAuthorizationCode generates a new authorization code func (s *OAuthService) GenerateAuthorizationCode( ctx context.Context, client *models.OAuthClient, userID uuid.UUID, redirectURI string, scopes []string, codeChallenge, codeChallengeMethod string, ) (string, error) { // For public clients, PKCE is required if client.IsPublic && codeChallenge == "" { return "", ErrPKCERequired } // Generate a secure random code codeBytes := make([]byte, 32) if _, err := rand.Read(codeBytes); err != nil { return "", fmt.Errorf("failed to generate code: %w", err) } code := base64.URLEncoding.EncodeToString(codeBytes) // Hash the code for storage codeHash := sha256.Sum256([]byte(code)) hashedCode := hex.EncodeToString(codeHash[:]) scopesJSON, _ := json.Marshal(scopes) var challengePtr, methodPtr *string if codeChallenge != "" { challengePtr = &codeChallenge if codeChallengeMethod == "" { codeChallengeMethod = "plain" } methodPtr = &codeChallengeMethod } _, err := s.db.Exec(ctx, ` INSERT INTO oauth_authorization_codes (code, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) `, hashedCode, client.ClientID, userID, redirectURI, scopesJSON, challengePtr, methodPtr, time.Now().Add(s.authCodeExpiration)) if err != nil { return "", fmt.Errorf("failed to store authorization code: %w", err) } return code, nil } // ======================================== // Client Management (Admin) // ======================================== // GetClientByID retrieves an OAuth client by its client_id func (s *OAuthService) GetClientByID(ctx context.Context, clientID string) (*models.OAuthClient, error) { return s.ValidateClient(ctx, clientID) } // CreateClient creates a new OAuth client (admin only) func (s *OAuthService) CreateClient( ctx context.Context, name, description string, redirectURIs, scopes, grantTypes []string, isPublic bool, createdBy *uuid.UUID, ) (*models.OAuthClient, string, error) { // Generate client_id clientIDBytes := make([]byte, 16) rand.Read(clientIDBytes) clientID := hex.EncodeToString(clientIDBytes) // Generate client_secret for confidential clients var clientSecret string var clientSecretPtr *string if !isPublic { secretBytes := make([]byte, 32) rand.Read(secretBytes) clientSecret = base64.URLEncoding.EncodeToString(secretBytes) clientSecretPtr = &clientSecret } redirectURIsJSON, _ := json.Marshal(redirectURIs) scopesJSON, _ := json.Marshal(scopes) grantTypesJSON, _ := json.Marshal(grantTypes) var client models.OAuthClient err := s.db.QueryRow(ctx, ` INSERT INTO oauth_clients (client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, created_by) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, client_id, name, is_public, is_active, created_at `, clientID, clientSecretPtr, name, description, redirectURIsJSON, scopesJSON, grantTypesJSON, isPublic, createdBy).Scan( &client.ID, &client.ClientID, &client.Name, &client.IsPublic, &client.IsActive, &client.CreatedAt, ) if err != nil { return nil, "", fmt.Errorf("failed to create client: %w", err) } client.RedirectURIs = redirectURIs client.Scopes = scopes client.GrantTypes = grantTypes return &client, clientSecret, nil }