diff --git a/index/postings.go b/index/postings.go index 1f2ee0d6..13df1c69 100644 --- a/index/postings.go +++ b/index/postings.go @@ -366,80 +366,25 @@ func Merge(its ...Postings) Postings { if len(its) == 1 { return its[0] } - l := len(its) / 2 - return newMergedPostings(Merge(its[:l]...), Merge(its[l:]...)) -} - -type mergedPostings struct { - a, b Postings - initialized bool - aok, bok bool - cur uint64 -} - -func newMergedPostings(a, b Postings) *mergedPostings { - return &mergedPostings{a: a, b: b} -} - -func (it *mergedPostings) At() uint64 { - return it.cur -} - -func (it *mergedPostings) Next() bool { - if !it.initialized { - it.aok = it.a.Next() - it.bok = it.b.Next() - it.initialized = true - } - - if !it.aok && !it.bok { - return false - } - - if !it.aok { - it.cur = it.b.At() - it.bok = it.b.Next() - return true - } - if !it.bok { - it.cur = it.a.At() - it.aok = it.a.Next() - return true - } - - acur, bcur := it.a.At(), it.b.At() - - if acur < bcur { - it.cur = acur - it.aok = it.a.Next() - } else if acur > bcur { - it.cur = bcur - it.bok = it.b.Next() - } else { - it.cur = acur - it.aok = it.a.Next() - it.bok = it.b.Next() - } - return true -} - -func (it *mergedPostings) Seek(id uint64) bool { - if it.cur >= id { - return true + // All the uses of this function immediately expand it, so + // collect everything in a map. This is more efficient + // when there's 100ks of postings, compared to + // having a tree of merge objects. + pm := make(map[uint64]struct{}, len(its)) + for _, it := range its { + for it.Next() { + pm[it.At()] = struct{}{} + } + if it.Err() != nil { + return ErrPostings(it.Err()) + } } - - it.aok = it.a.Seek(id) - it.bok = it.b.Seek(id) - it.initialized = true - - return it.Next() -} - -func (it *mergedPostings) Err() error { - if it.a.Err() != nil { - return it.a.Err() + pl := make([]uint64, 0, len(pm)) + for p := range pm { + pl = append(pl, p) } - return it.b.Err() + sort.Slice(pl, func(i, j int) bool { return pl[i] < pl[j] }) + return newListPostings(pl) } // Without returns a new postings list that contains all elements from the full list that diff --git a/index/postings_test.go b/index/postings_test.go index 53a9d95f..54c37f48 100644 --- a/index/postings_test.go +++ b/index/postings_test.go @@ -233,7 +233,7 @@ func TestMergedPostings(t *testing.T) { a := newListPostings(c.a) b := newListPostings(c.b) - res, err := ExpandPostings(newMergedPostings(a, b)) + res, err := ExpandPostings(Merge(a, b)) testutil.Ok(t, err) testutil.Equals(t, c.res, res) } @@ -286,7 +286,7 @@ func TestMergedPostingsSeek(t *testing.T) { a := newListPostings(c.a) b := newListPostings(c.b) - p := newMergedPostings(a, b) + p := Merge(a, b) testutil.Equals(t, c.success, p.Seek(c.seek)) @@ -546,7 +546,7 @@ func TestIntersectWithMerge(t *testing.T) { // https://github.com/prometheus/prometheus/issues/2616 a := newListPostings([]uint64{21, 22, 23, 24, 25, 30}) - b := newMergedPostings( + b := Merge( newListPostings([]uint64{10, 20, 30}), newListPostings([]uint64{15, 26, 30}), )