package services import ( "context" "fmt" "time" "github.com/breakpilot/billing-service/internal/database" "github.com/breakpilot/billing-service/internal/models" "github.com/google/uuid" ) // UsageService handles usage tracking operations type UsageService struct { db *database.DB entitlementService *EntitlementService } // NewUsageService creates a new UsageService func NewUsageService(db *database.DB, entitlementService *EntitlementService) *UsageService { return &UsageService{ db: db, entitlementService: entitlementService, } } // TrackUsage tracks usage for a user func (s *UsageService) TrackUsage(ctx context.Context, userIDStr, usageType string, quantity int) error { userID, err := uuid.Parse(userIDStr) if err != nil { return fmt.Errorf("invalid user ID: %w", err) } // Get current period start (beginning of current month) now := time.Now() periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) // Upsert usage summary query := ` INSERT INTO usage_summary (user_id, usage_type, period_start, total_count) VALUES ($1, $2, $3, $4) ON CONFLICT (user_id, usage_type, period_start) DO UPDATE SET total_count = usage_summary.total_count + EXCLUDED.total_count, updated_at = NOW() ` _, err = s.db.Pool.Exec(ctx, query, userID, usageType, periodStart, quantity) if err != nil { return fmt.Errorf("failed to track usage: %w", err) } // Also update entitlements cache return s.entitlementService.IncrementUsage(ctx, userID, usageType, quantity) } // GetUsageSummary returns usage summary for a user func (s *UsageService) GetUsageSummary(ctx context.Context, userID uuid.UUID) (*models.UsageInfo, error) { // Get entitlements (which include current usage) ent, err := s.entitlementService.getUserEntitlements(ctx, userID) if err != nil || ent == nil { return nil, err } // Calculate percentages aiPercent := 0.0 if ent.AIRequestsLimit > 0 { aiPercent = float64(ent.AIRequestsUsed) / float64(ent.AIRequestsLimit) * 100 } docPercent := 0.0 if ent.DocumentsLimit > 0 { docPercent = float64(ent.DocumentsUsed) / float64(ent.DocumentsLimit) * 100 } // Get period dates now := time.Now() periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) periodEnd := periodStart.AddDate(0, 1, 0).Add(-time.Second) return &models.UsageInfo{ AIRequestsUsed: ent.AIRequestsUsed, AIRequestsLimit: ent.AIRequestsLimit, AIRequestsPercent: aiPercent, DocumentsUsed: ent.DocumentsUsed, DocumentsLimit: ent.DocumentsLimit, DocumentsPercent: docPercent, PeriodStart: periodStart.Format("2006-01-02"), PeriodEnd: periodEnd.Format("2006-01-02"), }, nil } // CheckUsageAllowed checks if a user is allowed to perform a usage action func (s *UsageService) CheckUsageAllowed(ctx context.Context, userIDStr, usageType string) (*models.CheckUsageResponse, error) { userID, err := uuid.Parse(userIDStr) if err != nil { return &models.CheckUsageResponse{ Allowed: false, Message: "Invalid user ID", }, nil } // Get entitlements ent, err := s.entitlementService.getUserEntitlements(ctx, userID) if err != nil { return &models.CheckUsageResponse{ Allowed: false, Message: "Failed to get entitlements", }, nil } if ent == nil { return &models.CheckUsageResponse{ Allowed: false, Message: "No subscription found", }, nil } var currentUsage, limit int switch usageType { case "ai_request": currentUsage = ent.AIRequestsUsed limit = ent.AIRequestsLimit case "document_created": currentUsage = ent.DocumentsUsed limit = ent.DocumentsLimit default: return &models.CheckUsageResponse{ Allowed: true, Message: "Unknown usage type - allowing", }, nil } remaining := limit - currentUsage allowed := remaining > 0 response := &models.CheckUsageResponse{ Allowed: allowed, CurrentUsage: currentUsage, Limit: limit, Remaining: remaining, } if !allowed { response.Message = fmt.Sprintf("Usage limit reached for %s (%d/%d)", usageType, currentUsage, limit) } return response, nil } // GetUsageHistory returns usage history for a user func (s *UsageService) GetUsageHistory(ctx context.Context, userID uuid.UUID, months int) ([]models.UsageSummary, error) { query := ` SELECT id, user_id, usage_type, period_start, total_count, created_at, updated_at FROM usage_summary WHERE user_id = $1 AND period_start >= $2 ORDER BY period_start DESC, usage_type ` // Calculate start date startDate := time.Now().AddDate(0, -months, 0) startDate = time.Date(startDate.Year(), startDate.Month(), 1, 0, 0, 0, 0, time.UTC) rows, err := s.db.Pool.Query(ctx, query, userID, startDate) if err != nil { return nil, err } defer rows.Close() var summaries []models.UsageSummary for rows.Next() { var summary models.UsageSummary err := rows.Scan( &summary.ID, &summary.UserID, &summary.UsageType, &summary.PeriodStart, &summary.TotalCount, &summary.CreatedAt, &summary.UpdatedAt, ) if err != nil { return nil, err } summaries = append(summaries, summary) } return summaries, nil } // ResetPeriodUsage resets usage for a new billing period func (s *UsageService) ResetPeriodUsage(ctx context.Context, userID uuid.UUID) error { now := time.Now() newPeriodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) newPeriodEnd := newPeriodStart.AddDate(0, 1, 0).Add(-time.Second) return s.entitlementService.ResetUsageCounters(ctx, userID, &newPeriodStart, &newPeriodEnd) }