From 899e22a31b54d9920e3217bccd922e4fde7a194f Mon Sep 17 00:00:00 2001 From: Benjamin Boenisch Date: Tue, 17 Feb 2026 23:44:47 +0100 Subject: [PATCH] feat(rag): connect bp_compliance_ce vector corpus to SDK - Switch LegalRAGClient from empty bp_legal_corpus to bp_compliance_ce collection (3,734 chunks across 14 regulations) - Replace embedding-service (384-dim MiniLM) with Ollama bge-m3 (1024-dim) - Add standalone RAG search endpoint: POST /sdk/v1/rag/search - Add regulations list endpoint: GET /sdk/v1/rag/regulations - Add QDRANT_HOST/PORT env vars to docker-compose.yml - Update regulation ID mapping to match actual Qdrant payload schema - Update determineRelevantRegulations for CE corpus regulation IDs Co-Authored-By: Claude Opus 4.6 --- ai-compliance-sdk/cmd/server/main.go | 11 + .../internal/api/handlers/rag_handlers.go | 76 +++++ ai-compliance-sdk/internal/ucca/legal_rag.go | 263 ++++++++++++------ docker-compose.yml | 2 + 4 files changed, 260 insertions(+), 92 deletions(-) create mode 100644 ai-compliance-sdk/internal/api/handlers/rag_handlers.go diff --git a/ai-compliance-sdk/cmd/server/main.go b/ai-compliance-sdk/cmd/server/main.go index 80f3788..2c98b9b 100644 --- a/ai-compliance-sdk/cmd/server/main.go +++ b/ai-compliance-sdk/cmd/server/main.go @@ -135,6 +135,10 @@ func main() { ttsClient := training.NewTTSClient(cfg.TTSServiceURL) contentGenerator := training.NewContentGenerator(providerRegistry, piiDetector, trainingStore, ttsClient) trainingHandlers := handlers.NewTrainingHandlers(trainingStore, contentGenerator) + + // Initialize RAG handlers + ragHandlers := handlers.NewRAGHandlers() + // Initialize middleware rbacMiddleware := rbac.NewMiddleware(rbacService, policyEngine) @@ -743,6 +747,13 @@ func main() { trainingRoutes.GET("/stats", trainingHandlers.GetStats) trainingRoutes.GET("/certificates/:id/verify", trainingHandlers.VerifyCertificate) } + + // RAG Search routes - Compliance Regulation Corpus + ragRoutes := v1.Group("/rag") + { + ragRoutes.POST("/search", ragHandlers.Search) + ragRoutes.GET("/regulations", ragHandlers.ListRegulations) + } } // Create HTTP server diff --git a/ai-compliance-sdk/internal/api/handlers/rag_handlers.go b/ai-compliance-sdk/internal/api/handlers/rag_handlers.go new file mode 100644 index 0000000..be3deb2 --- /dev/null +++ b/ai-compliance-sdk/internal/api/handlers/rag_handlers.go @@ -0,0 +1,76 @@ +package handlers + +import ( + "net/http" + + "github.com/breakpilot/ai-compliance-sdk/internal/ucca" + "github.com/gin-gonic/gin" +) + +// RAGHandlers handles RAG search API endpoints. +type RAGHandlers struct { + ragClient *ucca.LegalRAGClient +} + +// NewRAGHandlers creates new RAG handlers. +func NewRAGHandlers() *RAGHandlers { + return &RAGHandlers{ + ragClient: ucca.NewLegalRAGClient(), + } +} + +// SearchRequest represents a RAG search request. +type SearchRequest struct { + Query string `json:"query" binding:"required"` + Regulations []string `json:"regulations,omitempty"` + TopK int `json:"top_k,omitempty"` +} + +// Search performs a semantic search across the compliance regulation corpus. +// POST /sdk/v1/rag/search +func (h *RAGHandlers) Search(c *gin.Context) { + var req SearchRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.TopK <= 0 || req.TopK > 20 { + req.TopK = 5 + } + + results, err := h.ragClient.Search(c.Request.Context(), req.Query, req.Regulations, req.TopK) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "RAG search failed: " + err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "query": req.Query, + "results": results, + "count": len(results), + }) +} + +// ListRegulations returns the list of available regulations in the corpus. +// GET /sdk/v1/rag/regulations +func (h *RAGHandlers) ListRegulations(c *gin.Context) { + regs := h.ragClient.ListAvailableRegulations() + + // Optionally filter by category + category := c.Query("category") + if category != "" { + filtered := make([]ucca.CERegulationInfo, 0) + for _, r := range regs { + if r.Category == category { + filtered = append(filtered, r) + } + } + regs = filtered + } + + c.JSON(http.StatusOK, gin.H{ + "regulations": regs, + "count": len(regs), + }) +} diff --git a/ai-compliance-sdk/internal/ucca/legal_rag.go b/ai-compliance-sdk/internal/ucca/legal_rag.go index 43e63ff..8c17eb8 100644 --- a/ai-compliance-sdk/internal/ucca/legal_rag.go +++ b/ai-compliance-sdk/internal/ucca/legal_rag.go @@ -12,36 +12,49 @@ import ( "time" ) -// LegalRAGClient provides access to the legal corpus vector search. +// LegalRAGClient provides access to the compliance CE vector search via Qdrant + Ollama bge-m3. type LegalRAGClient struct { - qdrantHost string - qdrantPort string - embeddingURL string - collection string - httpClient *http.Client + qdrantHost string + qdrantPort string + ollamaURL string + embeddingModel string + collection string + httpClient *http.Client } -// LegalSearchResult represents a single search result from the legal corpus. +// LegalSearchResult represents a single search result from the compliance corpus. type LegalSearchResult struct { - Text string `json:"text"` - RegulationCode string `json:"regulation_code"` - RegulationName string `json:"regulation_name"` - Article string `json:"article,omitempty"` - Paragraph string `json:"paragraph,omitempty"` - SourceURL string `json:"source_url"` - Score float64 `json:"score"` + Text string `json:"text"` + RegulationCode string `json:"regulation_code"` + RegulationName string `json:"regulation_name"` + RegulationShort string `json:"regulation_short"` + Category string `json:"category"` + Article string `json:"article,omitempty"` + Paragraph string `json:"paragraph,omitempty"` + Pages []int `json:"pages,omitempty"` + SourceURL string `json:"source_url"` + Score float64 `json:"score"` } // LegalContext represents aggregated legal context for an assessment. type LegalContext struct { - Query string `json:"query"` - Results []LegalSearchResult `json:"results"` - RelevantArticles []string `json:"relevant_articles"` - Regulations []string `json:"regulations"` - GeneratedAt time.Time `json:"generated_at"` + Query string `json:"query"` + Results []LegalSearchResult `json:"results"` + RelevantArticles []string `json:"relevant_articles"` + Regulations []string `json:"regulations"` + GeneratedAt time.Time `json:"generated_at"` } -// NewLegalRAGClient creates a new Legal RAG client. +// RegulationInfo describes an available regulation in the corpus. +type CERegulationInfo struct { + ID string `json:"id"` + NameDE string `json:"name_de"` + NameEN string `json:"name_en"` + Short string `json:"short"` + Category string `json:"category"` +} + +// NewLegalRAGClient creates a new Legal RAG client using Ollama bge-m3 embeddings. func NewLegalRAGClient() *LegalRAGClient { qdrantHost := os.Getenv("QDRANT_HOST") if qdrantHost == "" { @@ -53,33 +66,40 @@ func NewLegalRAGClient() *LegalRAGClient { qdrantPort = "6333" } - embeddingURL := os.Getenv("EMBEDDING_SERVICE_URL") - if embeddingURL == "" { - embeddingURL = "http://localhost:8087" + ollamaURL := os.Getenv("OLLAMA_URL") + if ollamaURL == "" { + ollamaURL = "http://localhost:11434" } return &LegalRAGClient{ - qdrantHost: qdrantHost, - qdrantPort: qdrantPort, - embeddingURL: embeddingURL, - collection: "bp_legal_corpus", + qdrantHost: qdrantHost, + qdrantPort: qdrantPort, + ollamaURL: ollamaURL, + embeddingModel: "bge-m3", + collection: "bp_compliance_ce", httpClient: &http.Client{ - Timeout: 30 * time.Second, + Timeout: 60 * time.Second, }, } } -// embeddingResponse from the embedding service. -type embeddingResponse struct { - Embeddings [][]float64 `json:"embeddings"` +// ollamaEmbeddingRequest for Ollama embedding API. +type ollamaEmbeddingRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +// ollamaEmbeddingResponse from Ollama embedding API. +type ollamaEmbeddingResponse struct { + Embedding []float64 `json:"embedding"` } // qdrantSearchRequest for Qdrant REST API. type qdrantSearchRequest struct { - Vector []float64 `json:"vector"` - Limit int `json:"limit"` - WithPayload bool `json:"with_payload"` - Filter *qdrantFilter `json:"filter,omitempty"` + Vector []float64 `json:"vector"` + Limit int `json:"limit"` + WithPayload bool `json:"with_payload"` + Filter *qdrantFilter `json:"filter,omitempty"` } type qdrantFilter struct { @@ -102,15 +122,21 @@ type qdrantSearchResponse struct { } type qdrantSearchHit struct { - ID string `json:"id"` + ID interface{} `json:"id"` Score float64 `json:"score"` Payload map[string]interface{} `json:"payload"` } -// generateEmbedding calls the embedding service to get a vector for the query. +// generateEmbedding calls Ollama bge-m3 to get a 1024-dim vector for the query. func (c *LegalRAGClient) generateEmbedding(ctx context.Context, text string) ([]float64, error) { - reqBody := map[string]interface{}{ - "texts": []string{text}, + // Truncate to 2000 chars for bge-m3 + if len(text) > 2000 { + text = text[:2000] + } + + reqBody := ollamaEmbeddingRequest{ + Model: c.embeddingModel, + Prompt: text, } jsonBody, err := json.Marshal(reqBody) @@ -118,7 +144,7 @@ func (c *LegalRAGClient) generateEmbedding(ctx context.Context, text string) ([] return nil, fmt.Errorf("failed to marshal embedding request: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", c.embeddingURL+"/embed", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", c.ollamaURL+"/api/embeddings", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create embedding request: %w", err) } @@ -132,24 +158,24 @@ func (c *LegalRAGClient) generateEmbedding(ctx context.Context, text string) ([] if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("embedding service returned %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("ollama returned %d: %s", resp.StatusCode, string(body)) } - var embResp embeddingResponse + var embResp ollamaEmbeddingResponse if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { return nil, fmt.Errorf("failed to decode embedding response: %w", err) } - if len(embResp.Embeddings) == 0 { - return nil, fmt.Errorf("no embeddings returned") + if len(embResp.Embedding) == 0 { + return nil, fmt.Errorf("no embedding returned from ollama") } - return embResp.Embeddings[0], nil + return embResp.Embedding, nil } -// Search queries the legal corpus for relevant passages. -func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationCodes []string, topK int) ([]LegalSearchResult, error) { - // Generate query embedding +// Search queries the compliance CE corpus for relevant passages. +func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) { + // Generate query embedding via Ollama bge-m3 embedding, err := c.generateEmbedding(ctx, query) if err != nil { return nil, fmt.Errorf("failed to generate embedding: %w", err) @@ -163,12 +189,12 @@ func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationCod } // Add filter for specific regulations if provided - if len(regulationCodes) > 0 { - conditions := make([]qdrantCondition, len(regulationCodes)) - for i, code := range regulationCodes { + if len(regulationIDs) > 0 { + conditions := make([]qdrantCondition, len(regulationIDs)) + for i, regID := range regulationIDs { conditions[i] = qdrantCondition{ - Key: "regulation_code", - Match: qdrantMatch{Value: code}, + Key: "regulation_id", + Match: qdrantMatch{Value: regID}, } } searchReq.Filter = &qdrantFilter{Should: conditions} @@ -203,17 +229,18 @@ func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationCod return nil, fmt.Errorf("failed to decode search response: %w", err) } - // Convert to results + // Convert to results using bp_compliance_ce payload schema results := make([]LegalSearchResult, len(searchResp.Result)) for i, hit := range searchResp.Result { results[i] = LegalSearchResult{ - Text: getString(hit.Payload, "text"), - RegulationCode: getString(hit.Payload, "regulation_code"), - RegulationName: getString(hit.Payload, "regulation_name"), - Article: getString(hit.Payload, "article"), - Paragraph: getString(hit.Payload, "paragraph"), - SourceURL: getString(hit.Payload, "source_url"), - Score: hit.Score, + Text: getString(hit.Payload, "chunk_text"), + RegulationCode: getString(hit.Payload, "regulation_id"), + RegulationName: getString(hit.Payload, "regulation_name_de"), + RegulationShort: getString(hit.Payload, "regulation_short"), + Category: getString(hit.Payload, "category"), + Pages: getIntSlice(hit.Payload, "pages"), + SourceURL: getString(hit.Payload, "source"), + Score: hit.Score, } } @@ -267,36 +294,34 @@ func (c *LegalRAGClient) GetLegalContextForAssessment(ctx context.Context, asses } // Determine which regulations to search based on triggered rules - regulationCodes := c.determineRelevantRegulations(assessment) + regulationIDs := c.determineRelevantRegulations(assessment) - // Search legal corpus - results, err := c.Search(ctx, query, regulationCodes, 5) + // Search compliance corpus + results, err := c.Search(ctx, query, regulationIDs, 5) if err != nil { return nil, err } - // Extract unique articles and regulations - articleSet := make(map[string]bool) + // Extract unique regulations regSet := make(map[string]bool) - for _, r := range results { - if r.Article != "" { - key := fmt.Sprintf("%s Art. %s", r.RegulationCode, r.Article) - articleSet[key] = true - } regSet[r.RegulationCode] = true } - articles := make([]string, 0, len(articleSet)) - for a := range articleSet { - articles = append(articles, a) - } - regulations := make([]string, 0, len(regSet)) for r := range regSet { regulations = append(regulations, r) } + // Build relevant articles from page references + articles := make([]string, 0) + for _, r := range results { + if len(r.Pages) > 0 { + key := fmt.Sprintf("%s S. %v", r.RegulationShort, r.Pages) + articles = append(articles, key) + } + } + return &LegalContext{ Query: query, Results: results, @@ -308,38 +333,77 @@ func (c *LegalRAGClient) GetLegalContextForAssessment(ctx context.Context, asses // determineRelevantRegulations determines which regulations to search based on the assessment. func (c *LegalRAGClient) determineRelevantRegulations(assessment *Assessment) []string { - codes := []string{"GDPR"} // Always include GDPR + ids := []string{"eu_2016_679"} // Always include GDPR // Check triggered rules for regulation hints for _, rule := range assessment.TriggeredRules { gdprRef := rule.GDPRRef if strings.Contains(gdprRef, "AI Act") || strings.Contains(gdprRef, "KI-VO") { - codes = append(codes, "AIACT") + if !contains(ids, "eu_2024_1689") { + ids = append(ids, "eu_2024_1689") + } } - if strings.Contains(gdprRef, "Art. 9") || strings.Contains(gdprRef, "Art. 22") { - // Already have GDPR + if strings.Contains(gdprRef, "NIS2") || strings.Contains(gdprRef, "NIS-2") { + if !contains(ids, "eu_2022_2555") { + ids = append(ids, "eu_2022_2555") + } + } + if strings.Contains(gdprRef, "CRA") || strings.Contains(gdprRef, "Cyber Resilience") { + if !contains(ids, "eu_2024_2847") { + ids = append(ids, "eu_2024_2847") + } + } + if strings.Contains(gdprRef, "Maschinenverordnung") || strings.Contains(gdprRef, "Machinery") { + if !contains(ids, "eu_2023_1230") { + ids = append(ids, "eu_2023_1230") + } } } // Add AI Act if AI-related controls are required for _, ctrl := range assessment.RequiredControls { if strings.HasPrefix(ctrl.ID, "AI-") { - if !contains(codes, "AIACT") { - codes = append(codes, "AIACT") + if !contains(ids, "eu_2024_1689") { + ids = append(ids, "eu_2024_1689") } break } } - // Add BSI if security controls are required + // Add CRA/NIS2 if security controls are required for _, ctrl := range assessment.RequiredControls { - if strings.HasPrefix(ctrl.ID, "CRYPTO-") || strings.HasPrefix(ctrl.ID, "IAM-") { - codes = append(codes, "BSI-TR-03161-1") + if strings.HasPrefix(ctrl.ID, "CRYPTO-") || strings.HasPrefix(ctrl.ID, "IAM-") || strings.HasPrefix(ctrl.ID, "SEC-") { + if !contains(ids, "eu_2022_2555") { + ids = append(ids, "eu_2022_2555") + } + if !contains(ids, "eu_2024_2847") { + ids = append(ids, "eu_2024_2847") + } break } } - return codes + return ids +} + +// ListAvailableRegulations returns the list of regulations available in the corpus. +func (c *LegalRAGClient) ListAvailableRegulations() []CERegulationInfo { + return []CERegulationInfo{ + CERegulationInfo{ID: "eu_2023_1230", NameDE: "EU-Maschinenverordnung 2023/1230", NameEN: "EU Machinery Regulation 2023/1230", Short: "Maschinenverordnung", Category: "regulation"}, + CERegulationInfo{ID: "eu_2024_1689", NameDE: "EU KI-Verordnung (AI Act)", NameEN: "EU AI Act 2024/1689", Short: "AI Act", Category: "regulation"}, + CERegulationInfo{ID: "eu_2024_2847", NameDE: "Cyber Resilience Act", NameEN: "Cyber Resilience Act 2024/2847", Short: "CRA", Category: "regulation"}, + CERegulationInfo{ID: "eu_2022_2555", NameDE: "NIS-2-Richtlinie", NameEN: "NIS2 Directive 2022/2555", Short: "NIS2", Category: "regulation"}, + CERegulationInfo{ID: "eu_2016_679", NameDE: "Datenschutz-Grundverordnung (DSGVO)", NameEN: "General Data Protection Regulation (GDPR)", Short: "DSGVO/GDPR", Category: "regulation"}, + CERegulationInfo{ID: "eu_blue_guide_2022", NameDE: "EU Blue Guide 2022", NameEN: "EU Blue Guide 2022", Short: "Blue Guide", Category: "guidance"}, + CERegulationInfo{ID: "nist_sp_800_218", NameDE: "NIST Secure Software Development Framework", NameEN: "NIST SSDF SP 800-218", Short: "NIST SSDF", Category: "guidance"}, + CERegulationInfo{ID: "nist_csf_2_0", NameDE: "NIST Cybersecurity Framework 2.0", NameEN: "NIST CSF 2.0", Short: "NIST CSF", Category: "guidance"}, + CERegulationInfo{ID: "oecd_ai_principles", NameDE: "OECD Empfehlung zu Kuenstlicher Intelligenz", NameEN: "OECD Recommendation on AI", Short: "OECD AI", Category: "guidance"}, + CERegulationInfo{ID: "enisa_supply_chain_good_practices", NameDE: "ENISA Supply Chain Cybersecurity", NameEN: "ENISA Good Practices for Supply Chain Cybersecurity", Short: "ENISA Supply Chain", Category: "guidance"}, + CERegulationInfo{ID: "enisa_threat_landscape_supply_chain", NameDE: "ENISA Threat Landscape Supply Chain", NameEN: "ENISA Threat Landscape for Supply Chain Attacks", Short: "ENISA Threat SC", Category: "guidance"}, + CERegulationInfo{ID: "enisa_ics_scada_dependencies", NameDE: "ENISA ICS/SCADA Abhaengigkeiten", NameEN: "ENISA ICS/SCADA Communication Dependencies", Short: "ENISA ICS/SCADA", Category: "guidance"}, + CERegulationInfo{ID: "cisa_secure_by_design", NameDE: "CISA Secure by Design", NameEN: "CISA Secure by Design", Short: "CISA SbD", Category: "guidance"}, + CERegulationInfo{ID: "enisa_cybersecurity_state_2024", NameDE: "ENISA State of Cybersecurity 2024", NameEN: "ENISA State of Cybersecurity in the Union 2024", Short: "ENISA 2024", Category: "guidance"}, + } } // FormatLegalContextForPrompt formats the legal context for inclusion in an LLM prompt. @@ -352,12 +416,9 @@ func (c *LegalRAGClient) FormatLegalContextForPrompt(lc *LegalContext) string { buf.WriteString("\n\n**Relevante Rechtsgrundlagen:**\n\n") for i, result := range lc.Results { - buf.WriteString(fmt.Sprintf("%d. **%s** (%s)", i+1, result.RegulationName, result.RegulationCode)) - if result.Article != "" { - buf.WriteString(fmt.Sprintf(" - Art. %s", result.Article)) - if result.Paragraph != "" { - buf.WriteString(fmt.Sprintf(" Abs. %s", result.Paragraph)) - } + buf.WriteString(fmt.Sprintf("%d. **%s** (%s)", i+1, result.RegulationShort, result.RegulationCode)) + if len(result.Pages) > 0 { + buf.WriteString(fmt.Sprintf(" - Seiten %v", result.Pages)) } buf.WriteString("\n") buf.WriteString(fmt.Sprintf(" > %s\n\n", truncateText(result.Text, 300))) @@ -377,6 +438,24 @@ func getString(m map[string]interface{}, key string) string { return "" } +func getIntSlice(m map[string]interface{}, key string) []int { + v, ok := m[key] + if !ok { + return nil + } + arr, ok := v.([]interface{}) + if !ok { + return nil + } + result := make([]int, 0, len(arr)) + for _, item := range arr { + if f, ok := item.(float64); ok { + result = append(result, int(f)) + } + } + return result +} + func contains(slice []string, item string) bool { for _, s := range slice { if s == item { diff --git a/docker-compose.yml b/docker-compose.yml index 799b1ed..35e084c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -148,6 +148,8 @@ services: AUDIT_LOG_PROMPTS: ${AUDIT_LOG_PROMPTS:-true} ALLOWED_ORIGINS: "*" TTS_SERVICE_URL: http://compliance-tts-service:8095 + QDRANT_HOST: bp-core-qdrant + QDRANT_PORT: "6333" extra_hosts: - "host.docker.internal:host-gateway" depends_on: