Files
breakpilot-compliance/ai-compliance-sdk/internal/ucca/legal_rag_graph_test.go
T
Benjamin_Boenisch 230dc05287
CI / detect-changes (push) Successful in 8s
CI / branch-name (push) Has been skipped
CI / guardrail-integrity (push) Has been skipped
CI / secret-scan (push) Has been skipped
CI / dep-audit (push) Has been skipped
CI / build-sha-integrity (push) Successful in 6s
CI / sbom-scan (push) Has been skipped
CI / validate-canonical-controls (push) Successful in 6s
CI / go-lint (push) Has been skipped
CI / loc-budget (push) Successful in 19s
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / nodejs-build (push) Successful in 3m1s
CI / test-go (push) Successful in 59s
CI / iace-gt-coverage (push) Successful in 22s
CI / test-python-backend (push) Has been skipped
CI / test-python-document-crawler (push) Has been skipped
CI / test-python-dsms-gateway (push) Has been skipped
feat(ai-sdk): legal-corpus coverage + Phase-2 citation-graph assessment (#33)
2026-06-24 06:37:22 +00:00

90 lines
2.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ucca
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestGetStringSlice(t *testing.T) {
m := map[string]interface{}{
"refs": []interface{}{"a", "b", 3, "c"}, // non-strings are skipped
"str": "not-a-list",
}
got := getStringSlice(m, "refs")
if len(got) != 3 || got[0] != "a" || got[2] != "c" {
t.Errorf("refs: %v", got)
}
if getStringSlice(m, "missing") != nil {
t.Error("missing key should be nil")
}
if getStringSlice(m, "str") != nil {
t.Error("non-list should be nil")
}
}
func TestTopByScore_DeterministicCap(t *testing.T) {
m := map[string]float64{"x": 0.5, "y": 0.9, "z": 0.5, "w": 0.7}
got := topByScore(m, 2)
if len(got) != 2 || got[0] != "y" || got[1] != "w" {
t.Errorf("want [y w], got %v", got)
}
all := topByScore(m, 10)
if all[2] != "x" || all[3] != "z" { // tie 0.5 broken by key string
t.Errorf("tie-break not deterministic: %v", all)
}
}
func TestExpandViaGraph_NoSeedsOrRefs(t *testing.T) {
c := &LegalRAGClient{} // nil httpClient → must not be called on these paths
if out := c.expandViaGraph(context.Background(), "x", nil); out != nil {
t.Error("empty hits should return nil")
}
hits := []qdrantSearchHit{{ID: 1, Score: 0.8, Payload: map[string]interface{}{"citation_unit": "Art. 1 CRA"}}}
if out := c.expandViaGraph(context.Background(), "x", hits); len(out) != 1 {
t.Errorf("no references → unchanged, got %d", len(out))
}
}
func TestExpandViaGraph_PullsConnectedNorm(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"result": map[string]interface{}{
"points": []map[string]interface{}{
{"id": 99, "payload": map[string]interface{}{
"citation_unit": "CRA Anhang I", "chunk_text": "Sicherheitsanforderungen",
"source_class": "binding_law", "authority_weight": 100, "regulation_short": "CRA",
}},
},
"next_page_offset": nil,
},
})
}))
defer srv.Close()
c := &LegalRAGClient{qdrantURL: srv.URL, httpClient: srv.Client()}
hits := []qdrantSearchHit{
{ID: 1, Score: 0.70, Payload: map[string]interface{}{
"citation_unit": "Art. 13 CRA", "references_out": []interface{}{"CRA Anhang I"},
}},
}
out := c.expandViaGraph(context.Background(), "bp_compliance_ce", hits)
if len(out) != 2 {
t.Fatalf("want 2 hits (seed + connected annex), got %d", len(out))
}
var found *qdrantSearchHit
for i := range out {
if getString(out[i].Payload, "citation_unit") == "CRA Anhang I" {
found = &out[i]
}
}
if found == nil {
t.Fatal("connected norm CRA Anhang I was not pulled into the pool")
}
if found.Score < 0.64 || found.Score > 0.66 { // 0.70 seed 0.05 hop penalty
t.Errorf("connected score = %v, want ~0.65", found.Score)
}
}