package services import ( "bytes" "context" "crypto/hmac" "crypto/rand" "crypto/sha1" "crypto/sha256" "encoding/base32" "encoding/base64" "encoding/binary" "encoding/hex" "encoding/json" "errors" "fmt" "image/png" "strings" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" qrcode "github.com/skip2/go-qrcode" "github.com/breakpilot/consent-service/internal/models" ) var ( ErrTOTPNotEnabled = errors.New("2FA is not enabled for this user") ErrTOTPAlreadyEnabled = errors.New("2FA is already enabled for this user") ErrTOTPInvalidCode = errors.New("invalid 2FA code") ErrTOTPChallengeExpired = errors.New("2FA challenge expired") ErrRecoveryCodeInvalid = errors.New("invalid recovery code") ErrRecoveryCodeUsed = errors.New("recovery code already used") ) const ( TOTPPeriod = 30 // TOTP period in seconds TOTPDigits = 6 // Number of digits in TOTP code TOTPSecretLen = 20 // Length of TOTP secret in bytes RecoveryCodeCount = 10 // Number of recovery codes to generate RecoveryCodeLen = 8 // Length of each recovery code ChallengeExpiry = 5 * time.Minute // 2FA challenge expiry ) // TOTPService handles Two-Factor Authentication using TOTP type TOTPService struct { db *pgxpool.Pool issuer string } // NewTOTPService creates a new TOTPService func NewTOTPService(db *pgxpool.Pool, issuer string) *TOTPService { return &TOTPService{ db: db, issuer: issuer, } } // GenerateSecret generates a new TOTP secret func (s *TOTPService) GenerateSecret() (string, error) { secret := make([]byte, TOTPSecretLen) if _, err := rand.Read(secret); err != nil { return "", fmt.Errorf("failed to generate secret: %w", err) } return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secret), nil } // GenerateRecoveryCodes generates a set of recovery codes func (s *TOTPService) GenerateRecoveryCodes() ([]string, error) { codes := make([]string, RecoveryCodeCount) for i := 0; i < RecoveryCodeCount; i++ { codeBytes := make([]byte, RecoveryCodeLen/2) if _, err := rand.Read(codeBytes); err != nil { return nil, fmt.Errorf("failed to generate recovery code: %w", err) } codes[i] = strings.ToUpper(hex.EncodeToString(codeBytes)) } return codes, nil } // GenerateQRCode generates a QR code for TOTP setup func (s *TOTPService) GenerateQRCode(secret, email string) (string, error) { // Create otpauth URL otpauthURL := fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=SHA1&digits=%d&period=%d", s.issuer, email, secret, s.issuer, TOTPDigits, TOTPPeriod) // Generate QR code qr, err := qrcode.New(otpauthURL, qrcode.Medium) if err != nil { return "", fmt.Errorf("failed to generate QR code: %w", err) } // Convert to PNG var buf bytes.Buffer if err := png.Encode(&buf, qr.Image(256)); err != nil { return "", fmt.Errorf("failed to encode QR code: %w", err) } // Convert to data URL dataURL := fmt.Sprintf("data:image/png;base64,%s", base64.StdEncoding.EncodeToString(buf.Bytes())) return dataURL, nil } // GenerateTOTP generates the current TOTP code for a secret func (s *TOTPService) GenerateTOTP(secret string) (string, error) { return s.GenerateTOTPAt(secret, time.Now()) } // GenerateTOTPAt generates a TOTP code for a specific time func (s *TOTPService) GenerateTOTPAt(secret string, t time.Time) (string, error) { // Decode secret secretBytes, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret)) if err != nil { return "", fmt.Errorf("invalid secret: %w", err) } // Calculate counter counter := uint64(t.Unix()) / TOTPPeriod // Generate HOTP buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, counter) mac := hmac.New(sha1.New, secretBytes) mac.Write(buf) hash := mac.Sum(nil) // Dynamic truncation offset := hash[len(hash)-1] & 0x0f code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff // Format code codeStr := fmt.Sprintf("%0*d", TOTPDigits, code%1000000) return codeStr, nil } // ValidateTOTP validates a TOTP code (allows 1 period drift) func (s *TOTPService) ValidateTOTP(secret, code string) bool { now := time.Now() // Check current, previous, and next period for _, offset := range []int{0, -1, 1} { t := now.Add(time.Duration(offset*TOTPPeriod) * time.Second) expected, err := s.GenerateTOTPAt(secret, t) if err == nil && expected == code { return true } } return false } // Setup2FA initiates 2FA setup for a user func (s *TOTPService) Setup2FA(ctx context.Context, userID uuid.UUID, email string) (*models.Setup2FAResponse, error) { // Check if 2FA is already enabled var exists bool err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM user_totp WHERE user_id = $1 AND verified = true)`, userID).Scan(&exists) if err == nil && exists { return nil, ErrTOTPAlreadyEnabled } // Generate secret secret, err := s.GenerateSecret() if err != nil { return nil, err } // Generate recovery codes recoveryCodes, err := s.GenerateRecoveryCodes() if err != nil { return nil, err } // Generate QR code qrCode, err := s.GenerateQRCode(secret, email) if err != nil { return nil, err } // Hash recovery codes for storage hashedCodes := make([]string, len(recoveryCodes)) for i, code := range recoveryCodes { hash := sha256.Sum256([]byte(code)) hashedCodes[i] = hex.EncodeToString(hash[:]) } recoveryCodesJSON, _ := json.Marshal(hashedCodes) // Store or update TOTP record (unverified) _, err = s.db.Exec(ctx, ` INSERT INTO user_totp (user_id, secret, verified, recovery_codes, created_at, updated_at) VALUES ($1, $2, false, $3, NOW(), NOW()) ON CONFLICT (user_id) DO UPDATE SET secret = $2, verified = false, recovery_codes = $3, updated_at = NOW() `, userID, secret, recoveryCodesJSON) if err != nil { return nil, fmt.Errorf("failed to store TOTP: %w", err) } return &models.Setup2FAResponse{ Secret: secret, QRCodeDataURL: qrCode, RecoveryCodes: recoveryCodes, }, nil } // Verify2FASetup verifies the 2FA setup with a code func (s *TOTPService) Verify2FASetup(ctx context.Context, userID uuid.UUID, code string) error { // Get TOTP record var secret string var verified bool err := s.db.QueryRow(ctx, `SELECT secret, verified FROM user_totp WHERE user_id = $1`, userID).Scan(&secret, &verified) if err != nil { return ErrTOTPNotEnabled } if verified { return ErrTOTPAlreadyEnabled } // Validate code if !s.ValidateTOTP(secret, code) { return ErrTOTPInvalidCode } // Mark as verified and enable 2FA _, err = s.db.Exec(ctx, ` UPDATE user_totp SET verified = true, enabled_at = NOW(), updated_at = NOW() WHERE user_id = $1 `, userID) if err != nil { return fmt.Errorf("failed to verify TOTP: %w", err) } // Update user record _, err = s.db.Exec(ctx, ` UPDATE users SET two_factor_enabled = true, two_factor_verified_at = NOW(), updated_at = NOW() WHERE id = $1 `, userID) if err != nil { return fmt.Errorf("failed to update user: %w", err) } return nil } // CreateChallenge creates a 2FA challenge for login func (s *TOTPService) CreateChallenge(ctx context.Context, userID uuid.UUID, ipAddress, userAgent string) (string, error) { // Generate challenge ID challengeBytes := make([]byte, 32) if _, err := rand.Read(challengeBytes); err != nil { return "", fmt.Errorf("failed to generate challenge: %w", err) } challengeID := base64.URLEncoding.EncodeToString(challengeBytes) // Store challenge _, err := s.db.Exec(ctx, ` INSERT INTO two_factor_challenges (user_id, challenge_id, ip_address, user_agent, expires_at, created_at) VALUES ($1, $2, $3, $4, $5, NOW()) `, userID, challengeID, ipAddress, userAgent, time.Now().Add(ChallengeExpiry)) if err != nil { return "", fmt.Errorf("failed to create challenge: %w", err) } return challengeID, nil } // VerifyChallenge verifies a 2FA challenge with a TOTP code func (s *TOTPService) VerifyChallenge(ctx context.Context, challengeID, code string) (*uuid.UUID, error) { var challenge models.TwoFactorChallenge err := s.db.QueryRow(ctx, ` SELECT id, user_id, expires_at, used_at FROM two_factor_challenges WHERE challenge_id = $1 `, challengeID).Scan(&challenge.ID, &challenge.UserID, &challenge.ExpiresAt, &challenge.UsedAt) if err != nil { return nil, ErrInvalidToken } if challenge.UsedAt != nil { return nil, ErrInvalidToken } if time.Now().After(challenge.ExpiresAt) { return nil, ErrTOTPChallengeExpired } // Get TOTP secret var secret string err = s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, challenge.UserID).Scan(&secret) if err != nil { return nil, ErrTOTPNotEnabled } // Validate TOTP code if !s.ValidateTOTP(secret, code) { return nil, ErrTOTPInvalidCode } // Mark challenge as used _, err = s.db.Exec(ctx, `UPDATE two_factor_challenges SET used_at = NOW() WHERE id = $1`, challenge.ID) if err != nil { return nil, fmt.Errorf("failed to mark challenge as used: %w", err) } // Update last used time _, _ = s.db.Exec(ctx, `UPDATE user_totp SET last_used_at = NOW() WHERE user_id = $1`, challenge.UserID) return &challenge.UserID, nil } // VerifyChallengeWithRecoveryCode verifies a 2FA challenge with a recovery code func (s *TOTPService) VerifyChallengeWithRecoveryCode(ctx context.Context, challengeID, recoveryCode string) (*uuid.UUID, error) { var challenge models.TwoFactorChallenge err := s.db.QueryRow(ctx, ` SELECT id, user_id, expires_at, used_at FROM two_factor_challenges WHERE challenge_id = $1 `, challengeID).Scan(&challenge.ID, &challenge.UserID, &challenge.ExpiresAt, &challenge.UsedAt) if err != nil { return nil, ErrInvalidToken } if challenge.UsedAt != nil { return nil, ErrInvalidToken } if time.Now().After(challenge.ExpiresAt) { return nil, ErrTOTPChallengeExpired } // Get recovery codes var recoveryCodesJSON []byte err = s.db.QueryRow(ctx, `SELECT recovery_codes FROM user_totp WHERE user_id = $1 AND verified = true`, challenge.UserID).Scan(&recoveryCodesJSON) if err != nil { return nil, ErrTOTPNotEnabled } var hashedCodes []string json.Unmarshal(recoveryCodesJSON, &hashedCodes) // Hash the provided recovery code codeHash := sha256.Sum256([]byte(strings.ToUpper(recoveryCode))) codeHashStr := hex.EncodeToString(codeHash[:]) // Find and remove the recovery code found := false newCodes := make([]string, 0, len(hashedCodes)-1) for _, hc := range hashedCodes { if hc == codeHashStr && !found { found = true continue } newCodes = append(newCodes, hc) } if !found { return nil, ErrRecoveryCodeInvalid } // Update recovery codes newCodesJSON, _ := json.Marshal(newCodes) _, err = s.db.Exec(ctx, `UPDATE user_totp SET recovery_codes = $1, updated_at = NOW() WHERE user_id = $2`, newCodesJSON, challenge.UserID) if err != nil { return nil, fmt.Errorf("failed to update recovery codes: %w", err) } // Mark challenge as used _, err = s.db.Exec(ctx, `UPDATE two_factor_challenges SET used_at = NOW() WHERE id = $1`, challenge.ID) if err != nil { return nil, fmt.Errorf("failed to mark challenge as used: %w", err) } return &challenge.UserID, nil } // Disable2FA disables 2FA for a user func (s *TOTPService) Disable2FA(ctx context.Context, userID uuid.UUID, code string) error { // Get TOTP secret var secret string err := s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, userID).Scan(&secret) if err != nil { return ErrTOTPNotEnabled } // Validate code if !s.ValidateTOTP(secret, code) { return ErrTOTPInvalidCode } // Delete TOTP record _, err = s.db.Exec(ctx, `DELETE FROM user_totp WHERE user_id = $1`, userID) if err != nil { return fmt.Errorf("failed to delete TOTP: %w", err) } // Update user record _, err = s.db.Exec(ctx, ` UPDATE users SET two_factor_enabled = false, two_factor_verified_at = NULL, updated_at = NOW() WHERE id = $1 `, userID) if err != nil { return fmt.Errorf("failed to update user: %w", err) } return nil } // GetStatus returns the 2FA status for a user func (s *TOTPService) GetStatus(ctx context.Context, userID uuid.UUID) (*models.TwoFactorStatusResponse, error) { var totp models.UserTOTP var recoveryCodesJSON []byte err := s.db.QueryRow(ctx, ` SELECT id, verified, enabled_at, recovery_codes FROM user_totp WHERE user_id = $1 `, userID).Scan(&totp.ID, &totp.Verified, &totp.EnabledAt, &recoveryCodesJSON) if err != nil { // 2FA not set up return &models.TwoFactorStatusResponse{ Enabled: false, Verified: false, RecoveryCodesCount: 0, }, nil } var hashedCodes []string json.Unmarshal(recoveryCodesJSON, &hashedCodes) return &models.TwoFactorStatusResponse{ Enabled: totp.Verified, Verified: totp.Verified, EnabledAt: totp.EnabledAt, RecoveryCodesCount: len(hashedCodes), }, nil } // RegenerateRecoveryCodes generates new recovery codes (requires current TOTP code) func (s *TOTPService) RegenerateRecoveryCodes(ctx context.Context, userID uuid.UUID, code string) ([]string, error) { // Get TOTP secret var secret string err := s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, userID).Scan(&secret) if err != nil { return nil, ErrTOTPNotEnabled } // Validate code if !s.ValidateTOTP(secret, code) { return nil, ErrTOTPInvalidCode } // Generate new recovery codes recoveryCodes, err := s.GenerateRecoveryCodes() if err != nil { return nil, err } // Hash recovery codes for storage hashedCodes := make([]string, len(recoveryCodes)) for i, rc := range recoveryCodes { hash := sha256.Sum256([]byte(rc)) hashedCodes[i] = hex.EncodeToString(hash[:]) } recoveryCodesJSON, _ := json.Marshal(hashedCodes) // Update recovery codes _, err = s.db.Exec(ctx, `UPDATE user_totp SET recovery_codes = $1, updated_at = NOW() WHERE user_id = $2`, recoveryCodesJSON, userID) if err != nil { return nil, fmt.Errorf("failed to update recovery codes: %w", err) } return recoveryCodes, nil } // IsTwoFactorEnabled checks if 2FA is enabled for a user func (s *TOTPService) IsTwoFactorEnabled(ctx context.Context, userID uuid.UUID) (bool, error) { var enabled bool err := s.db.QueryRow(ctx, `SELECT two_factor_enabled FROM users WHERE id = $1`, userID).Scan(&enabled) if err != nil { return false, err } return enabled, nil }