package embedding import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "time" ) // EmbeddingProvider defines the interface for embedding services type EmbeddingProvider interface { // Embed generates embeddings for the given text Embed(ctx context.Context, text string) ([]float32, error) // EmbedBatch generates embeddings for multiple texts EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) // Dimension returns the embedding vector dimension Dimension() int } // Service wraps an embedding provider type Service struct { provider EmbeddingProvider dimension int enabled bool } // NewService creates a new embedding service based on configuration func NewService(provider, apiKey, model, ollamaURL string, dimension int, enabled bool) (*Service, error) { if !enabled { return &Service{ provider: nil, dimension: dimension, enabled: false, }, nil } var p EmbeddingProvider var err error switch provider { case "openai": if apiKey == "" { return nil, errors.New("OpenAI API key required for openai provider") } p = NewOpenAIProvider(apiKey, model, dimension) case "ollama": p, err = NewOllamaProvider(ollamaURL, model, dimension) if err != nil { return nil, err } case "none", "": return &Service{ provider: nil, dimension: dimension, enabled: false, }, nil default: return nil, fmt.Errorf("unknown embedding provider: %s", provider) } return &Service{ provider: p, dimension: dimension, enabled: true, }, nil } // IsEnabled returns true if semantic search is enabled func (s *Service) IsEnabled() bool { return s.enabled && s.provider != nil } // Embed generates embedding for a single text func (s *Service) Embed(ctx context.Context, text string) ([]float32, error) { if !s.IsEnabled() { return nil, errors.New("embedding service not enabled") } return s.provider.Embed(ctx, text) } // EmbedBatch generates embeddings for multiple texts func (s *Service) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { if !s.IsEnabled() { return nil, errors.New("embedding service not enabled") } return s.provider.EmbedBatch(ctx, texts) } // Dimension returns the configured embedding dimension func (s *Service) Dimension() int { return s.dimension } // ===================================================== // OpenAI Embedding Provider // ===================================================== // OpenAIProvider implements EmbeddingProvider using OpenAI's API type OpenAIProvider struct { apiKey string model string dimension int httpClient *http.Client } // NewOpenAIProvider creates a new OpenAI embedding provider func NewOpenAIProvider(apiKey, model string, dimension int) *OpenAIProvider { return &OpenAIProvider{ apiKey: apiKey, model: model, dimension: dimension, httpClient: &http.Client{ Timeout: 60 * time.Second, }, } } // openAIEmbeddingRequest represents the OpenAI API request type openAIEmbeddingRequest struct { Model string `json:"model"` Input []string `json:"input"` Dimensions int `json:"dimensions,omitempty"` } // openAIEmbeddingResponse represents the OpenAI API response type openAIEmbeddingResponse struct { Data []struct { Embedding []float32 `json:"embedding"` Index int `json:"index"` } `json:"data"` Usage struct { PromptTokens int `json:"prompt_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` Error *struct { Message string `json:"message"` Type string `json:"type"` } `json:"error,omitempty"` } // Embed generates embedding for a single text func (p *OpenAIProvider) Embed(ctx context.Context, text string) ([]float32, error) { embeddings, err := p.EmbedBatch(ctx, []string{text}) if err != nil { return nil, err } if len(embeddings) == 0 { return nil, errors.New("no embedding returned") } return embeddings[0], nil } // EmbedBatch generates embeddings for multiple texts func (p *OpenAIProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { if len(texts) == 0 { return nil, nil } // Truncate texts to avoid token limits (max ~8000 tokens per text) truncatedTexts := make([]string, len(texts)) for i, text := range texts { if len(text) > 30000 { // Rough estimate: ~4 chars per token truncatedTexts[i] = text[:30000] } else { truncatedTexts[i] = text } } reqBody := openAIEmbeddingRequest{ Model: p.model, Input: truncatedTexts, } // Only set dimensions for models that support it (text-embedding-3-*) if p.model == "text-embedding-3-small" || p.model == "text-embedding-3-large" { reqBody.Dimensions = p.dimension } body, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/embeddings", bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", "Bearer "+p.apiKey) req.Header.Set("Content-Type", "application/json") resp, err := p.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to call OpenAI API: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } var apiResp openAIEmbeddingResponse if err := json.Unmarshal(respBody, &apiResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } if apiResp.Error != nil { return nil, fmt.Errorf("OpenAI API error: %s", apiResp.Error.Message) } if len(apiResp.Data) != len(texts) { return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(apiResp.Data)) } // Sort by index to maintain order result := make([][]float32, len(texts)) for _, item := range apiResp.Data { result[item.Index] = item.Embedding } return result, nil } // Dimension returns the embedding dimension func (p *OpenAIProvider) Dimension() int { return p.dimension } // ===================================================== // Ollama Embedding Provider (for local models) // ===================================================== // OllamaProvider implements EmbeddingProvider using Ollama's API type OllamaProvider struct { baseURL string model string dimension int httpClient *http.Client } // NewOllamaProvider creates a new Ollama embedding provider func NewOllamaProvider(baseURL, model string, dimension int) (*OllamaProvider, error) { return &OllamaProvider{ baseURL: baseURL, model: model, dimension: dimension, httpClient: &http.Client{ Timeout: 120 * time.Second, // Ollama can be slow on first inference }, }, nil } // ollamaEmbeddingRequest represents the Ollama API request type ollamaEmbeddingRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` } // ollamaEmbeddingResponse represents the Ollama API response type ollamaEmbeddingResponse struct { Embedding []float32 `json:"embedding"` } // Embed generates embedding for a single text func (p *OllamaProvider) Embed(ctx context.Context, text string) ([]float32, error) { // Truncate text if len(text) > 30000 { text = text[:30000] } reqBody := ollamaEmbeddingRequest{ Model: p.model, Prompt: text, } body, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/api/embeddings", bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := p.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to call Ollama API: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("Ollama API error (status %d): %s", resp.StatusCode, string(respBody)) } var apiResp ollamaEmbeddingResponse if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } return apiResp.Embedding, nil } // EmbedBatch generates embeddings for multiple texts (sequential for Ollama) func (p *OllamaProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { result := make([][]float32, len(texts)) for i, text := range texts { embedding, err := p.Embed(ctx, text) if err != nil { return nil, fmt.Errorf("failed to embed text %d: %w", i, err) } result[i] = embedding } return result, nil } // Dimension returns the embedding dimension func (p *OllamaProvider) Dimension() int { return p.dimension }