Skip to content

Commit

Permalink
Fix multi-search with raw strings
Browse files Browse the repository at this point in the history
  • Loading branch information
olivere committed Feb 10, 2018
1 parent 89b8c1e commit 8e15c58
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 24 deletions.
9 changes: 6 additions & 3 deletions CONTRIBUTORS
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ John Stanford [@jxstanford](https://github.com/jxstanford)
Josh Chorlton [@jchorl](https://github.com/jchorl)
jun [@coseyo](https://github.com/coseyo)
Junpei Tsuji [@jun06t](https://github.com/jun06t)
kartlee [@kartlee](https://github.com/kartlee)
Keith Hatton [@khatton-ft](https://github.com/khatton-ft)
kel [@liketic](https://github.com/liketic)
Kenta SUZUKI [@suzuken](https://github.com/suzuken)
Expand Down Expand Up @@ -98,10 +99,12 @@ Orne Brocaar [@brocaar](https://github.com/brocaar)
Paul [@eyeamera](https://github.com/eyeamera)
Pete C [@peteclark-ft](https://github.com/peteclark-ft)
Radoslaw Wesolowski [r--w](https://github.com/r--w)
Roman Colohanin [@zuzmic](https://github.com/zuzmic)
Ryan Schmukler [@rschmukler](https://github.com/rschmukler)
Sacheendra talluri [@sacheendra](https://github.com/sacheendra)
Sean DuBois [@Sean-Der](https://github.com/Sean-Der)
Shalin LK [@shalinlk](https://github.com/shalinlk)
singham [@zhaochenxiao90](https://github.com/zhaochenxiao90)
Stephen Kubovic [@stephenkubovic](https://github.com/stephenkubovic)
Stuart Warren [@Woz](https://github.com/stuart-warren)
Sulaiman [@salajlan](https://github.com/salajlan)
Expand All @@ -111,13 +114,13 @@ Take [ww24](https://github.com/ww24)
Tetsuya Morimoto [@t2y](https://github.com/t2y)
TimeEmit [@TimeEmit](https://github.com/timeemit)
TusharM [@tusharm](https://github.com/tusharm)
zhangxin [@visaxin](https://github.com/visaxin)
wangtuo [@wangtuo](https://github.com/wangtuo)
Wédney Yuri [@wedneyyuri](https://github.com/wedneyyuri)
wolfkdy [@wolfkdy](https://github.com/wolfkdy)
Wyndham Blanton [@wyndhblb](https://github.com/wyndhblb)
Yarden Bar [@ayashjorden](https://github.com/ayashjorden)
zakthomas [@zakthomas](https://github.com/zakthomas)
singham [@zhaochenxiao90](https://github.com/zhaochenxiao90)
Yuya Kusakabe [@higebu](https://github.com/higebu)
Zach [@snowzach](https://github.com/snowzach)
zhangxin [@visaxin](https://github.com/visaxin)
@林 [@zplzpl](https://github.com/zplzpl)
Roman Colohanin [@zuzmic](https://github.com/zuzmic)
29 changes: 22 additions & 7 deletions msearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ type MultiSearchService struct {
requests []*SearchRequest
indices []string
pretty bool
routing string
preference string
maxConcurrentRequests *int
preFilterShardSize *int
}

func NewMultiSearchService(client *Client) *MultiSearchService {
builder := &MultiSearchService{
client: client,
requests: make([]*SearchRequest, 0),
indices: make([]string, 0),
}
return builder
}
Expand All @@ -46,6 +44,16 @@ func (s *MultiSearchService) Pretty(pretty bool) *MultiSearchService {
return s
}

func (s *MultiSearchService) MaxConcurrentSearches(max int) *MultiSearchService {
s.maxConcurrentRequests = &max
return s
}

func (s *MultiSearchService) PreFilterShardSize(size int) *MultiSearchService {
s.preFilterShardSize = &size
return s
}

func (s *MultiSearchService) Do(ctx context.Context) (*MultiSearchResult, error) {
// Build url
path := "/_msearch"
Expand All @@ -55,6 +63,12 @@ func (s *MultiSearchService) Do(ctx context.Context) (*MultiSearchResult, error)
if s.pretty {
params.Set("pretty", fmt.Sprintf("%v", s.pretty))
}
if v := s.maxConcurrentRequests; v != nil {
params.Set("max_concurrent_searches", fmt.Sprintf("%v", *v))
}
if v := s.preFilterShardSize; v != nil {
params.Set("pre_filter_shard_size", fmt.Sprintf("%v", *v))
}

// Set body
var lines []string
Expand All @@ -68,14 +82,14 @@ func (s *MultiSearchService) Do(ctx context.Context) (*MultiSearchResult, error)
if err != nil {
return nil, err
}
body, err := json.Marshal(sr.Body())
body, err := sr.Body()
if err != nil {
return nil, err
}
lines = append(lines, string(header))
lines = append(lines, string(body))
lines = append(lines, body)
}
body := strings.Join(lines, "\n") + "\n" // Don't forget trailing \n
body := strings.Join(lines, "\n") + "\n" // add trailing \n

// Get response
res, err := s.client.PerformRequest(ctx, "GET", path, params, body)
Expand All @@ -91,6 +105,7 @@ func (s *MultiSearchService) Do(ctx context.Context) (*MultiSearchResult, error)
return ret, nil
}

// MultiSearchResult is the outcome of running a multi-search operation.
type MultiSearchResult struct {
Responses []*SearchResult `json:"responses,omitempty"`
}
103 changes: 103 additions & 0 deletions msearch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,109 @@ func TestMultiSearch(t *testing.T) {
}
}

func TestMultiSearchWithStrings(t *testing.T) {
client := setupTestClientAndCreateIndex(t)
// client := setupTestClientAndCreateIndexAndLog(t)

tweet1 := tweet{
User: "olivere",
Message: "Welcome to Golang and Elasticsearch.",
Tags: []string{"golang", "elasticsearch"},
}
tweet2 := tweet{
User: "olivere",
Message: "Another unrelated topic.",
Tags: []string{"golang"},
}
tweet3 := tweet{
User: "sandrae",
Message: "Cycling is fun.",
Tags: []string{"sports", "cycling"},
}

// Add all documents
_, err := client.Index().Index(testIndexName).Type("tweet").Id("1").BodyJson(&tweet1).Do(context.TODO())
if err != nil {
t.Fatal(err)
}

_, err = client.Index().Index(testIndexName).Type("tweet").Id("2").BodyJson(&tweet2).Do(context.TODO())
if err != nil {
t.Fatal(err)
}

_, err = client.Index().Index(testIndexName).Type("tweet").Id("3").BodyJson(&tweet3).Do(context.TODO())
if err != nil {
t.Fatal(err)
}

_, err = client.Flush().Index(testIndexName).Do(context.TODO())
if err != nil {
t.Fatal(err)
}

// Spawn two search queries with one roundtrip
sreq1 := NewSearchRequest().Index(testIndexName, testIndexName2).
Source(`{"query":{"match_all":{}}}`)
sreq2 := NewSearchRequest().Index(testIndexName).Type("tweet").
Source(`{"query":{"term":{"tags":"golang"}}}`)

searchResult, err := client.MultiSearch().
Add(sreq1, sreq2).
Do(context.TODO())
if err != nil {
t.Fatal(err)
}
if searchResult.Responses == nil {
t.Fatal("expected responses != nil; got nil")
}
if len(searchResult.Responses) != 2 {
t.Fatalf("expected 2 responses; got %d", len(searchResult.Responses))
}

sres := searchResult.Responses[0]
if sres.Hits == nil {
t.Errorf("expected Hits != nil; got nil")
}
if sres.Hits.TotalHits != 3 {
t.Errorf("expected Hits.TotalHits = %d; got %d", 3, sres.Hits.TotalHits)
}
if len(sres.Hits.Hits) != 3 {
t.Errorf("expected len(Hits.Hits) = %d; got %d", 3, len(sres.Hits.Hits))
}
for _, hit := range sres.Hits.Hits {
if hit.Index != testIndexName {
t.Errorf("expected Hits.Hit.Index = %q; got %q", testIndexName, hit.Index)
}
item := make(map[string]interface{})
err := json.Unmarshal(*hit.Source, &item)
if err != nil {
t.Fatal(err)
}
}

sres = searchResult.Responses[1]
if sres.Hits == nil {
t.Errorf("expected Hits != nil; got nil")
}
if sres.Hits.TotalHits != 2 {
t.Errorf("expected Hits.TotalHits = %d; got %d", 2, sres.Hits.TotalHits)
}
if len(sres.Hits.Hits) != 2 {
t.Errorf("expected len(Hits.Hits) = %d; got %d", 2, len(sres.Hits.Hits))
}
for _, hit := range sres.Hits.Hits {
if hit.Index != testIndexName {
t.Errorf("expected Hits.Hit.Index = %q; got %q", testIndexName, hit.Index)
}
item := make(map[string]interface{})
err := json.Unmarshal(*hit.Source, &item)
if err != nil {
t.Fatal(err)
}
}
}

func TestMultiSearchWithOneRequest(t *testing.T) {
client := setupTestClientAndCreateIndex(t)

Expand Down
49 changes: 35 additions & 14 deletions search_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

package elastic

import "strings"
import (
"encoding/json"
"strings"
)

// SearchRequest combines a search request and its
// query details (see SearchSource).
Expand Down Expand Up @@ -130,17 +133,7 @@ func (r *SearchRequest) SearchSource(searchSource *SearchSource) *SearchRequest
}

func (r *SearchRequest) Source(source interface{}) *SearchRequest {
switch v := source.(type) {
case *SearchSource:
src, err := v.Source()
if err != nil {
// Do not do anything in case of an error
return r
}
r.source = src
default:
r.source = source
}
r.source = source
return r
}

Expand Down Expand Up @@ -200,6 +193,34 @@ func (r *SearchRequest) header() interface{} {
// Body is used e.g. by MultiSearch to get information about the search body
// of one SearchRequest.
// See https://www.elastic.co/guide/en/elasticsearch/reference/5.6/search-multi-search.html
func (r *SearchRequest) Body() interface{} {
return r.source
func (r *SearchRequest) Body() (string, error) {
switch t := r.source.(type) {
default:
body, err := json.Marshal(r.source)
if err != nil {
return "", err
}
return string(body), nil
case *SearchSource:
src, err := t.Source()
if err != nil {
return "", err
}
body, err := json.Marshal(src)
if err != nil {
return "", err
}
return string(body), nil
case json.RawMessage:
return string(t), nil
case *json.RawMessage:
return string(*t), nil
case string:
return t, nil
case *string:
if t != nil {
return *t, nil
}
return "{}", nil
}
}
55 changes: 55 additions & 0 deletions search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,61 @@ func TestSearchSource(t *testing.T) {
}
}

func TestSearchSourceWithString(t *testing.T) {
client := setupTestClientAndCreateIndex(t)

tweet1 := tweet{
User: "olivere", Retweets: 108,
Message: "Welcome to Golang and Elasticsearch.",
Created: time.Date(2012, 12, 12, 17, 38, 34, 0, time.UTC),
}
tweet2 := tweet{
User: "olivere", Retweets: 0,
Message: "Another unrelated topic.",
Created: time.Date(2012, 10, 10, 8, 12, 03, 0, time.UTC),
}
tweet3 := tweet{
User: "sandrae", Retweets: 12,
Message: "Cycling is fun.",
Created: time.Date(2011, 11, 11, 10, 58, 12, 0, time.UTC),
}

// Add all documents
_, err := client.Index().Index(testIndexName).Type("tweet").Id("1").BodyJson(&tweet1).Do(context.TODO())
if err != nil {
t.Fatal(err)
}

_, err = client.Index().Index(testIndexName).Type("tweet").Id("2").BodyJson(&tweet2).Do(context.TODO())
if err != nil {
t.Fatal(err)
}

_, err = client.Index().Index(testIndexName).Type("tweet").Id("3").BodyJson(&tweet3).Do(context.TODO())
if err != nil {
t.Fatal(err)
}

_, err = client.Flush().Index(testIndexName).Do(context.TODO())
if err != nil {
t.Fatal(err)
}

searchResult, err := client.Search().
Index(testIndexName).
Source(`{"query":{"match_all":{}}}`). // sets the JSON request
Do(context.TODO())
if err != nil {
t.Fatal(err)
}
if searchResult.Hits == nil {
t.Errorf("expected SearchResult.Hits != nil; got nil")
}
if searchResult.Hits.TotalHits != 3 {
t.Errorf("expected SearchResult.Hits.TotalHits = %d; got %d", 3, searchResult.Hits.TotalHits)
}
}

func TestSearchRawString(t *testing.T) {
// client := setupTestClientAndCreateIndexAndLog(t, SetTraceLog(log.New(os.Stdout, "", 0)))
client := setupTestClientAndCreateIndex(t)
Expand Down

0 comments on commit 8e15c58

Please sign in to comment.