package ucca import ( "context" "encoding/json" "os" "strconv" "strings" "testing" ) type benchQ struct { ID string `json:"id"` Document string `json:"document"` Question string `json:"question"` } // docTokens maps a bench question's expected document to acceptable regulation_code/label substrings. func docTokens(document string) []string { d := strings.ToUpper(document) var t []string for _, wp := range []string{"WP243", "WP248", "WP260"} { if strings.Contains(d, wp) { t = append(t, wp) } } dns := strings.ReplaceAll(d, " ", "") for _, gl := range []struct{ key, tok string }{{"07/2020", "GL07"}, {"05/2020", "GL05"}, {"09/2022", "GL09"}} { if strings.Contains(dns, gl.key) { t = append(t, gl.tok) } } if strings.Contains(d, "TDDDG") { t = append(t, "TDDDG") } if strings.Contains(d, "DSGVO") || strings.Contains(d, "ART. 13") || strings.Contains(d, "ART. 14") { t = append(t, "DSGVO") } if strings.Contains(d, "BDSG") { t = append(t, "BDSG") } if strings.Contains(d, "CRA") { t = append(t, "CRA") } if strings.Contains(d, "MASCH") { t = append(t, "MASCH", "MACHINERY", "MVO") } return t } func hitDoc(results []LegalSearchResult, toks []string) bool { for _, r := range results { s := strings.ReplaceAll(strings.ToUpper(r.RegulationCode+" "+r.ArticleLabel), " ", "") for _, tk := range toks { if strings.Contains(s, strings.ReplaceAll(tk, " ", "")) { return true } } } return false } // TestMultiReg0070E2E (RUN_E2E=1) is the 0070 regression: a cross-regulation query (CRA + MaschVO) // must return BOTH domains through the real Retrieve(), not just the keyword-dominant CRA. func TestMultiReg0070E2E(t *testing.T) { if os.Getenv("RUN_E2E") != "1" { t.Skip("set RUN_E2E=1 + QDRANT_URL/OLLAMA_URL/QDRANT_API_KEY") } c := NewLegalRAGClient() q := "Wie greifen CRA und Maschinenverordnung bei einer vernetzten Maschine ineinander?" res, err := c.Retrieve(context.Background(), q, 8) if err != nil { t.Fatalf("retrieve: %v", err) } var hasCRA, hasMasch bool var codes []string for _, r := range res { u := strings.ToUpper(r.RegulationCode) codes = append(codes, u) if strings.Contains(u, "CRA") { hasCRA = true } if strings.Contains(u, "MASCH") || strings.Contains(u, "MACHIN") || u == "MVO" { hasMasch = true } } t.Logf("0070 top-8 codes: %v", codes) if !hasCRA || !hasMasch { t.Errorf("0070 must return BOTH domains via Retrieve(): CRA=%v MaschVO=%v", hasCRA, hasMasch) } } // TestAuthorityRouterCB100 (RUN_E2E=1) drives the REAL Retrieve() over the ComplianceBench-100 against // the live collections: NEW (scope routing on → slice added for in-scope queries) vs OLD (routing off // → broad base only). It is the regression gate that the router actually delivers the proven slice // gain (+28/0-regr in the offline simulation) through the production Go code path. func TestAuthorityRouterCB100(t *testing.T) { if os.Getenv("RUN_E2E") != "1" { t.Skip("set RUN_E2E=1 + QDRANT_URL/OLLAMA_URL/QDRANT_API_KEY + BENCH_PATH") } path := os.Getenv("BENCH_PATH") if path == "" { path = "/tmp/compliance_bench.json" } raw, err := os.ReadFile(path) if err != nil { t.Fatalf("bench read: %v", err) } var doc struct { Questions []benchQ `json:"questions"` } if err := json.Unmarshal(raw, &doc); err != nil { t.Fatalf("bench parse: %v", err) } // BENCH_STRIDE samples every Kth question (stratified across DS/CRA/MaschVO) so the gate stays // tractable against the remote dev Qdrant; default 1 = full CB-100. stride := 1 if s := os.Getenv("BENCH_STRIDE"); s != "" { if n, err := strconv.Atoi(s); err == nil && n > 0 { stride = n } } c := NewLegalRAGClient() ctx := context.Background() var n, oldHit, newHit, gain, regr int for i, q := range doc.Questions { if i%stride != 0 { continue } n++ toks := docTokens(q.Document) c.kbScopeRoutingEnabled = false oldRes, _ := c.Retrieve(ctx, q.Question, 8) c.kbScopeRoutingEnabled = true newRes, _ := c.Retrieve(ctx, q.Question, 8) oh, nh := hitDoc(oldRes, toks), hitDoc(newRes, toks) if oh { oldHit++ } if nh { newHit++ } flip := "=" switch { case !oh && nh: gain++ flip = "GAIN" case oh && !nh: regr++ flip = "REGR" } t.Logf("%-9s [%-14s] OLD=%-5v NEW=%-5v %s", q.ID, q.Document, oh, nh, flip) } t.Logf("CB-100 sample (stride=%d) via Retrieve(): N=%d | OLD-hit %d | NEW-hit %d | GAIN %d | REGR %d", stride, n, oldHit, newHit, gain, regr) if newHit <= oldHit || gain < 3 { t.Errorf("router must add slice gains: NEW(%d) must exceed OLD(%d), gain=%d", newHit, oldHit, gain) } if regr > 2 { t.Errorf("too many regressions through the router: %d", regr) } }