Skip to content

Commit

Permalink
[RPC] Add the IPV6 support for server side auto tuning (#2462)
Browse files Browse the repository at this point in the history
* use IPV6 instead of IPV4

* backward compatible

* add error report

* fix linter

* more linter

* fix the python2 api
  • Loading branch information
lly-zero-one authored and eqy committed Jan 21, 2019
1 parent e4b9f98 commit 0806b69
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 19 deletions.
2 changes: 1 addition & 1 deletion python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def main(args):
"""Main function"""

if args.tracker:
url, port = args.tracker.split(":")
url, port = args.tracker.rsplit(":", 1)
port = int(port)
tracker_addr = (url, port)
if not args.key:
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/rpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class TrackerCode(object):
RPC_SESS_MASK = 128


def get_addr_family(addr):
res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP)
return res[0][0]


def recvall(sock, nbytes):
"""Receive all nbytes from socket.
Expand Down Expand Up @@ -142,7 +147,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
tstart = time.time()
while True:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock = socket.socket(get_addr_family(addr), socket.SOCK_STREAM)
sock.connect(addr)
return sock
except socket.error as sock_err:
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def _update_tracker(self, period_update=False):
"""Update information on tracker."""
try:
if self._tracker_conn is None:
self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._tracker_conn = socket.socket(base.get_addr_family(self._tracker_addr),
socket.SOCK_STREAM)
self._tracker_conn.connect(self._tracker_addr)
self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
Expand Down Expand Up @@ -481,7 +482,7 @@ def __init__(self,
tracker_addr=None,
index_page=None,
resource_files=None):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _connect_proxy_loop(addr, key, load_library):
retry_period = 5
while True:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock = socket.socket(base.get_addr_family(addr), socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack("<i", base.RPC_MAGIC))
sock.sendall(struct.pack("<i", len(key)))
Expand Down Expand Up @@ -334,7 +334,7 @@ def __init__(self,
self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
time.sleep(0.5)
elif not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/rpc/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def __init__(self,
if silent:
logger.setLevel(logging.WARN)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end):
Expand All @@ -391,7 +391,7 @@ def __init__(self,
sock.close()

def _stop_tracker(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM)
sock.connect((self.host, self.port))
sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(sock, 4))[0]
Expand Down
53 changes: 43 additions & 10 deletions src/common/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ inline std::string GetHostName() {
* \brief Common data structure fornetwork address.
*/
struct SockAddr {
sockaddr_in addr;
sockaddr_storage addr;
SockAddr() {}
/*!
* \brief construc address by url and port
Expand All @@ -63,30 +63,63 @@ struct SockAddr {
void Set(const char *host, int port) {
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_family = PF_UNSPEC;
hints.ai_flags = AI_PASSIVE;
hints.ai_protocol = SOCK_STREAM;
addrinfo *res = NULL;
int sig = getaddrinfo(host, NULL, &hints, &res);
CHECK(sig == 0 && res != NULL)
<< "cannot obtain address of " << host;
CHECK(res->ai_family == AF_INET)
<< "Does not support IPv6";
memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port);
switch (res->ai_family) {
case AF_INET: {
sockaddr_in *addr4 = reinterpret_cast<sockaddr_in *>(&addr);
memcpy(addr4, res->ai_addr, res->ai_addrlen);
addr4->sin_port = htons(port);
addr4->sin_family = AF_INET;
}
break;
case AF_INET6: {
sockaddr_in6 *addr6 = reinterpret_cast<sockaddr_in6 *>(&addr);
memcpy(addr6, res->ai_addr, res->ai_addrlen);
addr6->sin6_port = htons(port);
addr6->sin6_family = AF_INET6;
}
break;
default:
CHECK(false) << "cannot decode address";
}
freeaddrinfo(res);
}
/*! \brief return port of the address */
int port() const {
return ntohs(addr.sin_port);
return ntohs((addr.ss_family == AF_INET6)? \
reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_port : \
reinterpret_cast<const sockaddr_in *>(&addr)->sin_port);
}
/*! \brief return the ip address family */
int ss_family() const {
return addr.ss_family;
}
/*! \return a string representation of the address */
std::string AsString() const {
std::string buf; buf.resize(256);

const void *sinx_addr = nullptr;
if (addr.ss_family == AF_INET6) {
const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_addr;
sinx_addr = reinterpret_cast<const void *>(&addr6);
} else if (addr.ss_family == AF_INET) {
const in_addr& addr4 = reinterpret_cast<const sockaddr_in *>(&addr)->sin_addr;
sinx_addr = reinterpret_cast<const void *>(&addr4);
} else {
CHECK(false) << "illegal address";
}

#ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
const char *s = inet_ntop(addr.ss_family, sinx_addr,
&buf[0], buf.length());
#else
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
const char *s = inet_ntop(addr.ss_family, sinx_addr,
&buf[0], static_cast<socklen_t>(buf.length()));
#endif
CHECK(s != nullptr) << "cannot decode address";
Expand Down Expand Up @@ -294,7 +327,7 @@ class TCPSocket : public Socket {
* \param af domain
*/
void Create(int af = PF_INET) {
sockfd = socket(PF_INET, SOCK_STREAM, 0);
sockfd = socket(af, SOCK_STREAM, 0);
if (sockfd == INVALID_SOCKET) {
Socket::Error("Create");
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rpc/rpc_socket_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ std::shared_ptr<RPCSession>
RPCConnect(std::string url, int port, std::string key) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
sock.Create(addr.ss_family());
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
Expand Down

0 comments on commit 0806b69

Please sign in to comment.