package handlers import ( "bytes" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" ) func TestAllowedCollections(t *testing.T) { allowed := []string{ "bp_compliance_ce", "bp_compliance_gesetze", "bp_compliance_datenschutz", "bp_compliance_gdpr", "bp_dsfa_corpus", "bp_dsfa_templates", "bp_dsfa_risks", "bp_legal_templates", "bp_iace_libraries", } for _, c := range allowed { if !AllowedCollections[c] { t.Errorf("Expected %s to be in AllowedCollections", c) } } disallowed := []string{ "bp_unknown", "", "some_random_collection", } for _, c := range disallowed { if AllowedCollections[c] { t.Errorf("Expected %s to NOT be in AllowedCollections", c) } } } func TestSearch_InvalidCollection_Returns400(t *testing.T) { gin.SetMode(gin.TestMode) handler := &RAGHandlers{} w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) body := SearchRequest{ Query: "test query", Collection: "bp_evil_collection", TopK: 5, } bodyBytes, _ := json.Marshal(body) c.Request, _ = http.NewRequest("POST", "/sdk/v1/rag/search", bytes.NewReader(bodyBytes)) c.Request.Header.Set("Content-Type", "application/json") handler.Search(c) if w.Code != http.StatusBadRequest { t.Errorf("Expected 400, got %d", w.Code) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) errMsg, ok := resp["error"].(string) if !ok || errMsg == "" { t.Error("Expected error message in response") } } func TestSearch_WithCollectionParam_BindsCorrectly(t *testing.T) { // Test that the SearchRequest struct correctly binds the collection field body := `{"query":"DSGVO Art. 35","collection":"bp_compliance_recht","top_k":3}` var req SearchRequest err := json.Unmarshal([]byte(body), &req) if err != nil { t.Fatalf("Failed to unmarshal: %v", err) } if req.Query != "DSGVO Art. 35" { t.Errorf("Expected query 'DSGVO Art. 35', got '%s'", req.Query) } if req.Collection != "bp_compliance_recht" { t.Errorf("Expected collection 'bp_compliance_recht', got '%s'", req.Collection) } if req.TopK != 3 { t.Errorf("Expected top_k 3, got %d", req.TopK) } } func TestHandleScrollChunks_MissingCollection_Returns400(t *testing.T) { gin.SetMode(gin.TestMode) handler := &RAGHandlers{} w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request, _ = http.NewRequest("GET", "/sdk/v1/rag/scroll", nil) handler.HandleScrollChunks(c) if w.Code != http.StatusBadRequest { t.Errorf("Expected 400, got %d", w.Code) } var resp map[string]interface{} json.Unmarshal(w.Body.Bytes(), &resp) if resp["error"] == nil { t.Error("Expected error message in response") } } func TestHandleScrollChunks_InvalidCollection_Returns400(t *testing.T) { gin.SetMode(gin.TestMode) handler := &RAGHandlers{} w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request, _ = http.NewRequest("GET", "/sdk/v1/rag/scroll?collection=bp_evil_collection", nil) handler.HandleScrollChunks(c) if w.Code != http.StatusBadRequest { t.Errorf("Expected 400, got %d", w.Code) } } func TestHandleScrollChunks_InvalidLimit_Returns400(t *testing.T) { gin.SetMode(gin.TestMode) handler := &RAGHandlers{} w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request, _ = http.NewRequest("GET", "/sdk/v1/rag/scroll?collection=bp_compliance_ce&limit=abc", nil) handler.HandleScrollChunks(c) if w.Code != http.StatusBadRequest { t.Errorf("Expected 400, got %d", w.Code) } } func TestHandleScrollChunks_NegativeLimit_Returns400(t *testing.T) { gin.SetMode(gin.TestMode) handler := &RAGHandlers{} w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request, _ = http.NewRequest("GET", "/sdk/v1/rag/scroll?collection=bp_compliance_ce&limit=-5", nil) handler.HandleScrollChunks(c) if w.Code != http.StatusBadRequest { t.Errorf("Expected 400, got %d", w.Code) } } func TestSearch_EmptyCollection_IsAllowed(t *testing.T) { // Empty collection should be allowed (falls back to default in the handler) body := `{"query":"test"}` var req SearchRequest err := json.Unmarshal([]byte(body), &req) if err != nil { t.Fatalf("Failed to unmarshal: %v", err) } if req.Collection != "" { t.Errorf("Expected empty collection, got '%s'", req.Collection) } // Empty string is not in AllowedCollections, but the handler // should skip validation for empty collection }