package handlers import ( "net/http" "github.com/breakpilot/ai-compliance-sdk/internal/audit" "github.com/breakpilot/ai-compliance-sdk/internal/llm" "github.com/breakpilot/ai-compliance-sdk/internal/rbac" "github.com/gin-gonic/gin" "github.com/google/uuid" ) // LLMHandlers handles LLM-related API endpoints type LLMHandlers struct { accessGate *llm.AccessGate registry *llm.ProviderRegistry piiDetector *llm.PIIDetector auditStore *audit.Store trailBuilder *audit.TrailBuilder } // NewLLMHandlers creates new LLM handlers func NewLLMHandlers( accessGate *llm.AccessGate, registry *llm.ProviderRegistry, piiDetector *llm.PIIDetector, auditStore *audit.Store, trailBuilder *audit.TrailBuilder, ) *LLMHandlers { return &LLMHandlers{ accessGate: accessGate, registry: registry, piiDetector: piiDetector, auditStore: auditStore, trailBuilder: trailBuilder, } } // ChatRequest represents a chat completion request type ChatRequest struct { Model string `json:"model"` Messages []llm.Message `json:"messages" binding:"required"` MaxTokens int `json:"max_tokens"` Temperature float64 `json:"temperature"` DataCategories []string `json:"data_categories"` // Optional hint about data types } // Chat handles chat completion requests func (h *LLMHandlers) Chat(c *gin.Context) { var req ChatRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } userID := rbac.GetUserID(c) tenantID := rbac.GetTenantID(c) namespaceID := rbac.GetNamespaceID(c) if userID == uuid.Nil || tenantID == uuid.Nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"}) return } // Detect data categories from messages if not provided dataCategories := req.DataCategories if len(dataCategories) == 0 { for _, msg := range req.Messages { detected := h.piiDetector.DetectDataCategories(msg.Content) dataCategories = append(dataCategories, detected...) } } // Process through access gate chatReq := &llm.ChatRequest{ Model: req.Model, Messages: req.Messages, MaxTokens: req.MaxTokens, Temperature: req.Temperature, } gatedReq, err := h.accessGate.ProcessChatRequest( c.Request.Context(), userID, tenantID, namespaceID, chatReq, dataCategories, ) if err != nil { // Log denied request h.logDeniedRequest(c, userID, tenantID, namespaceID, "chat", req.Model, err.Error()) c.JSON(http.StatusForbidden, gin.H{ "error": "access_denied", "message": err.Error(), }) return } // Execute the request resp, err := h.accessGate.ExecuteChat(c.Request.Context(), gatedReq) // Log the request h.logLLMRequest(c, gatedReq.GatedRequest, "chat", resp, err) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "llm_error", "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "id": resp.ID, "model": resp.Model, "provider": resp.Provider, "message": resp.Message, "finish_reason": resp.FinishReason, "usage": resp.Usage, "pii_detected": gatedReq.PIIDetected, "pii_redacted": gatedReq.PromptRedacted, }) } // CompletionRequest represents a text completion request type CompletionRequest struct { Model string `json:"model"` Prompt string `json:"prompt" binding:"required"` MaxTokens int `json:"max_tokens"` Temperature float64 `json:"temperature"` DataCategories []string `json:"data_categories"` } // Complete handles text completion requests func (h *LLMHandlers) Complete(c *gin.Context) { var req CompletionRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } userID := rbac.GetUserID(c) tenantID := rbac.GetTenantID(c) namespaceID := rbac.GetNamespaceID(c) if userID == uuid.Nil || tenantID == uuid.Nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"}) return } // Detect data categories from prompt if not provided dataCategories := req.DataCategories if len(dataCategories) == 0 { dataCategories = h.piiDetector.DetectDataCategories(req.Prompt) } // Process through access gate completionReq := &llm.CompletionRequest{ Model: req.Model, Prompt: req.Prompt, MaxTokens: req.MaxTokens, Temperature: req.Temperature, } gatedReq, err := h.accessGate.ProcessCompletionRequest( c.Request.Context(), userID, tenantID, namespaceID, completionReq, dataCategories, ) if err != nil { h.logDeniedRequest(c, userID, tenantID, namespaceID, "completion", req.Model, err.Error()) c.JSON(http.StatusForbidden, gin.H{ "error": "access_denied", "message": err.Error(), }) return } // Execute the request resp, err := h.accessGate.ExecuteCompletion(c.Request.Context(), gatedReq) // Log the request h.logLLMRequest(c, gatedReq.GatedRequest, "completion", resp, err) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "llm_error", "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "id": resp.ID, "model": resp.Model, "provider": resp.Provider, "text": resp.Text, "finish_reason": resp.FinishReason, "usage": resp.Usage, "pii_detected": gatedReq.PIIDetected, "pii_redacted": gatedReq.PromptRedacted, }) } // ListModels returns available models func (h *LLMHandlers) ListModels(c *gin.Context) { models, err := h.registry.ListAllModels(c.Request.Context()) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"models": models}) } // GetProviderStatus returns the status of LLM providers func (h *LLMHandlers) GetProviderStatus(c *gin.Context) { ctx := c.Request.Context() statuses := make(map[string]bool) if p, ok := h.registry.GetPrimary(); ok { statuses[p.Name()] = p.IsAvailable(ctx) } if p, ok := h.registry.GetFallback(); ok { statuses[p.Name()] = p.IsAvailable(ctx) } c.JSON(http.StatusOK, gin.H{"providers": statuses}) } // AnalyzeText analyzes text for PII without making an LLM call func (h *LLMHandlers) AnalyzeText(c *gin.Context) { var req struct { Text string `json:"text" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } findings := h.piiDetector.FindPII(req.Text) categories := h.piiDetector.DetectDataCategories(req.Text) containsPII := len(findings) > 0 c.JSON(http.StatusOK, gin.H{ "contains_pii": containsPII, "pii_findings": findings, "data_categories": categories, }) } // RedactText redacts PII from text func (h *LLMHandlers) RedactText(c *gin.Context) { var req struct { Text string `json:"text" binding:"required"` Level string `json:"level"` // strict, moderate, minimal } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } level := rbac.PIIRedactionStrict switch req.Level { case "moderate": level = rbac.PIIRedactionModerate case "minimal": level = rbac.PIIRedactionMinimal case "none": level = rbac.PIIRedactionNone } redacted := h.piiDetector.Redact(req.Text, level) c.JSON(http.StatusOK, gin.H{ "original": req.Text, "redacted": redacted, "level": level, }) } // logLLMRequest logs an LLM request to the audit trail func (h *LLMHandlers) logLLMRequest(c *gin.Context, gatedReq *llm.GatedRequest, operation string, resp any, err error) { entry := h.trailBuilder.NewLLMEntry(). WithTenant(gatedReq.TenantID). WithUser(gatedReq.UserID). WithOperation(operation). WithPrompt(gatedReq.PromptHash, 0). // Length calculated below WithPII(gatedReq.PIIDetected, gatedReq.PIITypes, gatedReq.PromptRedacted) if gatedReq.NamespaceID != nil { entry.WithNamespace(*gatedReq.NamespaceID) } if gatedReq.Policy != nil { entry.WithPolicy(&gatedReq.Policy.ID, gatedReq.AccessResult.BlockedCategories) } // Add response data if available switch r := resp.(type) { case *llm.ChatResponse: entry.WithModel(r.Model, r.Provider). WithResponse(len(r.Message.Content)). WithUsage(r.Usage.TotalTokens, int(r.Duration.Milliseconds())) case *llm.CompletionResponse: entry.WithModel(r.Model, r.Provider). WithResponse(len(r.Text)). WithUsage(r.Usage.TotalTokens, int(r.Duration.Milliseconds())) } if err != nil { entry.WithError(err.Error()) } // Add client info entry.AddMetadata("ip_address", c.ClientIP()). AddMetadata("user_agent", c.GetHeader("User-Agent")) // Save asynchronously go func() { entry.Save(c.Request.Context()) }() } // logDeniedRequest logs a denied LLM request func (h *LLMHandlers) logDeniedRequest(c *gin.Context, userID, tenantID uuid.UUID, namespaceID *uuid.UUID, operation, model, reason string) { entry := h.trailBuilder.NewLLMEntry(). WithTenant(tenantID). WithUser(userID). WithOperation(operation). WithModel(model, "denied"). WithError("access_denied: " + reason). AddMetadata("ip_address", c.ClientIP()). AddMetadata("user_agent", c.GetHeader("User-Agent")) if namespaceID != nil { entry.WithNamespace(*namespaceID) } go func() { entry.Save(c.Request.Context()) }() }