Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP field via MultrangeQuery fix #16200 #16391

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 141 additions & 16 deletions server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.sandbox.search.MultiRangeQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
Expand All @@ -47,6 +52,7 @@
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.common.network.InetAddresses;
import org.opensearch.common.network.NetworkAddress;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.plain.SortedSetOrdinalsIndexFieldData;
Expand All @@ -58,11 +64,14 @@
import java.io.IOException;
import java.net.InetAddress;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.IntSupplier;
import java.util.function.Supplier;

/**
Expand Down Expand Up @@ -256,36 +265,90 @@ public Query termQuery(Object value, @Nullable QueryShardContext context) {
String term = value.toString();
if (term.contains("/")) {
final Tuple<InetAddress, Integer> cidr = InetAddresses.parseCidr(term);
return InetAddressPoint.newPrefixQuery(name(), cidr.v1(), cidr.v2());
PointRangeQuery pointsRange = (PointRangeQuery) InetAddressPoint.newPrefixQuery(name(), cidr.v1(), cidr.v2());
return SortedSetDocValuesField.newSlowRangeQuery(
name(),
new BytesRef(pointsRange.getLowerPoint()),
new BytesRef(pointsRange.getUpperPoint()),
true,
true
);
}
return SortedSetDocValuesField.newSlowExactQuery(name(), new BytesRef(((PointRangeQuery) query).getLowerPoint()));
}
return query;
}

// @TODO check strings, byterefs, inetaddresses for concrete and masks
@Override
public Query termsQuery(List<?> values, QueryShardContext context) {
failIfNotIndexedAndNoDocValues();
InetAddress[] addresses = new InetAddress[values.size()];
int i = 0;
for (Object value : values) {
InetAddress address;
if (value instanceof InetAddress) {
address = (InetAddress) value;
Tuple<List<InetAddress>, List<String>> ipsMasks = splitIpsAndMasks(values);
QueryUnion combiner = new QueryUnion();
convertIps(ipsMasks.v1(), combiner);
convertMasks(ipsMasks.v2(), context, combiner, combiner.getAsInt());
return combiner.get();
}

private void convertMasks(List<String> masks, QueryShardContext context, Consumer<Query> combiner, int clauses) {
if (!masks.isEmpty()) {
// attempting to avoid too many exception at best
if (masks.size() + clauses >= IndexSearcher.getMaxClauseCount() - 1 && isSearchable()) {
IpMultiRangeQueryBuilder multiRange = new IpMultiRangeQueryBuilder(name());
for (String strVal : masks) {
final Tuple<InetAddress, Integer> cidr = InetAddresses.parseCidr(strVal);
PointRangeQuery query = (PointRangeQuery) InetAddressPoint.newPrefixQuery(name(), cidr.v1(), cidr.v2());

multiRange.add(query.getLowerPoint(), query.getUpperPoint());
}
combiner.accept(multiRange.build());
} else {
if (value instanceof BytesRef) {
value = ((BytesRef) value).utf8ToString();
for (String strVal : masks) {
combiner.accept(termQuery(strVal, context));
}
}
}
}

private void convertIps(List<InetAddress> inetAddresses, Consumer<Query> combiner) {
if (!inetAddresses.isEmpty()) {
Supplier<Query> pointsQuery;
pointsQuery = () -> inetAddresses.size() == 1
? InetAddressPoint.newExactQuery(name(), inetAddresses.iterator().next())
: InetAddressPoint.newSetQuery(name(), inetAddresses.toArray(new InetAddress[0]));
if (hasDocValues()) {
List<BytesRef> set = new ArrayList<>(inetAddresses.size());
for (final InetAddress address : inetAddresses) {
set.add(new BytesRef(InetAddressPoint.encode(address)));
}
Query dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), set);
if (!isSearchable()) {
pointsQuery = () -> dvQuery;
} else {
Supplier<Query> wrap = pointsQuery;
pointsQuery = () -> new IndexOrDocValuesQuery(wrap.get(), dvQuery);
}
if (value.toString().contains("/")) {
// the `terms` query contains some prefix queries, so we cannot create a set query
// and need to fall back to a disjunction of `term` queries
return super.termsQuery(values, context);
}
combiner.accept(pointsQuery.get());
}
}

private static Tuple<List<InetAddress>, List<String>> splitIpsAndMasks(List<?> values) {
List<InetAddress> concreteIPs = new ArrayList<>();
List<String> masks = new ArrayList<>();
for (final Object value : values) {
if (value instanceof InetAddress) {
concreteIPs.add((InetAddress) value);
} else {
final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString();
if (strVal.contains("/")) {
masks.add(strVal);
} else {
concreteIPs.add(InetAddresses.forString(strVal));
}
address = InetAddresses.forString(value.toString());
}
addresses[i++] = address;
}
return InetAddressPoint.newSetQuery(name(), addresses);
return Tuple.tuple(concreteIPs, masks);
}

@Override
Expand Down Expand Up @@ -437,6 +500,68 @@ public DocValueFormat docValueFormat(@Nullable String format, ZoneId timeZone) {
}
return DocValueFormat.IP;
}

private static class QueryUnion implements Consumer<Query>, Supplier<Query>, IntSupplier {
Query first;
BooleanQuery.Builder union;
int cnt;

@Override
public void accept(Query query) {
if (first == null) {
first = query;
} else {
if (union == null) {
union = new BooleanQuery.Builder();
union.add(first, BooleanClause.Occur.SHOULD);
}
union.add(query, BooleanClause.Occur.SHOULD);
}
cnt++;
}

@Override
public Query get() {
if (union != null) {
return new ConstantScoreQuery(union.build());
} else {
if (first != null) {
return first;
} else { // no matches then
return new BooleanQuery.Builder().build();
}
}
}

@Override
public int getAsInt() {
return cnt;
}
}
}

/**
* Union over IP address ranges
*/
public static class IpMultiRangeQueryBuilder extends MultiRangeQuery.Builder {
public IpMultiRangeQueryBuilder(String field) {
super(field, InetAddressPoint.BYTES, 1);
}

public IpMultiRangeQueryBuilder add(InetAddress lower, InetAddress upper) {
add(new MultiRangeQuery.RangeClause(InetAddressPoint.encode(lower), InetAddressPoint.encode(upper)));
return this;
}

@Override
public MultiRangeQuery build() {
return new MultiRangeQuery(field, numDims, bytesPerDim, clauses) {
@Override
protected String toString(int dimension, byte[] value) {
return NetworkAddress.format(InetAddressPoint.decode(value));
}
};
}
}

private final boolean indexed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public void testTermQuery() {
}

public void testTermsQuery() {
MappedFieldType ft = new IpFieldMapper.IpFieldType("field");
MappedFieldType ft = new IpFieldMapper.IpFieldType("field", true, false, false, null, Collections.emptyMap());

assertEquals(
InetAddressPoint.newSetQuery("field", InetAddresses.forString("::2"), InetAddresses.forString("::5")),
Expand All @@ -129,14 +129,12 @@ public void testTermsQuery() {
);

// if the list includes a prefix query we fallback to a bool query
assertEquals(
new ConstantScoreQuery(
new BooleanQuery.Builder().add(ft.termQuery("::42", null), Occur.SHOULD)
.add(ft.termQuery("::2/16", null), Occur.SHOULD)
.build()
),
ft.termsQuery(Arrays.asList("::42", "::2/16"), null)
);
Query actual = ft.termsQuery(Arrays.asList("::42", "::2/16"), null);
assertTrue(actual instanceof ConstantScoreQuery);
assertTrue(((ConstantScoreQuery) actual).getQuery() instanceof BooleanQuery);
BooleanQuery bq = (BooleanQuery) ((ConstantScoreQuery) actual).getQuery();
assertEquals(2, bq.clauses().size());
assertTrue(bq.clauses().stream().allMatch(c -> c.getOccur() == Occur.SHOULD));
}

public void testRangeQuery() {
Expand Down
Loading
Loading