This repository has been archived on 2026-02-15. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
breakpilot-pwa/ai-compliance-sdk/internal/llm/access_gate.go
Benjamin Admin 21a844cb8a fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.

This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).

Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 09:51:32 +01:00

368 lines
10 KiB
Go

package llm
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
"github.com/google/uuid"
)
// AccessGate controls access to LLM operations based on RBAC policies
type AccessGate struct {
policyEngine *rbac.PolicyEngine
piiDetector *PIIDetector
registry *ProviderRegistry
}
// NewAccessGate creates a new access gate
func NewAccessGate(policyEngine *rbac.PolicyEngine, piiDetector *PIIDetector, registry *ProviderRegistry) *AccessGate {
return &AccessGate{
policyEngine: policyEngine,
piiDetector: piiDetector,
registry: registry,
}
}
// GatedRequest represents a request that has passed through the access gate
type GatedRequest struct {
OriginalRequest any
UserID uuid.UUID
TenantID uuid.UUID
NamespaceID *uuid.UUID
Model string
PIIDetected bool
PIITypes []string
PromptRedacted bool
PromptHash string
Policy *rbac.LLMPolicy
AccessResult *rbac.LLMAccessResult
}
// GatedChatRequest represents a chat request that has passed through the gate
type GatedChatRequest struct {
*GatedRequest
Messages []Message
}
// GatedCompletionRequest represents a completion request that has passed through the gate
type GatedCompletionRequest struct {
*GatedRequest
Prompt string
}
// ProcessChatRequest validates and processes a chat request
func (g *AccessGate) ProcessChatRequest(
ctx context.Context,
userID, tenantID uuid.UUID,
namespaceID *uuid.UUID,
req *ChatRequest,
dataCategories []string,
) (*GatedChatRequest, error) {
// 1. Evaluate LLM access
accessReq := &rbac.LLMAccessRequest{
UserID: userID,
TenantID: tenantID,
NamespaceID: namespaceID,
Model: req.Model,
DataCategories: dataCategories,
TokensRequested: req.MaxTokens,
Operation: "chat",
}
accessResult, err := g.policyEngine.EvaluateLLMAccess(ctx, accessReq)
if err != nil {
return nil, fmt.Errorf("access evaluation failed: %w", err)
}
if !accessResult.Allowed {
return nil, fmt.Errorf("access denied: %s", accessResult.Reason)
}
// 2. Process messages for PII
processedMessages := make([]Message, len(req.Messages))
copy(processedMessages, req.Messages)
var allPIITypes []string
piiDetected := false
redacted := false
for i, msg := range processedMessages {
if msg.Role == "user" || msg.Role == "system" {
// Check for PII
findings := g.piiDetector.FindPII(msg.Content)
if len(findings) > 0 {
piiDetected = true
for _, f := range findings {
allPIITypes = append(allPIITypes, f.Type)
}
// Redact if required by policy
if accessResult.RequirePIIRedaction {
processedMessages[i].Content = g.piiDetector.Redact(msg.Content, accessResult.PIIRedactionLevel)
redacted = true
}
}
}
}
// 3. Generate prompt hash for audit
promptHash := g.hashMessages(processedMessages)
// 4. Apply token limits from policy
if accessResult.MaxTokens > 0 && req.MaxTokens > accessResult.MaxTokens {
req.MaxTokens = accessResult.MaxTokens
}
return &GatedChatRequest{
GatedRequest: &GatedRequest{
OriginalRequest: req,
UserID: userID,
TenantID: tenantID,
NamespaceID: namespaceID,
Model: req.Model,
PIIDetected: piiDetected,
PIITypes: uniqueStrings(allPIITypes),
PromptRedacted: redacted,
PromptHash: promptHash,
Policy: accessResult.Policy,
AccessResult: accessResult,
},
Messages: processedMessages,
}, nil
}
// ProcessCompletionRequest validates and processes a completion request
func (g *AccessGate) ProcessCompletionRequest(
ctx context.Context,
userID, tenantID uuid.UUID,
namespaceID *uuid.UUID,
req *CompletionRequest,
dataCategories []string,
) (*GatedCompletionRequest, error) {
// 1. Evaluate LLM access
accessReq := &rbac.LLMAccessRequest{
UserID: userID,
TenantID: tenantID,
NamespaceID: namespaceID,
Model: req.Model,
DataCategories: dataCategories,
TokensRequested: req.MaxTokens,
Operation: "completion",
}
accessResult, err := g.policyEngine.EvaluateLLMAccess(ctx, accessReq)
if err != nil {
return nil, fmt.Errorf("access evaluation failed: %w", err)
}
if !accessResult.Allowed {
return nil, fmt.Errorf("access denied: %s", accessResult.Reason)
}
// 2. Process prompt for PII
processedPrompt := req.Prompt
var allPIITypes []string
piiDetected := false
redacted := false
findings := g.piiDetector.FindPII(req.Prompt)
if len(findings) > 0 {
piiDetected = true
for _, f := range findings {
allPIITypes = append(allPIITypes, f.Type)
}
// Redact if required by policy
if accessResult.RequirePIIRedaction {
processedPrompt = g.piiDetector.Redact(req.Prompt, accessResult.PIIRedactionLevel)
redacted = true
}
}
// 3. Generate prompt hash for audit
promptHash := g.hashPrompt(processedPrompt)
// 4. Apply token limits from policy
if accessResult.MaxTokens > 0 && req.MaxTokens > accessResult.MaxTokens {
req.MaxTokens = accessResult.MaxTokens
}
return &GatedCompletionRequest{
GatedRequest: &GatedRequest{
OriginalRequest: req,
UserID: userID,
TenantID: tenantID,
NamespaceID: namespaceID,
Model: req.Model,
PIIDetected: piiDetected,
PIITypes: uniqueStrings(allPIITypes),
PromptRedacted: redacted,
PromptHash: promptHash,
Policy: accessResult.Policy,
AccessResult: accessResult,
},
Prompt: processedPrompt,
}, nil
}
// ExecuteChat executes a gated chat request
func (g *AccessGate) ExecuteChat(ctx context.Context, gatedReq *GatedChatRequest) (*ChatResponse, error) {
provider, err := g.registry.GetAvailable(ctx)
if err != nil {
return nil, err
}
req := &ChatRequest{
Model: gatedReq.Model,
Messages: gatedReq.Messages,
MaxTokens: gatedReq.AccessResult.MaxTokens,
Temperature: 0.7,
}
if orig, ok := gatedReq.OriginalRequest.(*ChatRequest); ok {
req.Temperature = orig.Temperature
req.TopP = orig.TopP
req.Stop = orig.Stop
req.Options = orig.Options
}
return provider.Chat(ctx, req)
}
// ExecuteCompletion executes a gated completion request
func (g *AccessGate) ExecuteCompletion(ctx context.Context, gatedReq *GatedCompletionRequest) (*CompletionResponse, error) {
provider, err := g.registry.GetAvailable(ctx)
if err != nil {
return nil, err
}
req := &CompletionRequest{
Model: gatedReq.Model,
Prompt: gatedReq.Prompt,
MaxTokens: gatedReq.AccessResult.MaxTokens,
}
if orig, ok := gatedReq.OriginalRequest.(*CompletionRequest); ok {
req.Temperature = orig.Temperature
req.TopP = orig.TopP
req.Stop = orig.Stop
req.Options = orig.Options
}
return provider.Complete(ctx, req)
}
// hashMessages creates a SHA-256 hash of chat messages (for audit without storing PII)
func (g *AccessGate) hashMessages(messages []Message) string {
hasher := sha256.New()
for _, msg := range messages {
hasher.Write([]byte(msg.Role))
hasher.Write([]byte(msg.Content))
}
return hex.EncodeToString(hasher.Sum(nil))
}
// hashPrompt creates a SHA-256 hash of a prompt
func (g *AccessGate) hashPrompt(prompt string) string {
hasher := sha256.New()
hasher.Write([]byte(prompt))
return hex.EncodeToString(hasher.Sum(nil))
}
// uniqueStrings returns unique strings from a slice
func uniqueStrings(slice []string) []string {
seen := make(map[string]bool)
var result []string
for _, s := range slice {
if !seen[s] {
seen[s] = true
result = append(result, s)
}
}
return result
}
// AuditEntry represents an entry for the audit log
type AuditEntry struct {
ID uuid.UUID `json:"id"`
TenantID uuid.UUID `json:"tenant_id"`
NamespaceID *uuid.UUID `json:"namespace_id,omitempty"`
UserID uuid.UUID `json:"user_id"`
SessionID string `json:"session_id,omitempty"`
Operation string `json:"operation"`
ModelUsed string `json:"model_used"`
Provider string `json:"provider"`
PromptHash string `json:"prompt_hash"`
PromptLength int `json:"prompt_length"`
ResponseLength int `json:"response_length,omitempty"`
TokensUsed int `json:"tokens_used"`
DurationMS int `json:"duration_ms"`
PIIDetected bool `json:"pii_detected"`
PIITypesDetected []string `json:"pii_types_detected,omitempty"`
PIIRedacted bool `json:"pii_redacted"`
PolicyID *uuid.UUID `json:"policy_id,omitempty"`
PolicyViolations []string `json:"policy_violations,omitempty"`
DataCategoriesAccessed []string `json:"data_categories_accessed,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
RequestMetadata map[string]any `json:"request_metadata,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// CreateAuditEntry creates an audit entry from a gated request and response
func (g *AccessGate) CreateAuditEntry(
gatedReq *GatedRequest,
operation string,
provider string,
resp any,
err error,
promptLength int,
sessionID string,
) *AuditEntry {
entry := &AuditEntry{
ID: uuid.New(),
TenantID: gatedReq.TenantID,
NamespaceID: gatedReq.NamespaceID,
UserID: gatedReq.UserID,
SessionID: sessionID,
Operation: operation,
ModelUsed: gatedReq.Model,
Provider: provider,
PromptHash: gatedReq.PromptHash,
PromptLength: promptLength,
PIIDetected: gatedReq.PIIDetected,
PIITypesDetected: gatedReq.PIITypes,
PIIRedacted: gatedReq.PromptRedacted,
CreatedAt: time.Now().UTC(),
}
if gatedReq.Policy != nil {
entry.PolicyID = &gatedReq.Policy.ID
}
if gatedReq.AccessResult != nil && len(gatedReq.AccessResult.BlockedCategories) > 0 {
entry.PolicyViolations = gatedReq.AccessResult.BlockedCategories
}
if err != nil {
entry.ErrorMessage = err.Error()
}
// Extract usage from response
switch r := resp.(type) {
case *ChatResponse:
entry.ResponseLength = len(r.Message.Content)
entry.TokensUsed = r.Usage.TotalTokens
entry.DurationMS = int(r.Duration.Milliseconds())
case *CompletionResponse:
entry.ResponseLength = len(r.Text)
entry.TokensUsed = r.Usage.TotalTokens
entry.DurationMS = int(r.Duration.Milliseconds())
}
return entry
}