package services import ( "context" "encoding/json" "time" "github.com/breakpilot/billing-service/internal/database" "github.com/breakpilot/billing-service/internal/models" "github.com/google/uuid" ) // EntitlementService handles entitlement-related operations type EntitlementService struct { db *database.DB subService *SubscriptionService } // NewEntitlementService creates a new EntitlementService func NewEntitlementService(db *database.DB, subService *SubscriptionService) *EntitlementService { return &EntitlementService{ db: db, subService: subService, } } // GetEntitlements returns the entitlement info for a user func (s *EntitlementService) GetEntitlements(ctx context.Context, userID uuid.UUID) (*models.EntitlementInfo, error) { entitlements, err := s.getUserEntitlements(ctx, userID) if err != nil || entitlements == nil { return nil, err } return &models.EntitlementInfo{ Features: entitlements.Features.FeatureFlags, MaxTeamMembers: entitlements.Features.MaxTeamMembers, PrioritySupport: entitlements.Features.PrioritySupport, CustomBranding: entitlements.Features.CustomBranding, }, nil } // GetEntitlementsByUserIDString returns entitlements by user ID string (for internal API) func (s *EntitlementService) GetEntitlementsByUserIDString(ctx context.Context, userIDStr string) (*models.UserEntitlements, error) { userID, err := uuid.Parse(userIDStr) if err != nil { return nil, err } return s.getUserEntitlements(ctx, userID) } // getUserEntitlements retrieves or creates entitlements for a user func (s *EntitlementService) getUserEntitlements(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) { query := ` SELECT id, user_id, plan_id, ai_requests_limit, ai_requests_used, documents_limit, documents_used, features, period_start, period_end, created_at, updated_at FROM user_entitlements WHERE user_id = $1 ` var ent models.UserEntitlements var featuresJSON []byte var periodStart, periodEnd *time.Time err := s.db.Pool.QueryRow(ctx, query, userID).Scan( &ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed, &ent.DocumentsLimit, &ent.DocumentsUsed, &featuresJSON, &periodStart, &periodEnd, nil, &ent.UpdatedAt, ) if err != nil { if err.Error() == "no rows in result set" { // Try to create entitlements based on subscription return s.createEntitlementsFromSubscription(ctx, userID) } return nil, err } if len(featuresJSON) > 0 { json.Unmarshal(featuresJSON, &ent.Features) } return &ent, nil } // createEntitlementsFromSubscription creates entitlements based on user's subscription func (s *EntitlementService) createEntitlementsFromSubscription(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) { // Get user's subscription sub, err := s.subService.GetByUserID(ctx, userID) if err != nil || sub == nil { return nil, err } // Get plan details plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID)) if err != nil || plan == nil { return nil, err } // Create entitlements return s.CreateEntitlements(ctx, userID, sub.PlanID, plan.Features, sub.CurrentPeriodEnd) } // CreateEntitlements creates entitlements for a user func (s *EntitlementService) CreateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures, periodEnd *time.Time) (*models.UserEntitlements, error) { featuresJSON, _ := json.Marshal(features) now := time.Now() periodStart := now query := ` INSERT INTO user_entitlements ( user_id, plan_id, ai_requests_limit, ai_requests_used, documents_limit, documents_used, features, period_start, period_end ) VALUES ($1, $2, $3, 0, $4, 0, $5, $6, $7) ON CONFLICT (user_id) DO UPDATE SET plan_id = EXCLUDED.plan_id, ai_requests_limit = EXCLUDED.ai_requests_limit, documents_limit = EXCLUDED.documents_limit, features = EXCLUDED.features, period_start = EXCLUDED.period_start, period_end = EXCLUDED.period_end, updated_at = NOW() RETURNING id, user_id, plan_id, ai_requests_limit, ai_requests_used, documents_limit, documents_used, updated_at ` var ent models.UserEntitlements err := s.db.Pool.QueryRow(ctx, query, userID, planID, features.AIRequestsLimit, features.DocumentsLimit, featuresJSON, periodStart, periodEnd, ).Scan( &ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed, &ent.DocumentsLimit, &ent.DocumentsUsed, &ent.UpdatedAt, ) if err != nil { return nil, err } ent.Features = features return &ent, nil } // UpdateEntitlements updates entitlements for a user (e.g., on plan change) func (s *EntitlementService) UpdateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures) error { featuresJSON, _ := json.Marshal(features) query := ` UPDATE user_entitlements SET plan_id = $2, ai_requests_limit = $3, documents_limit = $4, features = $5, updated_at = NOW() WHERE user_id = $1 ` _, err := s.db.Pool.Exec(ctx, query, userID, planID, features.AIRequestsLimit, features.DocumentsLimit, featuresJSON, ) return err } // ResetUsageCounters resets usage counters for a new period func (s *EntitlementService) ResetUsageCounters(ctx context.Context, userID uuid.UUID, newPeriodStart, newPeriodEnd *time.Time) error { query := ` UPDATE user_entitlements SET ai_requests_used = 0, documents_used = 0, period_start = $2, period_end = $3, updated_at = NOW() WHERE user_id = $1 ` _, err := s.db.Pool.Exec(ctx, query, userID, newPeriodStart, newPeriodEnd) return err } // CheckEntitlement checks if a user has a specific feature entitlement func (s *EntitlementService) CheckEntitlement(ctx context.Context, userIDStr, feature string) (bool, models.PlanID, error) { userID, err := uuid.Parse(userIDStr) if err != nil { return false, "", err } ent, err := s.getUserEntitlements(ctx, userID) if err != nil || ent == nil { return false, "", err } // Check if feature is in the feature flags for _, f := range ent.Features.FeatureFlags { if f == feature { return true, ent.PlanID, nil } } return false, ent.PlanID, nil } // IncrementUsage increments a usage counter func (s *EntitlementService) IncrementUsage(ctx context.Context, userID uuid.UUID, usageType string, amount int) error { var column string switch usageType { case "ai_request": column = "ai_requests_used" case "document_created": column = "documents_used" default: return nil } query := ` UPDATE user_entitlements SET ` + column + ` = ` + column + ` + $2, updated_at = NOW() WHERE user_id = $1 ` _, err := s.db.Pool.Exec(ctx, query, userID, amount) return err } // DeleteEntitlements removes entitlements for a user (on subscription cancellation) func (s *EntitlementService) DeleteEntitlements(ctx context.Context, userID uuid.UUID) error { query := `DELETE FROM user_entitlements WHERE user_id = $1` _, err := s.db.Pool.Exec(ctx, query, userID) return err }