From 31c61c7c0d514384596342637aefbcd55e9023e8 Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Fri, 10 May 2024 21:57:12 +0000 Subject: [PATCH] Do not pass negative scores into function_score or script_score queries In theory, Lucene scores should never go negative. To stop users from writing `function_score` and `script_score` queries that return negative values, we explicitly check their outputs and throw an exception when negative. Unfortunately, due to a subtle, more complicated bug in multi_match queries, sometimes those might (incorrectly) return negative scores. While that problem is also worth solving, we should protect function and script scoring from throwing an exception just for passing through a negative value that they had no hand in computing. Signed-off-by: Michael Froh --- CHANGELOG.md | 2 +- .../rest-api-spec/test/painless/30_search.yml | 76 ++++++++++++ .../search/function/FunctionScoreQuery.java | 6 +- .../search/function/ScriptScoreFunction.java | 14 ++- .../org/opensearch/script/ScoreScript.java | 4 +- .../index/query/NegativeBoostQuery.java | 114 ++++++++++++++++++ .../functionscore/FunctionScoreTests.java | 19 +++ .../search/query/ScriptScoreQueryTests.java | 24 +++- 8 files changed, 250 insertions(+), 9 deletions(-) create mode 100644 server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f5d0ab4f7af38..9396907a107f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,8 +24,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Remove handling of index.mapper.dynamic in AutoCreateIndex([#13067](https://github.com/opensearch-project/OpenSearch/pull/13067)) ### Fixed -- Fix negative RequestStats metric issue ([#13553](https://github.com/opensearch-project/OpenSearch/pull/13553)) - Fix get field mapping API returns 404 error in mixed cluster with multiple versions ([#13624](https://github.com/opensearch-project/OpenSearch/pull/13624)) +- Replace negative input scores to function/script score queries with zero to avoid downstream exception ([#13627](https://github.com/opensearch-project/OpenSearch/pull/13627)) ### Security diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml index a006fde630716..e48add37f44b8 100644 --- a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml @@ -482,3 +482,79 @@ }] - match: { error.root_cause.0.type: "illegal_argument_exception" } - match: { error.root_cause.0.reason: "script score function must not produce negative scores, but got: [-9.0]"} + +--- +"Do not throw exception if input score is negative": + - do: + index: + index: test + id: 1 + body: { "color" : "orange red yellow" } + - do: + index: + index: test + id: 2 + body: { "color": "orange red purple", "shape": "red square" } + - do: + index: + index: test + id: 3 + body: { "color" : "orange red yellow purple" } + - do: + indices.refresh: { } + - do: + search: + index: test + body: + query: + function_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + functions: [{ + "script_score": { + "script": { + "lang": "painless", + "source": "_score" + } + } + }] + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } + - do: + search: + index: test + body: + query: + function_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + weight: 1 + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } + - do: + search: + index: test + body: + query: + script_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + script: + source: "_score" + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java b/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java index cb93e80288a98..512dec9f1f355 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java @@ -533,8 +533,10 @@ public float score() throws IOException { int docId = docID(); // Even if the weight is created with needsScores=false, it might // be costly to call score(), so we explicitly check if scores - // are needed - float subQueryScore = needsScores ? super.score() : 0f; + // are needed. + // While the function scorer should never turn a score negative, we + // must guard against the input score being negative. + float subQueryScore = needsScores ? Math.max(0f, super.score()) : 0f; if (leafFunctions.length == 0) { return subQueryScore; } diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java index 38c356a8be4b0..146dfd4440e2f 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java @@ -52,8 +52,14 @@ public class ScriptScoreFunction extends ScoreFunction { static final class CannedScorer extends Scorable { - protected int docid; - protected float score; + private int docid; + private float score; + + public void score(float subScore) { + // We check to make sure the script score function never makes a score negative, but we need to make + // sure the script score function does not receive negative input. + this.score = Math.max(0.0f, subScore); + } @Override public int docID() { @@ -105,7 +111,7 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx public double score(int docId, float subQueryScore) throws IOException { leafScript.setDocument(docId); scorer.docid = docId; - scorer.score = subQueryScore; + scorer.score(subQueryScore); double result = leafScript.execute(null); if (result < 0f) { throw new IllegalArgumentException("script score function must not produce negative scores, but got: [" + result + "]"); @@ -119,7 +125,7 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE if (leafScript instanceof ExplainableScoreScript) { leafScript.setDocument(docId); scorer.docid = docId; - scorer.score = subQueryScore.getValue().floatValue(); + scorer.score(subQueryScore.getValue().floatValue()); exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore, functionName); } else { double score = score(docId, subQueryScore.getValue().floatValue()); diff --git a/server/src/main/java/org/opensearch/script/ScoreScript.java b/server/src/main/java/org/opensearch/script/ScoreScript.java index 70de636a655f2..63bf9f43d133f 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScript.java +++ b/server/src/main/java/org/opensearch/script/ScoreScript.java @@ -165,7 +165,9 @@ public void setDocument(int docid) { public void setScorer(Scorable scorer) { this.scoreSupplier = () -> { try { - return scorer.score(); + // The ScoreScript is forbidden from returning a negative value. + // We should guard against receiving negative input. + return Math.max(0f, scorer.score()); } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java b/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java new file mode 100644 index 0000000000000..06b29cd1c1303 --- /dev/null +++ b/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java @@ -0,0 +1,114 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; + +/** + * Similar to Lucene's BoostQuery, but will accept negative boost values (which is normally wrong, since scores + * should not be negative). Useful for testing that other query types guard against negative input scores. + */ +public class NegativeBoostQuery extends Query { + private final Query query; + private final float boost; + + public NegativeBoostQuery(Query query, float boost) { + if (boost >= 0) { + throw new IllegalArgumentException("Expected negative boost. Use BoostQuery if boost is non-negative."); + } + this.boost = boost; + this.query = query; + } + + @Override + public String toString(String field) { + StringBuilder builder = new StringBuilder(); + builder.append("("); + builder.append(query.toString(field)); + builder.append(")"); + builder.append("^"); + builder.append(boost); + return builder.toString(); + } + + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor); + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && equalsTo(getClass().cast(other)); + } + + private boolean equalsTo(NegativeBoostQuery other) { + return query.equals(other.query) && Float.floatToIntBits(boost) == Float.floatToIntBits(other.boost); + } + + @Override + public int hashCode() { + int h = classHash(); + h = 31 * h + query.hashCode(); + h = 31 * h + Float.floatToIntBits(boost); + return h; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + final float negativeBoost = this.boost; + Weight delegate = query.createWeight(searcher, scoreMode, boost); + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return delegate.explain(context, doc); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Scorer delegateScorer = delegate.scorer(context); + return new Scorer(this) { + @Override + public DocIdSetIterator iterator() { + return delegateScorer.iterator(); + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return delegateScorer.getMaxScore(upTo); + } + + @Override + public float score() throws IOException { + return delegateScorer.score() * negativeBoost; + } + + @Override + public int docID() { + return delegateScorer.docID(); + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return delegate.isCacheable(ctx); + } + }; + } +} diff --git a/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java b/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java index 0ea91efc568d0..02ffd97835ef0 100644 --- a/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java +++ b/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java @@ -74,6 +74,7 @@ import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; import org.opensearch.index.fielddata.SortedNumericDoubleValues; +import org.opensearch.index.query.NegativeBoostQuery; import org.opensearch.search.DocValueFormat; import org.opensearch.search.MultiValueMode; import org.opensearch.search.aggregations.support.ValuesSourceType; @@ -1095,6 +1096,24 @@ public void testExceptionOnNegativeScores() { assertThat(exc.getMessage(), not(containsString("consider using log1p or log2p instead of log to avoid negative scores"))); } + public void testNoExceptionOnNegativeScoreInput() throws IOException { + IndexSearcher localSearcher = new IndexSearcher(reader); + TermQuery termQuery = new TermQuery(new Term(FIELD, "out")); + + // test that field_value_factor function throws an exception on negative scores + FieldValueFactorFunction.Modifier modifier = FieldValueFactorFunction.Modifier.NONE; + + final ScoreFunction fvfFunction = new FieldValueFactorFunction(FIELD, 1, modifier, 1.0, new IndexNumericFieldDataStub()); + FunctionScoreQuery fsQuery1 = new FunctionScoreQuery( + new NegativeBoostQuery(termQuery, -10f), + fvfFunction, + CombineFunction.MULTIPLY, + null, + Float.POSITIVE_INFINITY + ); + localSearcher.search(fsQuery1, 1); + } + public void testExceptionOnLnNegativeScores() { IndexSearcher localSearcher = new IndexSearcher(reader); TermQuery termQuery = new TermQuery(new Term(FIELD, "out")); diff --git a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java index 55c50b8cf854d..6d3fdb8f06a0b 100644 --- a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java +++ b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java @@ -39,6 +39,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -46,12 +47,15 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.opensearch.Version; import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.function.ScriptScoreQuery; +import org.opensearch.index.query.NegativeBoostQuery; import org.opensearch.script.ScoreScript; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; @@ -64,6 +68,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import static org.hamcrest.CoreMatchers.containsString; @@ -185,6 +190,15 @@ public void testScriptScoreErrorOnNegativeScore() { assertTrue(e.getMessage().contains("Must be a non-negative score!")); } + public void testNoExceptionOnNegativeInputScore() throws IOException { + Script script = new Script("script that returns _score"); + ScoreScript.LeafFactory factory = newFactory(script, true, (s, e) -> s.get_score()); + NegativeBoostQuery negativeBoostQuery = new NegativeBoostQuery(new TermQuery(new Term("field", "text")), -10.0f); + ScriptScoreQuery query = new ScriptScoreQuery(negativeBoostQuery, script, factory, -1f, "index", 0, Version.CURRENT); + TopDocs topDocs = searcher.search(query, 1); + assertEquals(0.0f, topDocs.scoreDocs[0].score, 0.0001); + } + public void testTwoPhaseIteratorDelegation() throws IOException { Map params = new HashMap<>(); String scriptSource = "doc['field'].value != null ? 2.0 : 0.0"; // Adjust based on actual field and logic @@ -220,6 +234,14 @@ private ScoreScript.LeafFactory newFactory( Script script, boolean needsScore, Function function + ) { + return newFactory(script, needsScore, (s, e) -> function.apply(e)); + } + + private ScoreScript.LeafFactory newFactory( + Script script, + boolean needsScore, + BiFunction function ) { SearchLookup lookup = mock(SearchLookup.class); LeafSearchLookup leafLookup = mock(LeafSearchLookup.class); @@ -236,7 +258,7 @@ public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { return new ScoreScript(script.getParams(), lookup, indexSearcher, leafReaderContext) { @Override public double execute(ExplanationHolder explanation) { - return function.apply(explanation); + return function.apply(this, explanation); } }; }