Services: Admin-Compliance, Backend-Compliance, AI-Compliance-SDK, Consent-SDK, Developer-Portal, PCA-Platform, DSMS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
346 lines
9.1 KiB
Go
346 lines
9.1 KiB
Go
package handlers
|
|
|
|
import (
|
|
"net/http"
|
|
|
|
"github.com/breakpilot/ai-compliance-sdk/internal/audit"
|
|
"github.com/breakpilot/ai-compliance-sdk/internal/llm"
|
|
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// LLMHandlers handles LLM-related API endpoints
|
|
type LLMHandlers struct {
|
|
accessGate *llm.AccessGate
|
|
registry *llm.ProviderRegistry
|
|
piiDetector *llm.PIIDetector
|
|
auditStore *audit.Store
|
|
trailBuilder *audit.TrailBuilder
|
|
}
|
|
|
|
// NewLLMHandlers creates new LLM handlers
|
|
func NewLLMHandlers(
|
|
accessGate *llm.AccessGate,
|
|
registry *llm.ProviderRegistry,
|
|
piiDetector *llm.PIIDetector,
|
|
auditStore *audit.Store,
|
|
trailBuilder *audit.TrailBuilder,
|
|
) *LLMHandlers {
|
|
return &LLMHandlers{
|
|
accessGate: accessGate,
|
|
registry: registry,
|
|
piiDetector: piiDetector,
|
|
auditStore: auditStore,
|
|
trailBuilder: trailBuilder,
|
|
}
|
|
}
|
|
|
|
// ChatRequest represents a chat completion request
|
|
type ChatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []llm.Message `json:"messages" binding:"required"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Temperature float64 `json:"temperature"`
|
|
DataCategories []string `json:"data_categories"` // Optional hint about data types
|
|
}
|
|
|
|
// Chat handles chat completion requests
|
|
func (h *LLMHandlers) Chat(c *gin.Context) {
|
|
var req ChatRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
userID := rbac.GetUserID(c)
|
|
tenantID := rbac.GetTenantID(c)
|
|
namespaceID := rbac.GetNamespaceID(c)
|
|
|
|
if userID == uuid.Nil || tenantID == uuid.Nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"})
|
|
return
|
|
}
|
|
|
|
// Detect data categories from messages if not provided
|
|
dataCategories := req.DataCategories
|
|
if len(dataCategories) == 0 {
|
|
for _, msg := range req.Messages {
|
|
detected := h.piiDetector.DetectDataCategories(msg.Content)
|
|
dataCategories = append(dataCategories, detected...)
|
|
}
|
|
}
|
|
|
|
// Process through access gate
|
|
chatReq := &llm.ChatRequest{
|
|
Model: req.Model,
|
|
Messages: req.Messages,
|
|
MaxTokens: req.MaxTokens,
|
|
Temperature: req.Temperature,
|
|
}
|
|
|
|
gatedReq, err := h.accessGate.ProcessChatRequest(
|
|
c.Request.Context(),
|
|
userID, tenantID, namespaceID,
|
|
chatReq, dataCategories,
|
|
)
|
|
if err != nil {
|
|
// Log denied request
|
|
h.logDeniedRequest(c, userID, tenantID, namespaceID, "chat", req.Model, err.Error())
|
|
c.JSON(http.StatusForbidden, gin.H{
|
|
"error": "access_denied",
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Execute the request
|
|
resp, err := h.accessGate.ExecuteChat(c.Request.Context(), gatedReq)
|
|
|
|
// Log the request
|
|
h.logLLMRequest(c, gatedReq.GatedRequest, "chat", resp, err)
|
|
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
"error": "llm_error",
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"id": resp.ID,
|
|
"model": resp.Model,
|
|
"provider": resp.Provider,
|
|
"message": resp.Message,
|
|
"finish_reason": resp.FinishReason,
|
|
"usage": resp.Usage,
|
|
"pii_detected": gatedReq.PIIDetected,
|
|
"pii_redacted": gatedReq.PromptRedacted,
|
|
})
|
|
}
|
|
|
|
// CompletionRequest represents a text completion request
|
|
type CompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt" binding:"required"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Temperature float64 `json:"temperature"`
|
|
DataCategories []string `json:"data_categories"`
|
|
}
|
|
|
|
// Complete handles text completion requests
|
|
func (h *LLMHandlers) Complete(c *gin.Context) {
|
|
var req CompletionRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
userID := rbac.GetUserID(c)
|
|
tenantID := rbac.GetTenantID(c)
|
|
namespaceID := rbac.GetNamespaceID(c)
|
|
|
|
if userID == uuid.Nil || tenantID == uuid.Nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"})
|
|
return
|
|
}
|
|
|
|
// Detect data categories from prompt if not provided
|
|
dataCategories := req.DataCategories
|
|
if len(dataCategories) == 0 {
|
|
dataCategories = h.piiDetector.DetectDataCategories(req.Prompt)
|
|
}
|
|
|
|
// Process through access gate
|
|
completionReq := &llm.CompletionRequest{
|
|
Model: req.Model,
|
|
Prompt: req.Prompt,
|
|
MaxTokens: req.MaxTokens,
|
|
Temperature: req.Temperature,
|
|
}
|
|
|
|
gatedReq, err := h.accessGate.ProcessCompletionRequest(
|
|
c.Request.Context(),
|
|
userID, tenantID, namespaceID,
|
|
completionReq, dataCategories,
|
|
)
|
|
if err != nil {
|
|
h.logDeniedRequest(c, userID, tenantID, namespaceID, "completion", req.Model, err.Error())
|
|
c.JSON(http.StatusForbidden, gin.H{
|
|
"error": "access_denied",
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Execute the request
|
|
resp, err := h.accessGate.ExecuteCompletion(c.Request.Context(), gatedReq)
|
|
|
|
// Log the request
|
|
h.logLLMRequest(c, gatedReq.GatedRequest, "completion", resp, err)
|
|
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
"error": "llm_error",
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"id": resp.ID,
|
|
"model": resp.Model,
|
|
"provider": resp.Provider,
|
|
"text": resp.Text,
|
|
"finish_reason": resp.FinishReason,
|
|
"usage": resp.Usage,
|
|
"pii_detected": gatedReq.PIIDetected,
|
|
"pii_redacted": gatedReq.PromptRedacted,
|
|
})
|
|
}
|
|
|
|
// ListModels returns available models
|
|
func (h *LLMHandlers) ListModels(c *gin.Context) {
|
|
models, err := h.registry.ListAllModels(c.Request.Context())
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"models": models})
|
|
}
|
|
|
|
// GetProviderStatus returns the status of LLM providers
|
|
func (h *LLMHandlers) GetProviderStatus(c *gin.Context) {
|
|
ctx := c.Request.Context()
|
|
|
|
statuses := make(map[string]bool)
|
|
|
|
if p, ok := h.registry.GetPrimary(); ok {
|
|
statuses[p.Name()] = p.IsAvailable(ctx)
|
|
}
|
|
|
|
if p, ok := h.registry.GetFallback(); ok {
|
|
statuses[p.Name()] = p.IsAvailable(ctx)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"providers": statuses})
|
|
}
|
|
|
|
// AnalyzeText analyzes text for PII without making an LLM call
|
|
func (h *LLMHandlers) AnalyzeText(c *gin.Context) {
|
|
var req struct {
|
|
Text string `json:"text" binding:"required"`
|
|
}
|
|
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
findings := h.piiDetector.FindPII(req.Text)
|
|
categories := h.piiDetector.DetectDataCategories(req.Text)
|
|
containsPII := len(findings) > 0
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"contains_pii": containsPII,
|
|
"pii_findings": findings,
|
|
"data_categories": categories,
|
|
})
|
|
}
|
|
|
|
// RedactText redacts PII from text
|
|
func (h *LLMHandlers) RedactText(c *gin.Context) {
|
|
var req struct {
|
|
Text string `json:"text" binding:"required"`
|
|
Level string `json:"level"` // strict, moderate, minimal
|
|
}
|
|
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
level := rbac.PIIRedactionStrict
|
|
switch req.Level {
|
|
case "moderate":
|
|
level = rbac.PIIRedactionModerate
|
|
case "minimal":
|
|
level = rbac.PIIRedactionMinimal
|
|
case "none":
|
|
level = rbac.PIIRedactionNone
|
|
}
|
|
|
|
redacted := h.piiDetector.Redact(req.Text, level)
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"original": req.Text,
|
|
"redacted": redacted,
|
|
"level": level,
|
|
})
|
|
}
|
|
|
|
// logLLMRequest logs an LLM request to the audit trail
|
|
func (h *LLMHandlers) logLLMRequest(c *gin.Context, gatedReq *llm.GatedRequest, operation string, resp any, err error) {
|
|
entry := h.trailBuilder.NewLLMEntry().
|
|
WithTenant(gatedReq.TenantID).
|
|
WithUser(gatedReq.UserID).
|
|
WithOperation(operation).
|
|
WithPrompt(gatedReq.PromptHash, 0). // Length calculated below
|
|
WithPII(gatedReq.PIIDetected, gatedReq.PIITypes, gatedReq.PromptRedacted)
|
|
|
|
if gatedReq.NamespaceID != nil {
|
|
entry.WithNamespace(*gatedReq.NamespaceID)
|
|
}
|
|
|
|
if gatedReq.Policy != nil {
|
|
entry.WithPolicy(&gatedReq.Policy.ID, gatedReq.AccessResult.BlockedCategories)
|
|
}
|
|
|
|
// Add response data if available
|
|
switch r := resp.(type) {
|
|
case *llm.ChatResponse:
|
|
entry.WithModel(r.Model, r.Provider).
|
|
WithResponse(len(r.Message.Content)).
|
|
WithUsage(r.Usage.TotalTokens, int(r.Duration.Milliseconds()))
|
|
case *llm.CompletionResponse:
|
|
entry.WithModel(r.Model, r.Provider).
|
|
WithResponse(len(r.Text)).
|
|
WithUsage(r.Usage.TotalTokens, int(r.Duration.Milliseconds()))
|
|
}
|
|
|
|
if err != nil {
|
|
entry.WithError(err.Error())
|
|
}
|
|
|
|
// Add client info
|
|
entry.AddMetadata("ip_address", c.ClientIP()).
|
|
AddMetadata("user_agent", c.GetHeader("User-Agent"))
|
|
|
|
// Save asynchronously
|
|
go func() {
|
|
entry.Save(c.Request.Context())
|
|
}()
|
|
}
|
|
|
|
// logDeniedRequest logs a denied LLM request
|
|
func (h *LLMHandlers) logDeniedRequest(c *gin.Context, userID, tenantID uuid.UUID, namespaceID *uuid.UUID, operation, model, reason string) {
|
|
entry := h.trailBuilder.NewLLMEntry().
|
|
WithTenant(tenantID).
|
|
WithUser(userID).
|
|
WithOperation(operation).
|
|
WithModel(model, "denied").
|
|
WithError("access_denied: " + reason).
|
|
AddMetadata("ip_address", c.ClientIP()).
|
|
AddMetadata("user_agent", c.GetHeader("User-Agent"))
|
|
|
|
if namespaceID != nil {
|
|
entry.WithNamespace(*namespaceID)
|
|
}
|
|
|
|
go func() {
|
|
entry.Save(c.Request.Context())
|
|
}()
|
|
}
|