Files
breakpilot-compliance/ai-compliance-sdk/internal/ucca/legal_rag_graph.go
T
Benjamin_Boenisch 230dc05287
CI / detect-changes (push) Successful in 8s
CI / branch-name (push) Has been skipped
CI / guardrail-integrity (push) Has been skipped
CI / secret-scan (push) Has been skipped
CI / dep-audit (push) Has been skipped
CI / build-sha-integrity (push) Successful in 6s
CI / sbom-scan (push) Has been skipped
CI / validate-canonical-controls (push) Successful in 6s
CI / go-lint (push) Has been skipped
CI / loc-budget (push) Successful in 19s
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / nodejs-build (push) Successful in 3m1s
CI / test-go (push) Successful in 59s
CI / iace-gt-coverage (push) Successful in 22s
CI / test-python-backend (push) Has been skipped
CI / test-python-document-crawler (push) Has been skipped
CI / test-python-dsms-gateway (push) Has been skipped
feat(ai-sdk): legal-corpus coverage + Phase-2 citation-graph assessment (#33)
2026-06-24 06:37:22 +00:00

163 lines
5.0 KiB
Go

package ucca
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
)
// Graph-augmented retrieval: when a top hit cites an annex/article (references_out)
// or is cited by one (references_in), pull that connected norm into the candidate
// pool via the PRECISE citation graph instead of hoping semantic search surfaces
// it. E.g. a hit on CRA Art. 13 pulls in CRA Anhang I (the actual requirement).
// Pool-augmentation only — authority re-rank + topK slice still apply, so the
// response schema is unchanged.
const (
graphSeedCount = 5 // only the top hits seed the expansion
graphMaxExpand = 15 // cap connected norms pulled in (avoid pool explosion)
graphHopPenalty = 0.05 // a one-hop neighbour ranks just below its seed
)
// expandViaGraph augments hits with the norms they cite and the norms that cite
// them. Best-effort: on any error (or nothing to expand) the original hits are
// returned unchanged.
func (c *LegalRAGClient) expandViaGraph(ctx context.Context, collection string, hits []qdrantSearchHit) []qdrantSearchHit {
if len(hits) == 0 {
return hits
}
present := make(map[string]bool, len(hits))
for _, h := range hits {
if cu := getString(h.Payload, "citation_unit"); cu != "" {
present[cu] = true
}
}
seeds := hits
if len(seeds) > graphSeedCount {
seeds = seeds[:graphSeedCount]
}
// Forward edges only (references_out = the detail a hit explicitly points to,
// e.g. Art. 13 → Anhang I). Reverse (references_in) has high fan-out for popular
// annexes (Anhang I is cited by 23 articles) → pool flooding; it is surfaced as
// connected-norm metadata in the Phase 2 response instead of expanding the pool.
want := make(map[string]float64) // connected citation_unit -> best seeding score
for _, h := range seeds {
for _, cu := range getStringSlice(h.Payload, "references_out") {
if cu == "" || present[cu] {
continue
}
if s, ok := want[cu]; !ok || h.Score > s {
want[cu] = h.Score
}
}
}
if len(want) == 0 {
return hits
}
units := topByScore(want, graphMaxExpand)
fetched, err := c.fetchByCitationUnits(ctx, collection, units)
if err != nil || len(fetched) == 0 {
return hits
}
neighbours := make([]qdrantSearchHit, 0, len(fetched))
for cu, pt := range fetched {
neighbours = append(neighbours, qdrantSearchHit{ID: pt.ID, Score: want[cu] - graphHopPenalty, Payload: pt.Payload})
}
return mergeDedupHits(hits, neighbours)
}
// topByScore returns up to n keys with the highest values. Deterministic: ties
// broken by the key string so the cap is stable across runs.
func topByScore(m map[string]float64, n int) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if m[keys[i]] != m[keys[j]] {
return m[keys[i]] > m[keys[j]]
}
return keys[i] < keys[j]
})
if len(keys) > n {
keys = keys[:n]
}
return keys
}
// fetchByCitationUnits loads one representative point (the first chunk) per
// citation_unit from the given collection.
func (c *LegalRAGClient) fetchByCitationUnits(ctx context.Context, collection string, units []string) (map[string]qdrantScrollPoint, error) {
should := make([]map[string]interface{}, 0, len(units))
for _, cu := range units {
should = append(should, map[string]interface{}{"key": "citation_unit", "match": map[string]interface{}{"value": cu}})
}
reqBody := map[string]interface{}{
"limit": len(units) * 4,
"with_payload": true,
"with_vectors": false,
"filter": map[string]interface{}{"should": should},
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
url := fmt.Sprintf("%s/collections/%s/points/scroll", c.qdrantURL, collection)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.qdrantAPIKey != "" {
req.Header.Set("api-key", c.qdrantAPIKey)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("qdrant scroll returned %d: %s", resp.StatusCode, string(body))
}
var scrollResp qdrantScrollResponse
if err := json.NewDecoder(resp.Body).Decode(&scrollResp); err != nil {
return nil, err
}
out := make(map[string]qdrantScrollPoint, len(units))
for _, pt := range scrollResp.Result.Points {
cu := getString(pt.Payload, "citation_unit")
if cu != "" {
if _, seen := out[cu]; !seen {
out[cu] = pt
}
}
}
return out, nil
}
// getStringSlice extracts a []string from a Qdrant payload list field
// (references_out / references_in are stored as JSON arrays of strings).
func getStringSlice(m map[string]interface{}, key string) []string {
v, ok := m[key]
if !ok {
return nil
}
arr, ok := v.([]interface{})
if !ok {
return nil
}
out := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok {
out = append(out, s)
}
}
return out
}