Skip to content

Commit

Permalink
Allow deframer errors to close stream with a status code
Browse files Browse the repository at this point in the history
Today, deframer errors cancel the stream without communicating a status code
to the peer. This change causes deframer errors to trigger a best-effort
attempt to send trailers with a status code so that the peer understands
why the stream is being closed.
  • Loading branch information
ryanpbrewster committed Apr 16, 2024
1 parent b6ca908 commit 93018a8
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 15 deletions.
19 changes: 13 additions & 6 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -718,12 +718,16 @@ void returnProcessedBytes(Http2Stream http2Stream, int bytes) {
}
}

private void closeStreamWhenDone(ChannelPromise promise, Http2Stream stream) {
private void closeStreamWhenDone(ChannelPromise promise, Http2Stream stream, Status status) {
promise.addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
serverStream(stream).complete();
if (status.isOk()) {
serverStream(stream).complete();
} else {
serverStream(stream).transportReportStatus(status);
}
}
});
}
Expand Down Expand Up @@ -753,7 +757,7 @@ private void sendGrpcFrame(
return;
}
if (cmd.endStream()) {
closeStreamWhenDone(promise, stream);
closeStreamWhenDone(promise, stream, Status.OK);
}
// Call the base class to write the HTTP/2 DATA frame.
encoder().writeData(ctx, streamId, cmd.content(), 0, cmd.endStream(), promise);
Expand All @@ -763,8 +767,8 @@ private void sendGrpcFrame(
/**
* Sends the response headers to the client.
*/
private void sendResponseHeaders(ChannelHandlerContext ctx, SendResponseHeadersCommand cmd,
ChannelPromise promise) throws Http2Exception {
private void sendResponseHeaders(
ChannelHandlerContext ctx, SendResponseHeadersCommand cmd, ChannelPromise promise) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerHandler.sendResponseHeaders")) {
PerfMark.attachTag(cmd.stream().tag());
PerfMark.linkIn(cmd.getLink());
Expand All @@ -775,7 +779,10 @@ private void sendResponseHeaders(ChannelHandlerContext ctx, SendResponseHeadersC
return;
}
if (cmd.endOfStream()) {
closeStreamWhenDone(promise, stream);
// The stream listener only cares about the status if the close was triggered
// by the transport. Application-initiated closes should always see Status.OK.
Status status = cmd.source() == StatusSource.Application ? Status.OK : cmd.status();
closeStreamWhenDone(promise, stream, status);
}
encoder().writeHeaders(ctx, streamId, cmd.headers(), 0, cmd.endOfStream(), promise);
}
Expand Down
12 changes: 11 additions & 1 deletion netty/src/main/java/io/grpc/netty/NettyServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.InternalStatus;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.AbstractServerStream;
Expand Down Expand Up @@ -203,7 +204,16 @@ public void deframeFailed(Throwable cause) {
log.log(Level.WARNING, "Exception processing message", cause);
Status status = Status.fromThrowable(cause);
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);

Metadata trailers = new Metadata();
trailers.put(InternalStatus.CODE_KEY, status);
if (status.getDescription() != null) {
trailers.put(InternalStatus.MESSAGE_KEY, status.getDescription());
}
Http2Headers http2Trailers = Utils.convertTrailers(trailers, /* headersSent = */ false);
SendResponseHeadersCommand cmd =
SendResponseHeadersCommand.transportError(this, http2Trailers, status);
handler.getWriteQueue().enqueue(cmd, /* flush = */ true);
}

void inboundDataReceived(ByteBuf frame, boolean endOfStream) {
Expand Down
26 changes: 20 additions & 6 deletions netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,30 @@ final class SendResponseHeadersCommand extends WriteQueue.AbstractQueuedCommand
private final StreamIdHolder stream;
private final Http2Headers headers;
private final Status status;
private final StatusSource source;

private SendResponseHeadersCommand(StreamIdHolder stream, Http2Headers headers, Status status) {
private SendResponseHeadersCommand(
StreamIdHolder stream, Http2Headers headers, Status status, StatusSource source) {
this.stream = Preconditions.checkNotNull(stream, "stream");
this.headers = Preconditions.checkNotNull(headers, "headers");
this.status = status;
this.source = source;
}

static SendResponseHeadersCommand createHeaders(StreamIdHolder stream, Http2Headers headers) {
return new SendResponseHeadersCommand(stream, headers, null);
return new SendResponseHeadersCommand(stream, headers, null, StatusSource.Application);
}

static SendResponseHeadersCommand createTrailers(
StreamIdHolder stream, Http2Headers headers, Status status) {
return new SendResponseHeadersCommand(
stream, headers, Preconditions.checkNotNull(status, "status"));
stream, headers, Preconditions.checkNotNull(status, "status"), StatusSource.Application);
}

static SendResponseHeadersCommand transportError(
StreamIdHolder stream, Http2Headers headers, Status status) {
Preconditions.checkArgument(!status.isOk(), "transport error must not be OK");
return new SendResponseHeadersCommand(stream, headers, status, StatusSource.Transport);
}

StreamIdHolder stream() {
Expand All @@ -61,6 +70,10 @@ Status status() {
return status;
}

StatusSource source() {
return source;
}

@Override
public boolean equals(Object that) {
if (that == null || !that.getClass().equals(SendResponseHeadersCommand.class)) {
Expand All @@ -69,17 +82,18 @@ public boolean equals(Object that) {
SendResponseHeadersCommand thatCmd = (SendResponseHeadersCommand) that;
return thatCmd.stream.equals(stream)
&& thatCmd.headers.equals(headers)
&& thatCmd.status.equals(status);
&& thatCmd.status.equals(status)
&& thatCmd.source.equals(source);
}

@Override
public String toString() {
return getClass().getSimpleName() + "(stream=" + stream.id() + ", headers=" + headers
+ ", status=" + status + ")";
+ ", status=" + status + ", source=" + source + ")";
}

@Override
public int hashCode() {
return Objects.hashCode(stream, status);
return Objects.hashCode(stream, status, source);
}
}
26 changes: 26 additions & 0 deletions netty/src/main/java/io/grpc/netty/StatusSource.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright 2024 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.netty;

/**
* A stream can be closed by the application or by the transport itself.
* We emit listener events differently depending on who initiated the action.
*/
enum StatusSource {
Application,
Transport,
}
30 changes: 28 additions & 2 deletions netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package io.grpc.netty;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static org.junit.Assert.assertNull;
import static org.mockito.ArgumentMatchers.any;
Expand All @@ -32,9 +31,11 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import io.grpc.Attributes;
import io.grpc.InternalStatus;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.ServerStreamListener;
Expand Down Expand Up @@ -280,14 +281,39 @@ public void cancelStreamShouldSucceed() {
true);
}

@Test
public void oversizedMessagesResultInResourceExhaustedTrailers() throws Exception {
@SuppressWarnings("InlineMeInliner") // Requires Java 11
String oversizedMsg = Strings.repeat("a", TEST_MAX_MESSAGE_SIZE + 1);
stream.request(1);
stream.transportState().inboundDataReceived(messageFrame(oversizedMsg), false);
assertNull("message should have caused a deframer error", listenerMessageQueue().poll());

ArgumentCaptor<SendResponseHeadersCommand> sendHeadersCap =
ArgumentCaptor.forClass(SendResponseHeadersCommand.class);
verify(writeQueue).enqueue(sendHeadersCap.capture(), eq(true));

Status status = Status.RESOURCE_EXHAUSTED
.withDescription("gRPC message exceeds maximum size 128: 129");

SendResponseHeadersCommand actualCmd = sendHeadersCap.getValue();
assertThat(actualCmd.status().getCode()).isEqualTo(status.getCode());
assertThat(actualCmd.status().getDescription()).isEqualTo(status.getDescription());
Metadata trailers = Utils.convertTrailers(actualCmd.headers());
assertThat(trailers.get(InternalStatus.CODE_KEY).getCode()).isEqualTo(status.getCode());
assertThat(trailers.get(InternalStatus.MESSAGE_KEY)).isEqualTo(status.getDescription());
}

private static final int TEST_MAX_MESSAGE_SIZE = 128;

@Override
@SuppressWarnings("DirectInvocationOnMock")
protected NettyServerStream createStream() {
when(handler.getWriteQueue()).thenReturn(writeQueue);
StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
TransportTracer transportTracer = new TransportTracer();
NettyServerStream.TransportState state = new NettyServerStream.TransportState(
handler, channel.eventLoop(), http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx,
handler, channel.eventLoop(), http2Stream, TEST_MAX_MESSAGE_SIZE, statsTraceCtx,
transportTracer, "method");
NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY,
"test-authority", statsTraceCtx, transportTracer);
Expand Down

0 comments on commit 93018a8

Please sign in to comment.