Skip to content

Commit

Permalink
Allow users to request TLS client-side enforcement (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
FalacerSelene authored Apr 20, 2022
1 parent 3f480db commit 7cfa33d
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 1 deletion.
20 changes: 20 additions & 0 deletions include/cassandra.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ typedef enum CassSslVerifyFlags_ {
CASS_SSL_VERIFY_PEER_IDENTITY_DNS = 0x04
} CassSslVerifyFlags;

typedef enum CassSslTlsVersion_ {
CASS_SSL_VERSION_TLS1 = 0x00,
CASS_SSL_VERSION_TLS1_1 = 0x01,
CASS_SSL_VERSION_TLS1_2 = 0x02
} CassSslTlsVersion;

typedef enum CassProtocolVersion_ {
CASS_PROTOCOL_VERSION_V1 = 0x01, /**< Deprecated */
CASS_PROTOCOL_VERSION_V2 = 0x02, /**< Deprecated */
Expand Down Expand Up @@ -4686,6 +4692,20 @@ cass_ssl_set_private_key_n(CassSsl* ssl,
const char* password,
size_t password_length);

/**
* Set minimum supported client-side protocol version. This will prevent the
* connection using protocol versions earlier than the specified one. Useful
* for preventing TLS downgrade attacks.
*
* @public @memberof CassSsl
*
* @param[in] ssl
* @param[in] min_version
* @return CASS_OK if successful, otherwise an error occurred.
*/
CASS_EXPORT CassError
cass_ssl_set_min_protocol_version(CassSsl* ssl, CassSslTlsVersion min_version);

/***********************************************************************************
*
* Authenticator
Expand Down
4 changes: 4 additions & 0 deletions src/ssl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ CassError cass_ssl_set_private_key_n(CassSsl* ssl, const char* key, size_t key_l
return ssl->set_private_key(key, key_length, password, password_length);
}

CassError cass_ssl_set_min_protocol_version(CassSsl* ssl, CassSslTlsVersion min_version) {
return ssl->set_min_protocol_version(min_version);
}

} // extern "C"

template <class T>
Expand Down
1 change: 1 addition & 0 deletions src/ssl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class SslContext : public RefCounted<SslContext> {
virtual CassError set_cert(const char* cert, size_t cert_length) = 0;
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
size_t password_length) = 0;
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version) = 0;

protected:
int verify_flags_;
Expand Down
4 changes: 4 additions & 0 deletions src/ssl/ssl_no_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,8 @@ CassError NoSslContext::set_private_key(const char* key, size_t key_length, cons
return CASS_ERROR_LIB_NOT_IMPLEMENTED;
}

CassError NoSslContext::set_min_protocol_version(CassSslTlsVersion min_version) {
return CASS_ERROR_LIB_NOT_IMPLEMENTED;
}

SslContext::Ptr NoSslContextFactory::create() { return SslContext::Ptr(new NoSslContext()); }
1 change: 1 addition & 0 deletions src/ssl/ssl_no_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class NoSslContext : public SslContext {
virtual CassError set_cert(const char* cert, size_t cert_length);
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
size_t password_length);
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version);
};

class NoSslContextFactory : public SslContextFactoryBase<NoSslContextFactory> {
Expand Down
44 changes: 44 additions & 0 deletions src/ssl/ssl_openssl_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
!defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined
// as 2.0.0
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
#define SSL_CAN_SET_MIN_VERSION
#define SSL_CLIENT_METHOD TLS_client_method
#else
#define SSL_CLIENT_METHOD SSLv23_client_method
#endif
#else
#if (LIBRESSL_VERSION_NUMBER >= 0x20302000L)
#define SSL_CAN_SET_MIN_VERSION
#define SSL_CLIENT_METHOD TLS_client_method
#else
#define SSL_CLIENT_METHOD SSLv23_client_method
Expand Down Expand Up @@ -615,6 +617,48 @@ CassError OpenSslContext::set_private_key(const char* key, size_t key_length, co
return CASS_OK;
}

CassError OpenSslContext::set_min_protocol_version(CassSslTlsVersion min_version) {
#ifdef SSL_CAN_SET_MIN_VERSION
int method;
switch (min_version) {
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1:
method = TLS1_VERSION;
break;
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_1:
method = TLS1_1_VERSION;
break;
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_2:
method = TLS1_2_VERSION;
break;
default:
// unsupported version
return CASS_ERROR_LIB_BAD_PARAMS;
}
SSL_CTX_set_min_proto_version(ssl_ctx_, method);
return CASS_OK;
#else
// If we don't have the `set_min_proto_version` function then we do this via
// the (deprecated in later versions) options function.
int options = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
switch (min_version) {
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1:
break;
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_1:
options |= SSL_OP_NO_TLSv1;
break;
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_2:
options |= SSL_OP_NO_TLSv1;
options |= SSL_OP_NO_TLSv1_1;
break;
default:
// unsupported version
return CASS_ERROR_LIB_BAD_PARAMS;
}
SSL_CTX_set_options(ssl_ctx_, options);
return CASS_OK;
#endif
}

SslContext::Ptr OpenSslContextFactory::create() { return SslContext::Ptr(new OpenSslContext()); }

namespace openssl {
Expand Down
1 change: 1 addition & 0 deletions src/ssl/ssl_openssl_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class OpenSslContext : public SslContext {
virtual CassError set_cert(const char* cert, size_t cert_length);
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
size_t password_length);
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version);

private:
SSL_CTX* ssl_ctx_;
Expand Down
17 changes: 17 additions & 0 deletions tests/src/unit/mockssandra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ using datastax::internal::core::UuidGen;
!defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined
// as 2.0.0
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
#define SSL_CAN_SET_MAX_VERSION
#define SSL_SERVER_METHOD TLS_server_method
#else
#define SSL_SERVER_METHOD SSLv23_server_method
#endif
#else
#if (LIBRESSL_VERSION_NUMBER >= 0x20302000L)
#define SSL_CAN_SET_MAX_VERSION
#define SSL_SERVER_METHOD TLS_server_method
#else
#define SSL_SERVER_METHOD SSLv23_server_method
Expand Down Expand Up @@ -555,6 +557,21 @@ bool ServerConnection::use_ssl(const String& key, const String& cert,
return true;
}

// Weaken the SSL connection, enforcing that it can only use TLS1.0 at max.
// This is used for testing client-side enforcement of more secure TLS
// protocols.
void ServerConnection::weaken_ssl() {
if (!ssl_context_) {
return;
}

#ifdef SSL_CAN_SET_MAX_VERSION
SSL_CTX_set_max_proto_version(ssl_context_, TLS1_VERSION);
#else
SSL_CTX_set_options(ssl_context_, SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1_2);
#endif
}

using datastax::internal::core::Task;

class RunListen : public Task {
Expand Down
13 changes: 12 additions & 1 deletion tests/src/unit/mockssandra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class ServerConnection : public RefCounted<ServerConnection> {

bool use_ssl(const String& key, const String& cert, const String& ca_cert = "",
bool require_client_cert = false);
void weaken_ssl();

void listen(EventLoopGroup* event_loop_group);
int wait_listen();
Expand Down Expand Up @@ -1161,6 +1162,7 @@ class Cluster {
~Cluster();

String use_ssl(const String& cn = "");
void weaken_ssl();

int start_all(EventLoopGroup* event_loop_group);
void start_all_async(EventLoopGroup* event_loop_group);
Expand Down Expand Up @@ -1264,7 +1266,8 @@ class SimpleEchoServer {
public:
SimpleEchoServer()
: factory_(new EchoClientConnectionFactory())
, event_loop_group_(1) {}
, event_loop_group_(1)
, ssl_weaken_(false) {}

~SimpleEchoServer() { close(); }

Expand All @@ -1281,6 +1284,8 @@ class SimpleEchoServer {
return ssl_cert_;
}

void weaken_ssl() { ssl_weaken_ = true; }

void use_connection_factory(internal::ClientConnectionFactory* factory) {
factory_.reset(factory);
}
Expand All @@ -1290,6 +1295,11 @@ class SimpleEchoServer {
if (!ssl_key_.empty() && !ssl_cert_.empty() && !server_->use_ssl(ssl_key_, ssl_cert_)) {
return -1;
}

if (ssl_weaken_) {
server_->weaken_ssl();
}

server_->listen(&event_loop_group_);
return server_->wait_listen();
}
Expand All @@ -1316,6 +1326,7 @@ class SimpleEchoServer {
internal::ServerConnection::Ptr server_;
String ssl_key_;
String ssl_cert_;
bool ssl_weaken_;
};

} // namespace mockssandra
Expand Down
32 changes: 32 additions & 0 deletions tests/src/unit/tests/test_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class SocketUnitTest : public LoopTest {
return settings;
}

void weaken_ssl() { server_.weaken_ssl(); }

void listen(const Address& address = Address("127.0.0.1", 8888)) {
ASSERT_EQ(server_.listen(address), 0);
}
Expand Down Expand Up @@ -185,6 +187,17 @@ class SocketUnitTest : public LoopTest {
}
}

/* SSL handshake failures have different error codes on different versions of
* OpenSSL - this accounts for both of them
*/
static void on_socket_ssl_error(SocketConnector* connector, bool* is_error) {
SocketConnector::SocketError err = connector->error_code();
if ((err == SocketConnector::SOCKET_ERROR_CLOSE) ||
(err == SocketConnector::SOCKET_ERROR_SSL_HANDSHAKE)) {
*is_error = true;
}
}

static void on_socket_canceled(SocketConnector* connector, bool* is_canceled) {
if (connector->is_canceled()) {
*is_canceled = true;
Expand Down Expand Up @@ -409,3 +422,22 @@ TEST_F(SocketUnitTest, SslVerifyIdentityDns) {

EXPECT_EQ(result, "The socket is successfully connected and wrote data - Closed");
}

TEST_F(SocketUnitTest, SslEnforceTlsVersion) {
SocketSettings settings(use_ssl("127.0.0.1"));
weaken_ssl();

listen();

settings.ssl_context->set_min_protocol_version(CASS_SSL_VERSION_TLS1_2);

bool is_error;
SocketConnector::Ptr connector(new SocketConnector(
Address("127.0.0.1", 8888), bind_callback(on_socket_ssl_error, &is_error)));

connector->with_settings(settings)->connect(loop());

uv_run(loop(), UV_RUN_DEFAULT);

EXPECT_TRUE(is_error);
}

0 comments on commit 7cfa33d

Please sign in to comment.