Initial commit: breakpilot-compliance - Compliance SDK Platform
Services: Admin-Compliance, Backend-Compliance, AI-Compliance-SDK, Consent-SDK, Developer-Portal, PCA-Platform, DSMS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
367
ai-compliance-sdk/internal/llm/access_gate.go
Normal file
367
ai-compliance-sdk/internal/llm/access_gate.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AccessGate controls access to LLM operations based on RBAC policies
|
||||
type AccessGate struct {
|
||||
policyEngine *rbac.PolicyEngine
|
||||
piiDetector *PIIDetector
|
||||
registry *ProviderRegistry
|
||||
}
|
||||
|
||||
// NewAccessGate creates a new access gate
|
||||
func NewAccessGate(policyEngine *rbac.PolicyEngine, piiDetector *PIIDetector, registry *ProviderRegistry) *AccessGate {
|
||||
return &AccessGate{
|
||||
policyEngine: policyEngine,
|
||||
piiDetector: piiDetector,
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
// GatedRequest represents a request that has passed through the access gate
|
||||
type GatedRequest struct {
|
||||
OriginalRequest any
|
||||
UserID uuid.UUID
|
||||
TenantID uuid.UUID
|
||||
NamespaceID *uuid.UUID
|
||||
Model string
|
||||
PIIDetected bool
|
||||
PIITypes []string
|
||||
PromptRedacted bool
|
||||
PromptHash string
|
||||
Policy *rbac.LLMPolicy
|
||||
AccessResult *rbac.LLMAccessResult
|
||||
}
|
||||
|
||||
// GatedChatRequest represents a chat request that has passed through the gate
|
||||
type GatedChatRequest struct {
|
||||
*GatedRequest
|
||||
Messages []Message
|
||||
}
|
||||
|
||||
// GatedCompletionRequest represents a completion request that has passed through the gate
|
||||
type GatedCompletionRequest struct {
|
||||
*GatedRequest
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// ProcessChatRequest validates and processes a chat request
|
||||
func (g *AccessGate) ProcessChatRequest(
|
||||
ctx context.Context,
|
||||
userID, tenantID uuid.UUID,
|
||||
namespaceID *uuid.UUID,
|
||||
req *ChatRequest,
|
||||
dataCategories []string,
|
||||
) (*GatedChatRequest, error) {
|
||||
// 1. Evaluate LLM access
|
||||
accessReq := &rbac.LLMAccessRequest{
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
NamespaceID: namespaceID,
|
||||
Model: req.Model,
|
||||
DataCategories: dataCategories,
|
||||
TokensRequested: req.MaxTokens,
|
||||
Operation: "chat",
|
||||
}
|
||||
|
||||
accessResult, err := g.policyEngine.EvaluateLLMAccess(ctx, accessReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access evaluation failed: %w", err)
|
||||
}
|
||||
|
||||
if !accessResult.Allowed {
|
||||
return nil, fmt.Errorf("access denied: %s", accessResult.Reason)
|
||||
}
|
||||
|
||||
// 2. Process messages for PII
|
||||
processedMessages := make([]Message, len(req.Messages))
|
||||
copy(processedMessages, req.Messages)
|
||||
|
||||
var allPIITypes []string
|
||||
piiDetected := false
|
||||
redacted := false
|
||||
|
||||
for i, msg := range processedMessages {
|
||||
if msg.Role == "user" || msg.Role == "system" {
|
||||
// Check for PII
|
||||
findings := g.piiDetector.FindPII(msg.Content)
|
||||
if len(findings) > 0 {
|
||||
piiDetected = true
|
||||
for _, f := range findings {
|
||||
allPIITypes = append(allPIITypes, f.Type)
|
||||
}
|
||||
|
||||
// Redact if required by policy
|
||||
if accessResult.RequirePIIRedaction {
|
||||
processedMessages[i].Content = g.piiDetector.Redact(msg.Content, accessResult.PIIRedactionLevel)
|
||||
redacted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Generate prompt hash for audit
|
||||
promptHash := g.hashMessages(processedMessages)
|
||||
|
||||
// 4. Apply token limits from policy
|
||||
if accessResult.MaxTokens > 0 && req.MaxTokens > accessResult.MaxTokens {
|
||||
req.MaxTokens = accessResult.MaxTokens
|
||||
}
|
||||
|
||||
return &GatedChatRequest{
|
||||
GatedRequest: &GatedRequest{
|
||||
OriginalRequest: req,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
NamespaceID: namespaceID,
|
||||
Model: req.Model,
|
||||
PIIDetected: piiDetected,
|
||||
PIITypes: uniqueStrings(allPIITypes),
|
||||
PromptRedacted: redacted,
|
||||
PromptHash: promptHash,
|
||||
Policy: accessResult.Policy,
|
||||
AccessResult: accessResult,
|
||||
},
|
||||
Messages: processedMessages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessCompletionRequest validates and processes a completion request
|
||||
func (g *AccessGate) ProcessCompletionRequest(
|
||||
ctx context.Context,
|
||||
userID, tenantID uuid.UUID,
|
||||
namespaceID *uuid.UUID,
|
||||
req *CompletionRequest,
|
||||
dataCategories []string,
|
||||
) (*GatedCompletionRequest, error) {
|
||||
// 1. Evaluate LLM access
|
||||
accessReq := &rbac.LLMAccessRequest{
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
NamespaceID: namespaceID,
|
||||
Model: req.Model,
|
||||
DataCategories: dataCategories,
|
||||
TokensRequested: req.MaxTokens,
|
||||
Operation: "completion",
|
||||
}
|
||||
|
||||
accessResult, err := g.policyEngine.EvaluateLLMAccess(ctx, accessReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access evaluation failed: %w", err)
|
||||
}
|
||||
|
||||
if !accessResult.Allowed {
|
||||
return nil, fmt.Errorf("access denied: %s", accessResult.Reason)
|
||||
}
|
||||
|
||||
// 2. Process prompt for PII
|
||||
processedPrompt := req.Prompt
|
||||
var allPIITypes []string
|
||||
piiDetected := false
|
||||
redacted := false
|
||||
|
||||
findings := g.piiDetector.FindPII(req.Prompt)
|
||||
if len(findings) > 0 {
|
||||
piiDetected = true
|
||||
for _, f := range findings {
|
||||
allPIITypes = append(allPIITypes, f.Type)
|
||||
}
|
||||
|
||||
// Redact if required by policy
|
||||
if accessResult.RequirePIIRedaction {
|
||||
processedPrompt = g.piiDetector.Redact(req.Prompt, accessResult.PIIRedactionLevel)
|
||||
redacted = true
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Generate prompt hash for audit
|
||||
promptHash := g.hashPrompt(processedPrompt)
|
||||
|
||||
// 4. Apply token limits from policy
|
||||
if accessResult.MaxTokens > 0 && req.MaxTokens > accessResult.MaxTokens {
|
||||
req.MaxTokens = accessResult.MaxTokens
|
||||
}
|
||||
|
||||
return &GatedCompletionRequest{
|
||||
GatedRequest: &GatedRequest{
|
||||
OriginalRequest: req,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
NamespaceID: namespaceID,
|
||||
Model: req.Model,
|
||||
PIIDetected: piiDetected,
|
||||
PIITypes: uniqueStrings(allPIITypes),
|
||||
PromptRedacted: redacted,
|
||||
PromptHash: promptHash,
|
||||
Policy: accessResult.Policy,
|
||||
AccessResult: accessResult,
|
||||
},
|
||||
Prompt: processedPrompt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecuteChat executes a gated chat request
|
||||
func (g *AccessGate) ExecuteChat(ctx context.Context, gatedReq *GatedChatRequest) (*ChatResponse, error) {
|
||||
provider, err := g.registry.GetAvailable(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &ChatRequest{
|
||||
Model: gatedReq.Model,
|
||||
Messages: gatedReq.Messages,
|
||||
MaxTokens: gatedReq.AccessResult.MaxTokens,
|
||||
Temperature: 0.7,
|
||||
}
|
||||
|
||||
if orig, ok := gatedReq.OriginalRequest.(*ChatRequest); ok {
|
||||
req.Temperature = orig.Temperature
|
||||
req.TopP = orig.TopP
|
||||
req.Stop = orig.Stop
|
||||
req.Options = orig.Options
|
||||
}
|
||||
|
||||
return provider.Chat(ctx, req)
|
||||
}
|
||||
|
||||
// ExecuteCompletion executes a gated completion request
|
||||
func (g *AccessGate) ExecuteCompletion(ctx context.Context, gatedReq *GatedCompletionRequest) (*CompletionResponse, error) {
|
||||
provider, err := g.registry.GetAvailable(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &CompletionRequest{
|
||||
Model: gatedReq.Model,
|
||||
Prompt: gatedReq.Prompt,
|
||||
MaxTokens: gatedReq.AccessResult.MaxTokens,
|
||||
}
|
||||
|
||||
if orig, ok := gatedReq.OriginalRequest.(*CompletionRequest); ok {
|
||||
req.Temperature = orig.Temperature
|
||||
req.TopP = orig.TopP
|
||||
req.Stop = orig.Stop
|
||||
req.Options = orig.Options
|
||||
}
|
||||
|
||||
return provider.Complete(ctx, req)
|
||||
}
|
||||
|
||||
// hashMessages creates a SHA-256 hash of chat messages (for audit without storing PII)
|
||||
func (g *AccessGate) hashMessages(messages []Message) string {
|
||||
hasher := sha256.New()
|
||||
for _, msg := range messages {
|
||||
hasher.Write([]byte(msg.Role))
|
||||
hasher.Write([]byte(msg.Content))
|
||||
}
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// hashPrompt creates a SHA-256 hash of a prompt
|
||||
func (g *AccessGate) hashPrompt(prompt string) string {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(prompt))
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// uniqueStrings returns unique strings from a slice
|
||||
func uniqueStrings(slice []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var result []string
|
||||
for _, s := range slice {
|
||||
if !seen[s] {
|
||||
seen[s] = true
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AuditEntry represents an entry for the audit log
|
||||
type AuditEntry struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
TenantID uuid.UUID `json:"tenant_id"`
|
||||
NamespaceID *uuid.UUID `json:"namespace_id,omitempty"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Operation string `json:"operation"`
|
||||
ModelUsed string `json:"model_used"`
|
||||
Provider string `json:"provider"`
|
||||
PromptHash string `json:"prompt_hash"`
|
||||
PromptLength int `json:"prompt_length"`
|
||||
ResponseLength int `json:"response_length,omitempty"`
|
||||
TokensUsed int `json:"tokens_used"`
|
||||
DurationMS int `json:"duration_ms"`
|
||||
PIIDetected bool `json:"pii_detected"`
|
||||
PIITypesDetected []string `json:"pii_types_detected,omitempty"`
|
||||
PIIRedacted bool `json:"pii_redacted"`
|
||||
PolicyID *uuid.UUID `json:"policy_id,omitempty"`
|
||||
PolicyViolations []string `json:"policy_violations,omitempty"`
|
||||
DataCategoriesAccessed []string `json:"data_categories_accessed,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
RequestMetadata map[string]any `json:"request_metadata,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// CreateAuditEntry creates an audit entry from a gated request and response
|
||||
func (g *AccessGate) CreateAuditEntry(
|
||||
gatedReq *GatedRequest,
|
||||
operation string,
|
||||
provider string,
|
||||
resp any,
|
||||
err error,
|
||||
promptLength int,
|
||||
sessionID string,
|
||||
) *AuditEntry {
|
||||
entry := &AuditEntry{
|
||||
ID: uuid.New(),
|
||||
TenantID: gatedReq.TenantID,
|
||||
NamespaceID: gatedReq.NamespaceID,
|
||||
UserID: gatedReq.UserID,
|
||||
SessionID: sessionID,
|
||||
Operation: operation,
|
||||
ModelUsed: gatedReq.Model,
|
||||
Provider: provider,
|
||||
PromptHash: gatedReq.PromptHash,
|
||||
PromptLength: promptLength,
|
||||
PIIDetected: gatedReq.PIIDetected,
|
||||
PIITypesDetected: gatedReq.PIITypes,
|
||||
PIIRedacted: gatedReq.PromptRedacted,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if gatedReq.Policy != nil {
|
||||
entry.PolicyID = &gatedReq.Policy.ID
|
||||
}
|
||||
|
||||
if gatedReq.AccessResult != nil && len(gatedReq.AccessResult.BlockedCategories) > 0 {
|
||||
entry.PolicyViolations = gatedReq.AccessResult.BlockedCategories
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
entry.ErrorMessage = err.Error()
|
||||
}
|
||||
|
||||
// Extract usage from response
|
||||
switch r := resp.(type) {
|
||||
case *ChatResponse:
|
||||
entry.ResponseLength = len(r.Message.Content)
|
||||
entry.TokensUsed = r.Usage.TotalTokens
|
||||
entry.DurationMS = int(r.Duration.Milliseconds())
|
||||
case *CompletionResponse:
|
||||
entry.ResponseLength = len(r.Text)
|
||||
entry.TokensUsed = r.Usage.TotalTokens
|
||||
entry.DurationMS = int(r.Duration.Milliseconds())
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
250
ai-compliance-sdk/internal/llm/anthropic_adapter.go
Normal file
250
ai-compliance-sdk/internal/llm/anthropic_adapter.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AnthropicAdapter implements the Provider interface for Anthropic API
|
||||
type AnthropicAdapter struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
defaultModel string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAnthropicAdapter creates a new Anthropic adapter
|
||||
func NewAnthropicAdapter(apiKey, defaultModel string) *AnthropicAdapter {
|
||||
return &AnthropicAdapter{
|
||||
apiKey: apiKey,
|
||||
baseURL: "https://api.anthropic.com",
|
||||
defaultModel: defaultModel,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name
|
||||
func (a *AnthropicAdapter) Name() string {
|
||||
return ProviderAnthropic
|
||||
}
|
||||
|
||||
// IsAvailable checks if Anthropic API is reachable
|
||||
func (a *AnthropicAdapter) IsAvailable(ctx context.Context) bool {
|
||||
if a.apiKey == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Simple check - we can't really ping Anthropic without making a request
|
||||
// Just verify we have an API key
|
||||
return true
|
||||
}
|
||||
|
||||
// ListModels returns available Anthropic models
|
||||
func (a *AnthropicAdapter) ListModels(ctx context.Context) ([]Model, error) {
|
||||
// Anthropic doesn't have a models endpoint, return known models
|
||||
return []Model{
|
||||
{
|
||||
ID: "claude-3-opus-20240229",
|
||||
Name: "Claude 3 Opus",
|
||||
Provider: ProviderAnthropic,
|
||||
Description: "Most powerful model for complex tasks",
|
||||
ContextSize: 200000,
|
||||
Capabilities: []string{"chat"},
|
||||
},
|
||||
{
|
||||
ID: "claude-3-sonnet-20240229",
|
||||
Name: "Claude 3 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
Description: "Balanced performance and speed",
|
||||
ContextSize: 200000,
|
||||
Capabilities: []string{"chat"},
|
||||
},
|
||||
{
|
||||
ID: "claude-3-haiku-20240307",
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
Description: "Fast and efficient",
|
||||
ContextSize: 200000,
|
||||
Capabilities: []string{"chat"},
|
||||
},
|
||||
{
|
||||
ID: "claude-3-5-sonnet-20240620",
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
Description: "Latest and most capable model",
|
||||
ContextSize: 200000,
|
||||
Capabilities: []string{"chat"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Complete performs text completion (converted to chat)
|
||||
func (a *AnthropicAdapter) Complete(ctx context.Context, req *CompletionRequest) (*CompletionResponse, error) {
|
||||
// Anthropic only supports chat, so convert completion to chat
|
||||
chatReq := &ChatRequest{
|
||||
Model: req.Model,
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: req.Prompt},
|
||||
},
|
||||
MaxTokens: req.MaxTokens,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stop: req.Stop,
|
||||
}
|
||||
|
||||
chatResp, err := a.Chat(ctx, chatReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CompletionResponse{
|
||||
ID: chatResp.ID,
|
||||
Model: chatResp.Model,
|
||||
Provider: chatResp.Provider,
|
||||
Text: chatResp.Message.Content,
|
||||
FinishReason: chatResp.FinishReason,
|
||||
Usage: chatResp.Usage,
|
||||
Duration: chatResp.Duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat performs chat completion
|
||||
func (a *AnthropicAdapter) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
if a.apiKey == "" {
|
||||
return nil, fmt.Errorf("anthropic API key not configured")
|
||||
}
|
||||
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = a.defaultModel
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Extract system message if present
|
||||
var systemMessage string
|
||||
var messages []map[string]string
|
||||
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == "system" {
|
||||
systemMessage = m.Content
|
||||
} else {
|
||||
messages = append(messages, map[string]string{
|
||||
"role": m.Role,
|
||||
"content": m.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
|
||||
anthropicReq := map[string]any{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": maxTokens,
|
||||
}
|
||||
|
||||
if systemMessage != "" {
|
||||
anthropicReq["system"] = systemMessage
|
||||
}
|
||||
|
||||
if req.Temperature > 0 {
|
||||
anthropicReq["temperature"] = req.Temperature
|
||||
}
|
||||
|
||||
if req.TopP > 0 {
|
||||
anthropicReq["top_p"] = req.TopP
|
||||
}
|
||||
|
||||
if len(req.Stop) > 0 {
|
||||
anthropicReq["stop_sequences"] = req.Stop
|
||||
}
|
||||
|
||||
body, err := json.Marshal(anthropicReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", a.baseURL+"/v1/messages", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("x-api-key", a.apiKey)
|
||||
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
resp, err := a.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("anthropic request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("anthropic error (%d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// Extract text from content blocks
|
||||
var responseText string
|
||||
for _, block := range result.Content {
|
||||
if block.Type == "text" {
|
||||
responseText += block.Text
|
||||
}
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
ID: result.ID,
|
||||
Model: result.Model,
|
||||
Provider: ProviderAnthropic,
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: responseText,
|
||||
},
|
||||
FinishReason: result.StopReason,
|
||||
Usage: UsageStats{
|
||||
PromptTokens: result.Usage.InputTokens,
|
||||
CompletionTokens: result.Usage.OutputTokens,
|
||||
TotalTokens: result.Usage.InputTokens + result.Usage.OutputTokens,
|
||||
},
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Embed creates embeddings (Anthropic doesn't support embeddings natively)
|
||||
func (a *AnthropicAdapter) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
|
||||
return nil, fmt.Errorf("anthropic does not support embeddings - use Ollama or OpenAI")
|
||||
}
|
||||
350
ai-compliance-sdk/internal/llm/ollama_adapter.go
Normal file
350
ai-compliance-sdk/internal/llm/ollama_adapter.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// OllamaAdapter implements the Provider interface for Ollama
|
||||
type OllamaAdapter struct {
|
||||
baseURL string
|
||||
defaultModel string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOllamaAdapter creates a new Ollama adapter
|
||||
func NewOllamaAdapter(baseURL, defaultModel string) *OllamaAdapter {
|
||||
return &OllamaAdapter{
|
||||
baseURL: baseURL,
|
||||
defaultModel: defaultModel,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Minute, // LLM requests can be slow
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name
|
||||
func (o *OllamaAdapter) Name() string {
|
||||
return ProviderOllama
|
||||
}
|
||||
|
||||
// IsAvailable checks if Ollama is reachable
|
||||
func (o *OllamaAdapter) IsAvailable(ctx context.Context) bool {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", o.baseURL+"/api/tags", nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
|
||||
// ListModels returns available Ollama models
|
||||
func (o *OllamaAdapter) ListModels(ctx context.Context) ([]Model, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", o.baseURL+"/api/tags", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list models: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Details struct {
|
||||
Format string `json:"format"`
|
||||
Family string `json:"family"`
|
||||
ParameterSize string `json:"parameter_size"`
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
} `json:"details"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]Model, len(result.Models))
|
||||
for i, m := range result.Models {
|
||||
models[i] = Model{
|
||||
ID: m.Name,
|
||||
Name: m.Name,
|
||||
Provider: ProviderOllama,
|
||||
Description: fmt.Sprintf("%s (%s)", m.Details.Family, m.Details.ParameterSize),
|
||||
ContextSize: 4096, // Default, actual varies by model
|
||||
Capabilities: []string{"chat", "completion"},
|
||||
}
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// Complete performs text completion
|
||||
func (o *OllamaAdapter) Complete(ctx context.Context, req *CompletionRequest) (*CompletionResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = o.defaultModel
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
ollamaReq := map[string]any{
|
||||
"model": model,
|
||||
"prompt": req.Prompt,
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
if ollamaReq["options"] == nil {
|
||||
ollamaReq["options"] = make(map[string]any)
|
||||
}
|
||||
ollamaReq["options"].(map[string]any)["num_predict"] = req.MaxTokens
|
||||
}
|
||||
|
||||
if req.Temperature > 0 {
|
||||
if ollamaReq["options"] == nil {
|
||||
ollamaReq["options"] = make(map[string]any)
|
||||
}
|
||||
ollamaReq["options"].(map[string]any)["temperature"] = req.Temperature
|
||||
}
|
||||
|
||||
body, err := json.Marshal(ollamaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/generate", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama error: %s", string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Model string `json:"model"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
TotalDuration int64 `json:"total_duration"`
|
||||
LoadDuration int64 `json:"load_duration"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration int64 `json:"eval_duration"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
return &CompletionResponse{
|
||||
ID: uuid.New().String(),
|
||||
Model: result.Model,
|
||||
Provider: ProviderOllama,
|
||||
Text: result.Response,
|
||||
FinishReason: "stop",
|
||||
Usage: UsageStats{
|
||||
PromptTokens: result.PromptEvalCount,
|
||||
CompletionTokens: result.EvalCount,
|
||||
TotalTokens: result.PromptEvalCount + result.EvalCount,
|
||||
},
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat performs chat completion
|
||||
func (o *OllamaAdapter) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = o.defaultModel
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Convert messages to Ollama format
|
||||
messages := make([]map[string]string, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
messages[i] = map[string]string{
|
||||
"role": m.Role,
|
||||
"content": m.Content,
|
||||
}
|
||||
}
|
||||
|
||||
ollamaReq := map[string]any{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
if ollamaReq["options"] == nil {
|
||||
ollamaReq["options"] = make(map[string]any)
|
||||
}
|
||||
ollamaReq["options"].(map[string]any)["num_predict"] = req.MaxTokens
|
||||
}
|
||||
|
||||
if req.Temperature > 0 {
|
||||
if ollamaReq["options"] == nil {
|
||||
ollamaReq["options"] = make(map[string]any)
|
||||
}
|
||||
ollamaReq["options"].(map[string]any)["temperature"] = req.Temperature
|
||||
}
|
||||
|
||||
body, err := json.Marshal(ollamaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/chat", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama chat request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama chat error: %s", string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Model string `json:"model"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
TotalDuration int64 `json:"total_duration"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode chat response: %w", err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
return &ChatResponse{
|
||||
ID: uuid.New().String(),
|
||||
Model: result.Model,
|
||||
Provider: ProviderOllama,
|
||||
Message: Message{
|
||||
Role: result.Message.Role,
|
||||
Content: result.Message.Content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
Usage: UsageStats{
|
||||
PromptTokens: result.PromptEvalCount,
|
||||
CompletionTokens: result.EvalCount,
|
||||
TotalTokens: result.PromptEvalCount + result.EvalCount,
|
||||
},
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Embed creates embeddings
|
||||
func (o *OllamaAdapter) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = "nomic-embed-text" // Default embedding model
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var embeddings [][]float64
|
||||
|
||||
for _, input := range req.Input {
|
||||
ollamaReq := map[string]any{
|
||||
"model": model,
|
||||
"prompt": input,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(ollamaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", o.baseURL+"/api/embeddings", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := o.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama embedding request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("ollama embedding error: %s", string(bodyBytes))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode embedding response: %w", err)
|
||||
}
|
||||
|
||||
embeddings = append(embeddings, result.Embedding)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
return &EmbedResponse{
|
||||
ID: uuid.New().String(),
|
||||
Model: model,
|
||||
Provider: ProviderOllama,
|
||||
Embeddings: embeddings,
|
||||
Usage: UsageStats{
|
||||
TotalTokens: len(req.Input) * 256, // Approximate
|
||||
},
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
276
ai-compliance-sdk/internal/llm/pii_detector.go
Normal file
276
ai-compliance-sdk/internal/llm/pii_detector.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
|
||||
)
|
||||
|
||||
// PIIType represents a type of personally identifiable information
|
||||
type PIIType string
|
||||
|
||||
const (
|
||||
PIITypeEmail PIIType = "email"
|
||||
PIITypePhone PIIType = "phone"
|
||||
PIITypeIPv4 PIIType = "ip_v4"
|
||||
PIITypeIPv6 PIIType = "ip_v6"
|
||||
PIITypeIBAN PIIType = "iban"
|
||||
PIITypeUUID PIIType = "uuid"
|
||||
PIITypeName PIIType = "name"
|
||||
PIITypeSocialSec PIIType = "social_security"
|
||||
PIITypeCreditCard PIIType = "credit_card"
|
||||
PIITypeDateOfBirth PIIType = "date_of_birth"
|
||||
PIITypeSalary PIIType = "salary"
|
||||
PIITypeAddress PIIType = "address"
|
||||
)
|
||||
|
||||
// PIIPattern defines a pattern for identifying PII
|
||||
type PIIPattern struct {
|
||||
Type PIIType
|
||||
Pattern *regexp.Regexp
|
||||
Replacement string
|
||||
Level rbac.PIIRedactionLevel // Minimum level at which this is redacted
|
||||
}
|
||||
|
||||
// PIIFinding represents a found PII instance
|
||||
type PIIFinding struct {
|
||||
Type string `json:"type"`
|
||||
Match string `json:"match"`
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
}
|
||||
|
||||
// PIIDetector detects and redacts personally identifiable information
|
||||
type PIIDetector struct {
|
||||
patterns []*PIIPattern
|
||||
}
|
||||
|
||||
// Pre-compiled patterns for common PII types
|
||||
var (
|
||||
emailPattern = regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b`)
|
||||
ipv4Pattern = regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b`)
|
||||
ipv6Pattern = regexp.MustCompile(`\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b`)
|
||||
phonePattern = regexp.MustCompile(`(?:\+49|0049)[\s.-]?\d{2,4}[\s.-]?\d{3,8}|\b0\d{2,4}[\s.-]?\d{3,8}\b|\b\+\d{1,3}[\s.-]?\d{2,4}[\s.-]?\d{3,8}\b`)
|
||||
ibanPattern = regexp.MustCompile(`(?i)\b[A-Z]{2}\d{2}[\s]?(?:\d{4}[\s]?){3,5}\d{1,4}\b`)
|
||||
uuidPattern = regexp.MustCompile(`(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b`)
|
||||
namePattern = regexp.MustCompile(`\b(?:Herr|Frau|Hr\.|Fr\.|Mr\.|Mrs\.|Ms\.)\s+[A-ZÄÖÜ][a-zäöüß]+(?:\s+[A-ZÄÖÜ][a-zäöüß]+)?\b`)
|
||||
creditCardPattern = regexp.MustCompile(`\b(?:\d{4}[\s-]?){3}\d{4}\b`)
|
||||
dobPattern = regexp.MustCompile(`\b(?:0[1-9]|[12][0-9]|3[01])\.(?:0[1-9]|1[012])\.(?:19|20)\d{2}\b`)
|
||||
salaryPattern = regexp.MustCompile(`(?i)(?:gehalt|salary|lohn|vergütung|einkommen)[:\s]+(?:€|EUR|USD|\$)?\s*[\d.,]+(?:\s*(?:€|EUR|USD|\$))?`)
|
||||
addressPattern = regexp.MustCompile(`(?i)\b(?:str\.|straße|strasse|weg|platz|allee)\s+\d+[a-z]?\b`)
|
||||
)
|
||||
|
||||
// NewPIIDetector creates a new PII detector with default patterns
|
||||
func NewPIIDetector() *PIIDetector {
|
||||
return &PIIDetector{
|
||||
patterns: DefaultPIIPatterns(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewPIIDetectorWithPatterns creates a new PII detector with custom patterns
|
||||
func NewPIIDetectorWithPatterns(patterns []*PIIPattern) *PIIDetector {
|
||||
return &PIIDetector{
|
||||
patterns: patterns,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultPIIPatterns returns the default set of PII patterns
|
||||
func DefaultPIIPatterns() []*PIIPattern {
|
||||
return []*PIIPattern{
|
||||
{Type: PIITypeEmail, Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
{Type: PIITypeIPv4, Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
{Type: PIITypeIPv6, Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
{Type: PIITypePhone, Pattern: phonePattern, Replacement: "[PHONE_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
}
|
||||
}
|
||||
|
||||
// AllPIIPatterns returns all available PII patterns
|
||||
func AllPIIPatterns() []*PIIPattern {
|
||||
return []*PIIPattern{
|
||||
{Type: PIITypeEmail, Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
{Type: PIITypeIPv4, Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
{Type: PIITypeIPv6, Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
{Type: PIITypePhone, Pattern: phonePattern, Replacement: "[PHONE_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
{Type: PIITypeIBAN, Pattern: ibanPattern, Replacement: "[IBAN_REDACTED]", Level: rbac.PIIRedactionModerate},
|
||||
{Type: PIITypeUUID, Pattern: uuidPattern, Replacement: "[UUID_REDACTED]", Level: rbac.PIIRedactionStrict},
|
||||
{Type: PIITypeName, Pattern: namePattern, Replacement: "[NAME_REDACTED]", Level: rbac.PIIRedactionModerate},
|
||||
{Type: PIITypeCreditCard, Pattern: creditCardPattern, Replacement: "[CARD_REDACTED]", Level: rbac.PIIRedactionMinimal},
|
||||
{Type: PIITypeDateOfBirth, Pattern: dobPattern, Replacement: "[DOB_REDACTED]", Level: rbac.PIIRedactionModerate},
|
||||
{Type: PIITypeSalary, Pattern: salaryPattern, Replacement: "[SALARY_REDACTED]", Level: rbac.PIIRedactionStrict},
|
||||
{Type: PIITypeAddress, Pattern: addressPattern, Replacement: "[ADDRESS_REDACTED]", Level: rbac.PIIRedactionModerate},
|
||||
}
|
||||
}
|
||||
|
||||
// FindPII finds all PII in the text
|
||||
func (d *PIIDetector) FindPII(text string) []PIIFinding {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var findings []PIIFinding
|
||||
for _, pattern := range d.patterns {
|
||||
matches := pattern.Pattern.FindAllStringIndex(text, -1)
|
||||
for _, match := range matches {
|
||||
findings = append(findings, PIIFinding{
|
||||
Type: string(pattern.Type),
|
||||
Match: text[match[0]:match[1]],
|
||||
Start: match[0],
|
||||
End: match[1],
|
||||
})
|
||||
}
|
||||
}
|
||||
return findings
|
||||
}
|
||||
|
||||
// ContainsPII checks if the text contains any PII
|
||||
func (d *PIIDetector) ContainsPII(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, pattern := range d.patterns {
|
||||
if pattern.Pattern.MatchString(text) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Redact removes PII from the given text based on redaction level
|
||||
func (d *PIIDetector) Redact(text string, level rbac.PIIRedactionLevel) string {
|
||||
if text == "" || level == rbac.PIIRedactionNone {
|
||||
return text
|
||||
}
|
||||
|
||||
result := text
|
||||
for _, pattern := range d.patterns {
|
||||
if d.shouldRedactAtLevel(pattern.Level, level) {
|
||||
result = pattern.Pattern.ReplaceAllString(result, pattern.Replacement)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// shouldRedactAtLevel determines if a pattern should be applied at the given level
|
||||
func (d *PIIDetector) shouldRedactAtLevel(patternLevel, requestedLevel rbac.PIIRedactionLevel) bool {
|
||||
levelOrder := map[rbac.PIIRedactionLevel]int{
|
||||
rbac.PIIRedactionNone: 0,
|
||||
rbac.PIIRedactionMinimal: 1,
|
||||
rbac.PIIRedactionModerate: 2,
|
||||
rbac.PIIRedactionStrict: 3,
|
||||
}
|
||||
|
||||
return levelOrder[requestedLevel] >= levelOrder[patternLevel]
|
||||
}
|
||||
|
||||
// RedactMap redacts PII from all string values in a map
|
||||
func (d *PIIDetector) RedactMap(data map[string]any, level rbac.PIIRedactionLevel) map[string]any {
|
||||
result := make(map[string]any)
|
||||
for key, value := range data {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
result[key] = d.Redact(v, level)
|
||||
case map[string]any:
|
||||
result[key] = d.RedactMap(v, level)
|
||||
case []any:
|
||||
result[key] = d.redactSlice(v, level)
|
||||
default:
|
||||
result[key] = v
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (d *PIIDetector) redactSlice(data []any, level rbac.PIIRedactionLevel) []any {
|
||||
result := make([]any, len(data))
|
||||
for i, value := range data {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
result[i] = d.Redact(v, level)
|
||||
case map[string]any:
|
||||
result[i] = d.RedactMap(v, level)
|
||||
case []any:
|
||||
result[i] = d.redactSlice(v, level)
|
||||
default:
|
||||
result[i] = v
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SafeLogString creates a safe-to-log version of a string
|
||||
func (d *PIIDetector) SafeLogString(text string) string {
|
||||
return d.Redact(text, rbac.PIIRedactionStrict)
|
||||
}
|
||||
|
||||
// DetectDataCategories attempts to detect data categories in text
|
||||
func (d *PIIDetector) DetectDataCategories(text string) []string {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var categories []string
|
||||
textLower := strings.ToLower(text)
|
||||
|
||||
// Salary detection
|
||||
if salaryPattern.MatchString(text) || strings.Contains(textLower, "gehalt") || strings.Contains(textLower, "salary") {
|
||||
categories = append(categories, "salary")
|
||||
}
|
||||
|
||||
// Health detection
|
||||
healthKeywords := []string{"diagnose", "krankheit", "medikament", "therapie", "arzt", "krankenhaus",
|
||||
"health", "medical", "diagnosis", "treatment", "hospital"}
|
||||
for _, kw := range healthKeywords {
|
||||
if strings.Contains(textLower, kw) {
|
||||
categories = append(categories, "health")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Financial detection
|
||||
if ibanPattern.MatchString(text) || creditCardPattern.MatchString(text) ||
|
||||
strings.Contains(textLower, "konto") || strings.Contains(textLower, "bank") {
|
||||
categories = append(categories, "financial")
|
||||
}
|
||||
|
||||
// Personal detection (names, addresses, DOB)
|
||||
if namePattern.MatchString(text) || addressPattern.MatchString(text) || dobPattern.MatchString(text) {
|
||||
categories = append(categories, "personal")
|
||||
}
|
||||
|
||||
// HR detection
|
||||
hrKeywords := []string{"mitarbeiter", "employee", "kündigung", "termination", "beförderung", "promotion",
|
||||
"leistungsbeurteilung", "performance review", "personalakte"}
|
||||
for _, kw := range hrKeywords {
|
||||
if strings.Contains(textLower, kw) {
|
||||
categories = append(categories, "hr")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return categories
|
||||
}
|
||||
|
||||
// Global default detector
|
||||
var defaultDetector = NewPIIDetectorWithPatterns(AllPIIPatterns())
|
||||
|
||||
// RedactPII is a convenience function using the default detector
|
||||
func RedactPII(text string, level rbac.PIIRedactionLevel) string {
|
||||
return defaultDetector.Redact(text, level)
|
||||
}
|
||||
|
||||
// ContainsPIIDefault checks if text contains PII using default patterns
|
||||
func ContainsPIIDefault(text string) bool {
|
||||
return defaultDetector.ContainsPII(text)
|
||||
}
|
||||
|
||||
// FindPIIDefault finds PII using default patterns
|
||||
func FindPIIDefault(text string) []PIIFinding {
|
||||
return defaultDetector.FindPII(text)
|
||||
}
|
||||
|
||||
// DetectDataCategoriesDefault detects data categories using default detector
|
||||
func DetectDataCategoriesDefault(text string) []string {
|
||||
return defaultDetector.DetectDataCategories(text)
|
||||
}
|
||||
239
ai-compliance-sdk/internal/llm/provider.go
Normal file
239
ai-compliance-sdk/internal/llm/provider.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider names
|
||||
const (
|
||||
ProviderOllama = "ollama"
|
||||
ProviderAnthropic = "anthropic"
|
||||
ProviderOpenAI = "openai"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProviderUnavailable = errors.New("LLM provider unavailable")
|
||||
ErrModelNotFound = errors.New("model not found")
|
||||
ErrContextTooLong = errors.New("context too long for model")
|
||||
ErrRateLimited = errors.New("rate limited")
|
||||
ErrInvalidRequest = errors.New("invalid request")
|
||||
)
|
||||
|
||||
// Provider defines the interface for LLM providers
|
||||
type Provider interface {
|
||||
// Name returns the provider name
|
||||
Name() string
|
||||
|
||||
// IsAvailable checks if the provider is currently available
|
||||
IsAvailable(ctx context.Context) bool
|
||||
|
||||
// ListModels returns available models
|
||||
ListModels(ctx context.Context) ([]Model, error)
|
||||
|
||||
// Complete performs text completion
|
||||
Complete(ctx context.Context, req *CompletionRequest) (*CompletionResponse, error)
|
||||
|
||||
// Chat performs chat completion
|
||||
Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
|
||||
|
||||
// Embed creates embeddings for text
|
||||
Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error)
|
||||
}
|
||||
|
||||
// Model represents an available LLM model
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Description string `json:"description,omitempty"`
|
||||
ContextSize int `json:"context_size"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Capabilities []string `json:"capabilities,omitempty"` // "chat", "completion", "embedding"
|
||||
}
|
||||
|
||||
// Message represents a chat message
|
||||
type Message struct {
|
||||
Role string `json:"role"` // "system", "user", "assistant"
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// CompletionRequest represents a text completion request
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
// CompletionResponse represents a text completion response
|
||||
type CompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Provider string `json:"provider"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage UsageStats `json:"usage"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
}
|
||||
|
||||
// ChatRequest represents a chat completion request
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
// ChatResponse represents a chat completion response
|
||||
type ChatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Provider string `json:"provider"`
|
||||
Message Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage UsageStats `json:"usage"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
}
|
||||
|
||||
// EmbedRequest represents an embedding request
|
||||
type EmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
// EmbedResponse represents an embedding response
|
||||
type EmbedResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Provider string `json:"provider"`
|
||||
Embeddings [][]float64 `json:"embeddings"`
|
||||
Usage UsageStats `json:"usage"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
}
|
||||
|
||||
// UsageStats represents token usage statistics
|
||||
type UsageStats struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// ProviderRegistry manages multiple LLM providers
|
||||
type ProviderRegistry struct {
|
||||
providers map[string]Provider
|
||||
primaryProvider string
|
||||
fallbackProvider string
|
||||
}
|
||||
|
||||
// NewProviderRegistry creates a new provider registry
|
||||
func NewProviderRegistry(primary, fallback string) *ProviderRegistry {
|
||||
return &ProviderRegistry{
|
||||
providers: make(map[string]Provider),
|
||||
primaryProvider: primary,
|
||||
fallbackProvider: fallback,
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a provider
|
||||
func (r *ProviderRegistry) Register(provider Provider) {
|
||||
r.providers[provider.Name()] = provider
|
||||
}
|
||||
|
||||
// GetProvider returns a provider by name
|
||||
func (r *ProviderRegistry) GetProvider(name string) (Provider, bool) {
|
||||
p, ok := r.providers[name]
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// GetPrimary returns the primary provider
|
||||
func (r *ProviderRegistry) GetPrimary() (Provider, bool) {
|
||||
return r.GetProvider(r.primaryProvider)
|
||||
}
|
||||
|
||||
// GetFallback returns the fallback provider
|
||||
func (r *ProviderRegistry) GetFallback() (Provider, bool) {
|
||||
return r.GetProvider(r.fallbackProvider)
|
||||
}
|
||||
|
||||
// GetAvailable returns the first available provider (primary, then fallback)
|
||||
func (r *ProviderRegistry) GetAvailable(ctx context.Context) (Provider, error) {
|
||||
if p, ok := r.GetPrimary(); ok && p.IsAvailable(ctx) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
if p, ok := r.GetFallback(); ok && p.IsAvailable(ctx) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
return nil, ErrProviderUnavailable
|
||||
}
|
||||
|
||||
// ListAllModels returns models from all available providers
|
||||
func (r *ProviderRegistry) ListAllModels(ctx context.Context) ([]Model, error) {
|
||||
var allModels []Model
|
||||
|
||||
for _, p := range r.providers {
|
||||
if p.IsAvailable(ctx) {
|
||||
models, err := p.ListModels(ctx)
|
||||
if err == nil {
|
||||
allModels = append(allModels, models...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allModels, nil
|
||||
}
|
||||
|
||||
// Complete performs completion with automatic fallback
|
||||
func (r *ProviderRegistry) Complete(ctx context.Context, req *CompletionRequest) (*CompletionResponse, error) {
|
||||
provider, err := r.GetAvailable(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := provider.Complete(ctx, req)
|
||||
if err != nil && r.fallbackProvider != "" {
|
||||
// Try fallback
|
||||
if fallback, ok := r.GetFallback(); ok && fallback.Name() != provider.Name() {
|
||||
return fallback.Complete(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Chat performs chat completion with automatic fallback
|
||||
func (r *ProviderRegistry) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
provider, err := r.GetAvailable(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := provider.Chat(ctx, req)
|
||||
if err != nil && r.fallbackProvider != "" {
|
||||
// Try fallback
|
||||
if fallback, ok := r.GetFallback(); ok && fallback.Name() != provider.Name() {
|
||||
return fallback.Chat(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Embed creates embeddings with automatic fallback
|
||||
func (r *ProviderRegistry) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
|
||||
provider, err := r.GetAvailable(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return provider.Embed(ctx, req)
|
||||
}
|
||||
Reference in New Issue
Block a user