package llm import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "time" "github.com/google/uuid" ) // OllamaAdapter implements the Provider interface for Ollama type OllamaAdapter struct { baseURL string defaultModel string httpClient *http.Client } // NewOllamaAdapter creates a new Ollama adapter func NewOllamaAdapter(baseURL, defaultModel string) *OllamaAdapter { return &OllamaAdapter{ baseURL: baseURL, defaultModel: defaultModel, httpClient: &http.Client{ Timeout: 5 * time.Minute, // LLM requests can be slow }, } } // Name returns the provider name func (o *OllamaAdapter) Name() string { return ProviderOllama } // IsAvailable checks if Ollama is reachable func (o *OllamaAdapter) IsAvailable(ctx context.Context) bool { req, err := http.NewRequestWithContext(ctx, "GET", o.baseURL+"/api/tags", nil) if err != nil { return false } ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() req = req.WithContext(ctx) resp, err := o.httpClient.Do(req) if err != nil { return false } defer resp.Body.Close() return resp.StatusCode == http.StatusOK } // ListModels returns available Ollama models func (o *OllamaAdapter) ListModels(ctx context.Context) ([]Model, error) { req, err := http.NewRequestWithContext(ctx, "GET", o.baseURL+"/api/tags", nil) if err != nil { return nil, err } resp, err := o.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to list models: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode) } var result struct { Models []struct { Name string `json:"name"` ModifiedAt string `json:"modified_at"` Size int64 `json:"size"` Details struct { Format string `json:"format"` Family string `json:"family"` ParameterSize string `json:"parameter_size"` QuantizationLevel string `json:"quantization_level"` } `json:"details"` } `json:"models"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } models := make([]Model, len(result.Models)) for i, m := range result.Models { models[i] = Model{ ID: m.Name, Name: m.Name, Provider: ProviderOllama, Description: fmt.Sprintf("%s (%s)", m.Details.Family, m.Details.ParameterSize), ContextSize: 4096, // Default, actual varies by model Capabilities: []string{"chat", "completion"}, } } return models, nil } // Complete performs text completion func (o *OllamaAdapter) Complete(ctx context.Context, req *CompletionRequest) (*CompletionResponse, error) { model := req.Model if model == "" { model = o.defaultModel } start := time.Now() ollamaReq := map[string]any{ "model": model, "prompt": req.Prompt, "stream": false, } if req.MaxTokens > 0 { if ollamaReq["options"] == nil { ollamaReq["options"] = make(map[string]any) } ollamaReq["options"].(map[string]any)["num_predict"] = req.MaxTokens } if req.Temperature > 0 { if ollamaReq["options"] == nil { ollamaReq["options"] = make(map[string]any) } ollamaReq["options"].(map[string]any)["temperature"] = req.Temperature } body, err := json.Marshal(ollamaReq) if err != nil { return nil, err } httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/generate", bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") resp, err := o.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("ollama request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("ollama error: %s", string(bodyBytes)) } var result struct { Model string `json:"model"` Response string `json:"response"` Done bool `json:"done"` TotalDuration int64 `json:"total_duration"` LoadDuration int64 `json:"load_duration"` PromptEvalCount int `json:"prompt_eval_count"` PromptEvalDuration int64 `json:"prompt_eval_duration"` EvalCount int `json:"eval_count"` EvalDuration int64 `json:"eval_duration"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } duration := time.Since(start) return &CompletionResponse{ ID: uuid.New().String(), Model: result.Model, Provider: ProviderOllama, Text: result.Response, FinishReason: "stop", Usage: UsageStats{ PromptTokens: result.PromptEvalCount, CompletionTokens: result.EvalCount, TotalTokens: result.PromptEvalCount + result.EvalCount, }, Duration: duration, }, nil } // Chat performs chat completion func (o *OllamaAdapter) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { model := req.Model if model == "" { model = o.defaultModel } start := time.Now() // Convert messages to Ollama format messages := make([]map[string]string, len(req.Messages)) for i, m := range req.Messages { messages[i] = map[string]string{ "role": m.Role, "content": m.Content, } } ollamaReq := map[string]any{ "model": model, "messages": messages, "stream": false, } if req.MaxTokens > 0 { if ollamaReq["options"] == nil { ollamaReq["options"] = make(map[string]any) } ollamaReq["options"].(map[string]any)["num_predict"] = req.MaxTokens } if req.Temperature > 0 { if ollamaReq["options"] == nil { ollamaReq["options"] = make(map[string]any) } ollamaReq["options"].(map[string]any)["temperature"] = req.Temperature } body, err := json.Marshal(ollamaReq) if err != nil { return nil, err } httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/chat", bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") resp, err := o.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("ollama chat request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("ollama chat error: %s", string(bodyBytes)) } var result struct { Model string `json:"model"` Message struct { Role string `json:"role"` Content string `json:"content"` } `json:"message"` Done bool `json:"done"` TotalDuration int64 `json:"total_duration"` PromptEvalCount int `json:"prompt_eval_count"` EvalCount int `json:"eval_count"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("failed to decode chat response: %w", err) } duration := time.Since(start) return &ChatResponse{ ID: uuid.New().String(), Model: result.Model, Provider: ProviderOllama, Message: Message{ Role: result.Message.Role, Content: result.Message.Content, }, FinishReason: "stop", Usage: UsageStats{ PromptTokens: result.PromptEvalCount, CompletionTokens: result.EvalCount, TotalTokens: result.PromptEvalCount + result.EvalCount, }, Duration: duration, }, nil } // Embed creates embeddings func (o *OllamaAdapter) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { model := req.Model if model == "" { model = "nomic-embed-text" // Default embedding model } start := time.Now() var embeddings [][]float64 for _, input := range req.Input { ollamaReq := map[string]any{ "model": model, "prompt": input, } body, err := json.Marshal(ollamaReq) if err != nil { return nil, err } httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/embeddings", bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") resp, err := o.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("ollama embedding request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("ollama embedding error: %s", string(bodyBytes)) } var result struct { Embedding []float64 `json:"embedding"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("failed to decode embedding response: %w", err) } embeddings = append(embeddings, result.Embedding) } duration := time.Since(start) return &EmbedResponse{ ID: uuid.New().String(), Model: model, Provider: ProviderOllama, Embeddings: embeddings, Usage: UsageStats{ TotalTokens: len(req.Input) * 256, // Approximate }, Duration: duration, }, nil }