Initial commit: breakpilot-compliance - Compliance SDK Platform
Services: Admin-Compliance, Backend-Compliance, AI-Compliance-SDK, Consent-SDK, Developer-Portal, PCA-Platform, DSMS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
345
ai-compliance-sdk/internal/api/handlers/llm_handlers.go
Normal file
345
ai-compliance-sdk/internal/api/handlers/llm_handlers.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/breakpilot/ai-compliance-sdk/internal/audit"
|
||||
"github.com/breakpilot/ai-compliance-sdk/internal/llm"
|
||||
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// LLMHandlers handles LLM-related API endpoints
|
||||
type LLMHandlers struct {
|
||||
accessGate *llm.AccessGate
|
||||
registry *llm.ProviderRegistry
|
||||
piiDetector *llm.PIIDetector
|
||||
auditStore *audit.Store
|
||||
trailBuilder *audit.TrailBuilder
|
||||
}
|
||||
|
||||
// NewLLMHandlers creates new LLM handlers
|
||||
func NewLLMHandlers(
|
||||
accessGate *llm.AccessGate,
|
||||
registry *llm.ProviderRegistry,
|
||||
piiDetector *llm.PIIDetector,
|
||||
auditStore *audit.Store,
|
||||
trailBuilder *audit.TrailBuilder,
|
||||
) *LLMHandlers {
|
||||
return &LLMHandlers{
|
||||
accessGate: accessGate,
|
||||
registry: registry,
|
||||
piiDetector: piiDetector,
|
||||
auditStore: auditStore,
|
||||
trailBuilder: trailBuilder,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatRequest represents a chat completion request
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []llm.Message `json:"messages" binding:"required"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
DataCategories []string `json:"data_categories"` // Optional hint about data types
|
||||
}
|
||||
|
||||
// Chat handles chat completion requests
|
||||
func (h *LLMHandlers) Chat(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
userID := rbac.GetUserID(c)
|
||||
tenantID := rbac.GetTenantID(c)
|
||||
namespaceID := rbac.GetNamespaceID(c)
|
||||
|
||||
if userID == uuid.Nil || tenantID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Detect data categories from messages if not provided
|
||||
dataCategories := req.DataCategories
|
||||
if len(dataCategories) == 0 {
|
||||
for _, msg := range req.Messages {
|
||||
detected := h.piiDetector.DetectDataCategories(msg.Content)
|
||||
dataCategories = append(dataCategories, detected...)
|
||||
}
|
||||
}
|
||||
|
||||
// Process through access gate
|
||||
chatReq := &llm.ChatRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Temperature: req.Temperature,
|
||||
}
|
||||
|
||||
gatedReq, err := h.accessGate.ProcessChatRequest(
|
||||
c.Request.Context(),
|
||||
userID, tenantID, namespaceID,
|
||||
chatReq, dataCategories,
|
||||
)
|
||||
if err != nil {
|
||||
// Log denied request
|
||||
h.logDeniedRequest(c, userID, tenantID, namespaceID, "chat", req.Model, err.Error())
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "access_denied",
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the request
|
||||
resp, err := h.accessGate.ExecuteChat(c.Request.Context(), gatedReq)
|
||||
|
||||
// Log the request
|
||||
h.logLLMRequest(c, gatedReq.GatedRequest, "chat", resp, err)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "llm_error",
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": resp.ID,
|
||||
"model": resp.Model,
|
||||
"provider": resp.Provider,
|
||||
"message": resp.Message,
|
||||
"finish_reason": resp.FinishReason,
|
||||
"usage": resp.Usage,
|
||||
"pii_detected": gatedReq.PIIDetected,
|
||||
"pii_redacted": gatedReq.PromptRedacted,
|
||||
})
|
||||
}
|
||||
|
||||
// CompletionRequest represents a text completion request
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
DataCategories []string `json:"data_categories"`
|
||||
}
|
||||
|
||||
// Complete handles text completion requests
|
||||
func (h *LLMHandlers) Complete(c *gin.Context) {
|
||||
var req CompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
userID := rbac.GetUserID(c)
|
||||
tenantID := rbac.GetTenantID(c)
|
||||
namespaceID := rbac.GetNamespaceID(c)
|
||||
|
||||
if userID == uuid.Nil || tenantID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Detect data categories from prompt if not provided
|
||||
dataCategories := req.DataCategories
|
||||
if len(dataCategories) == 0 {
|
||||
dataCategories = h.piiDetector.DetectDataCategories(req.Prompt)
|
||||
}
|
||||
|
||||
// Process through access gate
|
||||
completionReq := &llm.CompletionRequest{
|
||||
Model: req.Model,
|
||||
Prompt: req.Prompt,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Temperature: req.Temperature,
|
||||
}
|
||||
|
||||
gatedReq, err := h.accessGate.ProcessCompletionRequest(
|
||||
c.Request.Context(),
|
||||
userID, tenantID, namespaceID,
|
||||
completionReq, dataCategories,
|
||||
)
|
||||
if err != nil {
|
||||
h.logDeniedRequest(c, userID, tenantID, namespaceID, "completion", req.Model, err.Error())
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "access_denied",
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the request
|
||||
resp, err := h.accessGate.ExecuteCompletion(c.Request.Context(), gatedReq)
|
||||
|
||||
// Log the request
|
||||
h.logLLMRequest(c, gatedReq.GatedRequest, "completion", resp, err)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "llm_error",
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": resp.ID,
|
||||
"model": resp.Model,
|
||||
"provider": resp.Provider,
|
||||
"text": resp.Text,
|
||||
"finish_reason": resp.FinishReason,
|
||||
"usage": resp.Usage,
|
||||
"pii_detected": gatedReq.PIIDetected,
|
||||
"pii_redacted": gatedReq.PromptRedacted,
|
||||
})
|
||||
}
|
||||
|
||||
// ListModels returns available models
|
||||
func (h *LLMHandlers) ListModels(c *gin.Context) {
|
||||
models, err := h.registry.ListAllModels(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"models": models})
|
||||
}
|
||||
|
||||
// GetProviderStatus returns the status of LLM providers
|
||||
func (h *LLMHandlers) GetProviderStatus(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
statuses := make(map[string]bool)
|
||||
|
||||
if p, ok := h.registry.GetPrimary(); ok {
|
||||
statuses[p.Name()] = p.IsAvailable(ctx)
|
||||
}
|
||||
|
||||
if p, ok := h.registry.GetFallback(); ok {
|
||||
statuses[p.Name()] = p.IsAvailable(ctx)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"providers": statuses})
|
||||
}
|
||||
|
||||
// AnalyzeText analyzes text for PII without making an LLM call
|
||||
func (h *LLMHandlers) AnalyzeText(c *gin.Context) {
|
||||
var req struct {
|
||||
Text string `json:"text" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
findings := h.piiDetector.FindPII(req.Text)
|
||||
categories := h.piiDetector.DetectDataCategories(req.Text)
|
||||
containsPII := len(findings) > 0
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"contains_pii": containsPII,
|
||||
"pii_findings": findings,
|
||||
"data_categories": categories,
|
||||
})
|
||||
}
|
||||
|
||||
// RedactText redacts PII from text
|
||||
func (h *LLMHandlers) RedactText(c *gin.Context) {
|
||||
var req struct {
|
||||
Text string `json:"text" binding:"required"`
|
||||
Level string `json:"level"` // strict, moderate, minimal
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
level := rbac.PIIRedactionStrict
|
||||
switch req.Level {
|
||||
case "moderate":
|
||||
level = rbac.PIIRedactionModerate
|
||||
case "minimal":
|
||||
level = rbac.PIIRedactionMinimal
|
||||
case "none":
|
||||
level = rbac.PIIRedactionNone
|
||||
}
|
||||
|
||||
redacted := h.piiDetector.Redact(req.Text, level)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"original": req.Text,
|
||||
"redacted": redacted,
|
||||
"level": level,
|
||||
})
|
||||
}
|
||||
|
||||
// logLLMRequest logs an LLM request to the audit trail
|
||||
func (h *LLMHandlers) logLLMRequest(c *gin.Context, gatedReq *llm.GatedRequest, operation string, resp any, err error) {
|
||||
entry := h.trailBuilder.NewLLMEntry().
|
||||
WithTenant(gatedReq.TenantID).
|
||||
WithUser(gatedReq.UserID).
|
||||
WithOperation(operation).
|
||||
WithPrompt(gatedReq.PromptHash, 0). // Length calculated below
|
||||
WithPII(gatedReq.PIIDetected, gatedReq.PIITypes, gatedReq.PromptRedacted)
|
||||
|
||||
if gatedReq.NamespaceID != nil {
|
||||
entry.WithNamespace(*gatedReq.NamespaceID)
|
||||
}
|
||||
|
||||
if gatedReq.Policy != nil {
|
||||
entry.WithPolicy(&gatedReq.Policy.ID, gatedReq.AccessResult.BlockedCategories)
|
||||
}
|
||||
|
||||
// Add response data if available
|
||||
switch r := resp.(type) {
|
||||
case *llm.ChatResponse:
|
||||
entry.WithModel(r.Model, r.Provider).
|
||||
WithResponse(len(r.Message.Content)).
|
||||
WithUsage(r.Usage.TotalTokens, int(r.Duration.Milliseconds()))
|
||||
case *llm.CompletionResponse:
|
||||
entry.WithModel(r.Model, r.Provider).
|
||||
WithResponse(len(r.Text)).
|
||||
WithUsage(r.Usage.TotalTokens, int(r.Duration.Milliseconds()))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
entry.WithError(err.Error())
|
||||
}
|
||||
|
||||
// Add client info
|
||||
entry.AddMetadata("ip_address", c.ClientIP()).
|
||||
AddMetadata("user_agent", c.GetHeader("User-Agent"))
|
||||
|
||||
// Save asynchronously
|
||||
go func() {
|
||||
entry.Save(c.Request.Context())
|
||||
}()
|
||||
}
|
||||
|
||||
// logDeniedRequest logs a denied LLM request
|
||||
func (h *LLMHandlers) logDeniedRequest(c *gin.Context, userID, tenantID uuid.UUID, namespaceID *uuid.UUID, operation, model, reason string) {
|
||||
entry := h.trailBuilder.NewLLMEntry().
|
||||
WithTenant(tenantID).
|
||||
WithUser(userID).
|
||||
WithOperation(operation).
|
||||
WithModel(model, "denied").
|
||||
WithError("access_denied: " + reason).
|
||||
AddMetadata("ip_address", c.ClientIP()).
|
||||
AddMetadata("user_agent", c.GetHeader("User-Agent"))
|
||||
|
||||
if namespaceID != nil {
|
||||
entry.WithNamespace(*namespaceID)
|
||||
}
|
||||
|
||||
go func() {
|
||||
entry.Save(c.Request.Context())
|
||||
}()
|
||||
}
|
||||
Reference in New Issue
Block a user