fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
427
billing-service/internal/handlers/billing_handlers.go
Normal file
427
billing-service/internal/handlers/billing_handlers.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/middleware"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/breakpilot/billing-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BillingHandler handles billing-related HTTP requests
|
||||
type BillingHandler struct {
|
||||
db *database.DB
|
||||
subscriptionService *services.SubscriptionService
|
||||
stripeService *services.StripeService
|
||||
entitlementService *services.EntitlementService
|
||||
usageService *services.UsageService
|
||||
}
|
||||
|
||||
// NewBillingHandler creates a new BillingHandler
|
||||
func NewBillingHandler(
|
||||
db *database.DB,
|
||||
subscriptionService *services.SubscriptionService,
|
||||
stripeService *services.StripeService,
|
||||
entitlementService *services.EntitlementService,
|
||||
usageService *services.UsageService,
|
||||
) *BillingHandler {
|
||||
return &BillingHandler{
|
||||
db: db,
|
||||
subscriptionService: subscriptionService,
|
||||
stripeService: stripeService,
|
||||
entitlementService: entitlementService,
|
||||
usageService: usageService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetBillingStatus returns the current billing status for a user
|
||||
// GET /api/v1/billing/status
|
||||
func (h *BillingHandler) GetBillingStatus(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get subscription
|
||||
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to get subscription",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get available plans
|
||||
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to get plans",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := models.BillingStatusResponse{
|
||||
HasSubscription: subscription != nil,
|
||||
AvailablePlans: plans,
|
||||
}
|
||||
|
||||
if subscription != nil {
|
||||
// Get plan details
|
||||
plan, _ := h.subscriptionService.GetPlanByID(ctx, string(subscription.PlanID))
|
||||
|
||||
subInfo := &models.SubscriptionInfo{
|
||||
PlanID: subscription.PlanID,
|
||||
Status: subscription.Status,
|
||||
IsTrialing: subscription.Status == models.StatusTrialing,
|
||||
CancelAtPeriodEnd: subscription.CancelAtPeriodEnd,
|
||||
CurrentPeriodEnd: subscription.CurrentPeriodEnd,
|
||||
}
|
||||
|
||||
if plan != nil {
|
||||
subInfo.PlanName = plan.Name
|
||||
subInfo.PriceCents = plan.PriceCents
|
||||
subInfo.Currency = plan.Currency
|
||||
}
|
||||
|
||||
// Calculate trial days left
|
||||
if subscription.TrialEnd != nil && subscription.Status == models.StatusTrialing {
|
||||
// TODO: Calculate days left
|
||||
}
|
||||
|
||||
response.Subscription = subInfo
|
||||
|
||||
// Get task usage info (legacy usage tracking - see TaskService for new task-based usage)
|
||||
// TODO: Replace with TaskService.GetTaskUsageInfo for task-based billing
|
||||
_, _ = h.usageService.GetUsageSummary(ctx, userID)
|
||||
|
||||
// Get entitlements
|
||||
entitlements, _ := h.entitlementService.GetEntitlements(ctx, userID)
|
||||
if entitlements != nil {
|
||||
response.Entitlements = entitlements
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetPlans returns all available billing plans
|
||||
// GET /api/v1/billing/plans
|
||||
func (h *BillingHandler) GetPlans(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to get plans",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"plans": plans,
|
||||
})
|
||||
}
|
||||
|
||||
// StartTrial starts a trial for the user with a specific plan
|
||||
// POST /api/v1/billing/trial/start
|
||||
func (h *BillingHandler) StartTrial(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.StartTrialRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"message": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Check if user already has a subscription
|
||||
existing, _ := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if existing != nil {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "subscription_exists",
|
||||
"message": "User already has a subscription",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user email from context
|
||||
email, _ := c.Get("email")
|
||||
emailStr, _ := email.(string)
|
||||
|
||||
// Create Stripe checkout session
|
||||
checkoutURL, sessionID, err := h.stripeService.CreateCheckoutSession(ctx, userID, emailStr, req.PlanID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "stripe_error",
|
||||
"message": "Failed to create checkout session",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.StartTrialResponse{
|
||||
CheckoutURL: checkoutURL,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePlan changes the user's subscription plan
|
||||
// POST /api/v1/billing/change-plan
|
||||
func (h *BillingHandler) ChangePlan(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.ChangePlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"message": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get current subscription
|
||||
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if err != nil || subscription == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "no_subscription",
|
||||
"message": "No active subscription found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Change plan via Stripe
|
||||
err = h.stripeService.ChangePlan(ctx, subscription.StripeSubscriptionID, req.NewPlanID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "stripe_error",
|
||||
"message": "Failed to change plan",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.ChangePlanResponse{
|
||||
Success: true,
|
||||
Message: "Plan changed successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// CancelSubscription cancels the user's subscription at period end
|
||||
// POST /api/v1/billing/cancel
|
||||
func (h *BillingHandler) CancelSubscription(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get current subscription
|
||||
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if err != nil || subscription == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "no_subscription",
|
||||
"message": "No active subscription found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel at period end via Stripe
|
||||
err = h.stripeService.CancelSubscription(ctx, subscription.StripeSubscriptionID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "stripe_error",
|
||||
"message": "Failed to cancel subscription",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.CancelSubscriptionResponse{
|
||||
Success: true,
|
||||
Message: "Subscription will be canceled at the end of the billing period",
|
||||
})
|
||||
}
|
||||
|
||||
// GetCustomerPortal returns a URL to the Stripe customer portal
|
||||
// GET /api/v1/billing/portal
|
||||
func (h *BillingHandler) GetCustomerPortal(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get current subscription
|
||||
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if err != nil || subscription == nil || subscription.StripeCustomerID == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "no_subscription",
|
||||
"message": "No active subscription found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create portal session
|
||||
portalURL, err := h.stripeService.CreateCustomerPortalSession(ctx, subscription.StripeCustomerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "stripe_error",
|
||||
"message": "Failed to create portal session",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.CustomerPortalResponse{
|
||||
PortalURL: portalURL,
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Internal Endpoints (Service-to-Service)
|
||||
// =============================================
|
||||
|
||||
// GetEntitlements returns entitlements for a user (internal)
|
||||
// GET /api/v1/billing/entitlements/:userId
|
||||
func (h *BillingHandler) GetEntitlements(c *gin.Context) {
|
||||
userIDStr := c.Param("userId")
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
entitlements, err := h.entitlementService.GetEntitlementsByUserIDString(ctx, userIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to get entitlements",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if entitlements == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "not_found",
|
||||
"message": "No entitlements found for user",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, entitlements)
|
||||
}
|
||||
|
||||
// TrackUsage tracks usage for a user (internal)
|
||||
// POST /api/v1/billing/usage/track
|
||||
func (h *BillingHandler) TrackUsage(c *gin.Context) {
|
||||
var req models.TrackUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"message": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
quantity := req.Quantity
|
||||
if quantity <= 0 {
|
||||
quantity = 1
|
||||
}
|
||||
|
||||
err := h.usageService.TrackUsage(ctx, req.UserID, req.UsageType, quantity)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to track usage",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Usage tracked",
|
||||
})
|
||||
}
|
||||
|
||||
// CheckUsage checks if usage is allowed (internal)
|
||||
// GET /api/v1/billing/usage/check/:userId/:type
|
||||
func (h *BillingHandler) CheckUsage(c *gin.Context) {
|
||||
userIDStr := c.Param("userId")
|
||||
usageType := c.Param("type")
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
response, err := h.usageService.CheckUsageAllowed(ctx, userIDStr, usageType)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to check usage",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// CheckEntitlement checks if a user has a specific entitlement (internal)
|
||||
// GET /api/v1/billing/entitlements/check/:userId/:feature
|
||||
func (h *BillingHandler) CheckEntitlement(c *gin.Context) {
|
||||
userIDStr := c.Param("userId")
|
||||
feature := c.Param("feature")
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
hasEntitlement, planID, err := h.entitlementService.CheckEntitlement(ctx, userIDStr, feature)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to check entitlement",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.EntitlementCheckResponse{
|
||||
HasEntitlement: hasEntitlement,
|
||||
PlanID: planID,
|
||||
})
|
||||
}
|
||||
612
billing-service/internal/handlers/billing_handlers_test.go
Normal file
612
billing-service/internal/handlers/billing_handlers_test.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Set Gin to test mode
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestGetPlans_ResponseFormat(t *testing.T) {
|
||||
// Test that GetPlans returns the expected response structure
|
||||
// Since we don't have a real database connection in unit tests,
|
||||
// we test the expected structure and format
|
||||
|
||||
// Test that default plans are well-formed
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
if len(plans) == 0 {
|
||||
t.Error("Default plans should not be empty")
|
||||
}
|
||||
|
||||
for _, plan := range plans {
|
||||
// Verify JSON serialization works
|
||||
data, err := json.Marshal(plan)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to marshal plan %s: %v", plan.ID, err)
|
||||
}
|
||||
|
||||
// Verify we can unmarshal back
|
||||
var decoded models.BillingPlan
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to unmarshal plan %s: %v", plan.ID, err)
|
||||
}
|
||||
|
||||
// Verify key fields
|
||||
if decoded.ID != plan.ID {
|
||||
t.Errorf("Plan ID mismatch: got %s, expected %s", decoded.ID, plan.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingStatusResponse_Structure(t *testing.T) {
|
||||
// Test the response structure
|
||||
response := models.BillingStatusResponse{
|
||||
HasSubscription: true,
|
||||
Subscription: &models.SubscriptionInfo{
|
||||
PlanID: models.PlanStandard,
|
||||
PlanName: "Standard",
|
||||
Status: models.StatusActive,
|
||||
IsTrialing: false,
|
||||
CancelAtPeriodEnd: false,
|
||||
PriceCents: 1990,
|
||||
Currency: "eur",
|
||||
},
|
||||
TaskUsage: &models.TaskUsageInfo{
|
||||
TasksAvailable: 85,
|
||||
MaxTasks: 500,
|
||||
InfoText: "Aufgaben verfuegbar: 85 von max. 500",
|
||||
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
|
||||
},
|
||||
Entitlements: &models.EntitlementInfo{
|
||||
Features: []string{"basic_ai", "basic_documents", "templates", "batch_processing"},
|
||||
MaxTeamMembers: 3,
|
||||
PrioritySupport: false,
|
||||
CustomBranding: false,
|
||||
BatchProcessing: true,
|
||||
CustomTemplates: true,
|
||||
FairUseMode: false,
|
||||
},
|
||||
AvailablePlans: models.GetDefaultPlans(),
|
||||
}
|
||||
|
||||
// Test JSON serialization
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal BillingStatusResponse: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's valid JSON
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check required fields exist
|
||||
if _, ok := decoded["has_subscription"]; !ok {
|
||||
t.Error("Response should have 'has_subscription' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartTrialRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request models.StartTrialRequest
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid basic plan",
|
||||
request: models.StartTrialRequest{PlanID: models.PlanBasic},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid standard plan",
|
||||
request: models.StartTrialRequest{PlanID: models.PlanStandard},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid premium plan",
|
||||
request: models.StartTrialRequest{PlanID: models.PlanPremium},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test JSON serialization
|
||||
data, err := json.Marshal(tt.request)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
var decoded models.StartTrialRequest
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
if decoded.PlanID != tt.request.PlanID {
|
||||
t.Errorf("PlanID mismatch: got %s, expected %s", decoded.PlanID, tt.request.PlanID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePlanRequest_Structure(t *testing.T) {
|
||||
request := models.ChangePlanRequest{
|
||||
NewPlanID: models.PlanPremium,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ChangePlanRequest: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["new_plan_id"]; !ok {
|
||||
t.Error("Request should have 'new_plan_id' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartTrialResponse_Structure(t *testing.T) {
|
||||
response := models.StartTrialResponse{
|
||||
CheckoutURL: "https://checkout.stripe.com/c/pay/cs_test_123",
|
||||
SessionID: "cs_test_123",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal StartTrialResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["checkout_url"]; !ok {
|
||||
t.Error("Response should have 'checkout_url' field")
|
||||
}
|
||||
if _, ok := decoded["session_id"]; !ok {
|
||||
t.Error("Response should have 'session_id' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelSubscriptionResponse_Structure(t *testing.T) {
|
||||
response := models.CancelSubscriptionResponse{
|
||||
Success: true,
|
||||
Message: "Subscription will be canceled at the end of the billing period",
|
||||
CancelDate: "2025-01-16",
|
||||
ActiveUntil: "2025-01-16",
|
||||
}
|
||||
|
||||
_, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CancelSubscriptionResponse: %v", err)
|
||||
}
|
||||
|
||||
if !response.Success {
|
||||
t.Error("Success should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomerPortalResponse_Structure(t *testing.T) {
|
||||
response := models.CustomerPortalResponse{
|
||||
PortalURL: "https://billing.stripe.com/p/session/test_123",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CustomerPortalResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["portal_url"]; !ok {
|
||||
t.Error("Response should have 'portal_url' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntitlementCheckResponse_Structure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response models.EntitlementCheckResponse
|
||||
}{
|
||||
{
|
||||
name: "Has entitlement",
|
||||
response: models.EntitlementCheckResponse{
|
||||
HasEntitlement: true,
|
||||
PlanID: models.PlanStandard,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "No entitlement",
|
||||
response: models.EntitlementCheckResponse{
|
||||
HasEntitlement: false,
|
||||
PlanID: models.PlanBasic,
|
||||
Message: "Feature not available in this plan",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal EntitlementCheckResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["has_entitlement"]; !ok {
|
||||
t.Error("Response should have 'has_entitlement' field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackUsageRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request models.TrackUsageRequest
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "Valid AI request",
|
||||
request: models.TrackUsageRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
UsageType: "ai_request",
|
||||
Quantity: 1,
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "Valid document created",
|
||||
request: models.TrackUsageRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
UsageType: "document_created",
|
||||
Quantity: 1,
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple quantity",
|
||||
request: models.TrackUsageRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
UsageType: "ai_request",
|
||||
Quantity: 5,
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.request)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal TrackUsageRequest: %v", err)
|
||||
}
|
||||
|
||||
var decoded models.TrackUsageRequest
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal TrackUsageRequest: %v", err)
|
||||
}
|
||||
|
||||
if decoded.UserID != tt.request.UserID {
|
||||
t.Errorf("UserID mismatch: got %s, expected %s", decoded.UserID, tt.request.UserID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUsageResponse_Format(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response models.CheckUsageResponse
|
||||
}{
|
||||
{
|
||||
name: "Allowed response",
|
||||
response: models.CheckUsageResponse{
|
||||
Allowed: true,
|
||||
CurrentUsage: 450,
|
||||
Limit: 1500,
|
||||
Remaining: 1050,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Limit reached",
|
||||
response: models.CheckUsageResponse{
|
||||
Allowed: false,
|
||||
CurrentUsage: 1500,
|
||||
Limit: 1500,
|
||||
Remaining: 0,
|
||||
Message: "Usage limit reached for ai_request (1500/1500)",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CheckUsageResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["allowed"]; !ok {
|
||||
t.Error("Response should have 'allowed' field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeTaskRequest_Format(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request models.ConsumeTaskRequest
|
||||
}{
|
||||
{
|
||||
name: "Correction task",
|
||||
request: models.ConsumeTaskRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
TaskType: models.TaskTypeCorrection,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Letter task",
|
||||
request: models.ConsumeTaskRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
TaskType: models.TaskTypeLetter,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Batch task",
|
||||
request: models.ConsumeTaskRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
TaskType: models.TaskTypeBatch,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.request)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ConsumeTaskRequest: %v", err)
|
||||
}
|
||||
|
||||
var decoded models.ConsumeTaskRequest
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal ConsumeTaskRequest: %v", err)
|
||||
}
|
||||
|
||||
if decoded.TaskType != tt.request.TaskType {
|
||||
t.Errorf("TaskType mismatch: got %s, expected %s", decoded.TaskType, tt.request.TaskType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeTaskResponse_Format(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response models.ConsumeTaskResponse
|
||||
}{
|
||||
{
|
||||
name: "Successful consumption",
|
||||
response: models.ConsumeTaskResponse{
|
||||
Success: true,
|
||||
TaskID: "task-uuid-123",
|
||||
TasksRemaining: 49,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Limit reached",
|
||||
response: models.ConsumeTaskResponse{
|
||||
Success: false,
|
||||
TasksRemaining: 0,
|
||||
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ConsumeTaskResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["success"]; !ok {
|
||||
t.Error("Response should have 'success' field")
|
||||
}
|
||||
if _, ok := decoded["tasks_remaining"]; !ok {
|
||||
t.Error("Response should have 'tasks_remaining' field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckTaskAllowedResponse_Format(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response models.CheckTaskAllowedResponse
|
||||
}{
|
||||
{
|
||||
name: "Task allowed",
|
||||
response: models.CheckTaskAllowedResponse{
|
||||
Allowed: true,
|
||||
TasksAvailable: 50,
|
||||
MaxTasks: 150,
|
||||
PlanID: models.PlanBasic,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Task not allowed",
|
||||
response: models.CheckTaskAllowedResponse{
|
||||
Allowed: false,
|
||||
TasksAvailable: 0,
|
||||
MaxTasks: 150,
|
||||
PlanID: models.PlanBasic,
|
||||
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Premium Fair Use",
|
||||
response: models.CheckTaskAllowedResponse{
|
||||
Allowed: true,
|
||||
TasksAvailable: 1000,
|
||||
MaxTasks: 5000,
|
||||
PlanID: models.PlanPremium,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CheckTaskAllowedResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["allowed"]; !ok {
|
||||
t.Error("Response should have 'allowed' field")
|
||||
}
|
||||
if _, ok := decoded["tasks_available"]; !ok {
|
||||
t.Error("Response should have 'tasks_available' field")
|
||||
}
|
||||
if _, ok := decoded["plan_id"]; !ok {
|
||||
t.Error("Response should have 'plan_id' field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP Handler Tests (without DB)
|
||||
|
||||
func TestHTTPErrorResponse_Format(t *testing.T) {
|
||||
// Test standard error response format
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Simulate an error response
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status 401, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := response["error"]; !ok {
|
||||
t.Error("Error response should have 'error' field")
|
||||
}
|
||||
if _, ok := response["message"]; !ok {
|
||||
t.Error("Error response should have 'message' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPSuccessResponse_Format(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Simulate a success response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Operation completed",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["success"] != true {
|
||||
t.Error("Success response should have success=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestParsing_InvalidJSON(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create request with invalid JSON
|
||||
invalidJSON := []byte(`{"plan_id": }`) // Invalid JSON
|
||||
c.Request = httptest.NewRequest("POST", "/test", bytes.NewReader(invalidJSON))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
var req models.StartTrialRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Should return error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPHeaders_ContentType(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"test": "value"})
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/json; charset=utf-8" {
|
||||
t.Errorf("Expected JSON content type, got %s", contentType)
|
||||
}
|
||||
}
|
||||
205
billing-service/internal/handlers/webhook_handlers.go
Normal file
205
billing-service/internal/handlers/webhook_handlers.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stripe/stripe-go/v76/webhook"
|
||||
)
|
||||
|
||||
// WebhookHandler handles Stripe webhook events
|
||||
type WebhookHandler struct {
|
||||
db *database.DB
|
||||
webhookSecret string
|
||||
subscriptionService *services.SubscriptionService
|
||||
entitlementService *services.EntitlementService
|
||||
}
|
||||
|
||||
// NewWebhookHandler creates a new WebhookHandler
|
||||
func NewWebhookHandler(
|
||||
db *database.DB,
|
||||
webhookSecret string,
|
||||
subscriptionService *services.SubscriptionService,
|
||||
entitlementService *services.EntitlementService,
|
||||
) *WebhookHandler {
|
||||
return &WebhookHandler{
|
||||
db: db,
|
||||
webhookSecret: webhookSecret,
|
||||
subscriptionService: subscriptionService,
|
||||
entitlementService: entitlementService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleStripeWebhook handles incoming Stripe webhook events
|
||||
// POST /api/v1/billing/webhook
|
||||
func (h *WebhookHandler) HandleStripeWebhook(c *gin.Context) {
|
||||
// Read the request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("Webhook: Error reading body: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "cannot read body"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the Stripe signature header
|
||||
sigHeader := c.GetHeader("Stripe-Signature")
|
||||
if sigHeader == "" {
|
||||
log.Printf("Webhook: Missing Stripe-Signature header")
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the webhook signature
|
||||
event, err := webhook.ConstructEvent(body, sigHeader, h.webhookSecret)
|
||||
if err != nil {
|
||||
log.Printf("Webhook: Signature verification failed: %v", err)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Check if we've already processed this event (idempotency)
|
||||
processed, err := h.subscriptionService.IsEventProcessed(ctx, event.ID)
|
||||
if err != nil {
|
||||
log.Printf("Webhook: Error checking event: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
return
|
||||
}
|
||||
if processed {
|
||||
log.Printf("Webhook: Event %s already processed", event.ID)
|
||||
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
|
||||
return
|
||||
}
|
||||
|
||||
// Mark event as being processed
|
||||
if err := h.subscriptionService.MarkEventProcessing(ctx, event.ID, string(event.Type)); err != nil {
|
||||
log.Printf("Webhook: Error marking event: %v", err)
|
||||
}
|
||||
|
||||
// Handle the event based on type
|
||||
var handleErr error
|
||||
switch event.Type {
|
||||
case "checkout.session.completed":
|
||||
handleErr = h.handleCheckoutSessionCompleted(ctx, event.Data.Raw)
|
||||
|
||||
case "customer.subscription.created":
|
||||
handleErr = h.handleSubscriptionCreated(ctx, event.Data.Raw)
|
||||
|
||||
case "customer.subscription.updated":
|
||||
handleErr = h.handleSubscriptionUpdated(ctx, event.Data.Raw)
|
||||
|
||||
case "customer.subscription.deleted":
|
||||
handleErr = h.handleSubscriptionDeleted(ctx, event.Data.Raw)
|
||||
|
||||
case "invoice.paid":
|
||||
handleErr = h.handleInvoicePaid(ctx, event.Data.Raw)
|
||||
|
||||
case "invoice.payment_failed":
|
||||
handleErr = h.handleInvoicePaymentFailed(ctx, event.Data.Raw)
|
||||
|
||||
case "customer.created":
|
||||
log.Printf("Webhook: Customer created - %s", event.ID)
|
||||
|
||||
default:
|
||||
log.Printf("Webhook: Unhandled event type: %s", event.Type)
|
||||
}
|
||||
|
||||
if handleErr != nil {
|
||||
log.Printf("Webhook: Error handling %s: %v", event.Type, handleErr)
|
||||
// Mark event as failed
|
||||
h.subscriptionService.MarkEventFailed(ctx, event.ID, handleErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Mark event as processed
|
||||
if err := h.subscriptionService.MarkEventProcessed(ctx, event.ID); err != nil {
|
||||
log.Printf("Webhook: Error marking event processed: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "processed"})
|
||||
}
|
||||
|
||||
// handleCheckoutSessionCompleted handles successful checkout
|
||||
func (h *WebhookHandler) handleCheckoutSessionCompleted(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing checkout.session.completed")
|
||||
|
||||
// Parse checkout session from data
|
||||
// The actual implementation will parse the JSON and create/update subscription
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse checkout session data
|
||||
// 2. Extract customer_id, subscription_id, user_id (from metadata)
|
||||
// 3. Create or update subscription record
|
||||
// 4. Update entitlements
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionCreated handles new subscription creation
|
||||
func (h *WebhookHandler) handleSubscriptionCreated(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing customer.subscription.created")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse subscription data
|
||||
// 2. Extract status, plan, trial_end, etc.
|
||||
// 3. Create subscription record
|
||||
// 4. Set up initial entitlements
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionUpdated handles subscription updates
|
||||
func (h *WebhookHandler) handleSubscriptionUpdated(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing customer.subscription.updated")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse subscription data
|
||||
// 2. Update subscription record (status, plan, cancel_at_period_end, etc.)
|
||||
// 3. Update entitlements if plan changed
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionDeleted handles subscription cancellation
|
||||
func (h *WebhookHandler) handleSubscriptionDeleted(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing customer.subscription.deleted")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse subscription data
|
||||
// 2. Update subscription status to canceled/expired
|
||||
// 3. Remove or downgrade entitlements
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaid handles successful invoice payment
|
||||
func (h *WebhookHandler) handleInvoicePaid(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing invoice.paid")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse invoice data
|
||||
// 2. Update subscription period
|
||||
// 3. Reset usage counters for new period
|
||||
// 4. Store invoice record
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaymentFailed handles failed invoice payment
|
||||
func (h *WebhookHandler) handleInvoicePaymentFailed(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing invoice.payment_failed")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse invoice data
|
||||
// 2. Update subscription status to past_due
|
||||
// 3. Send notification to user
|
||||
// 4. Possibly restrict access
|
||||
|
||||
return nil
|
||||
}
|
||||
433
billing-service/internal/handlers/webhook_handlers_test.go
Normal file
433
billing-service/internal/handlers/webhook_handlers_test.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TestWebhookEventTypes tests the event types we handle
|
||||
func TestWebhookEventTypes(t *testing.T) {
|
||||
eventTypes := []struct {
|
||||
eventType string
|
||||
shouldHandle bool
|
||||
}{
|
||||
{"checkout.session.completed", true},
|
||||
{"customer.subscription.created", true},
|
||||
{"customer.subscription.updated", true},
|
||||
{"customer.subscription.deleted", true},
|
||||
{"invoice.paid", true},
|
||||
{"invoice.payment_failed", true},
|
||||
{"customer.created", true}, // Handled but just logged
|
||||
{"unknown.event.type", false},
|
||||
}
|
||||
|
||||
for _, tt := range eventTypes {
|
||||
t.Run(tt.eventType, func(t *testing.T) {
|
||||
if tt.eventType == "" {
|
||||
t.Error("Event type should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookRequest_MissingSignature tests handling of missing signature
|
||||
func TestWebhookRequest_MissingSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create request without Stripe-Signature header
|
||||
body := []byte(`{"id": "evt_test_123", "type": "test.event"}`)
|
||||
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
// Note: No Stripe-Signature header
|
||||
|
||||
// Simulate the check we do in the handler
|
||||
sigHeader := c.GetHeader("Stripe-Signature")
|
||||
if sigHeader == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
|
||||
}
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for missing signature, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "missing signature" {
|
||||
t.Errorf("Expected 'missing signature' error, got '%v'", response["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookRequest_EmptyBody tests handling of empty request body
|
||||
func TestWebhookRequest_EmptyBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create request with empty body
|
||||
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader([]byte{}))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("Stripe-Signature", "t=123,v1=signature")
|
||||
|
||||
// Read the body
|
||||
body := make([]byte, 0)
|
||||
|
||||
// Simulate empty body handling
|
||||
if len(body) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "empty body"})
|
||||
}
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for empty body, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookIdempotency tests idempotency behavior
|
||||
func TestWebhookIdempotency(t *testing.T) {
|
||||
// Test that the same event ID should not be processed twice
|
||||
eventID := "evt_test_123456789"
|
||||
|
||||
// Simulate event tracking
|
||||
processedEvents := make(map[string]bool)
|
||||
|
||||
// First time - should process
|
||||
if !processedEvents[eventID] {
|
||||
processedEvents[eventID] = true
|
||||
}
|
||||
|
||||
// Second time - should skip
|
||||
alreadyProcessed := processedEvents[eventID]
|
||||
if !alreadyProcessed {
|
||||
t.Error("Event should be marked as processed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookResponse_Processed tests successful webhook response
|
||||
func TestWebhookResponse_Processed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "processed"})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "processed" {
|
||||
t.Errorf("Expected status 'processed', got '%v'", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookResponse_AlreadyProcessed tests idempotent response
|
||||
func TestWebhookResponse_AlreadyProcessed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "already_processed" {
|
||||
t.Errorf("Expected status 'already_processed', got '%v'", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookResponse_InternalError tests error response
|
||||
func TestWebhookResponse_InternalError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "handler error" {
|
||||
t.Errorf("Expected 'handler error', got '%v'", response["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookResponse_InvalidSignature tests signature verification failure
|
||||
func TestWebhookResponse_InvalidSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status 401, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "invalid signature" {
|
||||
t.Errorf("Expected 'invalid signature', got '%v'", response["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckoutSessionCompleted_EventStructure tests the event data structure
|
||||
func TestCheckoutSessionCompleted_EventStructure(t *testing.T) {
|
||||
// Test the expected structure of a checkout.session.completed event
|
||||
eventData := map[string]interface{}{
|
||||
"id": "cs_test_123",
|
||||
"customer": "cus_test_456",
|
||||
"subscription": "sub_test_789",
|
||||
"mode": "subscription",
|
||||
"payment_status": "paid",
|
||||
"status": "complete",
|
||||
"metadata": map[string]interface{}{
|
||||
"user_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"plan_id": "standard",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if decoded["customer"] == nil {
|
||||
t.Error("Event should have 'customer' field")
|
||||
}
|
||||
if decoded["subscription"] == nil {
|
||||
t.Error("Event should have 'subscription' field")
|
||||
}
|
||||
metadata, ok := decoded["metadata"].(map[string]interface{})
|
||||
if !ok || metadata["user_id"] == nil {
|
||||
t.Error("Event should have 'metadata.user_id' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscriptionCreated_EventStructure tests subscription.created event structure
|
||||
func TestSubscriptionCreated_EventStructure(t *testing.T) {
|
||||
eventData := map[string]interface{}{
|
||||
"id": "sub_test_123",
|
||||
"customer": "cus_test_456",
|
||||
"status": "trialing",
|
||||
"items": map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{
|
||||
"price": map[string]interface{}{
|
||||
"id": "price_test_789",
|
||||
"metadata": map[string]interface{}{"plan_id": "standard"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"trial_end": 1735689600,
|
||||
"current_period_end": 1735689600,
|
||||
"metadata": map[string]interface{}{
|
||||
"user_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"plan_id": "standard",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if decoded["status"] != "trialing" {
|
||||
t.Errorf("Expected status 'trialing', got '%v'", decoded["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscriptionUpdated_StatusTransitions tests subscription status transitions
|
||||
func TestSubscriptionUpdated_StatusTransitions(t *testing.T) {
|
||||
validTransitions := []struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
{"trialing", "active"},
|
||||
{"active", "past_due"},
|
||||
{"past_due", "active"},
|
||||
{"active", "canceled"},
|
||||
{"trialing", "canceled"},
|
||||
}
|
||||
|
||||
for _, tt := range validTransitions {
|
||||
t.Run(tt.from+"->"+tt.to, func(t *testing.T) {
|
||||
if tt.from == "" || tt.to == "" {
|
||||
t.Error("Status should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvoicePaid_EventStructure tests invoice.paid event structure
|
||||
func TestInvoicePaid_EventStructure(t *testing.T) {
|
||||
eventData := map[string]interface{}{
|
||||
"id": "in_test_123",
|
||||
"subscription": "sub_test_456",
|
||||
"customer": "cus_test_789",
|
||||
"status": "paid",
|
||||
"amount_paid": 1990,
|
||||
"currency": "eur",
|
||||
"period_start": 1735689600,
|
||||
"period_end": 1738368000,
|
||||
"hosted_invoice_url": "https://invoice.stripe.com/test",
|
||||
"invoice_pdf": "https://invoice.stripe.com/test.pdf",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if decoded["status"] != "paid" {
|
||||
t.Errorf("Expected status 'paid', got '%v'", decoded["status"])
|
||||
}
|
||||
if decoded["subscription"] == nil {
|
||||
t.Error("Event should have 'subscription' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvoicePaymentFailed_EventStructure tests invoice.payment_failed event structure
|
||||
func TestInvoicePaymentFailed_EventStructure(t *testing.T) {
|
||||
eventData := map[string]interface{}{
|
||||
"id": "in_test_123",
|
||||
"subscription": "sub_test_456",
|
||||
"customer": "cus_test_789",
|
||||
"status": "open",
|
||||
"attempt_count": 1,
|
||||
"next_payment_attempt": 1735776000,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded["attempt_count"] == nil {
|
||||
t.Error("Event should have 'attempt_count' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscriptionDeleted_EventStructure tests subscription.deleted event structure
|
||||
func TestSubscriptionDeleted_EventStructure(t *testing.T) {
|
||||
eventData := map[string]interface{}{
|
||||
"id": "sub_test_123",
|
||||
"customer": "cus_test_456",
|
||||
"status": "canceled",
|
||||
"ended_at": 1735689600,
|
||||
"canceled_at": 1735689600,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if decoded["status"] != "canceled" {
|
||||
t.Errorf("Expected status 'canceled', got '%v'", decoded["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripeSignatureFormat tests the Stripe signature header format
|
||||
func TestStripeSignatureFormat(t *testing.T) {
|
||||
// Stripe signature format: t=timestamp,v1=signature
|
||||
validSignatures := []string{
|
||||
"t=1609459200,v1=abc123def456",
|
||||
"t=1609459200,v1=signature_here,v0=old_signature",
|
||||
}
|
||||
|
||||
for _, sig := range validSignatures {
|
||||
if len(sig) < 10 {
|
||||
t.Errorf("Signature seems too short: %s", sig)
|
||||
}
|
||||
// Should start with timestamp
|
||||
if sig[:2] != "t=" {
|
||||
t.Errorf("Signature should start with 't=': %s", sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookEventID_Format tests Stripe event ID format
|
||||
func TestWebhookEventID_Format(t *testing.T) {
|
||||
validEventIDs := []string{
|
||||
"evt_1234567890abcdef",
|
||||
"evt_test_123456789",
|
||||
"evt_live_987654321",
|
||||
}
|
||||
|
||||
for _, eventID := range validEventIDs {
|
||||
// Event IDs should start with "evt_"
|
||||
if len(eventID) < 10 || eventID[:4] != "evt_" {
|
||||
t.Errorf("Invalid event ID format: %s", eventID)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user