Skip to content

Commit

Permalink
Save memory when parent and child are not on top (#57892) (#57944)
Browse files Browse the repository at this point in the history
Reworks the `parent` and `child` aggregation are not at the top level
using the optimization from #55873. Instead of wrapping all
non-top-level `parent` and `child` aggregators we now handle being a
child aggregator in the aggregator, specifically by adding recording
which global ordinals show up in the parent and then checking if they
match the child.
  • Loading branch information
nik9000 authored Jun 10, 2020
1 parent 9eb8085 commit 0a2bd10
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,8 @@ protected Aggregator doCreateInternal(ValuesSource rawValuesSource,
}
WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
if (collectsFromSingleBucket) {
return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
parentFilter, valuesSource, maxOrd, metadata);
} else {
return asMultiBucketAggregator(this, searchContext, parent);
}
return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ public class ChildrenToParentAggregator extends ParentJoinAggregator {
public ChildrenToParentAggregator(String name, AggregatorFactories factories,
SearchContext context, Aggregator parent, Query childFilter,
Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, metadata);
long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,8 @@ protected Aggregator doCreateInternal(ValuesSource rawValuesSource,
}
WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
if (collectsFromSingleBucket) {
return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
parentFilter, valuesSource, maxOrd, metadata);
} else {
return asMultiBucketAggregator(this, searchContext, children);
}
return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.internal.SearchContext;

Expand Down Expand Up @@ -68,6 +68,7 @@ public ParentJoinAggregator(String name,
Query outFilter,
ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd,
boolean collectsFromSingleBucket,
Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, metadata);

Expand All @@ -81,8 +82,9 @@ public ParentJoinAggregator(String name,
this.outFilter = context.searcher().createWeight(context.searcher().rewrite(outFilter), ScoreMode.COMPLETE_NO_SCORES, 1f);
this.valuesSource = valuesSource;
boolean singleAggregator = parent == null;
collectionStrategy = singleAggregator ?
new DenseCollectionStrategy(maxOrd, context.bigArrays()) : new SparseCollectionStrategy(context.bigArrays());
collectionStrategy = singleAggregator && collectsFromSingleBucket
? new DenseCollectionStrategy(maxOrd, context.bigArrays())
: new SparseCollectionStrategy(context.bigArrays(), collectsFromSingleBucket);
}

@Override
Expand All @@ -95,19 +97,18 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final Bits parentDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), inFilter.scorerSupplier(ctx));
return new LeafBucketCollector() {
@Override
public void collect(int docId, long bucket) throws IOException {
assert bucket == 0;
public void collect(int docId, long owningBucketOrd) throws IOException {
if (parentDocs.get(docId) && globalOrdinals.advanceExact(docId)) {
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
collectionStrategy.addGlobalOrdinal(globalOrdinal);
collectionStrategy.add(owningBucketOrd, globalOrdinal);
}
}
};
}

@Override
protected final void doPostCollection() throws IOException {
protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {
IndexReader indexReader = context().searcher().getIndexReader();
for (LeafReaderContext ctx : indexReader.leaves()) {
Scorer childDocsScorer = outFilter.scorer(ctx);
Expand Down Expand Up @@ -137,11 +138,21 @@ public int docID() {
if (liveDocs != null && liveDocs.get(docId) == false) {
continue;
}
if (globalOrdinals.advanceExact(docId)) {
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
if (collectionStrategy.existsGlobalOrdinal(globalOrdinal)) {
collectBucket(sub, docId, 0);
if (false == globalOrdinals.advanceExact(docId)) {
continue;
}
int globalOrdinal = (int) globalOrdinals.nextOrd();
assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
/*
* Check if we contain every ordinal. It's almost certainly be
* faster to replay all the matching ordinals and filter them down
* to just those listed in ordsToCollect, but we don't have a data
* structure that maps a primitive long to a list of primitive
* longs.
*/
for (long owningBucketOrd: ordsToCollect) {
if (collectionStrategy.exists(owningBucketOrd, globalOrdinal)) {
collectBucket(sub, docId, owningBucketOrd);
}
}
}
Expand All @@ -160,8 +171,8 @@ protected void doClose() {
* {@code ParentJoinAggregator#outFilter} also have the ordinal.
*/
protected interface CollectionStrategy extends Releasable {
void addGlobalOrdinal(int globalOrdinal);
boolean existsGlobalOrdinal(int globalOrdinal);
void add(long owningBucketOrd, int globalOrdinal);
boolean exists(long owningBucketOrd, int globalOrdinal);
}

/**
Expand All @@ -178,12 +189,14 @@ public DenseCollectionStrategy(long maxOrd, BigArrays bigArrays) {
}

@Override
public void addGlobalOrdinal(int globalOrdinal) {
public void add(long owningBucketOrd, int globalOrdinal) {
assert owningBucketOrd == 0;
ordsBits.set(globalOrdinal);
}

@Override
public boolean existsGlobalOrdinal(int globalOrdinal) {
public boolean exists(long owningBucketOrd, int globalOrdinal) {
assert owningBucketOrd == 0;
return ordsBits.get(globalOrdinal);
}

Expand All @@ -200,20 +213,20 @@ public void close() {
* when only some docs might match.
*/
protected class SparseCollectionStrategy implements CollectionStrategy {
private final LongHash ordsHash;
private final LongKeyedBucketOrds ordsHash;

public SparseCollectionStrategy(BigArrays bigArrays) {
ordsHash = new LongHash(1, bigArrays);
public SparseCollectionStrategy(BigArrays bigArrays, boolean collectsFromSingleBucket) {
ordsHash = LongKeyedBucketOrds.build(bigArrays, collectsFromSingleBucket);
}

@Override
public void addGlobalOrdinal(int globalOrdinal) {
ordsHash.add(globalOrdinal);
public void add(long owningBucketOrd, int globalOrdinal) {
ordsHash.add(owningBucketOrd, globalOrdinal);
}

@Override
public boolean existsGlobalOrdinal(int globalOrdinal) {
return ordsHash.find(globalOrdinal) >= 0;
public boolean exists(long owningBucketOrd, int globalOrdinal) {
return ordsHash.find(owningBucketOrd, globalOrdinal) >= 0;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public class ParentToChildrenAggregator extends ParentJoinAggregator {
public ParentToChildrenAggregator(String name, AggregatorFactories factories,
SearchContext context, Aggregator parent, Query childFilter,
Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
long maxOrd, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, metadata);
long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMin;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -313,8 +312,7 @@ private void testCaseTerms(Query query, IndexSearcher indexSearcher, Consumer<In
throws IOException {

ParentAggregationBuilder aggregationBuilder = new ParentAggregationBuilder("_name", CHILD_TYPE);
aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG)
.field("number"));
aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").field("number"));

MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number");
Expand All @@ -326,9 +324,9 @@ private void testCaseTerms(Query query, IndexSearcher indexSearcher, Consumer<In
private void testCaseTermsParentTerms(Query query, IndexSearcher indexSearcher, Consumer<LongTerms> verify)
throws IOException {
AggregationBuilder aggregationBuilder =
new TermsAggregationBuilder("subvalue_terms").userValueTypeHint(ValueType.LONG).field("subNumber").
new TermsAggregationBuilder("subvalue_terms").field("subNumber").
subAggregation(new ParentAggregationBuilder("to_parent", CHILD_TYPE).
subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG).field("number")));
subAggregation(new TermsAggregationBuilder("value_terms").field("number")));

MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
Expand Down Expand Up @@ -52,7 +53,10 @@
import org.elasticsearch.join.mapper.MetaJoinFieldMapper;
import org.elasticsearch.join.mapper.ParentJoinFieldMapper;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalMin;
import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;

Expand All @@ -64,6 +68,7 @@
import java.util.Map;
import java.util.function.Consumer;

import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -124,12 +129,68 @@ public void testParentChild() throws IOException {
directory.close();
}

public void testParentChildAsSubAgg() throws IOException {
try (Directory directory = newDirectory()) {
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);

final Map<String, Tuple<Integer, Integer>> expectedParentChildRelations = setupIndex(indexWriter);
indexWriter.close();

try (
IndexReader indexReader = ElasticsearchDirectoryReader.wrap(
DirectoryReader.open(directory),
new ShardId(new Index("foo", "_na_"), 1)
)
) {
IndexSearcher indexSearcher = newSearcher(indexReader, false, true);

AggregationBuilder request = new TermsAggregationBuilder("t").field("kwd")
.subAggregation(
new ChildrenAggregationBuilder("children", CHILD_TYPE).subAggregation(
new MinAggregationBuilder("min").field("number")
)
);

long expectedEvenChildCount = 0;
double expectedEvenMin = Double.MAX_VALUE;
long expectedOddChildCount = 0;
double expectedOddMin = Double.MAX_VALUE;
for (Map.Entry<String, Tuple<Integer, Integer>> e : expectedParentChildRelations.entrySet()) {
if (Integer.valueOf(e.getKey().substring("parent".length())) % 2 == 0) {
expectedEvenChildCount += e.getValue().v1();
expectedEvenMin = Math.min(expectedEvenMin, e.getValue().v2());
} else {
expectedOddChildCount += e.getValue().v1();
expectedOddMin = Math.min(expectedOddMin, e.getValue().v2());
}
}
StringTerms result = search(indexSearcher, new MatchAllDocsQuery(), request, longField("number"), keywordField("kwd"));

StringTerms.Bucket evenBucket = result.getBucketByKey("even");
InternalChildren evenChildren = evenBucket.getAggregations().get("children");
InternalMin evenMin = evenChildren.getAggregations().get("min");
assertThat(evenChildren.getDocCount(), equalTo(expectedEvenChildCount));
assertThat(evenMin.getValue(), equalTo(expectedEvenMin));

if (expectedOddChildCount > 0) {
StringTerms.Bucket oddBucket = result.getBucketByKey("odd");
InternalChildren oddChildren = oddBucket.getAggregations().get("children");
InternalMin oddMin = oddChildren.getAggregations().get("min");
assertThat(oddChildren.getDocCount(), equalTo(expectedOddChildCount));
assertThat(oddMin.getValue(), equalTo(expectedOddMin));
} else {
assertNull(result.getBucketByKey("odd"));
}
}
}
}

private static Map<String, Tuple<Integer, Integer>> setupIndex(RandomIndexWriter iw) throws IOException {
Map<String, Tuple<Integer, Integer>> expectedValues = new HashMap<>();
int numParents = randomIntBetween(1, 10);
for (int i = 0; i < numParents; i++) {
String parent = "parent" + i;
iw.addDocument(createParentDocument(parent));
iw.addDocument(createParentDocument(parent, i % 2 == 0 ? "even" : "odd"));
int numChildren = randomIntBetween(1, 10);
int minValue = Integer.MAX_VALUE;
for (int c = 0; c < numChildren; c++) {
Expand All @@ -142,9 +203,10 @@ private static Map<String, Tuple<Integer, Integer>> setupIndex(RandomIndexWriter
return expectedValues;
}

private static List<Field> createParentDocument(String id) {
private static List<Field> createParentDocument(String id, String kwd) {
return Arrays.asList(
new StringField(IdFieldMapper.NAME, Uid.encodeId(id), Field.Store.NO),
new SortedSetDocValuesField("kwd", new BytesRef(kwd)),
new StringField("join_field", PARENT_TYPE, Field.Store.NO),
createJoinField(PARENT_TYPE, id)
);
Expand Down

0 comments on commit 0a2bd10

Please sign in to comment.