Initial commit: breakpilot-core - Shared Infrastructure
Docker Compose with 24+ services: - PostgreSQL (PostGIS), Valkey, MinIO, Qdrant - Vault (PKI/TLS), Nginx (Reverse Proxy) - Backend Core API, Consent Service, Billing Service - RAG Service, Embedding Service - Gitea, Woodpecker CI/CD - Night Scheduler, Health Aggregator - Jitsi (Web/XMPP/JVB/Jicofo), Mailpit Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
524
consent-service/internal/services/oauth_service.go
Normal file
524
consent-service/internal/services/oauth_service.go
Normal file
@@ -0,0 +1,524 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidClient = errors.New("invalid_client")
|
||||
ErrInvalidGrant = errors.New("invalid_grant")
|
||||
ErrInvalidScope = errors.New("invalid_scope")
|
||||
ErrInvalidRequest = errors.New("invalid_request")
|
||||
ErrUnauthorizedClient = errors.New("unauthorized_client")
|
||||
ErrAccessDenied = errors.New("access_denied")
|
||||
ErrInvalidRedirectURI = errors.New("invalid redirect_uri")
|
||||
ErrCodeExpired = errors.New("authorization code expired")
|
||||
ErrCodeUsed = errors.New("authorization code already used")
|
||||
ErrPKCERequired = errors.New("PKCE code_challenge required for public clients")
|
||||
ErrPKCEVerifyFailed = errors.New("PKCE verification failed")
|
||||
)
|
||||
|
||||
// OAuthService handles OAuth 2.0 Authorization Code Flow with PKCE
|
||||
type OAuthService struct {
|
||||
db *pgxpool.Pool
|
||||
jwtSecret string
|
||||
authCodeExpiration time.Duration
|
||||
accessTokenExpiration time.Duration
|
||||
refreshTokenExpiration time.Duration
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuthService
|
||||
func NewOAuthService(db *pgxpool.Pool, jwtSecret string) *OAuthService {
|
||||
return &OAuthService{
|
||||
db: db,
|
||||
jwtSecret: jwtSecret,
|
||||
authCodeExpiration: 10 * time.Minute, // Authorization codes expire quickly
|
||||
accessTokenExpiration: time.Hour, // 1 hour
|
||||
refreshTokenExpiration: 30 * 24 * time.Hour, // 30 days
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateClient validates an OAuth client
|
||||
func (s *OAuthService) ValidateClient(ctx context.Context, clientID string) (*models.OAuthClient, error) {
|
||||
var client models.OAuthClient
|
||||
var redirectURIsJSON, scopesJSON, grantTypesJSON []byte
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, is_active, created_at
|
||||
FROM oauth_clients WHERE client_id = $1
|
||||
`, clientID).Scan(
|
||||
&client.ID, &client.ClientID, &client.ClientSecret, &client.Name, &client.Description,
|
||||
&redirectURIsJSON, &scopesJSON, &grantTypesJSON, &client.IsPublic, &client.IsActive, &client.CreatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
if !client.IsActive {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
// Parse JSON arrays
|
||||
json.Unmarshal(redirectURIsJSON, &client.RedirectURIs)
|
||||
json.Unmarshal(scopesJSON, &client.Scopes)
|
||||
json.Unmarshal(grantTypesJSON, &client.GrantTypes)
|
||||
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
// ValidateClientSecret validates client credentials for confidential clients
|
||||
func (s *OAuthService) ValidateClientSecret(client *models.OAuthClient, clientSecret string) error {
|
||||
if client.IsPublic {
|
||||
// Public clients don't have a secret
|
||||
return nil
|
||||
}
|
||||
|
||||
if client.ClientSecret != clientSecret {
|
||||
return ErrInvalidClient
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRedirectURI validates the redirect URI against registered URIs
|
||||
func (s *OAuthService) ValidateRedirectURI(client *models.OAuthClient, redirectURI string) error {
|
||||
for _, uri := range client.RedirectURIs {
|
||||
if uri == redirectURI {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrInvalidRedirectURI
|
||||
}
|
||||
|
||||
// ValidateScopes validates requested scopes against client's allowed scopes
|
||||
func (s *OAuthService) ValidateScopes(client *models.OAuthClient, requestedScopes string) ([]string, error) {
|
||||
if requestedScopes == "" {
|
||||
// Return default scopes
|
||||
return []string{"openid", "profile", "email"}, nil
|
||||
}
|
||||
|
||||
requested := strings.Split(requestedScopes, " ")
|
||||
allowedMap := make(map[string]bool)
|
||||
for _, scope := range client.Scopes {
|
||||
allowedMap[scope] = true
|
||||
}
|
||||
|
||||
var validScopes []string
|
||||
for _, scope := range requested {
|
||||
if allowedMap[scope] {
|
||||
validScopes = append(validScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validScopes) == 0 {
|
||||
return nil, ErrInvalidScope
|
||||
}
|
||||
|
||||
return validScopes, nil
|
||||
}
|
||||
|
||||
// GenerateAuthorizationCode generates a new authorization code
|
||||
func (s *OAuthService) GenerateAuthorizationCode(
|
||||
ctx context.Context,
|
||||
client *models.OAuthClient,
|
||||
userID uuid.UUID,
|
||||
redirectURI string,
|
||||
scopes []string,
|
||||
codeChallenge, codeChallengeMethod string,
|
||||
) (string, error) {
|
||||
// For public clients, PKCE is required
|
||||
if client.IsPublic && codeChallenge == "" {
|
||||
return "", ErrPKCERequired
|
||||
}
|
||||
|
||||
// Generate a secure random code
|
||||
codeBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(codeBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate code: %w", err)
|
||||
}
|
||||
code := base64.URLEncoding.EncodeToString(codeBytes)
|
||||
|
||||
// Hash the code for storage
|
||||
codeHash := sha256.Sum256([]byte(code))
|
||||
hashedCode := hex.EncodeToString(codeHash[:])
|
||||
|
||||
scopesJSON, _ := json.Marshal(scopes)
|
||||
|
||||
var challengePtr, methodPtr *string
|
||||
if codeChallenge != "" {
|
||||
challengePtr = &codeChallenge
|
||||
if codeChallengeMethod == "" {
|
||||
codeChallengeMethod = "plain"
|
||||
}
|
||||
methodPtr = &codeChallengeMethod
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(ctx, `
|
||||
INSERT INTO oauth_authorization_codes (code, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`, hashedCode, client.ClientID, userID, redirectURI, scopesJSON, challengePtr, methodPtr, time.Now().Add(s.authCodeExpiration))
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store authorization code: %w", err)
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// ExchangeAuthorizationCode exchanges an authorization code for tokens
|
||||
func (s *OAuthService) ExchangeAuthorizationCode(
|
||||
ctx context.Context,
|
||||
code string,
|
||||
clientID string,
|
||||
redirectURI string,
|
||||
codeVerifier string,
|
||||
) (*models.OAuthTokenResponse, error) {
|
||||
// Hash the code to look it up
|
||||
codeHash := sha256.Sum256([]byte(code))
|
||||
hashedCode := hex.EncodeToString(codeHash[:])
|
||||
|
||||
var authCode models.OAuthAuthorizationCode
|
||||
var scopesJSON []byte
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at, used_at
|
||||
FROM oauth_authorization_codes WHERE code = $1
|
||||
`, hashedCode).Scan(
|
||||
&authCode.ID, &authCode.ClientID, &authCode.UserID, &authCode.RedirectURI,
|
||||
&scopesJSON, &authCode.CodeChallenge, &authCode.CodeChallengeMethod,
|
||||
&authCode.ExpiresAt, &authCode.UsedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Check if code was already used
|
||||
if authCode.UsedAt != nil {
|
||||
return nil, ErrCodeUsed
|
||||
}
|
||||
|
||||
// Check if code is expired
|
||||
if time.Now().After(authCode.ExpiresAt) {
|
||||
return nil, ErrCodeExpired
|
||||
}
|
||||
|
||||
// Verify client_id matches
|
||||
if authCode.ClientID != clientID {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Verify redirect_uri matches
|
||||
if authCode.RedirectURI != redirectURI {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Verify PKCE if code_challenge was provided
|
||||
if authCode.CodeChallenge != nil && *authCode.CodeChallenge != "" {
|
||||
if codeVerifier == "" {
|
||||
return nil, ErrPKCEVerifyFailed
|
||||
}
|
||||
|
||||
var expectedChallenge string
|
||||
if authCode.CodeChallengeMethod != nil && *authCode.CodeChallengeMethod == "S256" {
|
||||
// SHA256 hash of verifier
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
expectedChallenge = base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
} else {
|
||||
// Plain method
|
||||
expectedChallenge = codeVerifier
|
||||
}
|
||||
|
||||
if expectedChallenge != *authCode.CodeChallenge {
|
||||
return nil, ErrPKCEVerifyFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
_, err = s.db.Exec(ctx, `UPDATE oauth_authorization_codes SET used_at = NOW() WHERE id = $1`, authCode.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to mark code as used: %w", err)
|
||||
}
|
||||
|
||||
// Parse scopes
|
||||
var scopes []string
|
||||
json.Unmarshal(scopesJSON, &scopes)
|
||||
|
||||
// Generate tokens
|
||||
return s.generateTokens(ctx, clientID, authCode.UserID, scopes)
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshes an access token using a refresh token
|
||||
func (s *OAuthService) RefreshAccessToken(ctx context.Context, refreshToken, clientID string, requestedScope string) (*models.OAuthTokenResponse, error) {
|
||||
// Hash the refresh token
|
||||
tokenHash := sha256.Sum256([]byte(refreshToken))
|
||||
hashedToken := hex.EncodeToString(tokenHash[:])
|
||||
|
||||
var rt models.OAuthRefreshToken
|
||||
var scopesJSON []byte
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, client_id, user_id, scopes, expires_at, revoked_at
|
||||
FROM oauth_refresh_tokens WHERE token_hash = $1
|
||||
`, hashedToken).Scan(
|
||||
&rt.ID, &rt.ClientID, &rt.UserID, &scopesJSON, &rt.ExpiresAt, &rt.RevokedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
if rt.RevokedAt != nil {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if time.Now().After(rt.ExpiresAt) {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Verify client_id matches
|
||||
if rt.ClientID != clientID {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Parse original scopes
|
||||
var originalScopes []string
|
||||
json.Unmarshal(scopesJSON, &originalScopes)
|
||||
|
||||
// Determine scopes for new tokens
|
||||
var scopes []string
|
||||
if requestedScope != "" {
|
||||
// Validate that requested scopes are subset of original scopes
|
||||
originalMap := make(map[string]bool)
|
||||
for _, s := range originalScopes {
|
||||
originalMap[s] = true
|
||||
}
|
||||
|
||||
for _, s := range strings.Split(requestedScope, " ") {
|
||||
if originalMap[s] {
|
||||
scopes = append(scopes, s)
|
||||
}
|
||||
}
|
||||
|
||||
if len(scopes) == 0 {
|
||||
return nil, ErrInvalidScope
|
||||
}
|
||||
} else {
|
||||
scopes = originalScopes
|
||||
}
|
||||
|
||||
// Revoke old refresh token (rotate)
|
||||
_, _ = s.db.Exec(ctx, `UPDATE oauth_refresh_tokens SET revoked_at = NOW() WHERE id = $1`, rt.ID)
|
||||
|
||||
// Generate new tokens
|
||||
return s.generateTokens(ctx, clientID, rt.UserID, scopes)
|
||||
}
|
||||
|
||||
// generateTokens generates access and refresh tokens
|
||||
func (s *OAuthService) generateTokens(ctx context.Context, clientID string, userID uuid.UUID, scopes []string) (*models.OAuthTokenResponse, error) {
|
||||
// Get user info for JWT
|
||||
var user models.User
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, email, name, role, account_status FROM users WHERE id = $1
|
||||
`, userID).Scan(&user.ID, &user.Email, &user.Name, &user.Role, &user.AccountStatus)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Generate access token (JWT)
|
||||
accessTokenClaims := jwt.MapClaims{
|
||||
"sub": userID.String(),
|
||||
"email": user.Email,
|
||||
"role": user.Role,
|
||||
"account_status": user.AccountStatus,
|
||||
"client_id": clientID,
|
||||
"scope": strings.Join(scopes, " "),
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(s.accessTokenExpiration).Unix(),
|
||||
"iss": "breakpilot-consent-service",
|
||||
"aud": clientID,
|
||||
}
|
||||
|
||||
if user.Name != nil {
|
||||
accessTokenClaims["name"] = *user.Name
|
||||
}
|
||||
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessTokenClaims)
|
||||
accessTokenString, err := accessToken.SignedString([]byte(s.jwtSecret))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign access token: %w", err)
|
||||
}
|
||||
|
||||
// Hash access token for storage
|
||||
accessTokenHash := sha256.Sum256([]byte(accessTokenString))
|
||||
hashedAccessToken := hex.EncodeToString(accessTokenHash[:])
|
||||
|
||||
scopesJSON, _ := json.Marshal(scopes)
|
||||
|
||||
// Store access token
|
||||
var accessTokenID uuid.UUID
|
||||
err = s.db.QueryRow(ctx, `
|
||||
INSERT INTO oauth_access_tokens (token_hash, client_id, user_id, scopes, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id
|
||||
`, hashedAccessToken, clientID, userID, scopesJSON, time.Now().Add(s.accessTokenExpiration)).Scan(&accessTokenID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store access token: %w", err)
|
||||
}
|
||||
|
||||
// Generate refresh token (opaque)
|
||||
refreshTokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(refreshTokenBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
refreshTokenString := base64.URLEncoding.EncodeToString(refreshTokenBytes)
|
||||
|
||||
// Hash refresh token for storage
|
||||
refreshTokenHash := sha256.Sum256([]byte(refreshTokenString))
|
||||
hashedRefreshToken := hex.EncodeToString(refreshTokenHash[:])
|
||||
|
||||
// Store refresh token
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO oauth_refresh_tokens (token_hash, access_token_id, client_id, user_id, scopes, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
`, hashedRefreshToken, accessTokenID, clientID, userID, scopesJSON, time.Now().Add(s.refreshTokenExpiration))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &models.OAuthTokenResponse{
|
||||
AccessToken: accessTokenString,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: int(s.accessTokenExpiration.Seconds()),
|
||||
RefreshToken: refreshTokenString,
|
||||
Scope: strings.Join(scopes, " "),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeToken revokes an access or refresh token
|
||||
func (s *OAuthService) RevokeToken(ctx context.Context, token, tokenTypeHint string) error {
|
||||
tokenHash := sha256.Sum256([]byte(token))
|
||||
hashedToken := hex.EncodeToString(tokenHash[:])
|
||||
|
||||
// Try to revoke as access token
|
||||
if tokenTypeHint == "" || tokenTypeHint == "access_token" {
|
||||
result, err := s.db.Exec(ctx, `UPDATE oauth_access_tokens SET revoked_at = NOW() WHERE token_hash = $1`, hashedToken)
|
||||
if err == nil && result.RowsAffected() > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try to revoke as refresh token
|
||||
if tokenTypeHint == "" || tokenTypeHint == "refresh_token" {
|
||||
result, err := s.db.Exec(ctx, `UPDATE oauth_refresh_tokens SET revoked_at = NOW() WHERE token_hash = $1`, hashedToken)
|
||||
if err == nil && result.RowsAffected() > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil // RFC 7009: Always return success
|
||||
}
|
||||
|
||||
// ValidateAccessToken validates an OAuth access token
|
||||
func (s *OAuthService) ValidateAccessToken(ctx context.Context, tokenString string) (*jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(s.jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// Check if token is revoked in database
|
||||
tokenHash := sha256.Sum256([]byte(tokenString))
|
||||
hashedToken := hex.EncodeToString(tokenHash[:])
|
||||
|
||||
var revokedAt *time.Time
|
||||
err = s.db.QueryRow(ctx, `SELECT revoked_at FROM oauth_access_tokens WHERE token_hash = $1`, hashedToken).Scan(&revokedAt)
|
||||
if err == nil && revokedAt != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// GetClientByID retrieves an OAuth client by its client_id
|
||||
func (s *OAuthService) GetClientByID(ctx context.Context, clientID string) (*models.OAuthClient, error) {
|
||||
return s.ValidateClient(ctx, clientID)
|
||||
}
|
||||
|
||||
// CreateClient creates a new OAuth client (admin only)
|
||||
func (s *OAuthService) CreateClient(
|
||||
ctx context.Context,
|
||||
name, description string,
|
||||
redirectURIs, scopes, grantTypes []string,
|
||||
isPublic bool,
|
||||
createdBy *uuid.UUID,
|
||||
) (*models.OAuthClient, string, error) {
|
||||
// Generate client_id
|
||||
clientIDBytes := make([]byte, 16)
|
||||
rand.Read(clientIDBytes)
|
||||
clientID := hex.EncodeToString(clientIDBytes)
|
||||
|
||||
// Generate client_secret for confidential clients
|
||||
var clientSecret string
|
||||
var clientSecretPtr *string
|
||||
if !isPublic {
|
||||
secretBytes := make([]byte, 32)
|
||||
rand.Read(secretBytes)
|
||||
clientSecret = base64.URLEncoding.EncodeToString(secretBytes)
|
||||
clientSecretPtr = &clientSecret
|
||||
}
|
||||
|
||||
redirectURIsJSON, _ := json.Marshal(redirectURIs)
|
||||
scopesJSON, _ := json.Marshal(scopes)
|
||||
grantTypesJSON, _ := json.Marshal(grantTypes)
|
||||
|
||||
var client models.OAuthClient
|
||||
err := s.db.QueryRow(ctx, `
|
||||
INSERT INTO oauth_clients (client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id, client_id, name, is_public, is_active, created_at
|
||||
`, clientID, clientSecretPtr, name, description, redirectURIsJSON, scopesJSON, grantTypesJSON, isPublic, createdBy).Scan(
|
||||
&client.ID, &client.ClientID, &client.Name, &client.IsPublic, &client.IsActive, &client.CreatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create client: %w", err)
|
||||
}
|
||||
|
||||
client.RedirectURIs = redirectURIs
|
||||
client.Scopes = scopes
|
||||
client.GrantTypes = grantTypes
|
||||
|
||||
return &client, clientSecret, nil
|
||||
}
|
||||
Reference in New Issue
Block a user