package services import ( "crypto/hmac" "crypto/sha1" "crypto/sha256" "encoding/base32" "encoding/binary" "encoding/hex" "strings" "testing" "time" ) // TestTOTPGeneration tests TOTP code generation func TestTOTPGeneration_ValidSecret(t *testing.T) { // Test secret (Base32 encoded) secret := "JBSWY3DPEHPK3PXP" // This is "Hello!" in Base32 // Decode secret secretBytes, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret) if err != nil { t.Fatalf("Failed to decode secret: %v", err) } // Generate TOTP for current time now := time.Now() counter := uint64(now.Unix()) / 30 buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, counter) mac := hmac.New(sha1.New, secretBytes) mac.Write(buf) hash := mac.Sum(nil) // Dynamic truncation offset := hash[len(hash)-1] & 0x0f code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff totpCode := code % 1000000 // Check that code is 6 digits if totpCode < 0 || totpCode > 999999 { t.Errorf("TOTP code should be 6 digits, got %d", totpCode) } } // TestTOTPGeneration_SameTimeProducesSameCode tests deterministic generation func TestTOTPGeneration_Deterministic(t *testing.T) { secret := "JBSWY3DPEHPK3PXP" secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret) fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) code1 := generateTOTPAt(secretBytes, fixedTime) code2 := generateTOTPAt(secretBytes, fixedTime) if code1 != code2 { t.Errorf("Same time should produce same code: got %s and %s", code1, code2) } } // TestTOTPGeneration_DifferentTimesProduceDifferentCodes tests time sensitivity func TestTOTPGeneration_TimeSensitive(t *testing.T) { secret := "JBSWY3DPEHPK3PXP" secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret) time1 := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) time2 := time1.Add(30 * time.Second) // Next TOTP period code1 := generateTOTPAt(secretBytes, time1) code2 := generateTOTPAt(secretBytes, time2) if code1 == code2 { t.Error("Different TOTP periods should produce different codes") } } // Helper function for TOTP generation at specific time func generateTOTPAt(secretBytes []byte, t time.Time) string { counter := uint64(t.Unix()) / 30 buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, counter) mac := hmac.New(sha1.New, secretBytes) mac.Write(buf) hash := mac.Sum(nil) offset := hash[len(hash)-1] & 0x0f code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff return padCode(code % 1000000) } func padCode(code uint32) string { s := "" for i := 0; i < 6; i++ { s = string(rune('0'+code%10)) + s code /= 10 } return s } // TestTOTPValidation_WithDrift tests validation with clock drift allowance func TestTOTPValidation_WithDrift(t *testing.T) { secret := "JBSWY3DPEHPK3PXP" secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret) now := time.Now() // Generate current code currentCode := generateTOTPAt(secretBytes, now) // Generate previous period code previousCode := generateTOTPAt(secretBytes, now.Add(-30*time.Second)) // Generate next period code nextCode := generateTOTPAt(secretBytes, now.Add(30*time.Second)) // All three should be valid for current validation (allowing 1 period drift) validCodes := []string{currentCode, previousCode, nextCode} for _, code := range validCodes { isValid := validateTOTPWithDrift(secretBytes, code, now) if !isValid { t.Errorf("Code %s should be valid with drift allowance", code) } } } // validateTOTPWithDrift validates a TOTP code allowing for clock drift func validateTOTPWithDrift(secretBytes []byte, code string, now time.Time) bool { for _, offset := range []int{0, -1, 1} { t := now.Add(time.Duration(offset*30) * time.Second) expected := generateTOTPAt(secretBytes, t) if expected == code { return true } } return false } // TestRecoveryCodeGeneration tests recovery code format func TestRecoveryCodeGeneration_Format(t *testing.T) { // Simulate recovery code generation codeBytes := make([]byte, 4) // 8 hex chars = 4 bytes for i := range codeBytes { codeBytes[i] = byte(i + 1) // Deterministic for testing } code := strings.ToUpper(hex.EncodeToString(codeBytes)) // Check format if len(code) != 8 { t.Errorf("Recovery code should be 8 characters, got %d", len(code)) } // Check uppercase if code != strings.ToUpper(code) { t.Error("Recovery code should be uppercase") } // Check alphanumeric (hex only contains 0-9 and A-F) for _, c := range code { if !((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F')) { t.Errorf("Recovery code should only contain hex characters, found '%c'", c) } } } // TestRecoveryCodeHashing tests that recovery codes are hashed for storage func TestRecoveryCodeHashing_Consistency(t *testing.T) { code := "ABCD1234" hash1 := sha256.Sum256([]byte(code)) hash2 := sha256.Sum256([]byte(code)) if hash1 != hash2 { t.Error("Recovery code hashing should be consistent") } } func TestRecoveryCodeHashing_CaseInsensitive(t *testing.T) { code1 := "ABCD1234" code2 := "abcd1234" hash1 := sha256.Sum256([]byte(strings.ToUpper(code1))) hash2 := sha256.Sum256([]byte(strings.ToUpper(code2))) if hash1 != hash2 { t.Error("Recovery codes should be case-insensitive when normalized to uppercase") } } // TestSecretGeneration tests that secrets are valid Base32 func TestSecretGeneration_ValidBase32(t *testing.T) { // Simulate secret generation (20 bytes -> Base32 without padding) secretBytes := make([]byte, 20) for i := range secretBytes { secretBytes[i] = byte(i * 13) // Deterministic for testing } secret := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secretBytes) // Verify it can be decoded decoded, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret) if err != nil { t.Errorf("Generated secret should be valid Base32: %v", err) } if len(decoded) != 20 { t.Errorf("Decoded secret should be 20 bytes, got %d", len(decoded)) } } // TestQRCodeOtpauthURL tests otpauth URL format func TestQRCodeOtpauthURL_Format(t *testing.T) { issuer := "BreakPilot" email := "test@example.com" secret := "JBSWY3DPEHPK3PXP" period := 30 digits := 6 url := "otpauth://totp/" + issuer + ":" + email + "?secret=" + secret + "&issuer=" + issuer + "&algorithm=SHA1" + "&digits=" + string(rune('0'+digits)) + "&period=" + string(rune('0'+period/10)) + string(rune('0'+period%10)) // Check URL starts with otpauth://totp/ if !strings.HasPrefix(url, "otpauth://totp/") { t.Error("OTP auth URL should start with otpauth://totp/") } // Check contains required parameters if !strings.Contains(url, "secret=") { t.Error("OTP auth URL should contain secret parameter") } if !strings.Contains(url, "issuer=") { t.Error("OTP auth URL should contain issuer parameter") } } // TestChallengeExpiry tests 2FA challenge expiration func TestChallengeExpiry_Logic(t *testing.T) { tests := []struct { name string expiryMins int usedAfter int // minutes after creation shouldAllow bool }{ { name: "challenge used within expiry", expiryMins: 5, usedAfter: 2, shouldAllow: true, }, { name: "challenge used at expiry", expiryMins: 5, usedAfter: 5, shouldAllow: false, // Expired }, { name: "challenge used after expiry", expiryMins: 5, usedAfter: 10, shouldAllow: false, }, { name: "challenge used immediately", expiryMins: 5, usedAfter: 0, shouldAllow: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { isValid := tt.usedAfter < tt.expiryMins if isValid != tt.shouldAllow { t.Errorf("Expected allow=%v for challenge used after %d mins (expiry: %d mins)", tt.shouldAllow, tt.usedAfter, tt.expiryMins) } }) } } // TestRecoveryCodeOneTimeUse tests that recovery codes can only be used once func TestRecoveryCodeOneTimeUse(t *testing.T) { initialCodes := []string{ sha256Hash("CODE0001"), sha256Hash("CODE0002"), sha256Hash("CODE0003"), } // Use CODE0002 usedCodeHash := sha256Hash("CODE0002") // Remove used code from list var remainingCodes []string for _, code := range initialCodes { if code != usedCodeHash { remainingCodes = append(remainingCodes, code) } } if len(remainingCodes) != 2 { t.Errorf("Should have 2 remaining codes after using one, got %d", len(remainingCodes)) } // Verify used code is not in remaining for _, code := range remainingCodes { if code == usedCodeHash { t.Error("Used recovery code should be removed from list") } } } func sha256Hash(s string) string { h := sha256.Sum256([]byte(s)) return hex.EncodeToString(h[:]) } // TestTwoFactorEnableFlow tests the 2FA enable workflow func TestTwoFactorEnableFlow_States(t *testing.T) { tests := []struct { name string initialState bool // verified action string expectedState bool }{ { name: "fresh user - not verified", initialState: false, action: "none", expectedState: false, }, { name: "user verifies 2FA", initialState: false, action: "verify", expectedState: true, }, { name: "already verified - stays verified", initialState: true, action: "verify", expectedState: true, }, { name: "user disables 2FA", initialState: true, action: "disable", expectedState: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { state := tt.initialState switch tt.action { case "verify": state = true case "disable": state = false } if state != tt.expectedState { t.Errorf("Expected state=%v after action '%s', got state=%v", tt.expectedState, tt.action, state) } }) } }