Skip to content

Commit

Permalink
Fix bug in which remote function redefinition doesn't happen. (#6175)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertnishihara authored and edoakes committed Nov 26, 2019
1 parent 7f8de61 commit ffb9c0e
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 47 deletions.
65 changes: 47 additions & 18 deletions python/ray/function_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import dis
import hashlib
import importlib
import inspect
Expand Down Expand Up @@ -102,7 +103,7 @@ def from_bytes_list(cls, function_descriptor_list):
"Invalid input for FunctionDescriptor.from_bytes_list")

@classmethod
def from_function(cls, function):
def from_function(cls, function, pickled_function):
"""Create a FunctionDescriptor from a function instance.
This function is used to create the function descriptor from
Expand All @@ -113,6 +114,9 @@ def from_function(cls, function):
cls: Current class which is required argument for classmethod.
function: the python function used to create the function
descriptor.
pickled_function: This is factored in to ensure that any
modifications to the function result in a different function
descriptor.
Returns:
The FunctionDescriptor instance created according to the function.
Expand All @@ -121,22 +125,10 @@ def from_function(cls, function):
function_name = function.__name__
class_name = ""

function_source_hasher = hashlib.sha1()
try:
# If we are running a script or are in IPython, include the source
# code in the hash.
source = inspect.getsource(function)
if sys.version_info[0] >= 3:
source = source.encode()
function_source_hasher.update(source)
function_source_hash = function_source_hasher.digest()
except (IOError, OSError, TypeError):
# Source code may not be available:
# e.g. Cython or Python interpreter.
function_source_hash = b""
pickled_function_hash = hashlib.sha1(pickled_function).digest()

return cls(module_name, function_name, class_name,
function_source_hash)
pickled_function_hash)

@classmethod
def from_class(cls, target_class):
Expand Down Expand Up @@ -315,6 +307,40 @@ def get_task_counter(self, job_id, function_descriptor):
job_id = ray.JobID.nil()
return self._num_task_executions[job_id][function_id]

def compute_collision_identifier(self, function_or_class):
"""The identifier is used to detect excessive duplicate exports.
The identifier is used to determine when the same function or class is
exported many times. This can yield false positives.
Args:
function_or_class: The function or class to compute an identifier
for.
Returns:
The identifier. Note that different functions or classes can give
rise to same identifier. However, the same function should
hopefully always give rise to the same identifier. TODO(rkn):
verify if this is actually the case. Note that if the
identifier is incorrect in any way, then we may give warnings
unnecessarily or fail to give warnings, but the application's
behavior won't change.
"""
if sys.version_info[0] >= 3:
import io
string_file = io.StringIO()
if sys.version_info[1] >= 7:
dis.dis(function_or_class, file=string_file, depth=2)
else:
dis.dis(function_or_class, file=string_file)
collision_identifier = (
function_or_class.__name__ + ":" + string_file.getvalue())
else:
collision_identifier = function_or_class.__name__

# Return a hash of the identifier in case it is too large.
return hashlib.sha1(collision_identifier.encode("ascii")).digest()

def export(self, remote_function):
"""Pickle a remote function and export it to redis.
Expand All @@ -339,9 +365,11 @@ def export(self, remote_function):
"job_id": self._worker.current_job_id.binary(),
"function_id": remote_function._function_descriptor.
function_id.binary(),
"name": remote_function._function_name,
"function_name": remote_function._function_name,
"module": function.__module__,
"function": pickled_function,
"collision_identifier": self.compute_collision_identifier(
function),
"max_calls": remote_function._max_calls
})
self._worker.redis_client.rpush("Exports", key)
Expand All @@ -351,8 +379,8 @@ def fetch_and_register_remote_function(self, key):
(job_id_str, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self._worker.redis_client.hmget(key, [
"job_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
"job_id", "function_id", "function_name", "function",
"num_return_vals", "module", "resources", "max_calls"
])
function_id = ray.FunctionID(function_id_str)
job_id = ray.JobID(job_id_str)
Expand Down Expand Up @@ -549,6 +577,7 @@ def export_actor_class(self, Class, actor_method_names):
"module": Class.__module__,
"class": pickle.dumps(Class),
"job_id": job_id.binary(),
"collision_identifier": self.compute_collision_identifier(Class),
"actor_method_names": json.dumps(list(actor_method_names))
}

Expand Down
48 changes: 47 additions & 1 deletion python/ray/import_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from __future__ import division
from __future__ import print_function

import redis
from collections import defaultdict
import threading
import traceback

import redis

import ray
from ray import ray_constants
from ray import cloudpickle as pickle
Expand All @@ -30,13 +32,19 @@ class ImportThread(object):
redis_client: the redis client used to query exports.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
imported_collision_identifiers: This is a dictionary mapping collision
identifiers for the exported remote functions and actor classes to
the number of times that collision identifier has appeared. This is
used to provide good error messages when the same function or class
is exported many times.
"""

def __init__(self, worker, mode, threads_stopped):
self.worker = worker
self.mode = mode
self.redis_client = worker.redis_client
self.threads_stopped = threads_stopped
self.imported_collision_identifiers = defaultdict(int)

def start(self):
"""Start the import thread."""
Expand Down Expand Up @@ -91,13 +99,51 @@ def _run(self):
# Close the pubsub client to avoid leaking file descriptors.
import_pubsub_client.close()

def _get_import_info_for_collision_detection(self, key):
"""Retrieve the collision identifier, type, and name of the import."""
if key.startswith(b"RemoteFunction"):
collision_identifier, function_name = (self.redis_client.hmget(
key, ["collision_identifier", "function_name"]))
return (collision_identifier, ray.utils.decode(function_name),
"remote function")
elif key.startswith(b"ActorClass"):
collision_identifier, class_name = self.redis_client.hmget(
key, ["collision_identifier", "class_name"])
return collision_identifier, ray.utils.decode(class_name), "actor"

def _process_key(self, key):
"""Process the given export key from redis."""
# Handle the driver case first.
if self.mode != ray.WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with profiling.profile("fetch_and_run_function"):
self.fetch_and_execute_function_to_run(key)

# If the same remote function or actor definition appears to be
# exported many times, then print a warning. We only issue this
# warning from the driver so that it is only triggered once instead
# of many times. TODO(rkn): We may want to push this to the driver
# through Redis so that it can be displayed in the dashboard more
# easily.
elif (key.startswith(b"RemoteFunction")
or key.startswith(b"ActorClass")):
collision_identifier, name, import_type = (
self._get_import_info_for_collision_detection(key))
self.imported_collision_identifiers[collision_identifier] += 1
if (self.imported_collision_identifiers[collision_identifier]
== ray_constants.DUPLICATE_REMOTE_FUNCTION_THRESHOLD):
logger.warning(
"The %s '%s' has been exported %s times. It's "
"possible that this warning is accidental, but this "
"may indicate that the same remote function is being "
"defined repeatedly from within many tasks and "
"exported to all of the workers. This can be a "
"performance issue and can be resolved by defining "
"the remote function on the driver instead. See "
"https://github.com/ray-project/ray/issues/6240 for "
"more discussion.", import_type, name,
ray_constants.DUPLICATE_REMOTE_FUNCTION_THRESHOLD)

# Return because FunctionsToRun are the only things that
# the driver should import.
return
Expand Down
4 changes: 4 additions & 0 deletions python/ray/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def env_integer(key, default):
# greater than this quantity, print an warning.
PICKLE_OBJECT_WARNING_SIZE = 10**7

# If remote functions with the same source are imported this many times, then
# print a warning.
DUPLICATE_REMOTE_FUNCTION_THRESHOLD = 100

# The maximum resource quantity that is allowed. TODO(rkn): This could be
# relaxed, but the current implementation of the node manager will be slower
# for large resource quantities due to bookkeeping of specific resource IDs.
Expand Down
30 changes: 23 additions & 7 deletions python/ray/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from functools import wraps

from ray import cloudpickle as pickle
from ray.function_manager import FunctionDescriptor
import ray.signature

Expand All @@ -24,7 +25,10 @@ class RemoteFunction(object):
Attributes:
_function: The original function.
_function_descriptor: The function descriptor.
_function_descriptor: The function descriptor. This is not defined
until the remote function is first invoked because that is when the
function is pickled, and the pickled function is used to compute
the function descriptor.
_function_name: The module and function name.
_num_cpus: The default number of CPUs to use for invocations of this
remote function.
Expand Down Expand Up @@ -57,9 +61,6 @@ class RemoteFunction(object):
def __init__(self, function, num_cpus, num_gpus, memory,
object_store_memory, resources, num_return_vals, max_calls):
self._function = function
self._function_descriptor = FunctionDescriptor.from_function(function)
self._function_descriptor_list = (
self._function_descriptor.get_function_descriptor_list())
self._function_name = (
self._function.__module__ + "." + self._function.__name__)
self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS
Expand Down Expand Up @@ -146,10 +147,25 @@ def _remote(self,
worker = ray.worker.get_global_worker()
worker.check_connected()

# If this function was not exported in this session and job, we need to
# export this function again, because the current GCS doesn't have it.
if self._last_export_session_and_job != worker.current_session_and_job:
# If this function was not exported in this session and job,
# we need to export this function again, because current GCS
# doesn't have it.
# There is an interesting question here. If the remote function is
# used by a subsequent driver (in the same script), should the
# second driver pickle the function again? If yes, then the remote
# function definition can differ in the second driver (e.g., if
# variables in its closure have changed). We probably want the
# behavior of the remote function in the second driver to be
# independent of whether or not the function was invoked by the
# first driver. This is an argument for repickling the function,
# which we do here.
self._pickled_function = pickle.dumps(self._function)

self._function_descriptor = FunctionDescriptor.from_function(
self._function, self._pickled_function)
self._function_descriptor_list = (
self._function_descriptor.get_function_descriptor_list())

self._last_export_session_and_job = worker.current_session_and_job
worker.function_actor_manager.export(self)

Expand Down
76 changes: 56 additions & 20 deletions python/ray/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,26 +964,6 @@ def no_op():
def test_defining_remote_functions(shutdown_only):
ray.init(num_cpus=3)

# Test that we can define a remote function in the shell.
@ray.remote
def f(x):
return x + 1

assert ray.get(f.remote(0)) == 1

# Test that we can redefine the remote function.
@ray.remote
def f(x):
return x + 10

while True:
val = ray.get(f.remote(0))
assert val in [1, 10]
if val == 10:
break
else:
logger.info("Still using old definition of f, trying again.")

# Test that we can close over plain old data.
data = [
np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, {
Expand Down Expand Up @@ -1029,6 +1009,62 @@ def m(x):
assert ray.get(m.remote(1)) == 2


def test_redefining_remote_functions(shutdown_only):
ray.init(num_cpus=1)

# Test that we can define a remote function in the shell.
@ray.remote
def f(x):
return x + 1

assert ray.get(f.remote(0)) == 1

# Test that we can redefine the remote function.
@ray.remote
def f(x):
return x + 10

while True:
val = ray.get(f.remote(0))
assert val in [1, 10]
if val == 10:
break
else:
logger.info("Still using old definition of f, trying again.")

# Check that we can redefine functions even when the remote function source
# doesn't change (see https://github.com/ray-project/ray/issues/6130).
@ray.remote
def g():
return nonexistent()

with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"):
ray.get(g.remote())

def nonexistent():
return 1

# Redefine the function and make sure it succeeds.
@ray.remote
def g():
return nonexistent()

assert ray.get(g.remote()) == 1

# Check the same thing but when the redefined function is inside of another
# task.
@ray.remote
def h(i):
@ray.remote
def j():
return i

return j.remote()

for i in range(20):
assert ray.get(ray.get(h.remote(i))) == i


@pytest.mark.skipif(RAY_FORCE_DIRECT, reason="reconstruction not implemented")
def test_submit_api(shutdown_only):
ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1})
Expand Down
Loading

0 comments on commit ffb9c0e

Please sign in to comment.