Skip to content

Commit

Permalink
Optimize 2 keyword multi-terms aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeshkr419 committed Jun 3, 2024
1 parent a22c1fd commit ae34048
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ protected Aggregator createInternal(
configs.stream()
.map(config -> queryShardContext.getValuesSourceRegistry().getAggregator(REGISTRY_KEY, config.v1()).build(config))
.collect(Collectors.toList()),
configs.stream().map(config -> config.v1().getValuesSource()).collect(Collectors.toList()),
configs.stream().map(c -> c.v1().format()).collect(Collectors.toList()),
order,
collectMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@

package org.opensearch.search.aggregations.bucket.terms;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.NumericUtils;
import org.apache.lucene.util.PriorityQueue;
Expand Down Expand Up @@ -62,19 +69,23 @@ public class MultiTermsAggregator extends DeferableBucketAggregator {

private final BytesKeyedBucketOrds bucketOrds;
private final MultiTermsValuesSource multiTermsValue;
private final List<ValuesSource> valuesSources;
private final boolean showTermDocCountError;
private final List<DocValueFormat> formats;
private final TermsAggregator.BucketCountThresholds bucketCountThresholds;
private final BucketOrder order;
private final Comparator<InternalMultiTerms.Bucket> partiallyBuiltBucketComparator;
private final SubAggCollectionMode collectMode;
private final Set<Aggregator> aggsUsedForSorting = new HashSet<>();
private Weight weight;
private static final Logger logger = LogManager.getLogger(MultiTermsAggregator.class);

public MultiTermsAggregator(
String name,
AggregatorFactories factories,
boolean showTermDocCountError,
List<InternalValuesSource> internalValuesSources,
List<ValuesSource> valuesSources,
List<DocValueFormat> formats,
BucketOrder order,
SubAggCollectionMode collectMode,
Expand All @@ -87,6 +98,7 @@ public MultiTermsAggregator(
super(name, factories, context, parent, metadata);
this.bucketOrds = BytesKeyedBucketOrds.build(context.bigArrays(), cardinality);
this.multiTermsValue = new MultiTermsValuesSource(internalValuesSources);
this.valuesSources = valuesSources;
this.showTermDocCountError = showTermDocCountError;
this.formats = formats;
this.bucketCountThresholds = bucketCountThresholds;
Expand Down Expand Up @@ -173,6 +185,10 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
return result;
}

public void setWeight(Weight weight) {
this.weight = weight;
}

InternalMultiTerms buildResult(long owningBucketOrd, long otherDocCount, InternalMultiTerms.Bucket[] topBuckets) {
BucketOrder reduceOrder;
if (isKeyOrder(order) == false) {
Expand Down Expand Up @@ -213,8 +229,116 @@ public InternalAggregation buildEmptyAggregation() {
);
}

private LeafBucketCollector getTermFrequencies(LeafReaderContext ctx) throws IOException {
// Instead of visiting doc values for each document, utilize posting data directly to get each composite bucket intersection
// For example, if we have a composite key of (a, b) where a is from field1 & b is from field2
// We can a find all the composite buckets by visiting both the posting lists
// and counting all the documents that intersect for each composite bucket.
// This is much faster than visiting the doc values for each document.

if (weight == null || weight.count(ctx) != ctx.reader().maxDoc()) {
// Weight not assigned - cannot use this optimization
// weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and
// top-level query matches all docs in the segment
return null;
}

String field1, field2;
// Restricting the number of fields to 2 and only keyword fields with FieldData available
if (this.valuesSources.size() == 2 &&
this.valuesSources.get(0) instanceof ValuesSource.Bytes.WithOrdinals.FieldData &&
this.valuesSources.get(1) instanceof ValuesSource.Bytes.WithOrdinals.FieldData) {
field1 = ((ValuesSource.Bytes.WithOrdinals.FieldData) valuesSources.get(0)).getIndexFieldName();
field2 = ((ValuesSource.Bytes.WithOrdinals.FieldData) valuesSources.get(1)).getIndexFieldName();

} else {
return null;
}

Terms segmentTerms1 = ctx.reader().terms(field1);
Terms segmentTerms2 = ctx.reader().terms(field2);

// TODO in this PR itself in coming commits:
// 1/ add check for fields cardinality - this might be ineffective for very high cardinality
// 2/ check for filter applied or not as default implementation might be resolving it as part of aggregation

TermsEnum segmentTermsEnum1 = segmentTerms1.iterator();

while (segmentTermsEnum1.next() != null) {
TermsEnum segmentTermsEnum2 = segmentTerms2.iterator();

while (segmentTermsEnum2.next() != null) {

PostingsEnum postings1 = segmentTermsEnum1.postings(null);
postings1.nextDoc();

PostingsEnum postings2 = segmentTermsEnum2.postings(null);
postings2.nextDoc();

int bucketCount = 0;

while (postings1.docID() != PostingsEnum.NO_MORE_DOCS &&
postings2.docID() != PostingsEnum.NO_MORE_DOCS) {

// Count of intersecting docs to get number of docs in each bucket
if (postings1.docID() == postings2.docID()) {
bucketCount++;
postings1.nextDoc();
postings2.nextDoc();
} else if (postings1.docID() < postings2.docID()) {
postings1.advance(postings2.docID());
} else {
postings2.advance(postings1.docID());
}
}

// For a key formed by value of t1 & a value of t2, create a composite key, convert it to byte ref and then update the ordinal data with count computed above
// The ordinal data is used to collect the sub-aggregations for each composite key
// The composite key is used to collect the buckets for each composite key
BytesRef v1 = segmentTermsEnum1.term();
BytesRef v2 = segmentTermsEnum2.term();


TermValue<BytesRef> termValue1 = new TermValue<>(v1, TermValue.BYTES_REF_WRITER);
TermValue<BytesRef> termValue2 = new TermValue<>(v2, TermValue.BYTES_REF_WRITER);


final BytesStreamOutput scratch = new BytesStreamOutput();
scratch.writeVInt(2); // number of fields per composite key
termValue1.writeTo(scratch);
termValue2.writeTo(scratch);
BytesRef compositeKeyBytesRef = scratch.bytes().toBytesRef(); //composite key formed
scratch.close();

long bucketOrd = bucketOrds.add(0, compositeKeyBytesRef);
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
}
incrementBucketDocCount(bucketOrd, bucketCount);
}
}


return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
throw new CollectionTerminatedException();
}
};

}
@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {

LeafBucketCollector optimizedCollector = this.getTermFrequencies(ctx);

if (optimizedCollector != null) {
logger.info("optimization used");
return optimizedCollector;
}

logger.info("optimization not not used");

MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
return new LeafBucketCollector() {
@Override
Expand Down Expand Up @@ -256,7 +380,7 @@ private boolean subAggsNeedScore() {

@Override
protected boolean shouldDefer(Aggregator aggregator) {
return collectMode == Aggregator.SubAggCollectionMode.BREADTH_FIRST && !aggsUsedForSorting.contains(aggregator);
return collectMode == SubAggCollectionMode.BREADTH_FIRST && !aggsUsedForSorting.contains(aggregator);
}

private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.aggregations.support.MultiTermsValuesSourceConfig;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceType;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.lookup.LeafDocLookup;
Expand Down Expand Up @@ -102,6 +103,7 @@ public class MultiTermsAggregatorTests extends AggregatorTestCase {
private static final String FLOAT_FIELD = "float";
private static final String DOUBLE_FIELD = "double";
private static final String KEYWORD_FIELD = "keyword";
private static final String KEYWORD_FIELD2 = "keyword2";
private static final String DATE_FIELD = "date";
private static final String IP_FIELD = "ip";
private static final String GEO_POINT_FIELD = "geopoint";
Expand All @@ -116,6 +118,7 @@ public class MultiTermsAggregatorTests extends AggregatorTestCase {
put(DOUBLE_FIELD, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD, NumberFieldMapper.NumberType.DOUBLE));
put(DATE_FIELD, dateFieldType(DATE_FIELD));
put(KEYWORD_FIELD, new KeywordFieldMapper.KeywordFieldType(KEYWORD_FIELD));
put(KEYWORD_FIELD2, new KeywordFieldMapper.KeywordFieldType(KEYWORD_FIELD2));
put(IP_FIELD, new IpFieldMapper.IpFieldType(IP_FIELD));
put(FIELD_NAME, new NumberFieldMapper.NumberFieldType(FIELD_NAME, NumberFieldMapper.NumberType.INTEGER));
put(UNRELATED_KEYWORD_FIELD, new KeywordFieldMapper.KeywordFieldType(UNRELATED_KEYWORD_FIELD));
Expand Down Expand Up @@ -306,6 +309,41 @@ public void testMixNumberAndKeyword() throws IOException {
});
}

public void testKeywordAndKeywordField() throws IOException {
testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, KEYWORD_FIELD2)), NONE_DECORATOR, iw -> {
iw.addDocument(
asList(
new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")),
new StringField(KEYWORD_FIELD, new BytesRef("a"), Field.Store.NO),
new SortedSetDocValuesField(KEYWORD_FIELD2, new BytesRef("n")),
new StringField(KEYWORD_FIELD2, new BytesRef("n"), Field.Store.NO)
)
);
iw.addDocument(
asList(
new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")),
new StringField(KEYWORD_FIELD, new BytesRef("a"), Field.Store.NO),
new SortedSetDocValuesField(KEYWORD_FIELD2, new BytesRef("n")),
new StringField(KEYWORD_FIELD2, new BytesRef("n"), Field.Store.NO)
)
);
iw.addDocument(
asList(
new SortedSetDocValuesField(KEYWORD_FIELD, new BytesRef("a")),
new StringField(KEYWORD_FIELD, new BytesRef("a"), Field.Store.NO),
new SortedSetDocValuesField(KEYWORD_FIELD2, new BytesRef("m")),
new StringField(KEYWORD_FIELD2, new BytesRef("m"), Field.Store.NO)
)
);
}, h -> {
MatcherAssert.assertThat(h.getBuckets(), hasSize(2));
MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo("n")));
MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L));
MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("a"), equalTo("m")));
MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L));
});
}

public void testMultiValuesField() throws IOException {
testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, INT_FIELD)), NONE_DECORATOR, iw -> {
iw.addDocument(
Expand Down Expand Up @@ -885,6 +923,7 @@ public void testEmptyAggregations() throws IOException {
AggregatorFactories factories = AggregatorFactories.EMPTY;
boolean showTermDocCountError = true;
MultiTermsAggregator.InternalValuesSource internalValuesSources = mock(MultiTermsAggregator.InternalValuesSource.class);
ValuesSource valuesSource = mock(ValuesSource.class);
DocValueFormat format = mock(DocValueFormat.class);
BucketOrder order = mock(BucketOrder.class);
Aggregator.SubAggCollectionMode collectMode = Aggregator.SubAggCollectionMode.BREADTH_FIRST;
Expand All @@ -901,6 +940,7 @@ public void testEmptyAggregations() throws IOException {
factories,
showTermDocCountError,
List.of(internalValuesSources),
List.of(valuesSource),
List.of(format),
order,
collectMode,
Expand Down

0 comments on commit ae34048

Please sign in to comment.