From e036b1b198bfa2eb5fbdd27fc02a5df95ecd939b Mon Sep 17 00:00:00 2001 From: "Ryan P. Brewster" Date: Wed, 24 Apr 2024 13:05:51 -0400 Subject: [PATCH] netty: Allow deframer errors to close stream with a status code Today, deframer errors cancel the stream without communicating a status code to the peer. This change causes deframer errors to trigger a best-effort attempt to send trailers with a status code so that the peer understands why the stream is being closed. Fixes #3996 --- .../grpc/netty/CancelServerStreamCommand.java | 26 +++++++++++++- .../io/grpc/netty/NettyServerHandler.java | 32 +++++++++++++++-- .../java/io/grpc/netty/NettyServerStream.java | 6 ++-- .../io/grpc/netty/NettyServerHandlerTest.java | 34 ++++++++++++++++++- .../io/grpc/netty/NettyServerStreamTest.java | 29 ++++++++++++++-- 5 files changed, 117 insertions(+), 10 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java index d9f5d96e06e..d49e3bd672b 100644 --- a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java +++ b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java @@ -27,10 +27,23 @@ final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand { private final NettyServerStream.TransportState stream; private final Status reason; + private final PeerNotify peerNotify; - CancelServerStreamCommand(NettyServerStream.TransportState stream, Status reason) { + private CancelServerStreamCommand( + NettyServerStream.TransportState stream, Status reason, PeerNotify peerNotify) { this.stream = Preconditions.checkNotNull(stream, "stream"); this.reason = Preconditions.checkNotNull(reason, "reason"); + this.peerNotify = Preconditions.checkNotNull(peerNotify, "peerNotify"); + } + + static CancelServerStreamCommand withReset( + NettyServerStream.TransportState stream, Status reason) { + return new CancelServerStreamCommand(stream, reason, PeerNotify.RESET); + } + + static CancelServerStreamCommand withReason( + NettyServerStream.TransportState stream, Status reason) { + return new CancelServerStreamCommand(stream, reason, PeerNotify.BEST_EFFORT_STATUS); } NettyServerStream.TransportState stream() { @@ -41,6 +54,10 @@ Status reason() { return reason; } + boolean wantsHeaders() { + return peerNotify == PeerNotify.BEST_EFFORT_STATUS; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -68,4 +85,11 @@ public String toString() { .add("reason", reason) .toString(); } + + private enum PeerNotify { + /** Notify the peer by sending a RST_STREAM with no other information. */ + RESET, + /** Notify the peer about the {@link #reason} by sending structured headers, if possible. */ + BEST_EFFORT_STATUS, + } } diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 77b448446b1..a6e855a199d 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -788,9 +788,37 @@ private void cancelStream(ChannelHandlerContext ctx, CancelServerStreamCommand c PerfMark.linkIn(cmd.getLink()); // Notify the listener if we haven't already. cmd.stream().transportReportStatus(cmd.reason()); - // Terminate the stream. - encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise); + + // Now we need to decide how we're going to notify the peer that this stream is closed. + // If possible, it's nice to inform the peer _why_ this stream was cancelled by sending + // a structured headers frame. + if (shouldCloseStreamWithHeaders(cmd, connection())) { + Metadata md = new Metadata(); + md.put(InternalStatus.CODE_KEY, cmd.reason()); + if (cmd.reason().getDescription() != null) { + md.put(InternalStatus.MESSAGE_KEY, cmd.reason().getDescription()); + } + Http2Headers headers = Utils.convertServerHeaders(md); + encoder().writeHeaders( + ctx, cmd.stream().id(), headers, /* padding = */ 0, /* endStream = */ true, promise); + } else { + // Terminate the stream. + encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise); + } + } + } + + // Determine whether a CancelServerStreamCommand should try to close the stream with a + // HEADERS or a RST_STREAM frame. The caller has some influence over this (they can + // configure cmd.wantsHeaders()). The state of the stream also has an influence: we + // only try to send HEADERS if the stream exists and hasn't already sent any headers. + private static boolean shouldCloseStreamWithHeaders( + CancelServerStreamCommand cmd, Http2Connection conn) { + if (!cmd.wantsHeaders()) { + return false; } + Http2Stream stream = conn.stream(cmd.stream().id()); + return stream != null && !stream.isHeadersSent(); } private void gracefulClose(final ChannelHandlerContext ctx, final GracefulServerCloseCommand msg, diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index a4304d5193e..836f39ddf19 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -130,7 +130,7 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status) @Override public void cancel(Status status) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) { - writeQueue.enqueue(new CancelServerStreamCommand(transportState(), status), true); + writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true); } } } @@ -189,7 +189,7 @@ public void deframeFailed(Throwable cause) { log.log(Level.WARNING, "Exception processing message", cause); Status status = Status.fromThrowable(cause); transportReportStatus(status); - handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReason(this, status), true); } private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) { @@ -222,7 +222,7 @@ private void handleWriteFutureFailures(ChannelFuture future) { */ protected void http2ProcessingFailed(Status status) { transportReportStatus(status); - handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReset(this, status), true); } void inboundDataReceived(ByteBuf frame, boolean endOfStream) { diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 281ff3b17d6..ce902a9620b 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -89,8 +89,10 @@ import java.io.InputStream; import java.nio.channels.ClosedChannelException; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.concurrent.TimeUnit; import org.junit.Before; @@ -469,11 +471,41 @@ public void connectionWindowShouldBeOverridden() throws Exception { public void cancelShouldSendRstStream() throws Exception { manualSetUp(); createStream(); - enqueue(new CancelServerStreamCommand(stream.transportState(), Status.DEADLINE_EXCEEDED)); + enqueue(CancelServerStreamCommand.withReset(stream.transportState(), Status.DEADLINE_EXCEEDED)); verifyWrite().writeRstStream(eq(ctx()), eq(stream.transportState().id()), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); } + @Test + public void cancelWithNotify_shouldSendHeaders() throws Exception { + manualSetUp(); + createStream(); + + enqueue(CancelServerStreamCommand.withReason( + stream.transportState(), + Status.RESOURCE_EXHAUSTED.withDescription("my custom description") + )); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Http2Headers.class); + verifyWrite() + .writeHeaders( + eq(ctx()), + eq(STREAM_ID), + captor.capture(), + eq(0), + eq(true), + any(ChannelPromise.class)); + + // For arcane reasons, the specific implementation of Http2Headers here doesn't actually support + // methods like `get(...)`, so we have to manually convert it into a map. + Map actualHeaders = new HashMap<>(); + for (Map.Entry entry : captor.getValue()) { + actualHeaders.put(entry.getKey().toString(), entry.getValue().toString()); + } + assertEquals("8", actualHeaders.get(InternalStatus.CODE_KEY.name())); + assertEquals("my custom description", actualHeaders.get(InternalStatus.MESSAGE_KEY.name())); + } + @Test public void headersWithInvalidContentTypeShouldFail() throws Exception { manualSetUp(); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index ab54d4e4e22..452f68341b1 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.netty.NettyTestUtil.messageFrame; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; @@ -37,6 +36,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ListMultimap; import io.grpc.Attributes; @@ -73,6 +73,8 @@ /** Unit tests for {@link NettyServerStream}. */ @RunWith(JUnit4.class) public class NettyServerStreamTest extends NettyStreamTestBase { + private static final int TEST_MAX_MESSAGE_SIZE = 128; + @Mock protected ServerStreamListener serverListener; @@ -380,10 +382,31 @@ public void emptyFramerShouldSendNoPayload() { public void cancelStreamShouldSucceed() { stream().cancel(Status.DEADLINE_EXCEEDED); verify(writeQueue).enqueue( - new CancelServerStreamCommand(stream().transportState(), Status.DEADLINE_EXCEEDED), + CancelServerStreamCommand.withReset(stream().transportState(), Status.DEADLINE_EXCEEDED), true); } + @Test + public void oversizedMessagesResultInResourceExhaustedTrailers() throws Exception { + @SuppressWarnings("InlineMeInliner") // Requires Java 11 + String oversizedMsg = Strings.repeat("a", TEST_MAX_MESSAGE_SIZE + 1); + stream.request(1); + stream.transportState().inboundDataReceived(messageFrame(oversizedMsg), false); + assertNull("message should have caused a deframer error", listenerMessageQueue().poll()); + + ArgumentCaptor cancelCmdCap = + ArgumentCaptor.forClass(CancelServerStreamCommand.class); + verify(writeQueue).enqueue(cancelCmdCap.capture(), eq(true)); + + Status status = Status.RESOURCE_EXHAUSTED + .withDescription("gRPC message exceeds maximum size 128: 129"); + + CancelServerStreamCommand actualCmd = cancelCmdCap.getValue(); + assertThat(actualCmd.reason().getCode()).isEqualTo(status.getCode()); + assertThat(actualCmd.reason().getDescription()).isEqualTo(status.getDescription()); + assertThat(actualCmd.wantsHeaders()).isTrue(); + } + @Override @SuppressWarnings("DirectInvocationOnMock") protected NettyServerStream createStream() { @@ -391,7 +414,7 @@ protected NettyServerStream createStream() { StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; TransportTracer transportTracer = new TransportTracer(); NettyServerStream.TransportState state = new NettyServerStream.TransportState( - handler, channel.eventLoop(), http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx, + handler, channel.eventLoop(), http2Stream, TEST_MAX_MESSAGE_SIZE, statsTraceCtx, transportTracer, "method"); NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY, "test-authority", statsTraceCtx);