Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] out of band serialization exception #47544

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# __anti_pattern_start__
import ray
import pickle
from ray._private.internal_api import memory_summary
import ray.exceptions

ray.init()


@ray.remote
def out_of_band_serialization_pickle():
obj_ref = ray.put(1)
import pickle

# object_ref is serialized from user code using a regular pickle.
# Ray cannot keep track of the reference, so the underlying object
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
# can be GC'ed unexpectedly which can cause unexpected hangs.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
return pickle.dumps(obj_ref)


@ray.remote
def out_of_band_serialization_ray_cloudpickle():
obj_ref = ray.put(1)
from ray import cloudpickle

# ray.cloudpickle can serialize only when
# RAY_allow_out_of_band_object_ref_serialization=1 env var is set.
# However, the object_ref is pinned for the lifetime of the worker
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
# which can cause Ray object leaks that can cause spilling.
return cloudpickle.dumps(obj_ref)


print("==== serialize object ref with pickle ====")
result = ray.get(out_of_band_serialization_pickle.remote())
try:
ray.get(pickle.loads(result), timeout=5)
except ray.exceptions.GetTimeoutError:
print("Underlying object is unexpectedly GC'ed!\n\n")

print("==== serialize object ref with ray.cloudpickle ====")
# By default, it is not allowed to serialize ray.ObjectRef using
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
# ray.cloudpickle.
try:
ray.get(out_of_band_serialization_ray_cloudpickle.remote())
except Exception as e:
print(f"Exception raised from out_of_band_serialization_ray_cloudpickle {e}\n\n")

print(
"==== serialize object ref with ray.cloudpickle with env var "
"RAY_allow_out_of_band_object_ref_serialization ===="
)
# It is allowed to use ray.cloudpickle to serialize object ref using an env var.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
ray.get(
out_of_band_serialization_ray_cloudpickle.options(
runtime_env={
"env_vars": {
"RAY_allow_out_of_band_object_ref_serialization": "1",
}
}
).remote()
)
# you can see objects are stil pinned although it is GC'ed and not used anymore.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
print(memory_summary())
# __anti_pattern_end__
1 change: 1 addition & 0 deletions doc/source/ray-core/patterns/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ This section is a collection of common design patterns and anti-patterns for wri
pass-large-arg-by-value
closure-capture-large-objects
global-variables
out-of-band-object-ref-serialization
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. _ray-out-of-band-object-ref-serialization:

Anti-pattern: Serialize ray.ObjectRef out of band
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
=================================================

**TLDR:** Avoid serialize ``ray.ObjectRef`` because Ray cannot know when the GC the underlying object.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved

Ray's ``ray.ObjectRef`` is distributed reference counted. Ray pins the underlying object until the reference is not used by the system anymore.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
When all references are gone, the pinned object is garbage collected and cleaned up from the system.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
However, if user code serializes ``ray.objectRef``, Ray cannot keep track of the reference.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved

To avoid incorrect behavior, if ``ray.ObjectRef`` is serialized by ``ray.cloudpickle``, Ray pins the object for the lifetime of a worker. "pin" means that object cannot be evicted from the object store
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
until the corresponding worker dies. It is prone to Ray object leaks, which can lead disk spilling.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved

To detect if this pattern exists in your code, you can set an environment variable ``RAY_allow_out_of_band_object_ref_serialization=0``. If Ray detects
``ray.ObjectRef`` is serialized by ``ray.cloudpickle``, it raises an exception with helpful messages.
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved

Code example
------------

**Anti-pattern:**

.. literalinclude:: ../doc_code/anti_pattern_out_of_band_object_ref_serialization.py
:language: python
:start-after: __anti_pattern_start__
:end-before: __anti_pattern_end__
56 changes: 43 additions & 13 deletions python/ray/_private/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
import threading
import traceback
from typing import Any
from typing import Any, Optional
import os


import google.protobuf.message

Expand Down Expand Up @@ -48,11 +50,15 @@
OutOfMemoryError,
ObjectRefStreamEndOfStreamError,
)
import ray.exceptions
from ray.experimental.compiled_dag_ref import CompiledDAGRef
from ray.util import serialization_addons
from ray.util import inspect_serializability

logger = logging.getLogger(__name__)
ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION = bool(
int(os.getenv("RAY_allow_out_of_band_object_ref_serialization", "1"))
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
)


class DeserializationError(Exception):
Expand All @@ -65,11 +71,14 @@ def pickle_dumps(obj: Any, error_msg: str):
"""
try:
return pickle.dumps(obj)
except TypeError as e:
except (TypeError, ray.exceptions.OufOfBandRefSerializationException) as e:
sio = io.StringIO()
inspect_serializability(obj, print_file=sio)
msg = f"{error_msg}:\n{sio.getvalue()}"
raise TypeError(msg) from e
if isinstance(e, TypeError):
raise TypeError(msg) from e
else:
raise ray.exceptions.OufOfBandRefSerializationException(msg)


def _object_ref_deserializer(binary, call_site, owner_address, object_status):
Expand Down Expand Up @@ -127,7 +136,7 @@ def actor_handle_reducer(obj):
serialized, actor_handle_id, weak_ref = obj._serialization_helper()
# Update ref counting for the actor handle
if not weak_ref:
self.add_contained_object_ref(actor_handle_id)
self.add_contained_object_ref(actor_handle_id, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it's always True for this case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it crashes so many tests now. And I think the leak is probably very minimal for actor handle. I will add comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it doesn't leak actual actors)

return _actor_handle_deserializer, (serialized, weak_ref)

self._register_cloudpickle_reducer(ray.actor.ActorHandle, actor_handle_reducer)
Expand All @@ -140,7 +149,11 @@ def compiled_dag_ref_reducer(obj):
def object_ref_reducer(obj):
worker = ray._private.worker.global_worker
worker.check_connected()
self.add_contained_object_ref(obj)
self.add_contained_object_ref(
obj,
ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION,
call_site=obj.call_site(),
)
obj, owner_address, object_status = worker.core_worker.serialize_object_ref(
obj
)
Expand Down Expand Up @@ -199,7 +212,12 @@ def get_and_clear_contained_object_refs(self):
self._thread_local.object_refs = set()
return object_refs

def add_contained_object_ref(self, object_ref):
def add_contained_object_ref(
self,
object_ref: "ray.ObjectRef",
allow_out_of_band_serialization: bool,
call_site: Optional[str] = None,
):
if self.is_in_band_serialization():
# This object ref is being stored in an object. Add the ID to the
# list of IDs contained in the object so that we keep the inner
Expand All @@ -208,13 +226,25 @@ def add_contained_object_ref(self, object_ref):
self._thread_local.object_refs = set()
self._thread_local.object_refs.add(object_ref)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray._private.worker.global_worker.core_worker.add_object_ref_reference(
object_ref
)
if not allow_out_of_band_serialization:
raise ray.exceptions.OufOfBandRefSerializationException(
f"It is not allowed to serialize ray.ObjectRef {object_ref.hex()}. "
"If you want to allow serialization, "
"set `RAY_allow_out_of_band_object_ref_serialization=1.` "
"If you set the env var, the object is pinned forever in the "
"lifetime of the worker process and can cause Ray object leaks. "
"See the callsite and trace to find where the serialization "
"occurs.\nCallsite: "
f"{call_site or 'Disabled. Set RAY_record_ref_creation_sites=1'}"
)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray._private.worker.global_worker.core_worker.add_object_ref_reference(
object_ref
)

def _deserialize_pickle5_data(self, data):
try:
Expand Down
25 changes: 16 additions & 9 deletions python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Iterable, List

import ray
import ray.cloudpickle as cloudpickle
from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.execution.interfaces import PhysicalOperator, RefBundle
from ray.data._internal.execution.interfaces.task_context import TaskContext
Expand All @@ -20,15 +19,21 @@
from ray.data._internal.util import _warn_on_high_parallelism
from ray.data.block import Block, BlockMetadata
from ray.data.datasource.datasource import ReadTask
from ray.experimental.locations import get_local_object_locations
from ray.util.debug import log_once

TASK_SIZE_WARN_THRESHOLD_BYTES = 1024 * 1024 # 1 MiB

logger = logging.getLogger(__name__)


def cleaned_metadata(read_task: ReadTask) -> BlockMetadata:
task_size = len(cloudpickle.dumps(read_task))
def cleaned_metadata(read_task: ReadTask, read_task_ref) -> BlockMetadata:
# NOTE: Use the `get_local_object_locations` API to get the size of the
# serialized ReadTask, instead of pickling.
# Because the ReadTask may capture ObjectRef objects, which cannot
# be serialized out-of-band.
locations = get_local_object_locations([read_task_ref])
task_size = locations[read_task_ref]["object_size"]
if task_size > TASK_SIZE_WARN_THRESHOLD_BYTES and log_once(
f"large_read_task_{read_task.read_fn.__name__}"
):
Expand Down Expand Up @@ -68,23 +73,25 @@ def get_input_data(target_max_block_size) -> List[RefBundle]:
read_tasks = op._datasource_or_legacy_reader.get_read_tasks(parallelism)
_warn_on_high_parallelism(parallelism, len(read_tasks))

return [
RefBundle(
ret = []
for read_task in read_tasks:
read_task_ref = ray.put(read_task)
ref_bundle = RefBundle(
[
(
# TODO(chengsu): figure out a better way to pass read
# tasks other than ray.put().
ray.put(read_task),
cleaned_metadata(read_task),
read_task_ref,
cleaned_metadata(read_task, read_task_ref),
)
],
# `owns_blocks` is False, because these refs are the root of the
# DAG. We shouldn't eagerly free them. Otherwise, the DAG cannot
# be reconstructed.
owns_blocks=False,
)
for read_task in read_tasks
]
ret.append(ref_bundle)
return ret

inputs = InputDataBuffer(
input_data_factory=get_input_data,
Expand Down
10 changes: 10 additions & 0 deletions python/ray/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,15 @@ class ObjectRefStreamEndOfStreamError(RayError):
pass


@DeveloperAPI
class OufOfBandRefSerializationException(RayError):
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
"""Raised when an `ray.ObjectRef` is out of band serialized by
`ray.cloudpickle`. It is an anti pattern.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link to the doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't yet (because doc is not there yet). we should do in a follow up

"""

pass


@PublicAPI(stability="alpha")
class RayChannelError(RaySystemError):
"""Indicates that Ray encountered a system error related
Expand Down Expand Up @@ -879,5 +888,6 @@ class RayAdagCapacityExceeded(RaySystemError):
ActorUnavailableError,
RayChannelError,
RayChannelTimeoutError,
OufOfBandRefSerializationException,
RayAdagCapacityExceeded,
]
49 changes: 49 additions & 0 deletions python/ray/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import ray
import ray.cluster_utils
import ray.exceptions
from ray import cloudpickle

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -733,6 +735,53 @@ def test(x, expect):
assert dataclasses.asdict(new_y) == expect_dict


def test_cannot_out_of_band_serialize_object_ref(shutdown_only, monkeypatch):
monkeypatch.setenv("RAY_allow_out_of_band_object_ref_serialization", "0")
ray.init()
ref = ray.put(1)

@ray.remote
def f():
ref

with pytest.raises(ray.exceptions.OufOfBandRefSerializationException):
ray.get(f.remote())

@ray.remote
def f():
cloudpickle.dumps(ray.put(1))

with pytest.raises(ray.exceptions.OufOfBandRefSerializationException):
ray.get(f.remote())


def test_can_out_of_band_serialize_object_ref_with_env_var(shutdown_only, monkeypatch):
monkeypatch.setenv("RAY_allow_out_of_band_object_ref_serialization", "1")
ray.init()

# Use ray.remote as a workaround because
# RAY_allow_out_of_band_object_ref_serialization cannot be set dynamically.
@ray.remote
def test():
ref = ray.put(1)

@ray.remote
def f():
ref

ray.get(f.remote())

@ray.remote
def f():
ref = ray.put(1)
cloudpickle.dumps(ref)

ray.get(f.remote())

# It should pass.
ray.get(test.remote())


if __name__ == "__main__":
import os
import pytest
Expand Down