Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP field via MultrangeQuery fix #16200 #16391

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.sandbox.search.MultiRangeQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
Expand All @@ -47,6 +51,7 @@
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.common.network.InetAddresses;
import org.opensearch.common.network.NetworkAddress;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.plain.SortedSetOrdinalsIndexFieldData;
Expand All @@ -58,6 +63,7 @@
import java.io.IOException;
import java.net.InetAddress;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -263,29 +269,66 @@ public Query termQuery(Object value, @Nullable QueryShardContext context) {
return query;
}

// @TODO check strings, byterefs, inetaddresses for concrete and masks
@Override
public Query termsQuery(List<?> values, QueryShardContext context) {
failIfNotIndexedAndNoDocValues();
InetAddress[] addresses = new InetAddress[values.size()];
int i = 0;
for (Object value : values) {
InetAddress address;
List<InetAddress> concreteIPs = new ArrayList<>();
List<Query> ranges = new ArrayList<>();
IpMultiRangeQueryBuilder multiRange = new IpMultiRangeQueryBuilder(name());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiRangeQuery handles points index, but there's no a special case for DV only field.

boolean multiRangeIsEmpty = true;
for (final Object value : values) {
if (value instanceof InetAddress) {
address = (InetAddress) value;
concreteIPs.add((InetAddress) value);
} else {
if (value instanceof BytesRef) {
value = ((BytesRef) value).utf8ToString();
}
if (value.toString().contains("/")) {
final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString();
if (strVal.contains("/")) {
// the `terms` query contains some prefix queries, so we cannot create a set query
// and need to fall back to a disjunction of `term` queries
return super.termsQuery(values, context);
// Query query = termQuery(strVal, context);
final Tuple<InetAddress, Integer> cidr = InetAddresses.parseCidr(strVal);
PointRangeQuery query = (PointRangeQuery) InetAddressPoint.newPrefixQuery(name(), cidr.v1(), cidr.v2());

// would be great to have union on ranges over bare points
// ranges.add(query);
multiRange.add(query.getLowerPoint(), query.getUpperPoint());
multiRangeIsEmpty = false;
} else {
concreteIPs.add(InetAddresses.forString(strVal));
}
address = InetAddresses.forString(value.toString());
}
addresses[i++] = address;
}
return InetAddressPoint.newSetQuery(name(), addresses);
if (!multiRangeIsEmpty) {
ranges.add(multiRange.build());
}
if (!concreteIPs.isEmpty()) {
Supplier<Query> pointsQuery;
pointsQuery = () -> concreteIPs.size() == 1
? InetAddressPoint.newExactQuery(name(), concreteIPs.iterator().next())
: InetAddressPoint.newSetQuery(name(), concreteIPs.toArray(new InetAddress[0]));
if (hasDocValues()) {
List<BytesRef> set = new ArrayList<>(concreteIPs.size());
for (final InetAddress address : concreteIPs) {
set.add(new BytesRef(InetAddressPoint.encode(address)));
}
Query dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), set);
if (!isSearchable()) {
pointsQuery = () -> dvQuery;
} else {
Supplier<Query> wrap = pointsQuery;
pointsQuery = () -> new IndexOrDocValuesQuery(wrap.get(), dvQuery);
}
}
ranges.add(pointsQuery.get());
}
if (ranges.size() == 1) {
return ranges.iterator().next(); // CSQ?
}
BooleanQuery.Builder union = new BooleanQuery.Builder();
for (Query q : ranges) {
union.add(q, BooleanClause.Occur.SHOULD);
}
return new ConstantScoreQuery(union.build());
}

@Override
Expand Down Expand Up @@ -439,6 +482,30 @@ public DocValueFormat docValueFormat(@Nullable String format, ZoneId timeZone) {
}
}

/**
* Union over IP address ranges
*/
public static class IpMultiRangeQueryBuilder extends MultiRangeQuery.Builder {
public IpMultiRangeQueryBuilder(String field) {
super(field, InetAddressPoint.BYTES, 1);
}

public IpMultiRangeQueryBuilder add(InetAddress lower, InetAddress upper) {
add(new MultiRangeQuery.RangeClause(InetAddressPoint.encode(lower), InetAddressPoint.encode(upper)));
return this;
}

@Override
public MultiRangeQuery build() {
return new MultiRangeQuery(field, numDims, bytesPerDim, clauses) {
@Override
protected String toString(int dimension, byte[] value) {
return NetworkAddress.format(InetAddressPoint.decode(value));
}
};
}
}

private final boolean indexed;
private final boolean hasDocValues;
private final boolean stored;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public void testTermQuery() {
}

public void testTermsQuery() {
MappedFieldType ft = new IpFieldMapper.IpFieldType("field");
MappedFieldType ft = new IpFieldMapper.IpFieldType("field", true, false, false, null, Collections.emptyMap());

assertEquals(
InetAddressPoint.newSetQuery("field", InetAddresses.forString("::2"), InetAddresses.forString("::5")),
Expand All @@ -131,8 +131,8 @@ public void testTermsQuery() {
// if the list includes a prefix query we fallback to a bool query
assertEquals(
new ConstantScoreQuery(
new BooleanQuery.Builder().add(ft.termQuery("::42", null), Occur.SHOULD)
.add(ft.termQuery("::2/16", null), Occur.SHOULD)
new BooleanQuery.Builder().add(ft.termQuery("::2/16", null), Occur.SHOULD)
.add(ft.termQuery("::42", null), Occur.SHOULD)
.build()
),
ft.termsQuery(Arrays.asList("::42", "::2/16"), null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search;

import org.opensearch.action.bulk.BulkRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.network.InetAddresses;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.test.OpenSearchSingleNodeTestCase;
import org.hamcrest.MatcherAssert;

import java.io.IOException;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE;
import static org.hamcrest.Matchers.equalTo;

public class SearchIpFieldTermsTests extends OpenSearchSingleNodeTestCase {

public static final boolean IPv4_ONLY = true;
static String defaultIndexName = "test";

public void testMassive() throws Exception {
XContentBuilder xcb = createMapping();
client().admin().indices().prepareCreate(defaultIndexName).setMapping(xcb).get();
ensureGreen();

BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();

int cidrs = 0;
int ips = 0;
List<String> toQuery = new ArrayList<>();
for (int i = 0; ips <= 10240 && cidrs <= 1024 && i < 1000000; i++) {
final String ip;
final int prefix;
if (IPv4_ONLY) {
ip = generateRandomIPv4();
prefix = 8 + random().nextInt(24); // CIDR prefix for IPv4
} else {
ip = generateRandomIPv6();
prefix = 32 + random().nextInt(97); // CIDR prefix for IPv6
}

bulkRequestBuilder.add(client().prepareIndex(defaultIndexName).setSource(Map.of("addr", ip)));

final String termToQuery;
if (random().nextBoolean()) {
termToQuery = ip + "/" + prefix;
cidrs++;
} else {
termToQuery = ip;
ips++;
}
toQuery.add(termToQuery);
}
int addMatches = 0;
for (int i = 0; i < atLeast(100); i++) {
final String ip;
if (IPv4_ONLY) {
ip = generateRandomIPv4();
} else {
ip = generateRandomIPv6();
}
bulkRequestBuilder.add(client().prepareIndex(defaultIndexName).setSource(Map.of("addr", ip)));
boolean match = false;
for (String termQ : toQuery) {
boolean isCidr = termQ.contains("/");
if ((isCidr && isIPInCIDR(ip, termQ)) || (!isCidr && termQ.equals(ip))) {
match = true;
break;
}
}
if (match) {
addMatches++;
} else {
break; // single mismatch is enough.
}
}

bulkRequestBuilder.setRefreshPolicy(IMMEDIATE).get();
SearchResponse result = client().prepareSearch(defaultIndexName).setQuery(QueryBuilders.termsQuery("addr", toQuery)).get();
MatcherAssert.assertThat(Objects.requireNonNull(result.getHits().getTotalHits()).value, equalTo((long) cidrs + ips + addMatches));
}

// Converts an IP string (either IPv4 or IPv6) to a byte array
private static byte[] ipToBytes(String ip) {
InetAddress inetAddress = InetAddresses.forString(ip);
return inetAddress.getAddress();
}

// Checks if an IP is within a given CIDR (works for both IPv4 and IPv6)
private static boolean isIPInCIDR(String ip, String cidr) {
String[] cidrParts = cidr.split("/");
String cidrIp = cidrParts[0];
int prefixLength = Integer.parseInt(cidrParts[1]);

byte[] ipBytes = ipToBytes(ip);
byte[] cidrIpBytes = ipToBytes(cidrIp);

// Calculate how many full bytes and how many bits are in the mask
int fullBytes = prefixLength / 8;
int extraBits = prefixLength % 8;

// Compare full bytes
for (int i = 0; i < fullBytes; i++) {
if (ipBytes[i] != cidrIpBytes[i]) {
return false;
}
}

// Compare extra bits (if any)
if (extraBits > 0) {
int mask = 0xFF << (8 - extraBits);
return (ipBytes[fullBytes] & mask) == (cidrIpBytes[fullBytes] & mask);
}

return true;
}

// Generate a random IPv4 address
private String generateRandomIPv4() {
return String.join(
".",
String.valueOf(random().nextInt(256)),
String.valueOf(random().nextInt(256)),
String.valueOf(random().nextInt(256)),
String.valueOf(random().nextInt(256))
);
}

// Generate a random IPv6 address
private static String generateRandomIPv6() {
StringBuilder ipv6 = new StringBuilder();
for (int i = 0; i < 8; i++) {
ipv6.append(Integer.toHexString(random().nextInt(0xFFFF + 1)));
if (i < 7) {
ipv6.append(":");
}
}
return ipv6.toString();
}

private XContentBuilder createMapping() throws IOException {
return XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("addr")
.field("type", "ip")
.startObject("fields")
.startObject("idx")
.field("type", "ip")
.field("doc_values", false)
.endObject()
.startObject("dv")
.field("type", "ip")
.field("index", false)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();
}
}
Loading