package api import ( "encoding/json" "net/http" "strconv" "github.com/breakpilot/ai-compliance-sdk/internal/db" "github.com/gin-gonic/gin" ) // StateHandler handles state management requests type StateHandler struct { dbPool *db.Pool memStore *db.InMemoryStore } // NewStateHandler creates a new state handler func NewStateHandler(dbPool *db.Pool) *StateHandler { return &StateHandler{ dbPool: dbPool, memStore: db.NewInMemoryStore(), } } // GetState retrieves state for a tenant func (h *StateHandler) GetState(c *gin.Context) { tenantID := c.Param("tenantId") if tenantID == "" { ErrorResponse(c, http.StatusBadRequest, "tenantId is required", "MISSING_TENANT_ID") return } var state *db.SDKState var err error // Try database first, fall back to in-memory if h.dbPool != nil { state, err = h.dbPool.GetState(c.Request.Context(), tenantID) } else { state, err = h.memStore.GetState(tenantID) } if err != nil { ErrorResponse(c, http.StatusNotFound, "State not found", "STATE_NOT_FOUND") return } // Generate ETag etag := generateETag(state.Version, state.UpdatedAt.String()) // Check If-None-Match header if c.GetHeader("If-None-Match") == etag { c.Status(http.StatusNotModified) return } // Parse state JSON var stateData interface{} if err := json.Unmarshal(state.State, &stateData); err != nil { stateData = state.State } c.Header("ETag", etag) c.Header("Last-Modified", state.UpdatedAt.Format("Mon, 02 Jan 2006 15:04:05 GMT")) c.Header("Cache-Control", "private, no-cache") SuccessResponse(c, StateData{ TenantID: state.TenantID, State: stateData, Version: state.Version, LastModified: state.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"), }) } // SaveState saves state for a tenant func (h *StateHandler) SaveState(c *gin.Context) { var req struct { TenantID string `json:"tenantId" binding:"required"` UserID string `json:"userId"` State json.RawMessage `json:"state" binding:"required"` Version *int `json:"version"` } if err := c.ShouldBindJSON(&req); err != nil { ErrorResponse(c, http.StatusBadRequest, err.Error(), "INVALID_REQUEST") return } // Check If-Match header for optimistic locking var expectedVersion *int if ifMatch := c.GetHeader("If-Match"); ifMatch != "" { v, err := strconv.Atoi(ifMatch) if err == nil { expectedVersion = &v } } else if req.Version != nil { expectedVersion = req.Version } var state *db.SDKState var err error // Try database first, fall back to in-memory if h.dbPool != nil { state, err = h.dbPool.SaveState(c.Request.Context(), req.TenantID, req.UserID, req.State, expectedVersion) } else { state, err = h.memStore.SaveState(req.TenantID, req.UserID, req.State, expectedVersion) } if err != nil { if err.Error() == "version conflict" { ErrorResponse(c, http.StatusConflict, "Version conflict. State was modified by another request.", "VERSION_CONFLICT") return } ErrorResponse(c, http.StatusInternalServerError, "Failed to save state", "SAVE_FAILED") return } // Generate ETag etag := generateETag(state.Version, state.UpdatedAt.String()) // Parse state JSON var stateData interface{} if err := json.Unmarshal(state.State, &stateData); err != nil { stateData = state.State } c.Header("ETag", etag) c.Header("Last-Modified", state.UpdatedAt.Format("Mon, 02 Jan 2006 15:04:05 GMT")) SuccessResponse(c, StateData{ TenantID: state.TenantID, State: stateData, Version: state.Version, LastModified: state.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"), }) } // DeleteState deletes state for a tenant func (h *StateHandler) DeleteState(c *gin.Context) { tenantID := c.Param("tenantId") if tenantID == "" { ErrorResponse(c, http.StatusBadRequest, "tenantId is required", "MISSING_TENANT_ID") return } var err error // Try database first, fall back to in-memory if h.dbPool != nil { err = h.dbPool.DeleteState(c.Request.Context(), tenantID) } else { err = h.memStore.DeleteState(tenantID) } if err != nil { ErrorResponse(c, http.StatusInternalServerError, "Failed to delete state", "DELETE_FAILED") return } SuccessResponse(c, gin.H{ "tenantId": tenantID, "deletedAt": now(), }) } // generateETag creates an ETag from version and timestamp func generateETag(version int, timestamp string) string { return "\"" + strconv.Itoa(version) + "-" + timestamp[:8] + "\"" }