Skip to content

Commit

Permalink
Added validations for score combination weights in Hybrid Search (#265)
Browse files Browse the repository at this point in the history
* Added strong check on number of weights equals number of sub-queries

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Aug 30, 2023
1 parent 75b59cd commit 685d5d6
Show file tree
Hide file tree
Showing 12 changed files with 231 additions and 74 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/))
### Enhancements
* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259))
* Added validations for score combination weights in Hybrid Search ([#265](https://github.com/opensearch-project/neural-search/pull/265))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params,
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float combinedScore = 0.0f;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public GeometricMeanScoreCombinationTechnique(final Map<String, Object> params,
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float weightedLnSum = 0;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params, f
*/
@Override
public float combine(final float[] scores) {
scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights);
float sumOfWeights = 0;
float sumOfHarmonics = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor.combination;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand All @@ -13,11 +14,19 @@
import java.util.Set;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.Range;

import com.google.common.math.DoubleMath;

/**
* Collection of utility methods for score combination technique classes
*/
@Log4j2
class ScoreCombinationUtil {
private static final String PARAM_NAME_WEIGHTS = "weights";
private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f;

/**
* Get collection of weights based on user provided config
Expand All @@ -29,9 +38,11 @@ public List<Float> getWeights(final Map<String, Object> params) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
List<Float> weightsList = ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
validateWeights(weightsList);
return weightsList;
}

/**
Expand Down Expand Up @@ -77,4 +88,55 @@ public void validateParams(final Map<String, Object> actualParams, final Set<Str
public float getWeightForSubQuery(final List<Float> weights, final int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}

/**
* Check if number of weights matches number of queries. This does not apply for case when
* weights were not provided, as this is valid default value
* @param scores collection of scores from all sub-queries of a single hybrid search query
* @param weights score combination weights that are defined as part of search result processor
*/
protected void validateIfWeightsMatchScores(final float[] scores, final List<Float> weights) {
if (weights.isEmpty()) {
return;
}
if (scores.length != weights.size()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"number of weights [%d] must match number of sub-queries [%d] in hybrid query",
weights.size(),
scores.length
)
);
}
}

/**
* Check if provided weights are valid for combination. Following conditions are checked:
* - every weight is between 0.0 and 1.0
* - sum of all weights must be equal 1.0
* @param weightsList
*/
private void validateWeights(final List<Float> weightsList) {
boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight));
if (isOutOfRange) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"all weights must be in range [0.0 ... 1.0], submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
}
float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum);
if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"sum of weights for combination must be equal to 1.0, submitted weights: %s",
Arrays.toString(weightsList.toArray(new Float[0]))
)
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.neuralsearch.processor;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults;
import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores;
import static org.opensearch.neuralsearch.TestUtils.createRandomVector;
Expand All @@ -18,6 +20,7 @@

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.ResponseException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;
Expand Down Expand Up @@ -96,7 +99,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.4f, 0.3f, 0.3f }))
);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
Expand All @@ -112,15 +115,15 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights1AsMap, 0.375, 0.3125, 0.001);
assertWeightedScores(searchResponseWithWeights1AsMap, 0.4, 0.3, 0.001);

// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.233f, 0.666f, 0.1f }))
);

Map<String, Object> searchResponseWithWeights2AsMap = search(
Expand All @@ -131,7 +134,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights2AsMap, 0.606, 0.242, 0.001);
assertWeightedScores(searchResponseWithWeights2AsMap, 0.6666, 0.2332, 0.001);

// check case when number of weights is less than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
Expand All @@ -140,18 +143,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 1.0f }))
);

Map<String, Object> searchResponseWithWeights3AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
ResponseException exception1 = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE))
);
org.hamcrest.MatcherAssert.assertThat(
exception1.getMessage(),
allOf(
containsString("number of weights"),
containsString("must match number of sub-queries"),
containsString("in hybrid query")
)
);

assertWeightedScores(searchResponseWithWeights3AsMap, 0.357, 0.285, 0.001);

// check case when number of weights is more than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
Expand All @@ -160,18 +166,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.25f, 0.25f, 0.2f }))
);

Map<String, Object> searchResponseWithWeights4AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
ResponseException exception2 = expectThrows(
ResponseException.class,
() -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE))
);
org.hamcrest.MatcherAssert.assertThat(
exception2.getMessage(),
allOf(
containsString("number of weights"),
containsString("must match number of sub-queries"),
containsString("in hybrid query")
)
);

assertWeightedScores(searchResponseWithWeights4AsMap, 0.375, 0.3125, 0.001);
}

/**
Expand Down Expand Up @@ -199,7 +208,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -223,7 +232,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
Expand Down Expand Up @@ -265,7 +274,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -289,7 +298,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
Expand All @@ -138,7 +138,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
L2_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
Expand Down Expand Up @@ -180,7 +180,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
DEFAULT_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);
String modelId = getDeployedModelId();

Expand All @@ -204,7 +204,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
HARMONIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder();
Expand All @@ -227,7 +227,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() {
SEARCH_PIPELINE,
DEFAULT_NORMALIZATION_METHOD,
GEOMETRIC_MEAN_COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f }))
);

HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.carrotsearch.randomizedtesting.RandomizedTest;

public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests {

private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();
Expand All @@ -33,9 +31,7 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
}

public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Expand All @@ -44,20 +40,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores()
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Float> scores = List.of(1.0f, -1.0f, 0.6f);
List<Double> weights = List.of(0.9, 0.2, 0.7);
List<Float> scores = List.of(1.0f, 0.0f, 0.6f);
List<Double> weights = List.of(0.45, 0.15, 0.4);
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
);
float expectedScore = 0.825f;
float expectedScore = 0.69f;
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore);
}

public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE)
.mapToObj(i -> RandomizedTest.randomDouble())
.collect(Collectors.toList());
List<Double> weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
scoreCombinationUtil
Expand Down
Loading

0 comments on commit 685d5d6

Please sign in to comment.