diff --git a/storage/ndb/include/util/SocketAuthenticator.hpp b/storage/ndb/include/util/SocketAuthenticator.hpp index 1da895619e35..be1203bc02e5 100644 --- a/storage/ndb/include/util/SocketAuthenticator.hpp +++ b/storage/ndb/include/util/SocketAuthenticator.hpp @@ -25,17 +25,30 @@ #ifndef SOCKET_AUTHENTICATOR_HPP #define SOCKET_AUTHENTICATOR_HPP -#include "portlib/ndb_socket.h" +#include "util/NdbSocket.h" class SocketAuthenticator { public: SocketAuthenticator() {} virtual ~SocketAuthenticator() {} - virtual bool client_authenticate(ndb_socket_t sockfd) = 0; - virtual bool server_authenticate(ndb_socket_t sockfd) = 0; + bool client_authenticate(ndb_socket_t); + bool server_authenticate(ndb_socket_t); + virtual bool client_authenticate(NdbSocket &) = 0; + virtual bool server_authenticate(NdbSocket &) = 0; }; +inline bool SocketAuthenticator::client_authenticate(ndb_socket_t fd) { + NdbSocket socket(fd, NdbSocket::From::Existing); + return client_authenticate(socket); +} + +inline bool SocketAuthenticator::server_authenticate(ndb_socket_t fd) { + NdbSocket socket(fd, NdbSocket::From::Existing); + return server_authenticate(socket); +} + + class SocketAuthSimple : public SocketAuthenticator { char *m_passwd; @@ -43,8 +56,9 @@ class SocketAuthSimple : public SocketAuthenticator public: SocketAuthSimple(const char *username, const char *passwd); ~SocketAuthSimple() override; - bool client_authenticate(ndb_socket_t sockfd) override; - bool server_authenticate(ndb_socket_t sockfd) override; + bool client_authenticate(NdbSocket &) override; + bool server_authenticate(NdbSocket &) override; }; + #endif // SOCKET_AUTHENTICATOR_HPP diff --git a/storage/ndb/src/common/util/SocketAuthenticator.cpp b/storage/ndb/src/common/util/SocketAuthenticator.cpp index d3444eb41dfd..7e8e1c9e8aa1 100644 --- a/storage/ndb/src/common/util/SocketAuthenticator.cpp +++ b/storage/ndb/src/common/util/SocketAuthenticator.cpp @@ -45,10 +45,10 @@ SocketAuthSimple::~SocketAuthSimple() free(m_username); } -bool SocketAuthSimple::client_authenticate(ndb_socket_t sockfd) +bool SocketAuthSimple::client_authenticate(NdbSocket & sockfd) { - SocketOutputStream s_output(sockfd); - SocketInputStream s_input(sockfd); + SecureSocketOutputStream s_output(sockfd); + SecureSocketInputStream s_input(sockfd); // Write username and password s_output.println("%s", m_username ? m_username : ""); @@ -68,10 +68,10 @@ bool SocketAuthSimple::client_authenticate(ndb_socket_t sockfd) return false; } -bool SocketAuthSimple::server_authenticate(ndb_socket_t sockfd) +bool SocketAuthSimple::server_authenticate(NdbSocket & sockfd) { - SocketOutputStream s_output(sockfd); - SocketInputStream s_input(sockfd); + SecureSocketOutputStream s_output(sockfd); + SecureSocketInputStream s_input(sockfd); char buf[256]; diff --git a/storage/ndb/src/common/util/SocketClient.cpp b/storage/ndb/src/common/util/SocketClient.cpp index 8b4d0ff83ad4..711a748f9e80 100644 --- a/storage/ndb/src/common/util/SocketClient.cpp +++ b/storage/ndb/src/common/util/SocketClient.cpp @@ -235,16 +235,16 @@ SocketClient::connect(NdbSocket & secureSocket, assert(m_last_used_port == 0); ndb_socket_get_port(m_sockfd, &m_last_used_port); + secureSocket.init_from_new(m_sockfd); + if (m_auth) { - if (!m_auth->client_authenticate(m_sockfd)) + if (!m_auth->client_authenticate(secureSocket)) { DEBUG_FPRINTF((stderr, "authenticate failed in connect\n")); - ndb_socket_close(m_sockfd); - ndb_socket_invalidate(&m_sockfd); + secureSocket.close(); + secureSocket.invalidate(); } } - secureSocket.init_from_new(m_sockfd); - ndb_socket_invalidate(&m_sockfd); }