package embedding import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" ) func TestNewService_Disabled(t *testing.T) { service, err := NewService("none", "", "", "", 1536, false) if err != nil { t.Fatalf("NewService failed: %v", err) } if service.IsEnabled() { t.Error("Service should not be enabled") } if service.Dimension() != 1536 { t.Errorf("Expected dimension 1536, got %d", service.Dimension()) } } func TestNewService_DisabledByProvider(t *testing.T) { service, err := NewService("none", "", "", "", 1536, true) if err != nil { t.Fatalf("NewService failed: %v", err) } if service.IsEnabled() { t.Error("Service should not be enabled when provider is 'none'") } } func TestNewService_OpenAIMissingKey(t *testing.T) { _, err := NewService("openai", "", "", "", 1536, true) if err == nil { t.Error("Expected error for missing OpenAI API key") } } func TestNewService_UnknownProvider(t *testing.T) { _, err := NewService("unknown", "", "", "", 1536, true) if err == nil { t.Error("Expected error for unknown provider") } } func TestService_EmbedWhenDisabled(t *testing.T) { service, _ := NewService("none", "", "", "", 1536, false) _, err := service.Embed(context.Background(), "test text") if err == nil { t.Error("Expected error when embedding with disabled service") } } func TestService_EmbedBatchWhenDisabled(t *testing.T) { service, _ := NewService("none", "", "", "", 1536, false) _, err := service.EmbedBatch(context.Background(), []string{"test1", "test2"}) if err == nil { t.Error("Expected error when embedding batch with disabled service") } } // ===================================================== // OpenAI Provider Tests with Mock Server // ===================================================== func TestOpenAIProvider_Embed(t *testing.T) { // Create mock server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Verify request if r.Method != "POST" { t.Errorf("Expected POST, got %s", r.Method) } if r.Header.Get("Authorization") != "Bearer test-api-key" { t.Errorf("Expected correct Authorization header") } if r.Header.Get("Content-Type") != "application/json" { t.Errorf("Expected Content-Type application/json") } // Parse request body var reqBody openAIEmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { t.Fatalf("Failed to parse request body: %v", err) } if reqBody.Model != "text-embedding-3-small" { t.Errorf("Expected model text-embedding-3-small, got %s", reqBody.Model) } // Send mock response resp := openAIEmbeddingResponse{ Data: []struct { Embedding []float32 `json:"embedding"` Index int `json:"index"` }{ { Embedding: make([]float32, 1536), Index: 0, }, }, } resp.Data[0].Embedding[0] = 0.1 resp.Data[0].Embedding[1] = 0.2 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) })) defer server.Close() // Create provider with mock server (we need to override the URL) provider := &OpenAIProvider{ apiKey: "test-api-key", model: "text-embedding-3-small", dimension: 1536, httpClient: &http.Client{ Timeout: 10 * time.Second, }, } // Note: This test won't actually work with the mock server because // the provider hardcodes the OpenAI URL. This is a structural test. // For real testing, we'd need to make the URL configurable. if provider.Dimension() != 1536 { t.Errorf("Expected dimension 1536, got %d", provider.Dimension()) } } func TestOpenAIProvider_EmbedBatch_EmptyInput(t *testing.T) { provider := NewOpenAIProvider("test-key", "text-embedding-3-small", 1536) result, err := provider.EmbedBatch(context.Background(), []string{}) if err != nil { t.Errorf("Empty input should not cause error: %v", err) } if result != nil { t.Errorf("Expected nil result for empty input, got %v", result) } } // ===================================================== // Ollama Provider Tests with Mock Server // ===================================================== func TestOllamaProvider_Embed(t *testing.T) { // Create mock server mockEmbedding := make([]float32, 384) mockEmbedding[0] = 0.5 mockEmbedding[1] = 0.3 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("Expected POST, got %s", r.Method) } if r.URL.Path != "/api/embeddings" { t.Errorf("Expected path /api/embeddings, got %s", r.URL.Path) } // Parse request var reqBody ollamaEmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { t.Fatalf("Failed to parse request: %v", err) } if reqBody.Model != "nomic-embed-text" { t.Errorf("Expected model nomic-embed-text, got %s", reqBody.Model) } // Send response resp := ollamaEmbeddingResponse{ Embedding: mockEmbedding, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) })) defer server.Close() provider, err := NewOllamaProvider(server.URL, "nomic-embed-text", 384) if err != nil { t.Fatalf("Failed to create provider: %v", err) } ctx := context.Background() embedding, err := provider.Embed(ctx, "Test text für Embedding") if err != nil { t.Fatalf("Embed failed: %v", err) } if len(embedding) != 384 { t.Errorf("Expected 384 dimensions, got %d", len(embedding)) } if embedding[0] != 0.5 { t.Errorf("Expected first value 0.5, got %f", embedding[0]) } } func TestOllamaProvider_EmbedBatch(t *testing.T) { callCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ mockEmbedding := make([]float32, 384) mockEmbedding[0] = float32(callCount) * 0.1 resp := ollamaEmbeddingResponse{ Embedding: mockEmbedding, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) })) defer server.Close() provider, err := NewOllamaProvider(server.URL, "nomic-embed-text", 384) if err != nil { t.Fatalf("Failed to create provider: %v", err) } ctx := context.Background() texts := []string{"Text 1", "Text 2", "Text 3"} embeddings, err := provider.EmbedBatch(ctx, texts) if err != nil { t.Fatalf("EmbedBatch failed: %v", err) } if len(embeddings) != 3 { t.Errorf("Expected 3 embeddings, got %d", len(embeddings)) } // Verify each embedding was called if callCount != 3 { t.Errorf("Expected 3 API calls, got %d", callCount) } } func TestOllamaProvider_EmbedServerError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("Internal server error")) })) defer server.Close() provider, _ := NewOllamaProvider(server.URL, "nomic-embed-text", 384) _, err := provider.Embed(context.Background(), "test") if err == nil { t.Error("Expected error for server error response") } } func TestOllamaProvider_Dimension(t *testing.T) { provider, _ := NewOllamaProvider("http://localhost:11434", "nomic-embed-text", 768) if provider.Dimension() != 768 { t.Errorf("Expected dimension 768, got %d", provider.Dimension()) } } // ===================================================== // Text Truncation Tests // ===================================================== func TestOllamaProvider_TextTruncation(t *testing.T) { receivedText := "" server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var reqBody ollamaEmbeddingRequest json.NewDecoder(r.Body).Decode(&reqBody) receivedText = reqBody.Prompt resp := ollamaEmbeddingResponse{ Embedding: make([]float32, 384), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) })) defer server.Close() provider, _ := NewOllamaProvider(server.URL, "nomic-embed-text", 384) // Create very long text longText := "" for i := 0; i < 40000; i++ { longText += "a" } provider.Embed(context.Background(), longText) // Text should be truncated to 30000 chars if len(receivedText) > 30000 { t.Errorf("Expected truncated text <= 30000 chars, got %d", len(receivedText)) } } // ===================================================== // Integration Tests (require actual service) // ===================================================== func TestOpenAIProvider_Integration(t *testing.T) { // Skip in CI/CD - only run manually with real API key t.Skip("Integration test - requires OPENAI_API_KEY environment variable") // provider := NewOpenAIProvider(os.Getenv("OPENAI_API_KEY"), "text-embedding-3-small", 1536) // embedding, err := provider.Embed(context.Background(), "Lehrplan Mathematik Bayern") // ... }