Files
breakpilot-compliance/ai-compliance-sdk/internal/api/handlers/llm_handlers.go
Benjamin Boenisch 4435e7ea0a Initial commit: breakpilot-compliance - Compliance SDK Platform
Services: Admin-Compliance, Backend-Compliance,
AI-Compliance-SDK, Consent-SDK, Developer-Portal,
PCA-Platform, DSMS

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:28 +01:00

346 lines
9.1 KiB
Go

package handlers
import (
"net/http"
"github.com/breakpilot/ai-compliance-sdk/internal/audit"
"github.com/breakpilot/ai-compliance-sdk/internal/llm"
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// LLMHandlers handles LLM-related API endpoints
type LLMHandlers struct {
accessGate *llm.AccessGate
registry *llm.ProviderRegistry
piiDetector *llm.PIIDetector
auditStore *audit.Store
trailBuilder *audit.TrailBuilder
}
// NewLLMHandlers creates new LLM handlers
func NewLLMHandlers(
accessGate *llm.AccessGate,
registry *llm.ProviderRegistry,
piiDetector *llm.PIIDetector,
auditStore *audit.Store,
trailBuilder *audit.TrailBuilder,
) *LLMHandlers {
return &LLMHandlers{
accessGate: accessGate,
registry: registry,
piiDetector: piiDetector,
auditStore: auditStore,
trailBuilder: trailBuilder,
}
}
// ChatRequest represents a chat completion request
type ChatRequest struct {
Model string `json:"model"`
Messages []llm.Message `json:"messages" binding:"required"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
DataCategories []string `json:"data_categories"` // Optional hint about data types
}
// Chat handles chat completion requests
func (h *LLMHandlers) Chat(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID := rbac.GetUserID(c)
tenantID := rbac.GetTenantID(c)
namespaceID := rbac.GetNamespaceID(c)
if userID == uuid.Nil || tenantID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"})
return
}
// Detect data categories from messages if not provided
dataCategories := req.DataCategories
if len(dataCategories) == 0 {
for _, msg := range req.Messages {
detected := h.piiDetector.DetectDataCategories(msg.Content)
dataCategories = append(dataCategories, detected...)
}
}
// Process through access gate
chatReq := &llm.ChatRequest{
Model: req.Model,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
Temperature: req.Temperature,
}
gatedReq, err := h.accessGate.ProcessChatRequest(
c.Request.Context(),
userID, tenantID, namespaceID,
chatReq, dataCategories,
)
if err != nil {
// Log denied request
h.logDeniedRequest(c, userID, tenantID, namespaceID, "chat", req.Model, err.Error())
c.JSON(http.StatusForbidden, gin.H{
"error": "access_denied",
"message": err.Error(),
})
return
}
// Execute the request
resp, err := h.accessGate.ExecuteChat(c.Request.Context(), gatedReq)
// Log the request
h.logLLMRequest(c, gatedReq.GatedRequest, "chat", resp, err)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "llm_error",
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"id": resp.ID,
"model": resp.Model,
"provider": resp.Provider,
"message": resp.Message,
"finish_reason": resp.FinishReason,
"usage": resp.Usage,
"pii_detected": gatedReq.PIIDetected,
"pii_redacted": gatedReq.PromptRedacted,
})
}
// CompletionRequest represents a text completion request
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
DataCategories []string `json:"data_categories"`
}
// Complete handles text completion requests
func (h *LLMHandlers) Complete(c *gin.Context) {
var req CompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID := rbac.GetUserID(c)
tenantID := rbac.GetTenantID(c)
namespaceID := rbac.GetNamespaceID(c)
if userID == uuid.Nil || tenantID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"})
return
}
// Detect data categories from prompt if not provided
dataCategories := req.DataCategories
if len(dataCategories) == 0 {
dataCategories = h.piiDetector.DetectDataCategories(req.Prompt)
}
// Process through access gate
completionReq := &llm.CompletionRequest{
Model: req.Model,
Prompt: req.Prompt,
MaxTokens: req.MaxTokens,
Temperature: req.Temperature,
}
gatedReq, err := h.accessGate.ProcessCompletionRequest(
c.Request.Context(),
userID, tenantID, namespaceID,
completionReq, dataCategories,
)
if err != nil {
h.logDeniedRequest(c, userID, tenantID, namespaceID, "completion", req.Model, err.Error())
c.JSON(http.StatusForbidden, gin.H{
"error": "access_denied",
"message": err.Error(),
})
return
}
// Execute the request
resp, err := h.accessGate.ExecuteCompletion(c.Request.Context(), gatedReq)
// Log the request
h.logLLMRequest(c, gatedReq.GatedRequest, "completion", resp, err)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "llm_error",
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"id": resp.ID,
"model": resp.Model,
"provider": resp.Provider,
"text": resp.Text,
"finish_reason": resp.FinishReason,
"usage": resp.Usage,
"pii_detected": gatedReq.PIIDetected,
"pii_redacted": gatedReq.PromptRedacted,
})
}
// ListModels returns available models
func (h *LLMHandlers) ListModels(c *gin.Context) {
models, err := h.registry.ListAllModels(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"models": models})
}
// GetProviderStatus returns the status of LLM providers
func (h *LLMHandlers) GetProviderStatus(c *gin.Context) {
ctx := c.Request.Context()
statuses := make(map[string]bool)
if p, ok := h.registry.GetPrimary(); ok {
statuses[p.Name()] = p.IsAvailable(ctx)
}
if p, ok := h.registry.GetFallback(); ok {
statuses[p.Name()] = p.IsAvailable(ctx)
}
c.JSON(http.StatusOK, gin.H{"providers": statuses})
}
// AnalyzeText analyzes text for PII without making an LLM call
func (h *LLMHandlers) AnalyzeText(c *gin.Context) {
var req struct {
Text string `json:"text" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
findings := h.piiDetector.FindPII(req.Text)
categories := h.piiDetector.DetectDataCategories(req.Text)
containsPII := len(findings) > 0
c.JSON(http.StatusOK, gin.H{
"contains_pii": containsPII,
"pii_findings": findings,
"data_categories": categories,
})
}
// RedactText redacts PII from text
func (h *LLMHandlers) RedactText(c *gin.Context) {
var req struct {
Text string `json:"text" binding:"required"`
Level string `json:"level"` // strict, moderate, minimal
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
level := rbac.PIIRedactionStrict
switch req.Level {
case "moderate":
level = rbac.PIIRedactionModerate
case "minimal":
level = rbac.PIIRedactionMinimal
case "none":
level = rbac.PIIRedactionNone
}
redacted := h.piiDetector.Redact(req.Text, level)
c.JSON(http.StatusOK, gin.H{
"original": req.Text,
"redacted": redacted,
"level": level,
})
}
// logLLMRequest logs an LLM request to the audit trail
func (h *LLMHandlers) logLLMRequest(c *gin.Context, gatedReq *llm.GatedRequest, operation string, resp any, err error) {
entry := h.trailBuilder.NewLLMEntry().
WithTenant(gatedReq.TenantID).
WithUser(gatedReq.UserID).
WithOperation(operation).
WithPrompt(gatedReq.PromptHash, 0). // Length calculated below
WithPII(gatedReq.PIIDetected, gatedReq.PIITypes, gatedReq.PromptRedacted)
if gatedReq.NamespaceID != nil {
entry.WithNamespace(*gatedReq.NamespaceID)
}
if gatedReq.Policy != nil {
entry.WithPolicy(&gatedReq.Policy.ID, gatedReq.AccessResult.BlockedCategories)
}
// Add response data if available
switch r := resp.(type) {
case *llm.ChatResponse:
entry.WithModel(r.Model, r.Provider).
WithResponse(len(r.Message.Content)).
WithUsage(r.Usage.TotalTokens, int(r.Duration.Milliseconds()))
case *llm.CompletionResponse:
entry.WithModel(r.Model, r.Provider).
WithResponse(len(r.Text)).
WithUsage(r.Usage.TotalTokens, int(r.Duration.Milliseconds()))
}
if err != nil {
entry.WithError(err.Error())
}
// Add client info
entry.AddMetadata("ip_address", c.ClientIP()).
AddMetadata("user_agent", c.GetHeader("User-Agent"))
// Save asynchronously
go func() {
entry.Save(c.Request.Context())
}()
}
// logDeniedRequest logs a denied LLM request
func (h *LLMHandlers) logDeniedRequest(c *gin.Context, userID, tenantID uuid.UUID, namespaceID *uuid.UUID, operation, model, reason string) {
entry := h.trailBuilder.NewLLMEntry().
WithTenant(tenantID).
WithUser(userID).
WithOperation(operation).
WithModel(model, "denied").
WithError("access_denied: " + reason).
AddMetadata("ip_address", c.ClientIP()).
AddMetadata("user_agent", c.GetHeader("User-Agent"))
if namespaceID != nil {
entry.WithNamespace(*namespaceID)
}
go func() {
entry.Save(c.Request.Context())
}()
}