diff --git a/server/src/test/java/org/opensearch/transport/InboundDecoderTests.java b/server/src/test/java/org/opensearch/transport/InboundDecoderTests.java index f2b8b317e8b44..47ae8f9783ea3 100644 --- a/server/src/test/java/org/opensearch/transport/InboundDecoderTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundDecoderTests.java @@ -178,7 +178,7 @@ public void testCompressedDecode() throws IOException { transportMessage = new TestResponse(randomAlphaOfLength(100)); } - final BytesReference totalBytes = serialize(false, Version.CURRENT, false, true, action, requestId, transportMessage); + final BytesReference totalBytes = serialize(isRequest, Version.CURRENT, false, true, action, requestId, transportMessage); final BytesStreamOutput out = new BytesStreamOutput(); transportMessage.writeTo(out); final BytesReference uncompressedBytes = out.bytes(); diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 33b13129d23f4..2553e7740990b 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -42,12 +42,14 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.tasks.TaskManager; import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.MockLogAppender; @@ -56,7 +58,6 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.nativeprotocol.NativeInboundMessage; -import org.opensearch.transport.nativeprotocol.NativeOutboundMessageTests; import org.junit.After; import org.junit.Before; @@ -75,7 +76,17 @@ import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.Matchers.instanceOf; -public class InboundHandlerTests extends OpenSearchTestCase { +public abstract class InboundHandlerTests extends OpenSearchTestCase { + + public abstract BytesReference serializeOutboundRequest( + ThreadContext threadContext, + Writeable message, + Version version, + String action, + long requestId, + boolean compress, + boolean handshake + ) throws IOException; private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); private final Version version = Version.CURRENT; @@ -194,12 +205,13 @@ public TestResponse read(StreamInput in) throws IOException { requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - BytesReference fullRequestBytes = NativeOutboundMessageTests.serializeNativeOutboundRequest( + BytesReference fullRequestBytes = serializeOutboundRequest( threadPool.getThreadContext(), new TestRequest(requestValue), version, action, requestId, + false, false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); @@ -397,12 +409,13 @@ public void onResponseSent(long requestId, String action, Exception error) { }); // Create the request payload with 1 byte overflow - final BytesRef bytes = NativeOutboundMessageTests.serializeNativeOutboundRequest( + final BytesRef bytes = serializeOutboundRequest( threadPool.getThreadContext(), new TestRequest(requestValue), version, action, requestId, + false, false ).toBytesRef(); final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1); @@ -469,12 +482,13 @@ public void onResponseSent(long requestId, String action, Exception error) { } }); - final BytesReference fullRequestBytes = NativeOutboundMessageTests.serializeNativeOutboundRequest( + final BytesReference fullRequestBytes = serializeOutboundRequest( threadPool.getThreadContext(), new TestRequest(requestValue), version, action, requestId, + false, false ); // Create the request payload by intentionally stripping 1 byte away @@ -537,12 +551,13 @@ public TestResponse read(StreamInput in) throws IOException { requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - BytesReference fullRequestBytes = NativeOutboundMessageTests.serializeNativeOutboundRequest( + BytesReference fullRequestBytes = serializeOutboundRequest( threadPool.getThreadContext(), new TestRequest(requestValue), version, action, requestId, + false, false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); @@ -630,12 +645,13 @@ public TestResponse read(StreamInput in) throws IOException { requestHandlers.registerHandler(registry); String requestValue = randomAlphaOfLength(10); - BytesReference fullRequestBytes = NativeOutboundMessageTests.serializeNativeOutboundRequest( + BytesReference fullRequestBytes = serializeOutboundRequest( threadPool.getThreadContext(), new TestRequest(requestValue), version, action, requestId, + false, false ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); diff --git a/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundHandlerTests.java new file mode 100644 index 0000000000000..ec0c1a50d5560 --- /dev/null +++ b/server/src/test/java/org/opensearch/transport/nativeprotocol/NativeInboundHandlerTests.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.transport.InboundHandlerTests; + +import java.io.IOException; + +public class NativeInboundHandlerTests extends InboundHandlerTests { + + @Override + public BytesReference serializeOutboundRequest( + ThreadContext threadContext, + Writeable message, + Version version, + String action, + long requestId, + boolean compress, + boolean handshake + ) throws IOException { + NativeOutboundMessage.Request request = new NativeOutboundMessage.Request( + threadContext, + new String[0], + message, + version, + action, + requestId, + handshake, + compress + ); + return request.serialize(new BytesStreamOutput()); + } + +}