package services import ( "context" "encoding/json" "time" "github.com/breakpilot/billing-service/internal/database" "github.com/breakpilot/billing-service/internal/models" "github.com/google/uuid" ) // SubscriptionService handles subscription-related operations type SubscriptionService struct { db *database.DB } // NewSubscriptionService creates a new SubscriptionService func NewSubscriptionService(db *database.DB) *SubscriptionService { return &SubscriptionService{db: db} } // GetByUserID retrieves a subscription by user ID func (s *SubscriptionService) GetByUserID(ctx context.Context, userID uuid.UUID) (*models.Subscription, error) { query := ` SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id, status, trial_end, current_period_end, cancel_at_period_end, created_at, updated_at FROM subscriptions WHERE user_id = $1 ` var sub models.Subscription var stripeCustomerID, stripeSubID *string var trialEnd, periodEnd *time.Time err := s.db.Pool.QueryRow(ctx, query, userID).Scan( &sub.ID, &sub.UserID, &stripeCustomerID, &stripeSubID, &sub.PlanID, &sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd, &sub.CreatedAt, &sub.UpdatedAt, ) if err != nil { if err.Error() == "no rows in result set" { return nil, nil } return nil, err } if stripeCustomerID != nil { sub.StripeCustomerID = *stripeCustomerID } if stripeSubID != nil { sub.StripeSubscriptionID = *stripeSubID } sub.TrialEnd = trialEnd sub.CurrentPeriodEnd = periodEnd return &sub, nil } // GetByStripeSubscriptionID retrieves a subscription by Stripe subscription ID func (s *SubscriptionService) GetByStripeSubscriptionID(ctx context.Context, stripeSubID string) (*models.Subscription, error) { query := ` SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id, status, trial_end, current_period_end, cancel_at_period_end, created_at, updated_at FROM subscriptions WHERE stripe_subscription_id = $1 ` var sub models.Subscription var stripeCustomerID, subID *string var trialEnd, periodEnd *time.Time err := s.db.Pool.QueryRow(ctx, query, stripeSubID).Scan( &sub.ID, &sub.UserID, &stripeCustomerID, &subID, &sub.PlanID, &sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd, &sub.CreatedAt, &sub.UpdatedAt, ) if err != nil { if err.Error() == "no rows in result set" { return nil, nil } return nil, err } if stripeCustomerID != nil { sub.StripeCustomerID = *stripeCustomerID } if subID != nil { sub.StripeSubscriptionID = *subID } sub.TrialEnd = trialEnd sub.CurrentPeriodEnd = periodEnd return &sub, nil } // Create creates a new subscription func (s *SubscriptionService) Create(ctx context.Context, sub *models.Subscription) error { query := ` INSERT INTO subscriptions ( user_id, stripe_customer_id, stripe_subscription_id, plan_id, status, trial_end, current_period_end, cancel_at_period_end ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at ` return s.db.Pool.QueryRow(ctx, query, sub.UserID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID, sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd, ).Scan(&sub.ID, &sub.CreatedAt, &sub.UpdatedAt) } // Update updates an existing subscription func (s *SubscriptionService) Update(ctx context.Context, sub *models.Subscription) error { query := ` UPDATE subscriptions SET stripe_customer_id = $2, stripe_subscription_id = $3, plan_id = $4, status = $5, trial_end = $6, current_period_end = $7, cancel_at_period_end = $8, updated_at = NOW() WHERE id = $1 ` _, err := s.db.Pool.Exec(ctx, query, sub.ID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID, sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd, ) return err } // UpdateStatus updates the subscription status func (s *SubscriptionService) UpdateStatus(ctx context.Context, id uuid.UUID, status models.SubscriptionStatus) error { query := `UPDATE subscriptions SET status = $2, updated_at = NOW() WHERE id = $1` _, err := s.db.Pool.Exec(ctx, query, id, status) return err } // GetAvailablePlans retrieves all active billing plans func (s *SubscriptionService) GetAvailablePlans(ctx context.Context) ([]models.BillingPlan, error) { query := ` SELECT id, stripe_price_id, name, description, price_cents, currency, interval, features, is_active, sort_order FROM billing_plans WHERE is_active = true ORDER BY sort_order ASC ` rows, err := s.db.Pool.Query(ctx, query) if err != nil { return nil, err } defer rows.Close() var plans []models.BillingPlan for rows.Next() { var plan models.BillingPlan var stripePriceID *string var featuresJSON []byte err := rows.Scan( &plan.ID, &stripePriceID, &plan.Name, &plan.Description, &plan.PriceCents, &plan.Currency, &plan.Interval, &featuresJSON, &plan.IsActive, &plan.SortOrder, ) if err != nil { return nil, err } if stripePriceID != nil { plan.StripePriceID = *stripePriceID } // Parse features JSON if len(featuresJSON) > 0 { json.Unmarshal(featuresJSON, &plan.Features) } plans = append(plans, plan) } return plans, nil } // GetPlanByID retrieves a billing plan by ID func (s *SubscriptionService) GetPlanByID(ctx context.Context, planID string) (*models.BillingPlan, error) { query := ` SELECT id, stripe_price_id, name, description, price_cents, currency, interval, features, is_active, sort_order FROM billing_plans WHERE id = $1 ` var plan models.BillingPlan var stripePriceID *string var featuresJSON []byte err := s.db.Pool.QueryRow(ctx, query, planID).Scan( &plan.ID, &stripePriceID, &plan.Name, &plan.Description, &plan.PriceCents, &plan.Currency, &plan.Interval, &featuresJSON, &plan.IsActive, &plan.SortOrder, ) if err != nil { if err.Error() == "no rows in result set" { return nil, nil } return nil, err } if stripePriceID != nil { plan.StripePriceID = *stripePriceID } if len(featuresJSON) > 0 { json.Unmarshal(featuresJSON, &plan.Features) } return &plan, nil } // UpdatePlanStripePriceID updates the Stripe price ID for a plan func (s *SubscriptionService) UpdatePlanStripePriceID(ctx context.Context, planID, stripePriceID, stripeProductID string) error { query := ` UPDATE billing_plans SET stripe_price_id = $2, stripe_product_id = $3, updated_at = NOW() WHERE id = $1 ` _, err := s.db.Pool.Exec(ctx, query, planID, stripePriceID, stripeProductID) return err } // ============================================= // Webhook Event Tracking (Idempotency) // ============================================= // IsEventProcessed checks if a webhook event has already been processed func (s *SubscriptionService) IsEventProcessed(ctx context.Context, eventID string) (bool, error) { query := `SELECT processed FROM stripe_webhook_events WHERE stripe_event_id = $1` var processed bool err := s.db.Pool.QueryRow(ctx, query, eventID).Scan(&processed) if err != nil { if err.Error() == "no rows in result set" { return false, nil } return false, err } return processed, nil } // MarkEventProcessing marks an event as being processed func (s *SubscriptionService) MarkEventProcessing(ctx context.Context, eventID, eventType string) error { query := ` INSERT INTO stripe_webhook_events (stripe_event_id, event_type, processed) VALUES ($1, $2, false) ON CONFLICT (stripe_event_id) DO NOTHING ` _, err := s.db.Pool.Exec(ctx, query, eventID, eventType) return err } // MarkEventProcessed marks an event as successfully processed func (s *SubscriptionService) MarkEventProcessed(ctx context.Context, eventID string) error { query := ` UPDATE stripe_webhook_events SET processed = true, processed_at = NOW() WHERE stripe_event_id = $1 ` _, err := s.db.Pool.Exec(ctx, query, eventID) return err } // MarkEventFailed marks an event as failed with an error message func (s *SubscriptionService) MarkEventFailed(ctx context.Context, eventID, errorMsg string) error { query := ` UPDATE stripe_webhook_events SET processed = false, error_message = $2, processed_at = NOW() WHERE stripe_event_id = $1 ` _, err := s.db.Pool.Exec(ctx, query, eventID, errorMsg) return err } // ============================================= // Audit Logging // ============================================= // LogAuditEvent logs a billing audit event func (s *SubscriptionService) LogAuditEvent(ctx context.Context, userID *uuid.UUID, action, entityType, entityID string, oldValue, newValue, metadata interface{}, ipAddress, userAgent string) error { oldJSON, _ := json.Marshal(oldValue) newJSON, _ := json.Marshal(newValue) metaJSON, _ := json.Marshal(metadata) query := ` INSERT INTO billing_audit_log ( user_id, action, entity_type, entity_id, old_value, new_value, metadata, ip_address, user_agent ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ` _, err := s.db.Pool.Exec(ctx, query, userID, action, entityType, entityID, oldJSON, newJSON, metaJSON, ipAddress, userAgent, ) return err }