From f877b02241c5969e177ff6c783bb1e9242294ade Mon Sep 17 00:00:00 2001 From: Keeyou Date: Fri, 18 Oct 2024 11:26:28 +0800 Subject: [PATCH] tls: Enable ALPS for HTTP/2 --- CMakeLists.txt | 1 + src/net/connection.hpp | 14 +++++++++++ src/net/content_server.hpp | 28 +++++++-------------- src/net/protocol.cpp | 40 +++++++++++++++++++++++++++++ src/net/protocol.hpp | 12 +++++++++ src/net/ssl_server_socket.cpp | 16 ++++++------ src/net/ssl_server_socket.hpp | 5 ++++ src/net/ssl_socket.cpp | 47 +++++++++++++++++++++++++++++------ src/net/ssl_socket.hpp | 8 +++--- src/net/ssl_stream.hpp | 9 +++---- 10 files changed, 134 insertions(+), 46 deletions(-) create mode 100644 src/net/protocol.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 430cd0f6a..fe48202c2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4150,6 +4150,7 @@ set(files src/net/http_parser.cpp src/net/padding.cpp src/net/resolver.cpp + src/net/protocol.cpp src/crypto/aead_base_decrypter.cpp src/crypto/aead_base_encrypter.cpp src/crypto/aead_evp_decrypter.cpp diff --git a/src/net/connection.hpp b/src/net/connection.hpp index 63831578e..5f2f8342e 100644 --- a/src/net/connection.hpp +++ b/src/net/connection.hpp @@ -92,6 +92,20 @@ class SSLDownlink : public Downlink { auto callback = std::move(handshake_callback_); DCHECK(!handshake_callback_); asio::error_code ec = result == OK ? asio::error_code() : asio::error::connection_refused; + if (!ec) { + auto alpn = ssl_socket_->negotiated_protocol(); + switch (alpn) { + case kProtoHTTP2: + DCHECK(!https_fallback_) << " unexpected alpn: " << NextProtoToString(alpn); + break; + case kProtoHTTP11: + DCHECK(https_fallback_) << " unexpected alpn: " << NextProtoToString(alpn); + break; + default: + LOG(WARNING) << "Alpn unexpected: " << NextProtoToString(alpn); + } + VLOG(2) << "Alpn selected (server): " << NextProtoToString(alpn); + } if (callback) { callback(ec); } diff --git a/src/net/content_server.hpp b/src/net/content_server.hpp index e489730ce..a58309476 100644 --- a/src/net/content_server.hpp +++ b/src/net/content_server.hpp @@ -23,6 +23,7 @@ #include "net/asio.hpp" #include "net/connection.hpp" #include "net/network.hpp" +#include "net/protocol.hpp" #include "net/ssl_socket.hpp" #include "net/x509_util.hpp" @@ -459,23 +460,27 @@ class ContentServer { if (in[0] + 1u > inlen) { goto err; } - using std::string_view_literals::operator""sv; auto alpn = std::string_view(reinterpret_cast(in + 1), in[0]); - if (!server->https_fallback_ && alpn == "h2"sv) { + if (!server->https_fallback_ && NextProtoFromString(alpn) == kProtoHTTP2) { VLOG(1) << "Connection (" << T::Name << ") " << connection_id << " Alpn support (server) chosen: " << alpn; server->set_https_fallback(connection_id, false); *out = in + 1; *outlen = in[0]; + + // Enable ALPS for HTTP/2 with empty data. + std::vector data; + SSL_add_application_settings(ssl, reinterpret_cast(alpn.data()), alpn.size(), data.data(), + data.size()); return SSL_TLSEXT_ERR_OK; } - if (alpn == "http/1.1"sv) { + if (NextProtoFromString(alpn) == kProtoHTTP11) { VLOG(1) << "Connection (" << T::Name << ") " << connection_id << " Alpn support (server) chosen: " << alpn; server->set_https_fallback(connection_id, true); *out = in + 1; *outlen = in[0]; return SSL_TLSEXT_ERR_OK; } - LOG(WARNING) << "Unexpected alpn: " << alpn; + LOG(WARNING) << "Connection (" << T::Name << ") " << connection_id << " Unexpected alpn: " << alpn; inlen -= 1u + in[0]; in += 1u + in[0]; } @@ -591,21 +596,6 @@ class ContentServer { VLOG(1) << "Using upstream certificate (in-memory)"; } - int ret; - std::vector alpn_vec = {2, 'h', '2', 8, 'h', 't', 't', 'p', '/', '1', '.', '1'}; - if (upstream_https_fallback_) { - alpn_vec = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'}; - } - ret = SSL_CTX_set_alpn_protos(ctx, alpn_vec.data(), alpn_vec.size()); - static_cast(ret); - DCHECK_EQ(ret, 0); - if (ret) { - print_openssl_error(); - ec = asio::error::access_denied; - return; - } - VLOG(1) << "Alpn support (client) enabled"; - client_instance_ = this; ssl_socket_data_index_ = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); diff --git a/src/net/protocol.cpp b/src/net/protocol.cpp new file mode 100644 index 000000000..54f9ce623 --- /dev/null +++ b/src/net/protocol.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2024 Chilledheart */ + +#include "net/protocol.hpp" + +#include + +using std::string_view_literals::operator""sv; + +namespace net { + +NextProto NextProtoFromString(std::string_view proto_string) { + if (proto_string == "http/1.1"sv) { + return kProtoHTTP11; + } + if (proto_string == "h2"sv) { + return kProtoHTTP2; + } + if (proto_string == "quic"sv || proto_string == "hq"sv) { + return kProtoQUIC; + } + + return kProtoUnknown; +} + +std::string_view NextProtoToString(NextProto next_proto) { + switch (next_proto) { + case kProtoHTTP11: + return "http/1.1"sv; + case kProtoHTTP2: + return "h2"sv; + case kProtoQUIC: + return "quic"sv; + case kProtoUnknown: + break; + } + return "unknown"; +} + +} // namespace net diff --git a/src/net/protocol.hpp b/src/net/protocol.hpp index 796239935..1ee6f0777 100644 --- a/src/net/protocol.hpp +++ b/src/net/protocol.hpp @@ -9,6 +9,7 @@ #endif #include +#include #include #include @@ -22,6 +23,17 @@ namespace net { +// This enum is used in Net.SSLNegotiatedAlpnProtocol histogram. +// Do not change or re-use values. +enum NextProto { kProtoUnknown = 0, kProtoHTTP11 = 1, kProtoHTTP2 = 2, kProtoQUIC = 3, kProtoLast = kProtoQUIC }; + +// List of protocols to use for ALPN, used for configuring HttpNetworkSessions. +typedef std::vector NextProtoVector; + +NextProto NextProtoFromString(std::string_view proto_string); + +std::string_view NextProtoToString(NextProto next_proto); + #ifndef NDEBUG inline void DumpHex_Impl(const char* file, int line, const char* prefix, const uint8_t* data, uint32_t length) { if (!VLOG_IS_ON(4)) { diff --git a/src/net/ssl_server_socket.cpp b/src/net/ssl_server_socket.cpp index c41bb4447..134e18414 100644 --- a/src/net/ssl_server_socket.cpp +++ b/src/net/ssl_server_socket.cpp @@ -301,6 +301,13 @@ int SSLServerSocket::DoHandshake(int* openssl_result) { int rv = SSL_do_handshake(ssl_.get()); *openssl_result = SSL_ERROR_NONE; if (rv == 1) { + const uint8_t* alpn_proto = nullptr; + unsigned alpn_len = 0; + SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len); + if (alpn_len > 0) { + std::string_view proto(reinterpret_cast(alpn_proto), alpn_len); + negotiated_protocol_ = NextProtoFromString(proto); + } #if 0 const STACK_OF(CRYPTO_BUFFER)* certs = SSL_get0_peer_certificates(ssl_.get()); @@ -310,15 +317,6 @@ int SSLServerSocket::DoHandshake(int* openssl_result) { return ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT; } - const uint8_t* alpn_proto = nullptr; - unsigned alpn_len = 0; - SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len); - if (alpn_len > 0) { - base::StringPiece proto(reinterpret_cast(alpn_proto), - alpn_len); - negotiated_protocol_ = NextProtoFromString(proto); - } - if (context_->ssl_server_config_.alert_after_handshake_for_testing) { SSL_send_fatal_alert(ssl_.get(), context_->ssl_server_config_ diff --git a/src/net/ssl_server_socket.hpp b/src/net/ssl_server_socket.hpp index 5bd5edfba..5b5e99e99 100644 --- a/src/net/ssl_server_socket.hpp +++ b/src/net/ssl_server_socket.hpp @@ -13,6 +13,7 @@ #include "net/iobuf.hpp" #include "net/net_errors.hpp" #include "net/openssl_util.hpp" +#include "net/protocol.hpp" namespace net { @@ -48,6 +49,8 @@ class SSLServerSocket : public RefCountedThreadSafe { void WaitRead(WaitCallback&& cb); void WaitWrite(WaitCallback&& cb); + NextProto negotiated_protocol() const { return negotiated_protocol_; } + protected: void OnWaitRead(asio::error_code ec); void OnWaitWrite(asio::error_code ec); @@ -88,6 +91,8 @@ class SSLServerSocket : public RefCountedThreadSafe { // Whether we received any data in early data. bool early_data_received_ = false; + NextProto negotiated_protocol_ = kProtoUnknown; + enum State { STATE_NONE, STATE_HANDSHAKE, diff --git a/src/net/ssl_socket.cpp b/src/net/ssl_socket.cpp index dc7f43ddd..11d3b5b20 100644 --- a/src/net/ssl_socket.cpp +++ b/src/net/ssl_socket.cpp @@ -24,6 +24,27 @@ const int kDefaultOpenSSLBufferSize = 17 * 1024; static constexpr const int kMaximumSSLCache = 1024; static absl::flat_hash_map> g_ssl_lru_cache; +static std::vector SerializeNextProtos(const NextProtoVector& next_protos) { + std::vector wire_protos; + for (const NextProto next_proto : next_protos) { + const std::string_view proto = NextProtoToString(next_proto); + if (proto.size() > 255) { + LOG(WARNING) << "Ignoring overlong ALPN protocol: " << proto; + continue; + } + if (proto.size() == 0) { + LOG(WARNING) << "Ignoring empty ALPN protocol"; + continue; + } + wire_protos.push_back(proto.size()); + for (const char ch : proto) { + wire_protos.push_back(static_cast(ch)); + } + } + + return wire_protos; +} + SSLSocket::SSLSocket(int ssl_socket_data_index, asio::io_context* io_context, asio::ip::tcp::socket* socket, @@ -118,13 +139,22 @@ SSLSocket::SSLSocket(int ssl_socket_data_index, LOG(FATAL) << "SSL_set_verify_algorithm_prefs failed"; } - // ALPS TLS extension is enabled and corresponding data is sent to client if - // client also enabled ALPS, for each NextProto in |application_settings|. - // Data might be empty. - const std::string_view proto_string = https_fallback ? "http/1.1"sv : "h2"sv; - std::vector data; - SSL_add_application_settings(ssl_.get(), reinterpret_cast(proto_string.data()), proto_string.size(), - data.data(), data.size()); + NextProtoVector alpn_protos = {kProtoHTTP2, kProtoHTTP11}; + if (https_fallback) { + alpn_protos = {kProtoHTTP11}; + } + std::vector wire_protos = SerializeNextProtos(alpn_protos); + SSL_set_alpn_protos(ssl_.get(), wire_protos.data(), wire_protos.size()); + + // Enable ALPS for HTTP/2 with empty data. + if (!https_fallback) { + std::string_view proto_string = NextProtoToString(kProtoHTTP2); + std::vector data; + if (!SSL_add_application_settings(ssl_.get(), reinterpret_cast(proto_string.data()), + proto_string.size(), data.data(), data.size())) { + LOG(FATAL) << "SSL_add_application_settings failed"; + }; + } SSL_enable_signed_cert_timestamps(ssl_.get()); SSL_enable_ocsp_stapling(ssl_.get()); @@ -512,7 +542,8 @@ int SSLSocket::DoHandshakeComplete(int result) { unsigned alpn_len = 0; SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len); if (alpn_len > 0) { - negotiated_protocol_ = std::string(reinterpret_cast(alpn_proto), alpn_len); + std::string_view proto(reinterpret_cast(alpn_proto), alpn_len); + negotiated_protocol_ = NextProtoFromString(proto); } const uint8_t* ocsp_response_raw; diff --git a/src/net/ssl_socket.hpp b/src/net/ssl_socket.hpp index 612d21548..9d33f1684 100644 --- a/src/net/ssl_socket.hpp +++ b/src/net/ssl_socket.hpp @@ -14,6 +14,7 @@ #include "net/iobuf.hpp" #include "net/net_errors.hpp" #include "net/openssl_util.hpp" +#include "net/protocol.hpp" namespace net { @@ -79,7 +80,7 @@ class SSLSocket : public RefCountedThreadSafe { void WaitRead(WaitCallback&& cb); void WaitWrite(WaitCallback&& cb); - const std::string& negotiated_protocol() const { return negotiated_protocol_; } + NextProto negotiated_protocol() const { return negotiated_protocol_; } int NewSessionCallback(SSL_SESSION* session); @@ -170,14 +171,13 @@ class SSLSocket : public RefCountedThreadSafe { // ERR_SSL_CLIENT_AUTH_CERT_NEEDED. bool send_client_cert_; - std::string negotiated_protocol_; + NextProto negotiated_protocol_ = kProtoUnknown; bool IsRenegotiationAllowed() const { - using std::string_literals::operator""s; // Prior to HTTP/2 and SPDY, some servers use TLS renegotiation to request // TLS client authentication after the HTTP request was sent. Allow // renegotiation for only those connections. - if (negotiated_protocol_ == "http/1.1"s) { + if (negotiated_protocol_ == kProtoHTTP11) { return true; } // True if renegotiation should be allowed for the default application-level diff --git a/src/net/ssl_stream.hpp b/src/net/ssl_stream.hpp index 881fbce8b..0c7577948 100644 --- a/src/net/ssl_stream.hpp +++ b/src/net/ssl_stream.hpp @@ -84,12 +84,9 @@ class ssl_stream : public stream { return; } - using std::string_view_literals::operator""sv; - std::string_view alpn = ssl_socket_->negotiated_protocol(); - if (!alpn.empty()) { - VLOG(2) << "Alpn selected (client): " << alpn; - } - https_fallback_ |= alpn == "http/1.1"sv; + auto alpn = ssl_socket_->negotiated_protocol(); + VLOG(2) << "Alpn selected (client): " << NextProtoToString(alpn); + https_fallback_ |= alpn == kProtoHTTP11; if (https_fallback_) { VLOG(2) << "Alpn fallback to https protocol (client)"; }