Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Block instance replication of itself and it's own replicas #1488

Merged
merged 10 commits into from
Jun 12, 2023
25 changes: 24 additions & 1 deletion src/commands/cmd_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "commander.h"
#include "commands/scan_base.h"
#include "common/io_util.h"
#include "config/config.h"
#include "error_constants.h"
#include "server/redis_connection.h"
Expand Down Expand Up @@ -867,6 +868,24 @@ class CommandFlushBackup : public Commander {

class CommandSlaveOf : public Commander {
public:
static Status IsTryingToReplicateItself(Server *svr, const std::string &host, uint32_t port) {
auto ip_addresses = util::LookupHostByName(host);
if (!ip_addresses) {
return {Status::NotOK, "Can not resolve hostname: " + host};
}
for (auto &ip : *ip_addresses) {
if (util::MatchListeningIP(svr->GetConfig()->binds, ip) && port == svr->GetConfig()->port) {
return {Status::NotOK, "can't replicate itself"};
}
for (std::pair<std::string, uint32_t> &host_port_pair : svr->GetSlaveHostAndPort()) {
if (host_port_pair.first == ip && host_port_pair.second == port) {
return {Status::NotOK, "can't replicate your own replicas"};
}
}
}
return Status::OK();
}

Status Parse(const std::vector<std::string> &args) override {
host_ = args[1];
const auto &port = args[2];
Expand Down Expand Up @@ -914,7 +933,11 @@ class CommandSlaveOf : public Commander {
return Status::OK();
}

auto s = svr->AddMaster(host_, port_, false);
auto s = IsTryingToReplicateItself(svr, host_, port_);
if (!s.IsOK()) {
return {Status::RedisExecErr, s.Msg()};
}
s = svr->AddMaster(host_, port_, false);
if (s.IsOK()) {
*output = redis::SimpleString("OK");
LOG(WARNING) << "SLAVE OF " << host_ << ":" << port_ << " enabled (user request from '" << conn->GetAddr()
Expand Down
27 changes: 27 additions & 0 deletions src/common/io_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,33 @@ Status SockSetTcpKeepalive(int fd, int interval) {
return Status::OK();
}

// Lookup IP addresses by hostname
StatusOr<std::vector<std::string>> LookupHostByName(const std::string &host) {
addrinfo hints = {}, *servinfo = nullptr;

hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;

if (int rv = getaddrinfo(host.c_str(), nullptr, &hints, &servinfo); rv != 0) {
return {Status::NotOK, gai_strerror(rv)};
}

auto exit = MakeScopeExit([servinfo] { freeaddrinfo(servinfo); });

std::vector<std::string> ips;
for (auto p = servinfo; p != nullptr; p = p->ai_next) {
char ip[INET6_ADDRSTRLEN] = {};
if (p->ai_family == AF_INET) {
inet_ntop(p->ai_family, &((struct sockaddr_in *)p->ai_addr)->sin_addr, ip, sizeof(ip));
} else {
inet_ntop(p->ai_family, &((struct sockaddr_in6 *)p->ai_addr)->sin6_addr, ip, sizeof(ip));
}
ips.emplace_back(ip);
}

return ips;
}

StatusOr<int> SockConnect(const std::string &host, uint32_t port, int conn_timeout, int timeout) {
addrinfo hints = {}, *servinfo = nullptr;

Expand Down
1 change: 1 addition & 0 deletions src/common/io_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
namespace util {

sockaddr_in NewSockaddrInet(const std::string &host, uint32_t port);
StatusOr<std::vector<std::string>> LookupHostByName(const std::string &host);
StatusOr<int> SockConnect(const std::string &host, uint32_t port, int conn_timeout = 0, int timeout = 0);
Status SockSetTcpNoDelay(int fd, int val);
Status SockSetTcpKeepalive(int fd, int interval);
Expand Down
13 changes: 13 additions & 0 deletions src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1751,3 +1751,16 @@ void Server::ResetWatchedKeys(redis::Connection *conn) {
watched_key_size_ = watched_key_map_.size();
}
}

std::list<std::pair<std::string, uint32_t>> Server::GetSlaveHostAndPort() {
std::list<std::pair<std::string, uint32_t>> result;
slave_threads_mu_.lock();
for (const auto &slave : slave_threads_) {
if (slave->IsStopped()) continue;
std::pair<std::string, int> host_port_pair = {slave->GetConn()->GetAnnounceIP(),
slave->GetConn()->GetListeningPort()};
result.emplace_back(host_port_pair);
}
slave_threads_mu_.unlock();
return result;
}
1 change: 1 addition & 0 deletions src/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class Server {
void WatchKey(redis::Connection *conn, const std::vector<std::string> &keys);
static bool IsWatchedKeysModified(redis::Connection *conn);
void ResetWatchedKeys(redis::Connection *conn);
std::list<std::pair<std::string, uint32_t>> GetSlaveHostAndPort();

#ifdef ENABLE_OPENSSL
UniqueSSLContext ssl_ctx;
Expand Down
29 changes: 29 additions & 0 deletions tests/gocase/integration/replication/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,32 @@ func TestReplicationAnnounceIP(t *testing.T) {
require.Equal(t, "1234", slave0port)
})
}

func TestShouldNotReplicate(t *testing.T) {
master := util.StartServer(t, map[string]string{})
defer master.Close()
masterClient := master.NewClient()
defer func() { require.NoError(t, masterClient.Close()) }()

ctx := context.Background()

slave := util.StartServer(t, map[string]string{})
defer slave.Close()
slaveClient := slave.NewClient()
defer func() { require.NoError(t, slaveClient.Close()) }()

t.Run("Setting server as replica of itself should throw error", func(t *testing.T) {
err := slaveClient.SlaveOf(ctx, slave.Host(), fmt.Sprintf("%d", slave.Port())).Err()
require.Equal(t, "ERR can't replicate itself", err.Error())
require.Equal(t, "master", util.FindInfoEntry(slaveClient, "role"))
})

t.Run("Master should not be able to replicate slave", func(t *testing.T) {
util.SlaveOf(t, slaveClient, master)
util.WaitForSync(t, slaveClient)
require.Equal(t, "slave", util.FindInfoEntry(slaveClient, "role"))
err := masterClient.SlaveOf(ctx, slave.Host(), fmt.Sprintf("%d", slave.Port())).Err()
require.EqualErrorf(t, err, "ERR can't replicate your own replicas", err.Error())
require.Equal(t, "master", util.FindInfoEntry(masterClient, "role"))
})
}