package session import ( "context" "encoding/json" "errors" "fmt" "os" "strconv" "sync" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/redis/go-redis/v9" ) // UserType distinguishes between internal employees and external customers type UserType string const ( UserTypeEmployee UserType = "employee" UserTypeCustomer UserType = "customer" ) // Session represents a user session with RBAC data type Session struct { SessionID string `json:"session_id"` UserID string `json:"user_id"` Email string `json:"email"` UserType UserType `json:"user_type"` Roles []string `json:"roles"` Permissions []string `json:"permissions"` TenantID *string `json:"tenant_id,omitempty"` IPAddress *string `json:"ip_address,omitempty"` UserAgent *string `json:"user_agent,omitempty"` CreatedAt time.Time `json:"created_at"` LastActivityAt time.Time `json:"last_activity_at"` } // HasPermission checks if session has a specific permission func (s *Session) HasPermission(permission string) bool { for _, p := range s.Permissions { if p == permission { return true } } return false } // HasAnyPermission checks if session has any of the specified permissions func (s *Session) HasAnyPermission(permissions []string) bool { for _, needed := range permissions { for _, has := range s.Permissions { if needed == has { return true } } } return false } // HasAllPermissions checks if session has all specified permissions func (s *Session) HasAllPermissions(permissions []string) bool { for _, needed := range permissions { found := false for _, has := range s.Permissions { if needed == has { found = true break } } if !found { return false } } return true } // HasRole checks if session has a specific role func (s *Session) HasRole(role string) bool { for _, r := range s.Roles { if r == role { return true } } return false } // IsEmployee checks if user is an employee (internal staff) func (s *Session) IsEmployee() bool { return s.UserType == UserTypeEmployee } // IsCustomer checks if user is a customer (external user) func (s *Session) IsCustomer() bool { return s.UserType == UserTypeCustomer } // SessionStore provides hybrid Valkey + PostgreSQL session storage type SessionStore struct { valkeyClient *redis.Client pgPool *pgxpool.Pool sessionTTL time.Duration valkeyEnabled bool mu sync.RWMutex } // NewSessionStore creates a new session store func NewSessionStore(pgPool *pgxpool.Pool) *SessionStore { ttlHours := 24 if ttlStr := os.Getenv("SESSION_TTL_HOURS"); ttlStr != "" { if val, err := strconv.Atoi(ttlStr); err == nil { ttlHours = val } } store := &SessionStore{ pgPool: pgPool, sessionTTL: time.Duration(ttlHours) * time.Hour, valkeyEnabled: false, } // Try to connect to Valkey valkeyURL := os.Getenv("VALKEY_URL") if valkeyURL == "" { valkeyURL = "redis://localhost:6379" } opt, err := redis.ParseURL(valkeyURL) if err == nil { store.valkeyClient = redis.NewClient(opt) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := store.valkeyClient.Ping(ctx).Err(); err == nil { store.valkeyEnabled = true } } return store } // Close closes all connections func (s *SessionStore) Close() { if s.valkeyClient != nil { s.valkeyClient.Close() } } // getValkeyKey returns the Valkey key for a session func (s *SessionStore) getValkeyKey(sessionID string) string { return fmt.Sprintf("session:%s", sessionID) } // CreateSession creates a new session func (s *SessionStore) CreateSession(ctx context.Context, userID, email string, userType UserType, roles, permissions []string, tenantID, ipAddress, userAgent *string) (*Session, error) { session := &Session{ SessionID: uuid.New().String(), UserID: userID, Email: email, UserType: userType, Roles: roles, Permissions: permissions, TenantID: tenantID, IPAddress: ipAddress, UserAgent: userAgent, CreatedAt: time.Now().UTC(), LastActivityAt: time.Now().UTC(), } // Store in Valkey (primary cache) if s.valkeyEnabled { data, err := json.Marshal(session) if err == nil { key := s.getValkeyKey(session.SessionID) s.valkeyClient.SetEx(ctx, key, data, s.sessionTTL) } } // Store in PostgreSQL (persistent + audit) if s.pgPool != nil { rolesJSON, _ := json.Marshal(roles) permsJSON, _ := json.Marshal(permissions) expiresAt := time.Now().UTC().Add(s.sessionTTL) _, err := s.pgPool.Exec(ctx, ` INSERT INTO user_sessions ( id, user_id, token_hash, email, user_type, roles, permissions, tenant_id, ip_address, user_agent, expires_at, created_at, last_activity_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) `, session.SessionID, session.UserID, session.SessionID, // token_hash = session_id for session-based auth session.Email, string(session.UserType), rolesJSON, permsJSON, tenantID, ipAddress, userAgent, expiresAt, session.CreatedAt, session.LastActivityAt, ) if err != nil { return nil, fmt.Errorf("failed to store session in PostgreSQL: %w", err) } } return session, nil } // GetSession retrieves a session by ID func (s *SessionStore) GetSession(ctx context.Context, sessionID string) (*Session, error) { // Try Valkey first if s.valkeyEnabled { key := s.getValkeyKey(sessionID) data, err := s.valkeyClient.Get(ctx, key).Bytes() if err == nil { var session Session if err := json.Unmarshal(data, &session); err == nil { // Update last activity go s.updateLastActivity(sessionID) return &session, nil } } } // Fallback to PostgreSQL if s.pgPool != nil { var session Session var rolesJSON, permsJSON []byte var tenantID, ipAddress, userAgent *string err := s.pgPool.QueryRow(ctx, ` SELECT id, user_id, email, user_type, roles, permissions, tenant_id, ip_address, user_agent, created_at, last_activity_at FROM user_sessions WHERE id = $1 AND revoked_at IS NULL AND expires_at > NOW() `, sessionID).Scan( &session.SessionID, &session.UserID, &session.Email, &session.UserType, &rolesJSON, &permsJSON, &tenantID, &ipAddress, &userAgent, &session.CreatedAt, &session.LastActivityAt, ) if err != nil { return nil, errors.New("session not found or expired") } json.Unmarshal(rolesJSON, &session.Roles) json.Unmarshal(permsJSON, &session.Permissions) session.TenantID = tenantID session.IPAddress = ipAddress session.UserAgent = userAgent // Re-cache in Valkey if s.valkeyEnabled { data, _ := json.Marshal(session) key := s.getValkeyKey(sessionID) s.valkeyClient.SetEx(ctx, key, data, s.sessionTTL) } return &session, nil } return nil, errors.New("session not found") } // updateLastActivity updates the last activity timestamp func (s *SessionStore) updateLastActivity(sessionID string) { ctx := context.Background() now := time.Now().UTC() // Update Valkey TTL if s.valkeyEnabled { key := s.getValkeyKey(sessionID) s.valkeyClient.Expire(ctx, key, s.sessionTTL) } // Update PostgreSQL if s.pgPool != nil { s.pgPool.Exec(ctx, ` UPDATE user_sessions SET last_activity_at = $1, expires_at = $2 WHERE id = $3 `, now, now.Add(s.sessionTTL), sessionID) } } // RevokeSession revokes a session (logout) func (s *SessionStore) RevokeSession(ctx context.Context, sessionID string) error { // Remove from Valkey if s.valkeyEnabled { key := s.getValkeyKey(sessionID) s.valkeyClient.Del(ctx, key) } // Mark as revoked in PostgreSQL if s.pgPool != nil { _, err := s.pgPool.Exec(ctx, ` UPDATE user_sessions SET revoked_at = NOW() WHERE id = $1 `, sessionID) if err != nil { return fmt.Errorf("failed to revoke session: %w", err) } } return nil } // RevokeAllUserSessions revokes all sessions for a user func (s *SessionStore) RevokeAllUserSessions(ctx context.Context, userID string) (int, error) { if s.pgPool == nil { return 0, nil } // Get all session IDs rows, err := s.pgPool.Query(ctx, ` SELECT id FROM user_sessions WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() `, userID) if err != nil { return 0, err } defer rows.Close() var sessionIDs []string for rows.Next() { var id string if err := rows.Scan(&id); err == nil { sessionIDs = append(sessionIDs, id) } } // Revoke in PostgreSQL result, err := s.pgPool.Exec(ctx, ` UPDATE user_sessions SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL `, userID) if err != nil { return 0, err } // Remove from Valkey if s.valkeyEnabled { for _, sessionID := range sessionIDs { key := s.getValkeyKey(sessionID) s.valkeyClient.Del(ctx, key) } } return int(result.RowsAffected()), nil } // GetActiveSessions returns all active sessions for a user func (s *SessionStore) GetActiveSessions(ctx context.Context, userID string) ([]*Session, error) { if s.pgPool == nil { return nil, nil } rows, err := s.pgPool.Query(ctx, ` SELECT id, user_id, email, user_type, roles, permissions, tenant_id, ip_address, user_agent, created_at, last_activity_at FROM user_sessions WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() ORDER BY last_activity_at DESC `, userID) if err != nil { return nil, err } defer rows.Close() var sessions []*Session for rows.Next() { var session Session var rolesJSON, permsJSON []byte var tenantID, ipAddress, userAgent *string err := rows.Scan( &session.SessionID, &session.UserID, &session.Email, &session.UserType, &rolesJSON, &permsJSON, &tenantID, &ipAddress, &userAgent, &session.CreatedAt, &session.LastActivityAt, ) if err != nil { continue } json.Unmarshal(rolesJSON, &session.Roles) json.Unmarshal(permsJSON, &session.Permissions) session.TenantID = tenantID session.IPAddress = ipAddress session.UserAgent = userAgent sessions = append(sessions, &session) } return sessions, nil } // CleanupExpiredSessions removes old expired sessions from PostgreSQL func (s *SessionStore) CleanupExpiredSessions(ctx context.Context) (int, error) { if s.pgPool == nil { return 0, nil } result, err := s.pgPool.Exec(ctx, ` DELETE FROM user_sessions WHERE expires_at < NOW() - INTERVAL '7 days' `) if err != nil { return 0, err } return int(result.RowsAffected()), nil } // Global session store instance var ( globalStore *SessionStore globalStoreMu sync.Mutex globalStoreOnce sync.Once ) // GetSessionStore returns the global session store instance func GetSessionStore(pgPool *pgxpool.Pool) *SessionStore { globalStoreMu.Lock() defer globalStoreMu.Unlock() if globalStore == nil { globalStore = NewSessionStore(pgPool) } return globalStore }