diff --git a/source/server/hot_restarting_parent.cc b/source/server/hot_restarting_parent.cc index becef491b5fa..1f4757b2e68e 100644 --- a/source/server/hot_restarting_parent.cc +++ b/source/server/hot_restarting_parent.cc @@ -95,7 +95,7 @@ HotRestartingParent::Internal::getListenSocketsForChild(const HotRestartMessage: Network::Address::InstanceConstSharedPtr addr = Network::Utility::resolveUrl(request.pass_listen_socket().address()); for (const auto& listener : server_->listenerManager().listeners()) { - if (*listener.get().socket().localAddress() == *addr) { + if (*listener.get().socket().localAddress() == *addr && listener.get().bindToPort()) { wrapped_reply.mutable_reply()->mutable_pass_listen_socket()->set_fd( listener.get().socket().ioHandle().fd()); break; diff --git a/source/server/listener_manager_impl.cc b/source/server/listener_manager_impl.cc index 443e3faa6049..f00427e57b27 100644 --- a/source/server/listener_manager_impl.cc +++ b/source/server/listener_manager_impl.cc @@ -166,16 +166,20 @@ Network::SocketSharedPtr ProdListenerComponentFactory::createListenSocket( ? Network::Utility::TCP_SCHEME : Network::Utility::UDP_SCHEME; const std::string addr = absl::StrCat(scheme, address->asString()); - const int fd = server_.hotRestart().duplicateParentListenSocket(addr); - if (fd != -1) { - ENVOY_LOG(debug, "obtained socket for address {} from parent", addr); - Network::IoHandlePtr io_handle = std::make_unique(fd); - if (socket_type == Network::Address::SocketType::Stream) { - return std::make_shared(std::move(io_handle), address, options); - } else { - return std::make_shared(std::move(io_handle), address, options); + + if (bind_to_port) { + const int fd = server_.hotRestart().duplicateParentListenSocket(addr); + if (fd != -1) { + ENVOY_LOG(debug, "obtained socket for address {} from parent", addr); + Network::IoHandlePtr io_handle = std::make_unique(fd); + if (socket_type == Network::Address::SocketType::Stream) { + return std::make_shared(std::move(io_handle), address, options); + } else { + return std::make_shared(std::move(io_handle), address, options); + } } } + if (socket_type == Network::Address::SocketType::Stream) { return std::make_shared(address, options, bind_to_port); } else { diff --git a/test/server/BUILD b/test/server/BUILD index bae14ad2dd47..84a3b4684d54 100644 --- a/test/server/BUILD +++ b/test/server/BUILD @@ -93,6 +93,7 @@ envoy_cc_test( deps = [ "//source/common/stats:stats_lib", "//source/server:hot_restart_lib", + "//test/mocks/network:network_mocks", "//test/mocks/server:server_mocks", ], ) @@ -201,6 +202,7 @@ envoy_cc_test( "//source/extensions/transport_sockets/tls:config", "//source/extensions/transport_sockets/tls:ssl_socket_lib", "//source/server:active_raw_udp_listener_config", + "//test/test_common:network_utility_lib", "//test/test_common:registry_lib", ], ) diff --git a/test/server/hot_restarting_parent_test.cc b/test/server/hot_restarting_parent_test.cc index 7b65c3316e61..4ee01700faa6 100644 --- a/test/server/hot_restarting_parent_test.cc +++ b/test/server/hot_restarting_parent_test.cc @@ -2,10 +2,12 @@ #include "server/hot_restarting_parent.h" +#include "test/mocks/network/mocks.h" #include "test/mocks/server/mocks.h" #include "gtest/gtest.h" +using testing::InSequence; using testing::Return; using testing::ReturnRef; @@ -40,6 +42,23 @@ TEST_F(HotRestartingParentTest, getListenSocketsForChildNotFound) { EXPECT_EQ(-1, message.reply().pass_listen_socket().fd()); } +TEST_F(HotRestartingParentTest, getListenSocketsForChildNotBindPort) { + MockListenerManager listener_manager; + Network::MockListenerConfig listener_config; + std::vector> listeners; + InSequence s; + listeners.push_back(std::ref(*static_cast(&listener_config))); + EXPECT_CALL(server_, listenerManager()).WillOnce(ReturnRef(listener_manager)); + EXPECT_CALL(listener_manager, listeners()).WillOnce(Return(listeners)); + EXPECT_CALL(listener_config, socket()).Times(1); + EXPECT_CALL(listener_config, bindToPort()).WillOnce(Return(false)); + + HotRestartMessage::Request request; + request.mutable_pass_listen_socket()->set_address("tcp://0.0.0.0:80"); + HotRestartMessage message = hot_restarting_parent_.getListenSocketsForChild(request); + EXPECT_EQ(-1, message.reply().pass_listen_socket().fd()); +} + TEST_F(HotRestartingParentTest, exportStatsToChild) { Stats::IsolatedStoreImpl store; MockListenerManager listener_manager; diff --git a/test/server/listener_manager_impl_test.cc b/test/server/listener_manager_impl_test.cc index 91882a64f73f..dd4170eb1544 100644 --- a/test/server/listener_manager_impl_test.cc +++ b/test/server/listener_manager_impl_test.cc @@ -20,6 +20,7 @@ #include "extensions/transport_sockets/tls/ssl_socket.h" #include "test/server/utility.h" +#include "test/test_common/network_utility.h" #include "test/test_common/registry.h" #include "test/test_common/utility.h" @@ -915,6 +916,47 @@ name: foo EXPECT_CALL(*listener_foo2, onDestroy()); } +TEST_F(ListenerManagerImplTest, BindToPortEqualToFalse) { + InSequence s; + ProdListenerComponentFactory real_listener_factory(server_); + EXPECT_CALL(*worker_, start(_)); + manager_->startWorkers(guard_dog_); + const std::string listener_foo_yaml = R"EOF( +name: foo +address: + socket_address: + address: 127.0.0.1 + port_value: 1234 +deprecated_v1: + bind_to_port: false +filter_chains: +- filters: [] + )EOF"; + Api::OsSysCallsImpl os_syscall; + auto syscall_result = os_syscall.socket(AF_INET, SOCK_STREAM, 0); + ASSERT_GE(syscall_result.rc_, 0); + ListenerHandle* listener_foo = expectListenerCreate(true, true); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, _, false)) + .WillOnce(Invoke([this, &syscall_result, &real_listener_factory]( + const Network::Address::InstanceConstSharedPtr& address, + Network::Address::SocketType socket_type, + const Network::Socket::OptionsSharedPtr& options, + bool bind_to_port) -> Network::SocketSharedPtr { + EXPECT_CALL(server_, hotRestart).Times(0); + // When bind_to_port is equal to false, create socket fd directly, and do not get socket + // fd through hot restart. + NiceMock os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + ON_CALL(os_sys_calls, socket(AF_INET, _, 0)) + .WillByDefault(Return(Api::SysCallIntResult{syscall_result.rc_, 0})); + return real_listener_factory.createListenSocket(address, socket_type, options, + bind_to_port); + })); + EXPECT_CALL(listener_foo->target_, initialize()); + EXPECT_CALL(*listener_foo, onDestroy()); + EXPECT_TRUE(manager_->addOrUpdateListener(parseListenerFromV2Yaml(listener_foo_yaml), "", true)); +} + TEST_F(ListenerManagerImplTest, CantBindSocket) { InSequence s;