fix: Restore all files lost during destructive rebase

A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.

This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).

Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-02-09 09:51:32 +01:00
parent f7487ee240
commit bfdaf63ba9
2009 changed files with 749983 additions and 1731 deletions

View File

@@ -0,0 +1,427 @@
package handlers
import (
"net/http"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/middleware"
"github.com/breakpilot/billing-service/internal/models"
"github.com/breakpilot/billing-service/internal/services"
"github.com/gin-gonic/gin"
)
// BillingHandler handles billing-related HTTP requests
type BillingHandler struct {
db *database.DB
subscriptionService *services.SubscriptionService
stripeService *services.StripeService
entitlementService *services.EntitlementService
usageService *services.UsageService
}
// NewBillingHandler creates a new BillingHandler
func NewBillingHandler(
db *database.DB,
subscriptionService *services.SubscriptionService,
stripeService *services.StripeService,
entitlementService *services.EntitlementService,
usageService *services.UsageService,
) *BillingHandler {
return &BillingHandler{
db: db,
subscriptionService: subscriptionService,
stripeService: stripeService,
entitlementService: entitlementService,
usageService: usageService,
}
}
// GetBillingStatus returns the current billing status for a user
// GET /api/v1/billing/status
func (h *BillingHandler) GetBillingStatus(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get subscription",
})
return
}
// Get available plans
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get plans",
})
return
}
response := models.BillingStatusResponse{
HasSubscription: subscription != nil,
AvailablePlans: plans,
}
if subscription != nil {
// Get plan details
plan, _ := h.subscriptionService.GetPlanByID(ctx, string(subscription.PlanID))
subInfo := &models.SubscriptionInfo{
PlanID: subscription.PlanID,
Status: subscription.Status,
IsTrialing: subscription.Status == models.StatusTrialing,
CancelAtPeriodEnd: subscription.CancelAtPeriodEnd,
CurrentPeriodEnd: subscription.CurrentPeriodEnd,
}
if plan != nil {
subInfo.PlanName = plan.Name
subInfo.PriceCents = plan.PriceCents
subInfo.Currency = plan.Currency
}
// Calculate trial days left
if subscription.TrialEnd != nil && subscription.Status == models.StatusTrialing {
// TODO: Calculate days left
}
response.Subscription = subInfo
// Get task usage info (legacy usage tracking - see TaskService for new task-based usage)
// TODO: Replace with TaskService.GetTaskUsageInfo for task-based billing
_, _ = h.usageService.GetUsageSummary(ctx, userID)
// Get entitlements
entitlements, _ := h.entitlementService.GetEntitlements(ctx, userID)
if entitlements != nil {
response.Entitlements = entitlements
}
}
c.JSON(http.StatusOK, response)
}
// GetPlans returns all available billing plans
// GET /api/v1/billing/plans
func (h *BillingHandler) GetPlans(c *gin.Context) {
ctx := c.Request.Context()
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get plans",
})
return
}
c.JSON(http.StatusOK, gin.H{
"plans": plans,
})
}
// StartTrial starts a trial for the user with a specific plan
// POST /api/v1/billing/trial/start
func (h *BillingHandler) StartTrial(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
var req models.StartTrialRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
// Check if user already has a subscription
existing, _ := h.subscriptionService.GetByUserID(ctx, userID)
if existing != nil {
c.JSON(http.StatusConflict, gin.H{
"error": "subscription_exists",
"message": "User already has a subscription",
})
return
}
// Get user email from context
email, _ := c.Get("email")
emailStr, _ := email.(string)
// Create Stripe checkout session
checkoutURL, sessionID, err := h.stripeService.CreateCheckoutSession(ctx, userID, emailStr, req.PlanID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to create checkout session",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.StartTrialResponse{
CheckoutURL: checkoutURL,
SessionID: sessionID,
})
}
// ChangePlan changes the user's subscription plan
// POST /api/v1/billing/change-plan
func (h *BillingHandler) ChangePlan(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
var req models.ChangePlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Change plan via Stripe
err = h.stripeService.ChangePlan(ctx, subscription.StripeSubscriptionID, req.NewPlanID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to change plan",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.ChangePlanResponse{
Success: true,
Message: "Plan changed successfully",
})
}
// CancelSubscription cancels the user's subscription at period end
// POST /api/v1/billing/cancel
func (h *BillingHandler) CancelSubscription(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Cancel at period end via Stripe
err = h.stripeService.CancelSubscription(ctx, subscription.StripeSubscriptionID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to cancel subscription",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.CancelSubscriptionResponse{
Success: true,
Message: "Subscription will be canceled at the end of the billing period",
})
}
// GetCustomerPortal returns a URL to the Stripe customer portal
// GET /api/v1/billing/portal
func (h *BillingHandler) GetCustomerPortal(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil || subscription.StripeCustomerID == "" {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Create portal session
portalURL, err := h.stripeService.CreateCustomerPortalSession(ctx, subscription.StripeCustomerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to create portal session",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.CustomerPortalResponse{
PortalURL: portalURL,
})
}
// =============================================
// Internal Endpoints (Service-to-Service)
// =============================================
// GetEntitlements returns entitlements for a user (internal)
// GET /api/v1/billing/entitlements/:userId
func (h *BillingHandler) GetEntitlements(c *gin.Context) {
userIDStr := c.Param("userId")
ctx := c.Request.Context()
entitlements, err := h.entitlementService.GetEntitlementsByUserIDString(ctx, userIDStr)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get entitlements",
})
return
}
if entitlements == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "not_found",
"message": "No entitlements found for user",
})
return
}
c.JSON(http.StatusOK, entitlements)
}
// TrackUsage tracks usage for a user (internal)
// POST /api/v1/billing/usage/track
func (h *BillingHandler) TrackUsage(c *gin.Context) {
var req models.TrackUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
quantity := req.Quantity
if quantity <= 0 {
quantity = 1
}
err := h.usageService.TrackUsage(ctx, req.UserID, req.UsageType, quantity)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to track usage",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Usage tracked",
})
}
// CheckUsage checks if usage is allowed (internal)
// GET /api/v1/billing/usage/check/:userId/:type
func (h *BillingHandler) CheckUsage(c *gin.Context) {
userIDStr := c.Param("userId")
usageType := c.Param("type")
ctx := c.Request.Context()
response, err := h.usageService.CheckUsageAllowed(ctx, userIDStr, usageType)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to check usage",
})
return
}
c.JSON(http.StatusOK, response)
}
// CheckEntitlement checks if a user has a specific entitlement (internal)
// GET /api/v1/billing/entitlements/check/:userId/:feature
func (h *BillingHandler) CheckEntitlement(c *gin.Context) {
userIDStr := c.Param("userId")
feature := c.Param("feature")
ctx := c.Request.Context()
hasEntitlement, planID, err := h.entitlementService.CheckEntitlement(ctx, userIDStr, feature)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to check entitlement",
})
return
}
c.JSON(http.StatusOK, models.EntitlementCheckResponse{
HasEntitlement: hasEntitlement,
PlanID: planID,
})
}

View File

@@ -0,0 +1,612 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/breakpilot/billing-service/internal/models"
"github.com/gin-gonic/gin"
)
func init() {
// Set Gin to test mode
gin.SetMode(gin.TestMode)
}
func TestGetPlans_ResponseFormat(t *testing.T) {
// Test that GetPlans returns the expected response structure
// Since we don't have a real database connection in unit tests,
// we test the expected structure and format
// Test that default plans are well-formed
plans := models.GetDefaultPlans()
if len(plans) == 0 {
t.Error("Default plans should not be empty")
}
for _, plan := range plans {
// Verify JSON serialization works
data, err := json.Marshal(plan)
if err != nil {
t.Errorf("Failed to marshal plan %s: %v", plan.ID, err)
}
// Verify we can unmarshal back
var decoded models.BillingPlan
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Errorf("Failed to unmarshal plan %s: %v", plan.ID, err)
}
// Verify key fields
if decoded.ID != plan.ID {
t.Errorf("Plan ID mismatch: got %s, expected %s", decoded.ID, plan.ID)
}
}
}
func TestBillingStatusResponse_Structure(t *testing.T) {
// Test the response structure
response := models.BillingStatusResponse{
HasSubscription: true,
Subscription: &models.SubscriptionInfo{
PlanID: models.PlanStandard,
PlanName: "Standard",
Status: models.StatusActive,
IsTrialing: false,
CancelAtPeriodEnd: false,
PriceCents: 1990,
Currency: "eur",
},
TaskUsage: &models.TaskUsageInfo{
TasksAvailable: 85,
MaxTasks: 500,
InfoText: "Aufgaben verfuegbar: 85 von max. 500",
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
},
Entitlements: &models.EntitlementInfo{
Features: []string{"basic_ai", "basic_documents", "templates", "batch_processing"},
MaxTeamMembers: 3,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: false,
},
AvailablePlans: models.GetDefaultPlans(),
}
// Test JSON serialization
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal BillingStatusResponse: %v", err)
}
// Verify it's valid JSON
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
// Check required fields exist
if _, ok := decoded["has_subscription"]; !ok {
t.Error("Response should have 'has_subscription' field")
}
}
func TestStartTrialRequest_Validation(t *testing.T) {
tests := []struct {
name string
request models.StartTrialRequest
wantError bool
}{
{
name: "Valid basic plan",
request: models.StartTrialRequest{PlanID: models.PlanBasic},
wantError: false,
},
{
name: "Valid standard plan",
request: models.StartTrialRequest{PlanID: models.PlanStandard},
wantError: false,
},
{
name: "Valid premium plan",
request: models.StartTrialRequest{PlanID: models.PlanPremium},
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test JSON serialization
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal request: %v", err)
}
var decoded models.StartTrialRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal request: %v", err)
}
if decoded.PlanID != tt.request.PlanID {
t.Errorf("PlanID mismatch: got %s, expected %s", decoded.PlanID, tt.request.PlanID)
}
})
}
}
func TestChangePlanRequest_Structure(t *testing.T) {
request := models.ChangePlanRequest{
NewPlanID: models.PlanPremium,
}
data, err := json.Marshal(request)
if err != nil {
t.Fatalf("Failed to marshal ChangePlanRequest: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["new_plan_id"]; !ok {
t.Error("Request should have 'new_plan_id' field")
}
}
func TestStartTrialResponse_Structure(t *testing.T) {
response := models.StartTrialResponse{
CheckoutURL: "https://checkout.stripe.com/c/pay/cs_test_123",
SessionID: "cs_test_123",
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal StartTrialResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["checkout_url"]; !ok {
t.Error("Response should have 'checkout_url' field")
}
if _, ok := decoded["session_id"]; !ok {
t.Error("Response should have 'session_id' field")
}
}
func TestCancelSubscriptionResponse_Structure(t *testing.T) {
response := models.CancelSubscriptionResponse{
Success: true,
Message: "Subscription will be canceled at the end of the billing period",
CancelDate: "2025-01-16",
ActiveUntil: "2025-01-16",
}
_, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal CancelSubscriptionResponse: %v", err)
}
if !response.Success {
t.Error("Success should be true")
}
}
func TestCustomerPortalResponse_Structure(t *testing.T) {
response := models.CustomerPortalResponse{
PortalURL: "https://billing.stripe.com/p/session/test_123",
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal CustomerPortalResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["portal_url"]; !ok {
t.Error("Response should have 'portal_url' field")
}
}
func TestEntitlementCheckResponse_Structure(t *testing.T) {
tests := []struct {
name string
response models.EntitlementCheckResponse
}{
{
name: "Has entitlement",
response: models.EntitlementCheckResponse{
HasEntitlement: true,
PlanID: models.PlanStandard,
},
},
{
name: "No entitlement",
response: models.EntitlementCheckResponse{
HasEntitlement: false,
PlanID: models.PlanBasic,
Message: "Feature not available in this plan",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal EntitlementCheckResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["has_entitlement"]; !ok {
t.Error("Response should have 'has_entitlement' field")
}
})
}
}
func TestTrackUsageRequest_Validation(t *testing.T) {
tests := []struct {
name string
request models.TrackUsageRequest
valid bool
}{
{
name: "Valid AI request",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "ai_request",
Quantity: 1,
},
valid: true,
},
{
name: "Valid document created",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "document_created",
Quantity: 1,
},
valid: true,
},
{
name: "Multiple quantity",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "ai_request",
Quantity: 5,
},
valid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal TrackUsageRequest: %v", err)
}
var decoded models.TrackUsageRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal TrackUsageRequest: %v", err)
}
if decoded.UserID != tt.request.UserID {
t.Errorf("UserID mismatch: got %s, expected %s", decoded.UserID, tt.request.UserID)
}
})
}
}
func TestCheckUsageResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.CheckUsageResponse
}{
{
name: "Allowed response",
response: models.CheckUsageResponse{
Allowed: true,
CurrentUsage: 450,
Limit: 1500,
Remaining: 1050,
},
},
{
name: "Limit reached",
response: models.CheckUsageResponse{
Allowed: false,
CurrentUsage: 1500,
Limit: 1500,
Remaining: 0,
Message: "Usage limit reached for ai_request (1500/1500)",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal CheckUsageResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["allowed"]; !ok {
t.Error("Response should have 'allowed' field")
}
})
}
}
func TestConsumeTaskRequest_Format(t *testing.T) {
tests := []struct {
name string
request models.ConsumeTaskRequest
}{
{
name: "Correction task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeCorrection,
},
},
{
name: "Letter task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeLetter,
},
},
{
name: "Batch task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeBatch,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal ConsumeTaskRequest: %v", err)
}
var decoded models.ConsumeTaskRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal ConsumeTaskRequest: %v", err)
}
if decoded.TaskType != tt.request.TaskType {
t.Errorf("TaskType mismatch: got %s, expected %s", decoded.TaskType, tt.request.TaskType)
}
})
}
}
func TestConsumeTaskResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.ConsumeTaskResponse
}{
{
name: "Successful consumption",
response: models.ConsumeTaskResponse{
Success: true,
TaskID: "task-uuid-123",
TasksRemaining: 49,
},
},
{
name: "Limit reached",
response: models.ConsumeTaskResponse{
Success: false,
TasksRemaining: 0,
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal ConsumeTaskResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["success"]; !ok {
t.Error("Response should have 'success' field")
}
if _, ok := decoded["tasks_remaining"]; !ok {
t.Error("Response should have 'tasks_remaining' field")
}
})
}
}
func TestCheckTaskAllowedResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.CheckTaskAllowedResponse
}{
{
name: "Task allowed",
response: models.CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: 50,
MaxTasks: 150,
PlanID: models.PlanBasic,
},
},
{
name: "Task not allowed",
response: models.CheckTaskAllowedResponse{
Allowed: false,
TasksAvailable: 0,
MaxTasks: 150,
PlanID: models.PlanBasic,
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
},
},
{
name: "Premium Fair Use",
response: models.CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: 1000,
MaxTasks: 5000,
PlanID: models.PlanPremium,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal CheckTaskAllowedResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["allowed"]; !ok {
t.Error("Response should have 'allowed' field")
}
if _, ok := decoded["tasks_available"]; !ok {
t.Error("Response should have 'tasks_available' field")
}
if _, ok := decoded["plan_id"]; !ok {
t.Error("Response should have 'plan_id' field")
}
})
}
}
// HTTP Handler Tests (without DB)
func TestHTTPErrorResponse_Format(t *testing.T) {
// Test standard error response format
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Simulate an error response
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if _, ok := response["error"]; !ok {
t.Error("Error response should have 'error' field")
}
if _, ok := response["message"]; !ok {
t.Error("Error response should have 'message' field")
}
}
func TestHTTPSuccessResponse_Format(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Simulate a success response
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Operation completed",
})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["success"] != true {
t.Error("Success response should have success=true")
}
}
func TestRequestParsing_InvalidJSON(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request with invalid JSON
invalidJSON := []byte(`{"plan_id": }`) // Invalid JSON
c.Request = httptest.NewRequest("POST", "/test", bytes.NewReader(invalidJSON))
c.Request.Header.Set("Content-Type", "application/json")
var req models.StartTrialRequest
err := c.ShouldBindJSON(&req)
if err == nil {
t.Error("Should return error for invalid JSON")
}
}
func TestHTTPHeaders_ContentType(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"test": "value"})
contentType := w.Header().Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Errorf("Expected JSON content type, got %s", contentType)
}
}

View File

@@ -0,0 +1,205 @@
package handlers
import (
"io"
"log"
"net/http"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76/webhook"
)
// WebhookHandler handles Stripe webhook events
type WebhookHandler struct {
db *database.DB
webhookSecret string
subscriptionService *services.SubscriptionService
entitlementService *services.EntitlementService
}
// NewWebhookHandler creates a new WebhookHandler
func NewWebhookHandler(
db *database.DB,
webhookSecret string,
subscriptionService *services.SubscriptionService,
entitlementService *services.EntitlementService,
) *WebhookHandler {
return &WebhookHandler{
db: db,
webhookSecret: webhookSecret,
subscriptionService: subscriptionService,
entitlementService: entitlementService,
}
}
// HandleStripeWebhook handles incoming Stripe webhook events
// POST /api/v1/billing/webhook
func (h *WebhookHandler) HandleStripeWebhook(c *gin.Context) {
// Read the request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("Webhook: Error reading body: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "cannot read body"})
return
}
// Get the Stripe signature header
sigHeader := c.GetHeader("Stripe-Signature")
if sigHeader == "" {
log.Printf("Webhook: Missing Stripe-Signature header")
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
return
}
// Verify the webhook signature
event, err := webhook.ConstructEvent(body, sigHeader, h.webhookSecret)
if err != nil {
log.Printf("Webhook: Signature verification failed: %v", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
return
}
ctx := c.Request.Context()
// Check if we've already processed this event (idempotency)
processed, err := h.subscriptionService.IsEventProcessed(ctx, event.ID)
if err != nil {
log.Printf("Webhook: Error checking event: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
return
}
if processed {
log.Printf("Webhook: Event %s already processed", event.ID)
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
return
}
// Mark event as being processed
if err := h.subscriptionService.MarkEventProcessing(ctx, event.ID, string(event.Type)); err != nil {
log.Printf("Webhook: Error marking event: %v", err)
}
// Handle the event based on type
var handleErr error
switch event.Type {
case "checkout.session.completed":
handleErr = h.handleCheckoutSessionCompleted(ctx, event.Data.Raw)
case "customer.subscription.created":
handleErr = h.handleSubscriptionCreated(ctx, event.Data.Raw)
case "customer.subscription.updated":
handleErr = h.handleSubscriptionUpdated(ctx, event.Data.Raw)
case "customer.subscription.deleted":
handleErr = h.handleSubscriptionDeleted(ctx, event.Data.Raw)
case "invoice.paid":
handleErr = h.handleInvoicePaid(ctx, event.Data.Raw)
case "invoice.payment_failed":
handleErr = h.handleInvoicePaymentFailed(ctx, event.Data.Raw)
case "customer.created":
log.Printf("Webhook: Customer created - %s", event.ID)
default:
log.Printf("Webhook: Unhandled event type: %s", event.Type)
}
if handleErr != nil {
log.Printf("Webhook: Error handling %s: %v", event.Type, handleErr)
// Mark event as failed
h.subscriptionService.MarkEventFailed(ctx, event.ID, handleErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
return
}
// Mark event as processed
if err := h.subscriptionService.MarkEventProcessed(ctx, event.ID); err != nil {
log.Printf("Webhook: Error marking event processed: %v", err)
}
c.JSON(http.StatusOK, gin.H{"status": "processed"})
}
// handleCheckoutSessionCompleted handles successful checkout
func (h *WebhookHandler) handleCheckoutSessionCompleted(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing checkout.session.completed")
// Parse checkout session from data
// The actual implementation will parse the JSON and create/update subscription
// TODO: Implementation
// 1. Parse checkout session data
// 2. Extract customer_id, subscription_id, user_id (from metadata)
// 3. Create or update subscription record
// 4. Update entitlements
return nil
}
// handleSubscriptionCreated handles new subscription creation
func (h *WebhookHandler) handleSubscriptionCreated(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.created")
// TODO: Implementation
// 1. Parse subscription data
// 2. Extract status, plan, trial_end, etc.
// 3. Create subscription record
// 4. Set up initial entitlements
return nil
}
// handleSubscriptionUpdated handles subscription updates
func (h *WebhookHandler) handleSubscriptionUpdated(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.updated")
// TODO: Implementation
// 1. Parse subscription data
// 2. Update subscription record (status, plan, cancel_at_period_end, etc.)
// 3. Update entitlements if plan changed
return nil
}
// handleSubscriptionDeleted handles subscription cancellation
func (h *WebhookHandler) handleSubscriptionDeleted(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.deleted")
// TODO: Implementation
// 1. Parse subscription data
// 2. Update subscription status to canceled/expired
// 3. Remove or downgrade entitlements
return nil
}
// handleInvoicePaid handles successful invoice payment
func (h *WebhookHandler) handleInvoicePaid(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing invoice.paid")
// TODO: Implementation
// 1. Parse invoice data
// 2. Update subscription period
// 3. Reset usage counters for new period
// 4. Store invoice record
return nil
}
// handleInvoicePaymentFailed handles failed invoice payment
func (h *WebhookHandler) handleInvoicePaymentFailed(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing invoice.payment_failed")
// TODO: Implementation
// 1. Parse invoice data
// 2. Update subscription status to past_due
// 3. Send notification to user
// 4. Possibly restrict access
return nil
}

View File

@@ -0,0 +1,433 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// TestWebhookEventTypes tests the event types we handle
func TestWebhookEventTypes(t *testing.T) {
eventTypes := []struct {
eventType string
shouldHandle bool
}{
{"checkout.session.completed", true},
{"customer.subscription.created", true},
{"customer.subscription.updated", true},
{"customer.subscription.deleted", true},
{"invoice.paid", true},
{"invoice.payment_failed", true},
{"customer.created", true}, // Handled but just logged
{"unknown.event.type", false},
}
for _, tt := range eventTypes {
t.Run(tt.eventType, func(t *testing.T) {
if tt.eventType == "" {
t.Error("Event type should not be empty")
}
})
}
}
// TestWebhookRequest_MissingSignature tests handling of missing signature
func TestWebhookRequest_MissingSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request without Stripe-Signature header
body := []byte(`{"id": "evt_test_123", "type": "test.event"}`)
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
// Note: No Stripe-Signature header
// Simulate the check we do in the handler
sigHeader := c.GetHeader("Stripe-Signature")
if sigHeader == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
}
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing signature, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "missing signature" {
t.Errorf("Expected 'missing signature' error, got '%v'", response["error"])
}
}
// TestWebhookRequest_EmptyBody tests handling of empty request body
func TestWebhookRequest_EmptyBody(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request with empty body
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader([]byte{}))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("Stripe-Signature", "t=123,v1=signature")
// Read the body
body := make([]byte, 0)
// Simulate empty body handling
if len(body) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "empty body"})
}
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for empty body, got %d", w.Code)
}
}
// TestWebhookIdempotency tests idempotency behavior
func TestWebhookIdempotency(t *testing.T) {
// Test that the same event ID should not be processed twice
eventID := "evt_test_123456789"
// Simulate event tracking
processedEvents := make(map[string]bool)
// First time - should process
if !processedEvents[eventID] {
processedEvents[eventID] = true
}
// Second time - should skip
alreadyProcessed := processedEvents[eventID]
if !alreadyProcessed {
t.Error("Event should be marked as processed")
}
}
// TestWebhookResponse_Processed tests successful webhook response
func TestWebhookResponse_Processed(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"status": "processed"})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["status"] != "processed" {
t.Errorf("Expected status 'processed', got '%v'", response["status"])
}
}
// TestWebhookResponse_AlreadyProcessed tests idempotent response
func TestWebhookResponse_AlreadyProcessed(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["status"] != "already_processed" {
t.Errorf("Expected status 'already_processed', got '%v'", response["status"])
}
}
// TestWebhookResponse_InternalError tests error response
func TestWebhookResponse_InternalError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
if w.Code != http.StatusInternalServerError {
t.Errorf("Expected status 500, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "handler error" {
t.Errorf("Expected 'handler error', got '%v'", response["error"])
}
}
// TestWebhookResponse_InvalidSignature tests signature verification failure
func TestWebhookResponse_InvalidSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "invalid signature" {
t.Errorf("Expected 'invalid signature', got '%v'", response["error"])
}
}
// TestCheckoutSessionCompleted_EventStructure tests the event data structure
func TestCheckoutSessionCompleted_EventStructure(t *testing.T) {
// Test the expected structure of a checkout.session.completed event
eventData := map[string]interface{}{
"id": "cs_test_123",
"customer": "cus_test_456",
"subscription": "sub_test_789",
"mode": "subscription",
"payment_status": "paid",
"status": "complete",
"metadata": map[string]interface{}{
"user_id": "550e8400-e29b-41d4-a716-446655440000",
"plan_id": "standard",
},
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["customer"] == nil {
t.Error("Event should have 'customer' field")
}
if decoded["subscription"] == nil {
t.Error("Event should have 'subscription' field")
}
metadata, ok := decoded["metadata"].(map[string]interface{})
if !ok || metadata["user_id"] == nil {
t.Error("Event should have 'metadata.user_id' field")
}
}
// TestSubscriptionCreated_EventStructure tests subscription.created event structure
func TestSubscriptionCreated_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "sub_test_123",
"customer": "cus_test_456",
"status": "trialing",
"items": map[string]interface{}{
"data": []map[string]interface{}{
{
"price": map[string]interface{}{
"id": "price_test_789",
"metadata": map[string]interface{}{"plan_id": "standard"},
},
},
},
},
"trial_end": 1735689600,
"current_period_end": 1735689600,
"metadata": map[string]interface{}{
"user_id": "550e8400-e29b-41d4-a716-446655440000",
"plan_id": "standard",
},
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "trialing" {
t.Errorf("Expected status 'trialing', got '%v'", decoded["status"])
}
}
// TestSubscriptionUpdated_StatusTransitions tests subscription status transitions
func TestSubscriptionUpdated_StatusTransitions(t *testing.T) {
validTransitions := []struct {
from string
to string
}{
{"trialing", "active"},
{"active", "past_due"},
{"past_due", "active"},
{"active", "canceled"},
{"trialing", "canceled"},
}
for _, tt := range validTransitions {
t.Run(tt.from+"->"+tt.to, func(t *testing.T) {
if tt.from == "" || tt.to == "" {
t.Error("Status should not be empty")
}
})
}
}
// TestInvoicePaid_EventStructure tests invoice.paid event structure
func TestInvoicePaid_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "in_test_123",
"subscription": "sub_test_456",
"customer": "cus_test_789",
"status": "paid",
"amount_paid": 1990,
"currency": "eur",
"period_start": 1735689600,
"period_end": 1738368000,
"hosted_invoice_url": "https://invoice.stripe.com/test",
"invoice_pdf": "https://invoice.stripe.com/test.pdf",
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "paid" {
t.Errorf("Expected status 'paid', got '%v'", decoded["status"])
}
if decoded["subscription"] == nil {
t.Error("Event should have 'subscription' field")
}
}
// TestInvoicePaymentFailed_EventStructure tests invoice.payment_failed event structure
func TestInvoicePaymentFailed_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "in_test_123",
"subscription": "sub_test_456",
"customer": "cus_test_789",
"status": "open",
"attempt_count": 1,
"next_payment_attempt": 1735776000,
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify fields
if decoded["attempt_count"] == nil {
t.Error("Event should have 'attempt_count' field")
}
}
// TestSubscriptionDeleted_EventStructure tests subscription.deleted event structure
func TestSubscriptionDeleted_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "sub_test_123",
"customer": "cus_test_456",
"status": "canceled",
"ended_at": 1735689600,
"canceled_at": 1735689600,
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "canceled" {
t.Errorf("Expected status 'canceled', got '%v'", decoded["status"])
}
}
// TestStripeSignatureFormat tests the Stripe signature header format
func TestStripeSignatureFormat(t *testing.T) {
// Stripe signature format: t=timestamp,v1=signature
validSignatures := []string{
"t=1609459200,v1=abc123def456",
"t=1609459200,v1=signature_here,v0=old_signature",
}
for _, sig := range validSignatures {
if len(sig) < 10 {
t.Errorf("Signature seems too short: %s", sig)
}
// Should start with timestamp
if sig[:2] != "t=" {
t.Errorf("Signature should start with 't=': %s", sig)
}
}
}
// TestWebhookEventID_Format tests Stripe event ID format
func TestWebhookEventID_Format(t *testing.T) {
validEventIDs := []string{
"evt_1234567890abcdef",
"evt_test_123456789",
"evt_live_987654321",
}
for _, eventID := range validEventIDs {
// Event IDs should start with "evt_"
if len(eventID) < 10 || eventID[:4] != "evt_" {
t.Errorf("Invalid event ID format: %s", eventID)
}
}
}