Skip to content

Commit

Permalink
fallback to MultiRange only when too many clauses
Browse files Browse the repository at this point in the history
Signed-off-by: mikhail-khludnev <[email protected]>
  • Loading branch information
mikhail-khludnev committed Nov 11, 2024
1 parent 6a11b54 commit 26ff736
Showing 1 changed file with 86 additions and 35 deletions.
121 changes: 86 additions & 35 deletions server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
Expand Down Expand Up @@ -69,6 +70,8 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.IntSupplier;
import java.util.function.Supplier;

/**
Expand Down Expand Up @@ -273,42 +276,42 @@ public Query termQuery(Object value, @Nullable QueryShardContext context) {
@Override
public Query termsQuery(List<?> values, QueryShardContext context) {
failIfNotIndexedAndNoDocValues();
List<InetAddress> concreteIPs = new ArrayList<>();
List<Query> ranges = new ArrayList<>();
IpMultiRangeQueryBuilder multiRange = new IpMultiRangeQueryBuilder(name());
boolean multiRangeIsEmpty = true;
for (final Object value : values) {
if (value instanceof InetAddress) {
concreteIPs.add((InetAddress) value);
} else {
final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString();
if (strVal.contains("/")) {
// the `terms` query contains some prefix queries, so we cannot create a set query
// and need to fall back to a disjunction of `term` queries
// Query query = termQuery(strVal, context);
Tuple<List<InetAddress>, List<String>> ipsMasks = splitIpsAndMasks(values);
QueryUnion combiner = new QueryUnion();
convertIps(ipsMasks.v1(), combiner);
convertMasks(ipsMasks.v2(), context, combiner, combiner.getAsInt());
return combiner.get();
}

private void convertMasks(List<String> masks, QueryShardContext context, Consumer<Query> combiner, int clauses) {
if (!masks.isEmpty()) {
// attempting to avoid too many exception at best
if (masks.size() + clauses >= IndexSearcher.getMaxClauseCount() - 1 && isSearchable()) {
IpMultiRangeQueryBuilder multiRange = new IpMultiRangeQueryBuilder(name());
for (String strVal : masks) {
final Tuple<InetAddress, Integer> cidr = InetAddresses.parseCidr(strVal);
PointRangeQuery query = (PointRangeQuery) InetAddressPoint.newPrefixQuery(name(), cidr.v1(), cidr.v2());

// would be great to have union on ranges over bare points
// ranges.add(query);
multiRange.add(query.getLowerPoint(), query.getUpperPoint());
multiRangeIsEmpty = false;
} else {
concreteIPs.add(InetAddresses.forString(strVal));
}
combiner.accept(multiRange.build());
} else {
for (String strVal : masks) {
combiner.accept(termQuery(strVal, context));
}
}
}
if (!multiRangeIsEmpty) {
ranges.add(multiRange.build());
}
if (!concreteIPs.isEmpty()) {
}

private void convertIps(List<InetAddress> inetAddresses, Consumer<Query> combiner) {
if (!inetAddresses.isEmpty()) {
Supplier<Query> pointsQuery;
pointsQuery = () -> concreteIPs.size() == 1
? InetAddressPoint.newExactQuery(name(), concreteIPs.iterator().next())
: InetAddressPoint.newSetQuery(name(), concreteIPs.toArray(new InetAddress[0]));
pointsQuery = () -> inetAddresses.size() == 1
? InetAddressPoint.newExactQuery(name(), inetAddresses.iterator().next())
: InetAddressPoint.newSetQuery(name(), inetAddresses.toArray(new InetAddress[0]));
if (hasDocValues()) {
List<BytesRef> set = new ArrayList<>(concreteIPs.size());
for (final InetAddress address : concreteIPs) {
List<BytesRef> set = new ArrayList<>(inetAddresses.size());
for (final InetAddress address : inetAddresses) {
set.add(new BytesRef(InetAddressPoint.encode(address)));
}
Query dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), set);
Expand All @@ -319,16 +322,26 @@ public Query termsQuery(List<?> values, QueryShardContext context) {
pointsQuery = () -> new IndexOrDocValuesQuery(wrap.get(), dvQuery);
}
}
ranges.add(pointsQuery.get());
}
if (ranges.size() == 1) {
return ranges.iterator().next(); // CSQ?
combiner.accept(pointsQuery.get());
}
BooleanQuery.Builder union = new BooleanQuery.Builder();
for (Query q : ranges) {
union.add(q, BooleanClause.Occur.SHOULD);
}

private static Tuple<List<InetAddress>, List<String>> splitIpsAndMasks(List<?> values) {
List<InetAddress> concreteIPs = new ArrayList<>();
List<String> masks = new ArrayList<>();
for (final Object value : values) {
if (value instanceof InetAddress) {
concreteIPs.add((InetAddress) value);
} else {
final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString();
if (strVal.contains("/")) {
masks.add(strVal);
} else {
concreteIPs.add(InetAddresses.forString(strVal));
}
}
}
return new ConstantScoreQuery(union.build());
return Tuple.tuple(concreteIPs, masks);
}

@Override
Expand Down Expand Up @@ -480,6 +493,44 @@ public DocValueFormat docValueFormat(@Nullable String format, ZoneId timeZone) {
}
return DocValueFormat.IP;
}

private static class QueryUnion implements Consumer<Query>, Supplier<Query>, IntSupplier {
Query first;
BooleanQuery.Builder union;
int cnt;

@Override
public void accept(Query query) {
if (first == null) {
first = query;
} else {
if (union == null) {
union = new BooleanQuery.Builder();
union.add(first, BooleanClause.Occur.SHOULD);
}
union.add(query, BooleanClause.Occur.SHOULD);
}
cnt++;
}

@Override
public Query get() {
if (union != null) {
return new ConstantScoreQuery(union.build());
} else {
if (first != null) {
return first;
} else { // no matches then
return new BooleanQuery.Builder().build();
}
}
}

@Override
public int getAsInt() {
return cnt;
}
}
}

/**
Expand Down

0 comments on commit 26ff736

Please sign in to comment.