package ucca import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "sort" ) // Graph-augmented retrieval: when a top hit cites an annex/article (references_out) // or is cited by one (references_in), pull that connected norm into the candidate // pool via the PRECISE citation graph instead of hoping semantic search surfaces // it. E.g. a hit on CRA Art. 13 pulls in CRA Anhang I (the actual requirement). // Pool-augmentation only — authority re-rank + topK slice still apply, so the // response schema is unchanged. const ( graphSeedCount = 5 // only the top hits seed the expansion graphMaxExpand = 15 // cap connected norms pulled in (avoid pool explosion) graphHopPenalty = 0.05 // a one-hop neighbour ranks just below its seed ) // expandViaGraph augments hits with the norms they cite and the norms that cite // them. Best-effort: on any error (or nothing to expand) the original hits are // returned unchanged. func (c *LegalRAGClient) expandViaGraph(ctx context.Context, collection string, hits []qdrantSearchHit) []qdrantSearchHit { if len(hits) == 0 { return hits } present := make(map[string]bool, len(hits)) for _, h := range hits { if cu := getString(h.Payload, "citation_unit"); cu != "" { present[cu] = true } } seeds := hits if len(seeds) > graphSeedCount { seeds = seeds[:graphSeedCount] } // Forward edges only (references_out = the detail a hit explicitly points to, // e.g. Art. 13 → Anhang I). Reverse (references_in) has high fan-out for popular // annexes (Anhang I is cited by 23 articles) → pool flooding; it is surfaced as // connected-norm metadata in the Phase 2 response instead of expanding the pool. want := make(map[string]float64) // connected citation_unit -> best seeding score for _, h := range seeds { for _, cu := range getStringSlice(h.Payload, "references_out") { if cu == "" || present[cu] { continue } if s, ok := want[cu]; !ok || h.Score > s { want[cu] = h.Score } } } if len(want) == 0 { return hits } units := topByScore(want, graphMaxExpand) fetched, err := c.fetchByCitationUnits(ctx, collection, units) if err != nil || len(fetched) == 0 { return hits } neighbours := make([]qdrantSearchHit, 0, len(fetched)) for cu, pt := range fetched { neighbours = append(neighbours, qdrantSearchHit{ID: pt.ID, Score: want[cu] - graphHopPenalty, Payload: pt.Payload}) } return mergeDedupHits(hits, neighbours) } // topByScore returns up to n keys with the highest values. Deterministic: ties // broken by the key string so the cap is stable across runs. func topByScore(m map[string]float64, n int) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Slice(keys, func(i, j int) bool { if m[keys[i]] != m[keys[j]] { return m[keys[i]] > m[keys[j]] } return keys[i] < keys[j] }) if len(keys) > n { keys = keys[:n] } return keys } // fetchByCitationUnits loads one representative point (the first chunk) per // citation_unit from the given collection. func (c *LegalRAGClient) fetchByCitationUnits(ctx context.Context, collection string, units []string) (map[string]qdrantScrollPoint, error) { should := make([]map[string]interface{}, 0, len(units)) for _, cu := range units { should = append(should, map[string]interface{}{"key": "citation_unit", "match": map[string]interface{}{"value": cu}}) } reqBody := map[string]interface{}{ "limit": len(units) * 4, "with_payload": true, "with_vectors": false, "filter": map[string]interface{}{"should": should}, } jsonBody, err := json.Marshal(reqBody) if err != nil { return nil, err } url := fmt.Sprintf("%s/collections/%s/points/scroll", c.qdrantURL, collection) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") if c.qdrantAPIKey != "" { req.Header.Set("api-key", c.qdrantAPIKey) } resp, err := c.httpClient.Do(req) if err != nil { return nil, err } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("qdrant scroll returned %d: %s", resp.StatusCode, string(body)) } var scrollResp qdrantScrollResponse if err := json.NewDecoder(resp.Body).Decode(&scrollResp); err != nil { return nil, err } out := make(map[string]qdrantScrollPoint, len(units)) for _, pt := range scrollResp.Result.Points { cu := getString(pt.Payload, "citation_unit") if cu != "" { if _, seen := out[cu]; !seen { out[cu] = pt } } } return out, nil } // getStringSlice extracts a []string from a Qdrant payload list field // (references_out / references_in are stored as JSON arrays of strings). func getStringSlice(m map[string]interface{}, key string) []string { v, ok := m[key] if !ok { return nil } arr, ok := v.([]interface{}) if !ok { return nil } out := make([]string, 0, len(arr)) for _, item := range arr { if s, ok := item.(string); ok { out = append(out, s) } } return out }