package llm import ( "context" "errors" "time" ) // Provider names const ( ProviderOllama = "ollama" ProviderAnthropic = "anthropic" ProviderOpenAI = "openai" ) var ( ErrProviderUnavailable = errors.New("LLM provider unavailable") ErrModelNotFound = errors.New("model not found") ErrContextTooLong = errors.New("context too long for model") ErrRateLimited = errors.New("rate limited") ErrInvalidRequest = errors.New("invalid request") ) // Provider defines the interface for LLM providers type Provider interface { // Name returns the provider name Name() string // IsAvailable checks if the provider is currently available IsAvailable(ctx context.Context) bool // ListModels returns available models ListModels(ctx context.Context) ([]Model, error) // Complete performs text completion Complete(ctx context.Context, req *CompletionRequest) (*CompletionResponse, error) // Chat performs chat completion Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) // Embed creates embeddings for text Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) } // Model represents an available LLM model type Model struct { ID string `json:"id"` Name string `json:"name"` Provider string `json:"provider"` Description string `json:"description,omitempty"` ContextSize int `json:"context_size"` Parameters map[string]any `json:"parameters,omitempty"` Capabilities []string `json:"capabilities,omitempty"` // "chat", "completion", "embedding" } // Message represents a chat message type Message struct { Role string `json:"role"` // "system", "user", "assistant" Content string `json:"content"` } // CompletionRequest represents a text completion request type CompletionRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` Stop []string `json:"stop,omitempty"` Options map[string]any `json:"options,omitempty"` } // CompletionResponse represents a text completion response type CompletionResponse struct { ID string `json:"id"` Model string `json:"model"` Provider string `json:"provider"` Text string `json:"text"` FinishReason string `json:"finish_reason,omitempty"` Usage UsageStats `json:"usage"` Duration time.Duration `json:"duration"` } // ChatRequest represents a chat completion request type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` Stop []string `json:"stop,omitempty"` Options map[string]any `json:"options,omitempty"` } // ChatResponse represents a chat completion response type ChatResponse struct { ID string `json:"id"` Model string `json:"model"` Provider string `json:"provider"` Message Message `json:"message"` FinishReason string `json:"finish_reason,omitempty"` Usage UsageStats `json:"usage"` Duration time.Duration `json:"duration"` } // EmbedRequest represents an embedding request type EmbedRequest struct { Model string `json:"model"` Input []string `json:"input"` } // EmbedResponse represents an embedding response type EmbedResponse struct { ID string `json:"id"` Model string `json:"model"` Provider string `json:"provider"` Embeddings [][]float64 `json:"embeddings"` Usage UsageStats `json:"usage"` Duration time.Duration `json:"duration"` } // UsageStats represents token usage statistics type UsageStats struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } // ProviderRegistry manages multiple LLM providers type ProviderRegistry struct { providers map[string]Provider primaryProvider string fallbackProvider string } // NewProviderRegistry creates a new provider registry func NewProviderRegistry(primary, fallback string) *ProviderRegistry { return &ProviderRegistry{ providers: make(map[string]Provider), primaryProvider: primary, fallbackProvider: fallback, } } // Register registers a provider func (r *ProviderRegistry) Register(provider Provider) { r.providers[provider.Name()] = provider } // GetProvider returns a provider by name func (r *ProviderRegistry) GetProvider(name string) (Provider, bool) { p, ok := r.providers[name] return p, ok } // GetPrimary returns the primary provider func (r *ProviderRegistry) GetPrimary() (Provider, bool) { return r.GetProvider(r.primaryProvider) } // GetFallback returns the fallback provider func (r *ProviderRegistry) GetFallback() (Provider, bool) { return r.GetProvider(r.fallbackProvider) } // GetAvailable returns the first available provider (primary, then fallback) func (r *ProviderRegistry) GetAvailable(ctx context.Context) (Provider, error) { if p, ok := r.GetPrimary(); ok && p.IsAvailable(ctx) { return p, nil } if p, ok := r.GetFallback(); ok && p.IsAvailable(ctx) { return p, nil } return nil, ErrProviderUnavailable } // ListAllModels returns models from all available providers func (r *ProviderRegistry) ListAllModels(ctx context.Context) ([]Model, error) { var allModels []Model for _, p := range r.providers { if p.IsAvailable(ctx) { models, err := p.ListModels(ctx) if err == nil { allModels = append(allModels, models...) } } } return allModels, nil } // Complete performs completion with automatic fallback func (r *ProviderRegistry) Complete(ctx context.Context, req *CompletionRequest) (*CompletionResponse, error) { provider, err := r.GetAvailable(ctx) if err != nil { return nil, err } resp, err := provider.Complete(ctx, req) if err != nil && r.fallbackProvider != "" { // Try fallback if fallback, ok := r.GetFallback(); ok && fallback.Name() != provider.Name() { return fallback.Complete(ctx, req) } } return resp, err } // Chat performs chat completion with automatic fallback func (r *ProviderRegistry) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { provider, err := r.GetAvailable(ctx) if err != nil { return nil, err } resp, err := provider.Chat(ctx, req) if err != nil && r.fallbackProvider != "" { // Try fallback if fallback, ok := r.GetFallback(); ok && fallback.Name() != provider.Name() { return fallback.Chat(ctx, req) } } return resp, err } // Embed creates embeddings with automatic fallback func (r *ProviderRegistry) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { provider, err := r.GetAvailable(ctx) if err != nil { return nil, err } return provider.Embed(ctx, req) }