package llm import ( "context" "crypto/sha256" "encoding/hex" "fmt" "time" "github.com/breakpilot/ai-compliance-sdk/internal/rbac" "github.com/google/uuid" ) // AccessGate controls access to LLM operations based on RBAC policies type AccessGate struct { policyEngine *rbac.PolicyEngine piiDetector *PIIDetector registry *ProviderRegistry } // NewAccessGate creates a new access gate func NewAccessGate(policyEngine *rbac.PolicyEngine, piiDetector *PIIDetector, registry *ProviderRegistry) *AccessGate { return &AccessGate{ policyEngine: policyEngine, piiDetector: piiDetector, registry: registry, } } // GatedRequest represents a request that has passed through the access gate type GatedRequest struct { OriginalRequest any UserID uuid.UUID TenantID uuid.UUID NamespaceID *uuid.UUID Model string PIIDetected bool PIITypes []string PromptRedacted bool PromptHash string Policy *rbac.LLMPolicy AccessResult *rbac.LLMAccessResult } // GatedChatRequest represents a chat request that has passed through the gate type GatedChatRequest struct { *GatedRequest Messages []Message } // GatedCompletionRequest represents a completion request that has passed through the gate type GatedCompletionRequest struct { *GatedRequest Prompt string } // ProcessChatRequest validates and processes a chat request func (g *AccessGate) ProcessChatRequest( ctx context.Context, userID, tenantID uuid.UUID, namespaceID *uuid.UUID, req *ChatRequest, dataCategories []string, ) (*GatedChatRequest, error) { // 1. Evaluate LLM access accessReq := &rbac.LLMAccessRequest{ UserID: userID, TenantID: tenantID, NamespaceID: namespaceID, Model: req.Model, DataCategories: dataCategories, TokensRequested: req.MaxTokens, Operation: "chat", } accessResult, err := g.policyEngine.EvaluateLLMAccess(ctx, accessReq) if err != nil { return nil, fmt.Errorf("access evaluation failed: %w", err) } if !accessResult.Allowed { return nil, fmt.Errorf("access denied: %s", accessResult.Reason) } // 2. Process messages for PII processedMessages := make([]Message, len(req.Messages)) copy(processedMessages, req.Messages) var allPIITypes []string piiDetected := false redacted := false for i, msg := range processedMessages { if msg.Role == "user" || msg.Role == "system" { // Check for PII findings := g.piiDetector.FindPII(msg.Content) if len(findings) > 0 { piiDetected = true for _, f := range findings { allPIITypes = append(allPIITypes, f.Type) } // Redact if required by policy if accessResult.RequirePIIRedaction { processedMessages[i].Content = g.piiDetector.Redact(msg.Content, accessResult.PIIRedactionLevel) redacted = true } } } } // 3. Generate prompt hash for audit promptHash := g.hashMessages(processedMessages) // 4. Apply token limits from policy if accessResult.MaxTokens > 0 && req.MaxTokens > accessResult.MaxTokens { req.MaxTokens = accessResult.MaxTokens } return &GatedChatRequest{ GatedRequest: &GatedRequest{ OriginalRequest: req, UserID: userID, TenantID: tenantID, NamespaceID: namespaceID, Model: req.Model, PIIDetected: piiDetected, PIITypes: uniqueStrings(allPIITypes), PromptRedacted: redacted, PromptHash: promptHash, Policy: accessResult.Policy, AccessResult: accessResult, }, Messages: processedMessages, }, nil } // ProcessCompletionRequest validates and processes a completion request func (g *AccessGate) ProcessCompletionRequest( ctx context.Context, userID, tenantID uuid.UUID, namespaceID *uuid.UUID, req *CompletionRequest, dataCategories []string, ) (*GatedCompletionRequest, error) { // 1. Evaluate LLM access accessReq := &rbac.LLMAccessRequest{ UserID: userID, TenantID: tenantID, NamespaceID: namespaceID, Model: req.Model, DataCategories: dataCategories, TokensRequested: req.MaxTokens, Operation: "completion", } accessResult, err := g.policyEngine.EvaluateLLMAccess(ctx, accessReq) if err != nil { return nil, fmt.Errorf("access evaluation failed: %w", err) } if !accessResult.Allowed { return nil, fmt.Errorf("access denied: %s", accessResult.Reason) } // 2. Process prompt for PII processedPrompt := req.Prompt var allPIITypes []string piiDetected := false redacted := false findings := g.piiDetector.FindPII(req.Prompt) if len(findings) > 0 { piiDetected = true for _, f := range findings { allPIITypes = append(allPIITypes, f.Type) } // Redact if required by policy if accessResult.RequirePIIRedaction { processedPrompt = g.piiDetector.Redact(req.Prompt, accessResult.PIIRedactionLevel) redacted = true } } // 3. Generate prompt hash for audit promptHash := g.hashPrompt(processedPrompt) // 4. Apply token limits from policy if accessResult.MaxTokens > 0 && req.MaxTokens > accessResult.MaxTokens { req.MaxTokens = accessResult.MaxTokens } return &GatedCompletionRequest{ GatedRequest: &GatedRequest{ OriginalRequest: req, UserID: userID, TenantID: tenantID, NamespaceID: namespaceID, Model: req.Model, PIIDetected: piiDetected, PIITypes: uniqueStrings(allPIITypes), PromptRedacted: redacted, PromptHash: promptHash, Policy: accessResult.Policy, AccessResult: accessResult, }, Prompt: processedPrompt, }, nil } // ExecuteChat executes a gated chat request func (g *AccessGate) ExecuteChat(ctx context.Context, gatedReq *GatedChatRequest) (*ChatResponse, error) { provider, err := g.registry.GetAvailable(ctx) if err != nil { return nil, err } req := &ChatRequest{ Model: gatedReq.Model, Messages: gatedReq.Messages, MaxTokens: gatedReq.AccessResult.MaxTokens, Temperature: 0.7, } if orig, ok := gatedReq.OriginalRequest.(*ChatRequest); ok { req.Temperature = orig.Temperature req.TopP = orig.TopP req.Stop = orig.Stop req.Options = orig.Options } return provider.Chat(ctx, req) } // ExecuteCompletion executes a gated completion request func (g *AccessGate) ExecuteCompletion(ctx context.Context, gatedReq *GatedCompletionRequest) (*CompletionResponse, error) { provider, err := g.registry.GetAvailable(ctx) if err != nil { return nil, err } req := &CompletionRequest{ Model: gatedReq.Model, Prompt: gatedReq.Prompt, MaxTokens: gatedReq.AccessResult.MaxTokens, } if orig, ok := gatedReq.OriginalRequest.(*CompletionRequest); ok { req.Temperature = orig.Temperature req.TopP = orig.TopP req.Stop = orig.Stop req.Options = orig.Options } return provider.Complete(ctx, req) } // hashMessages creates a SHA-256 hash of chat messages (for audit without storing PII) func (g *AccessGate) hashMessages(messages []Message) string { hasher := sha256.New() for _, msg := range messages { hasher.Write([]byte(msg.Role)) hasher.Write([]byte(msg.Content)) } return hex.EncodeToString(hasher.Sum(nil)) } // hashPrompt creates a SHA-256 hash of a prompt func (g *AccessGate) hashPrompt(prompt string) string { hasher := sha256.New() hasher.Write([]byte(prompt)) return hex.EncodeToString(hasher.Sum(nil)) } // uniqueStrings returns unique strings from a slice func uniqueStrings(slice []string) []string { seen := make(map[string]bool) var result []string for _, s := range slice { if !seen[s] { seen[s] = true result = append(result, s) } } return result } // AuditEntry represents an entry for the audit log type AuditEntry struct { ID uuid.UUID `json:"id"` TenantID uuid.UUID `json:"tenant_id"` NamespaceID *uuid.UUID `json:"namespace_id,omitempty"` UserID uuid.UUID `json:"user_id"` SessionID string `json:"session_id,omitempty"` Operation string `json:"operation"` ModelUsed string `json:"model_used"` Provider string `json:"provider"` PromptHash string `json:"prompt_hash"` PromptLength int `json:"prompt_length"` ResponseLength int `json:"response_length,omitempty"` TokensUsed int `json:"tokens_used"` DurationMS int `json:"duration_ms"` PIIDetected bool `json:"pii_detected"` PIITypesDetected []string `json:"pii_types_detected,omitempty"` PIIRedacted bool `json:"pii_redacted"` PolicyID *uuid.UUID `json:"policy_id,omitempty"` PolicyViolations []string `json:"policy_violations,omitempty"` DataCategoriesAccessed []string `json:"data_categories_accessed,omitempty"` ErrorMessage string `json:"error_message,omitempty"` RequestMetadata map[string]any `json:"request_metadata,omitempty"` CreatedAt time.Time `json:"created_at"` } // CreateAuditEntry creates an audit entry from a gated request and response func (g *AccessGate) CreateAuditEntry( gatedReq *GatedRequest, operation string, provider string, resp any, err error, promptLength int, sessionID string, ) *AuditEntry { entry := &AuditEntry{ ID: uuid.New(), TenantID: gatedReq.TenantID, NamespaceID: gatedReq.NamespaceID, UserID: gatedReq.UserID, SessionID: sessionID, Operation: operation, ModelUsed: gatedReq.Model, Provider: provider, PromptHash: gatedReq.PromptHash, PromptLength: promptLength, PIIDetected: gatedReq.PIIDetected, PIITypesDetected: gatedReq.PIITypes, PIIRedacted: gatedReq.PromptRedacted, CreatedAt: time.Now().UTC(), } if gatedReq.Policy != nil { entry.PolicyID = &gatedReq.Policy.ID } if gatedReq.AccessResult != nil && len(gatedReq.AccessResult.BlockedCategories) > 0 { entry.PolicyViolations = gatedReq.AccessResult.BlockedCategories } if err != nil { entry.ErrorMessage = err.Error() } // Extract usage from response switch r := resp.(type) { case *ChatResponse: entry.ResponseLength = len(r.Message.Content) entry.TokensUsed = r.Usage.TotalTokens entry.DurationMS = int(r.Duration.Milliseconds()) case *CompletionResponse: entry.ResponseLength = len(r.Text) entry.TokensUsed = r.Usage.TotalTokens entry.DurationMS = int(r.Duration.Milliseconds()) } return entry }