Files
breakpilot-compliance/ai-compliance-sdk/internal/ucca/legal_rag.go
Benjamin Admin 14a99322eb
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 35s
CI / test-python-backend-compliance (push) Successful in 26s
CI / test-python-document-crawler (push) Successful in 22s
CI / test-python-dsms-gateway (push) Successful in 19s
feat: Phase 2 — RAG integration in Requirements + DSFA Draft
Add legal context enrichment from Qdrant vector corpus to the two
highest-priority modules (Requirements AI assistant and DSFA drafting
engine).

Go SDK:
- Add SearchCollection() with collection override + whitelist validation
- Refactor Search() to delegate to shared searchInternal()

Python backend:
- New ComplianceRAGClient proxying POST /sdk/v1/rag/search (error-tolerant)
- AI assistant: enrich interpret_requirement() and suggest_controls() with RAG
- Requirements API: add ?include_legal_context=true query parameter

Admin (Next.js):
- Extract shared queryRAG() utility from chat route
- Inject RAG legal context into v1 and v2 draft pipelines

Tests for all three layers (Go, Python, TypeScript shared utility).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 08:57:39 +01:00

488 lines
16 KiB
Go

package ucca
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
)
// LegalRAGClient provides access to the compliance CE vector search via Qdrant + Ollama bge-m3.
type LegalRAGClient struct {
qdrantHost string
qdrantPort string
ollamaURL string
embeddingModel string
collection string
httpClient *http.Client
}
// LegalSearchResult represents a single search result from the compliance corpus.
type LegalSearchResult struct {
Text string `json:"text"`
RegulationCode string `json:"regulation_code"`
RegulationName string `json:"regulation_name"`
RegulationShort string `json:"regulation_short"`
Category string `json:"category"`
Article string `json:"article,omitempty"`
Paragraph string `json:"paragraph,omitempty"`
Pages []int `json:"pages,omitempty"`
SourceURL string `json:"source_url"`
Score float64 `json:"score"`
}
// LegalContext represents aggregated legal context for an assessment.
type LegalContext struct {
Query string `json:"query"`
Results []LegalSearchResult `json:"results"`
RelevantArticles []string `json:"relevant_articles"`
Regulations []string `json:"regulations"`
GeneratedAt time.Time `json:"generated_at"`
}
// RegulationInfo describes an available regulation in the corpus.
type CERegulationInfo struct {
ID string `json:"id"`
NameDE string `json:"name_de"`
NameEN string `json:"name_en"`
Short string `json:"short"`
Category string `json:"category"`
}
// NewLegalRAGClient creates a new Legal RAG client using Ollama bge-m3 embeddings.
func NewLegalRAGClient() *LegalRAGClient {
qdrantHost := os.Getenv("QDRANT_HOST")
if qdrantHost == "" {
qdrantHost = "localhost"
}
qdrantPort := os.Getenv("QDRANT_PORT")
if qdrantPort == "" {
qdrantPort = "6333"
}
ollamaURL := os.Getenv("OLLAMA_URL")
if ollamaURL == "" {
ollamaURL = "http://localhost:11434"
}
return &LegalRAGClient{
qdrantHost: qdrantHost,
qdrantPort: qdrantPort,
ollamaURL: ollamaURL,
embeddingModel: "bge-m3",
collection: "bp_compliance_ce",
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
// ollamaEmbeddingRequest for Ollama embedding API.
type ollamaEmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
}
// ollamaEmbeddingResponse from Ollama embedding API.
type ollamaEmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
// qdrantSearchRequest for Qdrant REST API.
type qdrantSearchRequest struct {
Vector []float64 `json:"vector"`
Limit int `json:"limit"`
WithPayload bool `json:"with_payload"`
Filter *qdrantFilter `json:"filter,omitempty"`
}
type qdrantFilter struct {
Should []qdrantCondition `json:"should,omitempty"`
Must []qdrantCondition `json:"must,omitempty"`
}
type qdrantCondition struct {
Key string `json:"key"`
Match qdrantMatch `json:"match"`
}
type qdrantMatch struct {
Value string `json:"value"`
}
// qdrantSearchResponse from Qdrant REST API.
type qdrantSearchResponse struct {
Result []qdrantSearchHit `json:"result"`
}
type qdrantSearchHit struct {
ID interface{} `json:"id"`
Score float64 `json:"score"`
Payload map[string]interface{} `json:"payload"`
}
// generateEmbedding calls Ollama bge-m3 to get a 1024-dim vector for the query.
func (c *LegalRAGClient) generateEmbedding(ctx context.Context, text string) ([]float64, error) {
// Truncate to 2000 chars for bge-m3
if len(text) > 2000 {
text = text[:2000]
}
reqBody := ollamaEmbeddingRequest{
Model: c.embeddingModel,
Prompt: text,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal embedding request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", c.ollamaURL+"/api/embeddings", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create embedding request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("embedding request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("ollama returned %d: %s", resp.StatusCode, string(body))
}
var embResp ollamaEmbeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
return nil, fmt.Errorf("failed to decode embedding response: %w", err)
}
if len(embResp.Embedding) == 0 {
return nil, fmt.Errorf("no embedding returned from ollama")
}
return embResp.Embedding, nil
}
// SearchCollection queries a specific Qdrant collection for relevant passages.
// If collection is empty, it falls back to the default collection (bp_compliance_ce).
func (c *LegalRAGClient) SearchCollection(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
if collection == "" {
collection = c.collection
}
return c.searchInternal(ctx, collection, query, regulationIDs, topK)
}
// Search queries the compliance CE corpus for relevant passages.
func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
return c.searchInternal(ctx, c.collection, query, regulationIDs, topK)
}
// searchInternal performs the actual search against a given collection.
func (c *LegalRAGClient) searchInternal(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
// Generate query embedding via Ollama bge-m3
embedding, err := c.generateEmbedding(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
// Build Qdrant search request
searchReq := qdrantSearchRequest{
Vector: embedding,
Limit: topK,
WithPayload: true,
}
// Add filter for specific regulations if provided
if len(regulationIDs) > 0 {
conditions := make([]qdrantCondition, len(regulationIDs))
for i, regID := range regulationIDs {
conditions[i] = qdrantCondition{
Key: "regulation_id",
Match: qdrantMatch{Value: regID},
}
}
searchReq.Filter = &qdrantFilter{Should: conditions}
}
jsonBody, err := json.Marshal(searchReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal search request: %w", err)
}
// Call Qdrant
url := fmt.Sprintf("http://%s:%s/collections/%s/points/search", c.qdrantHost, c.qdrantPort, collection)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create search request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("search request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("qdrant returned %d: %s", resp.StatusCode, string(body))
}
var searchResp qdrantSearchResponse
if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil {
return nil, fmt.Errorf("failed to decode search response: %w", err)
}
// Convert to results using bp_compliance_ce payload schema
results := make([]LegalSearchResult, len(searchResp.Result))
for i, hit := range searchResp.Result {
results[i] = LegalSearchResult{
Text: getString(hit.Payload, "chunk_text"),
RegulationCode: getString(hit.Payload, "regulation_id"),
RegulationName: getString(hit.Payload, "regulation_name_de"),
RegulationShort: getString(hit.Payload, "regulation_short"),
Category: getString(hit.Payload, "category"),
Pages: getIntSlice(hit.Payload, "pages"),
SourceURL: getString(hit.Payload, "source"),
Score: hit.Score,
}
}
return results, nil
}
// GetLegalContextForAssessment retrieves relevant legal context for an assessment.
func (c *LegalRAGClient) GetLegalContextForAssessment(ctx context.Context, assessment *Assessment) (*LegalContext, error) {
// Build query from assessment data
queryParts := []string{}
// Add domain context
if assessment.Domain != "" {
queryParts = append(queryParts, fmt.Sprintf("KI-Anwendung im Bereich %s", assessment.Domain))
}
// Add data type context
if assessment.Intake.DataTypes.Article9Data {
queryParts = append(queryParts, "besondere Kategorien personenbezogener Daten Art. 9 DSGVO")
}
if assessment.Intake.DataTypes.PersonalData {
queryParts = append(queryParts, "personenbezogene Daten")
}
if assessment.Intake.DataTypes.MinorData {
queryParts = append(queryParts, "Daten von Minderjährigen")
}
// Add purpose context
if assessment.Intake.Purpose.EvaluationScoring {
queryParts = append(queryParts, "automatisierte Bewertung Scoring")
}
if assessment.Intake.Purpose.DecisionMaking {
queryParts = append(queryParts, "automatisierte Entscheidung Art. 22 DSGVO")
}
if assessment.Intake.Purpose.Profiling {
queryParts = append(queryParts, "Profiling")
}
// Add risk-specific context
if assessment.DSFARecommended {
queryParts = append(queryParts, "Datenschutz-Folgenabschätzung Art. 35 DSGVO")
}
if assessment.Art22Risk {
queryParts = append(queryParts, "automatisierte Einzelentscheidung rechtliche Wirkung")
}
// Build final query
query := strings.Join(queryParts, " ")
if query == "" {
query = "DSGVO Anforderungen KI-System Datenschutz"
}
// Determine which regulations to search based on triggered rules
regulationIDs := c.determineRelevantRegulations(assessment)
// Search compliance corpus
results, err := c.Search(ctx, query, regulationIDs, 5)
if err != nil {
return nil, err
}
// Extract unique regulations
regSet := make(map[string]bool)
for _, r := range results {
regSet[r.RegulationCode] = true
}
regulations := make([]string, 0, len(regSet))
for r := range regSet {
regulations = append(regulations, r)
}
// Build relevant articles from page references
articles := make([]string, 0)
for _, r := range results {
if len(r.Pages) > 0 {
key := fmt.Sprintf("%s S. %v", r.RegulationShort, r.Pages)
articles = append(articles, key)
}
}
return &LegalContext{
Query: query,
Results: results,
RelevantArticles: articles,
Regulations: regulations,
GeneratedAt: time.Now().UTC(),
}, nil
}
// determineRelevantRegulations determines which regulations to search based on the assessment.
func (c *LegalRAGClient) determineRelevantRegulations(assessment *Assessment) []string {
ids := []string{"eu_2016_679"} // Always include GDPR
// Check triggered rules for regulation hints
for _, rule := range assessment.TriggeredRules {
gdprRef := rule.GDPRRef
if strings.Contains(gdprRef, "AI Act") || strings.Contains(gdprRef, "KI-VO") {
if !contains(ids, "eu_2024_1689") {
ids = append(ids, "eu_2024_1689")
}
}
if strings.Contains(gdprRef, "NIS2") || strings.Contains(gdprRef, "NIS-2") {
if !contains(ids, "eu_2022_2555") {
ids = append(ids, "eu_2022_2555")
}
}
if strings.Contains(gdprRef, "CRA") || strings.Contains(gdprRef, "Cyber Resilience") {
if !contains(ids, "eu_2024_2847") {
ids = append(ids, "eu_2024_2847")
}
}
if strings.Contains(gdprRef, "Maschinenverordnung") || strings.Contains(gdprRef, "Machinery") {
if !contains(ids, "eu_2023_1230") {
ids = append(ids, "eu_2023_1230")
}
}
}
// Add AI Act if AI-related controls are required
for _, ctrl := range assessment.RequiredControls {
if strings.HasPrefix(ctrl.ID, "AI-") {
if !contains(ids, "eu_2024_1689") {
ids = append(ids, "eu_2024_1689")
}
break
}
}
// Add CRA/NIS2 if security controls are required
for _, ctrl := range assessment.RequiredControls {
if strings.HasPrefix(ctrl.ID, "CRYPTO-") || strings.HasPrefix(ctrl.ID, "IAM-") || strings.HasPrefix(ctrl.ID, "SEC-") {
if !contains(ids, "eu_2022_2555") {
ids = append(ids, "eu_2022_2555")
}
if !contains(ids, "eu_2024_2847") {
ids = append(ids, "eu_2024_2847")
}
break
}
}
return ids
}
// ListAvailableRegulations returns the list of regulations available in the corpus.
func (c *LegalRAGClient) ListAvailableRegulations() []CERegulationInfo {
return []CERegulationInfo{
CERegulationInfo{ID: "eu_2023_1230", NameDE: "EU-Maschinenverordnung 2023/1230", NameEN: "EU Machinery Regulation 2023/1230", Short: "Maschinenverordnung", Category: "regulation"},
CERegulationInfo{ID: "eu_2024_1689", NameDE: "EU KI-Verordnung (AI Act)", NameEN: "EU AI Act 2024/1689", Short: "AI Act", Category: "regulation"},
CERegulationInfo{ID: "eu_2024_2847", NameDE: "Cyber Resilience Act", NameEN: "Cyber Resilience Act 2024/2847", Short: "CRA", Category: "regulation"},
CERegulationInfo{ID: "eu_2022_2555", NameDE: "NIS-2-Richtlinie", NameEN: "NIS2 Directive 2022/2555", Short: "NIS2", Category: "regulation"},
CERegulationInfo{ID: "eu_2016_679", NameDE: "Datenschutz-Grundverordnung (DSGVO)", NameEN: "General Data Protection Regulation (GDPR)", Short: "DSGVO/GDPR", Category: "regulation"},
CERegulationInfo{ID: "eu_blue_guide_2022", NameDE: "EU Blue Guide 2022", NameEN: "EU Blue Guide 2022", Short: "Blue Guide", Category: "guidance"},
CERegulationInfo{ID: "nist_sp_800_218", NameDE: "NIST Secure Software Development Framework", NameEN: "NIST SSDF SP 800-218", Short: "NIST SSDF", Category: "guidance"},
CERegulationInfo{ID: "nist_csf_2_0", NameDE: "NIST Cybersecurity Framework 2.0", NameEN: "NIST CSF 2.0", Short: "NIST CSF", Category: "guidance"},
CERegulationInfo{ID: "oecd_ai_principles", NameDE: "OECD Empfehlung zu Kuenstlicher Intelligenz", NameEN: "OECD Recommendation on AI", Short: "OECD AI", Category: "guidance"},
CERegulationInfo{ID: "enisa_supply_chain_good_practices", NameDE: "ENISA Supply Chain Cybersecurity", NameEN: "ENISA Good Practices for Supply Chain Cybersecurity", Short: "ENISA Supply Chain", Category: "guidance"},
CERegulationInfo{ID: "enisa_threat_landscape_supply_chain", NameDE: "ENISA Threat Landscape Supply Chain", NameEN: "ENISA Threat Landscape for Supply Chain Attacks", Short: "ENISA Threat SC", Category: "guidance"},
CERegulationInfo{ID: "enisa_ics_scada_dependencies", NameDE: "ENISA ICS/SCADA Abhaengigkeiten", NameEN: "ENISA ICS/SCADA Communication Dependencies", Short: "ENISA ICS/SCADA", Category: "guidance"},
CERegulationInfo{ID: "cisa_secure_by_design", NameDE: "CISA Secure by Design", NameEN: "CISA Secure by Design", Short: "CISA SbD", Category: "guidance"},
CERegulationInfo{ID: "enisa_cybersecurity_state_2024", NameDE: "ENISA State of Cybersecurity 2024", NameEN: "ENISA State of Cybersecurity in the Union 2024", Short: "ENISA 2024", Category: "guidance"},
}
}
// FormatLegalContextForPrompt formats the legal context for inclusion in an LLM prompt.
func (c *LegalRAGClient) FormatLegalContextForPrompt(lc *LegalContext) string {
if lc == nil || len(lc.Results) == 0 {
return ""
}
var buf bytes.Buffer
buf.WriteString("\n\n**Relevante Rechtsgrundlagen:**\n\n")
for i, result := range lc.Results {
buf.WriteString(fmt.Sprintf("%d. **%s** (%s)", i+1, result.RegulationShort, result.RegulationCode))
if len(result.Pages) > 0 {
buf.WriteString(fmt.Sprintf(" - Seiten %v", result.Pages))
}
buf.WriteString("\n")
buf.WriteString(fmt.Sprintf(" > %s\n\n", truncateText(result.Text, 300)))
}
return buf.String()
}
// Helper functions
func getString(m map[string]interface{}, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getIntSlice(m map[string]interface{}, key string) []int {
v, ok := m[key]
if !ok {
return nil
}
arr, ok := v.([]interface{})
if !ok {
return nil
}
result := make([]int, 0, len(arr))
for _, item := range arr {
if f, ok := item.(float64); ok {
result = append(result, int(f))
}
}
return result
}
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
func truncateText(text string, maxLen int) string {
if len(text) <= maxLen {
return text
}
return text[:maxLen] + "..."
}