Skip to content

Commit

Permalink
Don't allow the instance replication of itself and it's own replicas (a…
Browse files Browse the repository at this point in the history
…pache#1488)

Co-authored-by: git-hulk <[email protected]>
Co-authored-by: Twice <[email protected]>
  • Loading branch information
3 people authored and jihuayu committed Jun 16, 2023
1 parent 0e29fb6 commit 5b1e322
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 2 deletions.
25 changes: 24 additions & 1 deletion src/commands/cmd_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "commander.h"
#include "commands/scan_base.h"
#include "common/io_util.h"
#include "config/config.h"
#include "error_constants.h"
#include "server/redis_connection.h"
Expand Down Expand Up @@ -867,6 +868,24 @@ class CommandFlushBackup : public Commander {

class CommandSlaveOf : public Commander {
public:
static Status IsTryingToReplicateItself(Server *svr, const std::string &host, uint32_t port) {
auto ip_addresses = util::LookupHostByName(host);
if (!ip_addresses) {
return {Status::NotOK, "Can not resolve hostname: " + host};
}
for (auto &ip : *ip_addresses) {
if (util::MatchListeningIP(svr->GetConfig()->binds, ip) && port == svr->GetConfig()->port) {
return {Status::NotOK, "can't replicate itself"};
}
for (std::pair<std::string, uint32_t> &host_port_pair : svr->GetSlaveHostAndPort()) {
if (host_port_pair.first == ip && host_port_pair.second == port) {
return {Status::NotOK, "can't replicate your own replicas"};
}
}
}
return Status::OK();
}

Status Parse(const std::vector<std::string> &args) override {
host_ = args[1];
const auto &port = args[2];
Expand Down Expand Up @@ -914,7 +933,11 @@ class CommandSlaveOf : public Commander {
return Status::OK();
}

auto s = svr->AddMaster(host_, port_, false);
auto s = IsTryingToReplicateItself(svr, host_, port_);
if (!s.IsOK()) {
return {Status::RedisExecErr, s.Msg()};
}
s = svr->AddMaster(host_, port_, false);
if (s.IsOK()) {
*output = redis::SimpleString("OK");
LOG(WARNING) << "SLAVE OF " << host_ << ":" << port_ << " enabled (user request from '" << conn->GetAddr()
Expand Down
27 changes: 27 additions & 0 deletions src/common/io_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,33 @@ Status SockSetTcpKeepalive(int fd, int interval) {
return Status::OK();
}

// Lookup IP addresses by hostname
StatusOr<std::vector<std::string>> LookupHostByName(const std::string &host) {
addrinfo hints = {}, *servinfo = nullptr;

hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;

if (int rv = getaddrinfo(host.c_str(), nullptr, &hints, &servinfo); rv != 0) {
return {Status::NotOK, gai_strerror(rv)};
}

auto exit = MakeScopeExit([servinfo] { freeaddrinfo(servinfo); });

std::vector<std::string> ips;
for (auto p = servinfo; p != nullptr; p = p->ai_next) {
char ip[INET6_ADDRSTRLEN] = {};
if (p->ai_family == AF_INET) {
inet_ntop(p->ai_family, &((struct sockaddr_in *)p->ai_addr)->sin_addr, ip, sizeof(ip));
} else {
inet_ntop(p->ai_family, &((struct sockaddr_in6 *)p->ai_addr)->sin6_addr, ip, sizeof(ip));
}
ips.emplace_back(ip);
}

return ips;
}

StatusOr<int> SockConnect(const std::string &host, uint32_t port, int conn_timeout, int timeout) {
addrinfo hints = {}, *servinfo = nullptr;

Expand Down
1 change: 1 addition & 0 deletions src/common/io_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
namespace util {

sockaddr_in NewSockaddrInet(const std::string &host, uint32_t port);
StatusOr<std::vector<std::string>> LookupHostByName(const std::string &host);
StatusOr<int> SockConnect(const std::string &host, uint32_t port, int conn_timeout = 0, int timeout = 0);
Status SockSetTcpNoDelay(int fd, int val);
Status SockSetTcpKeepalive(int fd, int interval);
Expand Down
15 changes: 14 additions & 1 deletion src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,19 @@ void Server::ResetWatchedKeys(redis::Connection *conn) {
}
}

std::list<std::pair<std::string, uint32_t>> Server::GetSlaveHostAndPort() {
std::list<std::pair<std::string, uint32_t>> result;
slave_threads_mu_.lock();
for (const auto &slave : slave_threads_) {
if (slave->IsStopped()) continue;
std::pair<std::string, int> host_port_pair = {slave->GetConn()->GetAnnounceIP(),
slave->GetConn()->GetListeningPort()};
result.emplace_back(host_port_pair);
}
slave_threads_mu_.unlock();
return result;
}

// The numeric cursor consists of a 32-bit hash, a 16-bit time stamp, and a 16-bit counter, with the highest bit set to
// 1 to prevent a zero cursor from occurring. The hash is used to prevent users from obtaining cursors that are used by
// other users. The time_stamp is used to prevent the generation of the same cursor in the extremely short period before
Expand Down Expand Up @@ -1799,4 +1812,4 @@ std::string Server::GetKeyNameFromCursor(const std::string &cursor) {
}

return {};
}
}
1 change: 1 addition & 0 deletions src/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class Server {
void WatchKey(redis::Connection *conn, const std::vector<std::string> &keys);
static bool IsWatchedKeysModified(redis::Connection *conn);
void ResetWatchedKeys(redis::Connection *conn);
std::list<std::pair<std::string, uint32_t>> GetSlaveHostAndPort();

#ifdef ENABLE_OPENSSL
UniqueSSLContext ssl_ctx;
Expand Down
29 changes: 29 additions & 0 deletions tests/gocase/integration/replication/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,32 @@ func TestReplicationAnnounceIP(t *testing.T) {
require.Equal(t, "1234", slave0port)
})
}

func TestShouldNotReplicate(t *testing.T) {
master := util.StartServer(t, map[string]string{})
defer master.Close()
masterClient := master.NewClient()
defer func() { require.NoError(t, masterClient.Close()) }()

ctx := context.Background()

slave := util.StartServer(t, map[string]string{})
defer slave.Close()
slaveClient := slave.NewClient()
defer func() { require.NoError(t, slaveClient.Close()) }()

t.Run("Setting server as replica of itself should throw error", func(t *testing.T) {
err := slaveClient.SlaveOf(ctx, slave.Host(), fmt.Sprintf("%d", slave.Port())).Err()
require.Equal(t, "ERR can't replicate itself", err.Error())
require.Equal(t, "master", util.FindInfoEntry(slaveClient, "role"))
})

t.Run("Master should not be able to replicate slave", func(t *testing.T) {
util.SlaveOf(t, slaveClient, master)
util.WaitForSync(t, slaveClient)
require.Equal(t, "slave", util.FindInfoEntry(slaveClient, "role"))
err := masterClient.SlaveOf(ctx, slave.Host(), fmt.Sprintf("%d", slave.Port())).Err()
require.EqualErrorf(t, err, "ERR can't replicate your own replicas", err.Error())
require.Equal(t, "master", util.FindInfoEntry(masterClient, "role"))
})
}

0 comments on commit 5b1e322

Please sign in to comment.