Skip to content

Commit

Permalink
Added Reranker feature (opensearch-project#591)
Browse files Browse the repository at this point in the history
* Adding support for generic re-ranker interface and opensearch ml re-ranker for improving search relavancy. (opensearch-project#494)

Signed-off-by: HenryL27 <[email protected]>
Co-authored-by: Heemin Kim <[email protected]>
Signed-off-by: Martin Gaievski <[email protected]>

---------

Signed-off-by: HenryL27 <[email protected]>
Signed-off-by: Martin Gaievski <[email protected]>
Co-authored-by: HenryL27 <[email protected]>
Co-authored-by: Heemin Kim <[email protected]>
(cherry picked from commit 1bb48e2)
  • Loading branch information
martin-gaievski committed Feb 6, 2024
1 parent ac04063 commit 0c5f387
Show file tree
Hide file tree
Showing 21 changed files with 1,999 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.11...2.x)
### Features
- Add rerank processor interface and ml-commons reranker ([#494](https://github.com/opensearch-project/neural-search/pull/494))
### Enhancements
### Bug Fixes
- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490))
Expand All @@ -22,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix Flaky test reported in #433 ([#533](https://github.com/opensearch-project/neural-search/pull/533))
- Enable support for default model id on HybridQueryBuilder ([#541](https://github.com/opensearch-project/neural-search/pull/541))
- Fix Flaky test reported in #384 ([#559](https://github.com/opensearch-project/neural-search/pull/559))
- Add validations for reranker requests per #555 ([#562](https://github.com/opensearch-project/neural-search/pull/562))
### Infrastructure
- BWC tests for Neural Search ([#515](https://github.com/opensearch-project/neural-search/pull/515))
- Github action to run integ tests in secure opensearch cluster ([#535](https://github.com/opensearch-project/neural-search/pull/535))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -137,6 +138,25 @@ public void inferenceSentences(
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}

/**
* Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the
* {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing
* the similarity scores of the texts w.r.t. the query text, in the order of the input texts.
*
* @param modelId {@link String} ML-Commons Model Id
* @param queryText {@link String} The query to compare all the inputText to
* @param inputText {@link List} of {@link String} The texts to compare to the query
* @param listener {@link ActionListener} receives the result of the inference
*/
public void inferenceSimilarity(
@NonNull final String modelId,
@NonNull final String queryText,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
Expand Down Expand Up @@ -178,12 +198,37 @@ private void retryableInferenceSentencesWithVectorResult(
}));
}

private void retryableInferenceSimilarityWithVectorResult(
final String modelId,
final String queryText,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}

private MLInput createMLTextPairsInput(final String query, final List<String> inputText) {
final MLInputDataset inputDataset = new TextSimilarityInputDataSet(query, inputText);
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.plugin;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -34,14 +35,17 @@
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
Expand All @@ -54,6 +58,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand Down Expand Up @@ -141,7 +146,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR

@Override
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS);
}

@Override
Expand All @@ -150,4 +155,25 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchReques
) {
return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory());
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService())
);
}

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List.of(
new SearchExtSpec<>(
RerankSearchExtBuilder.PARAM_FIELD_NAME,
in -> new RerankSearchExtBuilder(in),
parser -> RerankSearchExtBuilder.parse(parser)
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.factory;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.google.common.collect.Sets;

import lombok.AllArgsConstructor;

/**
* Factory for rerank processors. Must:
* - Instantiate the right kind of rerank processor
* - Instantiate the appropriate context source fetchers
*/
@AllArgsConstructor
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

public static final String RERANK_PROCESSOR_TYPE = "rerank";
public static final String CONTEXT_CONFIG_FIELD = "context";

private final MLCommonsClientAccessor clientAccessor;
private final ClusterService clusterService;

@Override
public SearchResponseProcessor create(
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
final String tag,
final String description,
final boolean ignoreFailure,
final Map<String, Object> config,
final Processor.PipelineContext pipelineContext
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(
config,
includeQueryContextFetcher,
tag,
clusterService
);
switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
String modelId = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
MLOpenSearchRerankProcessor.MODEL_ID_FIELD
);
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel()));
}
}

private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException {
// Set of rerank type labels in the config
Set<String> rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet());
// A rerank type must be provided
if (rerankTypes.size() == 0) {
StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]");
for (RerankType t : RerankType.values()) {
msgBuilder.add(t.getLabel());
}
throw new IllegalArgumentException(msgBuilder.toString());
}
// Only one rerank type may be provided
if (rerankTypes.size() > 1) {
StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted.");
rerankTypes.forEach(rt -> msgBuilder.add(rt));
throw new IllegalArgumentException(msgBuilder.toString());
}
return RerankType.from(rerankTypes.iterator().next());
}

/**
* Factory class for context fetchers. Constructs a list of context fetchers
* specified in the pipeline config (and maybe the query context fetcher)
*/
private static class ContextFetcherFactory {

/**
* Map rerank types to whether they should include the query context source fetcher
* @param type the constructing RerankType
* @return does this RerankType depend on the QueryContextSourceFetcher?
*/
public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
return type == RerankType.ML_OPENSEARCH;
}

/**
* Create necessary queryContextFetchers for this processor
* @param config processor config object. Look for "context" field to find fetchers
* @param includeQueryContextFetcher should I include the queryContextFetcher?
* @return list of contextFetchers for the processor to use
*/
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag,
final ClusterService clusterService
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
fetchers.add(DocumentContextSourceFetcher.create(cfg, clusterService));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
}
}
if (includeQueryContextFetcher) {
fetchers.add(new QueryContextSourceFetcher(clusterService));
}
return fetchers;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;

/**
* Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore
*/
public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor {

public static final String MODEL_ID_FIELD = "model_id";

protected final String modelId;

protected final MLCommonsClientAccessor mlCommonsClientAccessor;

/**
* Constructor
* @param description
* @param tag
* @param ignoreFailure
* @param modelId id of TEXT_SIMILARITY model
* @param contextSourceFetchers
* @param mlCommonsClientAccessor
*/
public MLOpenSearchRerankProcessor(
final String description,
final String tag,
final boolean ignoreFailure,
final String modelId,
final List<ContextSourceFetcher> contextSourceFetchers,
final MLCommonsClientAccessor mlCommonsClientAccessor
) {
super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers);
this.modelId = modelId;
this.mlCommonsClientAccessor = mlCommonsClientAccessor;
}

@Override
public void rescoreSearchResponse(
final SearchResponse response,
final Map<String, Object> rerankingContext,
final ActionListener<List<Float>> listener
) {
Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD);
if (!(ctxObj instanceof List<?>)) {
listener.onFailure(
new IllegalStateException(
String.format(
Locale.ROOT,
"No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?",
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
DocumentContextSourceFetcher.NAME
)
)
);
return;
}
List<?> ctxList = (List<?>) ctxObj;
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList());
mlCommonsClientAccessor.inferenceSimilarity(
modelId,
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD),
contexts,
listener
);
}

}
Loading

0 comments on commit 0c5f387

Please sign in to comment.