feat: Phase 2 — RAG integration in Requirements + DSFA Draft
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 35s
CI / test-python-backend-compliance (push) Successful in 26s
CI / test-python-document-crawler (push) Successful in 22s
CI / test-python-dsms-gateway (push) Successful in 19s
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 35s
CI / test-python-backend-compliance (push) Successful in 26s
CI / test-python-document-crawler (push) Successful in 22s
CI / test-python-dsms-gateway (push) Successful in 19s
Add legal context enrichment from Qdrant vector corpus to the two highest-priority modules (Requirements AI assistant and DSFA drafting engine). Go SDK: - Add SearchCollection() with collection override + whitelist validation - Refactor Search() to delegate to shared searchInternal() Python backend: - New ComplianceRAGClient proxying POST /sdk/v1/rag/search (error-tolerant) - AI assistant: enrich interpret_requirement() and suggest_controls() with RAG - Requirements API: add ?include_legal_context=true query parameter Admin (Next.js): - Extract shared queryRAG() utility from chat route - Inject RAG legal context into v1 and v2 draft pipelines Tests for all three layers (Go, Python, TypeScript shared utility). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,8 +7,8 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import { NextRequest, NextResponse } from 'next/server'
|
import { NextRequest, NextResponse } from 'next/server'
|
||||||
|
import { queryRAG } from '@/lib/sdk/drafting-engine/rag-query'
|
||||||
|
|
||||||
const KLAUSUR_SERVICE_URL = process.env.KLAUSUR_SERVICE_URL || 'http://klausur-service:8086'
|
|
||||||
const OLLAMA_URL = process.env.OLLAMA_URL || 'http://host.docker.internal:11434'
|
const OLLAMA_URL = process.env.OLLAMA_URL || 'http://host.docker.internal:11434'
|
||||||
const LLM_MODEL = process.env.COMPLIANCE_LLM_MODEL || 'qwen2.5vl:32b'
|
const LLM_MODEL = process.env.COMPLIANCE_LLM_MODEL || 'qwen2.5vl:32b'
|
||||||
|
|
||||||
@@ -30,34 +30,6 @@ Konsistenz zwischen Dokumenten sicherzustellen.
|
|||||||
## Kompetenzbereich
|
## Kompetenzbereich
|
||||||
DSGVO, BDSG, AI Act, TTDSG, DSK-Kurzpapiere, SDM V3.0, BSI-Grundschutz, ISO 27001/27701, EDPB Guidelines, WP248`
|
DSGVO, BDSG, AI Act, TTDSG, DSK-Kurzpapiere, SDM V3.0, BSI-Grundschutz, ISO 27001/27701, EDPB Guidelines, WP248`
|
||||||
|
|
||||||
/**
|
|
||||||
* Query the RAG corpus for relevant documents
|
|
||||||
*/
|
|
||||||
async function queryRAG(query: string): Promise<string> {
|
|
||||||
try {
|
|
||||||
const url = `${KLAUSUR_SERVICE_URL}/api/v1/dsfa-rag/search?query=${encodeURIComponent(query)}&top_k=3`
|
|
||||||
const res = await fetch(url, {
|
|
||||||
headers: { 'Content-Type': 'application/json' },
|
|
||||||
signal: AbortSignal.timeout(10000),
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!res.ok) return ''
|
|
||||||
|
|
||||||
const data = await res.json()
|
|
||||||
if (data.results?.length > 0) {
|
|
||||||
return data.results
|
|
||||||
.map(
|
|
||||||
(r: { source_name?: string; source_code?: string; content?: string }, i: number) =>
|
|
||||||
`[Quelle ${i + 1}: ${r.source_name || r.source_code || 'Unbekannt'}]\n${r.content || ''}`
|
|
||||||
)
|
|
||||||
.join('\n\n---\n\n')
|
|
||||||
}
|
|
||||||
return ''
|
|
||||||
} catch {
|
|
||||||
return ''
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function POST(request: NextRequest) {
|
export async function POST(request: NextRequest) {
|
||||||
try {
|
try {
|
||||||
const body = await request.json()
|
const body = await request.json()
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import { sanitizeAllowedFacts, validateNoRemainingPII, SanitizationError } from
|
|||||||
import { terminologyToPromptString, styleContractToPromptString } from '@/lib/sdk/drafting-engine/terminology'
|
import { terminologyToPromptString, styleContractToPromptString } from '@/lib/sdk/drafting-engine/terminology'
|
||||||
import { executeRepairLoop, type ProseBlockOutput, type RepairAudit } from '@/lib/sdk/drafting-engine/prose-validator'
|
import { executeRepairLoop, type ProseBlockOutput, type RepairAudit } from '@/lib/sdk/drafting-engine/prose-validator'
|
||||||
import { ProseCacheManager, computeChecksumSync, type CacheKeyParams } from '@/lib/sdk/drafting-engine/cache'
|
import { ProseCacheManager, computeChecksumSync, type CacheKeyParams } from '@/lib/sdk/drafting-engine/cache'
|
||||||
|
import { queryRAG } from '@/lib/sdk/drafting-engine/rag-query'
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Shared State
|
// Shared State
|
||||||
@@ -103,9 +104,20 @@ async function handleV1Draft(body: Record<string, unknown>): Promise<NextRespons
|
|||||||
}, { status: 403 })
|
}, { status: 403 })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RAG: Fetch relevant legal context
|
||||||
|
const ragQuery = documentType === 'dsfa'
|
||||||
|
? 'Datenschutz-Folgenabschaetzung Art. 35 DSGVO Risikobewertung'
|
||||||
|
: `${documentType} DSGVO Compliance Anforderungen`
|
||||||
|
const ragContext = await queryRAG(ragQuery)
|
||||||
|
|
||||||
|
let v1SystemPrompt = V1_SYSTEM_PROMPT
|
||||||
|
if (ragContext) {
|
||||||
|
v1SystemPrompt += `\n\n## Relevanter Rechtskontext\n${ragContext}`
|
||||||
|
}
|
||||||
|
|
||||||
const draftPrompt = buildPromptForDocumentType(documentType, draftContext, instructions)
|
const draftPrompt = buildPromptForDocumentType(documentType, draftContext, instructions)
|
||||||
const messages = [
|
const messages = [
|
||||||
{ role: 'system', content: V1_SYSTEM_PROMPT },
|
{ role: 'system', content: v1SystemPrompt },
|
||||||
...(existingDraft ? [{
|
...(existingDraft ? [{
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: `Bisheriger Entwurf:\n${JSON.stringify(existingDraft.sections, null, 2)}`,
|
content: `Bisheriger Entwurf:\n${JSON.stringify(existingDraft.sections, null, 2)}`,
|
||||||
@@ -368,6 +380,12 @@ async function handleV2Draft(body: Record<string, unknown>): Promise<NextRespons
|
|||||||
// Compute prompt hash for audit
|
// Compute prompt hash for audit
|
||||||
const promptHash = computeChecksumSync({ factsString, tagsString, termsString, styleString, disallowedString })
|
const promptHash = computeChecksumSync({ factsString, tagsString, termsString, styleString, disallowedString })
|
||||||
|
|
||||||
|
// Step 5b: RAG Legal Context
|
||||||
|
const v2RagQuery = documentType === 'dsfa'
|
||||||
|
? 'DSFA Art. 35 DSGVO Risikobewertung Massnahmen Datenschutz-Folgenabschaetzung'
|
||||||
|
: `${documentType} DSGVO Compliance`
|
||||||
|
const v2RagContext = await queryRAG(v2RagQuery)
|
||||||
|
|
||||||
// Step 6: Generate Prose Blocks (with cache + repair loop)
|
// Step 6: Generate Prose Blocks (with cache + repair loop)
|
||||||
const proseBlocks = DOCUMENT_PROSE_BLOCKS[documentType] || DOCUMENT_PROSE_BLOCKS.tom
|
const proseBlocks = DOCUMENT_PROSE_BLOCKS[documentType] || DOCUMENT_PROSE_BLOCKS.tom
|
||||||
const generatedBlocks: ProseBlockOutput[] = []
|
const generatedBlocks: ProseBlockOutput[] = []
|
||||||
@@ -399,12 +417,15 @@ async function handleV2Draft(body: Record<string, unknown>): Promise<NextRespons
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build prompts
|
// Build prompts
|
||||||
const systemPrompt = buildV2SystemPrompt(
|
let systemPrompt = buildV2SystemPrompt(
|
||||||
factsString, tagsString, termsString, styleString, disallowedString,
|
factsString, tagsString, termsString, styleString, disallowedString,
|
||||||
sanitizedFacts.companyName,
|
sanitizedFacts.companyName,
|
||||||
blockDef.blockId, blockDef.blockType, blockDef.sectionName,
|
blockDef.blockId, blockDef.blockType, blockDef.sectionName,
|
||||||
documentType, blockDef.targetWords
|
documentType, blockDef.targetWords
|
||||||
)
|
)
|
||||||
|
if (v2RagContext) {
|
||||||
|
systemPrompt += `\n\nRECHTSKONTEXT (als Referenz, nicht woertlich uebernehmen):\n${v2RagContext}`
|
||||||
|
}
|
||||||
const userPrompt = buildBlockSpecificPrompt(
|
const userPrompt = buildBlockSpecificPrompt(
|
||||||
blockDef.blockType, blockDef.sectionName, documentType
|
blockDef.blockType, blockDef.sectionName, documentType
|
||||||
) + (instructions ? `\n\nZusaetzliche Anweisungen: ${instructions}` : '')
|
) + (instructions ? `\n\nZusaetzliche Anweisungen: ${instructions}` : '')
|
||||||
|
|||||||
40
admin-compliance/lib/sdk/drafting-engine/rag-query.ts
Normal file
40
admin-compliance/lib/sdk/drafting-engine/rag-query.ts
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
/**
|
||||||
|
* Shared RAG query utility for the Drafting Engine.
|
||||||
|
*
|
||||||
|
* Queries the DSFA RAG corpus via klausur-service for relevant legal context.
|
||||||
|
* Used by both chat and draft routes.
|
||||||
|
*/
|
||||||
|
|
||||||
|
const KLAUSUR_SERVICE_URL = process.env.KLAUSUR_SERVICE_URL || 'http://klausur-service:8086'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Query the RAG corpus for relevant legal documents.
|
||||||
|
*
|
||||||
|
* @param query - The search query (e.g. "DSFA Art. 35 DSGVO")
|
||||||
|
* @param topK - Number of results to return (default: 3)
|
||||||
|
* @returns Formatted string of legal context, or empty string on error
|
||||||
|
*/
|
||||||
|
export async function queryRAG(query: string, topK = 3): Promise<string> {
|
||||||
|
try {
|
||||||
|
const url = `${KLAUSUR_SERVICE_URL}/api/v1/dsfa-rag/search?query=${encodeURIComponent(query)}&top_k=${topK}`
|
||||||
|
const res = await fetch(url, {
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
signal: AbortSignal.timeout(10000),
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!res.ok) return ''
|
||||||
|
|
||||||
|
const data = await res.json()
|
||||||
|
if (data.results?.length > 0) {
|
||||||
|
return data.results
|
||||||
|
.map(
|
||||||
|
(r: { source_name?: string; source_code?: string; content?: string }, i: number) =>
|
||||||
|
`[Quelle ${i + 1}: ${r.source_name || r.source_code || 'Unbekannt'}]\n${r.content || ''}`
|
||||||
|
)
|
||||||
|
.join('\n\n---\n\n')
|
||||||
|
}
|
||||||
|
return ''
|
||||||
|
} catch {
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,9 +21,20 @@ func NewRAGHandlers(corpusVersionStore *ucca.CorpusVersionStore) *RAGHandlers {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowedCollections is the whitelist of Qdrant collections that can be queried.
|
||||||
|
var AllowedCollections = map[string]bool{
|
||||||
|
"bp_compliance_ce": true,
|
||||||
|
"bp_compliance_recht": true,
|
||||||
|
"bp_compliance_gesetze": true,
|
||||||
|
"bp_compliance_datenschutz": true,
|
||||||
|
"bp_dsfa_corpus": true,
|
||||||
|
"bp_legal_templates": true,
|
||||||
|
}
|
||||||
|
|
||||||
// SearchRequest represents a RAG search request.
|
// SearchRequest represents a RAG search request.
|
||||||
type SearchRequest struct {
|
type SearchRequest struct {
|
||||||
Query string `json:"query" binding:"required"`
|
Query string `json:"query" binding:"required"`
|
||||||
|
Collection string `json:"collection,omitempty"`
|
||||||
Regulations []string `json:"regulations,omitempty"`
|
Regulations []string `json:"regulations,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -41,7 +52,15 @@ func (h *RAGHandlers) Search(c *gin.Context) {
|
|||||||
req.TopK = 5
|
req.TopK = 5
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := h.ragClient.Search(c.Request.Context(), req.Query, req.Regulations, req.TopK)
|
// Validate collection if specified
|
||||||
|
if req.Collection != "" {
|
||||||
|
if !AllowedCollections[req.Collection] {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Unknown collection: " + req.Collection + ". Allowed: bp_compliance_ce, bp_compliance_recht, bp_compliance_gesetze, bp_compliance_datenschutz, bp_dsfa_corpus, bp_legal_templates"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := h.ragClient.SearchCollection(c.Request.Context(), req.Collection, req.Query, req.Regulations, req.TopK)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "RAG search failed: " + err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "RAG search failed: " + err.Error()})
|
||||||
return
|
return
|
||||||
|
|||||||
109
ai-compliance-sdk/internal/api/handlers/rag_handlers_test.go
Normal file
109
ai-compliance-sdk/internal/api/handlers/rag_handlers_test.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAllowedCollections(t *testing.T) {
|
||||||
|
allowed := []string{
|
||||||
|
"bp_compliance_ce",
|
||||||
|
"bp_compliance_recht",
|
||||||
|
"bp_compliance_gesetze",
|
||||||
|
"bp_compliance_datenschutz",
|
||||||
|
"bp_dsfa_corpus",
|
||||||
|
"bp_legal_templates",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range allowed {
|
||||||
|
if !AllowedCollections[c] {
|
||||||
|
t.Errorf("Expected %s to be in AllowedCollections", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
disallowed := []string{
|
||||||
|
"bp_unknown",
|
||||||
|
"",
|
||||||
|
"some_random_collection",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range disallowed {
|
||||||
|
if AllowedCollections[c] {
|
||||||
|
t.Errorf("Expected %s to NOT be in AllowedCollections", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearch_InvalidCollection_Returns400(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
handler := &RAGHandlers{}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body := SearchRequest{
|
||||||
|
Query: "test query",
|
||||||
|
Collection: "bp_evil_collection",
|
||||||
|
TopK: 5,
|
||||||
|
}
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
c.Request, _ = http.NewRequest("POST", "/sdk/v1/rag/search", bytes.NewReader(bodyBytes))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
handler.Search(c)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("Expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||||
|
|
||||||
|
errMsg, ok := resp["error"].(string)
|
||||||
|
if !ok || errMsg == "" {
|
||||||
|
t.Error("Expected error message in response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearch_WithCollectionParam_BindsCorrectly(t *testing.T) {
|
||||||
|
// Test that the SearchRequest struct correctly binds the collection field
|
||||||
|
body := `{"query":"DSGVO Art. 35","collection":"bp_compliance_recht","top_k":3}`
|
||||||
|
var req SearchRequest
|
||||||
|
err := json.Unmarshal([]byte(body), &req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Query != "DSGVO Art. 35" {
|
||||||
|
t.Errorf("Expected query 'DSGVO Art. 35', got '%s'", req.Query)
|
||||||
|
}
|
||||||
|
if req.Collection != "bp_compliance_recht" {
|
||||||
|
t.Errorf("Expected collection 'bp_compliance_recht', got '%s'", req.Collection)
|
||||||
|
}
|
||||||
|
if req.TopK != 3 {
|
||||||
|
t.Errorf("Expected top_k 3, got %d", req.TopK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearch_EmptyCollection_IsAllowed(t *testing.T) {
|
||||||
|
// Empty collection should be allowed (falls back to default in the handler)
|
||||||
|
body := `{"query":"test"}`
|
||||||
|
var req SearchRequest
|
||||||
|
err := json.Unmarshal([]byte(body), &req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Collection != "" {
|
||||||
|
t.Errorf("Expected empty collection, got '%s'", req.Collection)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty string is not in AllowedCollections, but the handler
|
||||||
|
// should skip validation for empty collection
|
||||||
|
}
|
||||||
@@ -173,8 +173,22 @@ func (c *LegalRAGClient) generateEmbedding(ctx context.Context, text string) ([]
|
|||||||
return embResp.Embedding, nil
|
return embResp.Embedding, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SearchCollection queries a specific Qdrant collection for relevant passages.
|
||||||
|
// If collection is empty, it falls back to the default collection (bp_compliance_ce).
|
||||||
|
func (c *LegalRAGClient) SearchCollection(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
|
||||||
|
if collection == "" {
|
||||||
|
collection = c.collection
|
||||||
|
}
|
||||||
|
return c.searchInternal(ctx, collection, query, regulationIDs, topK)
|
||||||
|
}
|
||||||
|
|
||||||
// Search queries the compliance CE corpus for relevant passages.
|
// Search queries the compliance CE corpus for relevant passages.
|
||||||
func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
|
func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
|
||||||
|
return c.searchInternal(ctx, c.collection, query, regulationIDs, topK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// searchInternal performs the actual search against a given collection.
|
||||||
|
func (c *LegalRAGClient) searchInternal(ctx context.Context, collection string, query string, regulationIDs []string, topK int) ([]LegalSearchResult, error) {
|
||||||
// Generate query embedding via Ollama bge-m3
|
// Generate query embedding via Ollama bge-m3
|
||||||
embedding, err := c.generateEmbedding(ctx, query)
|
embedding, err := c.generateEmbedding(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -206,7 +220,7 @@ func (c *LegalRAGClient) Search(ctx context.Context, query string, regulationIDs
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call Qdrant
|
// Call Qdrant
|
||||||
url := fmt.Sprintf("http://%s:%s/collections/%s/points/search", c.qdrantHost, c.qdrantPort, c.collection)
|
url := fmt.Sprintf("http://%s:%s/collections/%s/points/search", c.qdrantHost, c.qdrantPort, collection)
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create search request: %w", err)
|
return nil, fmt.Errorf("failed to create search request: %w", err)
|
||||||
|
|||||||
157
ai-compliance-sdk/internal/ucca/legal_rag_test.go
Normal file
157
ai-compliance-sdk/internal/ucca/legal_rag_test.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
package ucca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSearchCollection_UsesCorrectCollection(t *testing.T) {
|
||||||
|
// Track which collection was requested
|
||||||
|
var requestedURL string
|
||||||
|
|
||||||
|
// Mock Ollama (embedding)
|
||||||
|
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||||
|
Embedding: make([]float64, 1024),
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ollamaMock.Close()
|
||||||
|
|
||||||
|
// Mock Qdrant
|
||||||
|
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestedURL = r.URL.Path
|
||||||
|
json.NewEncoder(w).Encode(qdrantSearchResponse{
|
||||||
|
Result: []qdrantSearchHit{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer qdrantMock.Close()
|
||||||
|
|
||||||
|
// Parse qdrant mock host/port
|
||||||
|
qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://")
|
||||||
|
parts := strings.Split(qdrantAddr, ":")
|
||||||
|
|
||||||
|
client := &LegalRAGClient{
|
||||||
|
qdrantHost: parts[0],
|
||||||
|
qdrantPort: parts[1],
|
||||||
|
ollamaURL: ollamaMock.URL,
|
||||||
|
embeddingModel: "bge-m3",
|
||||||
|
collection: "bp_compliance_ce",
|
||||||
|
httpClient: http.DefaultClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with explicit collection
|
||||||
|
_, err := client.SearchCollection(context.Background(), "bp_compliance_recht", "test query", nil, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SearchCollection failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(requestedURL, "/collections/bp_compliance_recht/") {
|
||||||
|
t.Errorf("Expected collection bp_compliance_recht in URL, got: %s", requestedURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchCollection_FallbackDefault(t *testing.T) {
|
||||||
|
var requestedURL string
|
||||||
|
|
||||||
|
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||||
|
Embedding: make([]float64, 1024),
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ollamaMock.Close()
|
||||||
|
|
||||||
|
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestedURL = r.URL.Path
|
||||||
|
json.NewEncoder(w).Encode(qdrantSearchResponse{
|
||||||
|
Result: []qdrantSearchHit{},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer qdrantMock.Close()
|
||||||
|
|
||||||
|
qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://")
|
||||||
|
parts := strings.Split(qdrantAddr, ":")
|
||||||
|
|
||||||
|
client := &LegalRAGClient{
|
||||||
|
qdrantHost: parts[0],
|
||||||
|
qdrantPort: parts[1],
|
||||||
|
ollamaURL: ollamaMock.URL,
|
||||||
|
embeddingModel: "bge-m3",
|
||||||
|
collection: "bp_compliance_ce",
|
||||||
|
httpClient: http.DefaultClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with empty collection (should fall back to default)
|
||||||
|
_, err := client.SearchCollection(context.Background(), "", "test query", nil, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SearchCollection failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(requestedURL, "/collections/bp_compliance_ce/") {
|
||||||
|
t.Errorf("Expected default collection bp_compliance_ce in URL, got: %s", requestedURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearch_StillWorks(t *testing.T) {
|
||||||
|
var requestedURL string
|
||||||
|
|
||||||
|
ollamaMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(ollamaEmbeddingResponse{
|
||||||
|
Embedding: make([]float64, 1024),
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ollamaMock.Close()
|
||||||
|
|
||||||
|
qdrantMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestedURL = r.URL.Path
|
||||||
|
json.NewEncoder(w).Encode(qdrantSearchResponse{
|
||||||
|
Result: []qdrantSearchHit{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
Score: 0.95,
|
||||||
|
Payload: map[string]interface{}{
|
||||||
|
"chunk_text": "Test content",
|
||||||
|
"regulation_id": "eu_2016_679",
|
||||||
|
"regulation_name_de": "DSGVO",
|
||||||
|
"regulation_short": "DSGVO",
|
||||||
|
"category": "regulation",
|
||||||
|
"source": "https://example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer qdrantMock.Close()
|
||||||
|
|
||||||
|
qdrantAddr := strings.TrimPrefix(qdrantMock.URL, "http://")
|
||||||
|
parts := strings.Split(qdrantAddr, ":")
|
||||||
|
|
||||||
|
client := &LegalRAGClient{
|
||||||
|
qdrantHost: parts[0],
|
||||||
|
qdrantPort: parts[1],
|
||||||
|
ollamaURL: ollamaMock.URL,
|
||||||
|
embeddingModel: "bge-m3",
|
||||||
|
collection: "bp_compliance_ce",
|
||||||
|
httpClient: http.DefaultClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := client.Search(context.Background(), "DSGVO Art. 35", nil, 5)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Search failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Fatalf("Expected 1 result, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[0].RegulationCode != "eu_2016_679" {
|
||||||
|
t.Errorf("Expected regulation_code eu_2016_679, got %s", results[0].RegulationCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(requestedURL, "/collections/bp_compliance_ce/") {
|
||||||
|
t.Errorf("Expected default collection in URL, got: %s", requestedURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -177,8 +177,12 @@ async def get_regulation_requirements(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/requirements/{requirement_id}")
|
@router.get("/requirements/{requirement_id}")
|
||||||
async def get_requirement(requirement_id: str, db: Session = Depends(get_db)):
|
async def get_requirement(
|
||||||
"""Get a specific requirement by ID."""
|
requirement_id: str,
|
||||||
|
include_legal_context: bool = Query(False, description="Include RAG legal context"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get a specific requirement by ID, optionally with RAG legal context."""
|
||||||
from ..db.models import RequirementDB, RegulationDB
|
from ..db.models import RequirementDB, RegulationDB
|
||||||
|
|
||||||
requirement = db.query(RequirementDB).filter(RequirementDB.id == requirement_id).first()
|
requirement = db.query(RequirementDB).filter(RequirementDB.id == requirement_id).first()
|
||||||
@@ -187,7 +191,7 @@ async def get_requirement(requirement_id: str, db: Session = Depends(get_db)):
|
|||||||
|
|
||||||
regulation = db.query(RegulationDB).filter(RegulationDB.id == requirement.regulation_id).first()
|
regulation = db.query(RegulationDB).filter(RegulationDB.id == requirement.regulation_id).first()
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"id": requirement.id,
|
"id": requirement.id,
|
||||||
"regulation_id": requirement.regulation_id,
|
"regulation_id": requirement.regulation_id,
|
||||||
"regulation_code": regulation.code if regulation else None,
|
"regulation_code": regulation.code if regulation else None,
|
||||||
@@ -214,6 +218,33 @@ async def get_requirement(requirement_id: str, db: Session = Depends(get_db)):
|
|||||||
"source_section": requirement.source_section,
|
"source_section": requirement.source_section,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if include_legal_context:
|
||||||
|
try:
|
||||||
|
from ..services.rag_client import get_rag_client
|
||||||
|
from ..services.ai_compliance_assistant import AIComplianceAssistant
|
||||||
|
|
||||||
|
rag = get_rag_client()
|
||||||
|
assistant = AIComplianceAssistant()
|
||||||
|
query = f"{requirement.title} {requirement.article or ''}"
|
||||||
|
collection = assistant._collection_for_regulation(regulation.code if regulation else "")
|
||||||
|
rag_results = await rag.search(query, collection=collection, top_k=3)
|
||||||
|
result["legal_context"] = [
|
||||||
|
{
|
||||||
|
"text": r.text,
|
||||||
|
"regulation_code": r.regulation_code,
|
||||||
|
"regulation_short": r.regulation_short,
|
||||||
|
"article": r.article,
|
||||||
|
"score": r.score,
|
||||||
|
"source_url": r.source_url,
|
||||||
|
}
|
||||||
|
for r in rag_results
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to fetch legal context for %s: %s", requirement_id, e)
|
||||||
|
result["legal_context"] = []
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.get("/requirements", response_model=PaginatedRequirementResponse)
|
@router.get("/requirements", response_model=PaginatedRequirementResponse)
|
||||||
async def list_requirements_paginated(
|
async def list_requirements_paginated(
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from typing import List, Optional, Dict, Any
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from .llm_provider import LLMProvider, get_shared_provider, LLMResponse
|
from .llm_provider import LLMProvider, get_shared_provider, LLMResponse
|
||||||
|
from .rag_client import get_rag_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -199,9 +200,23 @@ Bewerte die Abdeckung und identifiziere Lücken im JSON-Format:
|
|||||||
|
|
||||||
Gib NUR das JSON zurück."""
|
Gib NUR das JSON zurück."""
|
||||||
|
|
||||||
|
# EU regulation codes → bp_compliance_ce, DE codes → bp_compliance_recht
|
||||||
|
_EU_CODES = {"DSGVO", "GDPR", "AIACT", "AI_ACT", "NIS2", "CRA"}
|
||||||
|
_DE_CODES = {"BDSG", "TDDDG", "DDG", "URHG", "TMG", "TKG"}
|
||||||
|
|
||||||
def __init__(self, llm_provider: Optional[LLMProvider] = None):
|
def __init__(self, llm_provider: Optional[LLMProvider] = None):
|
||||||
"""Initialize the assistant with an LLM provider."""
|
"""Initialize the assistant with an LLM provider."""
|
||||||
self.llm = llm_provider or get_shared_provider()
|
self.llm = llm_provider or get_shared_provider()
|
||||||
|
self.rag = get_rag_client()
|
||||||
|
|
||||||
|
def _collection_for_regulation(self, regulation_code: str) -> str:
|
||||||
|
"""Determine the RAG collection based on regulation code."""
|
||||||
|
code_upper = regulation_code.upper()
|
||||||
|
if any(c in code_upper for c in self._EU_CODES):
|
||||||
|
return "bp_compliance_ce"
|
||||||
|
elif any(c in code_upper for c in self._DE_CODES):
|
||||||
|
return "bp_compliance_recht"
|
||||||
|
return "bp_compliance_ce"
|
||||||
|
|
||||||
async def interpret_requirement(
|
async def interpret_requirement(
|
||||||
self,
|
self,
|
||||||
@@ -226,6 +241,17 @@ Gib NUR das JSON zurück."""
|
|||||||
requirement_text=requirement_text or "Kein Text verfügbar"
|
requirement_text=requirement_text or "Kein Text verfügbar"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Enrich prompt with RAG legal context
|
||||||
|
try:
|
||||||
|
rag_query = f"{regulation_name} {article} {title}"
|
||||||
|
collection = self._collection_for_regulation(regulation_code)
|
||||||
|
rag_results = await self.rag.search(rag_query, collection=collection, top_k=3)
|
||||||
|
rag_context = self.rag.format_for_prompt(rag_results)
|
||||||
|
if rag_context:
|
||||||
|
prompt += f"\n\n{rag_context}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("RAG enrichment failed for interpret_requirement: %s", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.llm.complete(
|
response = await self.llm.complete(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -282,6 +308,16 @@ Gib NUR das JSON zurück."""
|
|||||||
affected_modules=", ".join(affected_modules) if affected_modules else "Alle Module"
|
affected_modules=", ".join(affected_modules) if affected_modules else "Alle Module"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Enrich prompt with RAG legal context
|
||||||
|
try:
|
||||||
|
rag_query = f"{regulation_name} {requirement_title} Massnahmen Controls"
|
||||||
|
rag_results = await self.rag.search(rag_query, top_k=3)
|
||||||
|
rag_context = self.rag.format_for_prompt(rag_results)
|
||||||
|
if rag_context:
|
||||||
|
prompt += f"\n\n{rag_context}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("RAG enrichment failed for suggest_controls: %s", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.llm.complete(
|
response = await self.llm.complete(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
129
backend-compliance/compliance/services/rag_client.py
Normal file
129
backend-compliance/compliance/services/rag_client.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
Compliance RAG Client — Proxy to Go SDK RAG Search.
|
||||||
|
|
||||||
|
Lightweight HTTP client that queries the Go AI Compliance SDK's
|
||||||
|
POST /sdk/v1/rag/search endpoint. This avoids needing embedding
|
||||||
|
models or direct Qdrant access in Python.
|
||||||
|
|
||||||
|
Error-tolerant: RAG failures never break the calling function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090")
|
||||||
|
RAG_SEARCH_TIMEOUT = 15.0 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RAGSearchResult:
|
||||||
|
"""A single search result from the compliance corpus."""
|
||||||
|
text: str
|
||||||
|
regulation_code: str
|
||||||
|
regulation_name: str
|
||||||
|
regulation_short: str
|
||||||
|
category: str
|
||||||
|
article: str
|
||||||
|
paragraph: str
|
||||||
|
source_url: str
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceRAGClient:
|
||||||
|
"""
|
||||||
|
RAG client that proxies search requests to the Go SDK.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
client = get_rag_client()
|
||||||
|
results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht")
|
||||||
|
context_str = client.format_for_prompt(results)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str = SDK_URL):
|
||||||
|
self._search_url = f"{base_url}/sdk/v1/rag/search"
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
collection: str = "bp_compliance_ce",
|
||||||
|
regulations: Optional[List[str]] = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> List[RAGSearchResult]:
|
||||||
|
"""
|
||||||
|
Search the RAG corpus via Go SDK.
|
||||||
|
|
||||||
|
Returns an empty list on any error (never raises).
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"collection": collection,
|
||||||
|
"top_k": top_k,
|
||||||
|
}
|
||||||
|
if regulations:
|
||||||
|
payload["regulations"] = regulations
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client:
|
||||||
|
resp = await client.post(self._search_url, json=payload)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(
|
||||||
|
"RAG search returned %d: %s", resp.status_code, resp.text[:200]
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
results = []
|
||||||
|
for r in data.get("results", []):
|
||||||
|
results.append(RAGSearchResult(
|
||||||
|
text=r.get("text", ""),
|
||||||
|
regulation_code=r.get("regulation_code", ""),
|
||||||
|
regulation_name=r.get("regulation_name", ""),
|
||||||
|
regulation_short=r.get("regulation_short", ""),
|
||||||
|
category=r.get("category", ""),
|
||||||
|
article=r.get("article", ""),
|
||||||
|
paragraph=r.get("paragraph", ""),
|
||||||
|
source_url=r.get("source_url", ""),
|
||||||
|
score=r.get("score", 0.0),
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("RAG search failed: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def format_for_prompt(
|
||||||
|
self, results: List[RAGSearchResult], max_results: int = 5
|
||||||
|
) -> str:
|
||||||
|
"""Format search results as Markdown for inclusion in an LLM prompt."""
|
||||||
|
if not results:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = ["## Relevanter Rechtskontext\n"]
|
||||||
|
for i, r in enumerate(results[:max_results]):
|
||||||
|
header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})"
|
||||||
|
if r.article:
|
||||||
|
header += f" — {r.article}"
|
||||||
|
lines.append(header)
|
||||||
|
text = r.text[:400] + "..." if len(r.text) > 400 else r.text
|
||||||
|
lines.append(f" > {text}\n")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton
|
||||||
|
_rag_client: Optional[ComplianceRAGClient] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_client() -> ComplianceRAGClient:
|
||||||
|
"""Get the shared RAG client instance."""
|
||||||
|
global _rag_client
|
||||||
|
if _rag_client is None:
|
||||||
|
_rag_client = ComplianceRAGClient()
|
||||||
|
return _rag_client
|
||||||
0
backend-compliance/tests/__init__.py
Normal file
0
backend-compliance/tests/__init__.py
Normal file
175
backend-compliance/tests/test_rag_client.py
Normal file
175
backend-compliance/tests/test_rag_client.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""Tests for ComplianceRAGClient."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
|
||||||
|
from compliance.services.rag_client import ComplianceRAGClient, RAGSearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestComplianceRAGClient:
|
||||||
|
"""Tests for the RAG client proxy."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
self.client = ComplianceRAGClient(base_url="http://test-sdk:8090")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_success(self):
|
||||||
|
"""Successful search returns parsed results."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"query": "DSGVO Art. 35",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"text": "Art. 35 DSGVO regelt die Datenschutz-Folgenabschaetzung...",
|
||||||
|
"regulation_code": "eu_2016_679",
|
||||||
|
"regulation_name": "DSGVO",
|
||||||
|
"regulation_short": "DSGVO",
|
||||||
|
"category": "regulation",
|
||||||
|
"article": "Art. 35",
|
||||||
|
"paragraph": "",
|
||||||
|
"source_url": "https://example.com",
|
||||||
|
"score": 0.92,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"count": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
mock_instance = AsyncMock()
|
||||||
|
mock_instance.post.return_value = mock_response
|
||||||
|
mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
|
||||||
|
mock_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = mock_instance
|
||||||
|
|
||||||
|
results = await self.client.search("DSGVO Art. 35", collection="bp_compliance_ce")
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].regulation_code == "eu_2016_679"
|
||||||
|
assert results[0].score == 0.92
|
||||||
|
assert "Art. 35" in results[0].text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_failure_returns_empty(self):
|
||||||
|
"""Network errors return empty list, never raise."""
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
mock_instance = AsyncMock()
|
||||||
|
mock_instance.post.side_effect = Exception("Connection refused")
|
||||||
|
mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
|
||||||
|
mock_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = mock_instance
|
||||||
|
|
||||||
|
results = await self.client.search("test query")
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_http_error_returns_empty(self):
|
||||||
|
"""HTTP errors return empty list."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
|
mock_response.text = "Internal Server Error"
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as MockClient:
|
||||||
|
mock_instance = AsyncMock()
|
||||||
|
mock_instance.post.return_value = mock_response
|
||||||
|
mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
|
||||||
|
mock_instance.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
MockClient.return_value = mock_instance
|
||||||
|
|
||||||
|
results = await self.client.search("test query")
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_format_for_prompt(self):
|
||||||
|
"""format_for_prompt produces Markdown output."""
|
||||||
|
results = [
|
||||||
|
RAGSearchResult(
|
||||||
|
text="Die Verarbeitung personenbezogener Daten...",
|
||||||
|
regulation_code="eu_2016_679",
|
||||||
|
regulation_name="DSGVO",
|
||||||
|
regulation_short="DSGVO",
|
||||||
|
category="regulation",
|
||||||
|
article="Art. 35",
|
||||||
|
paragraph="",
|
||||||
|
source_url="https://example.com",
|
||||||
|
score=0.9,
|
||||||
|
),
|
||||||
|
RAGSearchResult(
|
||||||
|
text="Risikobewertung fuer KI-Systeme...",
|
||||||
|
regulation_code="eu_2024_1689",
|
||||||
|
regulation_name="AI Act",
|
||||||
|
regulation_short="AI Act",
|
||||||
|
category="regulation",
|
||||||
|
article="",
|
||||||
|
paragraph="",
|
||||||
|
source_url="https://example.com",
|
||||||
|
score=0.85,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
output = self.client.format_for_prompt(results)
|
||||||
|
|
||||||
|
assert "## Relevanter Rechtskontext" in output
|
||||||
|
assert "**DSGVO**" in output
|
||||||
|
assert "Art. 35" in output
|
||||||
|
assert "**AI Act**" in output
|
||||||
|
|
||||||
|
def test_format_for_prompt_empty(self):
|
||||||
|
"""Empty results return empty string."""
|
||||||
|
assert self.client.format_for_prompt([]) == ""
|
||||||
|
|
||||||
|
def test_format_for_prompt_truncation(self):
|
||||||
|
"""Long text is truncated to 400 chars."""
|
||||||
|
results = [
|
||||||
|
RAGSearchResult(
|
||||||
|
text="A" * 500,
|
||||||
|
regulation_code="test",
|
||||||
|
regulation_name="Test",
|
||||||
|
regulation_short="Test",
|
||||||
|
category="test",
|
||||||
|
article="",
|
||||||
|
paragraph="",
|
||||||
|
source_url="",
|
||||||
|
score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
output = self.client.format_for_prompt(results)
|
||||||
|
assert "..." in output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectionMapping:
|
||||||
|
"""Tests for regulation → collection mapping in AIComplianceAssistant."""
|
||||||
|
|
||||||
|
def test_eu_regulations_map_to_ce(self):
|
||||||
|
from compliance.services.ai_compliance_assistant import AIComplianceAssistant
|
||||||
|
assistant = AIComplianceAssistant.__new__(AIComplianceAssistant)
|
||||||
|
|
||||||
|
assert assistant._collection_for_regulation("DSGVO") == "bp_compliance_ce"
|
||||||
|
assert assistant._collection_for_regulation("GDPR") == "bp_compliance_ce"
|
||||||
|
assert assistant._collection_for_regulation("AI_ACT") == "bp_compliance_ce"
|
||||||
|
assert assistant._collection_for_regulation("NIS2") == "bp_compliance_ce"
|
||||||
|
assert assistant._collection_for_regulation("CRA") == "bp_compliance_ce"
|
||||||
|
|
||||||
|
def test_de_regulations_map_to_recht(self):
|
||||||
|
from compliance.services.ai_compliance_assistant import AIComplianceAssistant
|
||||||
|
assistant = AIComplianceAssistant.__new__(AIComplianceAssistant)
|
||||||
|
|
||||||
|
assert assistant._collection_for_regulation("BDSG") == "bp_compliance_recht"
|
||||||
|
assert assistant._collection_for_regulation("TDDDG") == "bp_compliance_recht"
|
||||||
|
assert assistant._collection_for_regulation("TKG") == "bp_compliance_recht"
|
||||||
|
|
||||||
|
def test_unknown_regulation_defaults_to_ce(self):
|
||||||
|
from compliance.services.ai_compliance_assistant import AIComplianceAssistant
|
||||||
|
assistant = AIComplianceAssistant.__new__(AIComplianceAssistant)
|
||||||
|
|
||||||
|
assert assistant._collection_for_regulation("UNKNOWN") == "bp_compliance_ce"
|
||||||
|
assert assistant._collection_for_regulation("") == "bp_compliance_ce"
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
from compliance.services.ai_compliance_assistant import AIComplianceAssistant
|
||||||
|
assistant = AIComplianceAssistant.__new__(AIComplianceAssistant)
|
||||||
|
|
||||||
|
assert assistant._collection_for_regulation("dsgvo") == "bp_compliance_ce"
|
||||||
|
assert assistant._collection_for_regulation("bdsg") == "bp_compliance_recht"
|
||||||
Reference in New Issue
Block a user