From 22dffdf10252e12f1dbec2a0dd7ac73f92695a80 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 5 May 2021 16:12:27 +0200 Subject: [PATCH] Implement QuicStreamChannel.bytesBeforeUnwritable() (#264) Motivation: Some users may depend on QuicStreamChannel.bytesBeforeUnwritable() to make decisions on how much they will try to write. Modifications: - Add implementation of QuicStreamChannel.bytesBeforeUnwritable() by keep track of the stream capacity - Add unit test Result: Be able to depend on QuicStreamChannel.bytesBeforeUnwritable() --- src/main/c/netty_quic_quiche.c | 9 ++ .../io/netty/incubator/codec/quic/Quiche.java | 6 + .../codec/quic/QuicheQuicChannel.java | 72 +++++----- .../codec/quic/QuicheQuicStreamChannel.java | 56 +++++--- .../codec/quic/QuicWritableTest.java | 125 ++++++++++++++++++ 5 files changed, 219 insertions(+), 49 deletions(-) diff --git a/src/main/c/netty_quic_quiche.c b/src/main/c/netty_quic_quiche.c index 8edfcf060..1351a3534 100644 --- a/src/main/c/netty_quic_quiche.c +++ b/src/main/c/netty_quic_quiche.c @@ -303,6 +303,14 @@ static jlong netty_quiche_conn_readable(JNIEnv* env, jclass clazz, jlong conn) { return (jlong) iter; } +static jlong netty_quiche_conn_writable(JNIEnv* env, jclass clazz, jlong conn) { + quiche_stream_iter* iter = quiche_conn_writable((quiche_conn *) conn); + if (iter == NULL) { + return -1; + } + return (jlong) iter; +} + static void netty_quiche_stream_iter_free(JNIEnv* env, jclass clazz, jlong iter) { quiche_stream_iter_free((quiche_stream_iter*) iter); } @@ -500,6 +508,7 @@ static const JNINativeMethod fixed_method_table[] = { { "quiche_conn_timeout_as_nanos", "(J)J", (void *) netty_quiche_conn_timeout_as_nanos }, { "quiche_conn_on_timeout", "(J)V", (void *) netty_quiche_conn_on_timeout }, { "quiche_conn_readable", "(J)J", (void *) netty_quiche_conn_readable }, + { "quiche_conn_writable", "(J)J", (void *) netty_quiche_conn_writable }, { "quiche_stream_iter_free", "(J)V", (void *) netty_quiche_stream_iter_free }, { "quiche_stream_iter_next", "(J[J)I", (void *) netty_quiche_stream_iter_next }, { "quiche_conn_dgram_max_writable_len", "(J)I", (void* ) netty_quiche_conn_dgram_max_writable_len }, diff --git a/src/main/java/io/netty/incubator/codec/quic/Quiche.java b/src/main/java/io/netty/incubator/codec/quic/Quiche.java index 5e96dd96e..72a013ba8 100644 --- a/src/main/java/io/netty/incubator/codec/quic/Quiche.java +++ b/src/main/java/io/netty/incubator/codec/quic/Quiche.java @@ -372,6 +372,12 @@ static native int quiche_conn_stream_priority( */ static native long quiche_conn_readable(long connAddr); + /** + * See + * quiche_conn_writable. + */ + static native long quiche_conn_writable(long connAddr); + /** * See * quiche_stream_iter_next. diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java index b56561f9c..44bdb7876 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicChannel.java @@ -99,8 +99,9 @@ public void operationComplete(ChannelFuture future) { private static final ChannelMetadata METADATA = new ChannelMetadata(false); private final long[] readableStreams = new long[128]; + private final long[] writableStreams = new long[128]; + private final LongObjectMap streams = new LongObjectHashMap<>(); - private final Queue flushPendingQueue = new ArrayDeque<>(); private final QuicheQuicChannelConfig config; private final boolean server; private final QuicStreamIdGenerator idGenerator; @@ -316,7 +317,6 @@ void forceClose() { state = CLOSED; closeStreams(); - flushPendingQueue.clear(); if (finBuffer != null) { finBuffer.release(); @@ -782,53 +782,56 @@ void writable() { } } - void streamHasPendingWrites(long streamId) { - flushPendingQueue.add(streamId); + int streamCapacity(long streamId) { + if (connection.isClosed()) { + return 0; + } + return Quiche.quiche_conn_stream_capacity(connection.address(), streamId); } private boolean handleWritableStreams() { - int pending = flushPendingQueue.size(); - if (isConnDestroyed() || pending == 0) { + if (isConnDestroyed()) { return false; } inHandleWritableStreams = true; try { long connAddr = connection.address(); boolean mayNeedWrite = false; + if (Quiche.quiche_conn_is_established(connAddr) || Quiche.quiche_conn_is_in_early_data(connAddr)) { - // We only want to process the number of channels that were in the queue when we entered - // handleWritableStreams(). Otherwise we may would loop forever as a channel may add itself again - // if the write was again partial. - for (int i = 0; i < pending; i++) { - Long streamId = flushPendingQueue.poll(); - if (streamId == null) { - break; - } - // Checking quiche_conn_stream_capacity(...) is cheaper then calling channel.writable() just - // to notice that we can not write again. - int capacity = Quiche.quiche_conn_stream_capacity(connAddr, streamId); - if (capacity == 0) { - // Still not writable, put back in the queue. - flushPendingQueue.add(streamId); - } else { - long sid = streamId; - QuicheQuicStreamChannel channel = streams.get(sid); - if (channel != null) { - if (capacity > 0) { - mayNeedWrite = true; - channel.writable(capacity); - } else { - if (!Quiche.quiche_conn_stream_finished(connAddr, sid)) { - // Only fire an exception if the error was not caused because the stream is - // considered finished. - channel.pipeline().fireExceptionCaught(Quiche.newException(capacity)); + long writableIterator = Quiche.quiche_conn_writable(connAddr); + + try { + // For streams we always process all streams when at least on read was requested. + for (;;) { + int writable = Quiche.quiche_stream_iter_next( + writableIterator, writableStreams); + for (int i = 0; i < writable; i++) { + long streamId = writableStreams[i]; + QuicheQuicStreamChannel streamChannel = streams.get(streamId); + if (streamChannel != null) { + int capacity = Quiche.quiche_conn_stream_capacity(connAddr, streamId); + if (capacity < 0) { + if (!Quiche.quiche_conn_stream_finished(connAddr, streamId)) { + // Only fire an exception if the error was not caused because the stream is + // considered finished. + streamChannel.pipeline().fireExceptionCaught(Quiche.newException(capacity)); + } + // Let's close the channel if quiche_conn_stream_capacity(...) returns an error. + streamChannel.forceClose(); + } else if (streamChannel.writable(capacity)) { + mayNeedWrite = true; } - // Let's close the channel if quiche_conn_stream_capacity(...) returns an error. - channel.forceClose(); } } + if (writable < writableStreams.length) { + // We did handle all writable streams. + break; + } } + } finally { + Quiche.quiche_stream_iter_free(writableIterator); } } return mayNeedWrite; @@ -1343,6 +1346,7 @@ private QuicheQuicStreamChannel addNewStreamChannel(long streamId) { QuicheQuicChannel.this, streamId); QuicheQuicStreamChannel old = streams.put(streamId, streamChannel); assert old == null; + streamChannel.writable(streamCapacity(streamId)); return streamChannel; } } diff --git a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicStreamChannel.java b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicStreamChannel.java index 36e3e40a3..b1fe74079 100644 --- a/src/main/java/io/netty/incubator/codec/quic/QuicheQuicStreamChannel.java +++ b/src/main/java/io/netty/incubator/codec/quic/QuicheQuicStreamChannel.java @@ -74,6 +74,7 @@ final class QuicheQuicStreamChannel extends DefaultAttributeMap implements QuicS private volatile boolean inputShutdown; private volatile boolean outputShutdown; private volatile QuicStreamPriority priority; + private volatile int capacity; QuicheQuicStreamChannel(QuicheQuicChannel parent, long streamId) { this.parent = parent; @@ -328,12 +329,16 @@ public boolean isWritable() { @Override public long bytesBeforeUnwritable() { - return 0; + return capacity; } @Override public long bytesBeforeWritable() { - return 0; + if (writable) { + return 0; + } + // Just return something positive for now + return 8; } @Override @@ -359,8 +364,18 @@ public int compareTo(Channel o) { /** * Stream is writable. */ - void writable(@SuppressWarnings("unused") int capacity) { - ((QuicStreamChannelUnsafe) unsafe()).writeQueued(); + boolean writable(@SuppressWarnings("unused") int capacity) { + this.capacity = capacity; + boolean mayNeedWrite = ((QuicStreamChannelUnsafe) unsafe()).writeQueued(); + updateWritabilityIfNeeded(capacity > 0); + return mayNeedWrite; + } + + private void updateWritabilityIfNeeded(boolean newWritable) { + if (writable != newWritable) { + writable = newWritable; + pipeline.fireChannelWritabilityChanged(); + } } /** @@ -559,10 +574,14 @@ private void closeIfNeeded(boolean wasFinSent) { } } - void writeQueued() { + boolean writeQueued() { boolean wasFinSent = QuicheQuicStreamChannel.this.finSent; inWriteQueued = true; try { + if (queue.isEmpty()) { + return false; + } + boolean written = false; for (;;) { Object msg = queue.current(); if (msg == null) { @@ -570,18 +589,17 @@ void writeQueued() { } try { if (!write0(msg)) { - return; + return written; } } catch (Exception e) { queue.remove().setFailure(e); continue; } queue.remove().setSuccess(); + written = true; } - if (!writable) { - writable = true; - pipeline.fireChannelWritabilityChanged(); - } + updateWritabilityIfNeeded(true); + return written; } finally { closeIfNeeded(wasFinSent); inWriteQueued = false; @@ -626,21 +644,24 @@ public void write(Object msg, ChannelPromise promise) { } boolean wasFinSent = QuicheQuicStreamChannel.this.finSent; + boolean mayNeedWritabilityUpdate = false; try { if (write0(msg)) { ReferenceCountUtil.release(msg); promise.setSuccess(); + mayNeedWritabilityUpdate = capacity == 0; } else { queue.add(msg, promise); - if (writable) { - writable = false; - pipeline.fireChannelWritabilityChanged(); - } + mayNeedWritabilityUpdate = true; } } catch (Exception e) { ReferenceCountUtil.release(msg); promise.setFailure(e); + mayNeedWritabilityUpdate = capacity == 0; } finally { + if (mayNeedWritabilityUpdate) { + updateWritabilityIfNeeded(false); + } closeIfNeeded(wasFinSent); } } @@ -673,8 +694,13 @@ private boolean write0(Object msg) throws Exception { try { do { int res = parent().streamSend(streamId(), buffer, fin); + + // Update the capacity as well. + int cap = parent.streamCapacity(streamId()); + if (cap >= 0) { + capacity = cap; + } if (Quiche.throwIfError(res) || res == 0) { - parent.streamHasPendingWrites(streamId()); return false; } sendSomething = true; diff --git a/src/test/java/io/netty/incubator/codec/quic/QuicWritableTest.java b/src/test/java/io/netty/incubator/codec/quic/QuicWritableTest.java index 443916b5e..91a7bea57 100644 --- a/src/test/java/io/netty/incubator/codec/quic/QuicWritableTest.java +++ b/src/test/java/io/netty/incubator/codec/quic/QuicWritableTest.java @@ -16,6 +16,7 @@ package io.netty.incubator.codec.quic; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -25,8 +26,11 @@ import org.junit.Test; import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -144,6 +148,127 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { } } + @Test(timeout = 5000) + public void testBytesUntilUnwritable() throws Throwable { + Promise writePromise = ImmediateEventExecutor.INSTANCE.newPromise(); + final AtomicReference serverErrorRef = new AtomicReference<>(); + final AtomicReference clientErrorRef = new AtomicReference<>(); + final CountDownLatch writableAgainLatch = new CountDownLatch(1); + int firstWriteNumBytes = 8; + int maxData = 32 * 1024; + final AtomicLong beforeWritableRef = new AtomicLong(); + Channel server = QuicTestUtils.newServer( + QuicTestUtils.newQuicServerBuilder().initialMaxStreamsBidirectional(5000), + InsecureQuicTokenHandler.INSTANCE, + null, new ChannelInboundHandlerAdapter() { + + private int numBytesRead; + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ByteBuf buffer = (ByteBuf) msg; + numBytesRead += buffer.readableBytes(); + buffer.release(); + if (numBytesRead == firstWriteNumBytes) { + long before = ctx.channel().bytesBeforeUnwritable(); + beforeWritableRef.set(before); + assertTrue(before > 0); + + while (before != 0) { + int size = (int) Math.min(before, 1024); + ctx.write(ctx.alloc().buffer(size).writeZero(size)); + long newBefore = ctx.channel().bytesBeforeUnwritable(); + + assertEquals(before, newBefore + size); + before = newBefore; + } + ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(new PromiseNotifier<>(writePromise)); + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + if (ctx.channel().isWritable()) { + if (ctx.channel().bytesBeforeUnwritable() > 0) { + writableAgainLatch.countDown(); + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + serverErrorRef.set(cause); + } + + @Override + public boolean isSharable() { + return true; + } + }); + InetSocketAddress address = (InetSocketAddress) server.localAddress(); + Channel channel = QuicTestUtils.newClient(QuicTestUtils.newQuicClientBuilder() + .initialMaxStreamDataBidirectionalLocal(maxData)); + try { + QuicChannel quicChannel = QuicChannel.newBootstrap(channel) + .handler(new ChannelInboundHandlerAdapter()) + .streamHandler(new ChannelInboundHandlerAdapter()) + .remoteAddress(address) + .connect() + .get(); + QuicStreamChannel stream = quicChannel.createStream( + QuicStreamType.BIDIRECTIONAL, new ChannelInboundHandlerAdapter() { + int bytes; + + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + ctx.channel().config().setAutoRead(false); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.writeAndFlush(ctx.alloc().buffer(firstWriteNumBytes).writeZero(firstWriteNumBytes)); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ByteBuf buffer = (ByteBuf) msg; + bytes += buffer.readableBytes(); + buffer.release(); + if (bytes == beforeWritableRef.get()) { + assertTrue(writePromise.isDone()); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.read(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + clientErrorRef.set(cause); + } + }).get(); + assertFalse(writePromise.isDone()); + + // Let's trigger the reads. This will ensure we will consume the data and the remote peer + // should be notified that it can write more data. + stream.read(); + + writePromise.sync(); + writableAgainLatch.await(); + stream.close().sync(); + stream.closeFuture().sync(); + quicChannel.close().sync(); + + throwIfNotNull(serverErrorRef); + throwIfNotNull(clientErrorRef); + } finally { + server.close().sync(); + // Close the parent Datagram channel as well. + channel.close().sync(); + } + } + private static void throwIfNotNull(AtomicReference errorRef) throws Throwable { Throwable cause = errorRef.get(); if (cause != null) {