package rbac import ( "context" "encoding/json" "fmt" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) // Store provides database operations for RBAC entities type Store struct { pool *pgxpool.Pool } // NewStore creates a new RBAC store func NewStore(pool *pgxpool.Pool) *Store { return &Store{pool: pool} } // ============================================================================ // Tenant Operations // ============================================================================ // CreateTenant creates a new tenant func (s *Store) CreateTenant(ctx context.Context, tenant *Tenant) error { tenant.ID = uuid.New() tenant.CreatedAt = time.Now().UTC() tenant.UpdatedAt = tenant.CreatedAt if tenant.Status == "" { tenant.Status = TenantStatusActive } if tenant.Settings == nil { tenant.Settings = make(map[string]any) } settingsJSON, err := json.Marshal(tenant.Settings) if err != nil { return fmt.Errorf("failed to marshal settings: %w", err) } _, err = s.pool.Exec(ctx, ` INSERT INTO compliance_tenants (id, name, slug, settings, max_users, llm_quota_monthly, status, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) `, tenant.ID, tenant.Name, tenant.Slug, settingsJSON, tenant.MaxUsers, tenant.LLMQuotaMonthly, tenant.Status, tenant.CreatedAt, tenant.UpdatedAt) return err } // GetTenant retrieves a tenant by ID func (s *Store) GetTenant(ctx context.Context, id uuid.UUID) (*Tenant, error) { var tenant Tenant var settingsJSON []byte err := s.pool.QueryRow(ctx, ` SELECT id, name, slug, settings, max_users, llm_quota_monthly, status, created_at, updated_at FROM compliance_tenants WHERE id = $1 `, id).Scan( &tenant.ID, &tenant.Name, &tenant.Slug, &settingsJSON, &tenant.MaxUsers, &tenant.LLMQuotaMonthly, &tenant.Status, &tenant.CreatedAt, &tenant.UpdatedAt, ) if err != nil { return nil, err } if err := json.Unmarshal(settingsJSON, &tenant.Settings); err != nil { tenant.Settings = make(map[string]any) } return &tenant, nil } // GetTenantBySlug retrieves a tenant by slug func (s *Store) GetTenantBySlug(ctx context.Context, slug string) (*Tenant, error) { var tenant Tenant var settingsJSON []byte err := s.pool.QueryRow(ctx, ` SELECT id, name, slug, settings, max_users, llm_quota_monthly, status, created_at, updated_at FROM compliance_tenants WHERE slug = $1 `, slug).Scan( &tenant.ID, &tenant.Name, &tenant.Slug, &settingsJSON, &tenant.MaxUsers, &tenant.LLMQuotaMonthly, &tenant.Status, &tenant.CreatedAt, &tenant.UpdatedAt, ) if err != nil { return nil, err } if err := json.Unmarshal(settingsJSON, &tenant.Settings); err != nil { tenant.Settings = make(map[string]any) } return &tenant, nil } // ListTenants lists all tenants func (s *Store) ListTenants(ctx context.Context) ([]*Tenant, error) { rows, err := s.pool.Query(ctx, ` SELECT id, name, slug, settings, max_users, llm_quota_monthly, status, created_at, updated_at FROM compliance_tenants ORDER BY name `) if err != nil { return nil, err } defer rows.Close() var tenants []*Tenant for rows.Next() { var tenant Tenant var settingsJSON []byte err := rows.Scan( &tenant.ID, &tenant.Name, &tenant.Slug, &settingsJSON, &tenant.MaxUsers, &tenant.LLMQuotaMonthly, &tenant.Status, &tenant.CreatedAt, &tenant.UpdatedAt, ) if err != nil { continue } if err := json.Unmarshal(settingsJSON, &tenant.Settings); err != nil { tenant.Settings = make(map[string]any) } tenants = append(tenants, &tenant) } return tenants, nil } // UpdateTenant updates a tenant func (s *Store) UpdateTenant(ctx context.Context, tenant *Tenant) error { tenant.UpdatedAt = time.Now().UTC() settingsJSON, err := json.Marshal(tenant.Settings) if err != nil { return fmt.Errorf("failed to marshal settings: %w", err) } _, err = s.pool.Exec(ctx, ` UPDATE compliance_tenants SET name = $2, slug = $3, settings = $4, max_users = $5, llm_quota_monthly = $6, status = $7, updated_at = $8 WHERE id = $1 `, tenant.ID, tenant.Name, tenant.Slug, settingsJSON, tenant.MaxUsers, tenant.LLMQuotaMonthly, tenant.Status, tenant.UpdatedAt) return err } // ============================================================================ // Namespace Operations // ============================================================================ // CreateNamespace creates a new namespace func (s *Store) CreateNamespace(ctx context.Context, ns *Namespace) error { ns.ID = uuid.New() ns.CreatedAt = time.Now().UTC() ns.UpdatedAt = ns.CreatedAt if ns.IsolationLevel == "" { ns.IsolationLevel = IsolationStrict } if ns.DataClassification == "" { ns.DataClassification = ClassificationInternal } if ns.Metadata == nil { ns.Metadata = make(map[string]any) } metadataJSON, err := json.Marshal(ns.Metadata) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } _, err = s.pool.Exec(ctx, ` INSERT INTO compliance_namespaces (id, tenant_id, name, slug, parent_namespace_id, isolation_level, data_classification, metadata, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) `, ns.ID, ns.TenantID, ns.Name, ns.Slug, ns.ParentNamespaceID, ns.IsolationLevel, ns.DataClassification, metadataJSON, ns.CreatedAt, ns.UpdatedAt) return err } // GetNamespace retrieves a namespace by ID func (s *Store) GetNamespace(ctx context.Context, id uuid.UUID) (*Namespace, error) { var ns Namespace var metadataJSON []byte err := s.pool.QueryRow(ctx, ` SELECT id, tenant_id, name, slug, parent_namespace_id, isolation_level, data_classification, metadata, created_at, updated_at FROM compliance_namespaces WHERE id = $1 `, id).Scan( &ns.ID, &ns.TenantID, &ns.Name, &ns.Slug, &ns.ParentNamespaceID, &ns.IsolationLevel, &ns.DataClassification, &metadataJSON, &ns.CreatedAt, &ns.UpdatedAt, ) if err != nil { return nil, err } if err := json.Unmarshal(metadataJSON, &ns.Metadata); err != nil { ns.Metadata = make(map[string]any) } return &ns, nil } // GetNamespaceBySlug retrieves a namespace by tenant and slug func (s *Store) GetNamespaceBySlug(ctx context.Context, tenantID uuid.UUID, slug string) (*Namespace, error) { var ns Namespace var metadataJSON []byte err := s.pool.QueryRow(ctx, ` SELECT id, tenant_id, name, slug, parent_namespace_id, isolation_level, data_classification, metadata, created_at, updated_at FROM compliance_namespaces WHERE tenant_id = $1 AND slug = $2 `, tenantID, slug).Scan( &ns.ID, &ns.TenantID, &ns.Name, &ns.Slug, &ns.ParentNamespaceID, &ns.IsolationLevel, &ns.DataClassification, &metadataJSON, &ns.CreatedAt, &ns.UpdatedAt, ) if err != nil { return nil, err } if err := json.Unmarshal(metadataJSON, &ns.Metadata); err != nil { ns.Metadata = make(map[string]any) } return &ns, nil } // ListNamespaces lists namespaces for a tenant func (s *Store) ListNamespaces(ctx context.Context, tenantID uuid.UUID) ([]*Namespace, error) { rows, err := s.pool.Query(ctx, ` SELECT id, tenant_id, name, slug, parent_namespace_id, isolation_level, data_classification, metadata, created_at, updated_at FROM compliance_namespaces WHERE tenant_id = $1 ORDER BY name `, tenantID) if err != nil { return nil, err } defer rows.Close() var namespaces []*Namespace for rows.Next() { var ns Namespace var metadataJSON []byte err := rows.Scan( &ns.ID, &ns.TenantID, &ns.Name, &ns.Slug, &ns.ParentNamespaceID, &ns.IsolationLevel, &ns.DataClassification, &metadataJSON, &ns.CreatedAt, &ns.UpdatedAt, ) if err != nil { continue } if err := json.Unmarshal(metadataJSON, &ns.Metadata); err != nil { ns.Metadata = make(map[string]any) } namespaces = append(namespaces, &ns) } return namespaces, nil } // ============================================================================ // Role Operations // ============================================================================ // CreateRole creates a new role func (s *Store) CreateRole(ctx context.Context, role *Role) error { role.ID = uuid.New() role.CreatedAt = time.Now().UTC() role.UpdatedAt = role.CreatedAt _, err := s.pool.Exec(ctx, ` INSERT INTO compliance_roles (id, tenant_id, name, description, permissions, is_system_role, hierarchy_level, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) `, role.ID, role.TenantID, role.Name, role.Description, role.Permissions, role.IsSystemRole, role.HierarchyLevel, role.CreatedAt, role.UpdatedAt) return err } // GetRole retrieves a role by ID func (s *Store) GetRole(ctx context.Context, id uuid.UUID) (*Role, error) { var role Role err := s.pool.QueryRow(ctx, ` SELECT id, tenant_id, name, description, permissions, is_system_role, hierarchy_level, created_at, updated_at FROM compliance_roles WHERE id = $1 `, id).Scan( &role.ID, &role.TenantID, &role.Name, &role.Description, &role.Permissions, &role.IsSystemRole, &role.HierarchyLevel, &role.CreatedAt, &role.UpdatedAt, ) return &role, err } // GetRoleByName retrieves a role by tenant and name func (s *Store) GetRoleByName(ctx context.Context, tenantID *uuid.UUID, name string) (*Role, error) { var role Role query := ` SELECT id, tenant_id, name, description, permissions, is_system_role, hierarchy_level, created_at, updated_at FROM compliance_roles WHERE name = $1 AND (tenant_id = $2 OR (tenant_id IS NULL AND is_system_role = TRUE)) ` err := s.pool.QueryRow(ctx, query, name, tenantID).Scan( &role.ID, &role.TenantID, &role.Name, &role.Description, &role.Permissions, &role.IsSystemRole, &role.HierarchyLevel, &role.CreatedAt, &role.UpdatedAt, ) return &role, err } // ListRoles lists roles for a tenant (including system roles) func (s *Store) ListRoles(ctx context.Context, tenantID *uuid.UUID) ([]*Role, error) { rows, err := s.pool.Query(ctx, ` SELECT id, tenant_id, name, description, permissions, is_system_role, hierarchy_level, created_at, updated_at FROM compliance_roles WHERE tenant_id = $1 OR is_system_role = TRUE ORDER BY hierarchy_level, name `, tenantID) if err != nil { return nil, err } defer rows.Close() var roles []*Role for rows.Next() { var role Role err := rows.Scan( &role.ID, &role.TenantID, &role.Name, &role.Description, &role.Permissions, &role.IsSystemRole, &role.HierarchyLevel, &role.CreatedAt, &role.UpdatedAt, ) if err != nil { continue } roles = append(roles, &role) } return roles, nil } // ListSystemRoles lists all system roles func (s *Store) ListSystemRoles(ctx context.Context) ([]*Role, error) { rows, err := s.pool.Query(ctx, ` SELECT id, tenant_id, name, description, permissions, is_system_role, hierarchy_level, created_at, updated_at FROM compliance_roles WHERE is_system_role = TRUE ORDER BY hierarchy_level, name `) if err != nil { return nil, err } defer rows.Close() var roles []*Role for rows.Next() { var role Role err := rows.Scan( &role.ID, &role.TenantID, &role.Name, &role.Description, &role.Permissions, &role.IsSystemRole, &role.HierarchyLevel, &role.CreatedAt, &role.UpdatedAt, ) if err != nil { continue } roles = append(roles, &role) } return roles, nil } // ============================================================================ // User Role Operations // ============================================================================ // AssignRole assigns a role to a user func (s *Store) AssignRole(ctx context.Context, ur *UserRole) error { ur.ID = uuid.New() ur.CreatedAt = time.Now().UTC() _, err := s.pool.Exec(ctx, ` INSERT INTO compliance_user_roles (id, user_id, role_id, tenant_id, namespace_id, granted_by, expires_at, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (user_id, role_id, tenant_id, namespace_id) DO UPDATE SET granted_by = EXCLUDED.granted_by, expires_at = EXCLUDED.expires_at `, ur.ID, ur.UserID, ur.RoleID, ur.TenantID, ur.NamespaceID, ur.GrantedBy, ur.ExpiresAt, ur.CreatedAt) return err } // RevokeRole revokes a role from a user func (s *Store) RevokeRole(ctx context.Context, userID, roleID, tenantID uuid.UUID, namespaceID *uuid.UUID) error { _, err := s.pool.Exec(ctx, ` DELETE FROM compliance_user_roles WHERE user_id = $1 AND role_id = $2 AND tenant_id = $3 AND (namespace_id = $4 OR (namespace_id IS NULL AND $4 IS NULL)) `, userID, roleID, tenantID, namespaceID) return err } // GetUserRoles retrieves all roles for a user in a tenant func (s *Store) GetUserRoles(ctx context.Context, userID, tenantID uuid.UUID) ([]*UserRole, error) { rows, err := s.pool.Query(ctx, ` SELECT ur.id, ur.user_id, ur.role_id, ur.tenant_id, ur.namespace_id, ur.granted_by, ur.expires_at, ur.created_at, r.name as role_name, r.permissions as role_permissions, n.name as namespace_name FROM compliance_user_roles ur JOIN compliance_roles r ON ur.role_id = r.id LEFT JOIN compliance_namespaces n ON ur.namespace_id = n.id WHERE ur.user_id = $1 AND ur.tenant_id = $2 AND (ur.expires_at IS NULL OR ur.expires_at > NOW()) ORDER BY r.hierarchy_level, r.name `, userID, tenantID) if err != nil { return nil, err } defer rows.Close() var userRoles []*UserRole for rows.Next() { var ur UserRole var namespaceName *string err := rows.Scan( &ur.ID, &ur.UserID, &ur.RoleID, &ur.TenantID, &ur.NamespaceID, &ur.GrantedBy, &ur.ExpiresAt, &ur.CreatedAt, &ur.RoleName, &ur.RolePermissions, &namespaceName, ) if err != nil { continue } if namespaceName != nil { ur.NamespaceName = *namespaceName } userRoles = append(userRoles, &ur) } return userRoles, nil } // GetUserRolesForNamespace retrieves roles for a user in a specific namespace func (s *Store) GetUserRolesForNamespace(ctx context.Context, userID, tenantID uuid.UUID, namespaceID *uuid.UUID) ([]*UserRole, error) { rows, err := s.pool.Query(ctx, ` SELECT ur.id, ur.user_id, ur.role_id, ur.tenant_id, ur.namespace_id, ur.granted_by, ur.expires_at, ur.created_at, r.name as role_name, r.permissions as role_permissions FROM compliance_user_roles ur JOIN compliance_roles r ON ur.role_id = r.id WHERE ur.user_id = $1 AND ur.tenant_id = $2 AND (ur.namespace_id = $3 OR ur.namespace_id IS NULL) AND (ur.expires_at IS NULL OR ur.expires_at > NOW()) ORDER BY r.hierarchy_level, r.name `, userID, tenantID, namespaceID) if err != nil { return nil, err } defer rows.Close() var userRoles []*UserRole for rows.Next() { var ur UserRole err := rows.Scan( &ur.ID, &ur.UserID, &ur.RoleID, &ur.TenantID, &ur.NamespaceID, &ur.GrantedBy, &ur.ExpiresAt, &ur.CreatedAt, &ur.RoleName, &ur.RolePermissions, ) if err != nil { continue } userRoles = append(userRoles, &ur) } return userRoles, nil } // ============================================================================ // LLM Policy Operations // ============================================================================ // CreateLLMPolicy creates a new LLM policy func (s *Store) CreateLLMPolicy(ctx context.Context, policy *LLMPolicy) error { policy.ID = uuid.New() policy.CreatedAt = time.Now().UTC() policy.UpdatedAt = policy.CreatedAt _, err := s.pool.Exec(ctx, ` INSERT INTO compliance_llm_policies ( id, tenant_id, namespace_id, name, description, allowed_data_categories, blocked_data_categories, require_pii_redaction, pii_redaction_level, allowed_models, max_tokens_per_request, max_requests_per_day, max_requests_per_hour, is_active, priority, created_at, updated_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) `, policy.ID, policy.TenantID, policy.NamespaceID, policy.Name, policy.Description, policy.AllowedDataCategories, policy.BlockedDataCategories, policy.RequirePIIRedaction, policy.PIIRedactionLevel, policy.AllowedModels, policy.MaxTokensPerRequest, policy.MaxRequestsPerDay, policy.MaxRequestsPerHour, policy.IsActive, policy.Priority, policy.CreatedAt, policy.UpdatedAt, ) return err } // GetLLMPolicy retrieves an LLM policy by ID func (s *Store) GetLLMPolicy(ctx context.Context, id uuid.UUID) (*LLMPolicy, error) { var policy LLMPolicy err := s.pool.QueryRow(ctx, ` SELECT id, tenant_id, namespace_id, name, description, allowed_data_categories, blocked_data_categories, require_pii_redaction, pii_redaction_level, allowed_models, max_tokens_per_request, max_requests_per_day, max_requests_per_hour, is_active, priority, created_at, updated_at FROM compliance_llm_policies WHERE id = $1 `, id).Scan( &policy.ID, &policy.TenantID, &policy.NamespaceID, &policy.Name, &policy.Description, &policy.AllowedDataCategories, &policy.BlockedDataCategories, &policy.RequirePIIRedaction, &policy.PIIRedactionLevel, &policy.AllowedModels, &policy.MaxTokensPerRequest, &policy.MaxRequestsPerDay, &policy.MaxRequestsPerHour, &policy.IsActive, &policy.Priority, &policy.CreatedAt, &policy.UpdatedAt, ) return &policy, err } // GetEffectiveLLMPolicy retrieves the effective LLM policy for a namespace func (s *Store) GetEffectiveLLMPolicy(ctx context.Context, tenantID uuid.UUID, namespaceID *uuid.UUID) (*LLMPolicy, error) { var policy LLMPolicy // Get most specific active policy (namespace-specific or tenant-wide) err := s.pool.QueryRow(ctx, ` SELECT id, tenant_id, namespace_id, name, description, allowed_data_categories, blocked_data_categories, require_pii_redaction, pii_redaction_level, allowed_models, max_tokens_per_request, max_requests_per_day, max_requests_per_hour, is_active, priority, created_at, updated_at FROM compliance_llm_policies WHERE tenant_id = $1 AND is_active = TRUE AND (namespace_id = $2 OR namespace_id IS NULL) ORDER BY CASE WHEN namespace_id = $2 THEN 0 ELSE 1 END, priority ASC LIMIT 1 `, tenantID, namespaceID).Scan( &policy.ID, &policy.TenantID, &policy.NamespaceID, &policy.Name, &policy.Description, &policy.AllowedDataCategories, &policy.BlockedDataCategories, &policy.RequirePIIRedaction, &policy.PIIRedactionLevel, &policy.AllowedModels, &policy.MaxTokensPerRequest, &policy.MaxRequestsPerDay, &policy.MaxRequestsPerHour, &policy.IsActive, &policy.Priority, &policy.CreatedAt, &policy.UpdatedAt, ) if err == pgx.ErrNoRows { return nil, nil // No policy = allow all } return &policy, err } // ListLLMPolicies lists LLM policies for a tenant func (s *Store) ListLLMPolicies(ctx context.Context, tenantID uuid.UUID) ([]*LLMPolicy, error) { rows, err := s.pool.Query(ctx, ` SELECT id, tenant_id, namespace_id, name, description, allowed_data_categories, blocked_data_categories, require_pii_redaction, pii_redaction_level, allowed_models, max_tokens_per_request, max_requests_per_day, max_requests_per_hour, is_active, priority, created_at, updated_at FROM compliance_llm_policies WHERE tenant_id = $1 ORDER BY priority, name `, tenantID) if err != nil { return nil, err } defer rows.Close() var policies []*LLMPolicy for rows.Next() { var policy LLMPolicy err := rows.Scan( &policy.ID, &policy.TenantID, &policy.NamespaceID, &policy.Name, &policy.Description, &policy.AllowedDataCategories, &policy.BlockedDataCategories, &policy.RequirePIIRedaction, &policy.PIIRedactionLevel, &policy.AllowedModels, &policy.MaxTokensPerRequest, &policy.MaxRequestsPerDay, &policy.MaxRequestsPerHour, &policy.IsActive, &policy.Priority, &policy.CreatedAt, &policy.UpdatedAt, ) if err != nil { continue } policies = append(policies, &policy) } return policies, nil } // UpdateLLMPolicy updates an LLM policy func (s *Store) UpdateLLMPolicy(ctx context.Context, policy *LLMPolicy) error { policy.UpdatedAt = time.Now().UTC() _, err := s.pool.Exec(ctx, ` UPDATE compliance_llm_policies SET name = $2, description = $3, allowed_data_categories = $4, blocked_data_categories = $5, require_pii_redaction = $6, pii_redaction_level = $7, allowed_models = $8, max_tokens_per_request = $9, max_requests_per_day = $10, max_requests_per_hour = $11, is_active = $12, priority = $13, updated_at = $14 WHERE id = $1 `, policy.ID, policy.Name, policy.Description, policy.AllowedDataCategories, policy.BlockedDataCategories, policy.RequirePIIRedaction, policy.PIIRedactionLevel, policy.AllowedModels, policy.MaxTokensPerRequest, policy.MaxRequestsPerDay, policy.MaxRequestsPerHour, policy.IsActive, policy.Priority, policy.UpdatedAt, ) return err } // DeleteLLMPolicy deletes an LLM policy func (s *Store) DeleteLLMPolicy(ctx context.Context, id uuid.UUID) error { _, err := s.pool.Exec(ctx, `DELETE FROM compliance_llm_policies WHERE id = $1`, id) return err }