Skip to content

Commit

Permalink
tls: Enable ALPS for HTTP/2
Browse files Browse the repository at this point in the history
  • Loading branch information
Chilledheart committed Oct 18, 2024
1 parent 6d6ea2e commit f877b02
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 46 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4150,6 +4150,7 @@ set(files
src/net/http_parser.cpp
src/net/padding.cpp
src/net/resolver.cpp
src/net/protocol.cpp
src/crypto/aead_base_decrypter.cpp
src/crypto/aead_base_encrypter.cpp
src/crypto/aead_evp_decrypter.cpp
Expand Down
14 changes: 14 additions & 0 deletions src/net/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ class SSLDownlink : public Downlink {
auto callback = std::move(handshake_callback_);
DCHECK(!handshake_callback_);
asio::error_code ec = result == OK ? asio::error_code() : asio::error::connection_refused;
if (!ec) {
auto alpn = ssl_socket_->negotiated_protocol();
switch (alpn) {
case kProtoHTTP2:
DCHECK(!https_fallback_) << " unexpected alpn: " << NextProtoToString(alpn);
break;
case kProtoHTTP11:
DCHECK(https_fallback_) << " unexpected alpn: " << NextProtoToString(alpn);
break;
default:
LOG(WARNING) << "Alpn unexpected: " << NextProtoToString(alpn);
}
VLOG(2) << "Alpn selected (server): " << NextProtoToString(alpn);
}
if (callback) {
callback(ec);
}
Expand Down
28 changes: 9 additions & 19 deletions src/net/content_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "net/asio.hpp"
#include "net/connection.hpp"
#include "net/network.hpp"
#include "net/protocol.hpp"
#include "net/ssl_socket.hpp"
#include "net/x509_util.hpp"

Expand Down Expand Up @@ -459,23 +460,27 @@ class ContentServer {
if (in[0] + 1u > inlen) {
goto err;
}
using std::string_view_literals::operator""sv;
auto alpn = std::string_view(reinterpret_cast<const char*>(in + 1), in[0]);
if (!server->https_fallback_ && alpn == "h2"sv) {
if (!server->https_fallback_ && NextProtoFromString(alpn) == kProtoHTTP2) {
VLOG(1) << "Connection (" << T::Name << ") " << connection_id << " Alpn support (server) chosen: " << alpn;
server->set_https_fallback(connection_id, false);
*out = in + 1;
*outlen = in[0];

// Enable ALPS for HTTP/2 with empty data.
std::vector<uint8_t> data;
SSL_add_application_settings(ssl, reinterpret_cast<const uint8_t*>(alpn.data()), alpn.size(), data.data(),
data.size());
return SSL_TLSEXT_ERR_OK;
}
if (alpn == "http/1.1"sv) {
if (NextProtoFromString(alpn) == kProtoHTTP11) {
VLOG(1) << "Connection (" << T::Name << ") " << connection_id << " Alpn support (server) chosen: " << alpn;
server->set_https_fallback(connection_id, true);
*out = in + 1;
*outlen = in[0];
return SSL_TLSEXT_ERR_OK;
}
LOG(WARNING) << "Unexpected alpn: " << alpn;
LOG(WARNING) << "Connection (" << T::Name << ") " << connection_id << " Unexpected alpn: " << alpn;
inlen -= 1u + in[0];
in += 1u + in[0];
}
Expand Down Expand Up @@ -591,21 +596,6 @@ class ContentServer {
VLOG(1) << "Using upstream certificate (in-memory)";
}

int ret;
std::vector<unsigned char> alpn_vec = {2, 'h', '2', 8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
if (upstream_https_fallback_) {
alpn_vec = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
}
ret = SSL_CTX_set_alpn_protos(ctx, alpn_vec.data(), alpn_vec.size());
static_cast<void>(ret);
DCHECK_EQ(ret, 0);
if (ret) {
print_openssl_error();
ec = asio::error::access_denied;
return;
}
VLOG(1) << "Alpn support (client) enabled";

client_instance_ = this;
ssl_socket_data_index_ = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);

Expand Down
40 changes: 40 additions & 0 deletions src/net/protocol.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2024 Chilledheart */

#include "net/protocol.hpp"

#include <string_view>

using std::string_view_literals::operator""sv;

namespace net {

NextProto NextProtoFromString(std::string_view proto_string) {
if (proto_string == "http/1.1"sv) {
return kProtoHTTP11;
}
if (proto_string == "h2"sv) {
return kProtoHTTP2;
}
if (proto_string == "quic"sv || proto_string == "hq"sv) {
return kProtoQUIC;
}

return kProtoUnknown;
}

std::string_view NextProtoToString(NextProto next_proto) {
switch (next_proto) {
case kProtoHTTP11:
return "http/1.1"sv;
case kProtoHTTP2:
return "h2"sv;
case kProtoQUIC:
return "quic"sv;
case kProtoUnknown:
break;
}
return "unknown";
}

} // namespace net
12 changes: 12 additions & 0 deletions src/net/protocol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#endif

#include <functional>
#include <string_view>
#include <utility>

#include <build/build_config.h>
Expand All @@ -22,6 +23,17 @@

namespace net {

// This enum is used in Net.SSLNegotiatedAlpnProtocol histogram.
// Do not change or re-use values.
enum NextProto { kProtoUnknown = 0, kProtoHTTP11 = 1, kProtoHTTP2 = 2, kProtoQUIC = 3, kProtoLast = kProtoQUIC };

// List of protocols to use for ALPN, used for configuring HttpNetworkSessions.
typedef std::vector<NextProto> NextProtoVector;

NextProto NextProtoFromString(std::string_view proto_string);

std::string_view NextProtoToString(NextProto next_proto);

#ifndef NDEBUG
inline void DumpHex_Impl(const char* file, int line, const char* prefix, const uint8_t* data, uint32_t length) {
if (!VLOG_IS_ON(4)) {
Expand Down
16 changes: 7 additions & 9 deletions src/net/ssl_server_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,13 @@ int SSLServerSocket::DoHandshake(int* openssl_result) {
int rv = SSL_do_handshake(ssl_.get());
*openssl_result = SSL_ERROR_NONE;
if (rv == 1) {
const uint8_t* alpn_proto = nullptr;
unsigned alpn_len = 0;
SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len);
if (alpn_len > 0) {
std::string_view proto(reinterpret_cast<const char*>(alpn_proto), alpn_len);
negotiated_protocol_ = NextProtoFromString(proto);
}
#if 0
const STACK_OF(CRYPTO_BUFFER)* certs =
SSL_get0_peer_certificates(ssl_.get());
Expand All @@ -310,15 +317,6 @@ int SSLServerSocket::DoHandshake(int* openssl_result) {
return ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT;
}

const uint8_t* alpn_proto = nullptr;
unsigned alpn_len = 0;
SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len);
if (alpn_len > 0) {
base::StringPiece proto(reinterpret_cast<const char*>(alpn_proto),
alpn_len);
negotiated_protocol_ = NextProtoFromString(proto);
}

if (context_->ssl_server_config_.alert_after_handshake_for_testing) {
SSL_send_fatal_alert(ssl_.get(),
context_->ssl_server_config_
Expand Down
5 changes: 5 additions & 0 deletions src/net/ssl_server_socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "net/iobuf.hpp"
#include "net/net_errors.hpp"
#include "net/openssl_util.hpp"
#include "net/protocol.hpp"

namespace net {

Expand Down Expand Up @@ -48,6 +49,8 @@ class SSLServerSocket : public RefCountedThreadSafe<SSLServerSocket> {
void WaitRead(WaitCallback&& cb);
void WaitWrite(WaitCallback&& cb);

NextProto negotiated_protocol() const { return negotiated_protocol_; }

protected:
void OnWaitRead(asio::error_code ec);
void OnWaitWrite(asio::error_code ec);
Expand Down Expand Up @@ -88,6 +91,8 @@ class SSLServerSocket : public RefCountedThreadSafe<SSLServerSocket> {
// Whether we received any data in early data.
bool early_data_received_ = false;

NextProto negotiated_protocol_ = kProtoUnknown;

enum State {
STATE_NONE,
STATE_HANDSHAKE,
Expand Down
47 changes: 39 additions & 8 deletions src/net/ssl_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,27 @@ const int kDefaultOpenSSLBufferSize = 17 * 1024;
static constexpr const int kMaximumSSLCache = 1024;
static absl::flat_hash_map<asio::ip::address, bssl::UniquePtr<SSL_SESSION>> g_ssl_lru_cache;

static std::vector<uint8_t> SerializeNextProtos(const NextProtoVector& next_protos) {
std::vector<uint8_t> wire_protos;
for (const NextProto next_proto : next_protos) {
const std::string_view proto = NextProtoToString(next_proto);
if (proto.size() > 255) {
LOG(WARNING) << "Ignoring overlong ALPN protocol: " << proto;
continue;
}
if (proto.size() == 0) {
LOG(WARNING) << "Ignoring empty ALPN protocol";
continue;
}
wire_protos.push_back(proto.size());
for (const char ch : proto) {
wire_protos.push_back(static_cast<uint8_t>(ch));
}
}

return wire_protos;
}

SSLSocket::SSLSocket(int ssl_socket_data_index,
asio::io_context* io_context,
asio::ip::tcp::socket* socket,
Expand Down Expand Up @@ -118,13 +139,22 @@ SSLSocket::SSLSocket(int ssl_socket_data_index,
LOG(FATAL) << "SSL_set_verify_algorithm_prefs failed";
}

// ALPS TLS extension is enabled and corresponding data is sent to client if
// client also enabled ALPS, for each NextProto in |application_settings|.
// Data might be empty.
const std::string_view proto_string = https_fallback ? "http/1.1"sv : "h2"sv;
std::vector<uint8_t> data;
SSL_add_application_settings(ssl_.get(), reinterpret_cast<const uint8_t*>(proto_string.data()), proto_string.size(),
data.data(), data.size());
NextProtoVector alpn_protos = {kProtoHTTP2, kProtoHTTP11};
if (https_fallback) {
alpn_protos = {kProtoHTTP11};
}
std::vector<uint8_t> wire_protos = SerializeNextProtos(alpn_protos);
SSL_set_alpn_protos(ssl_.get(), wire_protos.data(), wire_protos.size());

// Enable ALPS for HTTP/2 with empty data.
if (!https_fallback) {
std::string_view proto_string = NextProtoToString(kProtoHTTP2);
std::vector<uint8_t> data;
if (!SSL_add_application_settings(ssl_.get(), reinterpret_cast<const uint8_t*>(proto_string.data()),
proto_string.size(), data.data(), data.size())) {
LOG(FATAL) << "SSL_add_application_settings failed";
};
}

SSL_enable_signed_cert_timestamps(ssl_.get());
SSL_enable_ocsp_stapling(ssl_.get());
Expand Down Expand Up @@ -512,7 +542,8 @@ int SSLSocket::DoHandshakeComplete(int result) {
unsigned alpn_len = 0;
SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len);
if (alpn_len > 0) {
negotiated_protocol_ = std::string(reinterpret_cast<const char*>(alpn_proto), alpn_len);
std::string_view proto(reinterpret_cast<const char*>(alpn_proto), alpn_len);
negotiated_protocol_ = NextProtoFromString(proto);
}

const uint8_t* ocsp_response_raw;
Expand Down
8 changes: 4 additions & 4 deletions src/net/ssl_socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "net/iobuf.hpp"
#include "net/net_errors.hpp"
#include "net/openssl_util.hpp"
#include "net/protocol.hpp"

namespace net {

Expand Down Expand Up @@ -79,7 +80,7 @@ class SSLSocket : public RefCountedThreadSafe<SSLSocket> {
void WaitRead(WaitCallback&& cb);
void WaitWrite(WaitCallback&& cb);

const std::string& negotiated_protocol() const { return negotiated_protocol_; }
NextProto negotiated_protocol() const { return negotiated_protocol_; }

int NewSessionCallback(SSL_SESSION* session);

Expand Down Expand Up @@ -170,14 +171,13 @@ class SSLSocket : public RefCountedThreadSafe<SSLSocket> {
// ERR_SSL_CLIENT_AUTH_CERT_NEEDED.
bool send_client_cert_;

std::string negotiated_protocol_;
NextProto negotiated_protocol_ = kProtoUnknown;

bool IsRenegotiationAllowed() const {
using std::string_literals::operator""s;
// Prior to HTTP/2 and SPDY, some servers use TLS renegotiation to request
// TLS client authentication after the HTTP request was sent. Allow
// renegotiation for only those connections.
if (negotiated_protocol_ == "http/1.1"s) {
if (negotiated_protocol_ == kProtoHTTP11) {
return true;
}
// True if renegotiation should be allowed for the default application-level
Expand Down
9 changes: 3 additions & 6 deletions src/net/ssl_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,9 @@ class ssl_stream : public stream {
return;
}

using std::string_view_literals::operator""sv;
std::string_view alpn = ssl_socket_->negotiated_protocol();
if (!alpn.empty()) {
VLOG(2) << "Alpn selected (client): " << alpn;
}
https_fallback_ |= alpn == "http/1.1"sv;
auto alpn = ssl_socket_->negotiated_protocol();
VLOG(2) << "Alpn selected (client): " << NextProtoToString(alpn);
https_fallback_ |= alpn == kProtoHTTP11;
if (https_fallback_) {
VLOG(2) << "Alpn fallback to https protocol (client)";
}
Expand Down

0 comments on commit f877b02

Please sign in to comment.