Skip to content

Commit

Permalink
Cache call site
Browse files Browse the repository at this point in the history
  • Loading branch information
mwtian committed Aug 17, 2021
1 parent 9136bb9 commit 676bfa0
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 79 deletions.
7 changes: 7 additions & 0 deletions python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ cdef class CoreWorker:
owner_address=*,
c_bool inline_small_object=*)
cdef unique_ptr[CAddress] _convert_python_address(self, address=*)
cpdef add_object_ref_reference(self, ObjectRef object_ref,
const c_string &call_site)
cdef deserialize_and_register_object_ref(
self, CObjectID object_id,
CObjectID outer_object_id,
const c_string &serialized_owner_address,
const c_string &serialized_object_status)
cdef store_task_outputs(
self, worker, outputs, const c_vector[CObjectID] return_ids,
c_vector[shared_ptr[CRayObject]] *returns)
Expand Down
93 changes: 71 additions & 22 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -847,14 +847,20 @@ cdef void get_py_stack(c_string* stack_out) nogil:
stack_out[0] = "".encode("ascii")
return
msg_frames = []
seen_worker = False
while frame and len(msg_frames) < 4:
filename = frame.f_code.co_filename
# Decode Ray internal frames to add annotations.
if filename.endswith("ray/worker.py"):
if frame.f_code.co_name == "put":
msg_frames = ["(put object) "]
elif filename.endswith("ray/workers/default_worker.py"):
pass
elif frame.f_code.co_name == "deserialize_objects":
msg_frames = ["(deserialize object) "]
elif not seen_worker:
msg_frames.append("({}:{}:{}, omitting rest from the file)".format(
frame.f_code.co_filename, frame.f_code.co_name,
frame.f_lineno))
seen_worker = True
elif filename.endswith("ray/remote_function.py"):
# TODO(ekl) distinguish between task return objects and
# arguments. This can only be done in the core worker.
Expand All @@ -863,16 +869,18 @@ cdef void get_py_stack(c_string* stack_out) nogil:
# TODO(ekl) distinguish between actor return objects and
# arguments. This can only be done in the core worker.
msg_frames = ["(actor call) "]
elif filename.endswith("ray/serialization.py"):
if frame.f_code.co_name == "id_deserializer":
msg_frames = ["(deserialize task arg) "]
elif filename.endswith("ray/workers/default_worker.py"):
pass
elif filename.endswith("ray/_private/client_mode_hook.py"):
pass
else:
msg_frames.append("{}:{}:{}".format(
frame.f_code.co_filename, frame.f_code.co_name,
frame.f_lineno))
frame = frame.f_back
stack_out[0] = " | ".join(msg_frames).encode("ascii")


cdef shared_ptr[CBuffer] string_to_buffer(c_string& c_str):
cdef shared_ptr[CBuffer] empty_metadata
if c_str.size() == 0:
Expand Down Expand Up @@ -1595,8 +1603,7 @@ cdef class CoreWorker:
worker.current_session_and_job)

def deserialize_and_register_actor_handle(self, const c_string &bytes,
ObjectRef
outer_object_ref):
ObjectRef outer_object_ref):
cdef:
CObjectID c_outer_object_id = (outer_object_ref.native() if
outer_object_ref else
Expand Down Expand Up @@ -1653,10 +1660,16 @@ cdef class CoreWorker:
actor_id.native(), &output, &c_actor_handle_id))
return output, ObjectRef(c_actor_handle_id.Binary())

def add_object_ref_reference(self, ObjectRef object_ref):
cpdef add_object_ref_reference(self, ObjectRef object_ref,
const c_string &call_site):
# Note: faster to not release GIL for short-running op.
CCoreWorkerProcess.GetCoreWorker().AddLocalReference(
object_ref.native())
if call_site.empty():
CCoreWorkerProcess.GetCoreWorker().AddLocalReference(
object_ref.native())
else:
CCoreWorkerProcess.GetCoreWorker().AddLocalReference(
object_ref.native(), call_site)


def remove_object_ref_reference(self, ObjectRef object_ref):
cdef:
Expand All @@ -1679,25 +1692,21 @@ cdef class CoreWorker:
c_owner_address.SerializeAsString(),
serialized_object_status)

def deserialize_and_register_object_ref(
self, const c_string &object_ref_binary,
ObjectRef outer_object_ref,
cdef deserialize_and_register_object_ref(
self, CObjectID object_id,
CObjectID outer_object_id,
const c_string &serialized_owner_address,
const c_string &serialized_object_status,
):
cdef:
CObjectID c_object_id = CObjectID.FromBinary(object_ref_binary)
CObjectID c_outer_object_id = (outer_object_ref.native() if
outer_object_ref else
CObjectID.Nil())
CAddress c_owner_address = CAddress()
CAddress owner_address = CAddress()

c_owner_address.ParseFromString(serialized_owner_address)
owner_address.ParseFromString(serialized_owner_address)
(CCoreWorkerProcess.GetCoreWorker()
.RegisterOwnershipInfoAndResolveFuture(
c_object_id,
c_outer_object_id,
c_owner_address,
object_id,
outer_object_id,
owner_address,
serialized_object_status))

cdef store_task_outputs(
Expand Down Expand Up @@ -1942,3 +1951,43 @@ cdef void async_callback(shared_ptr[CRayObject] obj,
py_callback = <object>user_callback
py_callback(result)
cpython.Py_DECREF(py_callback)
# Context for deserializing object refs and actor handles contained in
# another object.
cdef class DeserializationInfo:
cdef:
public ObjectRef outer_object_ref
CObjectID outer_object_id
# Only set when needed, because it is expensive to compute,
c_string call_site
def __cinit__(self, ObjectRef object_ref):
self.outer_object_ref = object_ref
self.outer_object_id = object_ref.native() \
if object_ref else CObjectID.Nil()
cpdef object_ref_deserializer(bytes binary, bytes owner_address,
bytes object_status):
worker = ray.worker.global_worker
# info is None in the case that this ObjectRef was closed
# over in a function or pickled directly using pickle.dumps().
cdef DeserializationInfo info = worker.get_serialization_context() \
.get_deserialization_info()
if not info:
info = DeserializationInfo(None)
if info.call_site.empty():
get_py_stack(&info.call_site)
# NOTE(swang): Must deserialize the object first before asking
# the core worker to resolve the value. This is to make sure
# that the ref count for the ObjectRef is greater than 0 by the
# time the core worker resolves the value of the object.
cdef ObjectRef obj_ref = ray.ObjectRef(binary, info.call_site)
cdef CoreWorker core_worker = worker.core_worker
if owner_address:
core_worker.deserialize_and_register_object_ref(
obj_ref.native(), info.outer_object_id,
owner_address, object_status)
return obj_ref
2 changes: 2 additions & 0 deletions python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
pair[c_vector[c_pair[c_string, c_string]], CRayStatus] ListNamedActors(
c_bool all_namespaces)
void AddLocalReference(const CObjectID &object_id)
void AddLocalReference(const CObjectID &object_id,
const c_string &call_site)
void RemoveLocalReference(const CObjectID &object_id)
void PutObjectIntoPlasma(const CRayObject &object,
const CObjectID &object_id)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/includes/object_ref.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _set_future_helper(

cdef class ObjectRef(BaseID):

def __init__(self, id):
def __init__(self, id, c_string call_site = b""):
check_id(id)
self.data = CObjectID.FromBinary(<c_string>id)
self.in_core_worker = False
Expand All @@ -45,7 +45,7 @@ cdef class ObjectRef(BaseID):
# But there are still some dummy object refs being created outside the
# context of a core worker.
if hasattr(worker, "core_worker"):
worker.core_worker.add_object_ref_reference(self)
worker.core_worker.add_object_ref_reference(self, call_site)
self.in_core_worker = True

def __dealloc__(self):
Expand Down
67 changes: 22 additions & 45 deletions python/ray/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
WorkerCrashedError, ObjectLostError,
RaySystemError, RuntimeEnvSetupError)
from ray._raylet import (
object_ref_deserializer,
split_buffer,
unpack_pickle5_buffers,
DeserializationInfo,
Pickle5Writer,
Pickle5SerializedObject,
MessagePackSerializer,
Expand All @@ -28,41 +30,13 @@ class DeserializationError(Exception):
pass


def _object_ref_deserializer(binary, owner_address, object_status):
# NOTE(suquark): This function should be a global function so
# cloudpickle can access it directly. Otherwise cloudpickle
# has to dump the whole function definition, which is inefficient.

# NOTE(swang): Must deserialize the object first before asking
# the core worker to resolve the value. This is to make sure
# that the ref count for the ObjectRef is greater than 0 by the
# time the core worker resolves the value of the object.
obj_ref = ray.ObjectRef(binary)

# TODO(edoakes): we should be able to just capture a reference
# to 'self' here instead, but this function is itself pickled
# somewhere, which causes an error.
if owner_address:
worker = ray.worker.global_worker
worker.check_connected()
context = worker.get_serialization_context()
outer_id = context.get_outer_object_ref()
# outer_id is None in the case that this ObjectRef was closed
# over in a function or pickled directly using pickle.dumps().
if outer_id is None:
outer_id = ray.ObjectRef.nil()
worker.core_worker.deserialize_and_register_object_ref(
obj_ref.binary(), outer_id, owner_address, object_status)
return obj_ref


def _actor_handle_deserializer(serialized_obj):
# If this actor handle was stored in another object, then tell the
# core worker.
context = ray.worker.global_worker.get_serialization_context()
outer_id = context.get_outer_object_ref()
info = context.get_deserialization_info()
return ray.actor.ActorHandle._deserialization_helper(
serialized_obj, outer_id)
serialized_obj, info.outer_object_ref if info else None)


class SerializationContext:
Expand Down Expand Up @@ -91,7 +65,7 @@ def object_ref_reducer(obj):
worker.check_connected()
obj, owner_address, object_status = (
worker.core_worker.serialize_and_promote_object_ref(obj))
return _object_ref_deserializer, \
return object_ref_deserializer, \
(obj.binary(), owner_address, object_status)

self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
Expand Down Expand Up @@ -120,10 +94,6 @@ def set_in_band_serialization(self):
def set_out_of_band_serialization(self):
self._thread_local.in_band = False

def get_outer_object_ref(self):
stack = getattr(self._thread_local, "object_ref_stack", [])
return stack[-1] if stack else None

def get_and_clear_contained_object_refs(self):
if not hasattr(self._thread_local, "object_refs"):
self._thread_local.object_refs = set()
Expand All @@ -133,6 +103,14 @@ def get_and_clear_contained_object_refs(self):
self._thread_local.object_refs = set()
return object_refs

def get_deserialization_info(self):
return getattr(self._thread_local, "deserialization_info", None)

def switch_deserialization_info(self, info: DeserializationInfo):
current = getattr(self._thread_local, "deserialization_info", None)
self._thread_local.deserialization_info = info
return current

def add_contained_object_ref(self, object_ref):
if self.is_in_band_serialization():
# This object ref is being stored in an object. Add the ID to the
Expand All @@ -147,7 +125,7 @@ def add_contained_object_ref(self, object_ref):
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray.worker.global_worker.core_worker.add_object_ref_reference(
object_ref)
object_ref, "")

def _deserialize_pickle5_data(self, data):
try:
Expand Down Expand Up @@ -241,24 +219,23 @@ def _deserialize_object(self, data, metadata, object_ref):

def deserialize_objects(self, data_metadata_pairs, object_refs):
assert len(data_metadata_pairs) == len(object_refs)
# initialize the thread-local field
if not hasattr(self._thread_local, "object_ref_stack"):
self._thread_local.object_ref_stack = []
results = []
for object_ref, (data, metadata) in zip(object_refs,
data_metadata_pairs):
try:
# Push the object ref to the stack, so the object under
# the object ref knows where it comes from.
self._thread_local.object_ref_stack.append(object_ref)
# Store the outer object ref in thread local, so inner object
# refs know where they come from.
prev = self.switch_deserialization_info(
DeserializationInfo(object_ref))
obj = self._deserialize_object(data, metadata, object_ref)
except Exception as e:
logger.exception(e)
obj = RaySystemError(e, traceback.format_exc())
finally:
# Must clear ObjectRef to not hold a reference.
if self._thread_local.object_ref_stack:
self._thread_local.object_ref_stack.pop()
# Restore previous thread local for correctness and avoiding
# keeping the outer object ref.
info = self.switch_deserialization_info(prev)
assert info.outer_object_ref == object_ref
results.append(obj)
return results

Expand Down
19 changes: 9 additions & 10 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,19 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {

void SetCallerCreationTimestamp();

/// Increase the reference count for this object ID.
/// Increase the local reference count for this object ID. Should be called
/// by the language frontend when a new reference is created.
///
/// \param[in] object_id The object ID to increase the reference count for.
/// \param[in] call_site The call site from the language frontend, for debugging.
void AddLocalReference(const ObjectID &object_id, const std::string &call_site) {
reference_counter_->AddLocalReference(object_id, call_site);
}

/// Same as above, except getting the call site from language frontend via
/// `CoreWorkerOptions.get_lang_stack`.
///
/// \param[in] object_id The object ID to increase the reference count for.
void AddLocalReference(const ObjectID &object_id) {
AddLocalReference(object_id, CurrentCallSite());
}
Expand Down Expand Up @@ -1085,15 +1093,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
/// Private methods related to task submission.
///

/// Increase the local reference count for this object ID. Should be called
/// by the language frontend when a new reference is created.
///
/// \param[in] object_id The object ID to increase the reference count for.
/// \param[in] call_site The call site from the language frontend.
void AddLocalReference(const ObjectID &object_id, std::string call_site) {
reference_counter_->AddLocalReference(object_id, call_site);
}

/// Stops the children tasks from the given TaskID
///
/// \param[in] task_id of the parent task
Expand Down

0 comments on commit 676bfa0

Please sign in to comment.