Skip to content

Commit

Permalink
Abstracting InboundHandlerTests
Browse files Browse the repository at this point in the history
Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed May 9, 2024
1 parent 0a7c0a3 commit 5048253
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public void testCompressedDecode() throws IOException {
transportMessage = new TestResponse(randomAlphaOfLength(100));
}

final BytesReference totalBytes = serialize(false, Version.CURRENT, false, true, action, requestId, transportMessage);
final BytesReference totalBytes = serialize(isRequest, Version.CURRENT, false, true, action, requestId, transportMessage);
final BytesStreamOutput out = new BytesStreamOutput();
transportMessage.writeTo(out);
final BytesReference uncompressedBytes = out.bytes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.tasks.TaskManager;
import org.opensearch.telemetry.tracing.noop.NoopTracer;
import org.opensearch.test.MockLogAppender;
Expand All @@ -56,7 +58,6 @@
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.opensearch.transport.nativeprotocol.NativeOutboundMessageTests;
import org.junit.After;
import org.junit.Before;

Expand All @@ -75,7 +76,17 @@
import static org.hamcrest.CoreMatchers.startsWith;
import static org.hamcrest.Matchers.instanceOf;

public class InboundHandlerTests extends OpenSearchTestCase {
public abstract class InboundHandlerTests extends OpenSearchTestCase {

public abstract BytesReference serializeOutboundRequest(
ThreadContext threadContext,
Writeable message,
Version version,
String action,
long requestId,
boolean compress,
boolean handshake
) throws IOException;

private final TestThreadPool threadPool = new TestThreadPool(getClass().getName());
private final Version version = Version.CURRENT;
Expand Down Expand Up @@ -194,12 +205,13 @@ public TestResponse read(StreamInput in) throws IOException {
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);

BytesReference fullRequestBytes = NativeOutboundMessageTests.serializeNativeOutboundRequest(
BytesReference fullRequestBytes = serializeOutboundRequest(
threadPool.getThreadContext(),
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
);
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Expand Down Expand Up @@ -397,12 +409,13 @@ public void onResponseSent(long requestId, String action, Exception error) {
});

// Create the request payload with 1 byte overflow
final BytesRef bytes = NativeOutboundMessageTests.serializeNativeOutboundRequest(
final BytesRef bytes = serializeOutboundRequest(
threadPool.getThreadContext(),
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
).toBytesRef();
final ByteBuffer buffer = ByteBuffer.allocate(bytes.length + 1);
Expand Down Expand Up @@ -469,12 +482,13 @@ public void onResponseSent(long requestId, String action, Exception error) {
}
});

final BytesReference fullRequestBytes = NativeOutboundMessageTests.serializeNativeOutboundRequest(
final BytesReference fullRequestBytes = serializeOutboundRequest(
threadPool.getThreadContext(),
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
);
// Create the request payload by intentionally stripping 1 byte away
Expand Down Expand Up @@ -537,12 +551,13 @@ public TestResponse read(StreamInput in) throws IOException {
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);

BytesReference fullRequestBytes = NativeOutboundMessageTests.serializeNativeOutboundRequest(
BytesReference fullRequestBytes = serializeOutboundRequest(
threadPool.getThreadContext(),
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
);
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Expand Down Expand Up @@ -630,12 +645,13 @@ public TestResponse read(StreamInput in) throws IOException {
requestHandlers.registerHandler(registry);
String requestValue = randomAlphaOfLength(10);

BytesReference fullRequestBytes = NativeOutboundMessageTests.serializeNativeOutboundRequest(
BytesReference fullRequestBytes = serializeOutboundRequest(
threadPool.getThreadContext(),
new TestRequest(requestValue),
version,
action,
requestId,
false,
false
);
BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.transport.nativeprotocol;

import org.opensearch.Version;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.transport.InboundHandlerTests;

import java.io.IOException;

public class NativeInboundHandlerTests extends InboundHandlerTests {

@Override
public BytesReference serializeOutboundRequest(
ThreadContext threadContext,
Writeable message,
Version version,
String action,
long requestId,
boolean compress,
boolean handshake
) throws IOException {
NativeOutboundMessage.Request request = new NativeOutboundMessage.Request(
threadContext,
new String[0],
message,
version,
action,
requestId,
handshake,
compress
);
return request.serialize(new BytesStreamOutput());
}

}

0 comments on commit 5048253

Please sign in to comment.