Initial commit: breakpilot-compliance - Compliance SDK Platform
Services: Admin-Compliance, Backend-Compliance, AI-Compliance-SDK, Consent-SDK, Developer-Portal, PCA-Platform, DSMS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
350
ai-compliance-sdk/internal/llm/ollama_adapter.go
Normal file
350
ai-compliance-sdk/internal/llm/ollama_adapter.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// OllamaAdapter implements the Provider interface for Ollama
|
||||
type OllamaAdapter struct {
|
||||
baseURL string
|
||||
defaultModel string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOllamaAdapter creates a new Ollama adapter
|
||||
func NewOllamaAdapter(baseURL, defaultModel string) *OllamaAdapter {
|
||||
return &OllamaAdapter{
|
||||
baseURL: baseURL,
|
||||
defaultModel: defaultModel,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Minute, // LLM requests can be slow
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name
|
||||
func (o *OllamaAdapter) Name() string {
|
||||
return ProviderOllama
|
||||
}
|
||||
|
||||
// IsAvailable checks if Ollama is reachable
|
||||
func (o *OllamaAdapter) IsAvailable(ctx context.Context) bool {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", o.baseURL+"/api/tags", nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
|
||||
// ListModels returns available Ollama models
|
||||
func (o *OllamaAdapter) ListModels(ctx context.Context) ([]Model, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", o.baseURL+"/api/tags", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list models: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Details struct {
|
||||
Format string `json:"format"`
|
||||
Family string `json:"family"`
|
||||
ParameterSize string `json:"parameter_size"`
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
} `json:"details"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]Model, len(result.Models))
|
||||
for i, m := range result.Models {
|
||||
models[i] = Model{
|
||||
ID: m.Name,
|
||||
Name: m.Name,
|
||||
Provider: ProviderOllama,
|
||||
Description: fmt.Sprintf("%s (%s)", m.Details.Family, m.Details.ParameterSize),
|
||||
ContextSize: 4096, // Default, actual varies by model
|
||||
Capabilities: []string{"chat", "completion"},
|
||||
}
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// Complete performs text completion
|
||||
func (o *OllamaAdapter) Complete(ctx context.Context, req *CompletionRequest) (*CompletionResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = o.defaultModel
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
ollamaReq := map[string]any{
|
||||
"model": model,
|
||||
"prompt": req.Prompt,
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
if ollamaReq["options"] == nil {
|
||||
ollamaReq["options"] = make(map[string]any)
|
||||
}
|
||||
ollamaReq["options"].(map[string]any)["num_predict"] = req.MaxTokens
|
||||
}
|
||||
|
||||
if req.Temperature > 0 {
|
||||
if ollamaReq["options"] == nil {
|
||||
ollamaReq["options"] = make(map[string]any)
|
||||
}
|
||||
ollamaReq["options"].(map[string]any)["temperature"] = req.Temperature
|
||||
}
|
||||
|
||||
body, err := json.Marshal(ollamaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/generate", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama error: %s", string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Model string `json:"model"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
TotalDuration int64 `json:"total_duration"`
|
||||
LoadDuration int64 `json:"load_duration"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration int64 `json:"eval_duration"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
return &CompletionResponse{
|
||||
ID: uuid.New().String(),
|
||||
Model: result.Model,
|
||||
Provider: ProviderOllama,
|
||||
Text: result.Response,
|
||||
FinishReason: "stop",
|
||||
Usage: UsageStats{
|
||||
PromptTokens: result.PromptEvalCount,
|
||||
CompletionTokens: result.EvalCount,
|
||||
TotalTokens: result.PromptEvalCount + result.EvalCount,
|
||||
},
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat performs chat completion
|
||||
func (o *OllamaAdapter) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = o.defaultModel
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Convert messages to Ollama format
|
||||
messages := make([]map[string]string, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
messages[i] = map[string]string{
|
||||
"role": m.Role,
|
||||
"content": m.Content,
|
||||
}
|
||||
}
|
||||
|
||||
ollamaReq := map[string]any{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
if ollamaReq["options"] == nil {
|
||||
ollamaReq["options"] = make(map[string]any)
|
||||
}
|
||||
ollamaReq["options"].(map[string]any)["num_predict"] = req.MaxTokens
|
||||
}
|
||||
|
||||
if req.Temperature > 0 {
|
||||
if ollamaReq["options"] == nil {
|
||||
ollamaReq["options"] = make(map[string]any)
|
||||
}
|
||||
ollamaReq["options"].(map[string]any)["temperature"] = req.Temperature
|
||||
}
|
||||
|
||||
body, err := json.Marshal(ollamaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/chat", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama chat request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama chat error: %s", string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Model string `json:"model"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
TotalDuration int64 `json:"total_duration"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode chat response: %w", err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
return &ChatResponse{
|
||||
ID: uuid.New().String(),
|
||||
Model: result.Model,
|
||||
Provider: ProviderOllama,
|
||||
Message: Message{
|
||||
Role: result.Message.Role,
|
||||
Content: result.Message.Content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
Usage: UsageStats{
|
||||
PromptTokens: result.PromptEvalCount,
|
||||
CompletionTokens: result.EvalCount,
|
||||
TotalTokens: result.PromptEvalCount + result.EvalCount,
|
||||
},
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Embed creates embeddings
|
||||
func (o *OllamaAdapter) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = "nomic-embed-text" // Default embedding model
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var embeddings [][]float64
|
||||
|
||||
for _, input := range req.Input {
|
||||
ollamaReq := map[string]any{
|
||||
"model": model,
|
||||
"prompt": input,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(ollamaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/embeddings", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama embedding request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama embedding error: %s", string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode embedding response: %w", err)
|
||||
}
|
||||
|
||||
embeddings = append(embeddings, result.Embedding)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
return &EmbedResponse{
|
||||
ID: uuid.New().String(),
|
||||
Model: model,
|
||||
Provider: ProviderOllama,
|
||||
Embeddings: embeddings,
|
||||
Usage: UsageStats{
|
||||
TotalTokens: len(req.Input) * 256, // Approximate
|
||||
},
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user