Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Reranker feature #591

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.11...2.x)
### Features
- Add rerank processor interface and ml-commons reranker ([#494](https://github.com/opensearch-project/neural-search/pull/494))
### Enhancements
### Bug Fixes
- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524))
- Fix Flaky test reported in #433 ([#533](https://github.com/opensearch-project/neural-search/pull/533))
- Enable support for default model id on HybridQueryBuilder ([#541](https://github.com/opensearch-project/neural-search/pull/541))
- Fix Flaky test reported in #384 ([#559](https://github.com/opensearch-project/neural-search/pull/559))
- Add validations for reranker requests per #555 ([#562](https://github.com/opensearch-project/neural-search/pull/562))
### Infrastructure
- BWC tests for Neural Search ([#515](https://github.com/opensearch-project/neural-search/pull/515))
- Github action to run integ tests in secure opensearch cluster ([#535](https://github.com/opensearch-project/neural-search/pull/535))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -132,6 +133,25 @@ public void inferenceSentences(
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}

/**
* Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the
* {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing
* the similarity scores of the texts w.r.t. the query text, in the order of the input texts.
*
* @param modelId {@link String} ML-Commons Model Id
* @param queryText {@link String} The query to compare all the inputText to
* @param inputText {@link List} of {@link String} The texts to compare to the query
* @param listener {@link ActionListener} receives the result of the inference
*/
public void inferenceSimilarity(
@NonNull final String modelId,
@NonNull final String queryText,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
Expand Down Expand Up @@ -173,12 +193,37 @@ private void retryableInferenceSentencesWithVectorResult(
}));
}

private void retryableInferenceSimilarityWithVectorResult(
final String modelId,
final String queryText,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}

private MLInput createMLTextPairsInput(final String query, final List<String> inputText) {
final MLInputDataset inputDataset = new TextSimilarityInputDataSet(query, inputText);
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.plugin;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -34,14 +35,17 @@
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
Expand All @@ -54,6 +58,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand Down Expand Up @@ -141,7 +146,7 @@

@Override
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS);

Check warning on line 149 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L149

Added line #L149 was not covered by tests
}

@Override
Expand All @@ -150,4 +155,25 @@
) {
return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory());
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(

Check warning on line 163 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L163

Added line #L163 was not covered by tests
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService())

Check warning on line 165 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L165

Added line #L165 was not covered by tests
);
}

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List.of(

Check warning on line 171 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L171

Added line #L171 was not covered by tests
new SearchExtSpec<>(
RerankSearchExtBuilder.PARAM_FIELD_NAME,
in -> new RerankSearchExtBuilder(in),
parser -> RerankSearchExtBuilder.parse(parser)

Check warning on line 175 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L174-L175

Added lines #L174 - L175 were not covered by tests
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.factory;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.google.common.collect.Sets;

import lombok.AllArgsConstructor;

/**
* Factory for rerank processors. Must:
* - Instantiate the right kind of rerank processor
* - Instantiate the appropriate context source fetchers
*/
@AllArgsConstructor
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

public static final String RERANK_PROCESSOR_TYPE = "rerank";
public static final String CONTEXT_CONFIG_FIELD = "context";

private final MLCommonsClientAccessor clientAccessor;
private final ClusterService clusterService;

@Override
public SearchResponseProcessor create(
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
final String tag,
final String description,
final boolean ignoreFailure,
final Map<String, Object> config,
final Processor.PipelineContext pipelineContext
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(
config,
includeQueryContextFetcher,
tag,
clusterService
);
switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
String modelId = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
MLOpenSearchRerankProcessor.MODEL_ID_FIELD
);
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel()));

Check warning on line 71 in src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java#L71

Added line #L71 was not covered by tests
}
}

private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException {
// Set of rerank type labels in the config
Set<String> rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet());
// A rerank type must be provided
if (rerankTypes.size() == 0) {
StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]");
for (RerankType t : RerankType.values()) {
msgBuilder.add(t.getLabel());
}
throw new IllegalArgumentException(msgBuilder.toString());
}
// Only one rerank type may be provided
if (rerankTypes.size() > 1) {
StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted.");
rerankTypes.forEach(rt -> msgBuilder.add(rt));
throw new IllegalArgumentException(msgBuilder.toString());

Check warning on line 90 in src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java#L88-L90

Added lines #L88 - L90 were not covered by tests
}
return RerankType.from(rerankTypes.iterator().next());
}

/**
* Factory class for context fetchers. Constructs a list of context fetchers
* specified in the pipeline config (and maybe the query context fetcher)
*/
private static class ContextFetcherFactory {

/**
* Map rerank types to whether they should include the query context source fetcher
* @param type the constructing RerankType
* @return does this RerankType depend on the QueryContextSourceFetcher?
*/
public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
return type == RerankType.ML_OPENSEARCH;
}

/**
* Create necessary queryContextFetchers for this processor
* @param config processor config object. Look for "context" field to find fetchers
* @param includeQueryContextFetcher should I include the queryContextFetcher?
* @return list of contextFetchers for the processor to use
*/
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag,
final ClusterService clusterService
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
fetchers.add(DocumentContextSourceFetcher.create(cfg, clusterService));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
}
}
if (includeQueryContextFetcher) {
fetchers.add(new QueryContextSourceFetcher(clusterService));
}
return fetchers;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;

/**
* Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore
*/
public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor {

public static final String MODEL_ID_FIELD = "model_id";

protected final String modelId;

protected final MLCommonsClientAccessor mlCommonsClientAccessor;

/**
* Constructor
* @param description
* @param tag
* @param ignoreFailure
* @param modelId id of TEXT_SIMILARITY model
* @param contextSourceFetchers
* @param mlCommonsClientAccessor
*/
public MLOpenSearchRerankProcessor(
final String description,
final String tag,
final boolean ignoreFailure,
final String modelId,
final List<ContextSourceFetcher> contextSourceFetchers,
final MLCommonsClientAccessor mlCommonsClientAccessor
) {
super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers);
this.modelId = modelId;
this.mlCommonsClientAccessor = mlCommonsClientAccessor;
}

@Override
public void rescoreSearchResponse(
final SearchResponse response,
final Map<String, Object> rerankingContext,
final ActionListener<List<Float>> listener
) {
Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD);
if (!(ctxObj instanceof List<?>)) {
listener.onFailure(
new IllegalStateException(
String.format(
Locale.ROOT,
"No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?",
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
DocumentContextSourceFetcher.NAME
)
)
);
return;
}
List<?> ctxList = (List<?>) ctxObj;
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList());
mlCommonsClientAccessor.inferenceSimilarity(
modelId,
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD),
contexts,
listener
);
}

}
Loading
Loading