Initial commit: breakpilot-compliance - Compliance SDK Platform
Services: Admin-Compliance, Backend-Compliance, AI-Compliance-SDK, Consent-SDK, Developer-Portal, PCA-Platform, DSMS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
367
ai-compliance-sdk/internal/llm/access_gate.go
Normal file
367
ai-compliance-sdk/internal/llm/access_gate.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AccessGate controls access to LLM operations based on RBAC policies
|
||||
type AccessGate struct {
|
||||
policyEngine *rbac.PolicyEngine
|
||||
piiDetector *PIIDetector
|
||||
registry *ProviderRegistry
|
||||
}
|
||||
|
||||
// NewAccessGate creates a new access gate
|
||||
func NewAccessGate(policyEngine *rbac.PolicyEngine, piiDetector *PIIDetector, registry *ProviderRegistry) *AccessGate {
|
||||
return &AccessGate{
|
||||
policyEngine: policyEngine,
|
||||
piiDetector: piiDetector,
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
// GatedRequest represents a request that has passed through the access gate
|
||||
type GatedRequest struct {
|
||||
OriginalRequest any
|
||||
UserID uuid.UUID
|
||||
TenantID uuid.UUID
|
||||
NamespaceID *uuid.UUID
|
||||
Model string
|
||||
PIIDetected bool
|
||||
PIITypes []string
|
||||
PromptRedacted bool
|
||||
PromptHash string
|
||||
Policy *rbac.LLMPolicy
|
||||
AccessResult *rbac.LLMAccessResult
|
||||
}
|
||||
|
||||
// GatedChatRequest represents a chat request that has passed through the gate
|
||||
type GatedChatRequest struct {
|
||||
*GatedRequest
|
||||
Messages []Message
|
||||
}
|
||||
|
||||
// GatedCompletionRequest represents a completion request that has passed through the gate
|
||||
type GatedCompletionRequest struct {
|
||||
*GatedRequest
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// ProcessChatRequest validates and processes a chat request
|
||||
func (g *AccessGate) ProcessChatRequest(
|
||||
ctx context.Context,
|
||||
userID, tenantID uuid.UUID,
|
||||
namespaceID *uuid.UUID,
|
||||
req *ChatRequest,
|
||||
dataCategories []string,
|
||||
) (*GatedChatRequest, error) {
|
||||
// 1. Evaluate LLM access
|
||||
accessReq := &rbac.LLMAccessRequest{
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
NamespaceID: namespaceID,
|
||||
Model: req.Model,
|
||||
DataCategories: dataCategories,
|
||||
TokensRequested: req.MaxTokens,
|
||||
Operation: "chat",
|
||||
}
|
||||
|
||||
accessResult, err := g.policyEngine.EvaluateLLMAccess(ctx, accessReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access evaluation failed: %w", err)
|
||||
}
|
||||
|
||||
if !accessResult.Allowed {
|
||||
return nil, fmt.Errorf("access denied: %s", accessResult.Reason)
|
||||
}
|
||||
|
||||
// 2. Process messages for PII
|
||||
processedMessages := make([]Message, len(req.Messages))
|
||||
copy(processedMessages, req.Messages)
|
||||
|
||||
var allPIITypes []string
|
||||
piiDetected := false
|
||||
redacted := false
|
||||
|
||||
for i, msg := range processedMessages {
|
||||
if msg.Role == "user" || msg.Role == "system" {
|
||||
// Check for PII
|
||||
findings := g.piiDetector.FindPII(msg.Content)
|
||||
if len(findings) > 0 {
|
||||
piiDetected = true
|
||||
for _, f := range findings {
|
||||
allPIITypes = append(allPIITypes, f.Type)
|
||||
}
|
||||
|
||||
// Redact if required by policy
|
||||
if accessResult.RequirePIIRedaction {
|
||||
processedMessages[i].Content = g.piiDetector.Redact(msg.Content, accessResult.PIIRedactionLevel)
|
||||
redacted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Generate prompt hash for audit
|
||||
promptHash := g.hashMessages(processedMessages)
|
||||
|
||||
// 4. Apply token limits from policy
|
||||
if accessResult.MaxTokens > 0 && req.MaxTokens > accessResult.MaxTokens {
|
||||
req.MaxTokens = accessResult.MaxTokens
|
||||
}
|
||||
|
||||
return &GatedChatRequest{
|
||||
GatedRequest: &GatedRequest{
|
||||
OriginalRequest: req,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
NamespaceID: namespaceID,
|
||||
Model: req.Model,
|
||||
PIIDetected: piiDetected,
|
||||
PIITypes: uniqueStrings(allPIITypes),
|
||||
PromptRedacted: redacted,
|
||||
PromptHash: promptHash,
|
||||
Policy: accessResult.Policy,
|
||||
AccessResult: accessResult,
|
||||
},
|
||||
Messages: processedMessages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessCompletionRequest validates and processes a completion request
|
||||
func (g *AccessGate) ProcessCompletionRequest(
|
||||
ctx context.Context,
|
||||
userID, tenantID uuid.UUID,
|
||||
namespaceID *uuid.UUID,
|
||||
req *CompletionRequest,
|
||||
dataCategories []string,
|
||||
) (*GatedCompletionRequest, error) {
|
||||
// 1. Evaluate LLM access
|
||||
accessReq := &rbac.LLMAccessRequest{
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
NamespaceID: namespaceID,
|
||||
Model: req.Model,
|
||||
DataCategories: dataCategories,
|
||||
TokensRequested: req.MaxTokens,
|
||||
Operation: "completion",
|
||||
}
|
||||
|
||||
accessResult, err := g.policyEngine.EvaluateLLMAccess(ctx, accessReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access evaluation failed: %w", err)
|
||||
}
|
||||
|
||||
if !accessResult.Allowed {
|
||||
return nil, fmt.Errorf("access denied: %s", accessResult.Reason)
|
||||
}
|
||||
|
||||
// 2. Process prompt for PII
|
||||
processedPrompt := req.Prompt
|
||||
var allPIITypes []string
|
||||
piiDetected := false
|
||||
redacted := false
|
||||
|
||||
findings := g.piiDetector.FindPII(req.Prompt)
|
||||
if len(findings) > 0 {
|
||||
piiDetected = true
|
||||
for _, f := range findings {
|
||||
allPIITypes = append(allPIITypes, f.Type)
|
||||
}
|
||||
|
||||
// Redact if required by policy
|
||||
if accessResult.RequirePIIRedaction {
|
||||
processedPrompt = g.piiDetector.Redact(req.Prompt, accessResult.PIIRedactionLevel)
|
||||
redacted = true
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Generate prompt hash for audit
|
||||
promptHash := g.hashPrompt(processedPrompt)
|
||||
|
||||
// 4. Apply token limits from policy
|
||||
if accessResult.MaxTokens > 0 && req.MaxTokens > accessResult.MaxTokens {
|
||||
req.MaxTokens = accessResult.MaxTokens
|
||||
}
|
||||
|
||||
return &GatedCompletionRequest{
|
||||
GatedRequest: &GatedRequest{
|
||||
OriginalRequest: req,
|
||||
UserID: userID,
|
||||
TenantID: tenantID,
|
||||
NamespaceID: namespaceID,
|
||||
Model: req.Model,
|
||||
PIIDetected: piiDetected,
|
||||
PIITypes: uniqueStrings(allPIITypes),
|
||||
PromptRedacted: redacted,
|
||||
PromptHash: promptHash,
|
||||
Policy: accessResult.Policy,
|
||||
AccessResult: accessResult,
|
||||
},
|
||||
Prompt: processedPrompt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecuteChat executes a gated chat request
|
||||
func (g *AccessGate) ExecuteChat(ctx context.Context, gatedReq *GatedChatRequest) (*ChatResponse, error) {
|
||||
provider, err := g.registry.GetAvailable(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &ChatRequest{
|
||||
Model: gatedReq.Model,
|
||||
Messages: gatedReq.Messages,
|
||||
MaxTokens: gatedReq.AccessResult.MaxTokens,
|
||||
Temperature: 0.7,
|
||||
}
|
||||
|
||||
if orig, ok := gatedReq.OriginalRequest.(*ChatRequest); ok {
|
||||
req.Temperature = orig.Temperature
|
||||
req.TopP = orig.TopP
|
||||
req.Stop = orig.Stop
|
||||
req.Options = orig.Options
|
||||
}
|
||||
|
||||
return provider.Chat(ctx, req)
|
||||
}
|
||||
|
||||
// ExecuteCompletion executes a gated completion request
|
||||
func (g *AccessGate) ExecuteCompletion(ctx context.Context, gatedReq *GatedCompletionRequest) (*CompletionResponse, error) {
|
||||
provider, err := g.registry.GetAvailable(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &CompletionRequest{
|
||||
Model: gatedReq.Model,
|
||||
Prompt: gatedReq.Prompt,
|
||||
MaxTokens: gatedReq.AccessResult.MaxTokens,
|
||||
}
|
||||
|
||||
if orig, ok := gatedReq.OriginalRequest.(*CompletionRequest); ok {
|
||||
req.Temperature = orig.Temperature
|
||||
req.TopP = orig.TopP
|
||||
req.Stop = orig.Stop
|
||||
req.Options = orig.Options
|
||||
}
|
||||
|
||||
return provider.Complete(ctx, req)
|
||||
}
|
||||
|
||||
// hashMessages creates a SHA-256 hash of chat messages (for audit without storing PII)
|
||||
func (g *AccessGate) hashMessages(messages []Message) string {
|
||||
hasher := sha256.New()
|
||||
for _, msg := range messages {
|
||||
hasher.Write([]byte(msg.Role))
|
||||
hasher.Write([]byte(msg.Content))
|
||||
}
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// hashPrompt creates a SHA-256 hash of a prompt
|
||||
func (g *AccessGate) hashPrompt(prompt string) string {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(prompt))
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// uniqueStrings returns unique strings from a slice
|
||||
func uniqueStrings(slice []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var result []string
|
||||
for _, s := range slice {
|
||||
if !seen[s] {
|
||||
seen[s] = true
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AuditEntry represents an entry for the audit log
|
||||
type AuditEntry struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
TenantID uuid.UUID `json:"tenant_id"`
|
||||
NamespaceID *uuid.UUID `json:"namespace_id,omitempty"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Operation string `json:"operation"`
|
||||
ModelUsed string `json:"model_used"`
|
||||
Provider string `json:"provider"`
|
||||
PromptHash string `json:"prompt_hash"`
|
||||
PromptLength int `json:"prompt_length"`
|
||||
ResponseLength int `json:"response_length,omitempty"`
|
||||
TokensUsed int `json:"tokens_used"`
|
||||
DurationMS int `json:"duration_ms"`
|
||||
PIIDetected bool `json:"pii_detected"`
|
||||
PIITypesDetected []string `json:"pii_types_detected,omitempty"`
|
||||
PIIRedacted bool `json:"pii_redacted"`
|
||||
PolicyID *uuid.UUID `json:"policy_id,omitempty"`
|
||||
PolicyViolations []string `json:"policy_violations,omitempty"`
|
||||
DataCategoriesAccessed []string `json:"data_categories_accessed,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
RequestMetadata map[string]any `json:"request_metadata,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// CreateAuditEntry creates an audit entry from a gated request and response
|
||||
func (g *AccessGate) CreateAuditEntry(
|
||||
gatedReq *GatedRequest,
|
||||
operation string,
|
||||
provider string,
|
||||
resp any,
|
||||
err error,
|
||||
promptLength int,
|
||||
sessionID string,
|
||||
) *AuditEntry {
|
||||
entry := &AuditEntry{
|
||||
ID: uuid.New(),
|
||||
TenantID: gatedReq.TenantID,
|
||||
NamespaceID: gatedReq.NamespaceID,
|
||||
UserID: gatedReq.UserID,
|
||||
SessionID: sessionID,
|
||||
Operation: operation,
|
||||
ModelUsed: gatedReq.Model,
|
||||
Provider: provider,
|
||||
PromptHash: gatedReq.PromptHash,
|
||||
PromptLength: promptLength,
|
||||
PIIDetected: gatedReq.PIIDetected,
|
||||
PIITypesDetected: gatedReq.PIITypes,
|
||||
PIIRedacted: gatedReq.PromptRedacted,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if gatedReq.Policy != nil {
|
||||
entry.PolicyID = &gatedReq.Policy.ID
|
||||
}
|
||||
|
||||
if gatedReq.AccessResult != nil && len(gatedReq.AccessResult.BlockedCategories) > 0 {
|
||||
entry.PolicyViolations = gatedReq.AccessResult.BlockedCategories
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
entry.ErrorMessage = err.Error()
|
||||
}
|
||||
|
||||
// Extract usage from response
|
||||
switch r := resp.(type) {
|
||||
case *ChatResponse:
|
||||
entry.ResponseLength = len(r.Message.Content)
|
||||
entry.TokensUsed = r.Usage.TotalTokens
|
||||
entry.DurationMS = int(r.Duration.Milliseconds())
|
||||
case *CompletionResponse:
|
||||
entry.ResponseLength = len(r.Text)
|
||||
entry.TokensUsed = r.Usage.TotalTokens
|
||||
entry.DurationMS = int(r.Duration.Milliseconds())
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
Reference in New Issue
Block a user