Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Adds method_parameters in neural search query to support ef_search #814

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.15...2.x)
### Features
### Enhancements
* Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
### Bug Fixes
- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.opensearch.index.query.MatchQueryBuilder;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
Expand Down Expand Up @@ -69,6 +70,7 @@ private void validateNormalizationProcessor(final String fileName, final String
loadModel(modelId);
addDocuments(getIndexNameForTest(), false);
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName);
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName, Map.of("ef_search", 100));
} finally {
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
}
Expand Down Expand Up @@ -96,10 +98,14 @@ private void createSearchPipeline(final String pipelineName) {
);
}

private void validateTestIndex(final String modelId, final String index, final String searchPipeline) throws Exception {
private void validateTestIndex(final String modelId, final String index, final String searchPipeline) {
validateTestIndex(modelId, index, searchPipeline, null);
}

private void validateTestIndex(final String modelId, final String index, final String searchPipeline, Map<String, ?> methodParameters) {
int docCount = getDocCount(index);
assertEquals(6, docCount);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
Map<String, Object> searchResponseAsMap = search(index, hybridQueryBuilder, null, 1, Map.of("search_pipeline", searchPipeline));
assertNotNull(searchResponseAsMap);
int hits = getHitCount(searchResponseAsMap);
Expand All @@ -110,12 +116,15 @@ private void validateTestIndex(final String modelId, final String index, final S
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId) {
private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ private void validateIndexQuery(final String modelId) {
null,
0.01f,
null,
null,
null
);
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -74,6 +75,7 @@ private void validateIndexQuery(final String modelId) {
100000f,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ private void validateTestIndex(final String modelId) throws Exception {
null,
null,
null,
null,
null
);
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, Map.of("ef_search", 100));
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
}
Expand All @@ -83,10 +84,15 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId) throws Exception {
validateTestIndexOnUpgrade(numberOfDocs, modelId, null);
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, Map<String, ?> methodParameters)
throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(numberOfDocs, docCount);
loadModel(modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
Map<String, Object> searchResponseAsMap = search(
getIndexNameForTest(),
hybridQueryBuilder,
Expand All @@ -103,12 +109,15 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId) {
private HybridQueryBuilder getQueryBuilder(final String modelId, final Map<String, ?> methodParameters) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
null,
0.01f,
null,
null,
null
);
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -100,6 +101,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
100000f,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
null,
null,
null,
null,
null
);
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.common;

import com.google.common.collect.ImmutableMap;
import org.opensearch.Version;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import java.util.Map;

import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD;

/**
* A util class which holds the logic to determine the min version supported by the request parameters
*/
public final class MinClusterVersionUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;

// Note this minimal version will act as a override
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
.put(MODEL_ID_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID)
.put(MAX_DISTANCE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
.put(MIN_SCORE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
.build();

public static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}

public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

public static boolean isClusterOnOrAfterMinReqVersion(String key) {
Version version;
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
version = MINIMAL_VERSION_NEURAL.get(key);
} else {
version = IndexUtil.minimalRequiredVersionMap.get(key);
}
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(version);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
package org.opensearch.neuralsearch.query;

import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch;
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT;
Expand All @@ -19,7 +25,6 @@
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
Expand All @@ -34,8 +39,9 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import com.google.common.annotations.VisibleForTesting;

Expand Down Expand Up @@ -69,18 +75,11 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
@VisibleForTesting
static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image");

@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
public static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it is public?


@VisibleForTesting
static final ParseField K_FIELD = new ParseField("k");

@VisibleForTesting
static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");

@VisibleForTesting
static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");

private static final int DEFAULT_K = 10;

private static MLCommonsClientAccessor ML_CLIENT;
Expand All @@ -101,8 +100,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
private Map<String, ?> methodParameters;

/**
* Constructor from stream input
Expand Down Expand Up @@ -130,6 +128,9 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
this.maxDistance = in.readOptionalFloat();
this.minScore = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
}

@Override
Expand All @@ -152,6 +153,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(this.maxDistance);
out.writeOptionalFloat(this.minScore);
}
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
}

@Override
Expand All @@ -174,6 +178,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(minScore)) {
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
if (Objects.nonNull(methodParameters)) {
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand Down Expand Up @@ -267,6 +274,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
} else if (token == XContentParser.Token.START_OBJECT) {
if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
} else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(parser));
}
} else {
throw new ParsingException(
Expand All @@ -292,15 +301,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (vectorSupplier().get() == null) {
return this;
}
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
if (maxDistance != null) {
knnQueryBuilder.maxDistance(maxDistance);
} else if (minScore != null) {
knnQueryBuilder.minScore(minScore);
} else {
knnQueryBuilder.k(k);
}
return knnQueryBuilder;
return KNNQueryBuilder.builder()
.fieldName(fieldName())
.vector(vectorSupplier.get())
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.k(k)
.build();
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand All @@ -326,7 +334,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
maxDistance(),
minScore(),
vectorSetOnce::get,
filter()
filter(),
methodParameters()
);
}

Expand Down Expand Up @@ -359,14 +368,6 @@ public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}

private static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

private static boolean validateKNNQueryType(NeuralQueryBuilder neuralQueryBuilder) {
int queryCount = 0;
if (neuralQueryBuilder.k() != null) {
Expand Down
Loading
Loading