Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug fix] Fix async actions are left in neural_sparse query #438

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
this.queryText = in.readString();
this.modelId = in.readString();
this.maxTokenScore = in.readOptionalFloat();
if (in.readBoolean()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a backward compatible change? I think for 2.10 and below versions there will be no boolean at all.
Plus the read - I think you can do same with readOptional method

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clarify my previous comment - that can be not backward compatible in terms of a rolling cluster upgrades. With that different cluster nodes may run different versions, say cluster is now on 2.8 and user wants to do rolling upgrade to 2.11 (some general documentation on this).
In such case query may be serialized at 2.8 node (without new field at all) and deserialized an a 2.11 node. If we just try to read the boolean there will be a runtime exception. We need to check minimal supported version before even reading boolean flag.
The reverse scenario is also possible: query serialized on 2.11 node and read on 2.8 node, in such case we shouldn't write to the stream.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is newly introduced in 2.11. No issue with backward compatibility.

Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
this.queryTokensSupplier = () -> queryTokens;
}
}

@Override
Expand All @@ -97,6 +101,12 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(queryText);
out.writeString(modelId);
out.writeOptionalFloat(maxTokenScore);
if (queryTokensSupplier != null && queryTokensSupplier.get() != null) {
out.writeBoolean(true);
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -276,16 +286,25 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we need to change formatting in this method, every if case body must be in curly braces, plus for null/not null checks we can utilize Objects static methods. I'm fine with doing this is future, not a blocker for this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's handle the issue when we backport it to 2.x

if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId)
.append(maxTokenScore, obj.maxTokenScore);
if (queryTokensSupplier != null) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore).toHashCode();
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore);
if (queryTokensSupplier != null) {
builder.append(queryTokensSupplier.get());
}
return builder.toHashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import lombok.SneakyThrows;

import org.opensearch.client.Client;
import org.opensearch.common.SetOnce;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -294,6 +295,23 @@ public void testStreams() {

NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);

SetOnce<Map<String, Float>> queryTokensSetOnce = new SetOnce<>();
queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f));
original.queryTokensSupplier(queryTokensSetOnce::get);

streamOutput = new BytesStreamOutput();
Copy link
Member

@martin-gaievski martin-gaievski Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we create a new instance of ByteStreamOutput instead of re-using existent one. That improves readability and lower chance of error. Same for line 306, instance of NamedWriteableAwareStreamInput. Can be done in followup PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's handle the issue when we backport it to 2.x or as a follow up pr.

original.writeTo(streamOutput);

filterStreamInput = new NamedWriteableAwareStreamInput(
streamOutput.bytes().streamInput(),
new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new))
)
);

copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);
}

public void testHashAndEquals() {
Expand All @@ -309,6 +327,8 @@ public void testHashAndEquals() {
float boost2 = 3.8f;
String queryName1 = "query-1";
String queryName2 = "query-2";
Map<String, Float> queryTokens1 = Map.of("hello", 1.0f, "world", 2.0f);
Map<String, Float> queryTokens2 = Map.of("hello", 1.0f, "world", 2.2f);

NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
Expand Down Expand Up @@ -379,6 +399,24 @@ public void testHashAndEquals() {
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens1);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens2);

assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline);
assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode());

Expand All @@ -405,6 +443,12 @@ public void testHashAndEquals() {

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode());
}

@SneakyThrows
Expand Down
Loading