Skip to content

Commit

Permalink
Bugfix to guard against stack overflow errors caused by very large re…
Browse files Browse the repository at this point in the history
…g-ex input (#2810)

* Bugfix to guard against stack overflow errors caused by very large reg-ex input

This change fixes a code path that did not properly impose the index-level max_regex_length limit. Therefore, it was possibly to provide ar arbitrarily large string as the include/exclude reg-ex value under search aggregations. This exposed the underlying node to crashes from a StackOverflowError, due to how the Lucene RegExp class processes strings using stack frames.

Signed-off-by: Kartik Ganesh <[email protected]>

* Adding integration tests for large string RegEx

Signed-off-by: Kartik Ganesh <[email protected]>

* Spotless

Signed-off-by: Kartik Ganesh <[email protected]>
(cherry picked from commit 566ebfa)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Sep 26, 2024
1 parent dc95ba3 commit fabde0c
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,18 @@

package org.opensearch.search.aggregations;

import org.opensearch.OpenSearchException;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchPhaseExecutionException;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.search.aggregations.bucket.terms.IncludeExclude;
import org.opensearch.search.aggregations.bucket.terms.RareTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregatorFactory;
import org.opensearch.search.aggregations.bucket.terms.Terms;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregatorFactory;
import org.opensearch.test.OpenSearchIntegTestCase;

import java.util.ArrayList;
Expand All @@ -50,6 +58,11 @@ public class AggregationsIntegrationIT extends OpenSearchIntegTestCase {

static int numDocs;

private static final String LARGE_STRING = "a".repeat(2000);
private static final String LARGE_STRING_EXCEPTION_MESSAGE = "The length of regex ["
+ LARGE_STRING.length()
+ "] used in the request has exceeded the allowed maximum";

@Override
public void setupSuiteScopeCluster() throws Exception {
assertAcked(prepareCreate("index").addMapping("type", "f", "type=keyword").get());
Expand Down Expand Up @@ -85,4 +98,51 @@ public void testScroll() {
assertEquals(numDocs, total);
}

public void testLargeRegExTermsAggregation() {
for (TermsAggregatorFactory.ExecutionMode executionMode : TermsAggregatorFactory.ExecutionMode.values()) {
TermsAggregationBuilder termsAggregation = terms("my_terms").field("f")
.includeExclude(getLargeStringInclude())
.executionHint(executionMode.toString());
runLargeStringAggregationTest(termsAggregation);
}
}

public void testLargeRegExSignificantTermsAggregation() {
for (SignificantTermsAggregatorFactory.ExecutionMode executionMode : SignificantTermsAggregatorFactory.ExecutionMode.values()) {
SignificantTermsAggregationBuilder significantTerms = new SignificantTermsAggregationBuilder("my_terms").field("f")
.includeExclude(getLargeStringInclude())
.executionHint(executionMode.toString());
runLargeStringAggregationTest(significantTerms);
}
}

public void testLargeRegExRareTermsAggregation() {
// currently this only supports "map" as an execution hint
RareTermsAggregationBuilder rareTerms = new RareTermsAggregationBuilder("my_terms").field("f")
.includeExclude(getLargeStringInclude())
.maxDocCount(2);
runLargeStringAggregationTest(rareTerms);
}

private IncludeExclude getLargeStringInclude() {
return new IncludeExclude(LARGE_STRING, null);
}

private void runLargeStringAggregationTest(AggregationBuilder aggregation) {
boolean exceptionThrown = false;
IncludeExclude include = new IncludeExclude(LARGE_STRING, null);
try {
client().prepareSearch("index").addAggregation(aggregation).get();
} catch (SearchPhaseExecutionException ex) {
exceptionThrown = true;
Throwable nestedException = ex.getCause();
assertNotNull(nestedException);
assertTrue(nestedException instanceof OpenSearchException);
assertNotNull(nestedException.getCause());
assertTrue(nestedException.getCause() instanceof IllegalArgumentException);
String actualExceptionMessage = nestedException.getCause().getMessage();
assertTrue(actualExceptionMessage.startsWith(LARGE_STRING_EXCEPTION_MESSAGE));
}
assertTrue("Exception should have been thrown", exceptionThrown);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@
import org.apache.lucene.util.automaton.Operations;
import org.apache.lucene.util.automaton.RegExp;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.ToXContentFragment;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.IndexSettings;
import org.opensearch.search.DocValueFormat;

import java.io.IOException;
Expand Down Expand Up @@ -337,19 +339,16 @@ public LongBitSet acceptedGlobalOrdinals(SortedSetDocValues globalOrdinals) thro

}

private final RegExp include, exclude;
private final String include, exclude;
private final SortedSet<BytesRef> includeValues, excludeValues;
private final int incZeroBasedPartition;
private final int incNumPartitions;

/**
* @param include The regular expression pattern for the terms to be included
* @param exclude The regular expression pattern for the terms to be excluded
* @param include The string or regular expression pattern for the terms to be included
* @param exclude The string or regular expression pattern for the terms to be excluded
*/
public IncludeExclude(RegExp include, RegExp exclude) {
if (include == null && exclude == null) {
throw new IllegalArgumentException();
}
public IncludeExclude(String include, String exclude) {
this.include = include;
this.exclude = exclude;
this.includeValues = null;
Expand All @@ -358,10 +357,6 @@ public IncludeExclude(RegExp include, RegExp exclude) {
this.incNumPartitions = 0;
}

public IncludeExclude(String include, String exclude) {
this(include == null ? null : new RegExp(include), exclude == null ? null : new RegExp(exclude));
}

/**
* @param includeValues The terms to be included
* @param excludeValues The terms to be excluded
Expand Down Expand Up @@ -412,10 +407,8 @@ public IncludeExclude(StreamInput in) throws IOException {
excludeValues = null;
incZeroBasedPartition = 0;
incNumPartitions = 0;
String includeString = in.readOptionalString();
include = includeString == null ? null : new RegExp(includeString);
String excludeString = in.readOptionalString();
exclude = excludeString == null ? null : new RegExp(excludeString);
include = in.readOptionalString();
exclude = in.readOptionalString();
return;
}
include = null;
Expand Down Expand Up @@ -447,8 +440,8 @@ public void writeTo(StreamOutput out) throws IOException {
boolean regexBased = isRegexBased();
out.writeBoolean(regexBased);
if (regexBased) {
out.writeOptionalString(include == null ? null : include.getOriginalString());
out.writeOptionalString(exclude == null ? null : exclude.getOriginalString());
out.writeOptionalString(include);
out.writeOptionalString(exclude);
} else {
boolean hasIncludes = includeValues != null;
out.writeBoolean(hasIncludes);
Expand Down Expand Up @@ -584,26 +577,54 @@ public boolean isPartitionBased() {
return incNumPartitions > 0;
}

private Automaton toAutomaton() {
Automaton a = null;
private Automaton toAutomaton(@Nullable IndexSettings indexSettings) {
int maxRegexLength = indexSettings == null ? -1 : indexSettings.getMaxRegexLength();
Automaton a;
if (include != null) {
a = include.toAutomaton();
if (include.length() > maxRegexLength) {
throw new IllegalArgumentException(
"The length of regex ["
+ include.length()
+ "] used in the request has exceeded "
+ "the allowed maximum of ["
+ maxRegexLength
+ "]. "
+ "This maximum can be set by changing the ["
+ IndexSettings.MAX_REGEX_LENGTH_SETTING.getKey()
+ "] index level setting."
);
}
a = new RegExp(include).toAutomaton();
} else if (includeValues != null) {
a = Automata.makeStringUnion(includeValues);
} else {
a = Automata.makeAnyString();
}
if (exclude != null) {
a = Operations.minus(a, exclude.toAutomaton(), Operations.DEFAULT_DETERMINIZE_WORK_LIMIT);
if (exclude.length() > maxRegexLength) {
throw new IllegalArgumentException(
"The length of regex ["
+ exclude.length()
+ "] used in the request has exceeded "
+ "the allowed maximum of ["
+ maxRegexLength
+ "]. "
+ "This maximum can be set by changing the ["
+ IndexSettings.MAX_REGEX_LENGTH_SETTING.getKey()
+ "] index level setting."
);
}
Automaton excludeAutomaton = new RegExp(exclude).toAutomaton();
a = Operations.minus(a, excludeAutomaton, Operations.DEFAULT_DETERMINIZE_WORK_LIMIT);
} else if (excludeValues != null) {
a = Operations.minus(a, Automata.makeStringUnion(excludeValues), Operations.DEFAULT_DETERMINIZE_WORK_LIMIT);
}
return a;
}

public StringFilter convertToStringFilter(DocValueFormat format) {
public StringFilter convertToStringFilter(DocValueFormat format, IndexSettings indexSettings) {
if (isRegexBased()) {
return new AutomatonBackedStringFilter(toAutomaton());
return new AutomatonBackedStringFilter(toAutomaton(indexSettings));
}
if (isPartitionBased()) {
return new PartitionedStringFilter();
Expand All @@ -624,10 +645,10 @@ private static SortedSet<BytesRef> parseForDocValues(SortedSet<BytesRef> endUser
return result;
}

public OrdinalsFilter convertToOrdinalsFilter(DocValueFormat format) {
public OrdinalsFilter convertToOrdinalsFilter(DocValueFormat format, IndexSettings indexSettings) {

if (isRegexBased()) {
return new AutomatonBackedOrdinalsFilter(toAutomaton());
return new AutomatonBackedOrdinalsFilter(toAutomaton(indexSettings));
}
if (isPartitionBased()) {
return new PartitionedOrdinalsFilter();
Expand Down Expand Up @@ -684,7 +705,7 @@ public LongFilter convertToDoubleFilter() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (include != null) {
builder.field(INCLUDE_FIELD.getPreferredName(), include.getOriginalString());
builder.field(INCLUDE_FIELD.getPreferredName(), include);
} else if (includeValues != null) {
builder.startArray(INCLUDE_FIELD.getPreferredName());
for (BytesRef value : includeValues) {
Expand All @@ -698,7 +719,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
}
if (exclude != null) {
builder.field(EXCLUDE_FIELD.getPreferredName(), exclude.getOriginalString());
builder.field(EXCLUDE_FIELD.getPreferredName(), exclude);
} else if (excludeValues != null) {
builder.startArray(EXCLUDE_FIELD.getPreferredName());
for (BytesRef value : excludeValues) {
Expand All @@ -711,14 +732,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

@Override
public int hashCode() {
return Objects.hash(
include == null ? null : include.getOriginalString(),
exclude == null ? null : exclude.getOriginalString(),
includeValues,
excludeValues,
incZeroBasedPartition,
incNumPartitions
);
return Objects.hash(include, exclude, includeValues, excludeValues, incZeroBasedPartition, incNumPartitions);
}

@Override
Expand All @@ -730,14 +744,8 @@ public boolean equals(Object obj) {
return false;
}
IncludeExclude other = (IncludeExclude) obj;
return Objects.equals(
include == null ? null : include.getOriginalString(),
other.include == null ? null : other.include.getOriginalString()
)
&& Objects.equals(
exclude == null ? null : exclude.getOriginalString(),
other.exclude == null ? null : other.exclude.getOriginalString()
)
return Objects.equals(include, other.include)
&& Objects.equals(exclude, other.exclude)
&& Objects.equals(includeValues, other.includeValues)
&& Objects.equals(excludeValues, other.excludeValues)
&& Objects.equals(incZeroBasedPartition, other.incZeroBasedPartition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.opensearch.common.ParseField;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
Expand Down Expand Up @@ -250,7 +251,10 @@ Aggregator create(
double precision,
CardinalityUpperBound cardinality
) throws IOException {
final IncludeExclude.StringFilter filter = includeExclude == null ? null : includeExclude.convertToStringFilter(format);
IndexSettings indexSettings = context.getQueryShardContext().getIndexSettings();
final IncludeExclude.StringFilter filter = includeExclude == null
? null
: includeExclude.convertToStringFilter(format, indexSettings);
return new StringRareTermsAggregator(
name,
factories,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.opensearch.common.ParseField;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.DocValueFormat;
Expand Down Expand Up @@ -325,8 +326,10 @@ Aggregator create(
CardinalityUpperBound cardinality,
Map<String, Object> metadata
) throws IOException {

final IncludeExclude.StringFilter filter = includeExclude == null ? null : includeExclude.convertToStringFilter(format);
IndexSettings indexSettings = aggregationContext.getQueryShardContext().getIndexSettings();
final IncludeExclude.StringFilter filter = includeExclude == null
? null
: includeExclude.convertToStringFilter(format, indexSettings);
return new MapStringTermsAggregator(
name,
factories,
Expand Down Expand Up @@ -364,8 +367,10 @@ Aggregator create(
CardinalityUpperBound cardinality,
Map<String, Object> metadata
) throws IOException {

final IncludeExclude.OrdinalsFilter filter = includeExclude == null ? null : includeExclude.convertToOrdinalsFilter(format);
IndexSettings indexSettings = aggregationContext.getQueryShardContext().getIndexSettings();
final IncludeExclude.OrdinalsFilter filter = includeExclude == null
? null
: includeExclude.convertToOrdinalsFilter(format, indexSettings);
boolean remapGlobalOrd = true;
if (cardinality == CardinalityUpperBound.ONE && factories == AggregatorFactories.EMPTY && includeExclude == null) {
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.BytesRefHash;
import org.opensearch.common.util.ObjectArray;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -137,7 +138,10 @@ protected Aggregator createInternal(

// TODO - need to check with mapping that this is indeed a text field....

IncludeExclude.StringFilter incExcFilter = includeExclude == null ? null : includeExclude.convertToStringFilter(DocValueFormat.RAW);
IndexSettings indexSettings = searchContext.getQueryShardContext().getIndexSettings();
IncludeExclude.StringFilter incExcFilter = includeExclude == null
? null
: includeExclude.convertToStringFilter(DocValueFormat.RAW, indexSettings);

MapStringTermsAggregator.CollectorSource collectorSource = new SignificantTextCollectorSource(
queryShardContext.lookup().source(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.ParseField;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.AggregationExecutionException;
Expand Down Expand Up @@ -380,7 +381,10 @@ Aggregator create(
CardinalityUpperBound cardinality,
Map<String, Object> metadata
) throws IOException {
final IncludeExclude.StringFilter filter = includeExclude == null ? null : includeExclude.convertToStringFilter(format);
IndexSettings indexSettings = context.getQueryShardContext().getIndexSettings();
final IncludeExclude.StringFilter filter = includeExclude == null
? null
: includeExclude.convertToStringFilter(format, indexSettings);
return new MapStringTermsAggregator(
name,
factories,
Expand Down Expand Up @@ -458,7 +462,10 @@ Aggregator create(
);

}
final IncludeExclude.OrdinalsFilter filter = includeExclude == null ? null : includeExclude.convertToOrdinalsFilter(format);
IndexSettings indexSettings = context.getQueryShardContext().getIndexSettings();
final IncludeExclude.OrdinalsFilter filter = includeExclude == null
? null
: includeExclude.convertToOrdinalsFilter(format, indexSettings);
boolean remapGlobalOrds;
if (cardinality == CardinalityUpperBound.ONE && REMAP_GLOBAL_ORDS != null) {
/*
Expand Down
Loading

0 comments on commit fabde0c

Please sign in to comment.