Initial commit: breakpilot-core - Shared Infrastructure
Docker Compose with 24+ services: - PostgreSQL (PostGIS), Valkey, MinIO, Qdrant - Vault (PKI/TLS), Nginx (Reverse Proxy) - Backend Core API, Consent Service, Billing Service - RAG Service, Embedding Service - Gitea, Woodpecker CI/CD - Night Scheduler, Health Aggregator - Jitsi (Web/XMPP/JVB/Jicofo), Mailpit Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
232
billing-service/internal/services/entitlement_service.go
Normal file
232
billing-service/internal/services/entitlement_service.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// EntitlementService handles entitlement-related operations
|
||||
type EntitlementService struct {
|
||||
db *database.DB
|
||||
subService *SubscriptionService
|
||||
}
|
||||
|
||||
// NewEntitlementService creates a new EntitlementService
|
||||
func NewEntitlementService(db *database.DB, subService *SubscriptionService) *EntitlementService {
|
||||
return &EntitlementService{
|
||||
db: db,
|
||||
subService: subService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetEntitlements returns the entitlement info for a user
|
||||
func (s *EntitlementService) GetEntitlements(ctx context.Context, userID uuid.UUID) (*models.EntitlementInfo, error) {
|
||||
entitlements, err := s.getUserEntitlements(ctx, userID)
|
||||
if err != nil || entitlements == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.EntitlementInfo{
|
||||
Features: entitlements.Features.FeatureFlags,
|
||||
MaxTeamMembers: entitlements.Features.MaxTeamMembers,
|
||||
PrioritySupport: entitlements.Features.PrioritySupport,
|
||||
CustomBranding: entitlements.Features.CustomBranding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetEntitlementsByUserIDString returns entitlements by user ID string (for internal API)
|
||||
func (s *EntitlementService) GetEntitlementsByUserIDString(ctx context.Context, userIDStr string) (*models.UserEntitlements, error) {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.getUserEntitlements(ctx, userID)
|
||||
}
|
||||
|
||||
// getUserEntitlements retrieves or creates entitlements for a user
|
||||
func (s *EntitlementService) getUserEntitlements(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) {
|
||||
query := `
|
||||
SELECT id, user_id, plan_id, ai_requests_limit, ai_requests_used,
|
||||
documents_limit, documents_used, features, period_start, period_end,
|
||||
created_at, updated_at
|
||||
FROM user_entitlements
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
var ent models.UserEntitlements
|
||||
var featuresJSON []byte
|
||||
var periodStart, periodEnd *time.Time
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
|
||||
&ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed,
|
||||
&ent.DocumentsLimit, &ent.DocumentsUsed, &featuresJSON, &periodStart, &periodEnd,
|
||||
nil, &ent.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
// Try to create entitlements based on subscription
|
||||
return s.createEntitlementsFromSubscription(ctx, userID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(featuresJSON) > 0 {
|
||||
json.Unmarshal(featuresJSON, &ent.Features)
|
||||
}
|
||||
|
||||
return &ent, nil
|
||||
}
|
||||
|
||||
// createEntitlementsFromSubscription creates entitlements based on user's subscription
|
||||
func (s *EntitlementService) createEntitlementsFromSubscription(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) {
|
||||
// Get user's subscription
|
||||
sub, err := s.subService.GetByUserID(ctx, userID)
|
||||
if err != nil || sub == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get plan details
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID))
|
||||
if err != nil || plan == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create entitlements
|
||||
return s.CreateEntitlements(ctx, userID, sub.PlanID, plan.Features, sub.CurrentPeriodEnd)
|
||||
}
|
||||
|
||||
// CreateEntitlements creates entitlements for a user
|
||||
func (s *EntitlementService) CreateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures, periodEnd *time.Time) (*models.UserEntitlements, error) {
|
||||
featuresJSON, _ := json.Marshal(features)
|
||||
|
||||
now := time.Now()
|
||||
periodStart := now
|
||||
|
||||
query := `
|
||||
INSERT INTO user_entitlements (
|
||||
user_id, plan_id, ai_requests_limit, ai_requests_used,
|
||||
documents_limit, documents_used, features, period_start, period_end
|
||||
) VALUES ($1, $2, $3, 0, $4, 0, $5, $6, $7)
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
plan_id = EXCLUDED.plan_id,
|
||||
ai_requests_limit = EXCLUDED.ai_requests_limit,
|
||||
documents_limit = EXCLUDED.documents_limit,
|
||||
features = EXCLUDED.features,
|
||||
period_start = EXCLUDED.period_start,
|
||||
period_end = EXCLUDED.period_end,
|
||||
updated_at = NOW()
|
||||
RETURNING id, user_id, plan_id, ai_requests_limit, ai_requests_used,
|
||||
documents_limit, documents_used, updated_at
|
||||
`
|
||||
|
||||
var ent models.UserEntitlements
|
||||
err := s.db.Pool.QueryRow(ctx, query,
|
||||
userID, planID, features.AIRequestsLimit, features.DocumentsLimit,
|
||||
featuresJSON, periodStart, periodEnd,
|
||||
).Scan(
|
||||
&ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed,
|
||||
&ent.DocumentsLimit, &ent.DocumentsUsed, &ent.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ent.Features = features
|
||||
return &ent, nil
|
||||
}
|
||||
|
||||
// UpdateEntitlements updates entitlements for a user (e.g., on plan change)
|
||||
func (s *EntitlementService) UpdateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures) error {
|
||||
featuresJSON, _ := json.Marshal(features)
|
||||
|
||||
query := `
|
||||
UPDATE user_entitlements SET
|
||||
plan_id = $2,
|
||||
ai_requests_limit = $3,
|
||||
documents_limit = $4,
|
||||
features = $5,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query,
|
||||
userID, planID, features.AIRequestsLimit, features.DocumentsLimit, featuresJSON,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetUsageCounters resets usage counters for a new period
|
||||
func (s *EntitlementService) ResetUsageCounters(ctx context.Context, userID uuid.UUID, newPeriodStart, newPeriodEnd *time.Time) error {
|
||||
query := `
|
||||
UPDATE user_entitlements SET
|
||||
ai_requests_used = 0,
|
||||
documents_used = 0,
|
||||
period_start = $2,
|
||||
period_end = $3,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query, userID, newPeriodStart, newPeriodEnd)
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckEntitlement checks if a user has a specific feature entitlement
|
||||
func (s *EntitlementService) CheckEntitlement(ctx context.Context, userIDStr, feature string) (bool, models.PlanID, error) {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
ent, err := s.getUserEntitlements(ctx, userID)
|
||||
if err != nil || ent == nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
// Check if feature is in the feature flags
|
||||
for _, f := range ent.Features.FeatureFlags {
|
||||
if f == feature {
|
||||
return true, ent.PlanID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, ent.PlanID, nil
|
||||
}
|
||||
|
||||
// IncrementUsage increments a usage counter
|
||||
func (s *EntitlementService) IncrementUsage(ctx context.Context, userID uuid.UUID, usageType string, amount int) error {
|
||||
var column string
|
||||
switch usageType {
|
||||
case "ai_request":
|
||||
column = "ai_requests_used"
|
||||
case "document_created":
|
||||
column = "documents_used"
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE user_entitlements SET
|
||||
` + column + ` = ` + column + ` + $2,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query, userID, amount)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteEntitlements removes entitlements for a user (on subscription cancellation)
|
||||
func (s *EntitlementService) DeleteEntitlements(ctx context.Context, userID uuid.UUID) error {
|
||||
query := `DELETE FROM user_entitlements WHERE user_id = $1`
|
||||
_, err := s.db.Pool.Exec(ctx, query, userID)
|
||||
return err
|
||||
}
|
||||
317
billing-service/internal/services/stripe_service.go
Normal file
317
billing-service/internal/services/stripe_service.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stripe/stripe-go/v76"
|
||||
"github.com/stripe/stripe-go/v76/billingportal/session"
|
||||
checkoutsession "github.com/stripe/stripe-go/v76/checkout/session"
|
||||
"github.com/stripe/stripe-go/v76/customer"
|
||||
"github.com/stripe/stripe-go/v76/price"
|
||||
"github.com/stripe/stripe-go/v76/product"
|
||||
"github.com/stripe/stripe-go/v76/subscription"
|
||||
)
|
||||
|
||||
// StripeService handles Stripe API interactions
|
||||
type StripeService struct {
|
||||
secretKey string
|
||||
webhookSecret string
|
||||
successURL string
|
||||
cancelURL string
|
||||
trialPeriodDays int64
|
||||
subService *SubscriptionService
|
||||
mockMode bool // If true, don't make real Stripe API calls
|
||||
}
|
||||
|
||||
// NewStripeService creates a new StripeService
|
||||
func NewStripeService(secretKey, webhookSecret, successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService {
|
||||
// Initialize Stripe with the secret key (only if not empty)
|
||||
if secretKey != "" {
|
||||
stripe.Key = secretKey
|
||||
}
|
||||
|
||||
return &StripeService{
|
||||
secretKey: secretKey,
|
||||
webhookSecret: webhookSecret,
|
||||
successURL: successURL,
|
||||
cancelURL: cancelURL,
|
||||
trialPeriodDays: int64(trialPeriodDays),
|
||||
subService: subService,
|
||||
mockMode: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockStripeService creates a mock StripeService for development
|
||||
func NewMockStripeService(successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService {
|
||||
return &StripeService{
|
||||
secretKey: "",
|
||||
webhookSecret: "",
|
||||
successURL: successURL,
|
||||
cancelURL: cancelURL,
|
||||
trialPeriodDays: int64(trialPeriodDays),
|
||||
subService: subService,
|
||||
mockMode: true,
|
||||
}
|
||||
}
|
||||
|
||||
// IsMockMode returns true if running in mock mode
|
||||
func (s *StripeService) IsMockMode() bool {
|
||||
return s.mockMode
|
||||
}
|
||||
|
||||
// CreateCheckoutSession creates a Stripe Checkout session for trial start
|
||||
func (s *StripeService) CreateCheckoutSession(ctx context.Context, userID uuid.UUID, email string, planID models.PlanID) (string, string, error) {
|
||||
// Mock mode: return a fake URL for development
|
||||
if s.mockMode {
|
||||
mockSessionID := fmt.Sprintf("mock_cs_%s", uuid.New().String()[:8])
|
||||
mockURL := fmt.Sprintf("%s?session_id=%s&mock=true&plan=%s", s.successURL, mockSessionID, planID)
|
||||
return mockURL, mockSessionID, nil
|
||||
}
|
||||
|
||||
// Get plan details
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(planID))
|
||||
if err != nil || plan == nil {
|
||||
return "", "", fmt.Errorf("plan not found: %s", planID)
|
||||
}
|
||||
|
||||
// Ensure we have a Stripe price ID
|
||||
if plan.StripePriceID == "" {
|
||||
// Create product and price in Stripe if not exists
|
||||
priceID, err := s.ensurePriceExists(ctx, plan)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to create stripe price: %w", err)
|
||||
}
|
||||
plan.StripePriceID = priceID
|
||||
}
|
||||
|
||||
// Create checkout session parameters
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(plan.StripePriceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
SuccessURL: stripe.String(s.successURL + "?session_id={CHECKOUT_SESSION_ID}"),
|
||||
CancelURL: stripe.String(s.cancelURL),
|
||||
SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
|
||||
TrialPeriodDays: stripe.Int64(s.trialPeriodDays),
|
||||
Metadata: map[string]string{
|
||||
"user_id": userID.String(),
|
||||
"plan_id": string(planID),
|
||||
},
|
||||
},
|
||||
PaymentMethodCollection: stripe.String(string(stripe.CheckoutSessionPaymentMethodCollectionAlways)),
|
||||
Metadata: map[string]string{
|
||||
"user_id": userID.String(),
|
||||
"plan_id": string(planID),
|
||||
},
|
||||
}
|
||||
|
||||
// Set customer email if provided
|
||||
if email != "" {
|
||||
params.CustomerEmail = stripe.String(email)
|
||||
}
|
||||
|
||||
// Create the session
|
||||
sess, err := checkoutsession.New(params)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to create checkout session: %w", err)
|
||||
}
|
||||
|
||||
return sess.URL, sess.ID, nil
|
||||
}
|
||||
|
||||
// ensurePriceExists creates a Stripe product and price if they don't exist
|
||||
func (s *StripeService) ensurePriceExists(ctx context.Context, plan *models.BillingPlan) (string, error) {
|
||||
// Create product
|
||||
productParams := &stripe.ProductParams{
|
||||
Name: stripe.String(plan.Name),
|
||||
Description: stripe.String(plan.Description),
|
||||
Metadata: map[string]string{
|
||||
"plan_id": string(plan.ID),
|
||||
},
|
||||
}
|
||||
|
||||
prod, err := product.New(productParams)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create product: %w", err)
|
||||
}
|
||||
|
||||
// Create price
|
||||
priceParams := &stripe.PriceParams{
|
||||
Product: stripe.String(prod.ID),
|
||||
UnitAmount: stripe.Int64(int64(plan.PriceCents)),
|
||||
Currency: stripe.String(plan.Currency),
|
||||
Recurring: &stripe.PriceRecurringParams{
|
||||
Interval: stripe.String(plan.Interval),
|
||||
},
|
||||
Metadata: map[string]string{
|
||||
"plan_id": string(plan.ID),
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := price.New(priceParams)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create price: %w", err)
|
||||
}
|
||||
|
||||
// Update plan with Stripe IDs
|
||||
if err := s.subService.UpdatePlanStripePriceID(ctx, string(plan.ID), pr.ID, prod.ID); err != nil {
|
||||
// Log but don't fail
|
||||
fmt.Printf("Warning: Failed to update plan with Stripe IDs: %v\n", err)
|
||||
}
|
||||
|
||||
return pr.ID, nil
|
||||
}
|
||||
|
||||
// GetOrCreateCustomer gets or creates a Stripe customer for a user
|
||||
func (s *StripeService) GetOrCreateCustomer(ctx context.Context, email, name string, userID uuid.UUID) (string, error) {
|
||||
// Search for existing customer
|
||||
params := &stripe.CustomerSearchParams{
|
||||
SearchParams: stripe.SearchParams{
|
||||
Query: fmt.Sprintf("email:'%s'", email),
|
||||
},
|
||||
}
|
||||
|
||||
iter := customer.Search(params)
|
||||
for iter.Next() {
|
||||
cust := iter.Customer()
|
||||
// Check if this customer belongs to our user
|
||||
if cust.Metadata["user_id"] == userID.String() {
|
||||
return cust.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create new customer
|
||||
customerParams := &stripe.CustomerParams{
|
||||
Email: stripe.String(email),
|
||||
Name: stripe.String(name),
|
||||
Metadata: map[string]string{
|
||||
"user_id": userID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
cust, err := customer.New(customerParams)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create customer: %w", err)
|
||||
}
|
||||
|
||||
return cust.ID, nil
|
||||
}
|
||||
|
||||
// ChangePlan changes a subscription to a new plan
|
||||
func (s *StripeService) ChangePlan(ctx context.Context, stripeSubID string, newPlanID models.PlanID) error {
|
||||
// Mock mode: just return success
|
||||
if s.mockMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get new plan details
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(newPlanID))
|
||||
if err != nil || plan == nil {
|
||||
return fmt.Errorf("plan not found: %s", newPlanID)
|
||||
}
|
||||
|
||||
if plan.StripePriceID == "" {
|
||||
return fmt.Errorf("plan %s has no Stripe price ID", newPlanID)
|
||||
}
|
||||
|
||||
// Get current subscription
|
||||
sub, err := subscription.Get(stripeSubID, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get subscription: %w", err)
|
||||
}
|
||||
|
||||
// Update subscription with new price
|
||||
params := &stripe.SubscriptionParams{
|
||||
Items: []*stripe.SubscriptionItemsParams{
|
||||
{
|
||||
ID: stripe.String(sub.Items.Data[0].ID),
|
||||
Price: stripe.String(plan.StripePriceID),
|
||||
},
|
||||
},
|
||||
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
|
||||
Metadata: map[string]string{
|
||||
"plan_id": string(newPlanID),
|
||||
},
|
||||
}
|
||||
|
||||
_, err = subscription.Update(stripeSubID, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update subscription: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelSubscription cancels a subscription at period end
|
||||
func (s *StripeService) CancelSubscription(ctx context.Context, stripeSubID string) error {
|
||||
// Mock mode: just return success
|
||||
if s.mockMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(true),
|
||||
}
|
||||
|
||||
_, err := subscription.Update(stripeSubID, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cancel subscription: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReactivateSubscription removes the cancel_at_period_end flag
|
||||
func (s *StripeService) ReactivateSubscription(ctx context.Context, stripeSubID string) error {
|
||||
// Mock mode: just return success
|
||||
if s.mockMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(false),
|
||||
}
|
||||
|
||||
_, err := subscription.Update(stripeSubID, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reactivate subscription: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateCustomerPortalSession creates a Stripe Customer Portal session
|
||||
func (s *StripeService) CreateCustomerPortalSession(ctx context.Context, customerID string) (string, error) {
|
||||
// Mock mode: return a mock URL
|
||||
if s.mockMode {
|
||||
return fmt.Sprintf("%s?mock_portal=true", s.successURL), nil
|
||||
}
|
||||
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(customerID),
|
||||
ReturnURL: stripe.String(s.successURL),
|
||||
}
|
||||
|
||||
sess, err := session.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create portal session: %w", err)
|
||||
}
|
||||
|
||||
return sess.URL, nil
|
||||
}
|
||||
|
||||
// GetSubscription retrieves a subscription from Stripe
|
||||
func (s *StripeService) GetSubscription(ctx context.Context, stripeSubID string) (*stripe.Subscription, error) {
|
||||
sub, err := subscription.Get(stripeSubID, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get subscription: %w", err)
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
315
billing-service/internal/services/subscription_service.go
Normal file
315
billing-service/internal/services/subscription_service.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SubscriptionService handles subscription-related operations
|
||||
type SubscriptionService struct {
|
||||
db *database.DB
|
||||
}
|
||||
|
||||
// NewSubscriptionService creates a new SubscriptionService
|
||||
func NewSubscriptionService(db *database.DB) *SubscriptionService {
|
||||
return &SubscriptionService{db: db}
|
||||
}
|
||||
|
||||
// GetByUserID retrieves a subscription by user ID
|
||||
func (s *SubscriptionService) GetByUserID(ctx context.Context, userID uuid.UUID) (*models.Subscription, error) {
|
||||
query := `
|
||||
SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id,
|
||||
status, trial_end, current_period_end, cancel_at_period_end,
|
||||
created_at, updated_at
|
||||
FROM subscriptions
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
var sub models.Subscription
|
||||
var stripeCustomerID, stripeSubID *string
|
||||
var trialEnd, periodEnd *time.Time
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
|
||||
&sub.ID, &sub.UserID, &stripeCustomerID, &stripeSubID, &sub.PlanID,
|
||||
&sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd,
|
||||
&sub.CreatedAt, &sub.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stripeCustomerID != nil {
|
||||
sub.StripeCustomerID = *stripeCustomerID
|
||||
}
|
||||
if stripeSubID != nil {
|
||||
sub.StripeSubscriptionID = *stripeSubID
|
||||
}
|
||||
sub.TrialEnd = trialEnd
|
||||
sub.CurrentPeriodEnd = periodEnd
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// GetByStripeSubscriptionID retrieves a subscription by Stripe subscription ID
|
||||
func (s *SubscriptionService) GetByStripeSubscriptionID(ctx context.Context, stripeSubID string) (*models.Subscription, error) {
|
||||
query := `
|
||||
SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id,
|
||||
status, trial_end, current_period_end, cancel_at_period_end,
|
||||
created_at, updated_at
|
||||
FROM subscriptions
|
||||
WHERE stripe_subscription_id = $1
|
||||
`
|
||||
|
||||
var sub models.Subscription
|
||||
var stripeCustomerID, subID *string
|
||||
var trialEnd, periodEnd *time.Time
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, stripeSubID).Scan(
|
||||
&sub.ID, &sub.UserID, &stripeCustomerID, &subID, &sub.PlanID,
|
||||
&sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd,
|
||||
&sub.CreatedAt, &sub.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stripeCustomerID != nil {
|
||||
sub.StripeCustomerID = *stripeCustomerID
|
||||
}
|
||||
if subID != nil {
|
||||
sub.StripeSubscriptionID = *subID
|
||||
}
|
||||
sub.TrialEnd = trialEnd
|
||||
sub.CurrentPeriodEnd = periodEnd
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// Create creates a new subscription
|
||||
func (s *SubscriptionService) Create(ctx context.Context, sub *models.Subscription) error {
|
||||
query := `
|
||||
INSERT INTO subscriptions (
|
||||
user_id, stripe_customer_id, stripe_subscription_id, plan_id,
|
||||
status, trial_end, current_period_end, cancel_at_period_end
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
return s.db.Pool.QueryRow(ctx, query,
|
||||
sub.UserID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID,
|
||||
sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd,
|
||||
).Scan(&sub.ID, &sub.CreatedAt, &sub.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update updates an existing subscription
|
||||
func (s *SubscriptionService) Update(ctx context.Context, sub *models.Subscription) error {
|
||||
query := `
|
||||
UPDATE subscriptions SET
|
||||
stripe_customer_id = $2,
|
||||
stripe_subscription_id = $3,
|
||||
plan_id = $4,
|
||||
status = $5,
|
||||
trial_end = $6,
|
||||
current_period_end = $7,
|
||||
cancel_at_period_end = $8,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query,
|
||||
sub.ID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID,
|
||||
sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateStatus updates the subscription status
|
||||
func (s *SubscriptionService) UpdateStatus(ctx context.Context, id uuid.UUID, status models.SubscriptionStatus) error {
|
||||
query := `UPDATE subscriptions SET status = $2, updated_at = NOW() WHERE id = $1`
|
||||
_, err := s.db.Pool.Exec(ctx, query, id, status)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAvailablePlans retrieves all active billing plans
|
||||
func (s *SubscriptionService) GetAvailablePlans(ctx context.Context) ([]models.BillingPlan, error) {
|
||||
query := `
|
||||
SELECT id, stripe_price_id, name, description, price_cents,
|
||||
currency, interval, features, is_active, sort_order
|
||||
FROM billing_plans
|
||||
WHERE is_active = true
|
||||
ORDER BY sort_order ASC
|
||||
`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var plans []models.BillingPlan
|
||||
for rows.Next() {
|
||||
var plan models.BillingPlan
|
||||
var stripePriceID *string
|
||||
var featuresJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&plan.ID, &stripePriceID, &plan.Name, &plan.Description,
|
||||
&plan.PriceCents, &plan.Currency, &plan.Interval,
|
||||
&featuresJSON, &plan.IsActive, &plan.SortOrder,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stripePriceID != nil {
|
||||
plan.StripePriceID = *stripePriceID
|
||||
}
|
||||
|
||||
// Parse features JSON
|
||||
if len(featuresJSON) > 0 {
|
||||
json.Unmarshal(featuresJSON, &plan.Features)
|
||||
}
|
||||
|
||||
plans = append(plans, plan)
|
||||
}
|
||||
|
||||
return plans, nil
|
||||
}
|
||||
|
||||
// GetPlanByID retrieves a billing plan by ID
|
||||
func (s *SubscriptionService) GetPlanByID(ctx context.Context, planID string) (*models.BillingPlan, error) {
|
||||
query := `
|
||||
SELECT id, stripe_price_id, name, description, price_cents,
|
||||
currency, interval, features, is_active, sort_order
|
||||
FROM billing_plans
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var plan models.BillingPlan
|
||||
var stripePriceID *string
|
||||
var featuresJSON []byte
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, planID).Scan(
|
||||
&plan.ID, &stripePriceID, &plan.Name, &plan.Description,
|
||||
&plan.PriceCents, &plan.Currency, &plan.Interval,
|
||||
&featuresJSON, &plan.IsActive, &plan.SortOrder,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stripePriceID != nil {
|
||||
plan.StripePriceID = *stripePriceID
|
||||
}
|
||||
|
||||
if len(featuresJSON) > 0 {
|
||||
json.Unmarshal(featuresJSON, &plan.Features)
|
||||
}
|
||||
|
||||
return &plan, nil
|
||||
}
|
||||
|
||||
// UpdatePlanStripePriceID updates the Stripe price ID for a plan
|
||||
func (s *SubscriptionService) UpdatePlanStripePriceID(ctx context.Context, planID, stripePriceID, stripeProductID string) error {
|
||||
query := `
|
||||
UPDATE billing_plans
|
||||
SET stripe_price_id = $2, stripe_product_id = $3, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, planID, stripePriceID, stripeProductID)
|
||||
return err
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Webhook Event Tracking (Idempotency)
|
||||
// =============================================
|
||||
|
||||
// IsEventProcessed checks if a webhook event has already been processed
|
||||
func (s *SubscriptionService) IsEventProcessed(ctx context.Context, eventID string) (bool, error) {
|
||||
query := `SELECT processed FROM stripe_webhook_events WHERE stripe_event_id = $1`
|
||||
|
||||
var processed bool
|
||||
err := s.db.Pool.QueryRow(ctx, query, eventID).Scan(&processed)
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return processed, nil
|
||||
}
|
||||
|
||||
// MarkEventProcessing marks an event as being processed
|
||||
func (s *SubscriptionService) MarkEventProcessing(ctx context.Context, eventID, eventType string) error {
|
||||
query := `
|
||||
INSERT INTO stripe_webhook_events (stripe_event_id, event_type, processed)
|
||||
VALUES ($1, $2, false)
|
||||
ON CONFLICT (stripe_event_id) DO NOTHING
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, eventID, eventType)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkEventProcessed marks an event as successfully processed
|
||||
func (s *SubscriptionService) MarkEventProcessed(ctx context.Context, eventID string) error {
|
||||
query := `
|
||||
UPDATE stripe_webhook_events
|
||||
SET processed = true, processed_at = NOW()
|
||||
WHERE stripe_event_id = $1
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, eventID)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkEventFailed marks an event as failed with an error message
|
||||
func (s *SubscriptionService) MarkEventFailed(ctx context.Context, eventID, errorMsg string) error {
|
||||
query := `
|
||||
UPDATE stripe_webhook_events
|
||||
SET processed = false, error_message = $2, processed_at = NOW()
|
||||
WHERE stripe_event_id = $1
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, eventID, errorMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Audit Logging
|
||||
// =============================================
|
||||
|
||||
// LogAuditEvent logs a billing audit event
|
||||
func (s *SubscriptionService) LogAuditEvent(ctx context.Context, userID *uuid.UUID, action, entityType, entityID string, oldValue, newValue, metadata interface{}, ipAddress, userAgent string) error {
|
||||
oldJSON, _ := json.Marshal(oldValue)
|
||||
newJSON, _ := json.Marshal(newValue)
|
||||
metaJSON, _ := json.Marshal(metadata)
|
||||
|
||||
query := `
|
||||
INSERT INTO billing_audit_log (
|
||||
user_id, action, entity_type, entity_id,
|
||||
old_value, new_value, metadata, ip_address, user_agent
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query,
|
||||
userID, action, entityType, entityID,
|
||||
oldJSON, newJSON, metaJSON, ipAddress, userAgent,
|
||||
)
|
||||
return err
|
||||
}
|
||||
326
billing-service/internal/services/subscription_service_test.go
Normal file
326
billing-service/internal/services/subscription_service_test.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
)
|
||||
|
||||
func TestSubscriptionStatus_Transitions(t *testing.T) {
|
||||
// Test valid subscription status values
|
||||
validStatuses := []models.SubscriptionStatus{
|
||||
models.StatusTrialing,
|
||||
models.StatusActive,
|
||||
models.StatusPastDue,
|
||||
models.StatusCanceled,
|
||||
models.StatusExpired,
|
||||
}
|
||||
|
||||
for _, status := range validStatuses {
|
||||
if status == "" {
|
||||
t.Errorf("Status should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanID_ValidValues(t *testing.T) {
|
||||
validPlanIDs := []models.PlanID{
|
||||
models.PlanBasic,
|
||||
models.PlanStandard,
|
||||
models.PlanPremium,
|
||||
}
|
||||
|
||||
expected := []string{"basic", "standard", "premium"}
|
||||
|
||||
for i, planID := range validPlanIDs {
|
||||
if string(planID) != expected[i] {
|
||||
t.Errorf("PlanID should be '%s', got '%s'", expected[i], planID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanFeatures_JSONSerialization(t *testing.T) {
|
||||
features := models.PlanFeatures{
|
||||
MonthlyTaskAllowance: 100,
|
||||
MaxTaskBalance: 500,
|
||||
FeatureFlags: []string{"basic_ai", "templates"},
|
||||
MaxTeamMembers: 3,
|
||||
PrioritySupport: false,
|
||||
CustomBranding: false,
|
||||
BatchProcessing: true,
|
||||
CustomTemplates: true,
|
||||
FairUseMode: false,
|
||||
}
|
||||
|
||||
// Test JSON serialization
|
||||
data, err := json.Marshal(features)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal PlanFeatures: %v", err)
|
||||
}
|
||||
|
||||
// Test JSON deserialization
|
||||
var decoded models.PlanFeatures
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal PlanFeatures: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded.MonthlyTaskAllowance != features.MonthlyTaskAllowance {
|
||||
t.Errorf("MonthlyTaskAllowance mismatch: got %d, expected %d",
|
||||
decoded.MonthlyTaskAllowance, features.MonthlyTaskAllowance)
|
||||
}
|
||||
if decoded.MaxTaskBalance != features.MaxTaskBalance {
|
||||
t.Errorf("MaxTaskBalance mismatch: got %d, expected %d",
|
||||
decoded.MaxTaskBalance, features.MaxTaskBalance)
|
||||
}
|
||||
if decoded.BatchProcessing != features.BatchProcessing {
|
||||
t.Errorf("BatchProcessing mismatch: got %v, expected %v",
|
||||
decoded.BatchProcessing, features.BatchProcessing)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_DefaultPlansAreValid(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
if len(plans) != 3 {
|
||||
t.Fatalf("Expected 3 default plans, got %d", len(plans))
|
||||
}
|
||||
|
||||
// Verify all plans have required fields
|
||||
for _, plan := range plans {
|
||||
if plan.ID == "" {
|
||||
t.Errorf("Plan ID should not be empty")
|
||||
}
|
||||
if plan.Name == "" {
|
||||
t.Errorf("Plan '%s' should have a name", plan.ID)
|
||||
}
|
||||
if plan.Description == "" {
|
||||
t.Errorf("Plan '%s' should have a description", plan.ID)
|
||||
}
|
||||
if plan.PriceCents <= 0 {
|
||||
t.Errorf("Plan '%s' should have a positive price, got %d", plan.ID, plan.PriceCents)
|
||||
}
|
||||
if plan.Currency != "eur" {
|
||||
t.Errorf("Plan '%s' currency should be 'eur', got '%s'", plan.ID, plan.Currency)
|
||||
}
|
||||
if plan.Interval != "month" {
|
||||
t.Errorf("Plan '%s' interval should be 'month', got '%s'", plan.ID, plan.Interval)
|
||||
}
|
||||
if !plan.IsActive {
|
||||
t.Errorf("Plan '%s' should be active", plan.ID)
|
||||
}
|
||||
if plan.SortOrder <= 0 {
|
||||
t.Errorf("Plan '%s' should have a positive sort order, got %d", plan.ID, plan.SortOrder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_TaskAllowanceProgression(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
// Basic should have lowest allowance
|
||||
basic := plans[0]
|
||||
standard := plans[1]
|
||||
premium := plans[2]
|
||||
|
||||
if basic.Features.MonthlyTaskAllowance >= standard.Features.MonthlyTaskAllowance {
|
||||
t.Error("Standard plan should have more tasks than Basic")
|
||||
}
|
||||
|
||||
if standard.Features.MonthlyTaskAllowance >= premium.Features.MonthlyTaskAllowance {
|
||||
t.Error("Premium plan should have more tasks than Standard")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_PriceProgression(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
// Prices should increase with each tier
|
||||
if plans[0].PriceCents >= plans[1].PriceCents {
|
||||
t.Error("Standard should cost more than Basic")
|
||||
}
|
||||
if plans[1].PriceCents >= plans[2].PriceCents {
|
||||
t.Error("Premium should cost more than Standard")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_FairUseModeOnlyForPremium(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
for _, plan := range plans {
|
||||
if plan.ID == models.PlanPremium {
|
||||
if !plan.Features.FairUseMode {
|
||||
t.Error("Premium plan should have FairUseMode enabled")
|
||||
}
|
||||
} else {
|
||||
if plan.Features.FairUseMode {
|
||||
t.Errorf("Plan '%s' should not have FairUseMode enabled", plan.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_MaxTaskBalanceCalculation(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
for _, plan := range plans {
|
||||
expected := plan.Features.MonthlyTaskAllowance * models.CarryoverMonthsCap
|
||||
if plan.Features.MaxTaskBalance != expected {
|
||||
t.Errorf("Plan '%s' MaxTaskBalance should be %d (allowance * 5), got %d",
|
||||
plan.ID, expected, plan.Features.MaxTaskBalance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditLogJSON_Marshaling(t *testing.T) {
|
||||
// Test that audit log values can be properly serialized
|
||||
oldValue := map[string]interface{}{
|
||||
"plan_id": "basic",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
newValue := map[string]interface{}{
|
||||
"plan_id": "standard",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
metadata := map[string]interface{}{
|
||||
"reason": "upgrade",
|
||||
}
|
||||
|
||||
// Marshal all values
|
||||
oldJSON, err := json.Marshal(oldValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal oldValue: %v", err)
|
||||
}
|
||||
|
||||
newJSON, err := json.Marshal(newValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal newValue: %v", err)
|
||||
}
|
||||
|
||||
metaJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal metadata: %v", err)
|
||||
}
|
||||
|
||||
// Verify non-empty
|
||||
if len(oldJSON) == 0 || len(newJSON) == 0 || len(metaJSON) == 0 {
|
||||
t.Error("JSON outputs should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionTrialCalculation(t *testing.T) {
|
||||
// Test trial days calculation logic
|
||||
trialDays := 7
|
||||
|
||||
if trialDays <= 0 {
|
||||
t.Error("Trial days should be positive")
|
||||
}
|
||||
|
||||
if trialDays > 30 {
|
||||
t.Error("Trial days should not exceed 30")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionInfo_TrialingStatus(t *testing.T) {
|
||||
info := models.SubscriptionInfo{
|
||||
PlanID: models.PlanBasic,
|
||||
PlanName: "Basic",
|
||||
Status: models.StatusTrialing,
|
||||
IsTrialing: true,
|
||||
TrialDaysLeft: 5,
|
||||
CancelAtPeriodEnd: false,
|
||||
PriceCents: 990,
|
||||
Currency: "eur",
|
||||
}
|
||||
|
||||
if !info.IsTrialing {
|
||||
t.Error("Should be trialing")
|
||||
}
|
||||
if info.Status != models.StatusTrialing {
|
||||
t.Errorf("Status should be 'trialing', got '%s'", info.Status)
|
||||
}
|
||||
if info.TrialDaysLeft <= 0 {
|
||||
t.Error("TrialDaysLeft should be positive during trial")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionInfo_ActiveStatus(t *testing.T) {
|
||||
info := models.SubscriptionInfo{
|
||||
PlanID: models.PlanStandard,
|
||||
PlanName: "Standard",
|
||||
Status: models.StatusActive,
|
||||
IsTrialing: false,
|
||||
TrialDaysLeft: 0,
|
||||
CancelAtPeriodEnd: false,
|
||||
PriceCents: 1990,
|
||||
Currency: "eur",
|
||||
}
|
||||
|
||||
if info.IsTrialing {
|
||||
t.Error("Should not be trialing")
|
||||
}
|
||||
if info.Status != models.StatusActive {
|
||||
t.Errorf("Status should be 'active', got '%s'", info.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionInfo_CanceledStatus(t *testing.T) {
|
||||
info := models.SubscriptionInfo{
|
||||
PlanID: models.PlanStandard,
|
||||
PlanName: "Standard",
|
||||
Status: models.StatusActive,
|
||||
IsTrialing: false,
|
||||
CancelAtPeriodEnd: true, // Scheduled for cancellation
|
||||
PriceCents: 1990,
|
||||
Currency: "eur",
|
||||
}
|
||||
|
||||
if !info.CancelAtPeriodEnd {
|
||||
t.Error("CancelAtPeriodEnd should be true")
|
||||
}
|
||||
// Status remains active until period end
|
||||
if info.Status != models.StatusActive {
|
||||
t.Errorf("Status should still be 'active', got '%s'", info.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookEventTypes(t *testing.T) {
|
||||
// Test common Stripe webhook event types we handle
|
||||
eventTypes := []string{
|
||||
"checkout.session.completed",
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"invoice.paid",
|
||||
"invoice.payment_failed",
|
||||
}
|
||||
|
||||
for _, eventType := range eventTypes {
|
||||
if eventType == "" {
|
||||
t.Error("Event type should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdempotencyKey_Format(t *testing.T) {
|
||||
// Test that we can handle Stripe event IDs
|
||||
sampleEventIDs := []string{
|
||||
"evt_1234567890abcdef",
|
||||
"evt_test_abc123xyz789",
|
||||
"evt_live_real_event_id",
|
||||
}
|
||||
|
||||
for _, eventID := range sampleEventIDs {
|
||||
if len(eventID) < 10 {
|
||||
t.Errorf("Event ID '%s' seems too short", eventID)
|
||||
}
|
||||
// Stripe event IDs typically start with "evt_"
|
||||
if eventID[:4] != "evt_" {
|
||||
t.Errorf("Event ID '%s' should start with 'evt_'", eventID)
|
||||
}
|
||||
}
|
||||
}
|
||||
352
billing-service/internal/services/task_service.go
Normal file
352
billing-service/internal/services/task_service.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTaskLimitReached is returned when task balance is 0
|
||||
ErrTaskLimitReached = errors.New("TASK_LIMIT_REACHED")
|
||||
// ErrNoSubscription is returned when user has no subscription
|
||||
ErrNoSubscription = errors.New("NO_SUBSCRIPTION")
|
||||
)
|
||||
|
||||
// TaskService handles task consumption and balance management
|
||||
type TaskService struct {
|
||||
db *database.DB
|
||||
subService *SubscriptionService
|
||||
}
|
||||
|
||||
// NewTaskService creates a new TaskService
|
||||
func NewTaskService(db *database.DB, subService *SubscriptionService) *TaskService {
|
||||
return &TaskService{
|
||||
db: db,
|
||||
subService: subService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountUsage retrieves or creates account usage for a user
|
||||
func (s *TaskService) GetAccountUsage(ctx context.Context, userID uuid.UUID) (*models.AccountUsage, error) {
|
||||
query := `
|
||||
SELECT id, account_id, plan, monthly_task_allowance, carryover_months_cap,
|
||||
max_task_balance, task_balance, last_renewal_at, created_at, updated_at
|
||||
FROM account_usage
|
||||
WHERE account_id = $1
|
||||
`
|
||||
|
||||
var usage models.AccountUsage
|
||||
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
|
||||
&usage.ID, &usage.AccountID, &usage.PlanID, &usage.MonthlyTaskAllowance,
|
||||
&usage.CarryoverMonthsCap, &usage.MaxTaskBalance, &usage.TaskBalance,
|
||||
&usage.LastRenewalAt, &usage.CreatedAt, &usage.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
// Create new account usage based on subscription
|
||||
return s.createAccountUsage(ctx, userID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if month renewal is needed
|
||||
if err := s.checkAndApplyMonthRenewal(ctx, &usage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
// createAccountUsage creates account usage based on user's subscription
|
||||
func (s *TaskService) createAccountUsage(ctx context.Context, userID uuid.UUID) (*models.AccountUsage, error) {
|
||||
// Get subscription to determine plan
|
||||
sub, err := s.subService.GetByUserID(ctx, userID)
|
||||
if err != nil || sub == nil {
|
||||
return nil, ErrNoSubscription
|
||||
}
|
||||
|
||||
// Get plan features
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID))
|
||||
if err != nil || plan == nil {
|
||||
return nil, fmt.Errorf("plan not found: %s", sub.PlanID)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
usage := &models.AccountUsage{
|
||||
AccountID: userID,
|
||||
PlanID: sub.PlanID,
|
||||
MonthlyTaskAllowance: plan.Features.MonthlyTaskAllowance,
|
||||
CarryoverMonthsCap: models.CarryoverMonthsCap,
|
||||
MaxTaskBalance: plan.Features.MaxTaskBalance,
|
||||
TaskBalance: plan.Features.MonthlyTaskAllowance, // Start with one month's worth
|
||||
LastRenewalAt: now,
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO account_usage (
|
||||
account_id, plan, monthly_task_allowance, carryover_months_cap,
|
||||
max_task_balance, task_balance, last_renewal_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
usage.AccountID, usage.PlanID, usage.MonthlyTaskAllowance,
|
||||
usage.CarryoverMonthsCap, usage.MaxTaskBalance, usage.TaskBalance, usage.LastRenewalAt,
|
||||
).Scan(&usage.ID, &usage.CreatedAt, &usage.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// checkAndApplyMonthRenewal checks if a month has passed and adds allowance
|
||||
// Implements the carryover logic: tasks accumulate up to max_task_balance
|
||||
func (s *TaskService) checkAndApplyMonthRenewal(ctx context.Context, usage *models.AccountUsage) error {
|
||||
now := time.Now()
|
||||
|
||||
// Check if at least one month has passed since last renewal
|
||||
monthsSinceRenewal := monthsBetween(usage.LastRenewalAt, now)
|
||||
if monthsSinceRenewal < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate new balance with carryover
|
||||
// Add monthly allowance for each month that passed
|
||||
newBalance := usage.TaskBalance
|
||||
for i := 0; i < monthsSinceRenewal; i++ {
|
||||
newBalance += usage.MonthlyTaskAllowance
|
||||
// Cap at max balance
|
||||
if newBalance > usage.MaxTaskBalance {
|
||||
newBalance = usage.MaxTaskBalance
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate new renewal date (add the number of months)
|
||||
newRenewalAt := usage.LastRenewalAt.AddDate(0, monthsSinceRenewal, 0)
|
||||
|
||||
// Update in database
|
||||
query := `
|
||||
UPDATE account_usage
|
||||
SET task_balance = $2, last_renewal_at = $3, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, usage.ID, newBalance, newRenewalAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update local struct
|
||||
usage.TaskBalance = newBalance
|
||||
usage.LastRenewalAt = newRenewalAt
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monthsBetween calculates full months between two dates
|
||||
func monthsBetween(start, end time.Time) int {
|
||||
months := 0
|
||||
for start.AddDate(0, months+1, 0).Before(end) || start.AddDate(0, months+1, 0).Equal(end) {
|
||||
months++
|
||||
}
|
||||
return months
|
||||
}
|
||||
|
||||
// CheckTaskAllowed checks if a task can be consumed (balance > 0)
|
||||
func (s *TaskService) CheckTaskAllowed(ctx context.Context, userID uuid.UUID) (*models.CheckTaskAllowedResponse, error) {
|
||||
usage, err := s.GetAccountUsage(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNoSubscription) {
|
||||
return &models.CheckTaskAllowedResponse{
|
||||
Allowed: false,
|
||||
PlanID: "",
|
||||
Message: "Kein aktives Abonnement gefunden.",
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Premium Fair Use mode - always allow
|
||||
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
|
||||
if plan != nil && plan.Features.FairUseMode {
|
||||
return &models.CheckTaskAllowedResponse{
|
||||
Allowed: true,
|
||||
TasksAvailable: usage.TaskBalance,
|
||||
MaxTasks: usage.MaxTaskBalance,
|
||||
PlanID: usage.PlanID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
allowed := usage.TaskBalance > 0
|
||||
|
||||
response := &models.CheckTaskAllowedResponse{
|
||||
Allowed: allowed,
|
||||
TasksAvailable: usage.TaskBalance,
|
||||
MaxTasks: usage.MaxTaskBalance,
|
||||
PlanID: usage.PlanID,
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
response.Message = "Dein Aufgaben-Kontingent ist aufgebraucht."
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ConsumeTask consumes one task from the balance
|
||||
// Returns error if balance is 0
|
||||
func (s *TaskService) ConsumeTask(ctx context.Context, userID uuid.UUID, taskType models.TaskType) (*models.ConsumeTaskResponse, error) {
|
||||
// First check if allowed
|
||||
checkResponse, err := s.CheckTaskAllowed(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !checkResponse.Allowed {
|
||||
return &models.ConsumeTaskResponse{
|
||||
Success: false,
|
||||
TasksRemaining: 0,
|
||||
Message: checkResponse.Message,
|
||||
}, ErrTaskLimitReached
|
||||
}
|
||||
|
||||
// Get current usage
|
||||
usage, err := s.GetAccountUsage(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx, err := s.db.Pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Decrement balance (only if not Premium Fair Use)
|
||||
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
|
||||
newBalance := usage.TaskBalance
|
||||
if plan == nil || !plan.Features.FairUseMode {
|
||||
newBalance = usage.TaskBalance - 1
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE account_usage
|
||||
SET task_balance = $2, updated_at = NOW()
|
||||
WHERE account_id = $1
|
||||
`, userID, newBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Create task record
|
||||
taskID := uuid.New()
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO tasks (id, account_id, task_type, consumed, created_at)
|
||||
VALUES ($1, $2, $3, true, NOW())
|
||||
`, taskID, userID, taskType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err = tx.Commit(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.ConsumeTaskResponse{
|
||||
Success: true,
|
||||
TaskID: taskID.String(),
|
||||
TasksRemaining: newBalance,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTaskUsageInfo returns formatted task usage info for display
|
||||
func (s *TaskService) GetTaskUsageInfo(ctx context.Context, userID uuid.UUID) (*models.TaskUsageInfo, error) {
|
||||
usage, err := s.GetAccountUsage(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for Fair Use mode (Premium)
|
||||
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
|
||||
if plan != nil && plan.Features.FairUseMode {
|
||||
return &models.TaskUsageInfo{
|
||||
TasksAvailable: usage.TaskBalance,
|
||||
MaxTasks: usage.MaxTaskBalance,
|
||||
InfoText: "Unbegrenzte Aufgaben (Fair Use)",
|
||||
TooltipText: "Im Premium-Tarif gibt es keine praktische Begrenzung.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &models.TaskUsageInfo{
|
||||
TasksAvailable: usage.TaskBalance,
|
||||
MaxTasks: usage.MaxTaskBalance,
|
||||
InfoText: fmt.Sprintf("Aufgaben verfuegbar: %d von max. %d", usage.TaskBalance, usage.MaxTaskBalance),
|
||||
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdatePlanForUser updates the plan and adjusts allowances
|
||||
func (s *TaskService) UpdatePlanForUser(ctx context.Context, userID uuid.UUID, newPlanID models.PlanID) error {
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(newPlanID))
|
||||
if err != nil || plan == nil {
|
||||
return fmt.Errorf("plan not found: %s", newPlanID)
|
||||
}
|
||||
|
||||
// Update account usage with new plan limits
|
||||
query := `
|
||||
UPDATE account_usage
|
||||
SET plan = $2,
|
||||
monthly_task_allowance = $3,
|
||||
max_task_balance = $4,
|
||||
updated_at = NOW()
|
||||
WHERE account_id = $1
|
||||
`
|
||||
|
||||
_, err = s.db.Pool.Exec(ctx, query,
|
||||
userID, newPlanID, plan.Features.MonthlyTaskAllowance, plan.Features.MaxTaskBalance)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTaskHistory returns task history for a user
|
||||
func (s *TaskService) GetTaskHistory(ctx context.Context, userID uuid.UUID, limit int) ([]models.Task, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, account_id, task_type, created_at, consumed
|
||||
FROM tasks
|
||||
WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2
|
||||
`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, userID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []models.Task
|
||||
for rows.Next() {
|
||||
var task models.Task
|
||||
err := rows.Scan(&task.ID, &task.AccountID, &task.TaskType, &task.CreatedAt, &task.Consumed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
397
billing-service/internal/services/task_service_test.go
Normal file
397
billing-service/internal/services/task_service_test.go
Normal file
@@ -0,0 +1,397 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMonthsBetween(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
start time.Time
|
||||
end time.Time
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Same day",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Less than one month",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 2, 10, 0, 0, 0, 0, time.UTC),
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Exactly one month",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "One month and one day",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 2, 16, 0, 0, 0, 0, time.UTC),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Two months",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 3, 15, 0, 0, 0, 0, time.UTC),
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Five months exactly",
|
||||
start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC),
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "Year boundary",
|
||||
start: time.Date(2024, 11, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Leap year February to March",
|
||||
start: time.Date(2024, 2, 29, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2024, 3, 29, 0, 0, 0, 0, time.UTC),
|
||||
expected: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := monthsBetween(tt.start, tt.end)
|
||||
if result != tt.expected {
|
||||
t.Errorf("monthsBetween(%v, %v) = %d, expected %d",
|
||||
tt.start.Format("2006-01-02"), tt.end.Format("2006-01-02"),
|
||||
result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCarryoverLogic(t *testing.T) {
|
||||
// Test the carryover calculation logic
|
||||
tests := []struct {
|
||||
name string
|
||||
currentBalance int
|
||||
monthlyAllowance int
|
||||
maxBalance int
|
||||
monthsSinceRenewal int
|
||||
expectedNewBalance int
|
||||
}{
|
||||
{
|
||||
name: "Normal renewal - add allowance",
|
||||
currentBalance: 50,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 80,
|
||||
},
|
||||
{
|
||||
name: "Two months missed",
|
||||
currentBalance: 50,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 2,
|
||||
expectedNewBalance: 110,
|
||||
},
|
||||
{
|
||||
name: "Cap at max balance",
|
||||
currentBalance: 140,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 150,
|
||||
},
|
||||
{
|
||||
name: "Already at max - no change",
|
||||
currentBalance: 150,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 150,
|
||||
},
|
||||
{
|
||||
name: "Multiple months - cap applies",
|
||||
currentBalance: 100,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 5,
|
||||
expectedNewBalance: 150,
|
||||
},
|
||||
{
|
||||
name: "Empty balance - add one month",
|
||||
currentBalance: 0,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 30,
|
||||
},
|
||||
{
|
||||
name: "Empty balance - add five months",
|
||||
currentBalance: 0,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 5,
|
||||
expectedNewBalance: 150,
|
||||
},
|
||||
{
|
||||
name: "Standard plan - normal case",
|
||||
currentBalance: 200,
|
||||
monthlyAllowance: 100,
|
||||
maxBalance: 500,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 300,
|
||||
},
|
||||
{
|
||||
name: "Premium plan - Fair Use",
|
||||
currentBalance: 1000,
|
||||
monthlyAllowance: 1000,
|
||||
maxBalance: 5000,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 2000,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the carryover logic
|
||||
newBalance := tt.currentBalance
|
||||
for i := 0; i < tt.monthsSinceRenewal; i++ {
|
||||
newBalance += tt.monthlyAllowance
|
||||
if newBalance > tt.maxBalance {
|
||||
newBalance = tt.maxBalance
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if newBalance != tt.expectedNewBalance {
|
||||
t.Errorf("Carryover for balance=%d, allowance=%d, max=%d, months=%d = %d, expected %d",
|
||||
tt.currentBalance, tt.monthlyAllowance, tt.maxBalance, tt.monthsSinceRenewal,
|
||||
newBalance, tt.expectedNewBalance)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskBalanceAfterConsumption(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentBalance int
|
||||
tasksToConsume int
|
||||
expectedBalance int
|
||||
shouldBeAllowed bool
|
||||
}{
|
||||
{
|
||||
name: "Normal consumption",
|
||||
currentBalance: 50,
|
||||
tasksToConsume: 1,
|
||||
expectedBalance: 49,
|
||||
shouldBeAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "Last task",
|
||||
currentBalance: 1,
|
||||
tasksToConsume: 1,
|
||||
expectedBalance: 0,
|
||||
shouldBeAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "Empty balance - not allowed",
|
||||
currentBalance: 0,
|
||||
tasksToConsume: 1,
|
||||
expectedBalance: 0,
|
||||
shouldBeAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple tasks",
|
||||
currentBalance: 50,
|
||||
tasksToConsume: 5,
|
||||
expectedBalance: 45,
|
||||
shouldBeAllowed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test if allowed
|
||||
allowed := tt.currentBalance > 0
|
||||
if allowed != tt.shouldBeAllowed {
|
||||
t.Errorf("Task allowed with balance=%d: got %v, expected %v",
|
||||
tt.currentBalance, allowed, tt.shouldBeAllowed)
|
||||
}
|
||||
|
||||
// Test balance calculation
|
||||
if allowed {
|
||||
newBalance := tt.currentBalance - tt.tasksToConsume
|
||||
if newBalance != tt.expectedBalance {
|
||||
t.Errorf("Balance after consuming %d tasks from %d: got %d, expected %d",
|
||||
tt.tasksToConsume, tt.currentBalance, newBalance, tt.expectedBalance)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskServiceErrors(t *testing.T) {
|
||||
// Test error constants
|
||||
if ErrTaskLimitReached == nil {
|
||||
t.Error("ErrTaskLimitReached should not be nil")
|
||||
}
|
||||
if ErrTaskLimitReached.Error() != "TASK_LIMIT_REACHED" {
|
||||
t.Errorf("ErrTaskLimitReached should be 'TASK_LIMIT_REACHED', got '%s'", ErrTaskLimitReached.Error())
|
||||
}
|
||||
|
||||
if ErrNoSubscription == nil {
|
||||
t.Error("ErrNoSubscription should not be nil")
|
||||
}
|
||||
if ErrNoSubscription.Error() != "NO_SUBSCRIPTION" {
|
||||
t.Errorf("ErrNoSubscription should be 'NO_SUBSCRIPTION', got '%s'", ErrNoSubscription.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewalDateCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lastRenewal time.Time
|
||||
monthsToAdd int
|
||||
expectedRenewal time.Time
|
||||
}{
|
||||
{
|
||||
name: "Add one month",
|
||||
lastRenewal: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
monthsToAdd: 1,
|
||||
expectedRenewal: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "Add three months",
|
||||
lastRenewal: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
monthsToAdd: 3,
|
||||
expectedRenewal: time.Date(2025, 4, 15, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "Year boundary",
|
||||
lastRenewal: time.Date(2024, 11, 15, 0, 0, 0, 0, time.UTC),
|
||||
monthsToAdd: 3,
|
||||
expectedRenewal: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "End of month adjustment",
|
||||
lastRenewal: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC),
|
||||
monthsToAdd: 1,
|
||||
// Go's AddDate handles this - February doesn't have 31 days
|
||||
expectedRenewal: time.Date(2025, 3, 3, 0, 0, 0, 0, time.UTC), // Feb 31 -> March 3
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.lastRenewal.AddDate(0, tt.monthsToAdd, 0)
|
||||
if !result.Equal(tt.expectedRenewal) {
|
||||
t.Errorf("AddDate(%v, %d months) = %v, expected %v",
|
||||
tt.lastRenewal.Format("2006-01-02"), tt.monthsToAdd,
|
||||
result.Format("2006-01-02"), tt.expectedRenewal.Format("2006-01-02"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFairUseModeLogic(t *testing.T) {
|
||||
// Test that Fair Use mode always allows tasks regardless of balance
|
||||
tests := []struct {
|
||||
name string
|
||||
fairUseMode bool
|
||||
balance int
|
||||
shouldAllow bool
|
||||
}{
|
||||
{
|
||||
name: "Fair Use - zero balance still allowed",
|
||||
fairUseMode: true,
|
||||
balance: 0,
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "Fair Use - normal balance allowed",
|
||||
fairUseMode: true,
|
||||
balance: 1000,
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "Not Fair Use - zero balance not allowed",
|
||||
fairUseMode: false,
|
||||
balance: 0,
|
||||
shouldAllow: false,
|
||||
},
|
||||
{
|
||||
name: "Not Fair Use - positive balance allowed",
|
||||
fairUseMode: false,
|
||||
balance: 50,
|
||||
shouldAllow: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the check logic
|
||||
var allowed bool
|
||||
if tt.fairUseMode {
|
||||
allowed = true // Fair Use always allows
|
||||
} else {
|
||||
allowed = tt.balance > 0
|
||||
}
|
||||
|
||||
if allowed != tt.shouldAllow {
|
||||
t.Errorf("FairUseMode=%v, balance=%d: allowed=%v, expected=%v",
|
||||
tt.fairUseMode, tt.balance, allowed, tt.shouldAllow)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBalanceDecrementLogic(t *testing.T) {
|
||||
// Test that Fair Use mode doesn't decrement balance
|
||||
tests := []struct {
|
||||
name string
|
||||
fairUseMode bool
|
||||
initialBalance int
|
||||
expectedAfter int
|
||||
}{
|
||||
{
|
||||
name: "Normal plan - decrement",
|
||||
fairUseMode: false,
|
||||
initialBalance: 50,
|
||||
expectedAfter: 49,
|
||||
},
|
||||
{
|
||||
name: "Fair Use - no decrement",
|
||||
fairUseMode: true,
|
||||
initialBalance: 1000,
|
||||
expectedAfter: 1000,
|
||||
},
|
||||
{
|
||||
name: "Normal plan - last task",
|
||||
fairUseMode: false,
|
||||
initialBalance: 1,
|
||||
expectedAfter: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
newBalance := tt.initialBalance
|
||||
if !tt.fairUseMode {
|
||||
newBalance = tt.initialBalance - 1
|
||||
}
|
||||
|
||||
if newBalance != tt.expectedAfter {
|
||||
t.Errorf("FairUseMode=%v, initial=%d: got %d, expected %d",
|
||||
tt.fairUseMode, tt.initialBalance, newBalance, tt.expectedAfter)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
194
billing-service/internal/services/usage_service.go
Normal file
194
billing-service/internal/services/usage_service.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UsageService handles usage tracking operations
|
||||
type UsageService struct {
|
||||
db *database.DB
|
||||
entitlementService *EntitlementService
|
||||
}
|
||||
|
||||
// NewUsageService creates a new UsageService
|
||||
func NewUsageService(db *database.DB, entitlementService *EntitlementService) *UsageService {
|
||||
return &UsageService{
|
||||
db: db,
|
||||
entitlementService: entitlementService,
|
||||
}
|
||||
}
|
||||
|
||||
// TrackUsage tracks usage for a user
|
||||
func (s *UsageService) TrackUsage(ctx context.Context, userIDStr, usageType string, quantity int) error {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid user ID: %w", err)
|
||||
}
|
||||
|
||||
// Get current period start (beginning of current month)
|
||||
now := time.Now()
|
||||
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// Upsert usage summary
|
||||
query := `
|
||||
INSERT INTO usage_summary (user_id, usage_type, period_start, total_count)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_id, usage_type, period_start) DO UPDATE SET
|
||||
total_count = usage_summary.total_count + EXCLUDED.total_count,
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
_, err = s.db.Pool.Exec(ctx, query, userID, usageType, periodStart, quantity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to track usage: %w", err)
|
||||
}
|
||||
|
||||
// Also update entitlements cache
|
||||
return s.entitlementService.IncrementUsage(ctx, userID, usageType, quantity)
|
||||
}
|
||||
|
||||
// GetUsageSummary returns usage summary for a user
|
||||
func (s *UsageService) GetUsageSummary(ctx context.Context, userID uuid.UUID) (*models.UsageInfo, error) {
|
||||
// Get entitlements (which include current usage)
|
||||
ent, err := s.entitlementService.getUserEntitlements(ctx, userID)
|
||||
if err != nil || ent == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate percentages
|
||||
aiPercent := 0.0
|
||||
if ent.AIRequestsLimit > 0 {
|
||||
aiPercent = float64(ent.AIRequestsUsed) / float64(ent.AIRequestsLimit) * 100
|
||||
}
|
||||
|
||||
docPercent := 0.0
|
||||
if ent.DocumentsLimit > 0 {
|
||||
docPercent = float64(ent.DocumentsUsed) / float64(ent.DocumentsLimit) * 100
|
||||
}
|
||||
|
||||
// Get period dates
|
||||
now := time.Now()
|
||||
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
periodEnd := periodStart.AddDate(0, 1, 0).Add(-time.Second)
|
||||
|
||||
return &models.UsageInfo{
|
||||
AIRequestsUsed: ent.AIRequestsUsed,
|
||||
AIRequestsLimit: ent.AIRequestsLimit,
|
||||
AIRequestsPercent: aiPercent,
|
||||
DocumentsUsed: ent.DocumentsUsed,
|
||||
DocumentsLimit: ent.DocumentsLimit,
|
||||
DocumentsPercent: docPercent,
|
||||
PeriodStart: periodStart.Format("2006-01-02"),
|
||||
PeriodEnd: periodEnd.Format("2006-01-02"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckUsageAllowed checks if a user is allowed to perform a usage action
|
||||
func (s *UsageService) CheckUsageAllowed(ctx context.Context, userIDStr, usageType string) (*models.CheckUsageResponse, error) {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
return &models.CheckUsageResponse{
|
||||
Allowed: false,
|
||||
Message: "Invalid user ID",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get entitlements
|
||||
ent, err := s.entitlementService.getUserEntitlements(ctx, userID)
|
||||
if err != nil {
|
||||
return &models.CheckUsageResponse{
|
||||
Allowed: false,
|
||||
Message: "Failed to get entitlements",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if ent == nil {
|
||||
return &models.CheckUsageResponse{
|
||||
Allowed: false,
|
||||
Message: "No subscription found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var currentUsage, limit int
|
||||
switch usageType {
|
||||
case "ai_request":
|
||||
currentUsage = ent.AIRequestsUsed
|
||||
limit = ent.AIRequestsLimit
|
||||
case "document_created":
|
||||
currentUsage = ent.DocumentsUsed
|
||||
limit = ent.DocumentsLimit
|
||||
default:
|
||||
return &models.CheckUsageResponse{
|
||||
Allowed: true,
|
||||
Message: "Unknown usage type - allowing",
|
||||
}, nil
|
||||
}
|
||||
|
||||
remaining := limit - currentUsage
|
||||
allowed := remaining > 0
|
||||
|
||||
response := &models.CheckUsageResponse{
|
||||
Allowed: allowed,
|
||||
CurrentUsage: currentUsage,
|
||||
Limit: limit,
|
||||
Remaining: remaining,
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
response.Message = fmt.Sprintf("Usage limit reached for %s (%d/%d)", usageType, currentUsage, limit)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetUsageHistory returns usage history for a user
|
||||
func (s *UsageService) GetUsageHistory(ctx context.Context, userID uuid.UUID, months int) ([]models.UsageSummary, error) {
|
||||
query := `
|
||||
SELECT id, user_id, usage_type, period_start, total_count, created_at, updated_at
|
||||
FROM usage_summary
|
||||
WHERE user_id = $1
|
||||
AND period_start >= $2
|
||||
ORDER BY period_start DESC, usage_type
|
||||
`
|
||||
|
||||
// Calculate start date
|
||||
startDate := time.Now().AddDate(0, -months, 0)
|
||||
startDate = time.Date(startDate.Year(), startDate.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, userID, startDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var summaries []models.UsageSummary
|
||||
for rows.Next() {
|
||||
var summary models.UsageSummary
|
||||
err := rows.Scan(
|
||||
&summary.ID, &summary.UserID, &summary.UsageType,
|
||||
&summary.PeriodStart, &summary.TotalCount,
|
||||
&summary.CreatedAt, &summary.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summaries = append(summaries, summary)
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
// ResetPeriodUsage resets usage for a new billing period
|
||||
func (s *UsageService) ResetPeriodUsage(ctx context.Context, userID uuid.UUID) error {
|
||||
now := time.Now()
|
||||
newPeriodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
newPeriodEnd := newPeriodStart.AddDate(0, 1, 0).Add(-time.Second)
|
||||
|
||||
return s.entitlementService.ResetUsageCounters(ctx, userID, &newPeriodStart, &newPeriodEnd)
|
||||
}
|
||||
Reference in New Issue
Block a user