diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 2b06a3fcf558..7d15d00fbccf 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -718,12 +718,16 @@ void returnProcessedBytes(Http2Stream http2Stream, int bytes) { } } - private void closeStreamWhenDone(ChannelPromise promise, Http2Stream stream) { + private void closeStreamWhenDone(ChannelPromise promise, Http2Stream stream, Status status) { promise.addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) { - serverStream(stream).complete(); + if (status.isOk()) { + serverStream(stream).complete(); + } else { + serverStream(stream).transportReportStatus(status); + } } }); } @@ -753,7 +757,7 @@ private void sendGrpcFrame( return; } if (cmd.endStream()) { - closeStreamWhenDone(promise, stream); + closeStreamWhenDone(promise, stream, Status.OK); } // Call the base class to write the HTTP/2 DATA frame. encoder().writeData(ctx, streamId, cmd.content(), 0, cmd.endStream(), promise); @@ -763,8 +767,8 @@ private void sendGrpcFrame( /** * Sends the response headers to the client. */ - private void sendResponseHeaders(ChannelHandlerContext ctx, SendResponseHeadersCommand cmd, - ChannelPromise promise) throws Http2Exception { + private void sendResponseHeaders( + ChannelHandlerContext ctx, SendResponseHeadersCommand cmd, ChannelPromise promise) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerHandler.sendResponseHeaders")) { PerfMark.attachTag(cmd.stream().tag()); PerfMark.linkIn(cmd.getLink()); @@ -775,7 +779,10 @@ private void sendResponseHeaders(ChannelHandlerContext ctx, SendResponseHeadersC return; } if (cmd.endOfStream()) { - closeStreamWhenDone(promise, stream); + // The stream listener only cares about the status if the close was triggered + // by the transport. Application-initiated closes should always see Status.OK. + Status status = cmd.source() == StatusSource.Application ? Status.OK : cmd.status(); + closeStreamWhenDone(promise, stream, status); } encoder().writeHeaders(ctx, streamId, cmd.headers(), 0, cmd.endOfStream(), promise); } diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index a44d8b4a64f0..a8bf1ef9a560 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -20,6 +20,7 @@ import com.google.common.base.Preconditions; import io.grpc.Attributes; +import io.grpc.InternalStatus; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.AbstractServerStream; @@ -203,7 +204,16 @@ public void deframeFailed(Throwable cause) { log.log(Level.WARNING, "Exception processing message", cause); Status status = Status.fromThrowable(cause); transportReportStatus(status); - handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + + Metadata trailers = new Metadata(); + trailers.put(InternalStatus.CODE_KEY, status); + if (status.getDescription() != null) { + trailers.put(InternalStatus.MESSAGE_KEY, status.getDescription()); + } + Http2Headers http2Trailers = Utils.convertTrailers(trailers, /* headersSent = */ false); + SendResponseHeadersCommand cmd = + SendResponseHeadersCommand.transportError(this, http2Trailers, status); + handler.getWriteQueue().enqueue(cmd, /* flush = */ true); } void inboundDataReceived(ByteBuf frame, boolean endOfStream) { diff --git a/netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java b/netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java index b649385bdc20..a53ccce21732 100644 --- a/netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java +++ b/netty/src/main/java/io/grpc/netty/SendResponseHeadersCommand.java @@ -28,21 +28,30 @@ final class SendResponseHeadersCommand extends WriteQueue.AbstractQueuedCommand private final StreamIdHolder stream; private final Http2Headers headers; private final Status status; + private final StatusSource source; - private SendResponseHeadersCommand(StreamIdHolder stream, Http2Headers headers, Status status) { + private SendResponseHeadersCommand( + StreamIdHolder stream, Http2Headers headers, Status status, StatusSource source) { this.stream = Preconditions.checkNotNull(stream, "stream"); this.headers = Preconditions.checkNotNull(headers, "headers"); this.status = status; + this.source = source; } static SendResponseHeadersCommand createHeaders(StreamIdHolder stream, Http2Headers headers) { - return new SendResponseHeadersCommand(stream, headers, null); + return new SendResponseHeadersCommand(stream, headers, null, StatusSource.Application); } static SendResponseHeadersCommand createTrailers( StreamIdHolder stream, Http2Headers headers, Status status) { return new SendResponseHeadersCommand( - stream, headers, Preconditions.checkNotNull(status, "status")); + stream, headers, Preconditions.checkNotNull(status, "status"), StatusSource.Application); + } + + static SendResponseHeadersCommand transportError( + StreamIdHolder stream, Http2Headers headers, Status status) { + Preconditions.checkArgument(!status.isOk(), "transport error must not be OK"); + return new SendResponseHeadersCommand(stream, headers, status, StatusSource.Transport); } StreamIdHolder stream() { @@ -61,6 +70,10 @@ Status status() { return status; } + StatusSource source() { + return source; + } + @Override public boolean equals(Object that) { if (that == null || !that.getClass().equals(SendResponseHeadersCommand.class)) { @@ -69,17 +82,18 @@ public boolean equals(Object that) { SendResponseHeadersCommand thatCmd = (SendResponseHeadersCommand) that; return thatCmd.stream.equals(stream) && thatCmd.headers.equals(headers) - && thatCmd.status.equals(status); + && thatCmd.status.equals(status) + && thatCmd.source.equals(source); } @Override public String toString() { return getClass().getSimpleName() + "(stream=" + stream.id() + ", headers=" + headers - + ", status=" + status + ")"; + + ", status=" + status + ", source=" + source + ")"; } @Override public int hashCode() { - return Objects.hashCode(stream, status); + return Objects.hashCode(stream, status, source); } } diff --git a/netty/src/main/java/io/grpc/netty/StatusSource.java b/netty/src/main/java/io/grpc/netty/StatusSource.java new file mode 100644 index 000000000000..e6efa5a3ade3 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/StatusSource.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +/** + * A stream can be closed by the application or by the transport itself. + * We emit listener events differently depending on who initiated the action. + */ +enum StatusSource { + Application, + Transport, +} diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index e95a2a52bc9a..8fc8f7ce2fe3 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -17,7 +17,6 @@ package io.grpc.netty; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.netty.NettyTestUtil.messageFrame; import static org.junit.Assert.assertNull; import static org.mockito.ArgumentMatchers.any; @@ -32,9 +31,11 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ListMultimap; import io.grpc.Attributes; +import io.grpc.InternalStatus; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ServerStreamListener; @@ -280,6 +281,31 @@ public void cancelStreamShouldSucceed() { true); } + @Test + public void oversizedMessagesResultInResourceExhaustedTrailers() throws Exception { + @SuppressWarnings("InlineMeInliner") // Requires Java 11 + String oversizedMsg = Strings.repeat("a", TEST_MAX_MESSAGE_SIZE + 1); + stream.request(1); + stream.transportState().inboundDataReceived(messageFrame(oversizedMsg), false); + assertNull("message should have caused a deframer error", listenerMessageQueue().poll()); + + ArgumentCaptor sendHeadersCap = + ArgumentCaptor.forClass(SendResponseHeadersCommand.class); + verify(writeQueue).enqueue(sendHeadersCap.capture(), eq(true)); + + Status status = Status.RESOURCE_EXHAUSTED + .withDescription("gRPC message exceeds maximum size 128: 129"); + + SendResponseHeadersCommand actualCmd = sendHeadersCap.getValue(); + assertThat(actualCmd.status().getCode()).isEqualTo(status.getCode()); + assertThat(actualCmd.status().getDescription()).isEqualTo(status.getDescription()); + Metadata trailers = Utils.convertTrailers(actualCmd.headers()); + assertThat(trailers.get(InternalStatus.CODE_KEY).getCode()).isEqualTo(status.getCode()); + assertThat(trailers.get(InternalStatus.MESSAGE_KEY)).isEqualTo(status.getDescription()); + } + + private static final int TEST_MAX_MESSAGE_SIZE = 128; + @Override @SuppressWarnings("DirectInvocationOnMock") protected NettyServerStream createStream() { @@ -287,7 +313,7 @@ protected NettyServerStream createStream() { StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; TransportTracer transportTracer = new TransportTracer(); NettyServerStream.TransportState state = new NettyServerStream.TransportState( - handler, channel.eventLoop(), http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx, + handler, channel.eventLoop(), http2Stream, TEST_MAX_MESSAGE_SIZE, statsTraceCtx, transportTracer, "method"); NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY, "test-authority", statsTraceCtx, transportTracer);