Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parent join support for lucene knn #1181

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ public class KNNConstants {
public static final String NAME = "name";
public static final String PARAMETERS = "parameters";
public static final String METHOD_HNSW = "hnsw";
public static final String TYPE = "type";
public static final String TYPE_NESTED = "nested";
public static final String PATH = "path";
public static final String QUERY = "query";
public static final String KNN = "knn";
public static final String VECTOR = "vector";
public static final String K = "k";
public static final String TYPE_KNN_VECTOR = "knn_vector";
public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search";
public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction";
public static final String METHOD_PARAMETER_M = "m";
Expand Down
55 changes: 26 additions & 29 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -87,9 +90,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
}

if (VectorDataType.BYTE == vectorDataType) {
return getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery);
return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, createQueryRequest.context.getParentFilter());
} else if (VectorDataType.FLOAT == vectorDataType) {
return getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery);
return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, createQueryRequest.context.getParentFilter());
} else {
throw new IllegalArgumentException(
String.format(
Expand All @@ -102,38 +105,30 @@ public static Query create(CreateQueryRequest createQueryRequest) {
}
}

private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) {
if (filterQuery != null) {
log.debug(
String.format(
Locale.ROOT,
"Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d",
indexName,
fieldName,
k
)
);
/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnByteVectorQuery(final String fieldName, final byte[] byteVector, final int k, final Query filterQuery, final BitSetProducer parentFilter) {
if (parentFilter == null) {
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnByteVectorQuery(fieldName, byteVector, k);
else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
}
}

private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) {
if (filterQuery != null) {
log.debug(
String.format(
Locale.ROOT,
"Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d",
indexName,
fieldName,
k
)
);
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnFloatVectorQuery(final String fieldName, final float[] floatVector, final int k, final Query filterQuery, final BitSetProducer parentFilter) {
if (parentFilter == null) {
return new KnnFloatVectorQuery(fieldName, floatVector, k);
}
else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnFloatVectorQuery(fieldName, floatVector, k);
}

private static Query getFilterQuery(CreateQueryRequest createQueryRequest) {
Expand Down Expand Up @@ -181,6 +176,8 @@ static class CreateQueryRequest {
@Getter
private int k;
// can be null in cases filter not passed with the knn query
@Getter
public BitSetProducer parentFilter;
private QueryBuilder filter;
// can be null in cases filter not passed with the knn query
private QueryShardContext context;
Expand Down
202 changes: 202 additions & 0 deletions src/test/java/org/opensearch/knn/index/NestedSearchIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.After;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.K;
import static org.opensearch.knn.common.KNNConstants.KNN;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.PATH;
import static org.opensearch.knn.common.KNNConstants.QUERY;
import static org.opensearch.knn.common.KNNConstants.TYPE;
import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR;
import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED;
import static org.opensearch.knn.common.KNNConstants.VECTOR;

public class NestedSearchIT extends KNNRestTestCase {
private static final String INDEX_NAME = "test-index-nested-search";
private static final String FIELD_NAME_NESTED = "test-nested";
private static final String FIELD_NAME_VECTOR = "test-vector";
private static final String PROPERTIES_FIELD = "properties";
private static final int EF_CONSTRUCTION = 128;
private static final int M = 16;
private static final SpaceType SPACE_TYPE = SpaceType.L2;

@After
@SneakyThrows
public final void cleanUp() {
deleteKNNIndex(INDEX_NAME);
}

@SneakyThrows
public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() {
createKnnIndex(2, KNNEngine.LUCENE.getName());

String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[]{1f, 1f}, new Float[]{1f, 1f})
.build();
addNestedKnnDoc(INDEX_NAME, "1", doc1);

String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[]{2f, 2f}, new Float[]{2f, 2f})
.build();
addNestedKnnDoc(INDEX_NAME, "2", doc2);

Float[] queryVector = { 1f, 1f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);

List<Object> hits = (List<Object>) ((Map<String, Object>) createParser(
MediaTypeRegistry.getDefaultMediaType().xContent(),
EntityUtils.toString(response.getEntity())
).map().get("hits")).get("hits");
assertEquals(2, hits.size());
}

/**
* {
* "properties": {
* "test-nested": {
* "type": "nested",
* "properties": {
* "test-vector": {
* "type": "knn_vector",
* "dimension": 3,
* "method": {
* "name": "hnsw",
* "space_type": "l2",
* "engine": "lucene",
* "parameters": {
* "ef_construction": 128,
* "m": 24
* }
* }
* }
* }
* }
* }
* }
*/
private void createKnnIndex(final int dimension, final String engine)
throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME_NESTED)
.field(TYPE, TYPE_NESTED)
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME_VECTOR)
.field(TYPE, TYPE_KNN_VECTOR)
.field(DIMENSION, dimension)
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, SPACE_TYPE)
.field(KNN_ENGINE, engine)
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_M, M)
.field(METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = builder.toString();
createKnnIndex(INDEX_NAME, mapping);
}

@SneakyThrows
private void ingestTestData() {
String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[]{1f, 1f}, new Float[]{1f, 1f})
.build();
addNestedKnnDoc(INDEX_NAME, "1", doc1);

String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[]{2f, 2f}, new Float[]{2f, 2f})
.build();
addNestedKnnDoc(INDEX_NAME, "2", doc2);
}

private void addNestedKnnDoc(final String index, final String docId, final String document) throws IOException {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

request.setJsonEntity(document);
client().performRequest(request);

request = new Request("POST", "/" + index + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
builder.startObject(TYPE_NESTED);
builder.field(PATH, FIELD_NAME_NESTED);
builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR);
builder.field(VECTOR, vector);
builder.field(K, k);
builder.endObject().endObject().endObject().endObject().endObject().endObject();

Request request = new Request("POST", "/" + index + "/_search");
request.setJsonEntity(builder.toString());

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

return response;
}

private static class NestedKnnDocBuilder {
private XContentBuilder builder;
public NestedKnnDocBuilder(final String fieldName) throws IOException {
builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName);
}

public static NestedKnnDocBuilder create(final String fieldName) throws IOException {
return new NestedKnnDocBuilder(fieldName);
}

public NestedKnnDocBuilder add(final String fieldName, final Object[]... vectors) throws IOException {
for (Object[] vector : vectors) {
builder.startObject();
builder.field(fieldName, vector);
builder.endObject();
}
return this;
}

public String build() throws IOException {
builder.endArray().endObject();
return builder.toString();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.mockito.Mockito;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
Expand All @@ -33,6 +37,7 @@ public class KNNQueryFactoryTests extends KNNTestCase {
private static final Query FILTER_QUERY = new TermQuery(new Term(FILTER_FILED_NAME, FILTER_FILED_VALUE));
private final int testQueryDimension = 17;
private final float[] testQueryVector = new float[testQueryDimension];
private final byte[] testByteQueryVector = new byte[testQueryDimension];
private final String testIndexName = "test-index";
private final String testFieldName = "test-field";
private final int testK = 10;
Expand Down Expand Up @@ -120,4 +125,34 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() {
assertEquals(testK, ((KNNQuery) query).getK());
assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery());
}

public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() {
validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, DiversifyingChildrenByteKnnVectorQuery.class);
validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class);
}
private void validateDiversifyingQueryWithParentFilter(final VectorDataType type, final Class expectedQueryClass) {
List<KNNEngine> luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
.collect(Collectors.toList());
for (KNNEngine knnEngine : luceneDefaultQueryEngineList) {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
MappedFieldType testMapper = mock(MappedFieldType.class);
when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper);
BitSetProducer parentFilter = mock(BitSetProducer.class);
when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter);
final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(testQueryVector)
.byteVector(testByteQueryVector)
.vectorDataType(type)
.k(testK)
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query.getClass().isAssignableFrom(expectedQueryClass));
}
}
}
Loading