feat: Phase 2 — RAG integration in Requirements + DSFA Draft
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 35s
CI / test-python-backend-compliance (push) Successful in 26s
CI / test-python-document-crawler (push) Successful in 22s
CI / test-python-dsms-gateway (push) Successful in 19s
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 35s
CI / test-python-backend-compliance (push) Successful in 26s
CI / test-python-document-crawler (push) Successful in 22s
CI / test-python-dsms-gateway (push) Successful in 19s
Add legal context enrichment from Qdrant vector corpus to the two highest-priority modules (Requirements AI assistant and DSFA drafting engine). Go SDK: - Add SearchCollection() with collection override + whitelist validation - Refactor Search() to delegate to shared searchInternal() Python backend: - New ComplianceRAGClient proxying POST /sdk/v1/rag/search (error-tolerant) - AI assistant: enrich interpret_requirement() and suggest_controls() with RAG - Requirements API: add ?include_legal_context=true query parameter Admin (Next.js): - Extract shared queryRAG() utility from chat route - Inject RAG legal context into v1 and v2 draft pipelines Tests for all three layers (Go, Python, TypeScript shared utility). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -21,9 +21,20 @@ func NewRAGHandlers(corpusVersionStore *ucca.CorpusVersionStore) *RAGHandlers {
|
||||
}
|
||||
}
|
||||
|
||||
// AllowedCollections is the whitelist of Qdrant collections that can be queried.
|
||||
var AllowedCollections = map[string]bool{
|
||||
"bp_compliance_ce": true,
|
||||
"bp_compliance_recht": true,
|
||||
"bp_compliance_gesetze": true,
|
||||
"bp_compliance_datenschutz": true,
|
||||
"bp_dsfa_corpus": true,
|
||||
"bp_legal_templates": true,
|
||||
}
|
||||
|
||||
// SearchRequest represents a RAG search request.
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query" binding:"required"`
|
||||
Collection string `json:"collection,omitempty"`
|
||||
Regulations []string `json:"regulations,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
}
|
||||
@@ -41,7 +52,15 @@ func (h *RAGHandlers) Search(c *gin.Context) {
|
||||
req.TopK = 5
|
||||
}
|
||||
|
||||
results, err := h.ragClient.Search(c.Request.Context(), req.Query, req.Regulations, req.TopK)
|
||||
// Validate collection if specified
|
||||
if req.Collection != "" {
|
||||
if !AllowedCollections[req.Collection] {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Unknown collection: " + req.Collection + ". Allowed: bp_compliance_ce, bp_compliance_recht, bp_compliance_gesetze, bp_compliance_datenschutz, bp_dsfa_corpus, bp_legal_templates"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
results, err := h.ragClient.SearchCollection(c.Request.Context(), req.Collection, req.Query, req.Regulations, req.TopK)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "RAG search failed: " + err.Error()})
|
||||
return
|
||||
|
||||
109
ai-compliance-sdk/internal/api/handlers/rag_handlers_test.go
Normal file
109
ai-compliance-sdk/internal/api/handlers/rag_handlers_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestAllowedCollections(t *testing.T) {
|
||||
allowed := []string{
|
||||
"bp_compliance_ce",
|
||||
"bp_compliance_recht",
|
||||
"bp_compliance_gesetze",
|
||||
"bp_compliance_datenschutz",
|
||||
"bp_dsfa_corpus",
|
||||
"bp_legal_templates",
|
||||
}
|
||||
|
||||
for _, c := range allowed {
|
||||
if !AllowedCollections[c] {
|
||||
t.Errorf("Expected %s to be in AllowedCollections", c)
|
||||
}
|
||||
}
|
||||
|
||||
disallowed := []string{
|
||||
"bp_unknown",
|
||||
"",
|
||||
"some_random_collection",
|
||||
}
|
||||
|
||||
for _, c := range disallowed {
|
||||
if AllowedCollections[c] {
|
||||
t.Errorf("Expected %s to NOT be in AllowedCollections", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_InvalidCollection_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := &RAGHandlers{}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
body := SearchRequest{
|
||||
Query: "test query",
|
||||
Collection: "bp_evil_collection",
|
||||
TopK: 5,
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
c.Request, _ = http.NewRequest("POST", "/sdk/v1/rag/search", bytes.NewReader(bodyBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Search(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
|
||||
errMsg, ok := resp["error"].(string)
|
||||
if !ok || errMsg == "" {
|
||||
t.Error("Expected error message in response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_WithCollectionParam_BindsCorrectly(t *testing.T) {
|
||||
// Test that the SearchRequest struct correctly binds the collection field
|
||||
body := `{"query":"DSGVO Art. 35","collection":"bp_compliance_recht","top_k":3}`
|
||||
var req SearchRequest
|
||||
err := json.Unmarshal([]byte(body), &req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if req.Query != "DSGVO Art. 35" {
|
||||
t.Errorf("Expected query 'DSGVO Art. 35', got '%s'", req.Query)
|
||||
}
|
||||
if req.Collection != "bp_compliance_recht" {
|
||||
t.Errorf("Expected collection 'bp_compliance_recht', got '%s'", req.Collection)
|
||||
}
|
||||
if req.TopK != 3 {
|
||||
t.Errorf("Expected top_k 3, got %d", req.TopK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_EmptyCollection_IsAllowed(t *testing.T) {
|
||||
// Empty collection should be allowed (falls back to default in the handler)
|
||||
body := `{"query":"test"}`
|
||||
var req SearchRequest
|
||||
err := json.Unmarshal([]byte(body), &req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if req.Collection != "" {
|
||||
t.Errorf("Expected empty collection, got '%s'", req.Collection)
|
||||
}
|
||||
|
||||
// Empty string is not in AllowedCollections, but the handler
|
||||
// should skip validation for empty collection
|
||||
}
|
||||
@@ -173,8 +173,22 @@ func (c *LegalRAGClient) generateEmbedding(ctx context.Context, text string) ([]
|
||||
return embResp.Embedding, nil
|
||||
}
|
||||
|
||||
// SearchCollection queries a specific Qdrant collection for relevant passages.
|
||||
// If collection is empty, it falls back to the default collection (bp_compliance_ce).
|
||||
func (c *LegalRAGClient) SearchCollection(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
|
||||
if collection == "" {
|
||||
collection = c.collection
|
||||
}
|
||||
return c.searchInternal(ctx, collection, query, regulationIDs, topK)
|
||||
}
|
||||
|
||||
// Search queries the compliance CE corpus for relevant passages.
|
||||
func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
|
||||
return c.searchInternal(ctx, c.collection, query, regulationIDs, topK)
|
||||
}
|
||||
|
||||
// searchInternal performs the actual search against a given collection.
|
||||
func (c *LegalRAGClient) searchInternal(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
|
||||
// Generate query embedding via Ollama bge-m3
|
||||
embedding, err := c.generateEmbedding(ctx, query)
|
||||
if err != nil {
|
||||
@@ -206,7 +220,7 @@ func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs
|
||||
}
|
||||
|
||||
// Call Qdrant
|
||||
url := fmt.Sprintf("http://%s:%s/collections/%s/points/search", c.qdrantHost, c.qdrantPort, c.collection)
|
||||
url := fmt.Sprintf("http://%s:%s/collections/%s/points/search", c.qdrantHost, c.qdrantPort, collection)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create search request: %w", err)
|
||||
|
||||
157
ai-compliance-sdk/internal/ucca/legal_rag_test.go
Normal file
157
ai-compliance-sdk/internal/ucca/legal_rag_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package ucca
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSearchCollection_UsesCorrectCollection(t *testing.T) {
|
||||
// Track which collection was requested
|
||||
var requestedURL string
|
||||
|
||||
// Mock Ollama (embedding)
|
||||
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||
Embedding: make([]float64, 1024),
|
||||
})
|
||||
}))
|
||||
defer ollamaMock.Close()
|
||||
|
||||
// Mock Qdrant
|
||||
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedURL = r.URL.Path
|
||||
json.NewEncoder(w).Encode(qdrantSearchResponse{
|
||||
Result: []qdrantSearchHit{},
|
||||
})
|
||||
}))
|
||||
defer qdrantMock.Close()
|
||||
|
||||
// Parse qdrant mock host/port
|
||||
qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://")
|
||||
parts := strings.Split(qdrantAddr, ":")
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantHost: parts[0],
|
||||
qdrantPort: parts[1],
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
// Test with explicit collection
|
||||
_, err := client.SearchCollection(context.Background(), "bp_compliance_recht", "test query", nil, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchCollection failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(requestedURL, "/collections/bp_compliance_recht/") {
|
||||
t.Errorf("Expected collection bp_compliance_recht in URL, got: %s", requestedURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchCollection_FallbackDefault(t *testing.T) {
|
||||
var requestedURL string
|
||||
|
||||
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||
Embedding: make([]float64, 1024),
|
||||
})
|
||||
}))
|
||||
defer ollamaMock.Close()
|
||||
|
||||
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedURL = r.URL.Path
|
||||
json.NewEncoder(w).Encode(qdrantSearchResponse{
|
||||
Result: []qdrantSearchHit{},
|
||||
})
|
||||
}))
|
||||
defer qdrantMock.Close()
|
||||
|
||||
qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://")
|
||||
parts := strings.Split(qdrantAddr, ":")
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantHost: parts[0],
|
||||
qdrantPort: parts[1],
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
// Test with empty collection (should fall back to default)
|
||||
_, err := client.SearchCollection(context.Background(), "", "test query", nil, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchCollection failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(requestedURL, "/collections/bp_compliance_ce/") {
|
||||
t.Errorf("Expected default collection bp_compliance_ce in URL, got: %s", requestedURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_StillWorks(t *testing.T) {
|
||||
var requestedURL string
|
||||
|
||||
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||
Embedding: make([]float64, 1024),
|
||||
})
|
||||
}))
|
||||
defer ollamaMock.Close()
|
||||
|
||||
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedURL = r.URL.Path
|
||||
json.NewEncoder(w).Encode(qdrantSearchResponse{
|
||||
Result: []qdrantSearchHit{
|
||||
{
|
||||
ID: "1",
|
||||
Score: 0.95,
|
||||
Payload: map[string]interface{}{
|
||||
"chunk_text": "Test content",
|
||||
"regulation_id": "eu_2016_679",
|
||||
"regulation_name_de": "DSGVO",
|
||||
"regulation_short": "DSGVO",
|
||||
"category": "regulation",
|
||||
"source": "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer qdrantMock.Close()
|
||||
|
||||
qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://")
|
||||
parts := strings.Split(qdrantAddr, ":")
|
||||
|
||||
client := &LegalRAGClient{
|
||||
qdrantHost: parts[0],
|
||||
qdrantPort: parts[1],
|
||||
ollamaURL: ollamaMock.URL,
|
||||
embeddingModel: "bge-m3",
|
||||
collection: "bp_compliance_ce",
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
results, err := client.Search(context.Background(), "DSGVO Art. 35", nil, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", len(results))
|
||||
}
|
||||
|
||||
if results[0].RegulationCode != "eu_2016_679" {
|
||||
t.Errorf("Expected regulation_code eu_2016_679, got %s", results[0].RegulationCode)
|
||||
}
|
||||
|
||||
if !strings.Contains(requestedURL, "/collections/bp_compliance_ce/") {
|
||||
t.Errorf("Expected default collection in URL, got: %s", requestedURL)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user