Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC][REFACTOR] Use PopenWorker to handle RPC Server. #7889

Merged
merged 1 commit into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ def __init__(
port=self.tracker.port,
port_end=10000,
key=device_key,
use_popen=True,
silent=True,
tracker_addr=(self.tracker.host, self.tracker.port),
)
Expand Down
1 change: 0 additions & 1 deletion python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@ def set_task(self, task):
port=9000,
port_end=10000,
key=device_key,
use_popen=True,
silent=True,
tracker_addr=(tracker.host, tracker.port),
)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def kill_child_processes(pid):

try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
except psutil.NoSuchProcess:
return

for process in parent.children(recursive=True):
for process in children:
try:
process.kill()
except psutil.NoSuchProcess:
Expand Down
20 changes: 6 additions & 14 deletions python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""Start an RPC server"""
from __future__ import absolute_import

import argparse
import multiprocessing
import sys
import logging
from .. import rpc

Expand Down Expand Up @@ -51,6 +47,7 @@ def main(args):
load_library=args.load_library,
custom_addr=args.custom_addr,
silent=args.silent,
no_fork=not args.fork,
)
server.proc.join()

Expand Down Expand Up @@ -85,14 +82,9 @@ def main(args):
parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.fork is False:
if sys.version_info[0] < 3:
raise RuntimeError("Python3 is required for spawn mode.")
multiprocessing.set_start_method("spawn")
else:
if not args.silent:
logging.info(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
if not args.fork is False and not args.silent:
logging.info(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main(args)
6 changes: 3 additions & 3 deletions python/tvm/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def __init__(
self.thread.start()


def _popen_start_server(
def _popen_start_proxy_server(
host,
port=9091,
port_end=9199,
Expand Down Expand Up @@ -570,7 +570,7 @@ def _popen_start_server(
class Proxy(object):
"""Start RPC proxy server on a seperate process.

Python implementation based on multi-processing.
Python implementation based on PopenWorker.

Parameters
----------
Expand Down Expand Up @@ -618,7 +618,7 @@ def __init__(
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_server,
_popen_start_proxy_server,
[
host,
port,
Expand Down
206 changes: 110 additions & 96 deletions python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,21 @@
import select
import struct
import logging
import threading
import multiprocessing
import subprocess
import time
import sys
import signal
import platform
import tvm._ffi

from tvm._ffi.base import py_str
from tvm._ffi.libinfo import find_lib_path
from tvm.runtime.module import load_module as _load_module
from tvm.contrib import utils
from tvm.contrib.popen_pool import PopenWorker
from . import _ffi_api
from . import base

# pylint: disable=unused-import
from . import testing
from .base import TrackerCode

logger = logging.getLogger("RPCServer")
Expand Down Expand Up @@ -296,13 +297,85 @@ def _connect_proxy_loop(addr, key, load_library):
time.sleep(retry_period)


def _popen(cmd):
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=os.environ)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Server invoke error:\n"
msg += out
raise RuntimeError(msg)
class PopenRPCServerState(object):
"""Internal PopenRPCServer State"""

current = None

def __init__(
self,
host,
port=9091,
port_end=9199,
is_proxy=False,
tracker_addr=None,
key="",
load_library=None,
custom_addr=None,
silent=False,
):

# start update
self.host = host
self.port = port
self.libs = []
self.custom_addr = custom_addr

if silent:
logger.setLevel(logging.ERROR)

if not is_proxy:
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.thread = threading.Thread(
target=_listen_loop,
args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr),
)
self.thread.start()
else:
self.thread = threading.Thread(
target=_connect_proxy_loop, args=((host, port), key, load_library)
)
self.thread.start()


def _popen_start_rpc_server(
host,
port=9091,
port_end=9199,
is_proxy=False,
tracker_addr=None,
key="",
load_library=None,
custom_addr=None,
silent=False,
no_fork=False,
):
if no_fork:
multiprocessing.set_start_method("spawn")
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
state = PopenRPCServerState(
host, port, port_end, is_proxy, tracker_addr, key, load_library, custom_addr, silent
)
PopenRPCServerState.current = state
# returns the port so that the main can get the port number.
return state.port


class Server(object):
Expand All @@ -328,11 +401,6 @@ class Server(object):
If this is true, the host and port actually corresponds to the
address of the proxy server.

use_popen : bool, optional
Whether to use Popen to start a fresh new process instead of fork.
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.

tracker_addr: Tuple (str, int) , optional
The address of RPC Tracker in tuple(host, ip) format.
If is not None, the server will register itself to the tracker.
Expand All @@ -348,6 +416,9 @@ class Server(object):

silent: bool, optional
Whether run this server in silent mode.

no_fork: bool, optional
Whether forbid fork in multiprocessing.
"""

def __init__(
Expand All @@ -356,101 +427,44 @@ def __init__(
port=9091,
port_end=9199,
is_proxy=False,
use_popen=False,
tracker_addr=None,
key="",
load_library=None,
custom_addr=None,
silent=False,
no_fork=False,
):
try:
if _ffi_api.ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_rpc_server,
[
host,
port,
port_end,
is_proxy,
tracker_addr,
key,
load_library,
custom_addr,
silent,
no_fork,
],
)
# receive the port
self.port = self.proc.recv()
self.host = host
self.port = port
self.libs = []
self.custom_addr = custom_addr
self.use_popen = use_popen

if silent:
logger.setLevel(logging.ERROR)

if use_popen:
cmd = [
sys.executable,
"-m",
"tvm.exec.rpc_server",
"--host=%s" % host,
"--port=%s" % port,
"--port-end=%s" % port_end,
]
if tracker_addr:
assert key
cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key]
if load_library:
cmd += ["--load-library", load_library]
if custom_addr:
cmd += ["--custom-addr", custom_addr]
if silent:
cmd += ["--silent"]

# prexec_fn is not thread safe and may result in deadlock.
# python 3.2 introduced the start_new_session parameter as
# an alternative to the common use case of
# prexec_fn=os.setsid. Once the minimum version of python
# supported by TVM reaches python 3.2 this code can be
# rewritten in favour of start_new_session. In the
# interim, stop the pylint diagnostic.
#
# pylint: disable=subprocess-popen-preexec-fn
if platform.system() == "Windows":
self.proc = subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
else:
self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
time.sleep(0.5)
elif not is_proxy:
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop,
args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr),
)
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library)
)
self.proc.start()

def terminate(self):
"""Terminate the server process"""
if self.use_popen:
if self.proc:
if platform.system() == "Windows":
os.kill(self.proc.pid, signal.CTRL_C_EVENT)
else:
os.killpg(self.proc.pid, signal.SIGTERM)
self.proc = None
else:
if self.proc:
self.proc.terminate()
self.proc = None
if self.proc:
self.proc.kill()
self.proc = None

def __del__(self):
self.terminate()
Loading