Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls: Enable ALPS for HTTP/2 #1123

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4150,6 +4150,7 @@ set(files
src/net/http_parser.cpp
src/net/padding.cpp
src/net/resolver.cpp
src/net/protocol.cpp
src/crypto/aead_base_decrypter.cpp
src/crypto/aead_base_encrypter.cpp
src/crypto/aead_evp_decrypter.cpp
Expand Down
14 changes: 14 additions & 0 deletions src/net/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ class SSLDownlink : public Downlink {
auto callback = std::move(handshake_callback_);
DCHECK(!handshake_callback_);
asio::error_code ec = result == OK ? asio::error_code() : asio::error::connection_refused;
if (!ec) {
auto alpn = ssl_socket_->negotiated_protocol();
switch (alpn) {
case kProtoHTTP2:
DCHECK(!https_fallback_) << " unexpected alpn: " << NextProtoToString(alpn);
break;
case kProtoHTTP11:
DCHECK(https_fallback_) << " unexpected alpn: " << NextProtoToString(alpn);
break;
default:
LOG(WARNING) << "Alpn unexpected: " << NextProtoToString(alpn);
}
VLOG(2) << "Alpn selected (server): " << NextProtoToString(alpn);
}
if (callback) {
callback(ec);
}
Expand Down
28 changes: 9 additions & 19 deletions src/net/content_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "net/asio.hpp"
#include "net/connection.hpp"
#include "net/network.hpp"
#include "net/protocol.hpp"
#include "net/ssl_socket.hpp"
#include "net/x509_util.hpp"

Expand Down Expand Up @@ -459,23 +460,27 @@ class ContentServer {
if (in[0] + 1u > inlen) {
goto err;
}
using std::string_view_literals::operator""sv;
auto alpn = std::string_view(reinterpret_cast<const char*>(in + 1), in[0]);
if (!server->https_fallback_ && alpn == "h2"sv) {
if (!server->https_fallback_ && NextProtoFromString(alpn) == kProtoHTTP2) {
VLOG(1) << "Connection (" << T::Name << ") " << connection_id << " Alpn support (server) chosen: " << alpn;
server->set_https_fallback(connection_id, false);
*out = in + 1;
*outlen = in[0];

// Enable ALPS for HTTP/2 with empty data.
std::vector<uint8_t> data;
SSL_add_application_settings(ssl, reinterpret_cast<const uint8_t*>(alpn.data()), alpn.size(), data.data(),
data.size());
return SSL_TLSEXT_ERR_OK;
}
if (alpn == "http/1.1"sv) {
if (NextProtoFromString(alpn) == kProtoHTTP11) {
VLOG(1) << "Connection (" << T::Name << ") " << connection_id << " Alpn support (server) chosen: " << alpn;
server->set_https_fallback(connection_id, true);
*out = in + 1;
*outlen = in[0];
return SSL_TLSEXT_ERR_OK;
}
LOG(WARNING) << "Unexpected alpn: " << alpn;
LOG(WARNING) << "Connection (" << T::Name << ") " << connection_id << " Unexpected alpn: " << alpn;
inlen -= 1u + in[0];
in += 1u + in[0];
}
Expand Down Expand Up @@ -591,21 +596,6 @@ class ContentServer {
VLOG(1) << "Using upstream certificate (in-memory)";
}

int ret;
std::vector<unsigned char> alpn_vec = {2, 'h', '2', 8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
if (upstream_https_fallback_) {
alpn_vec = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
}
ret = SSL_CTX_set_alpn_protos(ctx, alpn_vec.data(), alpn_vec.size());
static_cast<void>(ret);
DCHECK_EQ(ret, 0);
if (ret) {
print_openssl_error();
ec = asio::error::access_denied;
return;
}
VLOG(1) << "Alpn support (client) enabled";

client_instance_ = this;
ssl_socket_data_index_ = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);

Expand Down
40 changes: 40 additions & 0 deletions src/net/protocol.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2024 Chilledheart */

#include "net/protocol.hpp"

#include <string_view>

using std::string_view_literals::operator""sv;

namespace net {

NextProto NextProtoFromString(std::string_view proto_string) {
if (proto_string == "http/1.1"sv) {
return kProtoHTTP11;
}
if (proto_string == "h2"sv) {
return kProtoHTTP2;
}
if (proto_string == "quic"sv || proto_string == "hq"sv) {
return kProtoQUIC;
}

return kProtoUnknown;
}

std::string_view NextProtoToString(NextProto next_proto) {
switch (next_proto) {
case kProtoHTTP11:
return "http/1.1"sv;
case kProtoHTTP2:
return "h2"sv;
case kProtoQUIC:
return "quic"sv;
case kProtoUnknown:
break;
}
return "unknown";
}

} // namespace net
12 changes: 12 additions & 0 deletions src/net/protocol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#endif

#include <functional>
#include <string_view>
#include <utility>

#include <build/build_config.h>
Expand All @@ -22,6 +23,17 @@

namespace net {

// This enum is used in Net.SSLNegotiatedAlpnProtocol histogram.
// Do not change or re-use values.
enum NextProto { kProtoUnknown = 0, kProtoHTTP11 = 1, kProtoHTTP2 = 2, kProtoQUIC = 3, kProtoLast = kProtoQUIC };

// List of protocols to use for ALPN, used for configuring HttpNetworkSessions.
typedef std::vector<NextProto> NextProtoVector;

NextProto NextProtoFromString(std::string_view proto_string);

std::string_view NextProtoToString(NextProto next_proto);

#ifndef NDEBUG
inline void DumpHex_Impl(const char* file, int line, const char* prefix, const uint8_t* data, uint32_t length) {
if (!VLOG_IS_ON(4)) {
Expand Down
16 changes: 7 additions & 9 deletions src/net/ssl_server_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,13 @@ int SSLServerSocket::DoHandshake(int* openssl_result) {
int rv = SSL_do_handshake(ssl_.get());
*openssl_result = SSL_ERROR_NONE;
if (rv == 1) {
const uint8_t* alpn_proto = nullptr;
unsigned alpn_len = 0;
SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len);
if (alpn_len > 0) {
std::string_view proto(reinterpret_cast<const char*>(alpn_proto), alpn_len);
negotiated_protocol_ = NextProtoFromString(proto);
}
#if 0
const STACK_OF(CRYPTO_BUFFER)* certs =
SSL_get0_peer_certificates(ssl_.get());
Expand All @@ -310,15 +317,6 @@ int SSLServerSocket::DoHandshake(int* openssl_result) {
return ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT;
}

const uint8_t* alpn_proto = nullptr;
unsigned alpn_len = 0;
SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len);
if (alpn_len > 0) {
base::StringPiece proto(reinterpret_cast<const char*>(alpn_proto),
alpn_len);
negotiated_protocol_ = NextProtoFromString(proto);
}

if (context_->ssl_server_config_.alert_after_handshake_for_testing) {
SSL_send_fatal_alert(ssl_.get(),
context_->ssl_server_config_
Expand Down
5 changes: 5 additions & 0 deletions src/net/ssl_server_socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "net/iobuf.hpp"
#include "net/net_errors.hpp"
#include "net/openssl_util.hpp"
#include "net/protocol.hpp"

namespace net {

Expand Down Expand Up @@ -48,6 +49,8 @@ class SSLServerSocket : public RefCountedThreadSafe<SSLServerSocket> {
void WaitRead(WaitCallback&& cb);
void WaitWrite(WaitCallback&& cb);

NextProto negotiated_protocol() const { return negotiated_protocol_; }

protected:
void OnWaitRead(asio::error_code ec);
void OnWaitWrite(asio::error_code ec);
Expand Down Expand Up @@ -88,6 +91,8 @@ class SSLServerSocket : public RefCountedThreadSafe<SSLServerSocket> {
// Whether we received any data in early data.
bool early_data_received_ = false;

NextProto negotiated_protocol_ = kProtoUnknown;

enum State {
STATE_NONE,
STATE_HANDSHAKE,
Expand Down
47 changes: 39 additions & 8 deletions src/net/ssl_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,27 @@ const int kDefaultOpenSSLBufferSize = 17 * 1024;
static constexpr const int kMaximumSSLCache = 1024;
static absl::flat_hash_map<asio::ip::address, bssl::UniquePtr<SSL_SESSION>> g_ssl_lru_cache;

static std::vector<uint8_t> SerializeNextProtos(const NextProtoVector& next_protos) {
std::vector<uint8_t> wire_protos;
for (const NextProto next_proto : next_protos) {
const std::string_view proto = NextProtoToString(next_proto);
if (proto.size() > 255) {
LOG(WARNING) << "Ignoring overlong ALPN protocol: " << proto;
continue;
}
if (proto.size() == 0) {
LOG(WARNING) << "Ignoring empty ALPN protocol";
continue;
}
wire_protos.push_back(proto.size());
for (const char ch : proto) {
wire_protos.push_back(static_cast<uint8_t>(ch));
}
}

return wire_protos;
}

SSLSocket::SSLSocket(int ssl_socket_data_index,
asio::io_context* io_context,
asio::ip::tcp::socket* socket,
Expand Down Expand Up @@ -118,13 +139,22 @@ SSLSocket::SSLSocket(int ssl_socket_data_index,
LOG(FATAL) << "SSL_set_verify_algorithm_prefs failed";
}

// ALPS TLS extension is enabled and corresponding data is sent to client if
// client also enabled ALPS, for each NextProto in |application_settings|.
// Data might be empty.
const std::string_view proto_string = https_fallback ? "http/1.1"sv : "h2"sv;
std::vector<uint8_t> data;
SSL_add_application_settings(ssl_.get(), reinterpret_cast<const uint8_t*>(proto_string.data()), proto_string.size(),
data.data(), data.size());
NextProtoVector alpn_protos = {kProtoHTTP2, kProtoHTTP11};
if (https_fallback) {
alpn_protos = {kProtoHTTP11};
}
std::vector<uint8_t> wire_protos = SerializeNextProtos(alpn_protos);
SSL_set_alpn_protos(ssl_.get(), wire_protos.data(), wire_protos.size());

// Enable ALPS for HTTP/2 with empty data.
if (!https_fallback) {
std::string_view proto_string = NextProtoToString(kProtoHTTP2);
std::vector<uint8_t> data;
if (!SSL_add_application_settings(ssl_.get(), reinterpret_cast<const uint8_t*>(proto_string.data()),
proto_string.size(), data.data(), data.size())) {
LOG(FATAL) << "SSL_add_application_settings failed";
};
}

SSL_enable_signed_cert_timestamps(ssl_.get());
SSL_enable_ocsp_stapling(ssl_.get());
Expand Down Expand Up @@ -512,7 +542,8 @@ int SSLSocket::DoHandshakeComplete(int result) {
unsigned alpn_len = 0;
SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len);
if (alpn_len > 0) {
negotiated_protocol_ = std::string(reinterpret_cast<const char*>(alpn_proto), alpn_len);
std::string_view proto(reinterpret_cast<const char*>(alpn_proto), alpn_len);
negotiated_protocol_ = NextProtoFromString(proto);
}

const uint8_t* ocsp_response_raw;
Expand Down
8 changes: 4 additions & 4 deletions src/net/ssl_socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "net/iobuf.hpp"
#include "net/net_errors.hpp"
#include "net/openssl_util.hpp"
#include "net/protocol.hpp"

namespace net {

Expand Down Expand Up @@ -79,7 +80,7 @@ class SSLSocket : public RefCountedThreadSafe<SSLSocket> {
void WaitRead(WaitCallback&& cb);
void WaitWrite(WaitCallback&& cb);

const std::string& negotiated_protocol() const { return negotiated_protocol_; }
NextProto negotiated_protocol() const { return negotiated_protocol_; }

int NewSessionCallback(SSL_SESSION* session);

Expand Down Expand Up @@ -170,14 +171,13 @@ class SSLSocket : public RefCountedThreadSafe<SSLSocket> {
// ERR_SSL_CLIENT_AUTH_CERT_NEEDED.
bool send_client_cert_;

std::string negotiated_protocol_;
NextProto negotiated_protocol_ = kProtoUnknown;

bool IsRenegotiationAllowed() const {
using std::string_literals::operator""s;
// Prior to HTTP/2 and SPDY, some servers use TLS renegotiation to request
// TLS client authentication after the HTTP request was sent. Allow
// renegotiation for only those connections.
if (negotiated_protocol_ == "http/1.1"s) {
if (negotiated_protocol_ == kProtoHTTP11) {
return true;
}
// True if renegotiation should be allowed for the default application-level
Expand Down
9 changes: 3 additions & 6 deletions src/net/ssl_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,9 @@ class ssl_stream : public stream {
return;
}

using std::string_view_literals::operator""sv;
std::string_view alpn = ssl_socket_->negotiated_protocol();
if (!alpn.empty()) {
VLOG(2) << "Alpn selected (client): " << alpn;
}
https_fallback_ |= alpn == "http/1.1"sv;
auto alpn = ssl_socket_->negotiated_protocol();
VLOG(2) << "Alpn selected (client): " << NextProtoToString(alpn);
https_fallback_ |= alpn == kProtoHTTP11;
if (https_fallback_) {
VLOG(2) << "Alpn fallback to https protocol (client)";
}
Expand Down