package services import ( "context" "fmt" "github.com/breakpilot/billing-service/internal/models" "github.com/google/uuid" "github.com/stripe/stripe-go/v76" "github.com/stripe/stripe-go/v76/billingportal/session" checkoutsession "github.com/stripe/stripe-go/v76/checkout/session" "github.com/stripe/stripe-go/v76/customer" "github.com/stripe/stripe-go/v76/price" "github.com/stripe/stripe-go/v76/product" "github.com/stripe/stripe-go/v76/subscription" ) // StripeService handles Stripe API interactions type StripeService struct { secretKey string webhookSecret string successURL string cancelURL string trialPeriodDays int64 subService *SubscriptionService mockMode bool // If true, don't make real Stripe API calls } // NewStripeService creates a new StripeService func NewStripeService(secretKey, webhookSecret, successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService { // Initialize Stripe with the secret key (only if not empty) if secretKey != "" { stripe.Key = secretKey } return &StripeService{ secretKey: secretKey, webhookSecret: webhookSecret, successURL: successURL, cancelURL: cancelURL, trialPeriodDays: int64(trialPeriodDays), subService: subService, mockMode: false, } } // NewMockStripeService creates a mock StripeService for development func NewMockStripeService(successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService { return &StripeService{ secretKey: "", webhookSecret: "", successURL: successURL, cancelURL: cancelURL, trialPeriodDays: int64(trialPeriodDays), subService: subService, mockMode: true, } } // IsMockMode returns true if running in mock mode func (s *StripeService) IsMockMode() bool { return s.mockMode } // CreateCheckoutSession creates a Stripe Checkout session for trial start func (s *StripeService) CreateCheckoutSession(ctx context.Context, userID uuid.UUID, email string, planID models.PlanID) (string, string, error) { // Mock mode: return a fake URL for development if s.mockMode { mockSessionID := fmt.Sprintf("mock_cs_%s", uuid.New().String()[:8]) mockURL := fmt.Sprintf("%s?session_id=%s&mock=true&plan=%s", s.successURL, mockSessionID, planID) return mockURL, mockSessionID, nil } // Get plan details plan, err := s.subService.GetPlanByID(ctx, string(planID)) if err != nil || plan == nil { return "", "", fmt.Errorf("plan not found: %s", planID) } // Ensure we have a Stripe price ID if plan.StripePriceID == "" { // Create product and price in Stripe if not exists priceID, err := s.ensurePriceExists(ctx, plan) if err != nil { return "", "", fmt.Errorf("failed to create stripe price: %w", err) } plan.StripePriceID = priceID } // Create checkout session parameters params := &stripe.CheckoutSessionParams{ Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), LineItems: []*stripe.CheckoutSessionLineItemParams{ { Price: stripe.String(plan.StripePriceID), Quantity: stripe.Int64(1), }, }, SuccessURL: stripe.String(s.successURL + "?session_id={CHECKOUT_SESSION_ID}"), CancelURL: stripe.String(s.cancelURL), SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{ TrialPeriodDays: stripe.Int64(s.trialPeriodDays), Metadata: map[string]string{ "user_id": userID.String(), "plan_id": string(planID), }, }, PaymentMethodCollection: stripe.String(string(stripe.CheckoutSessionPaymentMethodCollectionAlways)), Metadata: map[string]string{ "user_id": userID.String(), "plan_id": string(planID), }, } // Set customer email if provided if email != "" { params.CustomerEmail = stripe.String(email) } // Create the session sess, err := checkoutsession.New(params) if err != nil { return "", "", fmt.Errorf("failed to create checkout session: %w", err) } return sess.URL, sess.ID, nil } // ensurePriceExists creates a Stripe product and price if they don't exist func (s *StripeService) ensurePriceExists(ctx context.Context, plan *models.BillingPlan) (string, error) { // Create product productParams := &stripe.ProductParams{ Name: stripe.String(plan.Name), Description: stripe.String(plan.Description), Metadata: map[string]string{ "plan_id": string(plan.ID), }, } prod, err := product.New(productParams) if err != nil { return "", fmt.Errorf("failed to create product: %w", err) } // Create price priceParams := &stripe.PriceParams{ Product: stripe.String(prod.ID), UnitAmount: stripe.Int64(int64(plan.PriceCents)), Currency: stripe.String(plan.Currency), Recurring: &stripe.PriceRecurringParams{ Interval: stripe.String(plan.Interval), }, Metadata: map[string]string{ "plan_id": string(plan.ID), }, } pr, err := price.New(priceParams) if err != nil { return "", fmt.Errorf("failed to create price: %w", err) } // Update plan with Stripe IDs if err := s.subService.UpdatePlanStripePriceID(ctx, string(plan.ID), pr.ID, prod.ID); err != nil { // Log but don't fail fmt.Printf("Warning: Failed to update plan with Stripe IDs: %v\n", err) } return pr.ID, nil } // GetOrCreateCustomer gets or creates a Stripe customer for a user func (s *StripeService) GetOrCreateCustomer(ctx context.Context, email, name string, userID uuid.UUID) (string, error) { // Search for existing customer params := &stripe.CustomerSearchParams{ SearchParams: stripe.SearchParams{ Query: fmt.Sprintf("email:'%s'", email), }, } iter := customer.Search(params) for iter.Next() { cust := iter.Customer() // Check if this customer belongs to our user if cust.Metadata["user_id"] == userID.String() { return cust.ID, nil } } // Create new customer customerParams := &stripe.CustomerParams{ Email: stripe.String(email), Name: stripe.String(name), Metadata: map[string]string{ "user_id": userID.String(), }, } cust, err := customer.New(customerParams) if err != nil { return "", fmt.Errorf("failed to create customer: %w", err) } return cust.ID, nil } // ChangePlan changes a subscription to a new plan func (s *StripeService) ChangePlan(ctx context.Context, stripeSubID string, newPlanID models.PlanID) error { // Mock mode: just return success if s.mockMode { return nil } // Get new plan details plan, err := s.subService.GetPlanByID(ctx, string(newPlanID)) if err != nil || plan == nil { return fmt.Errorf("plan not found: %s", newPlanID) } if plan.StripePriceID == "" { return fmt.Errorf("plan %s has no Stripe price ID", newPlanID) } // Get current subscription sub, err := subscription.Get(stripeSubID, nil) if err != nil { return fmt.Errorf("failed to get subscription: %w", err) } // Update subscription with new price params := &stripe.SubscriptionParams{ Items: []*stripe.SubscriptionItemsParams{ { ID: stripe.String(sub.Items.Data[0].ID), Price: stripe.String(plan.StripePriceID), }, }, ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), Metadata: map[string]string{ "plan_id": string(newPlanID), }, } _, err = subscription.Update(stripeSubID, params) if err != nil { return fmt.Errorf("failed to update subscription: %w", err) } return nil } // CancelSubscription cancels a subscription at period end func (s *StripeService) CancelSubscription(ctx context.Context, stripeSubID string) error { // Mock mode: just return success if s.mockMode { return nil } params := &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(true), } _, err := subscription.Update(stripeSubID, params) if err != nil { return fmt.Errorf("failed to cancel subscription: %w", err) } return nil } // ReactivateSubscription removes the cancel_at_period_end flag func (s *StripeService) ReactivateSubscription(ctx context.Context, stripeSubID string) error { // Mock mode: just return success if s.mockMode { return nil } params := &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(false), } _, err := subscription.Update(stripeSubID, params) if err != nil { return fmt.Errorf("failed to reactivate subscription: %w", err) } return nil } // CreateCustomerPortalSession creates a Stripe Customer Portal session func (s *StripeService) CreateCustomerPortalSession(ctx context.Context, customerID string) (string, error) { // Mock mode: return a mock URL if s.mockMode { return fmt.Sprintf("%s?mock_portal=true", s.successURL), nil } params := &stripe.BillingPortalSessionParams{ Customer: stripe.String(customerID), ReturnURL: stripe.String(s.successURL), } sess, err := session.New(params) if err != nil { return "", fmt.Errorf("failed to create portal session: %w", err) } return sess.URL, nil } // GetSubscription retrieves a subscription from Stripe func (s *StripeService) GetSubscription(ctx context.Context, stripeSubID string) (*stripe.Subscription, error) { sub, err := subscription.Get(stripeSubID, nil) if err != nil { return nil, fmt.Errorf("failed to get subscription: %w", err) } return sub, nil }