Skip to content

Commit

Permalink
Abstracting outbound side of transport
Browse files Browse the repository at this point in the history
Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed Apr 18, 2024
1 parent f5c3ef9 commit 0a9c4d8
Show file tree
Hide file tree
Showing 13 changed files with 398 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
*/
public class BytesTransportRequest extends TransportRequest {

BytesReference bytes;
public BytesReference bytes;
Version version;

public BytesTransportRequest(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@
*
* @opensearch.internal
*/
final class CompressibleBytesOutputStream extends StreamOutput {
public final class CompressibleBytesOutputStream extends StreamOutput {

private final OutputStream stream;
private final BytesStream bytesStreamOutput;
private final boolean shouldCompress;

CompressibleBytesOutputStream(BytesStream bytesStreamOutput, boolean shouldCompress) throws IOException {
public CompressibleBytesOutputStream(BytesStream bytesStreamOutput, boolean shouldCompress) throws IOException {
this.bytesStreamOutput = bytesStreamOutput;
this.shouldCompress = shouldCompress;
if (shouldCompress) {
Expand All @@ -80,7 +80,7 @@ final class CompressibleBytesOutputStream extends StreamOutput {
* @return bytes underlying the stream
* @throws IOException if an exception occurs when writing or flushing
*/
BytesReference materializeBytes() throws IOException {
public BytesReference materializeBytes() throws IOException {
// If we are using compression the stream needs to be closed to ensure that EOS marker bytes are written.
// The actual ReleasableBytesStreamOutput will not be closed yet as it is wrapped in flushOnCloseStream when
// passed to the deflater stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public abstract class NetworkMessage {
protected final long requestId;
protected final byte status;

NetworkMessage(ThreadContext threadContext, Version version, byte status, long requestId) {
public NetworkMessage(ThreadContext threadContext, Version version, byte status, long requestId) {
this.threadContext = threadContext.captureAsWriteable();
this.version = version;
this.requestId = requestId;
Expand Down
161 changes: 30 additions & 131 deletions server/src/main/java/org/opensearch/transport/OutboundHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,22 @@

package org.opensearch.transport;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.CheckedSupplier;
import org.opensearch.common.io.stream.ReleasableBytesStreamOutput;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.network.CloseableChannel;
import org.opensearch.common.transport.NetworkExceptionHelper;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.NotifyOnceListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.OutboundMessageHandler.SendContext;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.opensearch.transport.nativeprotocol.NativeOutboundMessageHandler;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

/**
Expand All @@ -63,15 +57,12 @@
*/
final class OutboundHandler {

private static final Logger logger = LogManager.getLogger(OutboundHandler.class);

private final String nodeName;
private final Version version;
private final String[] features;
private final StatsTracker statsTracker;
private final ThreadPool threadPool;
private final BigArrays bigArrays;
private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;
private final Map<String, OutboundMessageHandler> protocolMessageHandlers;

OutboundHandler(
String nodeName,
Expand All @@ -83,14 +74,16 @@ final class OutboundHandler {
) {
this.nodeName = nodeName;
this.version = version;
this.features = features;
this.statsTracker = statsTracker;
this.threadPool = threadPool;
this.bigArrays = bigArrays;
this.protocolMessageHandlers = Map.of(
NativeInboundMessage.NATIVE_PROTOCOL,
new NativeOutboundMessageHandler(features, statsTracker, threadPool, bigArrays)
);
}

void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener<Void> listener) {
SendContext sendContext = new SendContext(channel, () -> bytes, listener);
SendContext sendContext = new SendContext(channel, () -> bytes, listener, statsTracker);
try {
internalSend(channel, sendContext);
} catch (IOException e) {
Expand All @@ -115,18 +108,18 @@ void sendRequest(
final boolean isHandshake
) throws IOException, TransportException {
Version version = Version.min(this.version, channelVersion);
OutboundMessage.Request message = new OutboundMessage.Request(
threadPool.getThreadContext(),
features,
// TODO: Add logic for protocols in transport message
OutboundMessageHandler outboundMessageHandler = protocolMessageHandlers.get("native");
ProtocolOutboundMessage message = outboundMessageHandler.convertRequestToOutboundMessage(
request,
version,
action,
requestId,
isHandshake,
compressRequest
compressRequest,
version
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onRequestSent(node, requestId, action, request, options));
sendMessage(channel, message, listener);
outboundMessageHandler.sendMessage(channel, message, listener);
}

/**
Expand All @@ -146,17 +139,18 @@ void sendResponse(
final boolean isHandshake
) throws IOException {
Version version = Version.min(this.version, nodeVersion);
OutboundMessage.Response message = new OutboundMessage.Response(
threadPool.getThreadContext(),
features,
// TODO: Add logic for protocols in transport message
OutboundMessageHandler outboundMessageHandler = protocolMessageHandlers.get("native");
ProtocolOutboundMessage message = outboundMessageHandler.convertResponseToOutboundMessage(
response,
version,
features,
requestId,
isHandshake,
compress
compress,
version
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
sendMessage(channel, message, listener);
outboundMessageHandler.sendMessage(channel, message, listener);
}

/**
Expand All @@ -173,23 +167,18 @@ void sendErrorResponse(
Version version = Version.min(this.version, nodeVersion);
TransportAddress address = new TransportAddress(channel.getLocalAddress());
RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
OutboundMessage.Response message = new OutboundMessage.Response(
threadPool.getThreadContext(),
features,
// TODO: Add logic for protocols in transport message
OutboundMessageHandler outboundMessageHandler = protocolMessageHandlers.get("native");
ProtocolOutboundMessage message = outboundMessageHandler.convertErrorResponseToOutboundMessage(
tx,
version,
features,
requestId,
false,
false
false,
version
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
sendMessage(channel, message, listener);
}

private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
internalSend(channel, sendContext);
outboundMessageHandler.sendMessage(channel, message, listener);
}

private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException {
Expand All @@ -213,94 +202,4 @@ void setMessageListener(TransportMessageListener listener) {
}
}

/**
* Internal message serializer
*
* @opensearch.internal
*/
private static class MessageSerializer implements CheckedSupplier<BytesReference, IOException>, Releasable {

private final OutboundMessage message;
private final BigArrays bigArrays;
private volatile ReleasableBytesStreamOutput bytesStreamOutput;

private MessageSerializer(OutboundMessage message, BigArrays bigArrays) {
this.message = message;
this.bigArrays = bigArrays;
}

@Override
public BytesReference get() throws IOException {
bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays);
return message.serialize(bytesStreamOutput);
}

@Override
public void close() {
IOUtils.closeWhileHandlingException(bytesStreamOutput);
}
}

private class SendContext extends NotifyOnceListener<Void> implements CheckedSupplier<BytesReference, IOException> {

private final TcpChannel channel;
private final CheckedSupplier<BytesReference, IOException> messageSupplier;
private final ActionListener<Void> listener;
private final Releasable optionalReleasable;
private long messageSize = -1;

private SendContext(
TcpChannel channel,
CheckedSupplier<BytesReference, IOException> messageSupplier,
ActionListener<Void> listener
) {
this(channel, messageSupplier, listener, null);
}

private SendContext(
TcpChannel channel,
CheckedSupplier<BytesReference, IOException> messageSupplier,
ActionListener<Void> listener,
Releasable optionalReleasable
) {
this.channel = channel;
this.messageSupplier = messageSupplier;
this.listener = listener;
this.optionalReleasable = optionalReleasable;
}

public BytesReference get() throws IOException {
BytesReference message;
try {
message = messageSupplier.get();
messageSize = message.length();
TransportLogger.logOutboundMessage(channel, message);
return message;
} catch (Exception e) {
onFailure(e);
throw e;
}
}

@Override
protected void innerOnResponse(Void v) {
assert messageSize != -1 : "If onResponse is being called, the message should have been serialized";
statsTracker.markBytesWritten(messageSize);
closeAndCallback(() -> listener.onResponse(v));
}

@Override
protected void innerOnFailure(Exception e) {
if (NetworkExceptionHelper.isCloseConnectionException(e)) {
logger.debug(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e);
} else {
logger.warn(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e);
}
closeAndCallback(() -> listener.onFailure(e));
}

private void closeAndCallback(Runnable runnable) {
Releasables.close(optionalReleasable, runnable::run);
}
}
}
Loading

0 comments on commit 0a9c4d8

Please sign in to comment.