diff --git a/doc/source/ray-core/doc_code/anti_pattern_out_of_band_object_ref_serialization.py b/doc/source/ray-core/doc_code/anti_pattern_out_of_band_object_ref_serialization.py new file mode 100644 index 000000000000..fbde67477480 --- /dev/null +++ b/doc/source/ray-core/doc_code/anti_pattern_out_of_band_object_ref_serialization.py @@ -0,0 +1,64 @@ +# __anti_pattern_start__ +import ray +import pickle +from ray._private.internal_api import memory_summary +import ray.exceptions + +ray.init() + + +@ray.remote +def out_of_band_serialization_pickle(): + obj_ref = ray.put(1) + import pickle + + # object_ref is serialized from user code using a regular pickle. + # Ray can't keep track of the reference, so the underlying object + # can be GC'ed unexpectedly, which can cause unexpected hangs. + return pickle.dumps(obj_ref) + + +@ray.remote +def out_of_band_serialization_ray_cloudpickle(): + obj_ref = ray.put(1) + from ray import cloudpickle + + # ray.cloudpickle can serialize only when + # RAY_allow_out_of_band_object_ref_serialization=1 env var is set. + # However, the object_ref is pinned for the lifetime of the worker, + # which can cause Ray object leaks that can cause spilling. + return cloudpickle.dumps(obj_ref) + + +print("==== serialize object ref with pickle ====") +result = ray.get(out_of_band_serialization_pickle.remote()) +try: + ray.get(pickle.loads(result), timeout=5) +except ray.exceptions.GetTimeoutError: + print("Underlying object is unexpectedly GC'ed!\n\n") + +print("==== serialize object ref with ray.cloudpickle ====") +# By default, it's allowed to serialize ray.ObjectRef using +# ray.cloudpickle. +ray.get(out_of_band_serialization_ray_cloudpickle.options().remote()) +# you can see objects are stil pinned although it's GC'ed and not used anymore. +print(memory_summary()) + +print( + "==== serialize object ref with ray.cloudpickle with env var " + "RAY_allow_out_of_band_object_ref_serialization=0 for debugging ====" +) +try: + ray.get( + out_of_band_serialization_ray_cloudpickle.options( + runtime_env={ + "env_vars": { + "RAY_allow_out_of_band_object_ref_serialization": "0", + } + } + ).remote() + ) +except Exception as e: + print(f"Exception raised from out_of_band_serialization_ray_cloudpickle {e}\n\n") + +# __anti_pattern_end__ diff --git a/doc/source/ray-core/objects/serialization.rst b/doc/source/ray-core/objects/serialization.rst index b41c14f4b43f..ef0540456e88 100644 --- a/doc/source/ray-core/objects/serialization.rst +++ b/doc/source/ray-core/objects/serialization.rst @@ -23,6 +23,8 @@ Plasma is used to efficiently transfer objects across different processes and di Each node has its own object store. When data is put into the object store, it does not get automatically broadcasted to other nodes. Data remains local to the writer until requested by another task or actor on another node. +.. _serialize-object-ref: + Serializing ObjectRefs ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/ray-core/patterns/index.rst b/doc/source/ray-core/patterns/index.rst index c94eb0aa44b9..cc3c7e7782be 100644 --- a/doc/source/ray-core/patterns/index.rst +++ b/doc/source/ray-core/patterns/index.rst @@ -26,3 +26,4 @@ This section is a collection of common design patterns and anti-patterns for wri pass-large-arg-by-value closure-capture-large-objects global-variables + out-of-band-object-ref-serialization diff --git a/doc/source/ray-core/patterns/out-of-band-object-ref-serialization.rst b/doc/source/ray-core/patterns/out-of-band-object-ref-serialization.rst new file mode 100644 index 000000000000..1c1c85241fe7 --- /dev/null +++ b/doc/source/ray-core/patterns/out-of-band-object-ref-serialization.rst @@ -0,0 +1,26 @@ +.. _ray-out-of-band-object-ref-serialization: + +Anti-pattern: Serialize ray.ObjectRef out of band +================================================= + +**TLDR:** Avoid serializing ``ray.ObjectRef`` because Ray can't know when to garbage collect the underlying object. + +Ray's ``ray.ObjectRef`` is distributed reference counted. Ray pins the underlying object until the reference isn't used by the system anymore. +When all references are the pinned object gone, Ray garbage collects the pinned object and cleans it up from the system. +However, if user code serializes ``ray.objectRef``, Ray can't keep track of the reference. + +To avoid incorrect behavior, if ``ray.cloudpickle`` serializes``ray.ObjectRef``, Ray pins the object for the lifetime of a worker. "Pin" means that object can't be evicted from the object store +until the corresponding owner worker dies. It's prone to Ray object leaks, which can lead disk spilling. See :ref:`thjs page ` for more details. + +To detect if this pattern exists in your code, you can set an environment variable ``RAY_allow_out_of_band_object_ref_serialization=0``. If Ray detects +that ``ray.cloudpickle`` serialized``ray.ObjectRef``, it raises an exception with helpful messages. + +Code example +------------ + +**Anti-pattern:** + +.. literalinclude:: ../doc_code/anti_pattern_out_of_band_object_ref_serialization.py + :language: python + :start-after: __anti_pattern_start__ + :end-before: __anti_pattern_end__ diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index 397005b45f41..88e6648d012a 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -2,7 +2,8 @@ import logging import threading import traceback -from typing import Any +from typing import Any, Optional + import google.protobuf.message @@ -48,11 +49,15 @@ OutOfMemoryError, ObjectRefStreamEndOfStreamError, ) +import ray.exceptions from ray.experimental.compiled_dag_ref import CompiledDAGRef from ray.util import serialization_addons from ray.util import inspect_serializability logger = logging.getLogger(__name__) +ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION = ray_constants.env_bool( + "RAY_allow_out_of_band_object_ref_serialization", True +) class DeserializationError(Exception): @@ -65,11 +70,14 @@ def pickle_dumps(obj: Any, error_msg: str): """ try: return pickle.dumps(obj) - except TypeError as e: + except (TypeError, ray.exceptions.OufOfBandObjectRefSerializationException) as e: sio = io.StringIO() inspect_serializability(obj, print_file=sio) msg = f"{error_msg}:\n{sio.getvalue()}" - raise TypeError(msg) from e + if isinstance(e, TypeError): + raise TypeError(msg) from e + else: + raise ray.exceptions.OufOfBandObjectRefSerializationException(msg) def _object_ref_deserializer(binary, call_site, owner_address, object_status): @@ -127,7 +135,12 @@ def actor_handle_reducer(obj): serialized, actor_handle_id, weak_ref = obj._serialization_helper() # Update ref counting for the actor handle if not weak_ref: - self.add_contained_object_ref(actor_handle_id) + self.add_contained_object_ref( + actor_handle_id, + # Right now, so many tests are failing when this is set. + # Allow it for now, but we should eventually disallow it here. + allow_out_of_band_serialization=True, + ) return _actor_handle_deserializer, (serialized, weak_ref) self._register_cloudpickle_reducer(ray.actor.ActorHandle, actor_handle_reducer) @@ -140,7 +153,13 @@ def compiled_dag_ref_reducer(obj): def object_ref_reducer(obj): worker = ray._private.worker.global_worker worker.check_connected() - self.add_contained_object_ref(obj) + self.add_contained_object_ref( + obj, + allow_out_of_band_serialization=( + ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION + ), + call_site=obj.call_site(), + ) obj, owner_address, object_status = worker.core_worker.serialize_object_ref( obj ) @@ -199,7 +218,13 @@ def get_and_clear_contained_object_refs(self): self._thread_local.object_refs = set() return object_refs - def add_contained_object_ref(self, object_ref): + def add_contained_object_ref( + self, + object_ref: "ray.ObjectRef", + *, + allow_out_of_band_serialization: bool, + call_site: Optional[str] = None, + ): if self.is_in_band_serialization(): # This object ref is being stored in an object. Add the ID to the # list of IDs contained in the object so that we keep the inner @@ -208,13 +233,25 @@ def add_contained_object_ref(self, object_ref): self._thread_local.object_refs = set() self._thread_local.object_refs.add(object_ref) else: - # If this serialization is out-of-band (e.g., from a call to - # cloudpickle directly or captured in a remote function/actor), - # then pin the object for the lifetime of this worker by adding - # a local reference that won't ever be removed. - ray._private.worker.global_worker.core_worker.add_object_ref_reference( - object_ref - ) + if not allow_out_of_band_serialization: + raise ray.exceptions.OufOfBandObjectRefSerializationException( + f"It is not allowed to serialize ray.ObjectRef {object_ref.hex()}. " + "If you want to allow serialization, " + "set `RAY_allow_out_of_band_object_ref_serialization=1.` " + "If you set the env var, the object is pinned forever in the " + "lifetime of the worker process and can cause Ray object leaks. " + "See the callsite and trace to find where the serialization " + "occurs.\nCallsite: " + f"{call_site or 'Disabled. Set RAY_record_ref_creation_sites=1'}" + ) + else: + # If this serialization is out-of-band (e.g., from a call to + # cloudpickle directly or captured in a remote function/actor), + # then pin the object for the lifetime of this worker by adding + # a local reference that won't ever be removed. + ray._private.worker.global_worker.core_worker.add_object_ref_reference( + object_ref + ) def _deserialize_pickle5_data(self, data): try: diff --git a/python/ray/data/_internal/planner/plan_read_op.py b/python/ray/data/_internal/planner/plan_read_op.py index 649cf3097163..94dbae1f3871 100644 --- a/python/ray/data/_internal/planner/plan_read_op.py +++ b/python/ray/data/_internal/planner/plan_read_op.py @@ -3,7 +3,6 @@ from typing import Iterable, List import ray -import ray.cloudpickle as cloudpickle from ray.data._internal.compute import TaskPoolStrategy from ray.data._internal.execution.interfaces import PhysicalOperator, RefBundle from ray.data._internal.execution.interfaces.task_context import TaskContext @@ -20,6 +19,7 @@ from ray.data._internal.util import _warn_on_high_parallelism from ray.data.block import Block, BlockMetadata from ray.data.datasource.datasource import ReadTask +from ray.experimental.locations import get_local_object_locations from ray.util.debug import log_once TASK_SIZE_WARN_THRESHOLD_BYTES = 1024 * 1024 # 1 MiB @@ -27,8 +27,13 @@ logger = logging.getLogger(__name__) -def cleaned_metadata(read_task: ReadTask) -> BlockMetadata: - task_size = len(cloudpickle.dumps(read_task)) +def cleaned_metadata(read_task: ReadTask, read_task_ref) -> BlockMetadata: + # NOTE: Use the `get_local_object_locations` API to get the size of the + # serialized ReadTask, instead of pickling. + # Because the ReadTask may capture ObjectRef objects, which cannot + # be serialized out-of-band. + locations = get_local_object_locations([read_task_ref]) + task_size = locations[read_task_ref]["object_size"] if task_size > TASK_SIZE_WARN_THRESHOLD_BYTES and log_once( f"large_read_task_{read_task.read_fn.__name__}" ): @@ -68,14 +73,16 @@ def get_input_data(target_max_block_size) -> List[RefBundle]: read_tasks = op._datasource_or_legacy_reader.get_read_tasks(parallelism) _warn_on_high_parallelism(parallelism, len(read_tasks)) - return [ - RefBundle( + ret = [] + for read_task in read_tasks: + read_task_ref = ray.put(read_task) + ref_bundle = RefBundle( [ ( # TODO(chengsu): figure out a better way to pass read # tasks other than ray.put(). - ray.put(read_task), - cleaned_metadata(read_task), + read_task_ref, + cleaned_metadata(read_task, read_task_ref), ) ], # `owns_blocks` is False, because these refs are the root of the @@ -83,8 +90,8 @@ def get_input_data(target_max_block_size) -> List[RefBundle]: # be reconstructed. owns_blocks=False, ) - for read_task in read_tasks - ] + ret.append(ref_bundle) + return ret inputs = InputDataBuffer( input_data_factory=get_input_data, diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 4a3752c31365..d69cde6a283d 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -829,6 +829,15 @@ class ObjectRefStreamEndOfStreamError(RayError): pass +@DeveloperAPI +class OufOfBandObjectRefSerializationException(RayError): + """Raised when an `ray.ObjectRef` is out of band serialized by + `ray.cloudpickle`. It is an anti pattern. + """ + + pass + + @PublicAPI(stability="alpha") class RayChannelError(RaySystemError): """Indicates that Ray encountered a system error related @@ -879,5 +888,6 @@ class RayAdagCapacityExceeded(RaySystemError): ActorUnavailableError, RayChannelError, RayChannelTimeoutError, + OufOfBandObjectRefSerializationException, RayAdagCapacityExceeded, ] diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py index df05d5627130..b19979318817 100644 --- a/python/ray/tests/test_serialization.py +++ b/python/ray/tests/test_serialization.py @@ -14,6 +14,8 @@ import ray import ray.cluster_utils +import ray.exceptions +from ray import cloudpickle logger = logging.getLogger(__name__) @@ -733,6 +735,60 @@ def test(x, expect): assert dataclasses.asdict(new_y) == expect_dict +def test_cannot_out_of_band_serialize_object_ref(shutdown_only, monkeypatch): + monkeypatch.setenv("RAY_allow_out_of_band_object_ref_serialization", "0") + ray.init() + + # Use ray.remote as a workaround because + # RAY_allow_out_of_band_object_ref_serialization cannot be set dynamically. + @ray.remote + def test(): + ref = ray.put(1) + + @ray.remote + def f(): + ref + + with pytest.raises(ray.exceptions.OufOfBandObjectRefSerializationException): + ray.get(f.remote()) + + @ray.remote + def f(): + cloudpickle.dumps(ray.put(1)) + + with pytest.raises(ray.exceptions.OufOfBandObjectRefSerializationException): + ray.get(f.remote()) + + return ray.get(test.remote()) + + +def test_can_out_of_band_serialize_object_ref_with_env_var(shutdown_only, monkeypatch): + monkeypatch.setenv("RAY_allow_out_of_band_object_ref_serialization", "1") + ray.init() + + # Use ray.remote as a workaround because + # RAY_allow_out_of_band_object_ref_serialization cannot be set dynamically. + @ray.remote + def test(): + ref = ray.put(1) + + @ray.remote + def f(): + ref + + ray.get(f.remote()) + + @ray.remote + def f(): + ref = ray.put(1) + cloudpickle.dumps(ref) + + ray.get(f.remote()) + + # It should pass. + ray.get(test.remote()) + + if __name__ == "__main__": import os import pytest