Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change abstraction point for transport protocol #15432

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion server/src/main/java/org/opensearch/transport/Header.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class Header {

private static final String RESPONSE_NAME = "NO_ACTION_NAME_FOR_RESPONSES";

private final TransportProtocol protocol;
andrross marked this conversation as resolved.
Show resolved Hide resolved
private final int networkMessageSize;
private final Version version;
private final long requestId;
Expand All @@ -64,13 +65,18 @@ public class Header {
Tuple<Map<String, String>, Map<String, Set<String>>> headers;
Set<String> features;

Header(int networkMessageSize, long requestId, byte status, Version version) {
Header(TransportProtocol protocol, int networkMessageSize, long requestId, byte status, Version version) {
this.protocol = protocol;
this.networkMessageSize = networkMessageSize;
this.version = version;
this.requestId = requestId;
this.status = status;
}

TransportProtocol getTransportProtocol() {
return protocol;
}

public int getNetworkMessageSize() {
return networkMessageSize;
}
Expand Down Expand Up @@ -142,6 +148,8 @@ void finishParsingHeader(StreamInput input) throws IOException {
@Override
public String toString() {
return "Header{"
+ protocol
+ "}{"
+ networkMessageSize
+ "}{"
+ version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.bytes.CompositeBytesReference;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -114,7 +113,7 @@ public void aggregate(ReleasableBytesReference content) {
}
}

public NativeInboundMessage finishAggregation() throws IOException {
public InboundMessage finishAggregation() throws IOException {
ensureOpen();
final ReleasableBytesReference releasableContent;
if (isFirstContent()) {
Expand All @@ -128,7 +127,7 @@ public NativeInboundMessage finishAggregation() throws IOException {
}

final BreakerControl breakerControl = new BreakerControl(circuitBreaker);
final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl);
final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl);
boolean success = false;
try {
if (aggregated.getHeader().needsToReadVariableHeader()) {
Expand All @@ -143,7 +142,7 @@ public NativeInboundMessage finishAggregation() throws IOException {
if (isShortCircuited()) {
aggregated.close();
success = true;
return new NativeInboundMessage(aggregated.getHeader(), aggregationException);
return new InboundMessage(aggregated.getHeader(), aggregationException);
} else {
success = true;
return aggregated;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,139 @@
package org.opensearch.transport;

import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.core.common.bytes.CompositeBytesReference;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.function.BiConsumer;

/**
* Interface for handling inbound bytes. Can be implemented by different transport protocols.
* Handler for inbound bytes, using {@link InboundDecoder} to decode headers
* and {@link InboundAggregator} to assemble complete messages to forward to
* the given message handler to parse the message payload.
*/
public interface InboundBytesHandler extends Closeable {
class InboundBytesHandler {

public void doHandleBytes(
TcpChannel channel,
ReleasableBytesReference reference,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException;
private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);

public boolean canHandleBytes(ReleasableBytesReference reference);
private final ArrayDeque<ReleasableBytesReference> pending;
private final InboundDecoder decoder;
private final InboundAggregator aggregator;
private final StatsTracker statsTracker;
private boolean isClosed = false;

InboundBytesHandler(
ArrayDeque<ReleasableBytesReference> pending,
InboundDecoder decoder,
InboundAggregator aggregator,
StatsTracker statsTracker
) {
this.pending = pending;
this.decoder = decoder;
this.aggregator = aggregator;
this.statsTracker = statsTracker;
}

public void close() {
isClosed = true;
}

public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference, BiConsumer<TcpChannel, InboundMessage> messageHandler)
throws IOException {
final ArrayList<Object> fragments = fragmentList.get();
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
}
}

if (fragments.isEmpty()) {
continueHandling = false;
} else {
try {
forwardFragments(channel, fragments, messageHandler);
} finally {
for (Object fragment : fragments) {
if (fragment instanceof ReleasableBytesReference) {
((ReleasableBytesReference) fragment).close();
}
}
fragments.clear();
}
}
}
}

private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
}
}

private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
try (ReleasableBytesReference reference = pending.pollFirst()) {
assert reference != null;
if (bytesToRelease < reference.length()) {
pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease));
bytesToRelease -= bytesToRelease;
} else {
bytesToRelease -= reference.length();
}
}
}
}

private boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments, BiConsumer<TcpChannel, InboundMessage> messageHandler)
throws IOException {
for (Object fragment : fragments) {
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, InboundMessage.PING);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (InboundMessage aggregated = aggregator.finishAggregation()) {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
}
}
}

@Override
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,12 @@ private int headerBytesToRead(BytesReference reference) {
// exposed for use in tests
static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException {
try (StreamInput streamInput = bytesReference.streamInput()) {
streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
TransportProtocol protocol = TransportProtocol.fromBytes(streamInput.readByte(), streamInput.readByte());
streamInput.skip(TcpHeader.MESSAGE_LENGTH_SIZE);
long requestId = streamInput.readLong();
byte status = streamInput.readByte();
Version remoteVersion = Version.fromId(streamInput.readInt());
Header header = new Header(networkMessageSize, requestId, status, remoteVersion);
Header header = new Header(protocol, networkMessageSize, requestId, status, remoteVersion);
final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake());
if (invalidVersion != null) {
throw invalidVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.Map;
Expand All @@ -56,7 +55,7 @@

private volatile long slowLogThresholdMs = Long.MAX_VALUE;

private final Map<String, ProtocolMessageHandler> protocolMessageHandlers;
private final Map<TransportProtocol, ProtocolMessageHandler> protocolMessageHandlers;

InboundHandler(
String nodeName,
Expand All @@ -75,7 +74,7 @@
) {
this.threadPool = threadPool;
this.protocolMessageHandlers = Map.of(
NativeInboundMessage.NATIVE_PROTOCOL,
TransportProtocol.NATIVE,
new NativeMessageHandler(
nodeName,
version,
Expand Down Expand Up @@ -107,16 +106,16 @@
this.slowLogThresholdMs = slowLogThreshold.getMillis();
}

void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws Exception {
void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception {
final long startTime = threadPool.relativeTimeInMillis();
channel.getChannelStats().markAccessed(startTime);
messageReceivedFromPipeline(channel, message, startTime);
}

private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException {
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol());
private void messageReceivedFromPipeline(TcpChannel channel, InboundMessage message, long startTime) throws IOException {
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getTransportProtocol());
if (protocolMessageHandler == null) {
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol());
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getTransportProtocol());

Check warning on line 118 in server/src/main/java/org/opensearch/transport/InboundHandler.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/transport/InboundHandler.java#L118

Added line #L118 was not covered by tests
}
protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener);
}
Expand Down
Loading
Loading