From 8bd857d6fa3a23eab5c8a158af0711dc88333a26 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 20 Apr 2021 18:04:48 -0400 Subject: [PATCH] [RPC][REFACTOR] Use PopenWorker to handle RPC Server. (#7889) Previously the rpc server relies multiprocessing to start a new process and does not work under jupyter. It also have a popen mode that does ensure the socket start listening before returning the port number. This PR switches the implementations use PopenWorker. The port number is returned after the socket get binded, which resolves some of the RPC flaky issues(need sleep to wait the server to start). It also makes the RPC server jupyter friendly. --- python/tvm/auto_scheduler/measure.py | 1 - python/tvm/autotvm/measure/measure_methods.py | 1 - python/tvm/contrib/popen_pool.py | 3 +- python/tvm/exec/rpc_server.py | 20 +- python/tvm/rpc/proxy.py | 6 +- python/tvm/rpc/server.py | 206 ++++++++++-------- python/tvm/rpc/testing.py | 69 ++++++ tests/python/relay/test_vm.py | 3 +- .../test_runtime_module_based_interface.py | 2 +- tests/python/unittest/test_runtime_rpc.py | 57 +---- 10 files changed, 195 insertions(+), 173 deletions(-) create mode 100644 python/tvm/rpc/testing.py diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 84dff157aa50..e77721496386 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -560,7 +560,6 @@ def __init__( port=self.tracker.port, port_end=10000, key=device_key, - use_popen=True, silent=True, tracker_addr=(self.tracker.host, self.tracker.port), ) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index d212e5f26f20..f328a06c079a 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -404,7 +404,6 @@ def set_task(self, task): port=9000, port_end=10000, key=device_key, - use_popen=True, silent=True, tracker_addr=(tracker.host, tracker.port), ) diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index 5a25484e9106..ecda995c7162 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -44,10 +44,11 @@ def kill_child_processes(pid): try: parent = psutil.Process(pid) + children = parent.children(recursive=True) except psutil.NoSuchProcess: return - for process in parent.children(recursive=True): + for process in children: try: process.kill() except psutil.NoSuchProcess: diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 9692b98fe22b..6b3e93edd223 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -16,11 +16,7 @@ # under the License. # pylint: disable=redefined-outer-name, invalid-name """Start an RPC server""" -from __future__ import absolute_import - import argparse -import multiprocessing -import sys import logging from .. import rpc @@ -51,6 +47,7 @@ def main(args): load_library=args.load_library, custom_addr=args.custom_addr, silent=args.silent, + no_fork=not args.fork, ) server.proc.join() @@ -85,14 +82,9 @@ def main(args): parser.set_defaults(fork=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) - if args.fork is False: - if sys.version_info[0] < 3: - raise RuntimeError("Python3 is required for spawn mode.") - multiprocessing.set_start_method("spawn") - else: - if not args.silent: - logging.info( - "If you are running ROCM/Metal, fork will cause " - "compiler internal error. Try to launch with arg ```--no-fork```" - ) + if not args.fork is False and not args.silent: + logging.info( + "If you are running ROCM/Metal, fork will cause " + "compiler internal error. Try to launch with arg ```--no-fork```" + ) main(args) diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index 28117b09f280..7e02bd77c491 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -537,7 +537,7 @@ def __init__( self.thread.start() -def _popen_start_server( +def _popen_start_proxy_server( host, port=9091, port_end=9199, @@ -570,7 +570,7 @@ def _popen_start_server( class Proxy(object): """Start RPC proxy server on a seperate process. - Python implementation based on multi-processing. + Python implementation based on PopenWorker. Parameters ---------- @@ -618,7 +618,7 @@ def __init__( self.proc = PopenWorker() # send the function self.proc.send( - _popen_start_server, + _popen_start_proxy_server, [ host, port, diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 786154253133..3fd6996034f7 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -31,20 +31,21 @@ import select import struct import logging +import threading import multiprocessing -import subprocess import time -import sys -import signal -import platform import tvm._ffi from tvm._ffi.base import py_str from tvm._ffi.libinfo import find_lib_path from tvm.runtime.module import load_module as _load_module from tvm.contrib import utils +from tvm.contrib.popen_pool import PopenWorker from . import _ffi_api from . import base + +# pylint: disable=unused-import +from . import testing from .base import TrackerCode logger = logging.getLogger("RPCServer") @@ -296,13 +297,85 @@ def _connect_proxy_loop(addr, key, load_library): time.sleep(retry_period) -def _popen(cmd): - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=os.environ) - (out, _) = proc.communicate() - if proc.returncode != 0: - msg = "Server invoke error:\n" - msg += out - raise RuntimeError(msg) +class PopenRPCServerState(object): + """Internal PopenRPCServer State""" + + current = None + + def __init__( + self, + host, + port=9091, + port_end=9199, + is_proxy=False, + tracker_addr=None, + key="", + load_library=None, + custom_addr=None, + silent=False, + ): + + # start update + self.host = host + self.port = port + self.libs = [] + self.custom_addr = custom_addr + + if silent: + logger.setLevel(logging.ERROR) + + if not is_proxy: + sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + self.port = None + for my_port in range(port, port_end): + try: + sock.bind((host, my_port)) + self.port = my_port + break + except socket.error as sock_err: + if sock_err.errno in [98, 48]: + continue + raise sock_err + if not self.port: + raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) + logger.info("bind to %s:%d", host, self.port) + sock.listen(1) + self.sock = sock + self.thread = threading.Thread( + target=_listen_loop, + args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr), + ) + self.thread.start() + else: + self.thread = threading.Thread( + target=_connect_proxy_loop, args=((host, port), key, load_library) + ) + self.thread.start() + + +def _popen_start_rpc_server( + host, + port=9091, + port_end=9199, + is_proxy=False, + tracker_addr=None, + key="", + load_library=None, + custom_addr=None, + silent=False, + no_fork=False, +): + if no_fork: + multiprocessing.set_start_method("spawn") + # This is a function that will be sent to the + # Popen worker to run on a separate process. + # Create and start the server in a different thread + state = PopenRPCServerState( + host, port, port_end, is_proxy, tracker_addr, key, load_library, custom_addr, silent + ) + PopenRPCServerState.current = state + # returns the port so that the main can get the port number. + return state.port class Server(object): @@ -328,11 +401,6 @@ class Server(object): If this is true, the host and port actually corresponds to the address of the proxy server. - use_popen : bool, optional - Whether to use Popen to start a fresh new process instead of fork. - This is recommended to switch on if we want to do local RPC demonstration - for GPU devices to avoid fork safety issues. - tracker_addr: Tuple (str, int) , optional The address of RPC Tracker in tuple(host, ip) format. If is not None, the server will register itself to the tracker. @@ -348,6 +416,9 @@ class Server(object): silent: bool, optional Whether run this server in silent mode. + + no_fork: bool, optional + Whether forbid fork in multiprocessing. """ def __init__( @@ -356,101 +427,44 @@ def __init__( port=9091, port_end=9199, is_proxy=False, - use_popen=False, tracker_addr=None, key="", load_library=None, custom_addr=None, silent=False, + no_fork=False, ): try: if _ffi_api.ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") except NameError: raise RuntimeError("Please compile with USE_RPC=1") + self.proc = PopenWorker() + # send the function + self.proc.send( + _popen_start_rpc_server, + [ + host, + port, + port_end, + is_proxy, + tracker_addr, + key, + load_library, + custom_addr, + silent, + no_fork, + ], + ) + # receive the port + self.port = self.proc.recv() self.host = host - self.port = port - self.libs = [] - self.custom_addr = custom_addr - self.use_popen = use_popen - - if silent: - logger.setLevel(logging.ERROR) - - if use_popen: - cmd = [ - sys.executable, - "-m", - "tvm.exec.rpc_server", - "--host=%s" % host, - "--port=%s" % port, - "--port-end=%s" % port_end, - ] - if tracker_addr: - assert key - cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key] - if load_library: - cmd += ["--load-library", load_library] - if custom_addr: - cmd += ["--custom-addr", custom_addr] - if silent: - cmd += ["--silent"] - - # prexec_fn is not thread safe and may result in deadlock. - # python 3.2 introduced the start_new_session parameter as - # an alternative to the common use case of - # prexec_fn=os.setsid. Once the minimum version of python - # supported by TVM reaches python 3.2 this code can be - # rewritten in favour of start_new_session. In the - # interim, stop the pylint diagnostic. - # - # pylint: disable=subprocess-popen-preexec-fn - if platform.system() == "Windows": - self.proc = subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP) - else: - self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid) - time.sleep(0.5) - elif not is_proxy: - sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) - self.port = None - for my_port in range(port, port_end): - try: - sock.bind((host, my_port)) - self.port = my_port - break - except socket.error as sock_err: - if sock_err.errno in [98, 48]: - continue - raise sock_err - if not self.port: - raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logger.info("bind to %s:%d", host, self.port) - sock.listen(1) - self.sock = sock - self.proc = multiprocessing.Process( - target=_listen_loop, - args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr), - ) - self.proc.start() - else: - self.proc = multiprocessing.Process( - target=_connect_proxy_loop, args=((host, port), key, load_library) - ) - self.proc.start() def terminate(self): """Terminate the server process""" - if self.use_popen: - if self.proc: - if platform.system() == "Windows": - os.kill(self.proc.pid, signal.CTRL_C_EVENT) - else: - os.killpg(self.proc.pid, signal.SIGTERM) - self.proc = None - else: - if self.proc: - self.proc.terminate() - self.proc = None + if self.proc: + self.proc.kill() + self.proc = None def __del__(self): self.terminate() diff --git a/python/tvm/rpc/testing.py b/python/tvm/rpc/testing.py new file mode 100644 index 000000000000..b7acc74c413a --- /dev/null +++ b/python/tvm/rpc/testing.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name,unnecessary-comprehension +""" Testing functions for the RPC server.""" +import numpy as np +import tvm + + +# RPC test functions to be registered for unit-tests purposes +@tvm.register_func("rpc.test.addone") +def _addone(x): + return x + 1 + + +@tvm.register_func("rpc.test.strcat") +def _strcat(name, x): + return "%s:%d" % (name, x) + + +@tvm.register_func("rpc.test.except") +def _remotethrow(name): + raise ValueError("%s" % name) + + +@tvm.register_func("rpc.test.runtime_str_concat") +def _strcat(x, y): + return x + y + + +@tvm.register_func("rpc.test.remote_array_func") +def _remote_array_func(y): + x = np.ones((3, 4)) + np.testing.assert_equal(y.asnumpy(), x) + + +@tvm.register_func("rpc.test.add_to_lhs") +def _add_to_lhs(x): + return lambda y: x + y + + +@tvm.register_func("rpc.test.remote_return_nd") +def _my_module(name): + # Use closure to check the ref counter correctness + nd = tvm.nd.array(np.zeros(10).astype("float32")) + + if name == "get_arr": + return lambda: nd + if name == "ref_count": + return lambda: tvm.testing.object_use_count(nd) + if name == "get_elem": + return lambda idx: nd.asnumpy()[idx] + if name == "get_arr_elem": + return lambda arr, idx: arr.asnumpy()[idx] + raise RuntimeError("unknown name") diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 58985832fb35..7e790494125e 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -827,8 +827,7 @@ def test_vm_rpc(): # Use local rpc server for testing. # Server must use popen so it doesn't inherit the current process state. It # will crash otherwise. - server = rpc.Server("localhost", port=9120, use_popen=True) - time.sleep(2) + server = rpc.Server("localhost", port=9120) remote = rpc.connect(server.host, server.port, session_timeout=10) # Upload the serialized Executable. diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 766338de3558..9658ce1b2c1e 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -262,7 +262,7 @@ def verify_rpc_gpu_export(obj_format): from tvm import rpc - server = rpc.Server("localhost", use_popen=True, port=9094) + server = rpc.Server("localhost", port=9094) remote = rpc.connect(server.host, server.port) remote.upload(path_lib) loaded_lib = remote.load_module(path_lib) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 256fd33387bf..a74f893065b8 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -85,21 +85,6 @@ def verify_rpc(remote, target, shape, dtype): verify_rpc(remote, target, (10,), dtype) -@tvm.register_func("rpc.test.addone") -def addone(x): - return x + 1 - - -@tvm.register_func("rpc.test.strcat") -def strcat(name, x): - return "%s:%d" % (name, x) - - -@tvm.register_func("rpc.test.except") -def remotethrow(name): - raise ValueError("%s" % name) - - @tvm.testing.requires_rpc def test_rpc_simple(): server = rpc.Server("localhost", key="x1") @@ -115,11 +100,6 @@ def test_rpc_simple(): assert f2("abc", 11) == "abc:11" -@tvm.register_func("rpc.test.runtime_str_concat") -def strcat(x, y): - return x + y - - @tvm.testing.requires_rpc def test_rpc_runtime_string(): server = rpc.Server("localhost", key="x1") @@ -130,12 +110,6 @@ def test_rpc_runtime_string(): assert str(func(x, y)) == "abcdef" -@tvm.register_func("rpc.test.remote_array_func") -def remote_array_func(y): - x = np.ones((3, 4)) - np.testing.assert_equal(y.asnumpy(), x) - - @tvm.testing.requires_rpc def test_rpc_array(): x = np.ones((3, 4)) @@ -342,16 +316,11 @@ def check_remote_link_cl(remote): check_minrpc() -@tvm.register_func("rpc.test.remote_func") -def addone(x): - return lambda y: x + y - - @tvm.testing.requires_rpc def test_rpc_return_func(): server = rpc.Server("localhost", key="x1") client = rpc.connect(server.host, server.port, key="x1") - f1 = client.get_function("rpc.test.remote_func") + f1 = client.get_function("rpc.test.add_to_lhs") fadd = f1(10) assert fadd(12) == 22 @@ -393,21 +362,6 @@ def check_error_handling(): check_error_handling() -@tvm.register_func("rpc.test.remote_return_nd") -def my_module(name): - # Use closure to check the ref counter correctness - nd = tvm.nd.array(np.zeros(10).astype("float32")) - - if name == "get_arr": - return lambda: nd - elif name == "ref_count": - return lambda: tvm.testing.object_use_count(nd) - elif name == "get_elem": - return lambda idx: nd.asnumpy()[idx] - elif name == "get_arr_elem": - return lambda arr, idx: arr.asnumpy()[idx] - - @tvm.testing.requires_rpc def test_rpc_return_ndarray(): # start server @@ -428,15 +382,10 @@ def run_arr_test(): run_arr_test() -@tvm.register_func("rpc.test.remote_func2") -def addone(x): - return lambda y: x + y - - @tvm.testing.requires_rpc def test_local_func(): client = rpc.LocalSession() - f1 = client.get_function("rpc.test.remote_func2") + f1 = client.get_function("rpc.test.add_to_lhs") fadd = f1(10) assert fadd(12) == 22 @@ -458,8 +407,8 @@ def test_rpc_tracker_register(): key=device_key, tracker_addr=(tracker.host, tracker.port), ) - time.sleep(1) client = rpc.connect_tracker(tracker.host, tracker.port) + time.sleep(1) summary = client.summary() assert summary["queue_info"][device_key]["free"] == 1