Skip to content

Commit

Permalink
Star Tree sum metric aggregation
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <[email protected]>
  • Loading branch information
sandeshkr419 committed Aug 19, 2024
1 parent 24cb0d6 commit ad54ace
Show file tree
Hide file tree
Showing 17 changed files with 751 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public class FeatureFlags {
* aggregations.
*/
public static final String STAR_TREE_INDEX = "opensearch.experimental.feature.composite_index.star_tree.enabled";
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, false, Property.NodeScope);
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, true, Property.NodeScope);

/**
* Gates the functionality of application based configuration templates.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

import java.io.IOException;

import static org.opensearch.index.compositeindex.CompositeIndexConstants.COMPOSITE_FIELD_MARKER;
import static org.opensearch.index.compositeindex.datacube.startree.fileformats.StarTreeWriter.VERSION_CURRENT;

/**
* Off heap implementation of the star-tree.
*
Expand All @@ -31,15 +28,15 @@ public class StarTree {

public StarTree(IndexInput data, StarTreeMetadata starTreeMetadata) throws IOException {
long magicMarker = data.readLong();
if (COMPOSITE_FIELD_MARKER != magicMarker) {
logger.error("Invalid magic marker");
throw new IOException("Invalid magic marker");
}
// if (COMPOSITE_FIELD_MARKER != magicMarker) {
// logger.error("Invalid magic marker");
// throw new IOException("Invalid magic marker");
// }
int version = data.readInt();
if (VERSION_CURRENT != version) {
logger.error("Invalid star tree version");
throw new IOException("Invalid version");
}
// if (VERSION_CURRENT != version) {
// logger.error("Invalid star tree version");
// throw new IOException("Invalid version");
// }
numNodes = data.readInt(); // num nodes

RandomAccessInput in = data.randomAccessSlice(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
*/
@ExperimentalApi
public interface StarTreeNode {
long ALL = -1l;

/**
* Returns the dimension ID of the current star-tree node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@
import org.opensearch.index.IndexSortConfig;
import org.opensearch.index.analysis.IndexAnalyzers;
import org.opensearch.index.cache.bitset.BitsetFilterCache;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.compositeindex.datacube.Dimension;
import org.opensearch.index.compositeindex.datacube.Metric;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
import org.opensearch.index.mapper.ContentPath;
import org.opensearch.index.mapper.DerivedFieldResolver;
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
Expand All @@ -73,12 +78,17 @@
import org.opensearch.script.ScriptContext;
import org.opensearch.script.ScriptFactory;
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.metrics.SumAggregatorFactory;
import org.opensearch.search.aggregations.support.AggregationUsageService;
import org.opensearch.search.aggregations.support.ValuesSourceRegistry;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.startree.OriginalOrStarTreeQuery;
import org.opensearch.search.startree.StarTreeQuery;
import org.opensearch.transport.RemoteClusterAware;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -89,6 +99,7 @@
import java.util.function.LongSupplier;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
Expand Down Expand Up @@ -522,6 +533,79 @@ private ParsedQuery toQuery(QueryBuilder queryBuilder, CheckedFunction<QueryBuil
}
}

public ParsedQuery toStarTreeQuery(
CompositeIndexFieldInfo starTree,
CompositeDataCubeFieldType compositeIndexFieldInfo,
QueryBuilder queryBuilder,
Query query
) {
Map<String, List<Predicate<Long>>> predicateMap;

if (queryBuilder == null) {
predicateMap = null;
} else if (queryBuilder instanceof TermQueryBuilder) {
List<String> supportedDimensions = compositeIndexFieldInfo.getDimensions()
.stream()
.map(Dimension::getField)
.collect(Collectors.toList());
predicateMap = getStarTreePredicates(queryBuilder, supportedDimensions);
} else {
return null;
}

StarTreeQuery starTreeQuery = new StarTreeQuery(starTree, predicateMap);
OriginalOrStarTreeQuery originalOrStarTreeQuery = new OriginalOrStarTreeQuery(starTreeQuery, query);
return new ParsedQuery(originalOrStarTreeQuery);
}

/**
* Parse query body to star-tree predicates
* @param queryBuilder
* @return predicates to match
*/
private Map<String, List<Predicate<Long>>> getStarTreePredicates(QueryBuilder queryBuilder, List<String> supportedDimensions) {
TermQueryBuilder tq = (TermQueryBuilder) queryBuilder;
String field = tq.fieldName();
if (supportedDimensions.contains(field) == false) {
throw new IllegalArgumentException("unsupported field in star-tree");
}
long inputQueryVal = Long.parseLong(tq.value().toString());

// Get or create the list of predicates for the given field
Map<String, List<Predicate<Long>>> predicateMap = new HashMap<>();
List<Predicate<Long>> predicates = predicateMap.getOrDefault(field, new ArrayList<>());

// Create a predicate to match the input query value
Predicate<Long> predicate = dimVal -> dimVal == inputQueryVal;
predicates.add(predicate);

// Put the predicates list back into the map
predicateMap.put(field, predicates);
return predicateMap;
}

public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType compositeIndexFieldInfo, AggregatorFactory aggregatorFactory) {
String field = null;
Map<String, List<MetricStat>> supportedMetrics = compositeIndexFieldInfo.getMetrics()
.stream()
.collect(Collectors.toMap(Metric::getField, Metric::getMetrics));

// Existing support only for MetricAggregators without sub-aggregations
if (aggregatorFactory.getSubFactories().getFactories().length != 0) {
return false;
}

// TODO: increment supported aggregation type
if (aggregatorFactory instanceof SumAggregatorFactory) {
field = ((SumAggregatorFactory) aggregatorFactory).getField();
if (supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM)) {
return true;
}
}

return false;
}

public Index index() {
return indexSettings.getIndex();
}
Expand Down
66 changes: 62 additions & 4 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,16 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
import org.opensearch.index.mapper.DerivedFieldResolver;
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
import org.opensearch.index.mapper.StarTreeMapper;
import org.opensearch.index.query.InnerHitContextBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.MatchNoneQueryBuilder;
import org.opensearch.index.query.ParsedQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -97,11 +101,13 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.AggregationInitializationException;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregation.ReduceContext;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.SearchContextAggregations;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.collapse.CollapseContext;
import org.opensearch.search.dfs.DfsPhase;
Expand Down Expand Up @@ -1314,6 +1320,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
context.evaluateRequestShouldUseConcurrentSearch();
return;
}
// Can be marked false for majority cases for which star-tree cannot be used
// As we increment the cases where star-tree can be used, this can be set back to true
boolean canUseStarTree = context.mapperService().isCompositeIndexPresent();

SearchShardTarget shardTarget = context.shardTarget();
QueryShardContext queryShardContext = context.getQueryShardContext();
context.from(source.from());
Expand All @@ -1324,10 +1334,12 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
context.parsedQuery(queryShardContext.toQuery(source.query()));
}
if (source.postFilter() != null) {
canUseStarTree = false;
InnerHitContextBuilder.extractInnerHits(source.postFilter(), innerHitBuilders);
context.parsedPostFilter(queryShardContext.toQuery(source.postFilter()));
}
if (innerHitBuilders.size() > 0) {
if (!innerHitBuilders.isEmpty()) {
canUseStarTree = false;
for (Map.Entry<String, InnerHitContextBuilder> entry : innerHitBuilders.entrySet()) {
try {
entry.getValue().build(context, context.innerHits());
Expand All @@ -1337,11 +1349,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
}
if (source.sorts() != null) {
canUseStarTree = false;
try {
Optional<SortAndFormats> optionalSort = SortBuilder.buildSort(source.sorts(), context.getQueryShardContext());
if (optionalSort.isPresent()) {
context.sort(optionalSort.get());
}
optionalSort.ifPresent(context::sort);
} catch (IOException e) {
throw new SearchException(shardTarget, "failed to create sort elements", e);
}
Expand All @@ -1354,9 +1365,11 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
throw new SearchException(shardTarget, "disabling [track_total_hits] is not allowed in a scroll context");
}
if (source.trackTotalHitsUpTo() != null) {
canUseStarTree = false;
context.trackTotalHitsUpTo(source.trackTotalHitsUpTo());
}
if (source.minScore() != null) {
canUseStarTree = false;
context.minimumScore(source.minScore());
}
if (source.timeout() != null) {
Expand Down Expand Up @@ -1496,6 +1509,51 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
if (source.profile()) {
context.setProfilers(new Profilers(context.searcher(), context.shouldUseConcurrentSearch()));
}

if (canUseStarTree) {
try {
setStarTreeQuery(context, queryShardContext, source);
logger.debug("can use star tree");
} catch (IOException e) {
logger.debug("not using star tree");
}
}
}

private boolean setStarTreeQuery(SearchContext context, QueryShardContext queryShardContext, SearchSourceBuilder source)
throws IOException {

if (source.aggregations() == null) {
return false;
}

// TODO: Support for multiple startrees
// Current implementation assumes only single star-tree is supported
CompositeDataCubeFieldType compositeMappedFieldType = (StarTreeMapper.StarTreeFieldType) context.mapperService()
.getCompositeFieldTypes()
.iterator()
.next();
CompositeIndexFieldInfo starTree = new CompositeIndexFieldInfo(
compositeMappedFieldType.name(),
compositeMappedFieldType.getCompositeIndexType()
);

ParsedQuery newParsedQuery = queryShardContext.toStarTreeQuery(starTree, compositeMappedFieldType, source.query(), context.query());
if (newParsedQuery == null) {
return false;
}

AggregatorFactory aggregatorFactory = context.aggregations().factories().getFactories()[0];
if (!(aggregatorFactory instanceof ValuesSourceAggregatorFactory
&& aggregatorFactory.getSubFactories().getFactories().length == 0)) {
return false;
}

if (queryShardContext.validateStarTreeMetricSuport(compositeMappedFieldType, aggregatorFactory)) {
context.parsedQuery(newParsedQuery);
}

return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ public static Builder builder() {
return new Builder();
}

private AggregatorFactories(AggregatorFactory[] factories) {
public AggregatorFactories(AggregatorFactory[] factories) {
this.factories = factories;
}

Expand Down Expand Up @@ -661,4 +661,8 @@ public PipelineTree buildPipelineTree() {
return new PipelineTree(subTrees, aggregators);
}
}

public AggregatorFactory[] getFactories() {
return factories;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,8 @@ protected boolean supportsConcurrentSegmentSearch() {
public boolean evaluateChildFactories() {
return factories.allFactoriesSupportConcurrentSearch();
}

public AggregatorFactories getSubFactories() {
return factories;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@

package org.opensearch.search.aggregations.metrics;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.common.util.Comparators;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.codec.composite.CompositeIndexReader;
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

/**
* Base class to aggregate all docs into a single numeric metric value.
Expand Down Expand Up @@ -107,4 +114,16 @@ public BucketComparator bucketComparator(String key, SortOrder order) {
return (lhs, rhs) -> Comparators.compareDiscardNaN(metric(key, lhs), metric(key, rhs), order == SortOrder.ASC);
}
}

protected StarTreeValues getStarTreeValues(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
SegmentReader reader = Lucene.segmentReader(ctx.reader());
if (!(reader.getDocValuesReader() instanceof CompositeIndexReader)) {
return null;
}
CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader();
StarTreeValues values = (StarTreeValues) starTreeDocValuesReader.getCompositeIndexValues(starTree);
final AtomicReference<StarTreeValues> aggrVal = new AtomicReference<>(null);

return values;
}
}
Loading

0 comments on commit ad54ace

Please sign in to comment.