diff --git a/ai-compliance-sdk/internal/ucca/authority_router.go b/ai-compliance-sdk/internal/ucca/authority_router.go index edd7c0d6..98d006bf 100644 --- a/ai-compliance-sdk/internal/ucca/authority_router.go +++ b/ai-compliance-sdk/internal/ucca/authority_router.go @@ -55,6 +55,15 @@ func (c *LegalRAGClient) Retrieve(ctx context.Context, query string, topK int) ( collections = append(collections, c.kbSliceCollection) } + // Cross-regulation queries (>=2 explicitly named regulations) get a larger per-collection budget + // so each collection's multi-regulation search isn't truncated down to the keyword-dominant + // domain; the final per-regulation balancing then guarantees every named domain in the top-K. + regs := detectRegulations(query) + perColl := routerPerCollectionTopK + if len(regs) >= 2 { + perColl = routerPerCollectionTopK * len(regs) + } + // Warm the full-text indexes sequentially first so the concurrent fan-out below only READS the // shared textIndexEnsured map (the writes happen here, serialized) — closes the cold-start map // race deterministically. Best-effort: a missing collection just stays un-indexed (hybrid then @@ -71,19 +80,25 @@ func (c *LegalRAGClient) Retrieve(ctx context.Context, query string, topK int) ( wg.Add(1) go func(i int, coll string) { defer wg.Done() - if res, err := c.searchInternal(ctx, coll, query, nil, routerPerCollectionTopK); err == nil { + if res, err := c.searchInternal(ctx, coll, query, nil, perColl); err == nil { out[i] = res } }(i, coll) } wg.Wait() - merged := make([]LegalSearchResult, 0, len(collections)*routerPerCollectionTopK) + merged := make([]LegalSearchResult, 0, len(collections)*perColl) for _, r := range out { merged = append(merged, r...) } merged = dedupResults(merged) sort.SliceStable(merged, func(a, b int) bool { return merged[a].Score > merged[b].Score }) + + // Cross-regulation: guarantee every named domain is represented (0070-class fix) instead of + // letting a global score-sort starve the non-dominant domain. + if len(regs) >= 2 { + return balanceByRegulation(merged, regs, topK), nil + } if len(merged) > topK { merged = merged[:topK] } diff --git a/ai-compliance-sdk/internal/ucca/authority_router_e2e_test.go b/ai-compliance-sdk/internal/ucca/authority_router_e2e_test.go index 893378fc..e8351cfc 100644 --- a/ai-compliance-sdk/internal/ucca/authority_router_e2e_test.go +++ b/ai-compliance-sdk/internal/ucca/authority_router_e2e_test.go @@ -60,6 +60,36 @@ func hitDoc(results []LegalSearchResult, toks []string) bool { return false } +// TestMultiReg0070E2E (RUN_E2E=1) is the 0070 regression: a cross-regulation query (CRA + MaschVO) +// must return BOTH domains through the real Retrieve(), not just the keyword-dominant CRA. +func TestMultiReg0070E2E(t *testing.T) { + if os.Getenv("RUN_E2E") != "1" { + t.Skip("set RUN_E2E=1 + QDRANT_URL/OLLAMA_URL/QDRANT_API_KEY") + } + c := NewLegalRAGClient() + q := "Wie greifen CRA und Maschinenverordnung bei einer vernetzten Maschine ineinander?" + res, err := c.Retrieve(context.Background(), q, 8) + if err != nil { + t.Fatalf("retrieve: %v", err) + } + var hasCRA, hasMasch bool + var codes []string + for _, r := range res { + u := strings.ToUpper(r.RegulationCode) + codes = append(codes, u) + if strings.Contains(u, "CRA") { + hasCRA = true + } + if strings.Contains(u, "MASCH") || strings.Contains(u, "MACHIN") || u == "MVO" { + hasMasch = true + } + } + t.Logf("0070 top-8 codes: %v", codes) + if !hasCRA || !hasMasch { + t.Errorf("0070 must return BOTH domains via Retrieve(): CRA=%v MaschVO=%v", hasCRA, hasMasch) + } +} + // TestAuthorityRouterCB100 (RUN_E2E=1) drives the REAL Retrieve() over the ComplianceBench-100 against // the live collections: NEW (scope routing on → slice added for in-scope queries) vs OLD (routing off // → broad base only). It is the regression gate that the router actually delivers the proven slice diff --git a/ai-compliance-sdk/internal/ucca/authority_router_test.go b/ai-compliance-sdk/internal/ucca/authority_router_test.go index c2c25a55..a08d7fd3 100644 --- a/ai-compliance-sdk/internal/ucca/authority_router_test.go +++ b/ai-compliance-sdk/internal/ucca/authority_router_test.go @@ -49,6 +49,38 @@ func TestRouterSliceSelection(t *testing.T) { } } +func TestBalanceByRegulation(t *testing.T) { + regs := []detectedRegulation{ + {Canonical: "CRA", CodeValues: []string{"CRA"}}, + {Canonical: "MaschVO", CodeValues: []string{"MASCHVO", "MVO", "MACHINERY"}}, + } + // CRA dominates by score; without balancing the top-4 would be all CRA + NIST. + pool := []LegalSearchResult{ + {RegulationCode: "CRA", Score: 0.99}, + {RegulationCode: "CRA", Score: 0.98}, + {RegulationCode: "CRA", Score: 0.97}, + {RegulationCode: "NIST", Score: 0.96}, + {RegulationCode: "MACHINERY", Score: 0.70}, + {RegulationCode: "MVO", Score: 0.65}, + } + out := balanceByRegulation(pool, regs, 4) + var hasCRA, hasMasch bool + for _, r := range out { + switch r.RegulationCode { + case "CRA": + hasCRA = true + case "MACHINERY", "MVO": + hasMasch = true + } + } + if !hasCRA || !hasMasch { + t.Errorf("both named domains must be represented: CRA=%v MaschVO=%v out=%v", hasCRA, hasMasch, out) + } + if out[0].RegulationCode != "CRA" || !(out[1].RegulationCode == "MACHINERY" || out[1].RegulationCode == "MVO") { + t.Errorf("round-robin should alternate domains, got %s then %s", out[0].RegulationCode, out[1].RegulationCode) + } +} + func TestDedupResults(t *testing.T) { in := []LegalSearchResult{ {RegulationCode: "EDPB WP248", ArticleLabel: "III.B", Text: "lorem", Score: 0.7}, diff --git a/ai-compliance-sdk/internal/ucca/multi_regulation.go b/ai-compliance-sdk/internal/ucca/multi_regulation.go index 3a71dd54..a9a33ef8 100644 --- a/ai-compliance-sdk/internal/ucca/multi_regulation.go +++ b/ai-compliance-sdk/internal/ucca/multi_regulation.go @@ -20,7 +20,9 @@ var regulationCatalog = []struct { CodeValues []string }{ {"CRA", []string{"cra", "cyber resilience"}, []string{"CRA"}}, - {"MaschVO", []string{"maschinenverordnung", "maschvo", "machinery regulation"}, []string{"MASCHVO", "MaschVO"}}, + // MaschVO heisst je Collection anders: Slice MASCHVO · gesetze MVO · ce MACHINERY/MASCHINENVO. + // Alle Varianten als CodeValues, sonst findet der per-Reg-Filter MaschVO nur in der Slice (0070). + {"MaschVO", []string{"maschinenverordnung", "maschvo", "machinery regulation"}, []string{"MASCHVO", "MaschVO", "MVO", "MASCHINENVO", "MACHINERY"}}, {"NIS2", []string{"nis2", "nis-2", "nis 2"}, []string{"NIS2"}}, {"DORA", []string{"dora"}, []string{"DORA"}}, {"Data Act", []string{"data act", "datengesetz"}, []string{"DATA ACT", "DataAct"}}, @@ -53,6 +55,62 @@ func detectRegulations(query string) []detectedRegulation { func hitID(h qdrantSearchHit) string { return fmt.Sprintf("%v", h.ID) } +// balanceByRegulation builds the final top-K so EVERY explicitly-named regulation with hits is +// represented, instead of letting the keyword-dominant domain (e.g. CRA) crowd out the other +// (e.g. MaschVO) in a cross-regulation query. The input pool must already be score-ordered; +// results are grouped by exact regulation_code match against each regulation's CodeValues, then +// taken round-robin across the named domains (highest-scored first within each), with any +// remaining slots filled by the leftover pool in score order. Generic; no doc-specific logic. +func balanceByRegulation(pool []LegalSearchResult, regs []detectedRegulation, topK int) []LegalSearchResult { + if topK <= 0 { + topK = 8 + } + byReg := make([][]LegalSearchResult, len(regs)) + matched := make([]bool, len(pool)) + for ri, r := range regs { + for pi := range pool { + if matched[pi] { + continue + } + code := strings.ToUpper(strings.TrimSpace(pool[pi].RegulationCode)) + for _, cv := range r.CodeValues { + if code == strings.ToUpper(cv) { + byReg[ri] = append(byReg[ri], pool[pi]) + matched[pi] = true + break + } + } + } + } + out := make([]LegalSearchResult, 0, topK) + idx := make([]int, len(regs)) + for len(out) < topK { + progressed := false + for ri := range regs { + if idx[ri] < len(byReg[ri]) { + out = append(out, byReg[ri][idx[ri]]) + idx[ri]++ + progressed = true + if len(out) >= topK { + break + } + } + } + if !progressed { + break + } + } + for pi := range pool { + if len(out) >= topK { + break + } + if !matched[pi] { + out = append(out, pool[pi]) + } + } + return out +} + // searchMultiRegulation retrieves each explicitly-named regulation SEPARATELY (per-regulation // filter) and merges, so a cross-regulation query ("Wie greifen CRA und MaschVO ineinander?") // returns BOTH domains in the prompt instead of only the keyword-dominant one. Generic over any