package search import ( "context" "encoding/json" "fmt" "strings" "github.com/opensearch-project/opensearch-go/v2" "github.com/opensearch-project/opensearch-go/v2/opensearchapi" ) // SearchRequest represents an API search request type SearchRequest struct { Query string `json:"q"` Mode string `json:"mode"` // keyword, semantic, hybrid Limit int `json:"limit"` Offset int `json:"offset"` Filters SearchFilters `json:"filters"` Rerank bool `json:"rerank"` Include SearchInclude `json:"include"` } // SearchFilters for narrowing results type SearchFilters struct { Language []string `json:"language"` CountryHint []string `json:"country_hint"` SourceCategory []string `json:"source_category"` DocType []string `json:"doc_type"` SchoolLevel []string `json:"school_level"` Subjects []string `json:"subjects"` State []string `json:"state"` MinTrustScore float64 `json:"min_trust_score"` DateFrom string `json:"date_from"` } // SearchInclude specifies what to include in response type SearchInclude struct { Snippets bool `json:"snippets"` Highlights bool `json:"highlights"` ContentText bool `json:"content_text"` } // SearchResult represents a single search result type SearchResult struct { DocID string `json:"doc_id"` Title string `json:"title"` URL string `json:"url"` Domain string `json:"domain"` Language string `json:"language"` DocType string `json:"doc_type"` SchoolLevel string `json:"school_level"` Subjects []string `json:"subjects"` Scores Scores `json:"scores"` Snippet string `json:"snippet,omitempty"` Highlights []string `json:"highlights,omitempty"` } // Scores contains all scoring components type Scores struct { BM25 float64 `json:"bm25"` Semantic float64 `json:"semantic"` Rerank float64 `json:"rerank"` Trust float64 `json:"trust"` Quality float64 `json:"quality"` Final float64 `json:"final"` } // SearchResponse is the API response type SearchResponse struct { QueryID string `json:"query_id"` Results []SearchResult `json:"results"` Pagination Pagination `json:"pagination"` } // Pagination info type Pagination struct { Limit int `json:"limit"` Offset int `json:"offset"` TotalEstimate int `json:"total_estimate"` } // EmbeddingProvider interface for generating embeddings type EmbeddingProvider interface { Embed(ctx context.Context, text string) ([]float32, error) IsEnabled() bool Dimension() int } // Service handles search operations type Service struct { client *opensearch.Client indexName string embeddingProvider EmbeddingProvider semanticEnabled bool } // NewService creates a new search service func NewService(url, username, password, indexName string) (*Service, error) { cfg := opensearch.Config{ Addresses: []string{url}, Username: username, Password: password, } client, err := opensearch.NewClient(cfg) if err != nil { return nil, err } return &Service{ client: client, indexName: indexName, semanticEnabled: false, }, nil } // SetEmbeddingProvider configures the embedding provider for semantic search func (s *Service) SetEmbeddingProvider(provider EmbeddingProvider) { if provider != nil && provider.IsEnabled() { s.embeddingProvider = provider s.semanticEnabled = true } } // IsSemanticEnabled returns true if semantic search is available func (s *Service) IsSemanticEnabled() bool { return s.semanticEnabled && s.embeddingProvider != nil } // Search performs a search query func (s *Service) Search(ctx context.Context, req *SearchRequest) (*SearchResponse, error) { // Determine search mode mode := req.Mode if mode == "" { mode = "keyword" // Default to keyword search } // For semantic/hybrid modes, generate query embedding var queryEmbedding []float32 var embErr error if (mode == "semantic" || mode == "hybrid") && s.IsSemanticEnabled() { queryEmbedding, embErr = s.embeddingProvider.Embed(ctx, req.Query) if embErr != nil { // Fall back to keyword search if embedding fails mode = "keyword" } } else if mode == "semantic" || mode == "hybrid" { // Semantic requested but not enabled, fall back mode = "keyword" } // Build OpenSearch query based on mode var query map[string]interface{} switch mode { case "semantic": query = s.buildSemanticQuery(req, queryEmbedding) case "hybrid": query = s.buildHybridQuery(req, queryEmbedding) default: query = s.buildQuery(req) } queryJSON, err := json.Marshal(query) if err != nil { return nil, err } searchReq := opensearchapi.SearchRequest{ Index: []string{s.indexName}, Body: strings.NewReader(string(queryJSON)), } res, err := searchReq.Do(ctx, s.client) if err != nil { return nil, err } defer res.Body.Close() // Parse response var osResponse struct { Hits struct { Total struct { Value int `json:"value"` } `json:"total"` Hits []struct { ID string `json:"_id"` Score float64 `json:"_score"` Source map[string]interface{} `json:"_source"` Highlight map[string][]string `json:"highlight,omitempty"` } `json:"hits"` } `json:"hits"` } if err := json.NewDecoder(res.Body).Decode(&osResponse); err != nil { return nil, err } // Convert to SearchResults results := make([]SearchResult, 0, len(osResponse.Hits.Hits)) for _, hit := range osResponse.Hits.Hits { result := s.hitToResult(hit.Source, hit.Score, hit.Highlight, req.Include) results = append(results, result) } return &SearchResponse{ QueryID: fmt.Sprintf("q-%d", ctx.Value("request_id")), Results: results, Pagination: Pagination{ Limit: req.Limit, Offset: req.Offset, TotalEstimate: osResponse.Hits.Total.Value, }, }, nil } // buildQuery constructs the OpenSearch query func (s *Service) buildQuery(req *SearchRequest) map[string]interface{} { // Main query must := []map[string]interface{}{} filter := []map[string]interface{}{} // Text search if req.Query != "" { must = append(must, map[string]interface{}{ "multi_match": map[string]interface{}{ "query": req.Query, "fields": []string{"title^3", "content_text"}, "type": "best_fields", }, }) } // Filters if len(req.Filters.Language) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"language": req.Filters.Language}, }) } if len(req.Filters.CountryHint) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"country_hint": req.Filters.CountryHint}, }) } if len(req.Filters.SourceCategory) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"source_category": req.Filters.SourceCategory}, }) } if len(req.Filters.DocType) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"doc_type": req.Filters.DocType}, }) } if len(req.Filters.SchoolLevel) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"school_level": req.Filters.SchoolLevel}, }) } if len(req.Filters.Subjects) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"subjects": req.Filters.Subjects}, }) } if len(req.Filters.State) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"state": req.Filters.State}, }) } if req.Filters.MinTrustScore > 0 { filter = append(filter, map[string]interface{}{ "range": map[string]interface{}{ "trust_score": map[string]interface{}{"gte": req.Filters.MinTrustScore}, }, }) } if req.Filters.DateFrom != "" { filter = append(filter, map[string]interface{}{ "range": map[string]interface{}{ "fetch_time": map[string]interface{}{"gte": req.Filters.DateFrom}, }, }) } // Build bool query boolQuery := map[string]interface{}{} if len(must) > 0 { boolQuery["must"] = must } if len(filter) > 0 { boolQuery["filter"] = filter } // Construct full query query := map[string]interface{}{ "query": map[string]interface{}{ "bool": boolQuery, }, "from": req.Offset, "size": req.Limit, "_source": []string{ "doc_id", "title", "url", "domain", "language", "doc_type", "school_level", "subjects", "trust_score", "quality_score", "snippet_text", }, } // Add highlighting if requested if req.Include.Highlights { query["highlight"] = map[string]interface{}{ "fields": map[string]interface{}{ "title": map[string]interface{}{}, "content_text": map[string]interface{}{"fragment_size": 150, "number_of_fragments": 3}, }, } } // Add function score for trust/quality boosting query["query"] = map[string]interface{}{ "function_score": map[string]interface{}{ "query": query["query"], "functions": []map[string]interface{}{ { "field_value_factor": map[string]interface{}{ "field": "trust_score", "factor": 1.5, "modifier": "sqrt", "missing": 0.5, }, }, { "field_value_factor": map[string]interface{}{ "field": "quality_score", "factor": 1.0, "modifier": "sqrt", "missing": 0.5, }, }, }, "score_mode": "multiply", "boost_mode": "multiply", }, } return query } // buildSemanticQuery constructs a pure vector search query using k-NN func (s *Service) buildSemanticQuery(req *SearchRequest, embedding []float32) map[string]interface{} { filter := s.buildFilters(req) // k-NN query for semantic search knnQuery := map[string]interface{}{ "content_embedding": map[string]interface{}{ "vector": embedding, "k": req.Limit + req.Offset, // Get enough results for pagination }, } // Add filter if present if len(filter) > 0 { knnQuery["content_embedding"].(map[string]interface{})["filter"] = map[string]interface{}{ "bool": map[string]interface{}{ "filter": filter, }, } } query := map[string]interface{}{ "knn": knnQuery, "from": req.Offset, "size": req.Limit, "_source": []string{ "doc_id", "title", "url", "domain", "language", "doc_type", "school_level", "subjects", "trust_score", "quality_score", "snippet_text", }, } // Add highlighting if requested if req.Include.Highlights { query["highlight"] = map[string]interface{}{ "fields": map[string]interface{}{ "title": map[string]interface{}{}, "content_text": map[string]interface{}{"fragment_size": 150, "number_of_fragments": 3}, }, } } return query } // buildHybridQuery constructs a combined BM25 + vector search query func (s *Service) buildHybridQuery(req *SearchRequest, embedding []float32) map[string]interface{} { filter := s.buildFilters(req) // Build the bool query for BM25 must := []map[string]interface{}{} if req.Query != "" { must = append(must, map[string]interface{}{ "multi_match": map[string]interface{}{ "query": req.Query, "fields": []string{"title^3", "content_text"}, "type": "best_fields", }, }) } boolQuery := map[string]interface{}{} if len(must) > 0 { boolQuery["must"] = must } if len(filter) > 0 { boolQuery["filter"] = filter } // Convert embedding to []interface{} for JSON embeddingInterface := make([]interface{}, len(embedding)) for i, v := range embedding { embeddingInterface[i] = v } // Hybrid query using script_score to combine BM25 and cosine similarity // This is a simpler approach than OpenSearch's neural search plugin query := map[string]interface{}{ "query": map[string]interface{}{ "script_score": map[string]interface{}{ "query": map[string]interface{}{ "bool": boolQuery, }, "script": map[string]interface{}{ "source": "cosineSimilarity(params.query_vector, 'content_embedding') + 1.0 + _score * 0.5", "params": map[string]interface{}{ "query_vector": embeddingInterface, }, }, }, }, "from": req.Offset, "size": req.Limit, "_source": []string{ "doc_id", "title", "url", "domain", "language", "doc_type", "school_level", "subjects", "trust_score", "quality_score", "snippet_text", }, } // Add highlighting if requested if req.Include.Highlights { query["highlight"] = map[string]interface{}{ "fields": map[string]interface{}{ "title": map[string]interface{}{}, "content_text": map[string]interface{}{"fragment_size": 150, "number_of_fragments": 3}, }, } } return query } // buildFilters constructs the filter array for queries func (s *Service) buildFilters(req *SearchRequest) []map[string]interface{} { filter := []map[string]interface{}{} if len(req.Filters.Language) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"language": req.Filters.Language}, }) } if len(req.Filters.CountryHint) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"country_hint": req.Filters.CountryHint}, }) } if len(req.Filters.SourceCategory) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"source_category": req.Filters.SourceCategory}, }) } if len(req.Filters.DocType) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"doc_type": req.Filters.DocType}, }) } if len(req.Filters.SchoolLevel) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"school_level": req.Filters.SchoolLevel}, }) } if len(req.Filters.Subjects) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"subjects": req.Filters.Subjects}, }) } if len(req.Filters.State) > 0 { filter = append(filter, map[string]interface{}{ "terms": map[string]interface{}{"state": req.Filters.State}, }) } if req.Filters.MinTrustScore > 0 { filter = append(filter, map[string]interface{}{ "range": map[string]interface{}{ "trust_score": map[string]interface{}{"gte": req.Filters.MinTrustScore}, }, }) } if req.Filters.DateFrom != "" { filter = append(filter, map[string]interface{}{ "range": map[string]interface{}{ "fetch_time": map[string]interface{}{"gte": req.Filters.DateFrom}, }, }) } return filter } // hitToResult converts an OpenSearch hit to SearchResult func (s *Service) hitToResult(source map[string]interface{}, score float64, highlight map[string][]string, include SearchInclude) SearchResult { result := SearchResult{ DocID: getString(source, "doc_id"), Title: getString(source, "title"), URL: getString(source, "url"), Domain: getString(source, "domain"), Language: getString(source, "language"), DocType: getString(source, "doc_type"), SchoolLevel: getString(source, "school_level"), Subjects: getStringArray(source, "subjects"), Scores: Scores{ BM25: score, Trust: getFloat(source, "trust_score"), Quality: getFloat(source, "quality_score"), Final: score, // MVP: final = BM25 * trust * quality (via function_score) }, } if include.Snippets { result.Snippet = getString(source, "snippet_text") } if include.Highlights && highlight != nil { if h, ok := highlight["content_text"]; ok { result.Highlights = h } } return result } // Helper functions func getString(m map[string]interface{}, key string) string { if v, ok := m[key].(string); ok { return v } return "" } func getFloat(m map[string]interface{}, key string) float64 { if v, ok := m[key].(float64); ok { return v } return 0.0 } func getStringArray(m map[string]interface{}, key string) []string { if v, ok := m[key].([]interface{}); ok { result := make([]string, 0, len(v)) for _, item := range v { if s, ok := item.(string); ok { result = append(result, s) } } return result } return nil }