Files
breakpilot-compliance/ai-compliance-sdk/internal/ucca/legal_rag_test.go
T
Benjamin Admin 49147d9497
CI / detect-changes (pull_request) Successful in 16s
CI / branch-name (pull_request) Successful in 2s
CI / guardrail-integrity (pull_request) Successful in 5s
CI / secret-scan (pull_request) Successful in 6s
CI / dep-audit (pull_request) Failing after 1m1s
CI / sbom-scan (pull_request) Failing after 1m4s
CI / build-sha-integrity (pull_request) Successful in 14s
CI / validate-canonical-controls (pull_request) Successful in 13s
CI / test-go (pull_request) Successful in 1m2s
CI / loc-budget (pull_request) Successful in 24s
CI / go-lint (pull_request) Failing after 20s
CI / python-lint (pull_request) Failing after 23s
CI / nodejs-lint (pull_request) Failing after 1m10s
CI / nodejs-build (pull_request) Successful in 3m26s
CI / iace-gt-coverage (pull_request) Successful in 16s
CI / test-python-backend (pull_request) Successful in 27s
CI / test-python-document-crawler (pull_request) Successful in 13s
CI / test-python-dsms-gateway (pull_request) Successful in 9s
feat(ai-sdk): authority-aware re-ranking for legal RAG retrieval (Phase 1)
Re-orders /sdk/v1/rag/search results so binding law from the matching
jurisdiction and domain ranks above guidance, foreign and off-domain law —
without dropping anything (guidance stays as interpretation context).
Internal-only: response schema is unchanged (json:"-" fields), so every
consumer benefits without a contract change.

- authority.go: classifyAuthority / queryDomain / chunkDomain / scopeClass /
  topic ontology. Tagged payload (authority_weight/source_class/jurisdiction)
  wins; deterministic fallback via category + name markers for the untagged corpus.
- authority_rerank.go: rerankByAuthority. final = semantic + authority +
  jurisdiction + domain + scope + topic; the authority score is written back to
  Score so the multi-collection advisor merge preserves the order.
- legal_rag_client: stratified retrieval — the binding-law pool AUGMENTS the
  semantic pool (mergeDedupHits), then re-rank.
- legal_rag_http: searchBinding (source_class filter) + shared doPointsSearch.
- table-driven tests for authority/domain/scope/topic + rerank acceptance +
  a stratified-binding integration test. go test -race green.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-23 10:37:31 +02:00

655 lines
19 KiB
Go

package ucca
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestSearchCollection_UsesCorrectCollection(t *testing.T) {
// Track which collection was requested
var requestedURL string
// Mock Ollama (embedding)
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
Embedding: make([]float64, 1024),
})
}))
defer ollamaMock.Close()
// Mock Qdrant
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestedURL = r.URL.Path
json.NewEncoder(w).Encode(qdrantSearchResponse{
Result: []qdrantSearchHit{},
})
}))
defer qdrantMock.Close()
// Parse qdrant mock host/port
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
ollamaURL: ollamaMock.URL,
embeddingModel: "bge-m3",
collection: "bp_compliance_ce",
textIndexEnsured: make(map[string]bool),
hybridEnabled: false, // dense-only for this test
httpClient: http.DefaultClient,
}
// Test with explicit collection
_, err := client.SearchCollection(context.Background(), "bp_compliance_recht", "test query", nil, 3)
if err != nil {
t.Fatalf("SearchCollection failed: %v", err)
}
if !strings.Contains(requestedURL, "/collections/bp_compliance_recht/") {
t.Errorf("Expected collection bp_compliance_recht in URL, got: %s", requestedURL)
}
}
func TestSearchCollection_FallbackDefault(t *testing.T) {
var requestedURL string
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
Embedding: make([]float64, 1024),
})
}))
defer ollamaMock.Close()
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestedURL = r.URL.Path
json.NewEncoder(w).Encode(qdrantSearchResponse{
Result: []qdrantSearchHit{},
})
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
ollamaURL: ollamaMock.URL,
embeddingModel: "bge-m3",
collection: "bp_compliance_ce",
textIndexEnsured: make(map[string]bool),
hybridEnabled: false,
httpClient: http.DefaultClient,
}
// Test with empty collection (should fall back to default)
_, err := client.SearchCollection(context.Background(), "", "test query", nil, 3)
if err != nil {
t.Fatalf("SearchCollection failed: %v", err)
}
if !strings.Contains(requestedURL, "/collections/bp_compliance_ce/") {
t.Errorf("Expected default collection bp_compliance_ce in URL, got: %s", requestedURL)
}
}
func TestScrollChunks_ReturnsChunksAndNextOffset(t *testing.T) {
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, "/points/scroll") {
t.Errorf("Expected scroll endpoint, got: %s", r.URL.Path)
}
// Decode request to verify fields
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
if reqBody["with_vectors"] != false {
t.Error("Expected with_vectors=false")
}
if reqBody["with_payload"] != true {
t.Error("Expected with_payload=true")
}
resp := map[string]interface{}{
"result": map[string]interface{}{
"points": []map[string]interface{}{
{
"id": "abc-123",
"payload": map[string]interface{}{
"text": "Artikel 35 DSGVO",
"regulation_code": "eu_2016_679",
"regulation_name": "DSGVO",
"regulation_short": "DSGVO",
"category": "regulation",
"article": "Art. 35",
"paragraph": "1",
"source_url": "https://example.com/dsgvo",
},
},
{
"id": "def-456",
"payload": map[string]interface{}{
"chunk_text": "AI Act Titel III",
"regulation_id": "eu_2024_1689",
"regulation_name_de": "KI-Verordnung",
"regulation_short": "AI Act",
"category": "regulation",
"source": "https://example.com/ai-act",
},
},
},
"next_page_offset": "def-456",
},
}
json.NewEncoder(w).Encode(resp)
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
textIndexEnsured: make(map[string]bool),
httpClient: http.DefaultClient,
}
chunks, nextOffset, err := client.ScrollChunks(context.Background(), "bp_compliance_ce", "", 100)
if err != nil {
t.Fatalf("ScrollChunks failed: %v", err)
}
if len(chunks) != 2 {
t.Fatalf("Expected 2 chunks, got %d", len(chunks))
}
// First chunk uses direct field names
if chunks[0].ID != "abc-123" {
t.Errorf("Expected ID abc-123, got %s", chunks[0].ID)
}
if chunks[0].Text != "Artikel 35 DSGVO" {
t.Errorf("Expected text 'Artikel 35 DSGVO', got '%s'", chunks[0].Text)
}
if chunks[0].RegulationCode != "eu_2016_679" {
t.Errorf("Expected regulation_code eu_2016_679, got %s", chunks[0].RegulationCode)
}
if chunks[0].Article != "Art. 35" {
t.Errorf("Expected article 'Art. 35', got '%s'", chunks[0].Article)
}
// Second chunk uses fallback field names (chunk_text, regulation_id, etc.)
if chunks[1].Text != "AI Act Titel III" {
t.Errorf("Expected fallback text 'AI Act Titel III', got '%s'", chunks[1].Text)
}
if chunks[1].RegulationCode != "eu_2024_1689" {
t.Errorf("Expected fallback regulation_code eu_2024_1689, got '%s'", chunks[1].RegulationCode)
}
if chunks[1].RegulationName != "KI-Verordnung" {
t.Errorf("Expected fallback regulation_name 'KI-Verordnung', got '%s'", chunks[1].RegulationName)
}
if nextOffset != "def-456" {
t.Errorf("Expected next_offset 'def-456', got '%s'", nextOffset)
}
}
func TestScrollChunks_EmptyCollection_ReturnsEmpty(t *testing.T) {
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"result": map[string]interface{}{
"points": []interface{}{},
"next_page_offset": nil,
},
}
json.NewEncoder(w).Encode(resp)
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
textIndexEnsured: make(map[string]bool),
httpClient: http.DefaultClient,
}
chunks, nextOffset, err := client.ScrollChunks(context.Background(), "bp_compliance_ce", "", 100)
if err != nil {
t.Fatalf("ScrollChunks failed: %v", err)
}
if len(chunks) != 0 {
t.Errorf("Expected 0 chunks, got %d", len(chunks))
}
if nextOffset != "" {
t.Errorf("Expected empty next_offset, got '%s'", nextOffset)
}
}
func TestScrollChunks_WithOffset_SendsOffset(t *testing.T) {
var receivedBody map[string]interface{}
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedBody)
resp := map[string]interface{}{
"result": map[string]interface{}{
"points": []interface{}{},
"next_page_offset": nil,
},
}
json.NewEncoder(w).Encode(resp)
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
textIndexEnsured: make(map[string]bool),
httpClient: http.DefaultClient,
}
_, _, err := client.ScrollChunks(context.Background(), "bp_compliance_ce", "some-offset-id", 50)
if err != nil {
t.Fatalf("ScrollChunks failed: %v", err)
}
if receivedBody["offset"] != "some-offset-id" {
t.Errorf("Expected offset 'some-offset-id', got '%v'", receivedBody["offset"])
}
if receivedBody["limit"] != float64(50) {
t.Errorf("Expected limit 50, got %v", receivedBody["limit"])
}
}
func TestScrollChunks_SendsAPIKey(t *testing.T) {
var receivedAPIKey string
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAPIKey = r.Header.Get("api-key")
resp := map[string]interface{}{
"result": map[string]interface{}{
"points": []interface{}{},
"next_page_offset": nil,
},
}
json.NewEncoder(w).Encode(resp)
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
qdrantAPIKey: "test-api-key-123",
textIndexEnsured: make(map[string]bool),
httpClient: http.DefaultClient,
}
_, _, err := client.ScrollChunks(context.Background(), "bp_compliance_ce", "", 10)
if err != nil {
t.Fatalf("ScrollChunks failed: %v", err)
}
if receivedAPIKey != "test-api-key-123" {
t.Errorf("Expected api-key 'test-api-key-123', got '%s'", receivedAPIKey)
}
}
func TestSearch_StillWorks(t *testing.T) {
var requestedURL string
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
Embedding: make([]float64, 1024),
})
}))
defer ollamaMock.Close()
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestedURL = r.URL.Path
json.NewEncoder(w).Encode(qdrantSearchResponse{
Result: []qdrantSearchHit{
{
ID: "1",
Score: 0.95,
Payload: map[string]interface{}{
"chunk_text": "Test content",
"regulation_id": "eu_2016_679",
"regulation_name_de": "DSGVO",
"regulation_short": "DSGVO",
"category": "regulation",
"source": "https://example.com",
},
},
},
})
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
ollamaURL: ollamaMock.URL,
embeddingModel: "bge-m3",
collection: "bp_compliance_ce",
textIndexEnsured: make(map[string]bool),
hybridEnabled: false,
httpClient: http.DefaultClient,
}
results, err := client.Search(context.Background(), "DSGVO Art. 35", nil, 5)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
if len(results) != 1 {
t.Fatalf("Expected 1 result, got %d", len(results))
}
if results[0].RegulationCode != "eu_2016_679" {
t.Errorf("Expected regulation_code eu_2016_679, got %s", results[0].RegulationCode)
}
if !strings.Contains(requestedURL, "/collections/bp_compliance_ce/") {
t.Errorf("Expected default collection in URL, got: %s", requestedURL)
}
}
// --- Hybrid Search Tests ---
func TestHybridSearch_UsesQueryAPI(t *testing.T) {
var requestedPaths []string
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
Embedding: make([]float64, 1024),
})
}))
defer ollamaMock.Close()
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestedPaths = append(requestedPaths, r.URL.Path)
if strings.Contains(r.URL.Path, "/index") {
// Text index creation — return OK
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":{"operation_id":1,"status":"completed"}}`))
return
}
if strings.Contains(r.URL.Path, "/points/query") {
// Verify the query request body has prefetch + fusion
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
if _, ok := reqBody["prefetch"]; !ok {
t.Error("Query request missing 'prefetch' field")
}
queryField, ok := reqBody["query"].(map[string]interface{})
if !ok || queryField["fusion"] != "rrf" {
t.Error("Query request missing 'query.fusion=rrf'")
}
json.NewEncoder(w).Encode(qdrantQueryResponse{
Result: []qdrantSearchHit{
{
ID: "1",
Score: 0.88,
Payload: map[string]interface{}{
"chunk_text": "Hybrid result",
"regulation_id": "eu_2016_679",
"regulation_name_de": "DSGVO",
"regulation_short": "DSGVO",
"category": "regulation",
"source": "https://example.com",
},
},
},
})
return
}
// /points/search is now the stratified binding-law augmentation query (it AUGMENTS
// the hybrid pool, it is not a dense fallback). Return empty so the hybrid hit
// remains the sole result for this test.
json.NewEncoder(w).Encode(qdrantSearchResponse{Result: []qdrantSearchHit{}})
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
ollamaURL: ollamaMock.URL,
embeddingModel: "bge-m3",
collection: "bp_compliance_ce",
textIndexEnsured: make(map[string]bool),
hybridEnabled: true,
httpClient: http.DefaultClient,
}
results, err := client.Search(context.Background(), "DSGVO Art. 35", nil, 5)
if err != nil {
t.Fatalf("Hybrid search failed: %v", err)
}
if len(results) != 1 {
t.Fatalf("Expected 1 result, got %d", len(results))
}
if results[0].Text != "Hybrid result" {
t.Errorf("Expected 'Hybrid result', got '%s'", results[0].Text)
}
// Verify text index was created
hasIndex := false
hasQuery := false
for _, p := range requestedPaths {
if strings.Contains(p, "/index") {
hasIndex = true
}
if strings.Contains(p, "/points/query") {
hasQuery = true
}
}
if !hasIndex {
t.Error("Expected text index creation call")
}
if !hasQuery {
t.Error("Expected Query API call")
}
}
// TestSearch_StratifiedBindingRerank verifies that the binding-law pool augments the
// semantic pool and that authority re-ranking lifts binding law above higher-semantic guidance.
func TestSearch_StratifiedBindingRerank(t *testing.T) {
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{Embedding: make([]float64, 1024)})
}))
defer ollamaMock.Close()
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/index") {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":{"status":"completed"}}`))
return
}
if strings.Contains(r.URL.Path, "/points/query") {
json.NewEncoder(w).Encode(qdrantQueryResponse{Result: []qdrantSearchHit{
{ID: "g1", Score: 0.72, Payload: map[string]interface{}{
"chunk_text": "ENISA guidance", "regulation_short": "ENISA",
"article_label": "ENISA CRA Mapping", "source_class": "supervisory_guidance",
"authority_weight": float64(70), "jurisdiction": "EU",
}},
}})
return
}
// /points/search = stratified binding-law pool (source_class=binding_law)
json.NewEncoder(w).Encode(qdrantSearchResponse{Result: []qdrantSearchHit{
{ID: "b1", Score: 0.66, Payload: map[string]interface{}{
"chunk_text": "CRA Anhang I requirement", "regulation_short": "CRA",
"article_label": "CRA Anhang I", "source_class": "binding_law",
"authority_weight": float64(100), "jurisdiction": "EU",
}},
}})
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL, ollamaURL: ollamaMock.URL, embeddingModel: "bge-m3",
collection: "bp_compliance_ce", textIndexEnsured: make(map[string]bool),
hybridEnabled: true, httpClient: http.DefaultClient,
}
results, err := client.Search(context.Background(), "Was gilt hier?", nil, 5)
if err != nil {
t.Fatalf("search failed: %v", err)
}
if len(results) != 2 {
t.Fatalf("expected 2 merged results (guidance + binding), got %d", len(results))
}
if results[0].RegulationShort != "CRA" {
t.Errorf("binding CRA must rank first over higher-semantic guidance, got %q", results[0].RegulationShort)
}
}
func TestHybridSearch_FallbackToDense(t *testing.T) {
var requestedPaths []string
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
Embedding: make([]float64, 1024),
})
}))
defer ollamaMock.Close()
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestedPaths = append(requestedPaths, r.URL.Path)
if strings.Contains(r.URL.Path, "/index") {
// Simulate text index failure (old Qdrant version)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"status":{"error":"not supported"}}`))
return
}
if strings.Contains(r.URL.Path, "/points/search") {
// Dense fallback
json.NewEncoder(w).Encode(qdrantSearchResponse{
Result: []qdrantSearchHit{
{
ID: "2",
Score: 0.90,
Payload: map[string]interface{}{
"chunk_text": "Dense fallback result",
"regulation_id": "eu_2016_679",
"regulation_name_de": "DSGVO",
"regulation_short": "DSGVO",
"category": "regulation",
"source": "https://example.com",
},
},
},
})
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
ollamaURL: ollamaMock.URL,
embeddingModel: "bge-m3",
collection: "bp_compliance_ce",
textIndexEnsured: make(map[string]bool),
hybridEnabled: true,
httpClient: http.DefaultClient,
}
results, err := client.Search(context.Background(), "test query", nil, 5)
if err != nil {
t.Fatalf("Fallback search failed: %v", err)
}
if len(results) != 1 {
t.Fatalf("Expected 1 result, got %d", len(results))
}
if results[0].Text != "Dense fallback result" {
t.Errorf("Expected 'Dense fallback result', got '%s'", results[0].Text)
}
// Verify it fell back to dense search
hasDense := false
for _, p := range requestedPaths {
if strings.Contains(p, "/points/search") {
hasDense = true
}
}
if !hasDense {
t.Error("Expected fallback to dense /points/search")
}
}
func TestEnsureTextIndex_OnlyCalledOnce(t *testing.T) {
callCount := 0
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/index") {
callCount++
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":{"operation_id":1,"status":"completed"}}`))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":[]}`))
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
textIndexEnsured: make(map[string]bool),
httpClient: http.DefaultClient,
}
ctx := context.Background()
_ = client.ensureTextIndex(ctx, "test_collection")
_ = client.ensureTextIndex(ctx, "test_collection")
_ = client.ensureTextIndex(ctx, "test_collection")
if callCount != 1 {
t.Errorf("Expected ensureTextIndex to call Qdrant exactly once, called %d times", callCount)
}
}
func TestHybridDisabled_UsesDenseOnly(t *testing.T) {
var requestedPaths []string
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
Embedding: make([]float64, 1024),
})
}))
defer ollamaMock.Close()
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestedPaths = append(requestedPaths, r.URL.Path)
json.NewEncoder(w).Encode(qdrantSearchResponse{
Result: []qdrantSearchHit{},
})
}))
defer qdrantMock.Close()
client := &LegalRAGClient{
qdrantURL: qdrantMock.URL,
ollamaURL: ollamaMock.URL,
embeddingModel: "bge-m3",
collection: "bp_compliance_ce",
textIndexEnsured: make(map[string]bool),
hybridEnabled: false,
httpClient: http.DefaultClient,
}
_, err := client.Search(context.Background(), "test", nil, 5)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
for _, p := range requestedPaths {
if strings.Contains(p, "/points/query") {
t.Error("Query API should not be called when hybrid is disabled")
}
if strings.Contains(p, "/index") {
t.Error("Text index should not be created when hybrid is disabled")
}
}
}