Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split fizz specific ClientTransportParametersExtension into FizzClientExtensions #76

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions quic/client/handshake/ClientHandshake.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ class ClientHandshake : public Handshake {
virtual void connect(
folly::Optional<std::string> hostname,
folly::Optional<fizz::client::CachedPsk> cachedPsk,
const std::shared_ptr<ClientTransportParametersExtension>&
transportParams,
std::shared_ptr<ClientTransportParametersExtension> transportParams,
HandshakeCallback* callback) = 0;

/**
Expand Down
58 changes: 1 addition & 57 deletions quic/client/handshake/ClientTransportParametersExtension.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@

#pragma once

#include <fizz/client/ClientExtensions.h>
#include <quic/handshake/FizzTransportParameters.h>

namespace quic {

class ClientTransportParametersExtension : public fizz::ClientExtensions {
public:
struct ClientTransportParametersExtension {
ClientTransportParametersExtension(
folly::Optional<QuicVersion> initialVersion,
uint64_t initialMaxData,
Expand All @@ -38,64 +36,10 @@ class ClientTransportParametersExtension : public fizz::ClientExtensions {
activeConnectionLimit_(activeConnectionIdLimit),
customTransportParameters_(customTransportParameters) {}

~ClientTransportParametersExtension() override = default;

std::vector<fizz::Extension> getClientHelloExtensions() const override {
std::vector<fizz::Extension> exts;

ClientTransportParameters params;
params.initial_version = initialVersion_;
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_bidi_local,
initialMaxStreamDataBidiLocal_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_bidi_remote,
initialMaxStreamDataBidiRemote_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_uni,
initialMaxStreamDataUni_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_data, initialMaxData_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_streams_bidi,
std::numeric_limits<uint32_t>::max()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_streams_uni,
std::numeric_limits<uint32_t>::max()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::idle_timeout, idleTimeout_.count()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::ack_delay_exponent, ackDelayExponent_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::max_packet_size, maxRecvPacketSize_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::active_connection_id_limit,
activeConnectionLimit_));

for (const auto& customParameter : customTransportParameters_) {
params.parameters.push_back(customParameter);
}

exts.push_back(encodeExtension(params));
return exts;
}

void onEncryptedExtensions(
const std::vector<fizz::Extension>& exts) override {
auto serverParams = fizz::getExtension<ServerTransportParameters>(exts);
if (!serverParams) {
throw fizz::FizzException(
"missing server quic transport parameters extension",
fizz::AlertDescription::missing_extension);
}
serverTransportParameters_ = std::move(serverParams);
}

folly::Optional<ServerTransportParameters> getServerTransportParams() {
return std::move(serverTransportParameters_);
}

private:
folly::Optional<QuicVersion> initialVersion_;
uint64_t initialMaxData_;
uint64_t initialMaxStreamDataBidiLocal_;
Expand Down
84 changes: 84 additions & 0 deletions quic/client/handshake/FizzClientExtensions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/

#pragma once

#include <fizz/client/ClientExtensions.h>
#include <quic/client/handshake/ClientTransportParametersExtension.h>
#include <quic/handshake/FizzTransportParameters.h>

namespace quic {

class FizzClientExtensions : public fizz::ClientExtensions {
public:
FizzClientExtensions(
std::shared_ptr<ClientTransportParametersExtension> clientParameters)
: clientParameters_(std::move(clientParameters)) {}

~FizzClientExtensions() override = default;

std::vector<fizz::Extension> getClientHelloExtensions() const override {
std::vector<fizz::Extension> exts;

ClientTransportParameters params;
params.initial_version = clientParameters_->initialVersion_;
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_bidi_local,
clientParameters_->initialMaxStreamDataBidiLocal_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_bidi_remote,
clientParameters_->initialMaxStreamDataBidiRemote_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_uni,
clientParameters_->initialMaxStreamDataUni_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_data,
clientParameters_->initialMaxData_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_streams_bidi,
std::numeric_limits<uint32_t>::max()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_streams_uni,
std::numeric_limits<uint32_t>::max()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::idle_timeout,
clientParameters_->idleTimeout_.count()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::ack_delay_exponent,
clientParameters_->ackDelayExponent_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::max_packet_size,
clientParameters_->maxRecvPacketSize_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::active_connection_id_limit,
clientParameters_->activeConnectionLimit_));

for (const auto& customParameter :
clientParameters_->customTransportParameters_) {
params.parameters.push_back(customParameter);
}

exts.push_back(encodeExtension(params));
return exts;
}

void onEncryptedExtensions(
const std::vector<fizz::Extension>& exts) override {
auto serverParams = fizz::getExtension<ServerTransportParameters>(exts);
if (!serverParams) {
throw fizz::FizzException(
"missing server quic transport parameters extension",
fizz::AlertDescription::missing_extension);
}
clientParameters_->serverTransportParameters_ = std::move(serverParams);
}

private:
std::shared_ptr<ClientTransportParametersExtension> clientParameters_;
};
} // namespace quic
5 changes: 3 additions & 2 deletions quic/client/handshake/FizzClientHandshake.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <quic/client/handshake/FizzClientHandshake.h>

#include <folly/Overload.h>
#include <quic/client/handshake/FizzClientExtensions.h>
#include <quic/client/handshake/FizzClientQuicHandshakeContext.h>
#include <quic/handshake/FizzBridge.h>

Expand All @@ -24,7 +25,7 @@ FizzClientHandshake::FizzClientHandshake(
void FizzClientHandshake::connect(
folly::Optional<std::string> hostname,
folly::Optional<fizz::client::CachedPsk> cachedPsk,
const std::shared_ptr<ClientTransportParametersExtension>& transportParams,
std::shared_ptr<ClientTransportParametersExtension> transportParams,
HandshakeCallback* callback) {
transportParams_ = transportParams;
callback_ = callback;
Expand All @@ -43,7 +44,7 @@ void FizzClientHandshake::connect(
fizzContext_->getCertificateVerifier(),
std::move(hostname),
std::move(cachedPsk),
transportParams));
std::make_shared<FizzClientExtensions>(std::move(transportParams))));
}

const CryptoFactory& FizzClientHandshake::getCryptoFactory() const {
Expand Down
3 changes: 1 addition & 2 deletions quic/client/handshake/FizzClientHandshake.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ class FizzClientHandshake : public ClientHandshake {
void connect(
folly::Optional<std::string> hostname,
folly::Optional<fizz::client::CachedPsk> cachedPsk,
const std::shared_ptr<ClientTransportParametersExtension>&
transportParams,
std::shared_ptr<ClientTransportParametersExtension> transportParams,
HandshakeCallback* callback) override;

const CryptoFactory& getCryptoFactory() const override;
Expand Down
2 changes: 1 addition & 1 deletion quic/client/handshake/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ endif()
quic_add_test(TARGET ClientHandshakeTest
SOURCES
ClientHandshakeTest.cpp
ClientTransportParametersTest.cpp
FizzClientExtensionsTest.cpp
DEPENDS
Folly::folly
${LIBFIZZ_LIBRARY}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include <gtest/gtest.h>

#include <quic/client/handshake/ClientTransportParametersExtension.h>
#include <quic/client/handshake/FizzClientExtensions.h>
#include <quic/common/test/TestUtils.h>

#include <fizz/protocol/test/TestMessages.h>
Expand All @@ -30,8 +30,8 @@ static EncryptedExtensions getEncryptedExtensions() {
return ee;
}

TEST(ClientTransportParametersTest, TestGetChloExtensions) {
ClientTransportParametersExtension ext(
TEST(FizzClientHandshakeTest, TestGetChloExtensions) {
FizzClientExtensions ext(std::make_shared<ClientTransportParametersExtension>(
folly::none,
kDefaultConnectionWindowSize,
kDefaultStreamWindowSize,
Expand All @@ -40,16 +40,16 @@ TEST(ClientTransportParametersTest, TestGetChloExtensions) {
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen,
kDefaultActiveConnectionIdLimit);
kDefaultActiveConnectionIdLimit));
auto extensions = ext.getClientHelloExtensions();

EXPECT_EQ(extensions.size(), 1);
auto serverParams = getExtension<ClientTransportParameters>(extensions);
EXPECT_TRUE(serverParams.hasValue());
}

TEST(ClientTransportParametersTest, TestOnEE) {
ClientTransportParametersExtension ext(
TEST(FizzClientHandshakeTest, TestOnEE) {
FizzClientExtensions ext(std::make_shared<ClientTransportParametersExtension>(
MVFST1,
kDefaultConnectionWindowSize,
kDefaultStreamWindowSize,
Expand All @@ -58,13 +58,13 @@ TEST(ClientTransportParametersTest, TestOnEE) {
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen,
kDefaultActiveConnectionIdLimit);
kDefaultActiveConnectionIdLimit));
ext.getClientHelloExtensions();
ext.onEncryptedExtensions(getEncryptedExtensions().extensions);
}

TEST(ClientTransportParametersTest, TestOnEEMissingServerParams) {
ClientTransportParametersExtension ext(
TEST(FizzClientHandshakeTest, TestOnEEMissingServerParams) {
FizzClientExtensions ext(std::make_shared<ClientTransportParametersExtension>(
MVFST1,
kDefaultConnectionWindowSize,
kDefaultStreamWindowSize,
Expand All @@ -73,14 +73,14 @@ TEST(ClientTransportParametersTest, TestOnEEMissingServerParams) {
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen,
kDefaultActiveConnectionIdLimit);
kDefaultActiveConnectionIdLimit));
ext.getClientHelloExtensions();
EXPECT_THROW(
ext.onEncryptedExtensions(TestMessages::encryptedExt().extensions),
FizzException);
}

TEST(ClientTransportParametersTest, TestGetChloExtensionsCustomParams) {
TEST(FizzClientHandshakeTest, TestGetChloExtensionsCustomParams) {
std::vector<TransportParameter> customTransportParameters;

std::string randomBytes = "\x01\x00\x55\x12\xff";
Expand All @@ -99,7 +99,7 @@ TEST(ClientTransportParametersTest, TestGetChloExtensionsCustomParams) {
customTransportParameters.push_back(element2->encode());
customTransportParameters.push_back(element3->encode());

ClientTransportParametersExtension ext(
FizzClientExtensions ext(std::make_shared<ClientTransportParametersExtension>(
folly::none,
kDefaultConnectionWindowSize,
kDefaultStreamWindowSize,
Expand All @@ -109,7 +109,7 @@ TEST(ClientTransportParametersTest, TestGetChloExtensionsCustomParams) {
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen,
kDefaultActiveConnectionIdLimit,
customTransportParameters);
customTransportParameters));
auto extensions = ext.getClientHelloExtensions();

EXPECT_EQ(extensions.size(), 1);
Expand Down Expand Up @@ -155,5 +155,5 @@ TEST(ClientTransportParametersTest, TestGetChloExtensionsCustomParams) {

EXPECT_TRUE(eq(folly::IOBuf::copyBuffer(randomBytes), it3->value));
}
}
}
} // namespace test
} // namespace quic
2 changes: 1 addition & 1 deletion quic/client/test/QuicClientTransportTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ class FakeOneRttHandshakeLayer : public ClientHandshake {
void connect(
folly::Optional<std::string>,
folly::Optional<fizz::client::CachedPsk>,
const std::shared_ptr<ClientTransportParametersExtension>&,
std::shared_ptr<ClientTransportParametersExtension>,
HandshakeCallback* callback) override {
connected_ = true;
writeDataToQuicStream(
Expand Down
8 changes: 6 additions & 2 deletions quic/server/handshake/test/ServerHandshakeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <folly/ssl/Init.h>

#include <quic/QuicConstants.h>
#include <quic/client/handshake/ClientTransportParametersExtension.h>
#include <quic/client/handshake/FizzClientExtensions.h>
#include <quic/common/test/TestUtils.h>
#include <quic/handshake/FizzBridge.h>
#include <quic/handshake/HandshakeLayer.h>
Expand Down Expand Up @@ -151,7 +151,11 @@ class ServerHandshakeTest : public Test {
}));
auto cachedPsk = clientCtx->getPsk(hostname);
fizzClient->connect(
clientCtx, verifier, hostname, cachedPsk, clientExtensions);
clientCtx,
verifier,
hostname,
cachedPsk,
std::make_shared<FizzClientExtensions>(clientExtensions));
}

void clientServerRound() {
Expand Down