package services import ( "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "strings" "testing" "time" ) // TestPKCEVerification tests PKCE code_challenge and code_verifier validation func TestPKCEVerification_S256_ValidVerifier(t *testing.T) { // Generate a code_verifier codeVerifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" // Calculate expected code_challenge (S256) hash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) // Verify the challenge matches verifierHash := sha256.Sum256([]byte(codeVerifier)) calculatedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:]) if calculatedChallenge != codeChallenge { t.Errorf("PKCE verification failed: expected %s, got %s", codeChallenge, calculatedChallenge) } } func TestPKCEVerification_S256_InvalidVerifier(t *testing.T) { codeVerifier := "correct-verifier-12345678901234567890" wrongVerifier := "wrong-verifier-00000000000000000000" // Calculate code_challenge from correct verifier hash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) // Calculate challenge from wrong verifier wrongHash := sha256.Sum256([]byte(wrongVerifier)) wrongChallenge := base64.RawURLEncoding.EncodeToString(wrongHash[:]) if wrongChallenge == codeChallenge { t.Error("PKCE verification should fail for wrong verifier") } } func TestPKCEVerification_Plain_ValidVerifier(t *testing.T) { codeVerifier := "plain-text-verifier-12345" codeChallenge := codeVerifier // Plain method: challenge = verifier if codeVerifier != codeChallenge { t.Error("Plain PKCE verification failed") } } // TestTokenHashing tests that token hashing is consistent func TestTokenHashing_Consistency(t *testing.T) { token := "sample-access-token-12345" hash1 := sha256.Sum256([]byte(token)) hash2 := sha256.Sum256([]byte(token)) if hash1 != hash2 { t.Error("Token hashing should be consistent") } } func TestTokenHashing_DifferentTokens(t *testing.T) { token1 := "token-1-abcdefgh" token2 := "token-2-ijklmnop" hash1 := sha256.Sum256([]byte(token1)) hash2 := sha256.Sum256([]byte(token2)) if hash1 == hash2 { t.Error("Different tokens should produce different hashes") } } // TestScopeValidation tests scope parsing and validation func TestScopeValidation_ParseScopes(t *testing.T) { tests := []struct { name string requestedScope string allowedScopes []string expectedCount int }{ { name: "all scopes allowed", requestedScope: "openid profile email", allowedScopes: []string{"openid", "profile", "email", "offline_access"}, expectedCount: 3, }, { name: "some scopes allowed", requestedScope: "openid profile admin", allowedScopes: []string{"openid", "profile", "email"}, expectedCount: 2, // admin not allowed }, { name: "no scopes allowed", requestedScope: "admin superuser", allowedScopes: []string{"openid", "profile", "email"}, expectedCount: 0, }, { name: "empty request defaults", requestedScope: "", allowedScopes: []string{"openid", "profile", "email"}, expectedCount: 0, // Empty request returns 0 from this test logic }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.requestedScope == "" { // Empty scope should use defaults in actual service return } allowedMap := make(map[string]bool) for _, scope := range tt.allowedScopes { allowedMap[scope] = true } var validScopes []string requestedScopes := splitScopes(tt.requestedScope) for _, scope := range requestedScopes { if allowedMap[scope] { validScopes = append(validScopes, scope) } } if len(validScopes) != tt.expectedCount { t.Errorf("Expected %d valid scopes, got %d", tt.expectedCount, len(validScopes)) } }) } } // Helper function for scope splitting func splitScopes(scopes string) []string { if scopes == "" { return nil } var result []string start := 0 for i := 0; i <= len(scopes); i++ { if i == len(scopes) || scopes[i] == ' ' { if start < i { result = append(result, scopes[start:i]) } start = i + 1 } } return result } // TestRedirectURIValidation tests redirect URI validation func TestRedirectURIValidation(t *testing.T) { tests := []struct { name string registeredURIs []string requestURI string shouldMatch bool }{ { name: "exact match", registeredURIs: []string{"https://example.com/callback"}, requestURI: "https://example.com/callback", shouldMatch: true, }, { name: "no match different domain", registeredURIs: []string{"https://example.com/callback"}, requestURI: "https://evil.com/callback", shouldMatch: false, }, { name: "no match different path", registeredURIs: []string{"https://example.com/callback"}, requestURI: "https://example.com/other", shouldMatch: false, }, { name: "multiple URIs - second matches", registeredURIs: []string{"https://example.com/callback", "https://example.com/auth"}, requestURI: "https://example.com/auth", shouldMatch: true, }, { name: "localhost for development", registeredURIs: []string{"http://localhost:3000/callback"}, requestURI: "http://localhost:3000/callback", shouldMatch: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matched := false for _, uri := range tt.registeredURIs { if uri == tt.requestURI { matched = true break } } if matched != tt.shouldMatch { t.Errorf("Expected match=%v, got match=%v", tt.shouldMatch, matched) } }) } } // TestGrantTypeValidation tests grant type validation func TestGrantTypeValidation(t *testing.T) { tests := []struct { name string allowedGrants []string requestedGrant string shouldAllow bool }{ { name: "authorization_code allowed", allowedGrants: []string{"authorization_code", "refresh_token"}, requestedGrant: "authorization_code", shouldAllow: true, }, { name: "refresh_token allowed", allowedGrants: []string{"authorization_code", "refresh_token"}, requestedGrant: "refresh_token", shouldAllow: true, }, { name: "password not allowed", allowedGrants: []string{"authorization_code", "refresh_token"}, requestedGrant: "password", shouldAllow: false, }, { name: "client_credentials not allowed", allowedGrants: []string{"authorization_code", "refresh_token"}, requestedGrant: "client_credentials", shouldAllow: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { allowed := false for _, grant := range tt.allowedGrants { if grant == tt.requestedGrant { allowed = true break } } if allowed != tt.shouldAllow { t.Errorf("Expected allow=%v, got allow=%v", tt.shouldAllow, allowed) } }) } } // TestAuthorizationCodeExpiry tests that expired codes should be rejected func TestAuthorizationCodeExpiry_Logic(t *testing.T) { tests := []struct { name string expiryMins int usedAfter int // minutes after creation shouldAllow bool }{ { name: "code used within expiry", expiryMins: 10, usedAfter: 5, shouldAllow: true, }, { name: "code used at expiry boundary", expiryMins: 10, usedAfter: 10, shouldAllow: false, // Expired at exactly 10 mins }, { name: "code used after expiry", expiryMins: 10, usedAfter: 15, shouldAllow: false, }, { name: "code used immediately", expiryMins: 10, usedAfter: 0, shouldAllow: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { isValid := tt.usedAfter < tt.expiryMins if isValid != tt.shouldAllow { t.Errorf("Expected allow=%v for code used after %d mins (expiry: %d mins)", tt.shouldAllow, tt.usedAfter, tt.expiryMins) } }) } } // TestClientSecretValidation tests confidential client authentication func TestClientSecretValidation(t *testing.T) { tests := []struct { name string isPublic bool storedSecret string providedSecret string shouldAllow bool }{ { name: "public client - no secret needed", isPublic: true, storedSecret: "", providedSecret: "", shouldAllow: true, }, { name: "confidential client - correct secret", isPublic: false, storedSecret: "super-secret-123", providedSecret: "super-secret-123", shouldAllow: true, }, { name: "confidential client - wrong secret", isPublic: false, storedSecret: "super-secret-123", providedSecret: "wrong-secret", shouldAllow: false, }, { name: "confidential client - empty secret", isPublic: false, storedSecret: "super-secret-123", providedSecret: "", shouldAllow: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var isValid bool if tt.isPublic { isValid = true } else { isValid = tt.storedSecret == tt.providedSecret } if isValid != tt.shouldAllow { t.Errorf("Expected allow=%v, got allow=%v", tt.shouldAllow, isValid) } }) } } // ======================================== // Extended OAuth 2.0 Tests // ======================================== // TestCodeVerifierGeneration tests that code verifiers meet RFC 7636 requirements func TestCodeVerifierGeneration_RFC7636(t *testing.T) { tests := []struct { name string length int expectedLength int description string }{ {"minimum length (43)", 43, 43, "RFC 7636 minimum"}, {"standard length (64)", 64, 64, "Recommended length"}, {"maximum length (128)", 128, 128, "RFC 7636 maximum"}, {"too short (42) - corrected to minimum", 42, 43, "Should be corrected to minimum"}, {"too long (129) - corrected to maximum", 129, 128, "Should be corrected to maximum"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { verifier := generateCodeVerifier(tt.length) // Check that length is corrected to valid range if len(verifier) != tt.expectedLength { t.Errorf("Expected length %d, got %d", tt.expectedLength, len(verifier)) } // Check character set (unreserved characters only: A-Z, a-z, 0-9, -, ., _, ~) for _, c := range verifier { if !isUnreservedChar(c) { t.Errorf("Code verifier contains invalid character: %c", c) } } }) } } // TestCodeVerifierLength_Validation tests length validation logic func TestCodeVerifierLength_Validation(t *testing.T) { tests := []struct { name string length int isValid bool }{ {"length 42 - too short", 42, false}, {"length 43 - minimum valid", 43, true}, {"length 64 - recommended", 64, true}, {"length 128 - maximum valid", 128, true}, {"length 129 - too long", 129, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { isValid := tt.length >= 43 && tt.length <= 128 if isValid != tt.isValid { t.Errorf("Expected valid=%v for length %d, got valid=%v", tt.isValid, tt.length, isValid) } }) } } // generateCodeVerifier generates a code verifier of specified length func generateCodeVerifier(length int) string { // Ensure minimum and maximum bounds if length < 43 { length = 43 } if length > 128 { length = 128 } const unreserved = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" bytes := make([]byte, length) rand.Read(bytes) result := make([]byte, length) for i, b := range bytes { result[i] = unreserved[int(b)%len(unreserved)] } return string(result) } // isUnreservedChar checks if a character is an unreserved character per RFC 3986 func isUnreservedChar(c rune) bool { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '.' || c == '_' || c == '~' } // TestCodeChallengeGeneration tests S256 challenge generation func TestCodeChallengeGeneration_S256(t *testing.T) { // Known test vector from RFC 7636 Appendix B verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" expectedChallenge := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" hash := sha256.Sum256([]byte(verifier)) challenge := base64.RawURLEncoding.EncodeToString(hash[:]) if challenge != expectedChallenge { t.Errorf("S256 challenge mismatch: expected %s, got %s", expectedChallenge, challenge) } } // TestRefreshTokenRotation tests that refresh tokens are rotated on use func TestRefreshTokenRotation_Logic(t *testing.T) { // Simulate refresh token rotation oldToken := "old-refresh-token-123" oldTokenHash := hashToken(oldToken) // Generate new token newToken := generateSecureToken(32) newTokenHash := hashToken(newToken) // Verify tokens are different if oldTokenHash == newTokenHash { t.Error("New refresh token should be different from old token") } // Verify old token would be revoked (simulated by marking revoked_at) oldTokenRevoked := true if !oldTokenRevoked { t.Error("Old refresh token should be revoked after rotation") } } func hashToken(token string) string { hash := sha256.Sum256([]byte(token)) return hex.EncodeToString(hash[:]) } func generateSecureToken(length int) string { bytes := make([]byte, length) rand.Read(bytes) return base64.URLEncoding.EncodeToString(bytes) } // TestAccessTokenExpiry tests access token expiration handling func TestAccessTokenExpiry_Scenarios(t *testing.T) { tests := []struct { name string tokenDuration time.Duration usedAfter time.Duration shouldBeValid bool }{ { name: "token used immediately", tokenDuration: time.Hour, usedAfter: 0, shouldBeValid: true, }, { name: "token used within validity", tokenDuration: time.Hour, usedAfter: 30 * time.Minute, shouldBeValid: true, }, { name: "token used at expiry", tokenDuration: time.Hour, usedAfter: time.Hour, shouldBeValid: false, }, { name: "token used after expiry", tokenDuration: time.Hour, usedAfter: 2 * time.Hour, shouldBeValid: false, }, { name: "short-lived token", tokenDuration: 5 * time.Minute, usedAfter: 6 * time.Minute, shouldBeValid: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { issuedAt := time.Now() expiresAt := issuedAt.Add(tt.tokenDuration) usedAt := issuedAt.Add(tt.usedAfter) isValid := usedAt.Before(expiresAt) if isValid != tt.shouldBeValid { t.Errorf("Expected valid=%v for token used after %v (duration: %v)", tt.shouldBeValid, tt.usedAfter, tt.tokenDuration) } }) } } // TestOAuthErrors tests that OAuth error codes are correct func TestOAuthErrors_RFC6749(t *testing.T) { tests := []struct { scenario string errorCode string description string }{ {"invalid client_id", "invalid_client", "Client authentication failed"}, {"invalid grant (code)", "invalid_grant", "Authorization code invalid or expired"}, {"invalid scope", "invalid_scope", "Requested scope is invalid"}, {"invalid request", "invalid_request", "Request is missing required parameter"}, {"unauthorized client", "unauthorized_client", "Client not authorized for this grant type"}, {"access denied", "access_denied", "Resource owner denied the request"}, } for _, tt := range tests { t.Run(tt.scenario, func(t *testing.T) { // Verify error codes match RFC 6749 Section 5.2 validErrors := map[string]bool{ "invalid_request": true, "invalid_client": true, "invalid_grant": true, "unauthorized_client": true, "unsupported_grant_type": true, "invalid_scope": true, "access_denied": true, "unsupported_response_type": true, "server_error": true, "temporarily_unavailable": true, } if !validErrors[tt.errorCode] { t.Errorf("Error code %s is not a valid OAuth 2.0 error code", tt.errorCode) } }) } } // TestStateParameter tests state parameter handling for CSRF protection func TestStateParameter_CSRF(t *testing.T) { tests := []struct { name string requestState string responseState string shouldMatch bool }{ { name: "matching state", requestState: "abc123xyz", responseState: "abc123xyz", shouldMatch: true, }, { name: "non-matching state", requestState: "abc123xyz", responseState: "different", shouldMatch: false, }, { name: "empty request state", requestState: "", responseState: "abc123xyz", shouldMatch: false, }, { name: "empty response state", requestState: "abc123xyz", responseState: "", shouldMatch: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matches := tt.requestState != "" && tt.requestState == tt.responseState if matches != tt.shouldMatch { t.Errorf("Expected match=%v, got match=%v", tt.shouldMatch, matches) } }) } } // TestResponseType tests response_type validation func TestResponseType_Validation(t *testing.T) { tests := []struct { name string responseType string isValid bool }{ {"code - valid", "code", true}, {"token - implicit flow (disabled)", "token", false}, {"id_token - OIDC", "id_token", false}, {"code token - hybrid", "code token", false}, {"empty", "", false}, {"invalid", "password", false}, } supportedResponseTypes := map[string]bool{ "code": true, // Only authorization code flow is supported } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { isValid := supportedResponseTypes[tt.responseType] if isValid != tt.isValid { t.Errorf("Expected valid=%v for response_type=%s, got valid=%v", tt.isValid, tt.responseType, isValid) } }) } } // TestCodeChallengeMethod tests code_challenge_method validation func TestCodeChallengeMethod_Validation(t *testing.T) { tests := []struct { name string method string isValid bool }{ {"S256 - recommended", "S256", true}, {"plain - discouraged but valid", "plain", true}, {"empty - defaults to plain", "", true}, {"sha512 - not supported", "sha512", false}, {"invalid", "md5", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { isValid := tt.method == "S256" || tt.method == "plain" || tt.method == "" if isValid != tt.isValid { t.Errorf("Expected valid=%v for method=%s, got valid=%v", tt.isValid, tt.method, isValid) } }) } } // TestTokenRevocation tests token revocation behavior per RFC 7009 func TestTokenRevocation_RFC7009(t *testing.T) { tests := []struct { name string tokenExists bool tokenRevoked bool expectSuccess bool }{ { name: "revoke existing active token", tokenExists: true, tokenRevoked: false, expectSuccess: true, }, { name: "revoke already revoked token", tokenExists: true, tokenRevoked: true, expectSuccess: true, // RFC 7009: Always return 200 }, { name: "revoke non-existent token", tokenExists: false, tokenRevoked: false, expectSuccess: true, // RFC 7009: Always return 200 }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Simulate revocation logic // Per RFC 7009, revocation endpoint always returns 200 OK success := true if success != tt.expectSuccess { t.Errorf("Expected success=%v, got success=%v", tt.expectSuccess, success) } }) } } // TestClientIDGeneration tests client_id format func TestClientIDGeneration_Format(t *testing.T) { // Generate multiple client IDs clientIDs := make(map[string]bool) for i := 0; i < 100; i++ { bytes := make([]byte, 16) rand.Read(bytes) clientID := hex.EncodeToString(bytes) // Check format (32 hex characters) if len(clientID) != 32 { t.Errorf("Client ID should be 32 characters, got %d", len(clientID)) } // Check uniqueness if clientIDs[clientID] { t.Error("Client ID should be unique") } clientIDs[clientID] = true // Check only hex characters for _, c := range clientID { if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { t.Errorf("Client ID should only contain hex characters, found %c", c) } } } } // TestScopeNormalization tests scope string normalization func TestScopeNormalization(t *testing.T) { tests := []struct { name string input string expected []string }{ { name: "single scope", input: "openid", expected: []string{"openid"}, }, { name: "multiple scopes", input: "openid profile email", expected: []string{"openid", "profile", "email"}, }, { name: "extra spaces", input: "openid profile email", expected: []string{"openid", "profile", "email"}, }, { name: "empty string", input: "", expected: []string{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { scopes := normalizeScopes(tt.input) if len(scopes) != len(tt.expected) { t.Errorf("Expected %d scopes, got %d", len(tt.expected), len(scopes)) return } for i, scope := range scopes { if scope != tt.expected[i] { t.Errorf("Expected scope[%d]=%s, got %s", i, tt.expected[i], scope) } } }) } } func normalizeScopes(scope string) []string { if scope == "" { return []string{} } parts := strings.Fields(scope) // Handles multiple spaces return parts } // BenchmarkPKCEVerification benchmarks PKCE S256 verification func BenchmarkPKCEVerification_S256(b *testing.B) { verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" for i := 0; i < b.N; i++ { hash := sha256.Sum256([]byte(verifier)) base64.RawURLEncoding.EncodeToString(hash[:]) } } // BenchmarkTokenHashing benchmarks token hashing for storage func BenchmarkTokenHashing(b *testing.B) { token := "sample-access-token-12345678901234567890" for i := 0; i < b.N; i++ { hash := sha256.Sum256([]byte(token)) hex.EncodeToString(hash[:]) } } // BenchmarkCodeVerifierGeneration benchmarks code verifier generation func BenchmarkCodeVerifierGeneration(b *testing.B) { for i := 0; i < b.N; i++ { generateCodeVerifier(64) } }