Skip to content

Commit

Permalink
Ensure we also set the QuicConnectionId when the id has a length of 0. (
Browse files Browse the repository at this point in the history
java-native-access#649)

Motivation:

We did not correctly handle the case when the remote peer used a
connection id with length of 0.

Modifications:

- Correctly handle the case when an id is of length 0
- Modify unit tests to catch the problem

Result:

Correct handling of zero length connection ids
  • Loading branch information
normanmaurer committed Jan 25, 2024
1 parent db92d88 commit d3e2e6a
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.util.internal.EmptyArrays;

import java.net.SocketAddress;
import java.nio.ByteBuffer;
Expand All @@ -28,6 +29,8 @@
*/
public final class QuicConnectionAddress extends SocketAddress {

static final QuicConnectionAddress NULL_LEN = new QuicConnectionAddress(EmptyArrays.EMPTY_BYTES);

/**
* Special {@link QuicConnectionAddress} that should be used when the connection address should be generated
* and chosen on the fly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ QuicConnectionAddress connectionId(Supplier<byte[]> idSupplier) {
}
id = idSupplier.get();
}
return id == null ? null : new QuicConnectionAddress(id);
return id == null ? QuicConnectionAddress.NULL_LEN : new QuicConnectionAddress(id);
}

QuicheQuicTransportParameters peerParameters() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

Expand Down Expand Up @@ -253,24 +254,25 @@ public void testAddressValidation(Executor executor) throws Throwable {
@ParameterizedTest
@MethodSource("newSslTaskExecutors")
public void testConnectWithCustomIdLength(Executor executor) throws Throwable {
testConnectWithCustomIdLength(executor, 10);
testConnectWithCustomIdLength(executor, 10, 5);
}

@ParameterizedTest
@MethodSource("newSslTaskExecutors")
public void testConnectWithCustomIdLengthOfZero(Executor executor) throws Throwable {
testConnectWithCustomIdLength(executor, 0);
testConnectWithCustomIdLength(executor, 0, 0);
}

private static void testConnectWithCustomIdLength(Executor executor, int idLength) throws Throwable {
private static void testConnectWithCustomIdLength(Executor executor, int clientIdLength, int serverIdLength)
throws Throwable {
ChannelActiveVerifyHandler serverQuicChannelHandler = new ChannelActiveVerifyHandler();
ChannelStateVerifyHandler serverQuicStreamHandler = new ChannelStateVerifyHandler();
Channel server = QuicTestUtils.newServer(QuicTestUtils.newQuicServerBuilder(executor)
.localConnectionIdLength(idLength),
.localConnectionIdLength(serverIdLength),
InsecureQuicTokenHandler.INSTANCE, serverQuicChannelHandler, serverQuicStreamHandler);
InetSocketAddress address = (InetSocketAddress) server.localAddress();
Channel channel = QuicTestUtils.newClient(QuicTestUtils.newQuicClientBuilder(executor)
.localConnectionIdLength(idLength));
.localConnectionIdLength(clientIdLength));
try {
ChannelActiveVerifyHandler clientQuicChannelHandler = new ChannelActiveVerifyHandler();
QuicChannel quicChannel = QuicTestUtils.newQuicChannelBootstrap(channel)
Expand All @@ -283,8 +285,12 @@ private static void testConnectWithCustomIdLength(Executor executor, int idLengt
ChannelFuture closeFuture = quicChannel.closeFuture().await();
assertTrue(closeFuture.isSuccess());
clientQuicChannelHandler.assertState();
assertEquals(clientIdLength, clientQuicChannelHandler.localAddress().connId.remaining());
assertEquals(serverIdLength, clientQuicChannelHandler.remoteAddress().connId.remaining());
} finally {
serverQuicChannelHandler.assertState();
assertEquals(serverIdLength, serverQuicChannelHandler.localAddress().connId.remaining());
assertEquals(clientIdLength, serverQuicChannelHandler.remoteAddress().connId.remaining());
serverQuicStreamHandler.assertState();

server.close().sync();
Expand Down Expand Up @@ -1439,8 +1445,6 @@ public void channelInactive(ChannelHandlerContext ctx) {

private static final class ChannelActiveVerifyHandler extends QuicChannelValidationHandler {
private final BlockingQueue<Integer> states = new LinkedBlockingQueue<>();
private volatile QuicConnectionAddress localAddress;
private volatile QuicConnectionAddress remoteAddress;

@Override
public void channelRegistered(ChannelHandlerContext ctx) {
Expand All @@ -1456,9 +1460,7 @@ public void channelUnregistered(ChannelHandlerContext ctx) {

@Override
public void channelActive(ChannelHandlerContext ctx) {
localAddress = (QuicConnectionAddress) ctx.channel().localAddress();
remoteAddress = (QuicConnectionAddress) ctx.channel().remoteAddress();
ctx.fireChannelActive();
super.channelActive(ctx);
states.add(1);
}

Expand All @@ -1476,14 +1478,6 @@ void assertState() throws Throwable {
assertNull(states.poll());
super.assertState();
}

QuicConnectionAddress localAddress() {
return localAddress;
}

QuicConnectionAddress remoteAddress() {
return remoteAddress;
}
}

private abstract static class TestX509ExtendedTrustManager extends X509ExtendedTrustManager {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ private void testDatagramNoAutoRead(Executor executor, int maxMessagesPerRead, b

@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
ctx.read();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ public void testEchoStartedFromClient(boolean autoRead, boolean directBuffer, bo
QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
setAllocator(ctx.channel(), allocator);
ctx.channel().config().setAutoRead(autoRead);
if (!autoRead) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,44 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

import static org.junit.jupiter.api.Assertions.assertNotNull;

class QuicChannelValidationHandler extends ChannelInboundHandlerAdapter {
private volatile boolean wasActive;

private volatile QuicConnectionAddress localAddress;
private volatile QuicConnectionAddress remoteAddress;
private volatile Throwable cause;

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
this.cause = cause;
}

@Override
public void channelActive(ChannelHandlerContext ctx) {
localAddress = (QuicConnectionAddress) ctx.channel().localAddress();
remoteAddress = (QuicConnectionAddress) ctx.channel().remoteAddress();
wasActive = true;
ctx.fireChannelActive();
}

QuicConnectionAddress localAddress() {
return localAddress;
}

QuicConnectionAddress remoteAddress() {
return remoteAddress;
}

void assertState() throws Throwable {
if (cause != null) {
throw cause;
}
if (wasActive) {
// Validate that the addresses could be retrieved
assertNotNull(localAddress);
assertNotNull(remoteAddress);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ public void testStatsAreCollected(Executor executor) throws Throwable {
QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
collectStats(ctx, serverActiveStats);
ctx.fireChannelActive();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ private static final class StreamCreationAndTearDownHandler extends QuicChannelV

@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
QuicChannel channel = (QuicChannel) ctx.channel();
channel.createStream(type, new ChannelInboundHandlerAdapter() {
@Override
Expand Down Expand Up @@ -271,6 +272,7 @@ private static final class StreamCreationHandler extends QuicChannelValidationHa

@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
QuicChannel channel = (QuicChannel) ctx.channel();
channel.createStream(type, new ChannelInboundHandlerAdapter() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ private static final class StreamCreationHandler extends QuicChannelValidationHa

@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
QuicChannel channel = (QuicChannel) ctx.channel();
channel.createStream(type, new ChannelInboundHandlerAdapter() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ private static final class StreamCreationHandler extends QuicChannelValidationHa

@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
QuicChannel channel = (QuicChannel) ctx.channel();
channel.createStream(type, new ChannelInboundHandlerAdapter() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ private static void testStreamLimitEnforcedWhenCreatingViaServer(Executor execut
QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
QuicChannel channel = (QuicChannel) ctx.channel();
channel.createStream(type, new ChannelInboundHandlerAdapter())
.addListener((Future<QuicStreamChannel> future) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public void testUnidirectionalCreatedByServer(Executor executor) throws Throwabl
QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
QuicChannel channel = (QuicChannel) ctx.channel();
channel.createStream(QuicStreamType.UNIDIRECTIONAL, new ChannelInboundHandlerAdapter() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ public void testParameters(Executor executor) throws Throwable {
QuicChannelValidationHandler serverHandler = new QuicChannelValidationHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
super.channelActive(ctx);
QuicheQuicChannel channel = (QuicheQuicChannel) ctx.channel();
serverParams.setSuccess(channel.peerTransportParameters());
ctx.fireChannelActive();
}
};
QuicChannelValidationHandler clientHandler = new QuicChannelValidationHandler();
Expand Down

0 comments on commit d3e2e6a

Please sign in to comment.