Skip to content

Commit

Permalink
Do not pass negative scores into function_score or script_score queries
Browse files Browse the repository at this point in the history
In theory, Lucene scores should never go negative. To stop users from
writing `function_score` and `script_score` queries that return
negative values, we explicitly check their outputs and throw an
exception when negative.

Unfortunately, due to a subtle, more complicated bug in multi_match
queries, sometimes those might (incorrectly) return negative scores.

While that problem is also worth solving, we should protect function
and script scoring from throwing an exception just for passing through
a negative value that they had no hand in computing.

Signed-off-by: Michael Froh <[email protected]>
  • Loading branch information
msfroh committed May 24, 2024
1 parent 56d8dc6 commit 983e43c
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Remove handling of index.mapper.dynamic in AutoCreateIndex([#13067](https://github.com/opensearch-project/OpenSearch/pull/13067))

### Fixed
- Fix negative RequestStats metric issue ([#13553](https://github.com/opensearch-project/OpenSearch/pull/13553))
- Fix get field mapping API returns 404 error in mixed cluster with multiple versions ([#13624](https://github.com/opensearch-project/OpenSearch/pull/13624))
- Allow clearing `remote_store.compatibility_mode` setting ([#13646](https://github.com/opensearch-project/OpenSearch/pull/13646))
- Replace negative input scores to function/script score queries with zero to avoid downstream exception ([#13627](https://github.com/opensearch-project/OpenSearch/pull/13627))
- Fix ReplicaShardBatchAllocator to batch shards without duplicates ([#13710](https://github.com/opensearch-project/OpenSearch/pull/13710))

### Security
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,79 @@
}]
- match: { error.root_cause.0.type: "illegal_argument_exception" }
- match: { error.root_cause.0.reason: "script score function must not produce negative scores, but got: [-9.0]"}

---
"Do not throw exception if input score is negative":
- do:
index:
index: test
id: 1
body: { "color" : "orange red yellow" }
- do:
index:
index: test
id: 2
body: { "color": "orange red purple", "shape": "red square" }
- do:
index:
index: test
id: 3
body: { "color" : "orange red yellow purple" }
- do:
indices.refresh: { }
- do:
search:
index: test
body:
query:
function_score:
query:
multi_match:
query: "red"
type: "cross_fields"
fields: [ "color", "shape^100"]
tie_breaker: 0.1
functions: [{
"script_score": {
"script": {
"lang": "painless",
"source": "_score"
}
}
}]
explain: true
- match: { hits.total.value: 3 }
- match: { hits.hits.2._score: 0.0 }
- do:
search:
index: test
body:
query:
function_score:
query:
multi_match:
query: "red"
type: "cross_fields"
fields: [ "color", "shape^100"]
tie_breaker: 0.1
weight: 1
explain: true
- match: { hits.total.value: 3 }
- match: { hits.hits.2._score: 0.0 }
- do:
search:
index: test
body:
query:
script_score:
query:
multi_match:
query: "red"
type: "cross_fields"
fields: [ "color", "shape^100"]
tie_breaker: 0.1
script:
source: "_score"
explain: true
- match: { hits.total.value: 3 }
- match: { hits.hits.2._score: 0.0 }
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,10 @@ public float score() throws IOException {
int docId = docID();
// Even if the weight is created with needsScores=false, it might
// be costly to call score(), so we explicitly check if scores
// are needed
float subQueryScore = needsScores ? super.score() : 0f;
// are needed.
// While the function scorer should never turn a score negative, we
// must guard against the input score being negative.
float subQueryScore = needsScores ? Math.max(0f, super.score()) : 0f;
if (leafFunctions.length == 0) {
return subQueryScore;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@
public class ScriptScoreFunction extends ScoreFunction {

static final class CannedScorer extends Scorable {
protected int docid;
protected float score;
private int docid;
private float score;

public void score(float subScore) {
// We check to make sure the script score function never makes a score negative, but we need to make
// sure the script score function does not receive negative input.
this.score = Math.max(0.0f, subScore);
}

@Override
public int docID() {
Expand Down Expand Up @@ -105,7 +111,7 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx
public double score(int docId, float subQueryScore) throws IOException {
leafScript.setDocument(docId);
scorer.docid = docId;
scorer.score = subQueryScore;
scorer.score(subQueryScore);
double result = leafScript.execute(null);
if (result < 0f) {
throw new IllegalArgumentException("script score function must not produce negative scores, but got: [" + result + "]");
Expand All @@ -119,7 +125,7 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE
if (leafScript instanceof ExplainableScoreScript) {
leafScript.setDocument(docId);
scorer.docid = docId;
scorer.score = subQueryScore.getValue().floatValue();
scorer.score(subQueryScore.getValue().floatValue());
exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore, functionName);
} else {
double score = score(docId, subQueryScore.getValue().floatValue());
Expand Down
4 changes: 3 additions & 1 deletion server/src/main/java/org/opensearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ public void setDocument(int docid) {
public void setScorer(Scorable scorer) {
this.scoreSupplier = () -> {
try {
return scorer.score();
// The ScoreScript is forbidden from returning a negative value.
// We should guard against receiving negative input.
return Math.max(0f, scorer.score());
} catch (IOException e) {
throw new UncheckedIOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;

import java.io.IOException;

/**
* Similar to Lucene's BoostQuery, but will accept negative boost values (which is normally wrong, since scores
* should not be negative). Useful for testing that other query types guard against negative input scores.
*/
public class NegativeBoostQuery extends Query {
private final Query query;
private final float boost;

public NegativeBoostQuery(Query query, float boost) {
if (boost >= 0) {
throw new IllegalArgumentException("Expected negative boost. Use BoostQuery if boost is non-negative.");
}
this.boost = boost;
this.query = query;
}

@Override
public String toString(String field) {
StringBuilder builder = new StringBuilder();
builder.append("(");
builder.append(query.toString(field));
builder.append(")");
builder.append("^");
builder.append(boost);
return builder.toString();
}

@Override
public void visit(QueryVisitor visitor) {
query.visit(visitor);
}

@Override
public boolean equals(Object other) {
return sameClassAs(other) && equalsTo(getClass().cast(other));
}

private boolean equalsTo(NegativeBoostQuery other) {
return query.equals(other.query) && Float.floatToIntBits(boost) == Float.floatToIntBits(other.boost);
}

@Override
public int hashCode() {
int h = classHash();
h = 31 * h + query.hashCode();
h = 31 * h + Float.floatToIntBits(boost);
return h;
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
final float negativeBoost = this.boost;
Weight delegate = query.createWeight(searcher, scoreMode, boost);
return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return delegate.explain(context, doc);
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
Scorer delegateScorer = delegate.scorer(context);
return new Scorer(this) {
@Override
public DocIdSetIterator iterator() {
return delegateScorer.iterator();
}

@Override
public float getMaxScore(int upTo) throws IOException {
return delegateScorer.getMaxScore(upTo);
}

@Override
public float score() throws IOException {
return delegateScorer.score() * negativeBoost;
}

@Override
public int docID() {
return delegateScorer.docID();
}
};
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return delegate.isCacheable(ctx);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.index.query.NegativeBoostQuery;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.MultiValueMode;
import org.opensearch.search.aggregations.support.ValuesSourceType;
Expand Down Expand Up @@ -1095,6 +1096,24 @@ public void testExceptionOnNegativeScores() {
assertThat(exc.getMessage(), not(containsString("consider using log1p or log2p instead of log to avoid negative scores")));
}

public void testNoExceptionOnNegativeScoreInput() throws IOException {
IndexSearcher localSearcher = new IndexSearcher(reader);
TermQuery termQuery = new TermQuery(new Term(FIELD, "out"));

// test that field_value_factor function throws an exception on negative scores
FieldValueFactorFunction.Modifier modifier = FieldValueFactorFunction.Modifier.NONE;

final ScoreFunction fvfFunction = new FieldValueFactorFunction(FIELD, 1, modifier, 1.0, new IndexNumericFieldDataStub());
FunctionScoreQuery fsQuery1 = new FunctionScoreQuery(
new NegativeBoostQuery(termQuery, -10f),
fvfFunction,
CombineFunction.MULTIPLY,
null,
Float.POSITIVE_INFINITY
);
localSearcher.search(fsQuery1, 1);
}

public void testExceptionOnLnNegativeScores() {
IndexSearcher localSearcher = new IndexSearcher(reader);
TermQuery termQuery = new TermQuery(new Term(FIELD, "out"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,23 @@
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.opensearch.Version;
import org.opensearch.common.lucene.search.Queries;
import org.opensearch.common.lucene.search.function.ScriptScoreQuery;
import org.opensearch.index.query.NegativeBoostQuery;
import org.opensearch.script.ScoreScript;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptType;
Expand All @@ -64,6 +68,7 @@
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;

import static org.hamcrest.CoreMatchers.containsString;
Expand Down Expand Up @@ -185,6 +190,15 @@ public void testScriptScoreErrorOnNegativeScore() {
assertTrue(e.getMessage().contains("Must be a non-negative score!"));
}

public void testNoExceptionOnNegativeInputScore() throws IOException {
Script script = new Script("script that returns _score");
ScoreScript.LeafFactory factory = newFactory(script, true, (s, e) -> s.get_score());
NegativeBoostQuery negativeBoostQuery = new NegativeBoostQuery(new TermQuery(new Term("field", "text")), -10.0f);
ScriptScoreQuery query = new ScriptScoreQuery(negativeBoostQuery, script, factory, -1f, "index", 0, Version.CURRENT);
TopDocs topDocs = searcher.search(query, 1);
assertEquals(0.0f, topDocs.scoreDocs[0].score, 0.0001);
}

public void testTwoPhaseIteratorDelegation() throws IOException {
Map<String, Object> params = new HashMap<>();
String scriptSource = "doc['field'].value != null ? 2.0 : 0.0"; // Adjust based on actual field and logic
Expand Down Expand Up @@ -220,6 +234,14 @@ private ScoreScript.LeafFactory newFactory(
Script script,
boolean needsScore,
Function<ScoreScript.ExplanationHolder, Double> function
) {
return newFactory(script, needsScore, (s, e) -> function.apply(e));
}

private ScoreScript.LeafFactory newFactory(
Script script,
boolean needsScore,
BiFunction<ScoreScript, ScoreScript.ExplanationHolder, Double> function
) {
SearchLookup lookup = mock(SearchLookup.class);
LeafSearchLookup leafLookup = mock(LeafSearchLookup.class);
Expand All @@ -236,7 +258,7 @@ public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return new ScoreScript(script.getParams(), lookup, indexSearcher, leafReaderContext) {
@Override
public double execute(ExplanationHolder explanation) {
return function.apply(explanation);
return function.apply(this, explanation);
}
};
}
Expand Down

0 comments on commit 983e43c

Please sign in to comment.