Skip to content

Commit

Permalink
Adds rescore parameter in neural search (#885)
Browse files Browse the repository at this point in the history
This is required to enable rescoring for on disk mode indices

Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas authored Sep 5, 2024
1 parent 60b1ce8 commit cd577d5
Show file tree
Hide file tree
Showing 14 changed files with 136 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.17...2.x)
### Features
### Enhancements
- Adds rescore parameter support ([#885](https://github.com/opensearch-project/neural-search/pull/885))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR;
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD;

import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

Expand Down Expand Up @@ -69,8 +71,10 @@ private void validateNormalizationProcessor(final String fileName, final String
modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR);
loadModel(modelId);
addDocuments(getIndexNameForTest(), false);
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName);
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName, Map.of("ef_search", 100));
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
} finally {
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
}
Expand Down Expand Up @@ -98,15 +102,10 @@ private void createSearchPipeline(final String pipelineName) {
);
}

private void validateTestIndex(final String modelId, final String index, final String searchPipeline) {
validateTestIndex(modelId, index, searchPipeline, null);
}

private void validateTestIndex(final String modelId, final String index, final String searchPipeline, Map<String, ?> methodParameters) {
private void validateTestIndex(final String index, final String searchPipeline, HybridQueryBuilder queryBuilder) {
int docCount = getDocCount(index);
assertEquals(6, docCount);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
Map<String, Object> searchResponseAsMap = search(index, hybridQueryBuilder, null, 1, Map.of("search_pipeline", searchPipeline));
Map<String, Object> searchResponseAsMap = search(index, queryBuilder, null, 1, Map.of("search_pipeline", searchPipeline));
assertNotNull(searchResponseAsMap);
int hits = getHitCount(searchResponseAsMap);
assertEquals(1, hits);
Expand All @@ -116,7 +115,7 @@ private void validateTestIndex(final String modelId, final String index, final S
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters) {
private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters, RescoreContext rescoreContext) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
Expand All @@ -125,6 +124,9 @@ private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?>
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
if (rescoreContext != null) {
neuralQueryBuilder.rescoreContext(rescoreContext);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ private void validateIndexQuery(final String modelId) {
0.01f,
null,
null,
null,
null
);
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -76,6 +77,7 @@ private void validateIndexQuery(final String modelId) {
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ private void validateTestIndex(final String modelId) throws Exception {
null,
null,
null,
null,
null
);
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;

import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

Expand Down Expand Up @@ -59,11 +61,13 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountMixed;
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
}
break;
case UPGRADED:
Expand All @@ -72,8 +76,10 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, Map.of("ef_search", 100));
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
}
Expand All @@ -83,16 +89,11 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
}
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId) throws Exception {
validateTestIndexOnUpgrade(numberOfDocs, modelId, null);
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, Map<String, ?> methodParameters)
private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, HybridQueryBuilder hybridQueryBuilder)
throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(numberOfDocs, docCount);
loadModel(modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
Map<String, Object> searchResponseAsMap = search(
getIndexNameForTest(),
hybridQueryBuilder,
Expand All @@ -109,7 +110,11 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId, final Map<String, ?> methodParameters) {
private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
Expand All @@ -118,6 +123,9 @@ private HybridQueryBuilder getQueryBuilder(final String modelId, final Map<Strin
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
if (rescoreContext != null) {
neuralQueryBuilder.rescoreContext(rescoreContext);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
0.01f,
null,
null,
null,
null
);
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -102,6 +103,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
null,
null,
null,
null,
null
);
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch;
Expand Down Expand Up @@ -40,6 +41,8 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

Expand Down Expand Up @@ -101,6 +104,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private Map<String, ?> methodParameters;
private RescoreContext rescoreContext;

/**
* Constructor from stream input
Expand Down Expand Up @@ -131,6 +135,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
this.rescoreContext = RescoreParser.streamInput(in);
}

@Override
Expand All @@ -156,6 +161,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
RescoreParser.streamOutput(out, rescoreContext);
}

@Override
Expand All @@ -181,6 +187,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(methodParameters)) {
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
}
if (Objects.nonNull(rescoreContext)) {
RescoreParser.doXContent(xContentBuilder, rescoreContext);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand Down Expand Up @@ -276,6 +285,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
} else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(parser));
} else if (RESCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.rescoreContext(RescoreParser.fromXContent(parser));
}
} else {
throw new ParsingException(
Expand Down Expand Up @@ -308,6 +319,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
.maxDistance(maxDistance)
.minScore(minScore)
.k(k)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();
}

Expand Down Expand Up @@ -335,7 +348,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
minScore(),
vectorSetOnce::get,
filter(),
methodParameters()
methodParameters(),
rescoreContext()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() {
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -148,6 +149,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -188,6 +190,7 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -249,7 +249,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down Expand Up @@ -299,7 +299,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -324,7 +324,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down
Loading

0 comments on commit cd577d5

Please sign in to comment.