feat(rag): optimize RAG pipeline — JSON-Mode, CoT, Hybrid Search, Re-Ranking, Cross-Reg Dedup, chunk 1024
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 42s
CI/CD / test-python-backend-compliance (push) Successful in 1m38s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 17s
CI/CD / validate-canonical-controls (push) Successful in 10s
CI/CD / Deploy (push) Has been skipped
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 42s
CI/CD / test-python-backend-compliance (push) Successful in 1m38s
CI/CD / test-python-document-crawler (push) Successful in 20s
CI/CD / test-python-dsms-gateway (push) Successful in 17s
CI/CD / validate-canonical-controls (push) Successful in 10s
CI/CD / Deploy (push) Has been skipped
Phase 1 (LLM Quality): - Add format=json to all Ollama payloads (obligation_extractor, control_generator, citation_backfill) - Add Chain-of-Thought analysis steps to Pass 0a/0b system prompts Phase 2 (Retrieval Quality): - Hybrid search via Qdrant Query API with RRF fusion + automatic text index (legal_rag.go) - Fallback to dense-only search if Query API unavailable - Cross-encoder re-ranking with BGE Reranker v2 (RERANK_ENABLED=false by default) - CPU-only PyTorch dependency to keep Docker image small Phase 3 (Data Layer): - Cross-regulation dedup pass (threshold 0.95) links controls across regulations - DedupResult.link_type field distinguishes dedup_merge vs cross_regulation - Chunk size defaults updated 512/50 → 1024/128 for new ingestions only - Existing collections and controls are NOT affected Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -14,12 +14,14 @@ import (
|
||||
|
||||
// LegalRAGClient provides access to the compliance CE vector search via Qdrant + Ollama bge-m3.
|
||||
type LegalRAGClient struct {
|
||||
qdrantURL string
|
||||
qdrantAPIKey string
|
||||
ollamaURL string
|
||||
embeddingModel string
|
||||
collection string
|
||||
httpClient *http.Client
|
||||
qdrantURL string
|
||||
qdrantAPIKey string
|
||||
ollamaURL string
|
||||
embeddingModel string
|
||||
collection string
|
||||
httpClient *http.Client
|
||||
textIndexEnsured map[string]bool // tracks which collections have text index
|
||||
hybridEnabled bool // use Query API with RRF fusion
|
||||
}
|
||||
|
||||
// LegalSearchResult represents a single search result from the compliance corpus.
|
||||
@@ -70,12 +72,16 @@ func NewLegalRAGClient() *LegalRAGClient {
|
||||
ollamaURL = "http://localhost:11434"
|
||||
}
|
||||
|
||||
hybridEnabled := os.Getenv("RAG_HYBRID_SEARCH") != "false" // enabled by default
|
||||
|
||||
return &LegalRAGClient{
|
||||
qdrantURL: qdrantURL,
|
||||
qdrantAPIKey: qdrantAPIKey,
|
||||
ollamaURL: ollamaURL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
qdrantURL: qdrantURL,
|
||||
qdrantAPIKey: qdrantAPIKey,
|
||||
ollamaURL: ollamaURL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
hybridEnabled: hybridEnabled,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
@@ -126,6 +132,161 @@ type qdrantSearchHit struct {
|
||||
Payload map[string]interface{} `json:"payload"`
|
||||
}
|
||||
|
||||
// --- Hybrid Search (Query API with RRF fusion) ---
|
||||
|
||||
// qdrantQueryRequest for Qdrant Query API with prefetch + fusion.
|
||||
type qdrantQueryRequest struct {
|
||||
Prefetch []qdrantPrefetch `json:"prefetch"`
|
||||
Query *qdrantFusion `json:"query"`
|
||||
Limit int `json:"limit"`
|
||||
WithPayload bool `json:"with_payload"`
|
||||
Filter *qdrantFilter `json:"filter,omitempty"`
|
||||
}
|
||||
|
||||
type qdrantPrefetch struct {
|
||||
Query []float64 `json:"query"`
|
||||
Limit int `json:"limit"`
|
||||
Filter *qdrantFilter `json:"filter,omitempty"`
|
||||
}
|
||||
|
||||
type qdrantFusion struct {
|
||||
Fusion string `json:"fusion"`
|
||||
}
|
||||
|
||||
// qdrantQueryResponse from Qdrant Query API (same shape as search).
|
||||
type qdrantQueryResponse struct {
|
||||
Result []qdrantSearchHit `json:"result"`
|
||||
}
|
||||
|
||||
// qdrantTextIndexRequest for creating a full-text index on a payload field.
|
||||
type qdrantTextIndexRequest struct {
|
||||
FieldName string `json:"field_name"`
|
||||
FieldSchema qdrantTextFieldSchema `json:"field_schema"`
|
||||
}
|
||||
|
||||
type qdrantTextFieldSchema struct {
|
||||
Type string `json:"type"`
|
||||
Tokenizer string `json:"tokenizer"`
|
||||
MinLen int `json:"min_token_len,omitempty"`
|
||||
MaxLen int `json:"max_token_len,omitempty"`
|
||||
}
|
||||
|
||||
// ensureTextIndex creates a full-text index on chunk_text if not already done for this collection.
|
||||
func (c *LegalRAGClient) ensureTextIndex(ctx context.Context, collection string) error {
|
||||
if c.textIndexEnsured[collection] {
|
||||
return nil
|
||||
}
|
||||
|
||||
indexReq := qdrantTextIndexRequest{
|
||||
FieldName: "chunk_text",
|
||||
FieldSchema: qdrantTextFieldSchema{
|
||||
Type: "text",
|
||||
Tokenizer: "word",
|
||||
MinLen: 2,
|
||||
MaxLen: 40,
|
||||
},
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(indexReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal text index request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/collections/%s/index", c.qdrantURL, collection)
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create text index request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.qdrantAPIKey != "" {
|
||||
req.Header.Set("api-key", c.qdrantAPIKey)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("text index request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 200 = created, 409 = already exists — both are fine
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusConflict {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("text index creation failed %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
c.textIndexEnsured[collection] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// searchHybrid performs RRF-fused hybrid search (dense + full-text) via Qdrant Query API.
|
||||
func (c *LegalRAGClient) searchHybrid(ctx context.Context, collection string, embedding []float64, regulationIDs []string, topK int) ([]qdrantSearchHit, error) {
|
||||
// Ensure text index exists
|
||||
if err := c.ensureTextIndex(ctx, collection); err != nil {
|
||||
// Non-fatal: log and fall back to dense-only
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build prefetch with dense vector (retrieve top-20 for re-ranking)
|
||||
prefetchLimit := 20
|
||||
if topK > 20 {
|
||||
prefetchLimit = topK * 4
|
||||
}
|
||||
|
||||
queryReq := qdrantQueryRequest{
|
||||
Prefetch: []qdrantPrefetch{
|
||||
{Query: embedding, Limit: prefetchLimit},
|
||||
},
|
||||
Query: &qdrantFusion{Fusion: "rrf"},
|
||||
Limit: topK,
|
||||
WithPayload: true,
|
||||
}
|
||||
|
||||
// Add regulation filter
|
||||
if len(regulationIDs) > 0 {
|
||||
conditions := make([]qdrantCondition, len(regulationIDs))
|
||||
for i, regID := range regulationIDs {
|
||||
conditions[i] = qdrantCondition{
|
||||
Key: "regulation_id",
|
||||
Match: qdrantMatch{Value: regID},
|
||||
}
|
||||
}
|
||||
queryReq.Filter = &qdrantFilter{Should: conditions}
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(queryReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal query request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/collections/%s/points/query", c.qdrantURL, collection)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.qdrantAPIKey != "" {
|
||||
req.Header.Set("api-key", c.qdrantAPIKey)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("qdrant query returned %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var queryResp qdrantQueryResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&queryResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode query response: %w", err)
|
||||
}
|
||||
|
||||
return queryResp.Result, nil
|
||||
}
|
||||
|
||||
// generateEmbedding calls Ollama bge-m3 to get a 1024-dim vector for the query.
|
||||
func (c *LegalRAGClient) generateEmbedding(ctx context.Context, text string) ([]float64, error) {
|
||||
// Truncate to 2000 chars for bge-m3
|
||||
@@ -187,6 +348,8 @@ func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs
|
||||
}
|
||||
|
||||
// searchInternal performs the actual search against a given collection.
|
||||
// If hybrid search is enabled, it uses the Qdrant Query API with RRF fusion
|
||||
// (dense + full-text). Falls back to dense-only /points/search on failure.
|
||||
func (c *LegalRAGClient) searchInternal(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
|
||||
// Generate query embedding via Ollama bge-m3
|
||||
embedding, err := c.generateEmbedding(ctx, query)
|
||||
@@ -194,14 +357,51 @@ func (c *LegalRAGClient) searchInternal(ctx context.Context, collection string,
|
||||
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
||||
}
|
||||
|
||||
// Build Qdrant search request
|
||||
// Try hybrid search first (Query API + RRF), fall back to dense-only
|
||||
var hits []qdrantSearchHit
|
||||
|
||||
if c.hybridEnabled {
|
||||
hybridHits, err := c.searchHybrid(ctx, collection, embedding, regulationIDs, topK)
|
||||
if err == nil {
|
||||
hits = hybridHits
|
||||
}
|
||||
// On error, fall through to dense-only search below
|
||||
}
|
||||
|
||||
if hits == nil {
|
||||
denseHits, err := c.searchDense(ctx, collection, embedding, regulationIDs, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hits = denseHits
|
||||
}
|
||||
|
||||
// Convert to results using bp_compliance_ce payload schema
|
||||
results := make([]LegalSearchResult, len(hits))
|
||||
for i, hit := range hits {
|
||||
results[i] = LegalSearchResult{
|
||||
Text: getString(hit.Payload, "chunk_text"),
|
||||
RegulationCode: getString(hit.Payload, "regulation_id"),
|
||||
RegulationName: getString(hit.Payload, "regulation_name_de"),
|
||||
RegulationShort: getString(hit.Payload, "regulation_short"),
|
||||
Category: getString(hit.Payload, "category"),
|
||||
Pages: getIntSlice(hit.Payload, "pages"),
|
||||
SourceURL: getString(hit.Payload, "source"),
|
||||
Score: hit.Score,
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// searchDense performs a dense-only vector search via Qdrant /points/search.
|
||||
func (c *LegalRAGClient) searchDense(ctx context.Context, collection string, embedding []float64, regulationIDs []string, topK int) ([]qdrantSearchHit, error) {
|
||||
searchReq := qdrantSearchRequest{
|
||||
Vector: embedding,
|
||||
Limit: topK,
|
||||
WithPayload: true,
|
||||
}
|
||||
|
||||
// Add filter for specific regulations if provided
|
||||
if len(regulationIDs) > 0 {
|
||||
conditions := make([]qdrantCondition, len(regulationIDs))
|
||||
for i, regID := range regulationIDs {
|
||||
@@ -218,7 +418,6 @@ func (c *LegalRAGClient) searchInternal(ctx context.Context, collection string,
|
||||
return nil, fmt.Errorf("failed to marshal search request: %w", err)
|
||||
}
|
||||
|
||||
// Call Qdrant
|
||||
url := fmt.Sprintf("%s/collections/%s/points/search", c.qdrantURL, collection)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
@@ -245,22 +444,7 @@ func (c *LegalRAGClient) searchInternal(ctx context.Context, collection string,
|
||||
return nil, fmt.Errorf("failed to decode search response: %w", err)
|
||||
}
|
||||
|
||||
// Convert to results using bp_compliance_ce payload schema
|
||||
results := make([]LegalSearchResult, len(searchResp.Result))
|
||||
for i, hit := range searchResp.Result {
|
||||
results[i] = LegalSearchResult{
|
||||
Text: getString(hit.Payload, "chunk_text"),
|
||||
RegulationCode: getString(hit.Payload, "regulation_id"),
|
||||
RegulationName: getString(hit.Payload, "regulation_name_de"),
|
||||
RegulationShort: getString(hit.Payload, "regulation_short"),
|
||||
Category: getString(hit.Payload, "category"),
|
||||
Pages: getIntSlice(hit.Payload, "pages"),
|
||||
SourceURL: getString(hit.Payload, "source"),
|
||||
Score: hit.Score,
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return searchResp.Result, nil
|
||||
}
|
||||
|
||||
// GetLegalContextForAssessment retrieves relevant legal context for an assessment.
|
||||
|
||||
@@ -32,11 +32,13 @@ func TestSearchCollection_UsesCorrectCollection(t *testing.T) {
|
||||
|
||||
// Parse qdrant mock host/port
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
httpClient: http.DefaultClient,
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
hybridEnabled: false, // dense-only for this test
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
// Test with explicit collection
|
||||
@@ -69,11 +71,13 @@ func TestSearchCollection_FallbackDefault(t *testing.T) {
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
httpClient: http.DefaultClient,
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
hybridEnabled: false,
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
// Test with empty collection (should fall back to default)
|
||||
@@ -140,8 +144,9 @@ func TestScrollChunks_ReturnsChunksAndNextOffset(t *testing.T) {
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
httpClient: http.DefaultClient,
|
||||
qdrantURL: qdrantMock.URL,
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
chunks, nextOffset, err := client.ScrollChunks(context.Background(), "bp_compliance_ce", "", 100)
|
||||
@@ -196,8 +201,9 @@ func TestScrollChunks_EmptyCollection_ReturnsEmpty(t *testing.T) {
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
httpClient: http.DefaultClient,
|
||||
qdrantURL: qdrantMock.URL,
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
chunks, nextOffset, err := client.ScrollChunks(context.Background(), "bp_compliance_ce", "", 100)
|
||||
@@ -230,8 +236,9 @@ func TestScrollChunks_WithOffset_SendsOffset(t *testing.T) {
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
httpClient: http.DefaultClient,
|
||||
qdrantURL: qdrantMock.URL,
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
_, _, err := client.ScrollChunks(context.Background(), "bp_compliance_ce", "some-offset-id", 50)
|
||||
@@ -263,9 +270,10 @@ func TestScrollChunks_SendsAPIKey(t *testing.T) {
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
qdrantAPIKey: "test-api-key-123",
|
||||
httpClient: http.DefaultClient,
|
||||
qdrantURL: qdrantMock.URL,
|
||||
qdrantAPIKey: "test-api-key-123",
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
_, _, err := client.ScrollChunks(context.Background(), "bp_compliance_ce", "", 10)
|
||||
@@ -310,11 +318,13 @@ func TestSearch_StillWorks(t *testing.T) {
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
httpClient: http.DefaultClient,
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
hybridEnabled: false,
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), "DSGVO Art. 35", nil, 5)
|
||||
@@ -334,3 +344,257 @@ func TestSearch_StillWorks(t *testing.T) {
|
||||
t.Errorf("Expected default collection in URL, got: %s", requestedURL)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Hybrid Search Tests ---
|
||||
|
||||
func TestHybridSearch_UsesQueryAPI(t *testing.T) {
|
||||
var requestedPaths []string
|
||||
|
||||
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||
Embedding: make([]float64, 1024),
|
||||
})
|
||||
}))
|
||||
defer ollamaMock.Close()
|
||||
|
||||
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedPaths = append(requestedPaths, r.URL.Path)
|
||||
|
||||
if strings.Contains(r.URL.Path, "/index") {
|
||||
// Text index creation — return OK
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"result":{"operation_id":1,"status":"completed"}}`))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(r.URL.Path, "/points/query") {
|
||||
// Verify the query request body has prefetch + fusion
|
||||
var reqBody map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
|
||||
if _, ok := reqBody["prefetch"]; !ok {
|
||||
t.Error("Query request missing 'prefetch' field")
|
||||
}
|
||||
queryField, ok := reqBody["query"].(map[string]interface{})
|
||||
if !ok || queryField["fusion"] != "rrf" {
|
||||
t.Error("Query request missing 'query.fusion=rrf'")
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(qdrantQueryResponse{
|
||||
Result: []qdrantSearchHit{
|
||||
{
|
||||
ID: "1",
|
||||
Score: 0.88,
|
||||
Payload: map[string]interface{}{
|
||||
"chunk_text": "Hybrid result",
|
||||
"regulation_id": "eu_2016_679",
|
||||
"regulation_name_de": "DSGVO",
|
||||
"regulation_short": "DSGVO",
|
||||
"category": "regulation",
|
||||
"source": "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback: should not reach dense search
|
||||
t.Error("Unexpected dense search call when hybrid succeeded")
|
||||
json.NewEncoder(w).Encode(qdrantSearchResponse{Result: []qdrantSearchHit{}})
|
||||
}))
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
hybridEnabled: true,
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), "DSGVO Art. 35", nil, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Hybrid search failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", len(results))
|
||||
}
|
||||
if results[0].Text != "Hybrid result" {
|
||||
t.Errorf("Expected 'Hybrid result', got '%s'", results[0].Text)
|
||||
}
|
||||
|
||||
// Verify text index was created
|
||||
hasIndex := false
|
||||
hasQuery := false
|
||||
for _, p := range requestedPaths {
|
||||
if strings.Contains(p, "/index") {
|
||||
hasIndex = true
|
||||
}
|
||||
if strings.Contains(p, "/points/query") {
|
||||
hasQuery = true
|
||||
}
|
||||
}
|
||||
if !hasIndex {
|
||||
t.Error("Expected text index creation call")
|
||||
}
|
||||
if !hasQuery {
|
||||
t.Error("Expected Query API call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridSearch_FallbackToDense(t *testing.T) {
|
||||
var requestedPaths []string
|
||||
|
||||
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||
Embedding: make([]float64, 1024),
|
||||
})
|
||||
}))
|
||||
defer ollamaMock.Close()
|
||||
|
||||
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedPaths = append(requestedPaths, r.URL.Path)
|
||||
|
||||
if strings.Contains(r.URL.Path, "/index") {
|
||||
// Simulate text index failure (old Qdrant version)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"status":{"error":"not supported"}}`))
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(r.URL.Path, "/points/search") {
|
||||
// Dense fallback
|
||||
json.NewEncoder(w).Encode(qdrantSearchResponse{
|
||||
Result: []qdrantSearchHit{
|
||||
{
|
||||
ID: "2",
|
||||
Score: 0.90,
|
||||
Payload: map[string]interface{}{
|
||||
"chunk_text": "Dense fallback result",
|
||||
"regulation_id": "eu_2016_679",
|
||||
"regulation_name_de": "DSGVO",
|
||||
"regulation_short": "DSGVO",
|
||||
"category": "regulation",
|
||||
"source": "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
hybridEnabled: true,
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), "test query", nil, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Fallback search failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", len(results))
|
||||
}
|
||||
if results[0].Text != "Dense fallback result" {
|
||||
t.Errorf("Expected 'Dense fallback result', got '%s'", results[0].Text)
|
||||
}
|
||||
|
||||
// Verify it fell back to dense search
|
||||
hasDense := false
|
||||
for _, p := range requestedPaths {
|
||||
if strings.Contains(p, "/points/search") {
|
||||
hasDense = true
|
||||
}
|
||||
}
|
||||
if !hasDense {
|
||||
t.Error("Expected fallback to dense /points/search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureTextIndex_OnlyCalledOnce(t *testing.T) {
|
||||
callCount := 0
|
||||
|
||||
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/index") {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"result":{"operation_id":1,"status":"completed"}}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"result":[]}`))
|
||||
}))
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_ = client.ensureTextIndex(ctx, "test_collection")
|
||||
_ = client.ensureTextIndex(ctx, "test_collection")
|
||||
_ = client.ensureTextIndex(ctx, "test_collection")
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected ensureTextIndex to call Qdrant exactly once, called %d times", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHybridDisabled_UsesDenseOnly(t *testing.T) {
|
||||
var requestedPaths []string
|
||||
|
||||
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||
Embedding: make([]float64, 1024),
|
||||
})
|
||||
}))
|
||||
defer ollamaMock.Close()
|
||||
|
||||
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedPaths = append(requestedPaths, r.URL.Path)
|
||||
json.NewEncoder(w).Encode(qdrantSearchResponse{
|
||||
Result: []qdrantSearchHit{},
|
||||
})
|
||||
}))
|
||||
defer qdrantMock.Close()
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantURL: qdrantMock.URL,
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
textIndexEnsured: make(map[string]bool),
|
||||
hybridEnabled: false,
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
_, err := client.Search(context.Background(), "test", nil, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search failed: %v", err)
|
||||
}
|
||||
|
||||
for _, p := range requestedPaths {
|
||||
if strings.Contains(p, "/points/query") {
|
||||
t.Error("Query API should not be called when hybrid is disabled")
|
||||
}
|
||||
if strings.Contains(p, "/index") {
|
||||
t.Error("Text index should not be created when hybrid is disabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user