Initial commit: breakpilot-core - Shared Infrastructure

Docker Compose with 24+ services:
- PostgreSQL (PostGIS), Valkey, MinIO, Qdrant
- Vault (PKI/TLS), Nginx (Reverse Proxy)
- Backend Core API, Consent Service, Billing Service
- RAG Service, Embedding Service
- Gitea, Woodpecker CI/CD
- Night Scheduler, Health Aggregator
- Jitsi (Web/XMPP/JVB/Jicofo), Mailpit

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Boenisch
2026-02-11 23:47:13 +01:00
commit ad111d5e69
244 changed files with 84288 additions and 0 deletions

View File

@@ -0,0 +1,157 @@
package config
import (
"fmt"
"os"
"github.com/joho/godotenv"
)
// Config holds all configuration for the billing service
type Config struct {
// Server
Port string
Environment string
// Database
DatabaseURL string
// JWT (shared with consent-service)
JWTSecret string
// Stripe
StripeSecretKey string
StripeWebhookSecret string
StripePublishableKey string
StripeMockMode bool // If true, Stripe calls are mocked (for dev without Stripe keys)
// URLs
BillingSuccessURL string
BillingCancelURL string
FrontendURL string
// Trial
TrialPeriodDays int
// CORS
AllowedOrigins []string
// Rate Limiting
RateLimitRequests int
RateLimitWindow int // in seconds
// Internal API Key (for service-to-service communication)
InternalAPIKey string
}
// Load loads configuration from environment variables
func Load() (*Config, error) {
// Load .env file if exists (for development)
_ = godotenv.Load()
cfg := &Config{
Port: getEnv("PORT", "8083"),
Environment: getEnv("ENVIRONMENT", "development"),
DatabaseURL: getEnv("DATABASE_URL", ""),
JWTSecret: getEnv("JWT_SECRET", ""),
// Stripe
StripeSecretKey: getEnv("STRIPE_SECRET_KEY", ""),
StripeWebhookSecret: getEnv("STRIPE_WEBHOOK_SECRET", ""),
StripePublishableKey: getEnv("STRIPE_PUBLISHABLE_KEY", ""),
StripeMockMode: getEnvBool("STRIPE_MOCK_MODE", false),
// URLs
BillingSuccessURL: getEnv("BILLING_SUCCESS_URL", "http://localhost:8000/app/billing/success"),
BillingCancelURL: getEnv("BILLING_CANCEL_URL", "http://localhost:8000/app/billing/cancel"),
FrontendURL: getEnv("FRONTEND_URL", "http://localhost:8000"),
// Trial
TrialPeriodDays: getEnvInt("TRIAL_PERIOD_DAYS", 7),
// Rate Limiting
RateLimitRequests: getEnvInt("RATE_LIMIT_REQUESTS", 100),
RateLimitWindow: getEnvInt("RATE_LIMIT_WINDOW", 60),
// Internal API
InternalAPIKey: getEnv("INTERNAL_API_KEY", ""),
}
// Parse allowed origins
originsStr := getEnv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8000")
cfg.AllowedOrigins = parseCommaSeparated(originsStr)
// Validate required fields
if cfg.DatabaseURL == "" {
return nil, fmt.Errorf("DATABASE_URL is required")
}
if cfg.JWTSecret == "" {
return nil, fmt.Errorf("JWT_SECRET is required")
}
// Stripe key is required unless mock mode is enabled
if cfg.StripeSecretKey == "" && !cfg.StripeMockMode {
// In development mode, auto-enable mock mode if no Stripe key
if cfg.Environment == "development" {
cfg.StripeMockMode = true
} else {
return nil, fmt.Errorf("STRIPE_SECRET_KEY is required (set STRIPE_MOCK_MODE=true to bypass in dev)")
}
}
return cfg, nil
}
// IsMockMode returns true if Stripe should be mocked
func (c *Config) IsMockMode() bool {
return c.StripeMockMode
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
var result int
fmt.Sscanf(value, "%d", &result)
return result
}
return defaultValue
}
func getEnvBool(key string, defaultValue bool) bool {
if value := os.Getenv(key); value != "" {
return value == "true" || value == "1" || value == "yes"
}
return defaultValue
}
func parseCommaSeparated(s string) []string {
if s == "" {
return []string{}
}
var result []string
start := 0
for i := 0; i <= len(s); i++ {
if i == len(s) || s[i] == ',' {
item := s[start:i]
// Trim whitespace
for len(item) > 0 && item[0] == ' ' {
item = item[1:]
}
for len(item) > 0 && item[len(item)-1] == ' ' {
item = item[:len(item)-1]
}
if item != "" {
result = append(result, item)
}
start = i + 1
}
}
return result
}

View File

@@ -0,0 +1,260 @@
package database
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
// DB wraps the pgx pool
type DB struct {
Pool *pgxpool.Pool
}
// Connect establishes a connection to the PostgreSQL database
func Connect(databaseURL string) (*DB, error) {
config, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("failed to parse database URL: %w", err)
}
// Configure connection pool
config.MaxConns = 15
config.MinConns = 3
config.MaxConnLifetime = time.Hour
config.MaxConnIdleTime = 30 * time.Minute
config.HealthCheckPeriod = time.Minute
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
// Test the connection
if err := pool.Ping(ctx); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &DB{Pool: pool}, nil
}
// Close closes the database connection pool
func (db *DB) Close() {
db.Pool.Close()
}
// Migrate runs database migrations for the billing service
func Migrate(db *DB) error {
ctx := context.Background()
migrations := []string{
// =============================================
// Billing Service Tables
// =============================================
// Subscriptions - core subscription data
`CREATE TABLE IF NOT EXISTS subscriptions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
stripe_customer_id VARCHAR(255),
stripe_subscription_id VARCHAR(255) UNIQUE,
plan_id VARCHAR(50) NOT NULL,
status VARCHAR(30) NOT NULL DEFAULT 'trialing',
trial_end TIMESTAMPTZ,
current_period_start TIMESTAMPTZ,
current_period_end TIMESTAMPTZ,
cancel_at_period_end BOOLEAN DEFAULT FALSE,
canceled_at TIMESTAMPTZ,
ended_at TIMESTAMPTZ,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(user_id)
)`,
// Billing Plans - cached from Stripe
`CREATE TABLE IF NOT EXISTS billing_plans (
id VARCHAR(50) PRIMARY KEY,
stripe_price_id VARCHAR(255) UNIQUE,
stripe_product_id VARCHAR(255),
name VARCHAR(100) NOT NULL,
description TEXT,
price_cents INT NOT NULL,
currency VARCHAR(3) DEFAULT 'eur',
interval VARCHAR(10) DEFAULT 'month',
features JSONB DEFAULT '{}',
is_active BOOLEAN DEFAULT TRUE,
sort_order INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Usage Summary - aggregated usage per period
`CREATE TABLE IF NOT EXISTS usage_summary (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
usage_type VARCHAR(50) NOT NULL,
period_start TIMESTAMPTZ NOT NULL,
total_count INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(user_id, usage_type, period_start)
)`,
// User Entitlements - cached entitlements for fast lookups
`CREATE TABLE IF NOT EXISTS user_entitlements (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL UNIQUE,
plan_id VARCHAR(50) NOT NULL,
ai_requests_limit INT DEFAULT 0,
ai_requests_used INT DEFAULT 0,
documents_limit INT DEFAULT 0,
documents_used INT DEFAULT 0,
features JSONB DEFAULT '{}',
period_start TIMESTAMPTZ,
period_end TIMESTAMPTZ,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Stripe Webhook Events - for idempotency
`CREATE TABLE IF NOT EXISTS stripe_webhook_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
stripe_event_id VARCHAR(255) UNIQUE NOT NULL,
event_type VARCHAR(100) NOT NULL,
processed BOOLEAN DEFAULT FALSE,
processed_at TIMESTAMPTZ,
payload JSONB,
error_message TEXT,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Billing Audit Log
`CREATE TABLE IF NOT EXISTS billing_audit_log (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID,
action VARCHAR(50) NOT NULL,
entity_type VARCHAR(50),
entity_id VARCHAR(255),
old_value JSONB,
new_value JSONB,
metadata JSONB,
ip_address INET,
user_agent TEXT,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Invoices - cached from Stripe
`CREATE TABLE IF NOT EXISTS invoices (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
stripe_invoice_id VARCHAR(255) UNIQUE NOT NULL,
stripe_subscription_id VARCHAR(255),
status VARCHAR(30) NOT NULL,
amount_due INT NOT NULL,
amount_paid INT DEFAULT 0,
currency VARCHAR(3) DEFAULT 'eur',
hosted_invoice_url TEXT,
invoice_pdf TEXT,
period_start TIMESTAMPTZ,
period_end TIMESTAMPTZ,
due_date TIMESTAMPTZ,
paid_at TIMESTAMPTZ,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
// =============================================
// Task-based Billing Tables
// =============================================
// Account Usage - tracks task balance per account
`CREATE TABLE IF NOT EXISTS account_usage (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
account_id UUID NOT NULL UNIQUE,
plan VARCHAR(50) NOT NULL,
monthly_task_allowance INT NOT NULL,
carryover_months_cap INT DEFAULT 5,
max_task_balance INT NOT NULL,
task_balance INT NOT NULL,
last_renewal_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Tasks - individual task consumption records
`CREATE TABLE IF NOT EXISTS tasks (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
account_id UUID NOT NULL,
task_type VARCHAR(50) NOT NULL,
consumed BOOLEAN DEFAULT TRUE,
page_count INT DEFAULT 0,
token_count INT DEFAULT 0,
process_time INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
// =============================================
// Indexes
// =============================================
`CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_subscriptions_stripe_customer ON subscriptions(stripe_customer_id)`,
`CREATE INDEX IF NOT EXISTS idx_subscriptions_stripe_sub ON subscriptions(stripe_subscription_id)`,
`CREATE INDEX IF NOT EXISTS idx_subscriptions_status ON subscriptions(status)`,
`CREATE INDEX IF NOT EXISTS idx_subscriptions_trial_end ON subscriptions(trial_end)`,
`CREATE INDEX IF NOT EXISTS idx_usage_summary_user ON usage_summary(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_usage_summary_period ON usage_summary(period_start)`,
`CREATE INDEX IF NOT EXISTS idx_usage_summary_type ON usage_summary(usage_type)`,
`CREATE INDEX IF NOT EXISTS idx_user_entitlements_user ON user_entitlements(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_user_entitlements_plan ON user_entitlements(plan_id)`,
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_event_id ON stripe_webhook_events(stripe_event_id)`,
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_type ON stripe_webhook_events(event_type)`,
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_processed ON stripe_webhook_events(processed)`,
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_user ON billing_audit_log(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_action ON billing_audit_log(action)`,
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_created ON billing_audit_log(created_at)`,
`CREATE INDEX IF NOT EXISTS idx_invoices_user ON invoices(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_invoices_stripe_invoice ON invoices(stripe_invoice_id)`,
`CREATE INDEX IF NOT EXISTS idx_invoices_status ON invoices(status)`,
`CREATE INDEX IF NOT EXISTS idx_account_usage_account ON account_usage(account_id)`,
`CREATE INDEX IF NOT EXISTS idx_account_usage_plan ON account_usage(plan)`,
`CREATE INDEX IF NOT EXISTS idx_account_usage_renewal ON account_usage(last_renewal_at)`,
`CREATE INDEX IF NOT EXISTS idx_tasks_account ON tasks(account_id)`,
`CREATE INDEX IF NOT EXISTS idx_tasks_type ON tasks(task_type)`,
`CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)`,
// =============================================
// Insert default plans
// =============================================
`INSERT INTO billing_plans (id, name, description, price_cents, currency, interval, features, sort_order)
VALUES
('basic', 'Basic', 'Perfekt für den Einstieg', 990, 'eur', 'month',
'{"ai_requests_limit": 300, "documents_limit": 50, "feature_flags": ["basic_ai", "basic_documents"], "max_team_members": 1, "priority_support": false, "custom_branding": false}',
1),
('standard', 'Standard', 'Für regelmäßige Nutzer', 1990, 'eur', 'month',
'{"ai_requests_limit": 1500, "documents_limit": 200, "feature_flags": ["basic_ai", "basic_documents", "templates", "batch_processing"], "max_team_members": 3, "priority_support": false, "custom_branding": false}',
2),
('premium', 'Premium', 'Für Teams und Power-User', 3990, 'eur', 'month',
'{"ai_requests_limit": 5000, "documents_limit": 1000, "feature_flags": ["basic_ai", "basic_documents", "templates", "batch_processing", "team_features", "admin_panel", "audit_log", "api_access"], "max_team_members": 10, "priority_support": true, "custom_branding": true}',
3)
ON CONFLICT (id) DO NOTHING`,
}
for _, migration := range migrations {
if _, err := db.Pool.Exec(ctx, migration); err != nil {
return fmt.Errorf("failed to run migration: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,427 @@
package handlers
import (
"net/http"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/middleware"
"github.com/breakpilot/billing-service/internal/models"
"github.com/breakpilot/billing-service/internal/services"
"github.com/gin-gonic/gin"
)
// BillingHandler handles billing-related HTTP requests
type BillingHandler struct {
db *database.DB
subscriptionService *services.SubscriptionService
stripeService *services.StripeService
entitlementService *services.EntitlementService
usageService *services.UsageService
}
// NewBillingHandler creates a new BillingHandler
func NewBillingHandler(
db *database.DB,
subscriptionService *services.SubscriptionService,
stripeService *services.StripeService,
entitlementService *services.EntitlementService,
usageService *services.UsageService,
) *BillingHandler {
return &BillingHandler{
db: db,
subscriptionService: subscriptionService,
stripeService: stripeService,
entitlementService: entitlementService,
usageService: usageService,
}
}
// GetBillingStatus returns the current billing status for a user
// GET /api/v1/billing/status
func (h *BillingHandler) GetBillingStatus(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get subscription",
})
return
}
// Get available plans
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get plans",
})
return
}
response := models.BillingStatusResponse{
HasSubscription: subscription != nil,
AvailablePlans: plans,
}
if subscription != nil {
// Get plan details
plan, _ := h.subscriptionService.GetPlanByID(ctx, string(subscription.PlanID))
subInfo := &models.SubscriptionInfo{
PlanID: subscription.PlanID,
Status: subscription.Status,
IsTrialing: subscription.Status == models.StatusTrialing,
CancelAtPeriodEnd: subscription.CancelAtPeriodEnd,
CurrentPeriodEnd: subscription.CurrentPeriodEnd,
}
if plan != nil {
subInfo.PlanName = plan.Name
subInfo.PriceCents = plan.PriceCents
subInfo.Currency = plan.Currency
}
// Calculate trial days left
if subscription.TrialEnd != nil && subscription.Status == models.StatusTrialing {
// TODO: Calculate days left
}
response.Subscription = subInfo
// Get task usage info (legacy usage tracking - see TaskService for new task-based usage)
// TODO: Replace with TaskService.GetTaskUsageInfo for task-based billing
_, _ = h.usageService.GetUsageSummary(ctx, userID)
// Get entitlements
entitlements, _ := h.entitlementService.GetEntitlements(ctx, userID)
if entitlements != nil {
response.Entitlements = entitlements
}
}
c.JSON(http.StatusOK, response)
}
// GetPlans returns all available billing plans
// GET /api/v1/billing/plans
func (h *BillingHandler) GetPlans(c *gin.Context) {
ctx := c.Request.Context()
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get plans",
})
return
}
c.JSON(http.StatusOK, gin.H{
"plans": plans,
})
}
// StartTrial starts a trial for the user with a specific plan
// POST /api/v1/billing/trial/start
func (h *BillingHandler) StartTrial(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
var req models.StartTrialRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
// Check if user already has a subscription
existing, _ := h.subscriptionService.GetByUserID(ctx, userID)
if existing != nil {
c.JSON(http.StatusConflict, gin.H{
"error": "subscription_exists",
"message": "User already has a subscription",
})
return
}
// Get user email from context
email, _ := c.Get("email")
emailStr, _ := email.(string)
// Create Stripe checkout session
checkoutURL, sessionID, err := h.stripeService.CreateCheckoutSession(ctx, userID, emailStr, req.PlanID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to create checkout session",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.StartTrialResponse{
CheckoutURL: checkoutURL,
SessionID: sessionID,
})
}
// ChangePlan changes the user's subscription plan
// POST /api/v1/billing/change-plan
func (h *BillingHandler) ChangePlan(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
var req models.ChangePlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Change plan via Stripe
err = h.stripeService.ChangePlan(ctx, subscription.StripeSubscriptionID, req.NewPlanID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to change plan",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.ChangePlanResponse{
Success: true,
Message: "Plan changed successfully",
})
}
// CancelSubscription cancels the user's subscription at period end
// POST /api/v1/billing/cancel
func (h *BillingHandler) CancelSubscription(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Cancel at period end via Stripe
err = h.stripeService.CancelSubscription(ctx, subscription.StripeSubscriptionID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to cancel subscription",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.CancelSubscriptionResponse{
Success: true,
Message: "Subscription will be canceled at the end of the billing period",
})
}
// GetCustomerPortal returns a URL to the Stripe customer portal
// GET /api/v1/billing/portal
func (h *BillingHandler) GetCustomerPortal(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil || subscription.StripeCustomerID == "" {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Create portal session
portalURL, err := h.stripeService.CreateCustomerPortalSession(ctx, subscription.StripeCustomerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to create portal session",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.CustomerPortalResponse{
PortalURL: portalURL,
})
}
// =============================================
// Internal Endpoints (Service-to-Service)
// =============================================
// GetEntitlements returns entitlements for a user (internal)
// GET /api/v1/billing/entitlements/:userId
func (h *BillingHandler) GetEntitlements(c *gin.Context) {
userIDStr := c.Param("userId")
ctx := c.Request.Context()
entitlements, err := h.entitlementService.GetEntitlementsByUserIDString(ctx, userIDStr)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get entitlements",
})
return
}
if entitlements == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "not_found",
"message": "No entitlements found for user",
})
return
}
c.JSON(http.StatusOK, entitlements)
}
// TrackUsage tracks usage for a user (internal)
// POST /api/v1/billing/usage/track
func (h *BillingHandler) TrackUsage(c *gin.Context) {
var req models.TrackUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
quantity := req.Quantity
if quantity <= 0 {
quantity = 1
}
err := h.usageService.TrackUsage(ctx, req.UserID, req.UsageType, quantity)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to track usage",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Usage tracked",
})
}
// CheckUsage checks if usage is allowed (internal)
// GET /api/v1/billing/usage/check/:userId/:type
func (h *BillingHandler) CheckUsage(c *gin.Context) {
userIDStr := c.Param("userId")
usageType := c.Param("type")
ctx := c.Request.Context()
response, err := h.usageService.CheckUsageAllowed(ctx, userIDStr, usageType)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to check usage",
})
return
}
c.JSON(http.StatusOK, response)
}
// CheckEntitlement checks if a user has a specific entitlement (internal)
// GET /api/v1/billing/entitlements/check/:userId/:feature
func (h *BillingHandler) CheckEntitlement(c *gin.Context) {
userIDStr := c.Param("userId")
feature := c.Param("feature")
ctx := c.Request.Context()
hasEntitlement, planID, err := h.entitlementService.CheckEntitlement(ctx, userIDStr, feature)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to check entitlement",
})
return
}
c.JSON(http.StatusOK, models.EntitlementCheckResponse{
HasEntitlement: hasEntitlement,
PlanID: planID,
})
}

View File

@@ -0,0 +1,612 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/breakpilot/billing-service/internal/models"
"github.com/gin-gonic/gin"
)
func init() {
// Set Gin to test mode
gin.SetMode(gin.TestMode)
}
func TestGetPlans_ResponseFormat(t *testing.T) {
// Test that GetPlans returns the expected response structure
// Since we don't have a real database connection in unit tests,
// we test the expected structure and format
// Test that default plans are well-formed
plans := models.GetDefaultPlans()
if len(plans) == 0 {
t.Error("Default plans should not be empty")
}
for _, plan := range plans {
// Verify JSON serialization works
data, err := json.Marshal(plan)
if err != nil {
t.Errorf("Failed to marshal plan %s: %v", plan.ID, err)
}
// Verify we can unmarshal back
var decoded models.BillingPlan
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Errorf("Failed to unmarshal plan %s: %v", plan.ID, err)
}
// Verify key fields
if decoded.ID != plan.ID {
t.Errorf("Plan ID mismatch: got %s, expected %s", decoded.ID, plan.ID)
}
}
}
func TestBillingStatusResponse_Structure(t *testing.T) {
// Test the response structure
response := models.BillingStatusResponse{
HasSubscription: true,
Subscription: &models.SubscriptionInfo{
PlanID: models.PlanStandard,
PlanName: "Standard",
Status: models.StatusActive,
IsTrialing: false,
CancelAtPeriodEnd: false,
PriceCents: 1990,
Currency: "eur",
},
TaskUsage: &models.TaskUsageInfo{
TasksAvailable: 85,
MaxTasks: 500,
InfoText: "Aufgaben verfuegbar: 85 von max. 500",
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
},
Entitlements: &models.EntitlementInfo{
Features: []string{"basic_ai", "basic_documents", "templates", "batch_processing"},
MaxTeamMembers: 3,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: false,
},
AvailablePlans: models.GetDefaultPlans(),
}
// Test JSON serialization
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal BillingStatusResponse: %v", err)
}
// Verify it's valid JSON
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
// Check required fields exist
if _, ok := decoded["has_subscription"]; !ok {
t.Error("Response should have 'has_subscription' field")
}
}
func TestStartTrialRequest_Validation(t *testing.T) {
tests := []struct {
name string
request models.StartTrialRequest
wantError bool
}{
{
name: "Valid basic plan",
request: models.StartTrialRequest{PlanID: models.PlanBasic},
wantError: false,
},
{
name: "Valid standard plan",
request: models.StartTrialRequest{PlanID: models.PlanStandard},
wantError: false,
},
{
name: "Valid premium plan",
request: models.StartTrialRequest{PlanID: models.PlanPremium},
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test JSON serialization
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal request: %v", err)
}
var decoded models.StartTrialRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal request: %v", err)
}
if decoded.PlanID != tt.request.PlanID {
t.Errorf("PlanID mismatch: got %s, expected %s", decoded.PlanID, tt.request.PlanID)
}
})
}
}
func TestChangePlanRequest_Structure(t *testing.T) {
request := models.ChangePlanRequest{
NewPlanID: models.PlanPremium,
}
data, err := json.Marshal(request)
if err != nil {
t.Fatalf("Failed to marshal ChangePlanRequest: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["new_plan_id"]; !ok {
t.Error("Request should have 'new_plan_id' field")
}
}
func TestStartTrialResponse_Structure(t *testing.T) {
response := models.StartTrialResponse{
CheckoutURL: "https://checkout.stripe.com/c/pay/cs_test_123",
SessionID: "cs_test_123",
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal StartTrialResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["checkout_url"]; !ok {
t.Error("Response should have 'checkout_url' field")
}
if _, ok := decoded["session_id"]; !ok {
t.Error("Response should have 'session_id' field")
}
}
func TestCancelSubscriptionResponse_Structure(t *testing.T) {
response := models.CancelSubscriptionResponse{
Success: true,
Message: "Subscription will be canceled at the end of the billing period",
CancelDate: "2025-01-16",
ActiveUntil: "2025-01-16",
}
_, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal CancelSubscriptionResponse: %v", err)
}
if !response.Success {
t.Error("Success should be true")
}
}
func TestCustomerPortalResponse_Structure(t *testing.T) {
response := models.CustomerPortalResponse{
PortalURL: "https://billing.stripe.com/p/session/test_123",
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal CustomerPortalResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["portal_url"]; !ok {
t.Error("Response should have 'portal_url' field")
}
}
func TestEntitlementCheckResponse_Structure(t *testing.T) {
tests := []struct {
name string
response models.EntitlementCheckResponse
}{
{
name: "Has entitlement",
response: models.EntitlementCheckResponse{
HasEntitlement: true,
PlanID: models.PlanStandard,
},
},
{
name: "No entitlement",
response: models.EntitlementCheckResponse{
HasEntitlement: false,
PlanID: models.PlanBasic,
Message: "Feature not available in this plan",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal EntitlementCheckResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["has_entitlement"]; !ok {
t.Error("Response should have 'has_entitlement' field")
}
})
}
}
func TestTrackUsageRequest_Validation(t *testing.T) {
tests := []struct {
name string
request models.TrackUsageRequest
valid bool
}{
{
name: "Valid AI request",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "ai_request",
Quantity: 1,
},
valid: true,
},
{
name: "Valid document created",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "document_created",
Quantity: 1,
},
valid: true,
},
{
name: "Multiple quantity",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "ai_request",
Quantity: 5,
},
valid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal TrackUsageRequest: %v", err)
}
var decoded models.TrackUsageRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal TrackUsageRequest: %v", err)
}
if decoded.UserID != tt.request.UserID {
t.Errorf("UserID mismatch: got %s, expected %s", decoded.UserID, tt.request.UserID)
}
})
}
}
func TestCheckUsageResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.CheckUsageResponse
}{
{
name: "Allowed response",
response: models.CheckUsageResponse{
Allowed: true,
CurrentUsage: 450,
Limit: 1500,
Remaining: 1050,
},
},
{
name: "Limit reached",
response: models.CheckUsageResponse{
Allowed: false,
CurrentUsage: 1500,
Limit: 1500,
Remaining: 0,
Message: "Usage limit reached for ai_request (1500/1500)",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal CheckUsageResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["allowed"]; !ok {
t.Error("Response should have 'allowed' field")
}
})
}
}
func TestConsumeTaskRequest_Format(t *testing.T) {
tests := []struct {
name string
request models.ConsumeTaskRequest
}{
{
name: "Correction task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeCorrection,
},
},
{
name: "Letter task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeLetter,
},
},
{
name: "Batch task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeBatch,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal ConsumeTaskRequest: %v", err)
}
var decoded models.ConsumeTaskRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal ConsumeTaskRequest: %v", err)
}
if decoded.TaskType != tt.request.TaskType {
t.Errorf("TaskType mismatch: got %s, expected %s", decoded.TaskType, tt.request.TaskType)
}
})
}
}
func TestConsumeTaskResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.ConsumeTaskResponse
}{
{
name: "Successful consumption",
response: models.ConsumeTaskResponse{
Success: true,
TaskID: "task-uuid-123",
TasksRemaining: 49,
},
},
{
name: "Limit reached",
response: models.ConsumeTaskResponse{
Success: false,
TasksRemaining: 0,
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal ConsumeTaskResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["success"]; !ok {
t.Error("Response should have 'success' field")
}
if _, ok := decoded["tasks_remaining"]; !ok {
t.Error("Response should have 'tasks_remaining' field")
}
})
}
}
func TestCheckTaskAllowedResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.CheckTaskAllowedResponse
}{
{
name: "Task allowed",
response: models.CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: 50,
MaxTasks: 150,
PlanID: models.PlanBasic,
},
},
{
name: "Task not allowed",
response: models.CheckTaskAllowedResponse{
Allowed: false,
TasksAvailable: 0,
MaxTasks: 150,
PlanID: models.PlanBasic,
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
},
},
{
name: "Premium Fair Use",
response: models.CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: 1000,
MaxTasks: 5000,
PlanID: models.PlanPremium,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal CheckTaskAllowedResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["allowed"]; !ok {
t.Error("Response should have 'allowed' field")
}
if _, ok := decoded["tasks_available"]; !ok {
t.Error("Response should have 'tasks_available' field")
}
if _, ok := decoded["plan_id"]; !ok {
t.Error("Response should have 'plan_id' field")
}
})
}
}
// HTTP Handler Tests (without DB)
func TestHTTPErrorResponse_Format(t *testing.T) {
// Test standard error response format
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Simulate an error response
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if _, ok := response["error"]; !ok {
t.Error("Error response should have 'error' field")
}
if _, ok := response["message"]; !ok {
t.Error("Error response should have 'message' field")
}
}
func TestHTTPSuccessResponse_Format(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Simulate a success response
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Operation completed",
})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["success"] != true {
t.Error("Success response should have success=true")
}
}
func TestRequestParsing_InvalidJSON(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request with invalid JSON
invalidJSON := []byte(`{"plan_id": }`) // Invalid JSON
c.Request = httptest.NewRequest("POST", "/test", bytes.NewReader(invalidJSON))
c.Request.Header.Set("Content-Type", "application/json")
var req models.StartTrialRequest
err := c.ShouldBindJSON(&req)
if err == nil {
t.Error("Should return error for invalid JSON")
}
}
func TestHTTPHeaders_ContentType(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"test": "value"})
contentType := w.Header().Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Errorf("Expected JSON content type, got %s", contentType)
}
}

View File

@@ -0,0 +1,205 @@
package handlers
import (
"io"
"log"
"net/http"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76/webhook"
)
// WebhookHandler handles Stripe webhook events
type WebhookHandler struct {
db *database.DB
webhookSecret string
subscriptionService *services.SubscriptionService
entitlementService *services.EntitlementService
}
// NewWebhookHandler creates a new WebhookHandler
func NewWebhookHandler(
db *database.DB,
webhookSecret string,
subscriptionService *services.SubscriptionService,
entitlementService *services.EntitlementService,
) *WebhookHandler {
return &WebhookHandler{
db: db,
webhookSecret: webhookSecret,
subscriptionService: subscriptionService,
entitlementService: entitlementService,
}
}
// HandleStripeWebhook handles incoming Stripe webhook events
// POST /api/v1/billing/webhook
func (h *WebhookHandler) HandleStripeWebhook(c *gin.Context) {
// Read the request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("Webhook: Error reading body: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "cannot read body"})
return
}
// Get the Stripe signature header
sigHeader := c.GetHeader("Stripe-Signature")
if sigHeader == "" {
log.Printf("Webhook: Missing Stripe-Signature header")
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
return
}
// Verify the webhook signature
event, err := webhook.ConstructEvent(body, sigHeader, h.webhookSecret)
if err != nil {
log.Printf("Webhook: Signature verification failed: %v", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
return
}
ctx := c.Request.Context()
// Check if we've already processed this event (idempotency)
processed, err := h.subscriptionService.IsEventProcessed(ctx, event.ID)
if err != nil {
log.Printf("Webhook: Error checking event: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
return
}
if processed {
log.Printf("Webhook: Event %s already processed", event.ID)
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
return
}
// Mark event as being processed
if err := h.subscriptionService.MarkEventProcessing(ctx, event.ID, string(event.Type)); err != nil {
log.Printf("Webhook: Error marking event: %v", err)
}
// Handle the event based on type
var handleErr error
switch event.Type {
case "checkout.session.completed":
handleErr = h.handleCheckoutSessionCompleted(ctx, event.Data.Raw)
case "customer.subscription.created":
handleErr = h.handleSubscriptionCreated(ctx, event.Data.Raw)
case "customer.subscription.updated":
handleErr = h.handleSubscriptionUpdated(ctx, event.Data.Raw)
case "customer.subscription.deleted":
handleErr = h.handleSubscriptionDeleted(ctx, event.Data.Raw)
case "invoice.paid":
handleErr = h.handleInvoicePaid(ctx, event.Data.Raw)
case "invoice.payment_failed":
handleErr = h.handleInvoicePaymentFailed(ctx, event.Data.Raw)
case "customer.created":
log.Printf("Webhook: Customer created - %s", event.ID)
default:
log.Printf("Webhook: Unhandled event type: %s", event.Type)
}
if handleErr != nil {
log.Printf("Webhook: Error handling %s: %v", event.Type, handleErr)
// Mark event as failed
h.subscriptionService.MarkEventFailed(ctx, event.ID, handleErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
return
}
// Mark event as processed
if err := h.subscriptionService.MarkEventProcessed(ctx, event.ID); err != nil {
log.Printf("Webhook: Error marking event processed: %v", err)
}
c.JSON(http.StatusOK, gin.H{"status": "processed"})
}
// handleCheckoutSessionCompleted handles successful checkout
func (h *WebhookHandler) handleCheckoutSessionCompleted(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing checkout.session.completed")
// Parse checkout session from data
// The actual implementation will parse the JSON and create/update subscription
// TODO: Implementation
// 1. Parse checkout session data
// 2. Extract customer_id, subscription_id, user_id (from metadata)
// 3. Create or update subscription record
// 4. Update entitlements
return nil
}
// handleSubscriptionCreated handles new subscription creation
func (h *WebhookHandler) handleSubscriptionCreated(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.created")
// TODO: Implementation
// 1. Parse subscription data
// 2. Extract status, plan, trial_end, etc.
// 3. Create subscription record
// 4. Set up initial entitlements
return nil
}
// handleSubscriptionUpdated handles subscription updates
func (h *WebhookHandler) handleSubscriptionUpdated(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.updated")
// TODO: Implementation
// 1. Parse subscription data
// 2. Update subscription record (status, plan, cancel_at_period_end, etc.)
// 3. Update entitlements if plan changed
return nil
}
// handleSubscriptionDeleted handles subscription cancellation
func (h *WebhookHandler) handleSubscriptionDeleted(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.deleted")
// TODO: Implementation
// 1. Parse subscription data
// 2. Update subscription status to canceled/expired
// 3. Remove or downgrade entitlements
return nil
}
// handleInvoicePaid handles successful invoice payment
func (h *WebhookHandler) handleInvoicePaid(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing invoice.paid")
// TODO: Implementation
// 1. Parse invoice data
// 2. Update subscription period
// 3. Reset usage counters for new period
// 4. Store invoice record
return nil
}
// handleInvoicePaymentFailed handles failed invoice payment
func (h *WebhookHandler) handleInvoicePaymentFailed(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing invoice.payment_failed")
// TODO: Implementation
// 1. Parse invoice data
// 2. Update subscription status to past_due
// 3. Send notification to user
// 4. Possibly restrict access
return nil
}

View File

@@ -0,0 +1,433 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// TestWebhookEventTypes tests the event types we handle
func TestWebhookEventTypes(t *testing.T) {
eventTypes := []struct {
eventType string
shouldHandle bool
}{
{"checkout.session.completed", true},
{"customer.subscription.created", true},
{"customer.subscription.updated", true},
{"customer.subscription.deleted", true},
{"invoice.paid", true},
{"invoice.payment_failed", true},
{"customer.created", true}, // Handled but just logged
{"unknown.event.type", false},
}
for _, tt := range eventTypes {
t.Run(tt.eventType, func(t *testing.T) {
if tt.eventType == "" {
t.Error("Event type should not be empty")
}
})
}
}
// TestWebhookRequest_MissingSignature tests handling of missing signature
func TestWebhookRequest_MissingSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request without Stripe-Signature header
body := []byte(`{"id": "evt_test_123", "type": "test.event"}`)
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
// Note: No Stripe-Signature header
// Simulate the check we do in the handler
sigHeader := c.GetHeader("Stripe-Signature")
if sigHeader == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
}
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing signature, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "missing signature" {
t.Errorf("Expected 'missing signature' error, got '%v'", response["error"])
}
}
// TestWebhookRequest_EmptyBody tests handling of empty request body
func TestWebhookRequest_EmptyBody(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request with empty body
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader([]byte{}))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("Stripe-Signature", "t=123,v1=signature")
// Read the body
body := make([]byte, 0)
// Simulate empty body handling
if len(body) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "empty body"})
}
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for empty body, got %d", w.Code)
}
}
// TestWebhookIdempotency tests idempotency behavior
func TestWebhookIdempotency(t *testing.T) {
// Test that the same event ID should not be processed twice
eventID := "evt_test_123456789"
// Simulate event tracking
processedEvents := make(map[string]bool)
// First time - should process
if !processedEvents[eventID] {
processedEvents[eventID] = true
}
// Second time - should skip
alreadyProcessed := processedEvents[eventID]
if !alreadyProcessed {
t.Error("Event should be marked as processed")
}
}
// TestWebhookResponse_Processed tests successful webhook response
func TestWebhookResponse_Processed(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"status": "processed"})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["status"] != "processed" {
t.Errorf("Expected status 'processed', got '%v'", response["status"])
}
}
// TestWebhookResponse_AlreadyProcessed tests idempotent response
func TestWebhookResponse_AlreadyProcessed(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["status"] != "already_processed" {
t.Errorf("Expected status 'already_processed', got '%v'", response["status"])
}
}
// TestWebhookResponse_InternalError tests error response
func TestWebhookResponse_InternalError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
if w.Code != http.StatusInternalServerError {
t.Errorf("Expected status 500, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "handler error" {
t.Errorf("Expected 'handler error', got '%v'", response["error"])
}
}
// TestWebhookResponse_InvalidSignature tests signature verification failure
func TestWebhookResponse_InvalidSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "invalid signature" {
t.Errorf("Expected 'invalid signature', got '%v'", response["error"])
}
}
// TestCheckoutSessionCompleted_EventStructure tests the event data structure
func TestCheckoutSessionCompleted_EventStructure(t *testing.T) {
// Test the expected structure of a checkout.session.completed event
eventData := map[string]interface{}{
"id": "cs_test_123",
"customer": "cus_test_456",
"subscription": "sub_test_789",
"mode": "subscription",
"payment_status": "paid",
"status": "complete",
"metadata": map[string]interface{}{
"user_id": "550e8400-e29b-41d4-a716-446655440000",
"plan_id": "standard",
},
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["customer"] == nil {
t.Error("Event should have 'customer' field")
}
if decoded["subscription"] == nil {
t.Error("Event should have 'subscription' field")
}
metadata, ok := decoded["metadata"].(map[string]interface{})
if !ok || metadata["user_id"] == nil {
t.Error("Event should have 'metadata.user_id' field")
}
}
// TestSubscriptionCreated_EventStructure tests subscription.created event structure
func TestSubscriptionCreated_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "sub_test_123",
"customer": "cus_test_456",
"status": "trialing",
"items": map[string]interface{}{
"data": []map[string]interface{}{
{
"price": map[string]interface{}{
"id": "price_test_789",
"metadata": map[string]interface{}{"plan_id": "standard"},
},
},
},
},
"trial_end": 1735689600,
"current_period_end": 1735689600,
"metadata": map[string]interface{}{
"user_id": "550e8400-e29b-41d4-a716-446655440000",
"plan_id": "standard",
},
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "trialing" {
t.Errorf("Expected status 'trialing', got '%v'", decoded["status"])
}
}
// TestSubscriptionUpdated_StatusTransitions tests subscription status transitions
func TestSubscriptionUpdated_StatusTransitions(t *testing.T) {
validTransitions := []struct {
from string
to string
}{
{"trialing", "active"},
{"active", "past_due"},
{"past_due", "active"},
{"active", "canceled"},
{"trialing", "canceled"},
}
for _, tt := range validTransitions {
t.Run(tt.from+"->"+tt.to, func(t *testing.T) {
if tt.from == "" || tt.to == "" {
t.Error("Status should not be empty")
}
})
}
}
// TestInvoicePaid_EventStructure tests invoice.paid event structure
func TestInvoicePaid_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "in_test_123",
"subscription": "sub_test_456",
"customer": "cus_test_789",
"status": "paid",
"amount_paid": 1990,
"currency": "eur",
"period_start": 1735689600,
"period_end": 1738368000,
"hosted_invoice_url": "https://invoice.stripe.com/test",
"invoice_pdf": "https://invoice.stripe.com/test.pdf",
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "paid" {
t.Errorf("Expected status 'paid', got '%v'", decoded["status"])
}
if decoded["subscription"] == nil {
t.Error("Event should have 'subscription' field")
}
}
// TestInvoicePaymentFailed_EventStructure tests invoice.payment_failed event structure
func TestInvoicePaymentFailed_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "in_test_123",
"subscription": "sub_test_456",
"customer": "cus_test_789",
"status": "open",
"attempt_count": 1,
"next_payment_attempt": 1735776000,
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify fields
if decoded["attempt_count"] == nil {
t.Error("Event should have 'attempt_count' field")
}
}
// TestSubscriptionDeleted_EventStructure tests subscription.deleted event structure
func TestSubscriptionDeleted_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "sub_test_123",
"customer": "cus_test_456",
"status": "canceled",
"ended_at": 1735689600,
"canceled_at": 1735689600,
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "canceled" {
t.Errorf("Expected status 'canceled', got '%v'", decoded["status"])
}
}
// TestStripeSignatureFormat tests the Stripe signature header format
func TestStripeSignatureFormat(t *testing.T) {
// Stripe signature format: t=timestamp,v1=signature
validSignatures := []string{
"t=1609459200,v1=abc123def456",
"t=1609459200,v1=signature_here,v0=old_signature",
}
for _, sig := range validSignatures {
if len(sig) < 10 {
t.Errorf("Signature seems too short: %s", sig)
}
// Should start with timestamp
if sig[:2] != "t=" {
t.Errorf("Signature should start with 't=': %s", sig)
}
}
}
// TestWebhookEventID_Format tests Stripe event ID format
func TestWebhookEventID_Format(t *testing.T) {
validEventIDs := []string{
"evt_1234567890abcdef",
"evt_test_123456789",
"evt_live_987654321",
}
for _, eventID := range validEventIDs {
// Event IDs should start with "evt_"
if len(eventID) < 10 || eventID[:4] != "evt_" {
t.Errorf("Invalid event ID format: %s", eventID)
}
}
}

View File

@@ -0,0 +1,288 @@
package middleware
import (
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// UserClaims represents the JWT claims for a user
type UserClaims struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// CORS returns a CORS middleware
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
// Allow localhost for development
allowedOrigins := []string{
"http://localhost:3000",
"http://localhost:8000",
"http://localhost:8080",
"http://localhost:8083",
"https://breakpilot.app",
}
allowed := false
for _, o := range allowedOrigins {
if origin == o {
allowed = true
break
}
}
if allowed {
c.Header("Access-Control-Allow-Origin", origin)
}
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With, X-Internal-API-Key")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// RequestLogger logs each request
func RequestLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
method := c.Request.Method
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
// Log only in development or for errors
if status >= 400 {
gin.DefaultWriter.Write([]byte(
method + " " + path + " " +
string(rune(status)) + " " +
latency.String() + "\n",
))
}
}
}
// RateLimiter implements a simple in-memory rate limiter
func RateLimiter() gin.HandlerFunc {
type client struct {
count int
lastSeen time.Time
}
var (
mu sync.Mutex
clients = make(map[string]*client)
)
// Clean up old entries periodically
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
for ip, c := range clients {
if time.Since(c.lastSeen) > time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(c *gin.Context) {
ip := c.ClientIP()
mu.Lock()
defer mu.Unlock()
if _, exists := clients[ip]; !exists {
clients[ip] = &client{}
}
cli := clients[ip]
// Reset count if more than a minute has passed
if time.Since(cli.lastSeen) > time.Minute {
cli.count = 0
}
cli.count++
cli.lastSeen = time.Now()
// Allow 100 requests per minute
if cli.count > 100 {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate_limit_exceeded",
"message": "Too many requests. Please try again later.",
})
return
}
c.Next()
}
}
// AuthMiddleware validates JWT tokens
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "missing_authorization",
"message": "Authorization header is required",
})
return
}
// Extract token from "Bearer <token>"
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_authorization",
"message": "Authorization header must be in format: Bearer <token>",
})
return
}
tokenString := parts[1]
// Parse and validate token
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(jwtSecret), nil
})
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_token",
"message": "Invalid or expired token",
})
return
}
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
// Set user info in context
c.Set("user_id", claims.UserID)
c.Set("email", claims.Email)
c.Set("role", claims.Role)
c.Next()
} else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_claims",
"message": "Invalid token claims",
})
return
}
}
}
// InternalAPIKeyMiddleware validates internal API key for service-to-service communication
func InternalAPIKeyMiddleware(apiKey string) gin.HandlerFunc {
return func(c *gin.Context) {
if apiKey == "" {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"error": "config_error",
"message": "Internal API key not configured",
})
return
}
providedKey := c.GetHeader("X-Internal-API-Key")
if providedKey == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "missing_api_key",
"message": "X-Internal-API-Key header is required",
})
return
}
if providedKey != apiKey {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_api_key",
"message": "Invalid API key",
})
return
}
c.Next()
}
}
// AdminOnly ensures only admin users can access the route
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User role not found",
})
return
}
roleStr, ok := role.(string)
if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "forbidden",
"message": "Admin access required",
})
return
}
c.Next()
}
}
// GetUserID extracts the user ID from the context
func GetUserID(c *gin.Context) (uuid.UUID, error) {
userIDStr, exists := c.Get("user_id")
if !exists {
return uuid.Nil, nil
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
return uuid.Nil, err
}
return userID, nil
}
// GetClientIP returns the client's IP address
func GetClientIP(c *gin.Context) string {
// Check X-Forwarded-For header first (for proxied requests)
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// GetUserAgent returns the client's User-Agent
func GetUserAgent(c *gin.Context) string {
return c.GetHeader("User-Agent")
}

View File

@@ -0,0 +1,372 @@
package models
import (
"time"
"github.com/google/uuid"
)
// SubscriptionStatus represents the status of a subscription
type SubscriptionStatus string
const (
StatusTrialing SubscriptionStatus = "trialing"
StatusActive SubscriptionStatus = "active"
StatusPastDue SubscriptionStatus = "past_due"
StatusCanceled SubscriptionStatus = "canceled"
StatusExpired SubscriptionStatus = "expired"
)
// PlanID represents the available plan IDs
type PlanID string
const (
PlanBasic PlanID = "basic"
PlanStandard PlanID = "standard"
PlanPremium PlanID = "premium"
)
// TaskType represents the type of task
type TaskType string
const (
TaskTypeCorrection TaskType = "correction"
TaskTypeLetter TaskType = "letter"
TaskTypeMeeting TaskType = "meeting"
TaskTypeBatch TaskType = "batch"
TaskTypeOther TaskType = "other"
)
// CarryoverMonthsCap is the maximum number of months tasks can accumulate
const CarryoverMonthsCap = 5
// Subscription represents a user's subscription
type Subscription struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
StripeCustomerID string `json:"stripe_customer_id"`
StripeSubscriptionID string `json:"stripe_subscription_id"`
PlanID PlanID `json:"plan_id"`
Status SubscriptionStatus `json:"status"`
TrialEnd *time.Time `json:"trial_end,omitempty"`
CurrentPeriodEnd *time.Time `json:"current_period_end,omitempty"`
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// BillingPlan represents a billing plan with its features and limits
type BillingPlan struct {
ID PlanID `json:"id"`
StripePriceID string `json:"stripe_price_id"`
Name string `json:"name"`
Description string `json:"description"`
PriceCents int `json:"price_cents"` // Price in cents (990 = 9.90 EUR)
Currency string `json:"currency"`
Interval string `json:"interval"` // "month" or "year"
Features PlanFeatures `json:"features"`
IsActive bool `json:"is_active"`
SortOrder int `json:"sort_order"`
}
// PlanFeatures represents the features and limits of a plan
type PlanFeatures struct {
// Task-based limits (primary billing unit)
MonthlyTaskAllowance int `json:"monthly_task_allowance"` // Tasks per month
MaxTaskBalance int `json:"max_task_balance"` // Max accumulated tasks (allowance * CarryoverMonthsCap)
// Legacy fields for backward compatibility (deprecated, use task-based limits)
AIRequestsLimit int `json:"ai_requests_limit,omitempty"`
DocumentsLimit int `json:"documents_limit,omitempty"`
// Feature flags
FeatureFlags []string `json:"feature_flags"`
MaxTeamMembers int `json:"max_team_members,omitempty"`
PrioritySupport bool `json:"priority_support"`
CustomBranding bool `json:"custom_branding"`
BatchProcessing bool `json:"batch_processing"`
CustomTemplates bool `json:"custom_templates"`
// Premium: Fair Use (no visible limit)
FairUseMode bool `json:"fair_use_mode"`
}
// Task represents a single task that consumes 1 unit from the balance
type Task struct {
ID uuid.UUID `json:"id"`
AccountID uuid.UUID `json:"account_id"`
TaskType TaskType `json:"task_type"`
CreatedAt time.Time `json:"created_at"`
Consumed bool `json:"consumed"` // Always true when created
// Internal metrics (not shown to user)
PageCount int `json:"-"`
TokenCount int `json:"-"`
ProcessTime int `json:"-"` // in seconds
}
// AccountUsage represents the task-based usage for an account
type AccountUsage struct {
ID uuid.UUID `json:"id"`
AccountID uuid.UUID `json:"account_id"`
PlanID PlanID `json:"plan"`
MonthlyTaskAllowance int `json:"monthly_task_allowance"`
CarryoverMonthsCap int `json:"carryover_months_cap"` // Always 5
MaxTaskBalance int `json:"max_task_balance"` // allowance * cap
TaskBalance int `json:"task_balance"` // Current available tasks
LastRenewalAt time.Time `json:"last_renewal_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// UsageSummary tracks usage for a specific period (internal metrics)
type UsageSummary struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
UsageType string `json:"usage_type"` // "task", "page", "token"
PeriodStart time.Time `json:"period_start"`
TotalCount int `json:"total_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// UserEntitlements represents cached entitlements for a user
type UserEntitlements struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
PlanID PlanID `json:"plan_id"`
TaskBalance int `json:"task_balance"`
MaxBalance int `json:"max_balance"`
Features PlanFeatures `json:"features"`
UpdatedAt time.Time `json:"updated_at"`
// Legacy fields for backward compatibility with old entitlement service
AIRequestsLimit int `json:"ai_requests_limit"`
AIRequestsUsed int `json:"ai_requests_used"`
DocumentsLimit int `json:"documents_limit"`
DocumentsUsed int `json:"documents_used"`
}
// StripeWebhookEvent tracks processed webhook events for idempotency
type StripeWebhookEvent struct {
StripeEventID string `json:"stripe_event_id"`
EventType string `json:"event_type"`
Processed bool `json:"processed"`
ProcessedAt time.Time `json:"processed_at"`
CreatedAt time.Time `json:"created_at"`
}
// BillingStatusResponse is the response for the billing status endpoint
type BillingStatusResponse struct {
HasSubscription bool `json:"has_subscription"`
Subscription *SubscriptionInfo `json:"subscription,omitempty"`
TaskUsage *TaskUsageInfo `json:"task_usage,omitempty"`
Entitlements *EntitlementInfo `json:"entitlements,omitempty"`
AvailablePlans []BillingPlan `json:"available_plans,omitempty"`
}
// SubscriptionInfo contains subscription details for the response
type SubscriptionInfo struct {
PlanID PlanID `json:"plan_id"`
PlanName string `json:"plan_name"`
Status SubscriptionStatus `json:"status"`
IsTrialing bool `json:"is_trialing"`
TrialDaysLeft int `json:"trial_days_left,omitempty"`
CurrentPeriodEnd *time.Time `json:"current_period_end,omitempty"`
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
PriceCents int `json:"price_cents"`
Currency string `json:"currency"`
}
// TaskUsageInfo contains current task usage information
// This is the ONLY usage info shown to users
type TaskUsageInfo struct {
TasksAvailable int `json:"tasks_available"` // Current balance
MaxTasks int `json:"max_tasks"` // Max possible balance
InfoText string `json:"info_text"` // "Aufgaben verfuegbar: X von max. Y"
TooltipText string `json:"tooltip_text"` // "Aufgaben koennen sich bis zu 5 Monate ansammeln."
}
// EntitlementInfo contains feature entitlements
type EntitlementInfo struct {
Features []string `json:"features"`
MaxTeamMembers int `json:"max_team_members,omitempty"`
PrioritySupport bool `json:"priority_support"`
CustomBranding bool `json:"custom_branding"`
BatchProcessing bool `json:"batch_processing"`
CustomTemplates bool `json:"custom_templates"`
FairUseMode bool `json:"fair_use_mode"` // Premium only
}
// StartTrialRequest is the request to start a trial
type StartTrialRequest struct {
PlanID PlanID `json:"plan_id" binding:"required"`
}
// StartTrialResponse is the response after starting a trial
type StartTrialResponse struct {
CheckoutURL string `json:"checkout_url"`
SessionID string `json:"session_id"`
}
// ChangePlanRequest is the request to change plans
type ChangePlanRequest struct {
NewPlanID PlanID `json:"new_plan_id" binding:"required"`
}
// ChangePlanResponse is the response after changing plans
type ChangePlanResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
EffectiveDate string `json:"effective_date,omitempty"`
}
// CancelSubscriptionResponse is the response after canceling
type CancelSubscriptionResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
CancelDate string `json:"cancel_date"`
ActiveUntil string `json:"active_until"`
}
// CustomerPortalResponse contains the portal URL
type CustomerPortalResponse struct {
PortalURL string `json:"portal_url"`
}
// ConsumeTaskRequest is the request to consume a task (internal)
type ConsumeTaskRequest struct {
UserID string `json:"user_id" binding:"required"`
TaskType TaskType `json:"task_type" binding:"required"`
}
// ConsumeTaskResponse is the response after consuming a task
type ConsumeTaskResponse struct {
Success bool `json:"success"`
TaskID string `json:"task_id,omitempty"`
TasksRemaining int `json:"tasks_remaining"`
Message string `json:"message,omitempty"`
}
// CheckTaskAllowedResponse is the response for task limit checks
type CheckTaskAllowedResponse struct {
Allowed bool `json:"allowed"`
TasksAvailable int `json:"tasks_available"`
MaxTasks int `json:"max_tasks"`
PlanID PlanID `json:"plan_id"`
Message string `json:"message,omitempty"`
}
// EntitlementCheckResponse is the response for entitlement checks (internal)
type EntitlementCheckResponse struct {
HasEntitlement bool `json:"has_entitlement"`
PlanID PlanID `json:"plan_id,omitempty"`
Message string `json:"message,omitempty"`
}
// TaskLimitError represents the error when task limit is reached
type TaskLimitError struct {
Error string `json:"error"`
CurrentBalance int `json:"current_balance"`
Plan PlanID `json:"plan"`
}
// UsageInfo represents current usage information (legacy, prefer TaskUsageInfo)
type UsageInfo struct {
AIRequestsUsed int `json:"ai_requests_used"`
AIRequestsLimit int `json:"ai_requests_limit"`
AIRequestsPercent float64 `json:"ai_requests_percent"`
DocumentsUsed int `json:"documents_used"`
DocumentsLimit int `json:"documents_limit"`
DocumentsPercent float64 `json:"documents_percent"`
PeriodStart string `json:"period_start"`
PeriodEnd string `json:"period_end"`
}
// CheckUsageResponse is the response for legacy usage checks
type CheckUsageResponse struct {
Allowed bool `json:"allowed"`
CurrentUsage int `json:"current_usage"`
Limit int `json:"limit"`
Remaining int `json:"remaining"`
Message string `json:"message,omitempty"`
}
// TrackUsageRequest is the request to track usage (internal)
type TrackUsageRequest struct {
UserID string `json:"user_id" binding:"required"`
UsageType string `json:"usage_type" binding:"required"`
Quantity int `json:"quantity"`
}
// GetDefaultPlans returns the default billing plans with task-based limits
func GetDefaultPlans() []BillingPlan {
return []BillingPlan{
{
ID: PlanBasic,
Name: "Basic",
Description: "Perfekt fuer den Einstieg - Gelegentliche Nutzung",
PriceCents: 990, // 9.90 EUR
Currency: "eur",
Interval: "month",
Features: PlanFeatures{
MonthlyTaskAllowance: 30, // 30 tasks/month
MaxTaskBalance: 30 * CarryoverMonthsCap, // 150 max
FeatureFlags: []string{"basic_ai", "basic_documents"},
MaxTeamMembers: 1,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: false,
CustomTemplates: false,
FairUseMode: false,
},
IsActive: true,
SortOrder: 1,
},
{
ID: PlanStandard,
Name: "Standard",
Description: "Fuer regelmaessige Nutzer - Mehrere Klassen und regelmaessige Korrekturen",
PriceCents: 1990, // 19.90 EUR
Currency: "eur",
Interval: "month",
Features: PlanFeatures{
MonthlyTaskAllowance: 100, // 100 tasks/month
MaxTaskBalance: 100 * CarryoverMonthsCap, // 500 max
FeatureFlags: []string{"basic_ai", "basic_documents", "templates", "batch_processing"},
MaxTeamMembers: 3,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: false,
},
IsActive: true,
SortOrder: 2,
},
{
ID: PlanPremium,
Name: "Premium",
Description: "Sorglos-Tarif - Vielnutzer, Teams, schulischer Kontext",
PriceCents: 3990, // 39.90 EUR
Currency: "eur",
Interval: "month",
Features: PlanFeatures{
MonthlyTaskAllowance: 1000, // Very high (Fair Use)
MaxTaskBalance: 1000 * CarryoverMonthsCap, // 5000 max (not shown to user)
FeatureFlags: []string{"basic_ai", "basic_documents", "templates", "batch_processing", "team_features", "admin_panel", "audit_log", "api_access"},
MaxTeamMembers: 10,
PrioritySupport: true,
CustomBranding: true,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: true, // No visible limit
},
IsActive: true,
SortOrder: 3,
},
}
}
// CalculateMaxTaskBalance calculates max task balance from monthly allowance
func CalculateMaxTaskBalance(monthlyAllowance int) int {
return monthlyAllowance * CarryoverMonthsCap
}

View File

@@ -0,0 +1,319 @@
package models
import (
"testing"
)
func TestCarryoverMonthsCap(t *testing.T) {
// Verify the constant is set correctly
if CarryoverMonthsCap != 5 {
t.Errorf("CarryoverMonthsCap should be 5, got %d", CarryoverMonthsCap)
}
}
func TestCalculateMaxTaskBalance(t *testing.T) {
tests := []struct {
name string
monthlyAllowance int
expected int
}{
{"Basic plan", 30, 150},
{"Standard plan", 100, 500},
{"Premium plan", 1000, 5000},
{"Zero allowance", 0, 0},
{"Single task", 1, 5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := CalculateMaxTaskBalance(tt.monthlyAllowance)
if result != tt.expected {
t.Errorf("CalculateMaxTaskBalance(%d) = %d, expected %d",
tt.monthlyAllowance, result, tt.expected)
}
})
}
}
func TestGetDefaultPlans(t *testing.T) {
plans := GetDefaultPlans()
if len(plans) != 3 {
t.Fatalf("Expected 3 plans, got %d", len(plans))
}
// Test Basic plan
basic := plans[0]
if basic.ID != PlanBasic {
t.Errorf("First plan should be Basic, got %s", basic.ID)
}
if basic.PriceCents != 990 {
t.Errorf("Basic price should be 990 cents, got %d", basic.PriceCents)
}
if basic.Features.MonthlyTaskAllowance != 30 {
t.Errorf("Basic monthly allowance should be 30, got %d", basic.Features.MonthlyTaskAllowance)
}
if basic.Features.MaxTaskBalance != 150 {
t.Errorf("Basic max balance should be 150, got %d", basic.Features.MaxTaskBalance)
}
if basic.Features.FairUseMode {
t.Error("Basic should not have FairUseMode")
}
// Test Standard plan
standard := plans[1]
if standard.ID != PlanStandard {
t.Errorf("Second plan should be Standard, got %s", standard.ID)
}
if standard.PriceCents != 1990 {
t.Errorf("Standard price should be 1990 cents, got %d", standard.PriceCents)
}
if standard.Features.MonthlyTaskAllowance != 100 {
t.Errorf("Standard monthly allowance should be 100, got %d", standard.Features.MonthlyTaskAllowance)
}
if !standard.Features.BatchProcessing {
t.Error("Standard should have BatchProcessing")
}
if !standard.Features.CustomTemplates {
t.Error("Standard should have CustomTemplates")
}
// Test Premium plan
premium := plans[2]
if premium.ID != PlanPremium {
t.Errorf("Third plan should be Premium, got %s", premium.ID)
}
if premium.PriceCents != 3990 {
t.Errorf("Premium price should be 3990 cents, got %d", premium.PriceCents)
}
if !premium.Features.FairUseMode {
t.Error("Premium should have FairUseMode")
}
if !premium.Features.PrioritySupport {
t.Error("Premium should have PrioritySupport")
}
if !premium.Features.CustomBranding {
t.Error("Premium should have CustomBranding")
}
}
func TestPlanIDConstants(t *testing.T) {
if PlanBasic != "basic" {
t.Errorf("PlanBasic should be 'basic', got '%s'", PlanBasic)
}
if PlanStandard != "standard" {
t.Errorf("PlanStandard should be 'standard', got '%s'", PlanStandard)
}
if PlanPremium != "premium" {
t.Errorf("PlanPremium should be 'premium', got '%s'", PlanPremium)
}
}
func TestSubscriptionStatusConstants(t *testing.T) {
statuses := []struct {
status SubscriptionStatus
expected string
}{
{StatusTrialing, "trialing"},
{StatusActive, "active"},
{StatusPastDue, "past_due"},
{StatusCanceled, "canceled"},
{StatusExpired, "expired"},
}
for _, tt := range statuses {
if string(tt.status) != tt.expected {
t.Errorf("Status %s should be '%s'", tt.status, tt.expected)
}
}
}
func TestTaskTypeConstants(t *testing.T) {
types := []struct {
taskType TaskType
expected string
}{
{TaskTypeCorrection, "correction"},
{TaskTypeLetter, "letter"},
{TaskTypeMeeting, "meeting"},
{TaskTypeBatch, "batch"},
{TaskTypeOther, "other"},
}
for _, tt := range types {
if string(tt.taskType) != tt.expected {
t.Errorf("TaskType %s should be '%s'", tt.taskType, tt.expected)
}
}
}
func TestPlanFeatures_CarryoverCalculation(t *testing.T) {
plans := GetDefaultPlans()
for _, plan := range plans {
expectedMax := plan.Features.MonthlyTaskAllowance * CarryoverMonthsCap
if plan.Features.MaxTaskBalance != expectedMax {
t.Errorf("Plan %s: MaxTaskBalance should be %d (allowance * 5), got %d",
plan.ID, expectedMax, plan.Features.MaxTaskBalance)
}
}
}
func TestBillingPlan_AllPlansActive(t *testing.T) {
plans := GetDefaultPlans()
for _, plan := range plans {
if !plan.IsActive {
t.Errorf("Plan %s should be active", plan.ID)
}
}
}
func TestBillingPlan_CurrencyIsEuro(t *testing.T) {
plans := GetDefaultPlans()
for _, plan := range plans {
if plan.Currency != "eur" {
t.Errorf("Plan %s currency should be 'eur', got '%s'", plan.ID, plan.Currency)
}
}
}
func TestBillingPlan_IntervalIsMonth(t *testing.T) {
plans := GetDefaultPlans()
for _, plan := range plans {
if plan.Interval != "month" {
t.Errorf("Plan %s interval should be 'month', got '%s'", plan.ID, plan.Interval)
}
}
}
func TestBillingPlan_SortOrder(t *testing.T) {
plans := GetDefaultPlans()
for i, plan := range plans {
expectedOrder := i + 1
if plan.SortOrder != expectedOrder {
t.Errorf("Plan %s sort order should be %d, got %d",
plan.ID, expectedOrder, plan.SortOrder)
}
}
}
func TestTaskUsageInfo_FormatStrings(t *testing.T) {
usage := TaskUsageInfo{
TasksAvailable: 45,
MaxTasks: 150,
InfoText: "Aufgaben verfuegbar: 45 von max. 150",
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
}
if usage.TasksAvailable != 45 {
t.Errorf("TasksAvailable should be 45, got %d", usage.TasksAvailable)
}
if usage.MaxTasks != 150 {
t.Errorf("MaxTasks should be 150, got %d", usage.MaxTasks)
}
}
func TestCheckTaskAllowedResponse_Allowed(t *testing.T) {
response := CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: 50,
MaxTasks: 150,
PlanID: PlanBasic,
}
if !response.Allowed {
t.Error("Response should be allowed")
}
if response.Message != "" {
t.Errorf("Message should be empty for allowed response, got '%s'", response.Message)
}
}
func TestCheckTaskAllowedResponse_NotAllowed(t *testing.T) {
response := CheckTaskAllowedResponse{
Allowed: false,
TasksAvailable: 0,
MaxTasks: 150,
PlanID: PlanBasic,
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
}
if response.Allowed {
t.Error("Response should not be allowed")
}
if response.TasksAvailable != 0 {
t.Errorf("TasksAvailable should be 0, got %d", response.TasksAvailable)
}
}
func TestTaskLimitError(t *testing.T) {
err := TaskLimitError{
Error: "TASK_LIMIT_REACHED",
CurrentBalance: 0,
Plan: PlanBasic,
}
if err.Error != "TASK_LIMIT_REACHED" {
t.Errorf("Error should be 'TASK_LIMIT_REACHED', got '%s'", err.Error)
}
if err.CurrentBalance != 0 {
t.Errorf("CurrentBalance should be 0, got %d", err.CurrentBalance)
}
if err.Plan != PlanBasic {
t.Errorf("Plan should be basic, got '%s'", err.Plan)
}
}
func TestConsumeTaskRequest(t *testing.T) {
req := ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: TaskTypeCorrection,
}
if req.UserID == "" {
t.Error("UserID should not be empty")
}
if req.TaskType != TaskTypeCorrection {
t.Errorf("TaskType should be correction, got '%s'", req.TaskType)
}
}
func TestConsumeTaskResponse_Success(t *testing.T) {
resp := ConsumeTaskResponse{
Success: true,
TaskID: "task-123",
TasksRemaining: 49,
}
if !resp.Success {
t.Error("Response should be successful")
}
if resp.TasksRemaining != 49 {
t.Errorf("TasksRemaining should be 49, got %d", resp.TasksRemaining)
}
}
func TestEntitlementInfo_Premium(t *testing.T) {
premium := GetDefaultPlans()[2]
info := EntitlementInfo{
Features: premium.Features.FeatureFlags,
MaxTeamMembers: premium.Features.MaxTeamMembers,
PrioritySupport: premium.Features.PrioritySupport,
CustomBranding: premium.Features.CustomBranding,
BatchProcessing: premium.Features.BatchProcessing,
CustomTemplates: premium.Features.CustomTemplates,
FairUseMode: premium.Features.FairUseMode,
}
if !info.FairUseMode {
t.Error("Premium should have FairUseMode")
}
if info.MaxTeamMembers != 10 {
t.Errorf("Premium MaxTeamMembers should be 10, got %d", info.MaxTeamMembers)
}
}

View File

@@ -0,0 +1,232 @@
package services
import (
"context"
"encoding/json"
"time"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
)
// EntitlementService handles entitlement-related operations
type EntitlementService struct {
db *database.DB
subService *SubscriptionService
}
// NewEntitlementService creates a new EntitlementService
func NewEntitlementService(db *database.DB, subService *SubscriptionService) *EntitlementService {
return &EntitlementService{
db: db,
subService: subService,
}
}
// GetEntitlements returns the entitlement info for a user
func (s *EntitlementService) GetEntitlements(ctx context.Context, userID uuid.UUID) (*models.EntitlementInfo, error) {
entitlements, err := s.getUserEntitlements(ctx, userID)
if err != nil || entitlements == nil {
return nil, err
}
return &models.EntitlementInfo{
Features: entitlements.Features.FeatureFlags,
MaxTeamMembers: entitlements.Features.MaxTeamMembers,
PrioritySupport: entitlements.Features.PrioritySupport,
CustomBranding: entitlements.Features.CustomBranding,
}, nil
}
// GetEntitlementsByUserIDString returns entitlements by user ID string (for internal API)
func (s *EntitlementService) GetEntitlementsByUserIDString(ctx context.Context, userIDStr string) (*models.UserEntitlements, error) {
userID, err := uuid.Parse(userIDStr)
if err != nil {
return nil, err
}
return s.getUserEntitlements(ctx, userID)
}
// getUserEntitlements retrieves or creates entitlements for a user
func (s *EntitlementService) getUserEntitlements(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) {
query := `
SELECT id, user_id, plan_id, ai_requests_limit, ai_requests_used,
documents_limit, documents_used, features, period_start, period_end,
created_at, updated_at
FROM user_entitlements
WHERE user_id = $1
`
var ent models.UserEntitlements
var featuresJSON []byte
var periodStart, periodEnd *time.Time
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
&ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed,
&ent.DocumentsLimit, &ent.DocumentsUsed, &featuresJSON, &periodStart, &periodEnd,
nil, &ent.UpdatedAt,
)
if err != nil {
if err.Error() == "no rows in result set" {
// Try to create entitlements based on subscription
return s.createEntitlementsFromSubscription(ctx, userID)
}
return nil, err
}
if len(featuresJSON) > 0 {
json.Unmarshal(featuresJSON, &ent.Features)
}
return &ent, nil
}
// createEntitlementsFromSubscription creates entitlements based on user's subscription
func (s *EntitlementService) createEntitlementsFromSubscription(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) {
// Get user's subscription
sub, err := s.subService.GetByUserID(ctx, userID)
if err != nil || sub == nil {
return nil, err
}
// Get plan details
plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID))
if err != nil || plan == nil {
return nil, err
}
// Create entitlements
return s.CreateEntitlements(ctx, userID, sub.PlanID, plan.Features, sub.CurrentPeriodEnd)
}
// CreateEntitlements creates entitlements for a user
func (s *EntitlementService) CreateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures, periodEnd *time.Time) (*models.UserEntitlements, error) {
featuresJSON, _ := json.Marshal(features)
now := time.Now()
periodStart := now
query := `
INSERT INTO user_entitlements (
user_id, plan_id, ai_requests_limit, ai_requests_used,
documents_limit, documents_used, features, period_start, period_end
) VALUES ($1, $2, $3, 0, $4, 0, $5, $6, $7)
ON CONFLICT (user_id) DO UPDATE SET
plan_id = EXCLUDED.plan_id,
ai_requests_limit = EXCLUDED.ai_requests_limit,
documents_limit = EXCLUDED.documents_limit,
features = EXCLUDED.features,
period_start = EXCLUDED.period_start,
period_end = EXCLUDED.period_end,
updated_at = NOW()
RETURNING id, user_id, plan_id, ai_requests_limit, ai_requests_used,
documents_limit, documents_used, updated_at
`
var ent models.UserEntitlements
err := s.db.Pool.QueryRow(ctx, query,
userID, planID, features.AIRequestsLimit, features.DocumentsLimit,
featuresJSON, periodStart, periodEnd,
).Scan(
&ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed,
&ent.DocumentsLimit, &ent.DocumentsUsed, &ent.UpdatedAt,
)
if err != nil {
return nil, err
}
ent.Features = features
return &ent, nil
}
// UpdateEntitlements updates entitlements for a user (e.g., on plan change)
func (s *EntitlementService) UpdateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures) error {
featuresJSON, _ := json.Marshal(features)
query := `
UPDATE user_entitlements SET
plan_id = $2,
ai_requests_limit = $3,
documents_limit = $4,
features = $5,
updated_at = NOW()
WHERE user_id = $1
`
_, err := s.db.Pool.Exec(ctx, query,
userID, planID, features.AIRequestsLimit, features.DocumentsLimit, featuresJSON,
)
return err
}
// ResetUsageCounters resets usage counters for a new period
func (s *EntitlementService) ResetUsageCounters(ctx context.Context, userID uuid.UUID, newPeriodStart, newPeriodEnd *time.Time) error {
query := `
UPDATE user_entitlements SET
ai_requests_used = 0,
documents_used = 0,
period_start = $2,
period_end = $3,
updated_at = NOW()
WHERE user_id = $1
`
_, err := s.db.Pool.Exec(ctx, query, userID, newPeriodStart, newPeriodEnd)
return err
}
// CheckEntitlement checks if a user has a specific feature entitlement
func (s *EntitlementService) CheckEntitlement(ctx context.Context, userIDStr, feature string) (bool, models.PlanID, error) {
userID, err := uuid.Parse(userIDStr)
if err != nil {
return false, "", err
}
ent, err := s.getUserEntitlements(ctx, userID)
if err != nil || ent == nil {
return false, "", err
}
// Check if feature is in the feature flags
for _, f := range ent.Features.FeatureFlags {
if f == feature {
return true, ent.PlanID, nil
}
}
return false, ent.PlanID, nil
}
// IncrementUsage increments a usage counter
func (s *EntitlementService) IncrementUsage(ctx context.Context, userID uuid.UUID, usageType string, amount int) error {
var column string
switch usageType {
case "ai_request":
column = "ai_requests_used"
case "document_created":
column = "documents_used"
default:
return nil
}
query := `
UPDATE user_entitlements SET
` + column + ` = ` + column + ` + $2,
updated_at = NOW()
WHERE user_id = $1
`
_, err := s.db.Pool.Exec(ctx, query, userID, amount)
return err
}
// DeleteEntitlements removes entitlements for a user (on subscription cancellation)
func (s *EntitlementService) DeleteEntitlements(ctx context.Context, userID uuid.UUID) error {
query := `DELETE FROM user_entitlements WHERE user_id = $1`
_, err := s.db.Pool.Exec(ctx, query, userID)
return err
}

View File

@@ -0,0 +1,317 @@
package services
import (
"context"
"fmt"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/billingportal/session"
checkoutsession "github.com/stripe/stripe-go/v76/checkout/session"
"github.com/stripe/stripe-go/v76/customer"
"github.com/stripe/stripe-go/v76/price"
"github.com/stripe/stripe-go/v76/product"
"github.com/stripe/stripe-go/v76/subscription"
)
// StripeService handles Stripe API interactions
type StripeService struct {
secretKey string
webhookSecret string
successURL string
cancelURL string
trialPeriodDays int64
subService *SubscriptionService
mockMode bool // If true, don't make real Stripe API calls
}
// NewStripeService creates a new StripeService
func NewStripeService(secretKey, webhookSecret, successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService {
// Initialize Stripe with the secret key (only if not empty)
if secretKey != "" {
stripe.Key = secretKey
}
return &StripeService{
secretKey: secretKey,
webhookSecret: webhookSecret,
successURL: successURL,
cancelURL: cancelURL,
trialPeriodDays: int64(trialPeriodDays),
subService: subService,
mockMode: false,
}
}
// NewMockStripeService creates a mock StripeService for development
func NewMockStripeService(successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService {
return &StripeService{
secretKey: "",
webhookSecret: "",
successURL: successURL,
cancelURL: cancelURL,
trialPeriodDays: int64(trialPeriodDays),
subService: subService,
mockMode: true,
}
}
// IsMockMode returns true if running in mock mode
func (s *StripeService) IsMockMode() bool {
return s.mockMode
}
// CreateCheckoutSession creates a Stripe Checkout session for trial start
func (s *StripeService) CreateCheckoutSession(ctx context.Context, userID uuid.UUID, email string, planID models.PlanID) (string, string, error) {
// Mock mode: return a fake URL for development
if s.mockMode {
mockSessionID := fmt.Sprintf("mock_cs_%s", uuid.New().String()[:8])
mockURL := fmt.Sprintf("%s?session_id=%s&mock=true&plan=%s", s.successURL, mockSessionID, planID)
return mockURL, mockSessionID, nil
}
// Get plan details
plan, err := s.subService.GetPlanByID(ctx, string(planID))
if err != nil || plan == nil {
return "", "", fmt.Errorf("plan not found: %s", planID)
}
// Ensure we have a Stripe price ID
if plan.StripePriceID == "" {
// Create product and price in Stripe if not exists
priceID, err := s.ensurePriceExists(ctx, plan)
if err != nil {
return "", "", fmt.Errorf("failed to create stripe price: %w", err)
}
plan.StripePriceID = priceID
}
// Create checkout session parameters
params := &stripe.CheckoutSessionParams{
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(plan.StripePriceID),
Quantity: stripe.Int64(1),
},
},
SuccessURL: stripe.String(s.successURL + "?session_id={CHECKOUT_SESSION_ID}"),
CancelURL: stripe.String(s.cancelURL),
SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
TrialPeriodDays: stripe.Int64(s.trialPeriodDays),
Metadata: map[string]string{
"user_id": userID.String(),
"plan_id": string(planID),
},
},
PaymentMethodCollection: stripe.String(string(stripe.CheckoutSessionPaymentMethodCollectionAlways)),
Metadata: map[string]string{
"user_id": userID.String(),
"plan_id": string(planID),
},
}
// Set customer email if provided
if email != "" {
params.CustomerEmail = stripe.String(email)
}
// Create the session
sess, err := checkoutsession.New(params)
if err != nil {
return "", "", fmt.Errorf("failed to create checkout session: %w", err)
}
return sess.URL, sess.ID, nil
}
// ensurePriceExists creates a Stripe product and price if they don't exist
func (s *StripeService) ensurePriceExists(ctx context.Context, plan *models.BillingPlan) (string, error) {
// Create product
productParams := &stripe.ProductParams{
Name: stripe.String(plan.Name),
Description: stripe.String(plan.Description),
Metadata: map[string]string{
"plan_id": string(plan.ID),
},
}
prod, err := product.New(productParams)
if err != nil {
return "", fmt.Errorf("failed to create product: %w", err)
}
// Create price
priceParams := &stripe.PriceParams{
Product: stripe.String(prod.ID),
UnitAmount: stripe.Int64(int64(plan.PriceCents)),
Currency: stripe.String(plan.Currency),
Recurring: &stripe.PriceRecurringParams{
Interval: stripe.String(plan.Interval),
},
Metadata: map[string]string{
"plan_id": string(plan.ID),
},
}
pr, err := price.New(priceParams)
if err != nil {
return "", fmt.Errorf("failed to create price: %w", err)
}
// Update plan with Stripe IDs
if err := s.subService.UpdatePlanStripePriceID(ctx, string(plan.ID), pr.ID, prod.ID); err != nil {
// Log but don't fail
fmt.Printf("Warning: Failed to update plan with Stripe IDs: %v\n", err)
}
return pr.ID, nil
}
// GetOrCreateCustomer gets or creates a Stripe customer for a user
func (s *StripeService) GetOrCreateCustomer(ctx context.Context, email, name string, userID uuid.UUID) (string, error) {
// Search for existing customer
params := &stripe.CustomerSearchParams{
SearchParams: stripe.SearchParams{
Query: fmt.Sprintf("email:'%s'", email),
},
}
iter := customer.Search(params)
for iter.Next() {
cust := iter.Customer()
// Check if this customer belongs to our user
if cust.Metadata["user_id"] == userID.String() {
return cust.ID, nil
}
}
// Create new customer
customerParams := &stripe.CustomerParams{
Email: stripe.String(email),
Name: stripe.String(name),
Metadata: map[string]string{
"user_id": userID.String(),
},
}
cust, err := customer.New(customerParams)
if err != nil {
return "", fmt.Errorf("failed to create customer: %w", err)
}
return cust.ID, nil
}
// ChangePlan changes a subscription to a new plan
func (s *StripeService) ChangePlan(ctx context.Context, stripeSubID string, newPlanID models.PlanID) error {
// Mock mode: just return success
if s.mockMode {
return nil
}
// Get new plan details
plan, err := s.subService.GetPlanByID(ctx, string(newPlanID))
if err != nil || plan == nil {
return fmt.Errorf("plan not found: %s", newPlanID)
}
if plan.StripePriceID == "" {
return fmt.Errorf("plan %s has no Stripe price ID", newPlanID)
}
// Get current subscription
sub, err := subscription.Get(stripeSubID, nil)
if err != nil {
return fmt.Errorf("failed to get subscription: %w", err)
}
// Update subscription with new price
params := &stripe.SubscriptionParams{
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(plan.StripePriceID),
},
},
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
Metadata: map[string]string{
"plan_id": string(newPlanID),
},
}
_, err = subscription.Update(stripeSubID, params)
if err != nil {
return fmt.Errorf("failed to update subscription: %w", err)
}
return nil
}
// CancelSubscription cancels a subscription at period end
func (s *StripeService) CancelSubscription(ctx context.Context, stripeSubID string) error {
// Mock mode: just return success
if s.mockMode {
return nil
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true),
}
_, err := subscription.Update(stripeSubID, params)
if err != nil {
return fmt.Errorf("failed to cancel subscription: %w", err)
}
return nil
}
// ReactivateSubscription removes the cancel_at_period_end flag
func (s *StripeService) ReactivateSubscription(ctx context.Context, stripeSubID string) error {
// Mock mode: just return success
if s.mockMode {
return nil
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
}
_, err := subscription.Update(stripeSubID, params)
if err != nil {
return fmt.Errorf("failed to reactivate subscription: %w", err)
}
return nil
}
// CreateCustomerPortalSession creates a Stripe Customer Portal session
func (s *StripeService) CreateCustomerPortalSession(ctx context.Context, customerID string) (string, error) {
// Mock mode: return a mock URL
if s.mockMode {
return fmt.Sprintf("%s?mock_portal=true", s.successURL), nil
}
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(customerID),
ReturnURL: stripe.String(s.successURL),
}
sess, err := session.New(params)
if err != nil {
return "", fmt.Errorf("failed to create portal session: %w", err)
}
return sess.URL, nil
}
// GetSubscription retrieves a subscription from Stripe
func (s *StripeService) GetSubscription(ctx context.Context, stripeSubID string) (*stripe.Subscription, error) {
sub, err := subscription.Get(stripeSubID, nil)
if err != nil {
return nil, fmt.Errorf("failed to get subscription: %w", err)
}
return sub, nil
}

View File

@@ -0,0 +1,315 @@
package services
import (
"context"
"encoding/json"
"time"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
)
// SubscriptionService handles subscription-related operations
type SubscriptionService struct {
db *database.DB
}
// NewSubscriptionService creates a new SubscriptionService
func NewSubscriptionService(db *database.DB) *SubscriptionService {
return &SubscriptionService{db: db}
}
// GetByUserID retrieves a subscription by user ID
func (s *SubscriptionService) GetByUserID(ctx context.Context, userID uuid.UUID) (*models.Subscription, error) {
query := `
SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id,
status, trial_end, current_period_end, cancel_at_period_end,
created_at, updated_at
FROM subscriptions
WHERE user_id = $1
`
var sub models.Subscription
var stripeCustomerID, stripeSubID *string
var trialEnd, periodEnd *time.Time
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
&sub.ID, &sub.UserID, &stripeCustomerID, &stripeSubID, &sub.PlanID,
&sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd,
&sub.CreatedAt, &sub.UpdatedAt,
)
if err != nil {
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
if stripeCustomerID != nil {
sub.StripeCustomerID = *stripeCustomerID
}
if stripeSubID != nil {
sub.StripeSubscriptionID = *stripeSubID
}
sub.TrialEnd = trialEnd
sub.CurrentPeriodEnd = periodEnd
return &sub, nil
}
// GetByStripeSubscriptionID retrieves a subscription by Stripe subscription ID
func (s *SubscriptionService) GetByStripeSubscriptionID(ctx context.Context, stripeSubID string) (*models.Subscription, error) {
query := `
SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id,
status, trial_end, current_period_end, cancel_at_period_end,
created_at, updated_at
FROM subscriptions
WHERE stripe_subscription_id = $1
`
var sub models.Subscription
var stripeCustomerID, subID *string
var trialEnd, periodEnd *time.Time
err := s.db.Pool.QueryRow(ctx, query, stripeSubID).Scan(
&sub.ID, &sub.UserID, &stripeCustomerID, &subID, &sub.PlanID,
&sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd,
&sub.CreatedAt, &sub.UpdatedAt,
)
if err != nil {
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
if stripeCustomerID != nil {
sub.StripeCustomerID = *stripeCustomerID
}
if subID != nil {
sub.StripeSubscriptionID = *subID
}
sub.TrialEnd = trialEnd
sub.CurrentPeriodEnd = periodEnd
return &sub, nil
}
// Create creates a new subscription
func (s *SubscriptionService) Create(ctx context.Context, sub *models.Subscription) error {
query := `
INSERT INTO subscriptions (
user_id, stripe_customer_id, stripe_subscription_id, plan_id,
status, trial_end, current_period_end, cancel_at_period_end
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at, updated_at
`
return s.db.Pool.QueryRow(ctx, query,
sub.UserID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID,
sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd,
).Scan(&sub.ID, &sub.CreatedAt, &sub.UpdatedAt)
}
// Update updates an existing subscription
func (s *SubscriptionService) Update(ctx context.Context, sub *models.Subscription) error {
query := `
UPDATE subscriptions SET
stripe_customer_id = $2,
stripe_subscription_id = $3,
plan_id = $4,
status = $5,
trial_end = $6,
current_period_end = $7,
cancel_at_period_end = $8,
updated_at = NOW()
WHERE id = $1
`
_, err := s.db.Pool.Exec(ctx, query,
sub.ID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID,
sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd,
)
return err
}
// UpdateStatus updates the subscription status
func (s *SubscriptionService) UpdateStatus(ctx context.Context, id uuid.UUID, status models.SubscriptionStatus) error {
query := `UPDATE subscriptions SET status = $2, updated_at = NOW() WHERE id = $1`
_, err := s.db.Pool.Exec(ctx, query, id, status)
return err
}
// GetAvailablePlans retrieves all active billing plans
func (s *SubscriptionService) GetAvailablePlans(ctx context.Context) ([]models.BillingPlan, error) {
query := `
SELECT id, stripe_price_id, name, description, price_cents,
currency, interval, features, is_active, sort_order
FROM billing_plans
WHERE is_active = true
ORDER BY sort_order ASC
`
rows, err := s.db.Pool.Query(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var plans []models.BillingPlan
for rows.Next() {
var plan models.BillingPlan
var stripePriceID *string
var featuresJSON []byte
err := rows.Scan(
&plan.ID, &stripePriceID, &plan.Name, &plan.Description,
&plan.PriceCents, &plan.Currency, &plan.Interval,
&featuresJSON, &plan.IsActive, &plan.SortOrder,
)
if err != nil {
return nil, err
}
if stripePriceID != nil {
plan.StripePriceID = *stripePriceID
}
// Parse features JSON
if len(featuresJSON) > 0 {
json.Unmarshal(featuresJSON, &plan.Features)
}
plans = append(plans, plan)
}
return plans, nil
}
// GetPlanByID retrieves a billing plan by ID
func (s *SubscriptionService) GetPlanByID(ctx context.Context, planID string) (*models.BillingPlan, error) {
query := `
SELECT id, stripe_price_id, name, description, price_cents,
currency, interval, features, is_active, sort_order
FROM billing_plans
WHERE id = $1
`
var plan models.BillingPlan
var stripePriceID *string
var featuresJSON []byte
err := s.db.Pool.QueryRow(ctx, query, planID).Scan(
&plan.ID, &stripePriceID, &plan.Name, &plan.Description,
&plan.PriceCents, &plan.Currency, &plan.Interval,
&featuresJSON, &plan.IsActive, &plan.SortOrder,
)
if err != nil {
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
if stripePriceID != nil {
plan.StripePriceID = *stripePriceID
}
if len(featuresJSON) > 0 {
json.Unmarshal(featuresJSON, &plan.Features)
}
return &plan, nil
}
// UpdatePlanStripePriceID updates the Stripe price ID for a plan
func (s *SubscriptionService) UpdatePlanStripePriceID(ctx context.Context, planID, stripePriceID, stripeProductID string) error {
query := `
UPDATE billing_plans
SET stripe_price_id = $2, stripe_product_id = $3, updated_at = NOW()
WHERE id = $1
`
_, err := s.db.Pool.Exec(ctx, query, planID, stripePriceID, stripeProductID)
return err
}
// =============================================
// Webhook Event Tracking (Idempotency)
// =============================================
// IsEventProcessed checks if a webhook event has already been processed
func (s *SubscriptionService) IsEventProcessed(ctx context.Context, eventID string) (bool, error) {
query := `SELECT processed FROM stripe_webhook_events WHERE stripe_event_id = $1`
var processed bool
err := s.db.Pool.QueryRow(ctx, query, eventID).Scan(&processed)
if err != nil {
if err.Error() == "no rows in result set" {
return false, nil
}
return false, err
}
return processed, nil
}
// MarkEventProcessing marks an event as being processed
func (s *SubscriptionService) MarkEventProcessing(ctx context.Context, eventID, eventType string) error {
query := `
INSERT INTO stripe_webhook_events (stripe_event_id, event_type, processed)
VALUES ($1, $2, false)
ON CONFLICT (stripe_event_id) DO NOTHING
`
_, err := s.db.Pool.Exec(ctx, query, eventID, eventType)
return err
}
// MarkEventProcessed marks an event as successfully processed
func (s *SubscriptionService) MarkEventProcessed(ctx context.Context, eventID string) error {
query := `
UPDATE stripe_webhook_events
SET processed = true, processed_at = NOW()
WHERE stripe_event_id = $1
`
_, err := s.db.Pool.Exec(ctx, query, eventID)
return err
}
// MarkEventFailed marks an event as failed with an error message
func (s *SubscriptionService) MarkEventFailed(ctx context.Context, eventID, errorMsg string) error {
query := `
UPDATE stripe_webhook_events
SET processed = false, error_message = $2, processed_at = NOW()
WHERE stripe_event_id = $1
`
_, err := s.db.Pool.Exec(ctx, query, eventID, errorMsg)
return err
}
// =============================================
// Audit Logging
// =============================================
// LogAuditEvent logs a billing audit event
func (s *SubscriptionService) LogAuditEvent(ctx context.Context, userID *uuid.UUID, action, entityType, entityID string, oldValue, newValue, metadata interface{}, ipAddress, userAgent string) error {
oldJSON, _ := json.Marshal(oldValue)
newJSON, _ := json.Marshal(newValue)
metaJSON, _ := json.Marshal(metadata)
query := `
INSERT INTO billing_audit_log (
user_id, action, entity_type, entity_id,
old_value, new_value, metadata, ip_address, user_agent
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
_, err := s.db.Pool.Exec(ctx, query,
userID, action, entityType, entityID,
oldJSON, newJSON, metaJSON, ipAddress, userAgent,
)
return err
}

View File

@@ -0,0 +1,326 @@
package services
import (
"encoding/json"
"testing"
"github.com/breakpilot/billing-service/internal/models"
)
func TestSubscriptionStatus_Transitions(t *testing.T) {
// Test valid subscription status values
validStatuses := []models.SubscriptionStatus{
models.StatusTrialing,
models.StatusActive,
models.StatusPastDue,
models.StatusCanceled,
models.StatusExpired,
}
for _, status := range validStatuses {
if status == "" {
t.Errorf("Status should not be empty")
}
}
}
func TestPlanID_ValidValues(t *testing.T) {
validPlanIDs := []models.PlanID{
models.PlanBasic,
models.PlanStandard,
models.PlanPremium,
}
expected := []string{"basic", "standard", "premium"}
for i, planID := range validPlanIDs {
if string(planID) != expected[i] {
t.Errorf("PlanID should be '%s', got '%s'", expected[i], planID)
}
}
}
func TestPlanFeatures_JSONSerialization(t *testing.T) {
features := models.PlanFeatures{
MonthlyTaskAllowance: 100,
MaxTaskBalance: 500,
FeatureFlags: []string{"basic_ai", "templates"},
MaxTeamMembers: 3,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: false,
}
// Test JSON serialization
data, err := json.Marshal(features)
if err != nil {
t.Fatalf("Failed to marshal PlanFeatures: %v", err)
}
// Test JSON deserialization
var decoded models.PlanFeatures
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal PlanFeatures: %v", err)
}
// Verify fields
if decoded.MonthlyTaskAllowance != features.MonthlyTaskAllowance {
t.Errorf("MonthlyTaskAllowance mismatch: got %d, expected %d",
decoded.MonthlyTaskAllowance, features.MonthlyTaskAllowance)
}
if decoded.MaxTaskBalance != features.MaxTaskBalance {
t.Errorf("MaxTaskBalance mismatch: got %d, expected %d",
decoded.MaxTaskBalance, features.MaxTaskBalance)
}
if decoded.BatchProcessing != features.BatchProcessing {
t.Errorf("BatchProcessing mismatch: got %v, expected %v",
decoded.BatchProcessing, features.BatchProcessing)
}
}
func TestBillingPlan_DefaultPlansAreValid(t *testing.T) {
plans := models.GetDefaultPlans()
if len(plans) != 3 {
t.Fatalf("Expected 3 default plans, got %d", len(plans))
}
// Verify all plans have required fields
for _, plan := range plans {
if plan.ID == "" {
t.Errorf("Plan ID should not be empty")
}
if plan.Name == "" {
t.Errorf("Plan '%s' should have a name", plan.ID)
}
if plan.Description == "" {
t.Errorf("Plan '%s' should have a description", plan.ID)
}
if plan.PriceCents <= 0 {
t.Errorf("Plan '%s' should have a positive price, got %d", plan.ID, plan.PriceCents)
}
if plan.Currency != "eur" {
t.Errorf("Plan '%s' currency should be 'eur', got '%s'", plan.ID, plan.Currency)
}
if plan.Interval != "month" {
t.Errorf("Plan '%s' interval should be 'month', got '%s'", plan.ID, plan.Interval)
}
if !plan.IsActive {
t.Errorf("Plan '%s' should be active", plan.ID)
}
if plan.SortOrder <= 0 {
t.Errorf("Plan '%s' should have a positive sort order, got %d", plan.ID, plan.SortOrder)
}
}
}
func TestBillingPlan_TaskAllowanceProgression(t *testing.T) {
plans := models.GetDefaultPlans()
// Basic should have lowest allowance
basic := plans[0]
standard := plans[1]
premium := plans[2]
if basic.Features.MonthlyTaskAllowance >= standard.Features.MonthlyTaskAllowance {
t.Error("Standard plan should have more tasks than Basic")
}
if standard.Features.MonthlyTaskAllowance >= premium.Features.MonthlyTaskAllowance {
t.Error("Premium plan should have more tasks than Standard")
}
}
func TestBillingPlan_PriceProgression(t *testing.T) {
plans := models.GetDefaultPlans()
// Prices should increase with each tier
if plans[0].PriceCents >= plans[1].PriceCents {
t.Error("Standard should cost more than Basic")
}
if plans[1].PriceCents >= plans[2].PriceCents {
t.Error("Premium should cost more than Standard")
}
}
func TestBillingPlan_FairUseModeOnlyForPremium(t *testing.T) {
plans := models.GetDefaultPlans()
for _, plan := range plans {
if plan.ID == models.PlanPremium {
if !plan.Features.FairUseMode {
t.Error("Premium plan should have FairUseMode enabled")
}
} else {
if plan.Features.FairUseMode {
t.Errorf("Plan '%s' should not have FairUseMode enabled", plan.ID)
}
}
}
}
func TestBillingPlan_MaxTaskBalanceCalculation(t *testing.T) {
plans := models.GetDefaultPlans()
for _, plan := range plans {
expected := plan.Features.MonthlyTaskAllowance * models.CarryoverMonthsCap
if plan.Features.MaxTaskBalance != expected {
t.Errorf("Plan '%s' MaxTaskBalance should be %d (allowance * 5), got %d",
plan.ID, expected, plan.Features.MaxTaskBalance)
}
}
}
func TestAuditLogJSON_Marshaling(t *testing.T) {
// Test that audit log values can be properly serialized
oldValue := map[string]interface{}{
"plan_id": "basic",
"status": "active",
}
newValue := map[string]interface{}{
"plan_id": "standard",
"status": "active",
}
metadata := map[string]interface{}{
"reason": "upgrade",
}
// Marshal all values
oldJSON, err := json.Marshal(oldValue)
if err != nil {
t.Fatalf("Failed to marshal oldValue: %v", err)
}
newJSON, err := json.Marshal(newValue)
if err != nil {
t.Fatalf("Failed to marshal newValue: %v", err)
}
metaJSON, err := json.Marshal(metadata)
if err != nil {
t.Fatalf("Failed to marshal metadata: %v", err)
}
// Verify non-empty
if len(oldJSON) == 0 || len(newJSON) == 0 || len(metaJSON) == 0 {
t.Error("JSON outputs should not be empty")
}
}
func TestSubscriptionTrialCalculation(t *testing.T) {
// Test trial days calculation logic
trialDays := 7
if trialDays <= 0 {
t.Error("Trial days should be positive")
}
if trialDays > 30 {
t.Error("Trial days should not exceed 30")
}
}
func TestSubscriptionInfo_TrialingStatus(t *testing.T) {
info := models.SubscriptionInfo{
PlanID: models.PlanBasic,
PlanName: "Basic",
Status: models.StatusTrialing,
IsTrialing: true,
TrialDaysLeft: 5,
CancelAtPeriodEnd: false,
PriceCents: 990,
Currency: "eur",
}
if !info.IsTrialing {
t.Error("Should be trialing")
}
if info.Status != models.StatusTrialing {
t.Errorf("Status should be 'trialing', got '%s'", info.Status)
}
if info.TrialDaysLeft <= 0 {
t.Error("TrialDaysLeft should be positive during trial")
}
}
func TestSubscriptionInfo_ActiveStatus(t *testing.T) {
info := models.SubscriptionInfo{
PlanID: models.PlanStandard,
PlanName: "Standard",
Status: models.StatusActive,
IsTrialing: false,
TrialDaysLeft: 0,
CancelAtPeriodEnd: false,
PriceCents: 1990,
Currency: "eur",
}
if info.IsTrialing {
t.Error("Should not be trialing")
}
if info.Status != models.StatusActive {
t.Errorf("Status should be 'active', got '%s'", info.Status)
}
}
func TestSubscriptionInfo_CanceledStatus(t *testing.T) {
info := models.SubscriptionInfo{
PlanID: models.PlanStandard,
PlanName: "Standard",
Status: models.StatusActive,
IsTrialing: false,
CancelAtPeriodEnd: true, // Scheduled for cancellation
PriceCents: 1990,
Currency: "eur",
}
if !info.CancelAtPeriodEnd {
t.Error("CancelAtPeriodEnd should be true")
}
// Status remains active until period end
if info.Status != models.StatusActive {
t.Errorf("Status should still be 'active', got '%s'", info.Status)
}
}
func TestWebhookEventTypes(t *testing.T) {
// Test common Stripe webhook event types we handle
eventTypes := []string{
"checkout.session.completed",
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
"invoice.paid",
"invoice.payment_failed",
}
for _, eventType := range eventTypes {
if eventType == "" {
t.Error("Event type should not be empty")
}
}
}
func TestIdempotencyKey_Format(t *testing.T) {
// Test that we can handle Stripe event IDs
sampleEventIDs := []string{
"evt_1234567890abcdef",
"evt_test_abc123xyz789",
"evt_live_real_event_id",
}
for _, eventID := range sampleEventIDs {
if len(eventID) < 10 {
t.Errorf("Event ID '%s' seems too short", eventID)
}
// Stripe event IDs typically start with "evt_"
if eventID[:4] != "evt_" {
t.Errorf("Event ID '%s' should start with 'evt_'", eventID)
}
}
}

View File

@@ -0,0 +1,352 @@
package services
import (
"context"
"errors"
"fmt"
"time"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
)
var (
// ErrTaskLimitReached is returned when task balance is 0
ErrTaskLimitReached = errors.New("TASK_LIMIT_REACHED")
// ErrNoSubscription is returned when user has no subscription
ErrNoSubscription = errors.New("NO_SUBSCRIPTION")
)
// TaskService handles task consumption and balance management
type TaskService struct {
db *database.DB
subService *SubscriptionService
}
// NewTaskService creates a new TaskService
func NewTaskService(db *database.DB, subService *SubscriptionService) *TaskService {
return &TaskService{
db: db,
subService: subService,
}
}
// GetAccountUsage retrieves or creates account usage for a user
func (s *TaskService) GetAccountUsage(ctx context.Context, userID uuid.UUID) (*models.AccountUsage, error) {
query := `
SELECT id, account_id, plan, monthly_task_allowance, carryover_months_cap,
max_task_balance, task_balance, last_renewal_at, created_at, updated_at
FROM account_usage
WHERE account_id = $1
`
var usage models.AccountUsage
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
&usage.ID, &usage.AccountID, &usage.PlanID, &usage.MonthlyTaskAllowance,
&usage.CarryoverMonthsCap, &usage.MaxTaskBalance, &usage.TaskBalance,
&usage.LastRenewalAt, &usage.CreatedAt, &usage.UpdatedAt,
)
if err != nil {
if err.Error() == "no rows in result set" {
// Create new account usage based on subscription
return s.createAccountUsage(ctx, userID)
}
return nil, err
}
// Check if month renewal is needed
if err := s.checkAndApplyMonthRenewal(ctx, &usage); err != nil {
return nil, err
}
return &usage, nil
}
// createAccountUsage creates account usage based on user's subscription
func (s *TaskService) createAccountUsage(ctx context.Context, userID uuid.UUID) (*models.AccountUsage, error) {
// Get subscription to determine plan
sub, err := s.subService.GetByUserID(ctx, userID)
if err != nil || sub == nil {
return nil, ErrNoSubscription
}
// Get plan features
plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID))
if err != nil || plan == nil {
return nil, fmt.Errorf("plan not found: %s", sub.PlanID)
}
now := time.Now()
usage := &models.AccountUsage{
AccountID: userID,
PlanID: sub.PlanID,
MonthlyTaskAllowance: plan.Features.MonthlyTaskAllowance,
CarryoverMonthsCap: models.CarryoverMonthsCap,
MaxTaskBalance: plan.Features.MaxTaskBalance,
TaskBalance: plan.Features.MonthlyTaskAllowance, // Start with one month's worth
LastRenewalAt: now,
}
query := `
INSERT INTO account_usage (
account_id, plan, monthly_task_allowance, carryover_months_cap,
max_task_balance, task_balance, last_renewal_at
) VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, created_at, updated_at
`
err = s.db.Pool.QueryRow(ctx, query,
usage.AccountID, usage.PlanID, usage.MonthlyTaskAllowance,
usage.CarryoverMonthsCap, usage.MaxTaskBalance, usage.TaskBalance, usage.LastRenewalAt,
).Scan(&usage.ID, &usage.CreatedAt, &usage.UpdatedAt)
if err != nil {
return nil, err
}
return usage, nil
}
// checkAndApplyMonthRenewal checks if a month has passed and adds allowance
// Implements the carryover logic: tasks accumulate up to max_task_balance
func (s *TaskService) checkAndApplyMonthRenewal(ctx context.Context, usage *models.AccountUsage) error {
now := time.Now()
// Check if at least one month has passed since last renewal
monthsSinceRenewal := monthsBetween(usage.LastRenewalAt, now)
if monthsSinceRenewal < 1 {
return nil
}
// Calculate new balance with carryover
// Add monthly allowance for each month that passed
newBalance := usage.TaskBalance
for i := 0; i < monthsSinceRenewal; i++ {
newBalance += usage.MonthlyTaskAllowance
// Cap at max balance
if newBalance > usage.MaxTaskBalance {
newBalance = usage.MaxTaskBalance
break
}
}
// Calculate new renewal date (add the number of months)
newRenewalAt := usage.LastRenewalAt.AddDate(0, monthsSinceRenewal, 0)
// Update in database
query := `
UPDATE account_usage
SET task_balance = $2, last_renewal_at = $3, updated_at = NOW()
WHERE id = $1
`
_, err := s.db.Pool.Exec(ctx, query, usage.ID, newBalance, newRenewalAt)
if err != nil {
return err
}
// Update local struct
usage.TaskBalance = newBalance
usage.LastRenewalAt = newRenewalAt
return nil
}
// monthsBetween calculates full months between two dates
func monthsBetween(start, end time.Time) int {
months := 0
for start.AddDate(0, months+1, 0).Before(end) || start.AddDate(0, months+1, 0).Equal(end) {
months++
}
return months
}
// CheckTaskAllowed checks if a task can be consumed (balance > 0)
func (s *TaskService) CheckTaskAllowed(ctx context.Context, userID uuid.UUID) (*models.CheckTaskAllowedResponse, error) {
usage, err := s.GetAccountUsage(ctx, userID)
if err != nil {
if errors.Is(err, ErrNoSubscription) {
return &models.CheckTaskAllowedResponse{
Allowed: false,
PlanID: "",
Message: "Kein aktives Abonnement gefunden.",
}, nil
}
return nil, err
}
// Premium Fair Use mode - always allow
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
if plan != nil && plan.Features.FairUseMode {
return &models.CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: usage.TaskBalance,
MaxTasks: usage.MaxTaskBalance,
PlanID: usage.PlanID,
}, nil
}
allowed := usage.TaskBalance > 0
response := &models.CheckTaskAllowedResponse{
Allowed: allowed,
TasksAvailable: usage.TaskBalance,
MaxTasks: usage.MaxTaskBalance,
PlanID: usage.PlanID,
}
if !allowed {
response.Message = "Dein Aufgaben-Kontingent ist aufgebraucht."
}
return response, nil
}
// ConsumeTask consumes one task from the balance
// Returns error if balance is 0
func (s *TaskService) ConsumeTask(ctx context.Context, userID uuid.UUID, taskType models.TaskType) (*models.ConsumeTaskResponse, error) {
// First check if allowed
checkResponse, err := s.CheckTaskAllowed(ctx, userID)
if err != nil {
return nil, err
}
if !checkResponse.Allowed {
return &models.ConsumeTaskResponse{
Success: false,
TasksRemaining: 0,
Message: checkResponse.Message,
}, ErrTaskLimitReached
}
// Get current usage
usage, err := s.GetAccountUsage(ctx, userID)
if err != nil {
return nil, err
}
// Start transaction
tx, err := s.db.Pool.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
// Decrement balance (only if not Premium Fair Use)
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
newBalance := usage.TaskBalance
if plan == nil || !plan.Features.FairUseMode {
newBalance = usage.TaskBalance - 1
_, err = tx.Exec(ctx, `
UPDATE account_usage
SET task_balance = $2, updated_at = NOW()
WHERE account_id = $1
`, userID, newBalance)
if err != nil {
return nil, err
}
}
// Create task record
taskID := uuid.New()
_, err = tx.Exec(ctx, `
INSERT INTO tasks (id, account_id, task_type, consumed, created_at)
VALUES ($1, $2, $3, true, NOW())
`, taskID, userID, taskType)
if err != nil {
return nil, err
}
// Commit transaction
if err = tx.Commit(ctx); err != nil {
return nil, err
}
return &models.ConsumeTaskResponse{
Success: true,
TaskID: taskID.String(),
TasksRemaining: newBalance,
}, nil
}
// GetTaskUsageInfo returns formatted task usage info for display
func (s *TaskService) GetTaskUsageInfo(ctx context.Context, userID uuid.UUID) (*models.TaskUsageInfo, error) {
usage, err := s.GetAccountUsage(ctx, userID)
if err != nil {
return nil, err
}
// Check for Fair Use mode (Premium)
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
if plan != nil && plan.Features.FairUseMode {
return &models.TaskUsageInfo{
TasksAvailable: usage.TaskBalance,
MaxTasks: usage.MaxTaskBalance,
InfoText: "Unbegrenzte Aufgaben (Fair Use)",
TooltipText: "Im Premium-Tarif gibt es keine praktische Begrenzung.",
}, nil
}
return &models.TaskUsageInfo{
TasksAvailable: usage.TaskBalance,
MaxTasks: usage.MaxTaskBalance,
InfoText: fmt.Sprintf("Aufgaben verfuegbar: %d von max. %d", usage.TaskBalance, usage.MaxTaskBalance),
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
}, nil
}
// UpdatePlanForUser updates the plan and adjusts allowances
func (s *TaskService) UpdatePlanForUser(ctx context.Context, userID uuid.UUID, newPlanID models.PlanID) error {
plan, err := s.subService.GetPlanByID(ctx, string(newPlanID))
if err != nil || plan == nil {
return fmt.Errorf("plan not found: %s", newPlanID)
}
// Update account usage with new plan limits
query := `
UPDATE account_usage
SET plan = $2,
monthly_task_allowance = $3,
max_task_balance = $4,
updated_at = NOW()
WHERE account_id = $1
`
_, err = s.db.Pool.Exec(ctx, query,
userID, newPlanID, plan.Features.MonthlyTaskAllowance, plan.Features.MaxTaskBalance)
return err
}
// GetTaskHistory returns task history for a user
func (s *TaskService) GetTaskHistory(ctx context.Context, userID uuid.UUID, limit int) ([]models.Task, error) {
if limit <= 0 {
limit = 50
}
query := `
SELECT id, account_id, task_type, created_at, consumed
FROM tasks
WHERE account_id = $1
ORDER BY created_at DESC
LIMIT $2
`
rows, err := s.db.Pool.Query(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var tasks []models.Task
for rows.Next() {
var task models.Task
err := rows.Scan(&task.ID, &task.AccountID, &task.TaskType, &task.CreatedAt, &task.Consumed)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
}
return tasks, nil
}

View File

@@ -0,0 +1,397 @@
package services
import (
"testing"
"time"
)
func TestMonthsBetween(t *testing.T) {
tests := []struct {
name string
start time.Time
end time.Time
expected int
}{
{
name: "Same day",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
expected: 0,
},
{
name: "Less than one month",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 10, 0, 0, 0, 0, time.UTC),
expected: 0,
},
{
name: "Exactly one month",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
expected: 1,
},
{
name: "One month and one day",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 16, 0, 0, 0, 0, time.UTC),
expected: 1,
},
{
name: "Two months",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 3, 15, 0, 0, 0, 0, time.UTC),
expected: 2,
},
{
name: "Five months exactly",
start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC),
expected: 5,
},
{
name: "Year boundary",
start: time.Date(2024, 11, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
expected: 3,
},
{
name: "Leap year February to March",
start: time.Date(2024, 2, 29, 0, 0, 0, 0, time.UTC),
end: time.Date(2024, 3, 29, 0, 0, 0, 0, time.UTC),
expected: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := monthsBetween(tt.start, tt.end)
if result != tt.expected {
t.Errorf("monthsBetween(%v, %v) = %d, expected %d",
tt.start.Format("2006-01-02"), tt.end.Format("2006-01-02"),
result, tt.expected)
}
})
}
}
func TestCarryoverLogic(t *testing.T) {
// Test the carryover calculation logic
tests := []struct {
name string
currentBalance int
monthlyAllowance int
maxBalance int
monthsSinceRenewal int
expectedNewBalance int
}{
{
name: "Normal renewal - add allowance",
currentBalance: 50,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 1,
expectedNewBalance: 80,
},
{
name: "Two months missed",
currentBalance: 50,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 2,
expectedNewBalance: 110,
},
{
name: "Cap at max balance",
currentBalance: 140,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 1,
expectedNewBalance: 150,
},
{
name: "Already at max - no change",
currentBalance: 150,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 1,
expectedNewBalance: 150,
},
{
name: "Multiple months - cap applies",
currentBalance: 100,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 5,
expectedNewBalance: 150,
},
{
name: "Empty balance - add one month",
currentBalance: 0,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 1,
expectedNewBalance: 30,
},
{
name: "Empty balance - add five months",
currentBalance: 0,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 5,
expectedNewBalance: 150,
},
{
name: "Standard plan - normal case",
currentBalance: 200,
monthlyAllowance: 100,
maxBalance: 500,
monthsSinceRenewal: 1,
expectedNewBalance: 300,
},
{
name: "Premium plan - Fair Use",
currentBalance: 1000,
monthlyAllowance: 1000,
maxBalance: 5000,
monthsSinceRenewal: 1,
expectedNewBalance: 2000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the carryover logic
newBalance := tt.currentBalance
for i := 0; i < tt.monthsSinceRenewal; i++ {
newBalance += tt.monthlyAllowance
if newBalance > tt.maxBalance {
newBalance = tt.maxBalance
break
}
}
if newBalance != tt.expectedNewBalance {
t.Errorf("Carryover for balance=%d, allowance=%d, max=%d, months=%d = %d, expected %d",
tt.currentBalance, tt.monthlyAllowance, tt.maxBalance, tt.monthsSinceRenewal,
newBalance, tt.expectedNewBalance)
}
})
}
}
func TestTaskBalanceAfterConsumption(t *testing.T) {
tests := []struct {
name string
currentBalance int
tasksToConsume int
expectedBalance int
shouldBeAllowed bool
}{
{
name: "Normal consumption",
currentBalance: 50,
tasksToConsume: 1,
expectedBalance: 49,
shouldBeAllowed: true,
},
{
name: "Last task",
currentBalance: 1,
tasksToConsume: 1,
expectedBalance: 0,
shouldBeAllowed: true,
},
{
name: "Empty balance - not allowed",
currentBalance: 0,
tasksToConsume: 1,
expectedBalance: 0,
shouldBeAllowed: false,
},
{
name: "Multiple tasks",
currentBalance: 50,
tasksToConsume: 5,
expectedBalance: 45,
shouldBeAllowed: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test if allowed
allowed := tt.currentBalance > 0
if allowed != tt.shouldBeAllowed {
t.Errorf("Task allowed with balance=%d: got %v, expected %v",
tt.currentBalance, allowed, tt.shouldBeAllowed)
}
// Test balance calculation
if allowed {
newBalance := tt.currentBalance - tt.tasksToConsume
if newBalance != tt.expectedBalance {
t.Errorf("Balance after consuming %d tasks from %d: got %d, expected %d",
tt.tasksToConsume, tt.currentBalance, newBalance, tt.expectedBalance)
}
}
})
}
}
func TestTaskServiceErrors(t *testing.T) {
// Test error constants
if ErrTaskLimitReached == nil {
t.Error("ErrTaskLimitReached should not be nil")
}
if ErrTaskLimitReached.Error() != "TASK_LIMIT_REACHED" {
t.Errorf("ErrTaskLimitReached should be 'TASK_LIMIT_REACHED', got '%s'", ErrTaskLimitReached.Error())
}
if ErrNoSubscription == nil {
t.Error("ErrNoSubscription should not be nil")
}
if ErrNoSubscription.Error() != "NO_SUBSCRIPTION" {
t.Errorf("ErrNoSubscription should be 'NO_SUBSCRIPTION', got '%s'", ErrNoSubscription.Error())
}
}
func TestRenewalDateCalculation(t *testing.T) {
tests := []struct {
name string
lastRenewal time.Time
monthsToAdd int
expectedRenewal time.Time
}{
{
name: "Add one month",
lastRenewal: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
monthsToAdd: 1,
expectedRenewal: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
},
{
name: "Add three months",
lastRenewal: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
monthsToAdd: 3,
expectedRenewal: time.Date(2025, 4, 15, 0, 0, 0, 0, time.UTC),
},
{
name: "Year boundary",
lastRenewal: time.Date(2024, 11, 15, 0, 0, 0, 0, time.UTC),
monthsToAdd: 3,
expectedRenewal: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
},
{
name: "End of month adjustment",
lastRenewal: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC),
monthsToAdd: 1,
// Go's AddDate handles this - February doesn't have 31 days
expectedRenewal: time.Date(2025, 3, 3, 0, 0, 0, 0, time.UTC), // Feb 31 -> March 3
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.lastRenewal.AddDate(0, tt.monthsToAdd, 0)
if !result.Equal(tt.expectedRenewal) {
t.Errorf("AddDate(%v, %d months) = %v, expected %v",
tt.lastRenewal.Format("2006-01-02"), tt.monthsToAdd,
result.Format("2006-01-02"), tt.expectedRenewal.Format("2006-01-02"))
}
})
}
}
func TestFairUseModeLogic(t *testing.T) {
// Test that Fair Use mode always allows tasks regardless of balance
tests := []struct {
name string
fairUseMode bool
balance int
shouldAllow bool
}{
{
name: "Fair Use - zero balance still allowed",
fairUseMode: true,
balance: 0,
shouldAllow: true,
},
{
name: "Fair Use - normal balance allowed",
fairUseMode: true,
balance: 1000,
shouldAllow: true,
},
{
name: "Not Fair Use - zero balance not allowed",
fairUseMode: false,
balance: 0,
shouldAllow: false,
},
{
name: "Not Fair Use - positive balance allowed",
fairUseMode: false,
balance: 50,
shouldAllow: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the check logic
var allowed bool
if tt.fairUseMode {
allowed = true // Fair Use always allows
} else {
allowed = tt.balance > 0
}
if allowed != tt.shouldAllow {
t.Errorf("FairUseMode=%v, balance=%d: allowed=%v, expected=%v",
tt.fairUseMode, tt.balance, allowed, tt.shouldAllow)
}
})
}
}
func TestBalanceDecrementLogic(t *testing.T) {
// Test that Fair Use mode doesn't decrement balance
tests := []struct {
name string
fairUseMode bool
initialBalance int
expectedAfter int
}{
{
name: "Normal plan - decrement",
fairUseMode: false,
initialBalance: 50,
expectedAfter: 49,
},
{
name: "Fair Use - no decrement",
fairUseMode: true,
initialBalance: 1000,
expectedAfter: 1000,
},
{
name: "Normal plan - last task",
fairUseMode: false,
initialBalance: 1,
expectedAfter: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
newBalance := tt.initialBalance
if !tt.fairUseMode {
newBalance = tt.initialBalance - 1
}
if newBalance != tt.expectedAfter {
t.Errorf("FairUseMode=%v, initial=%d: got %d, expected %d",
tt.fairUseMode, tt.initialBalance, newBalance, tt.expectedAfter)
}
})
}
}

View File

@@ -0,0 +1,194 @@
package services
import (
"context"
"fmt"
"time"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
)
// UsageService handles usage tracking operations
type UsageService struct {
db *database.DB
entitlementService *EntitlementService
}
// NewUsageService creates a new UsageService
func NewUsageService(db *database.DB, entitlementService *EntitlementService) *UsageService {
return &UsageService{
db: db,
entitlementService: entitlementService,
}
}
// TrackUsage tracks usage for a user
func (s *UsageService) TrackUsage(ctx context.Context, userIDStr, usageType string, quantity int) error {
userID, err := uuid.Parse(userIDStr)
if err != nil {
return fmt.Errorf("invalid user ID: %w", err)
}
// Get current period start (beginning of current month)
now := time.Now()
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
// Upsert usage summary
query := `
INSERT INTO usage_summary (user_id, usage_type, period_start, total_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, usage_type, period_start) DO UPDATE SET
total_count = usage_summary.total_count + EXCLUDED.total_count,
updated_at = NOW()
`
_, err = s.db.Pool.Exec(ctx, query, userID, usageType, periodStart, quantity)
if err != nil {
return fmt.Errorf("failed to track usage: %w", err)
}
// Also update entitlements cache
return s.entitlementService.IncrementUsage(ctx, userID, usageType, quantity)
}
// GetUsageSummary returns usage summary for a user
func (s *UsageService) GetUsageSummary(ctx context.Context, userID uuid.UUID) (*models.UsageInfo, error) {
// Get entitlements (which include current usage)
ent, err := s.entitlementService.getUserEntitlements(ctx, userID)
if err != nil || ent == nil {
return nil, err
}
// Calculate percentages
aiPercent := 0.0
if ent.AIRequestsLimit > 0 {
aiPercent = float64(ent.AIRequestsUsed) / float64(ent.AIRequestsLimit) * 100
}
docPercent := 0.0
if ent.DocumentsLimit > 0 {
docPercent = float64(ent.DocumentsUsed) / float64(ent.DocumentsLimit) * 100
}
// Get period dates
now := time.Now()
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
periodEnd := periodStart.AddDate(0, 1, 0).Add(-time.Second)
return &models.UsageInfo{
AIRequestsUsed: ent.AIRequestsUsed,
AIRequestsLimit: ent.AIRequestsLimit,
AIRequestsPercent: aiPercent,
DocumentsUsed: ent.DocumentsUsed,
DocumentsLimit: ent.DocumentsLimit,
DocumentsPercent: docPercent,
PeriodStart: periodStart.Format("2006-01-02"),
PeriodEnd: periodEnd.Format("2006-01-02"),
}, nil
}
// CheckUsageAllowed checks if a user is allowed to perform a usage action
func (s *UsageService) CheckUsageAllowed(ctx context.Context, userIDStr, usageType string) (*models.CheckUsageResponse, error) {
userID, err := uuid.Parse(userIDStr)
if err != nil {
return &models.CheckUsageResponse{
Allowed: false,
Message: "Invalid user ID",
}, nil
}
// Get entitlements
ent, err := s.entitlementService.getUserEntitlements(ctx, userID)
if err != nil {
return &models.CheckUsageResponse{
Allowed: false,
Message: "Failed to get entitlements",
}, nil
}
if ent == nil {
return &models.CheckUsageResponse{
Allowed: false,
Message: "No subscription found",
}, nil
}
var currentUsage, limit int
switch usageType {
case "ai_request":
currentUsage = ent.AIRequestsUsed
limit = ent.AIRequestsLimit
case "document_created":
currentUsage = ent.DocumentsUsed
limit = ent.DocumentsLimit
default:
return &models.CheckUsageResponse{
Allowed: true,
Message: "Unknown usage type - allowing",
}, nil
}
remaining := limit - currentUsage
allowed := remaining > 0
response := &models.CheckUsageResponse{
Allowed: allowed,
CurrentUsage: currentUsage,
Limit: limit,
Remaining: remaining,
}
if !allowed {
response.Message = fmt.Sprintf("Usage limit reached for %s (%d/%d)", usageType, currentUsage, limit)
}
return response, nil
}
// GetUsageHistory returns usage history for a user
func (s *UsageService) GetUsageHistory(ctx context.Context, userID uuid.UUID, months int) ([]models.UsageSummary, error) {
query := `
SELECT id, user_id, usage_type, period_start, total_count, created_at, updated_at
FROM usage_summary
WHERE user_id = $1
AND period_start >= $2
ORDER BY period_start DESC, usage_type
`
// Calculate start date
startDate := time.Now().AddDate(0, -months, 0)
startDate = time.Date(startDate.Year(), startDate.Month(), 1, 0, 0, 0, 0, time.UTC)
rows, err := s.db.Pool.Query(ctx, query, userID, startDate)
if err != nil {
return nil, err
}
defer rows.Close()
var summaries []models.UsageSummary
for rows.Next() {
var summary models.UsageSummary
err := rows.Scan(
&summary.ID, &summary.UserID, &summary.UsageType,
&summary.PeriodStart, &summary.TotalCount,
&summary.CreatedAt, &summary.UpdatedAt,
)
if err != nil {
return nil, err
}
summaries = append(summaries, summary)
}
return summaries, nil
}
// ResetPeriodUsage resets usage for a new billing period
func (s *UsageService) ResetPeriodUsage(ctx context.Context, userID uuid.UUID) error {
now := time.Now()
newPeriodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
newPeriodEnd := newPeriodStart.AddDate(0, 1, 0).Add(-time.Second)
return s.entitlementService.ResetUsageCounters(ctx, userID, &newPeriodStart, &newPeriodEnd)
}