Skip to content

Commit

Permalink
Fix hot restart bug due to passed an fd without bind (envoyproxy#8426)
Browse files Browse the repository at this point in the history
Signed-off-by: tianqian.zyf <[email protected]>
  • Loading branch information
zyfjeff authored and nandu-vinodan committed Oct 17, 2019
1 parent 84ac8cf commit c7c6bb8
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 9 deletions.
2 changes: 1 addition & 1 deletion source/server/hot_restarting_parent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ HotRestartingParent::Internal::getListenSocketsForChild(const HotRestartMessage:
Network::Address::InstanceConstSharedPtr addr =
Network::Utility::resolveUrl(request.pass_listen_socket().address());
for (const auto& listener : server_->listenerManager().listeners()) {
if (*listener.get().socket().localAddress() == *addr) {
if (*listener.get().socket().localAddress() == *addr && listener.get().bindToPort()) {
wrapped_reply.mutable_reply()->mutable_pass_listen_socket()->set_fd(
listener.get().socket().ioHandle().fd());
break;
Expand Down
20 changes: 12 additions & 8 deletions source/server/listener_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,20 @@ Network::SocketSharedPtr ProdListenerComponentFactory::createListenSocket(
? Network::Utility::TCP_SCHEME
: Network::Utility::UDP_SCHEME;
const std::string addr = absl::StrCat(scheme, address->asString());
const int fd = server_.hotRestart().duplicateParentListenSocket(addr);
if (fd != -1) {
ENVOY_LOG(debug, "obtained socket for address {} from parent", addr);
Network::IoHandlePtr io_handle = std::make_unique<Network::IoSocketHandleImpl>(fd);
if (socket_type == Network::Address::SocketType::Stream) {
return std::make_shared<Network::TcpListenSocket>(std::move(io_handle), address, options);
} else {
return std::make_shared<Network::UdpListenSocket>(std::move(io_handle), address, options);

if (bind_to_port) {
const int fd = server_.hotRestart().duplicateParentListenSocket(addr);
if (fd != -1) {
ENVOY_LOG(debug, "obtained socket for address {} from parent", addr);
Network::IoHandlePtr io_handle = std::make_unique<Network::IoSocketHandleImpl>(fd);
if (socket_type == Network::Address::SocketType::Stream) {
return std::make_shared<Network::TcpListenSocket>(std::move(io_handle), address, options);
} else {
return std::make_shared<Network::UdpListenSocket>(std::move(io_handle), address, options);
}
}
}

if (socket_type == Network::Address::SocketType::Stream) {
return std::make_shared<Network::TcpListenSocket>(address, options, bind_to_port);
} else {
Expand Down
2 changes: 2 additions & 0 deletions test/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ envoy_cc_test(
deps = [
"//source/common/stats:stats_lib",
"//source/server:hot_restart_lib",
"//test/mocks/network:network_mocks",
"//test/mocks/server:server_mocks",
],
)
Expand Down Expand Up @@ -201,6 +202,7 @@ envoy_cc_test(
"//source/extensions/transport_sockets/tls:config",
"//source/extensions/transport_sockets/tls:ssl_socket_lib",
"//source/server:active_raw_udp_listener_config",
"//test/test_common:network_utility_lib",
"//test/test_common:registry_lib",
],
)
Expand Down
19 changes: 19 additions & 0 deletions test/server/hot_restarting_parent_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

#include "server/hot_restarting_parent.h"

#include "test/mocks/network/mocks.h"
#include "test/mocks/server/mocks.h"

#include "gtest/gtest.h"

using testing::InSequence;
using testing::Return;
using testing::ReturnRef;

Expand Down Expand Up @@ -40,6 +42,23 @@ TEST_F(HotRestartingParentTest, getListenSocketsForChildNotFound) {
EXPECT_EQ(-1, message.reply().pass_listen_socket().fd());
}

TEST_F(HotRestartingParentTest, getListenSocketsForChildNotBindPort) {
MockListenerManager listener_manager;
Network::MockListenerConfig listener_config;
std::vector<std::reference_wrapper<Network::ListenerConfig>> listeners;
InSequence s;
listeners.push_back(std::ref(*static_cast<Network::ListenerConfig*>(&listener_config)));
EXPECT_CALL(server_, listenerManager()).WillOnce(ReturnRef(listener_manager));
EXPECT_CALL(listener_manager, listeners()).WillOnce(Return(listeners));
EXPECT_CALL(listener_config, socket()).Times(1);
EXPECT_CALL(listener_config, bindToPort()).WillOnce(Return(false));

HotRestartMessage::Request request;
request.mutable_pass_listen_socket()->set_address("tcp://0.0.0.0:80");
HotRestartMessage message = hot_restarting_parent_.getListenSocketsForChild(request);
EXPECT_EQ(-1, message.reply().pass_listen_socket().fd());
}

TEST_F(HotRestartingParentTest, exportStatsToChild) {
Stats::IsolatedStoreImpl store;
MockListenerManager listener_manager;
Expand Down
42 changes: 42 additions & 0 deletions test/server/listener_manager_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "extensions/transport_sockets/tls/ssl_socket.h"

#include "test/server/utility.h"
#include "test/test_common/network_utility.h"
#include "test/test_common/registry.h"
#include "test/test_common/utility.h"

Expand Down Expand Up @@ -915,6 +916,47 @@ name: foo
EXPECT_CALL(*listener_foo2, onDestroy());
}

TEST_F(ListenerManagerImplTest, BindToPortEqualToFalse) {
InSequence s;
ProdListenerComponentFactory real_listener_factory(server_);
EXPECT_CALL(*worker_, start(_));
manager_->startWorkers(guard_dog_);
const std::string listener_foo_yaml = R"EOF(
name: foo
address:
socket_address:
address: 127.0.0.1
port_value: 1234
deprecated_v1:
bind_to_port: false
filter_chains:
- filters: []
)EOF";
Api::OsSysCallsImpl os_syscall;
auto syscall_result = os_syscall.socket(AF_INET, SOCK_STREAM, 0);
ASSERT_GE(syscall_result.rc_, 0);
ListenerHandle* listener_foo = expectListenerCreate(true, true);
EXPECT_CALL(listener_factory_, createListenSocket(_, _, _, false))
.WillOnce(Invoke([this, &syscall_result, &real_listener_factory](
const Network::Address::InstanceConstSharedPtr& address,
Network::Address::SocketType socket_type,
const Network::Socket::OptionsSharedPtr& options,
bool bind_to_port) -> Network::SocketSharedPtr {
EXPECT_CALL(server_, hotRestart).Times(0);
// When bind_to_port is equal to false, create socket fd directly, and do not get socket
// fd through hot restart.
NiceMock<Api::MockOsSysCalls> os_sys_calls;
TestThreadsafeSingletonInjector<Api::OsSysCallsImpl> os_calls(&os_sys_calls);
ON_CALL(os_sys_calls, socket(AF_INET, _, 0))
.WillByDefault(Return(Api::SysCallIntResult{syscall_result.rc_, 0}));
return real_listener_factory.createListenSocket(address, socket_type, options,
bind_to_port);
}));
EXPECT_CALL(listener_foo->target_, initialize());
EXPECT_CALL(*listener_foo, onDestroy());
EXPECT_TRUE(manager_->addOrUpdateListener(parseListenerFromV2Yaml(listener_foo_yaml), "", true));
}

TEST_F(ListenerManagerImplTest, CantBindSocket) {
InSequence s;

Expand Down

0 comments on commit c7c6bb8

Please sign in to comment.