Skip to content

Commit

Permalink
[Core] Defer SIGINT interrupt during task argument deserialization. (r…
Browse files Browse the repository at this point in the history
…ay-project#30476)

Importing certain libraries (e.g. Arrow, Pandas, Torch) is not reentrant, and task cancellation is occasionally interrupting the Arrow import triggered via this deserialization add-on during task argument deserialization, which we are then trying to import again when serializing the error. See here for an example failure: https://buildkite.com/ray-project/oss-ci-build-branch/builds/1115#018485e1-df32-480f-9c36-cc898341f0a2

This PR prevents this import reentrancy from happening for the task cancellation case by deferring interrupts until after task argument deserialization finishes, so we can be sure that the serialization-related imports have finished before processing the interrupt.

Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
clarkzinzow authored and WeichenXu123 committed Dec 19, 2022
1 parent 7580ef0 commit df88dc9
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 7 deletions.
76 changes: 76 additions & 0 deletions python/ray/_private/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import binascii
import contextlib
import errno
import functools
import hashlib
Expand All @@ -23,6 +24,7 @@
from pathlib import Path
from subprocess import list2cmdline
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union

import grpc
import numpy as np

Expand Down Expand Up @@ -1775,3 +1777,77 @@ def _get_pyarrow_version() -> Optional[str]:
if hasattr(pyarrow, "__version__"):
_PYARROW_VERSION = pyarrow.__version__
return _PYARROW_VERSION


class DeferSigint(contextlib.AbstractContextManager):
"""Context manager that defers SIGINT signals until the the context is left."""

# This is used by Ray's task cancellation to defer cancellation interrupts during
# problematic areas, e.g. task argument deserialization.
def __init__(self):
# Whether the task has been cancelled while in the context.
self.task_cancelled = False
# The original SIGINT handler.
self.orig_sigint_handler = None
# The original signal method.
self.orig_signal = None

@classmethod
def create_if_main_thread(cls) -> contextlib.AbstractContextManager:
"""Creates a DeferSigint context manager if running on the main thread,
returns a no-op context manager otherwise.
"""
if threading.current_thread() == threading.main_thread():
return cls()
else:
# TODO(Clark): Use contextlib.nullcontext() once Python 3.6 support is
# dropped.
return contextlib.suppress()

def _set_task_cancelled(self, signum, frame):
"""SIGINT handler that defers the signal."""
self.task_cancelled = True

def _signal_monkey_patch(self, signum, handler):
"""Monkey patch for signal.signal that raises an error if a SIGINT handler is
registered within the DeferSigint context.
"""
# Only raise an error if setting a SIGINT handler in the main thread; if setting
# a handler in a non-main thread, signal.signal will raise an error anyway
# indicating that Python does not allow that.
if (
threading.current_thread() == threading.main_thread()
and signum == signal.SIGINT
):
raise ValueError(
"Can't set signal handler for SIGINT while SIGINT is being deferred "
"within a DeferSigint context."
)
return self.orig_signal(signum, handler)

def __enter__(self):
# Save original SIGINT handler for later restoration.
self.orig_sigint_handler = signal.getsignal(signal.SIGINT)
# Set SIGINT signal handler that defers the signal.
signal.signal(signal.SIGINT, self._set_task_cancelled)
# Monkey patch signal.signal to raise an error if a SIGINT handler is registered
# within the context.
self.orig_signal = signal.signal
signal.signal = self._signal_monkey_patch
return self

def __exit__(self, exc_type, exc, exc_tb):
assert self.orig_sigint_handler is not None
assert self.orig_signal is not None
# Restore original signal.signal function.
signal.signal = self.orig_signal
# Restore original SIGINT handler.
signal.signal(signal.SIGINT, self.orig_sigint_handler)
if exc_type is None and self.task_cancelled:
# No exception raised in context but task has been cancelled, so we raise
# KeyboardInterrupt to go through the task cancellation path.
raise KeyboardInterrupt
else:
# If exception was raised in context, returning False will cause it to be
# reraised.
return False
21 changes: 16 additions & 5 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import msgpack
import os
import pickle
import setproctitle
import signal
import sys
import threading
import time
Expand Down Expand Up @@ -138,7 +139,7 @@ from ray._private.client_mode_hook import disable_client_hook
import ray._private.gcs_utils as gcs_utils
import ray._private.memory_monitor as memory_monitor
import ray._private.profiling as profiling
from ray._private.utils import decode
from ray._private.utils import decode, DeferSigint

cimport cpython

Expand Down Expand Up @@ -675,6 +676,7 @@ cdef execute_dynamic_generator_and_store_task_outputs(
"by the first execution.\n"
"See https://github.com/ray-project/ray/issues/28688.")


cdef void execute_task(
const CAddress &caller_address,
CTaskType task_type,
Expand Down Expand Up @@ -794,7 +796,6 @@ cdef void execute_task(
object_refs = VectorToObjectRefs(
c_arg_refs,
skip_adding_local_ref=False)

if core_worker.current_actor_is_asyncio():
# We deserialize objects in event loop thread to
# prevent segfaults. See #7799
Expand All @@ -806,9 +807,19 @@ cdef void execute_task(
deserialize_args, function_descriptor,
name_of_concurrency_group_to_execute)
else:
args = (ray._private.worker.global_worker
.deserialize_objects(
metadata_pairs, object_refs))
# Defer task cancellation (SIGINT) until after the task argument
# deserialization context has been left.
# NOTE (Clark): We defer SIGINT until after task argument
# deserialization completes to keep from interrupting
# non-reentrant imports that may be re-entered during error
# serialization or storage.
# See https://github.com/ray-project/ray/issues/30453.
# NOTE (Clark): Signal handlers can only be registered on the
# main thread.
with DeferSigint.create_if_main_thread():
args = (ray._private.worker.global_worker
.deserialize_objects(
metadata_pairs, object_refs))

for arg in args:
raise_if_dependency_failed(arg)
Expand Down
188 changes: 186 additions & 2 deletions python/ray/tests/test_cancel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import random
import signal
import sys
import threading
import _thread
import time

import pytest
Expand All @@ -12,6 +16,7 @@
WorkerCrashedError,
ObjectLostError,
)
from ray._private.utils import DeferSigint
from ray._private.test_utils import SignalActor


Expand Down Expand Up @@ -92,6 +97,187 @@ def dummy(a: SlowToDeserialize):
ray.get(obj)


def test_defer_sigint():
# Tests a helper context manager for deferring SIGINT signals until after the
# context is left. This is used by Ray's task cancellation to defer cancellation
# interrupts during problematic areas, e.g. task argument deserialization.
signal_was_deferred = False
orig_sigint_handler = signal.getsignal(signal.SIGINT)
try:
with DeferSigint():
# Send singal to current process.
# NOTE: We use _thread.interrupt_main() instead of os.kill() in order to
# support Windows.
_thread.interrupt_main()
# Wait for signal to be delivered.
time.sleep(1)
# Signal should have been delivered by here, so we consider it deferred if
# this is reached.
signal_was_deferred = True
except KeyboardInterrupt:
# Check that SIGINT was deferred until the end of the context.
assert signal_was_deferred
# Check that original SIGINT handler was restored.
assert signal.getsignal(signal.SIGINT) is orig_sigint_handler
else:
pytest.fail("SIGINT signal was never sent in test")


def test_defer_sigint_monkey_patch():
# Tests that setting a SIGINT signal handler within a DeferSigint context is not
# allowed.
orig_sigint_handler = signal.getsignal(signal.SIGINT)
with pytest.raises(ValueError):
with DeferSigint():
signal.signal(signal.SIGINT, orig_sigint_handler)


def test_defer_sigint_noop_in_non_main_thread():
# Tests that we don't try to defer SIGINT when not in the main thread.

# Check that DeferSigint.create_if_main_thread() does not return DeferSigint when
# not in the main thread.
def check_no_defer():
cm = DeferSigint.create_if_main_thread()
assert not isinstance(cm, DeferSigint)

check_no_defer_thread = threading.Thread(target=check_no_defer)
try:
check_no_defer_thread.start()
check_no_defer_thread.join()
except AssertionError as e:
pytest.fail(
"DeferSigint.create_if_main_thread() unexpected returned a DeferSigint "
f"instance when not in the main thread: {e}"
)

# Check that signal is not deferred when trying to defer it in not the main thread.
signal_was_deferred = False

def maybe_defer():
nonlocal signal_was_deferred

with DeferSigint.create_if_main_thread() as cm:
# Check that DeferSigint context manager was NOT returned.
assert not isinstance(cm, DeferSigint)
# Send singal to current process.
# NOTE: We use _thread.interrupt_main() instead of os.kill() in order to
# support Windows.
_thread.interrupt_main()
# Wait for signal to be delivered.
time.sleep(1)
# Signal should have been delivered by here, so we consider it deferred if
# this is reached.
signal_was_deferred = True

# Create thread that will maybe defer SIGINT.
maybe_defer_thread = threading.Thread(target=maybe_defer)
try:
maybe_defer_thread.start()
maybe_defer_thread.join()
# KeyboardInterrupt should get raised in main thread.
except KeyboardInterrupt:
# Check that SIGINT was not deferred.
assert not signal_was_deferred
# Check that original SIGINT handler was not overridden.
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
else:
pytest.fail("SIGINT signal was never sent in test")


def test_cancel_during_arg_deser_non_reentrant_import(ray_start_regular):
# This test ensures that task argument deserialization properly defers task
# cancellation interrupts until after deserialization completes, in order to ensure
# that non-reentrant imports that happen during both task argument deserialization
# and during error storage are not interrupted.

# We test this by doing the following:
# - register a custom serializer for (a) a task argument that triggers
# non-reentrant imports on deserialization, and (b) RayTaskError that triggers
# non-reentrant imports on serialization; in our case, we chose pandas it is both
# non-reentrant and expensive, with an import time ~0.5 seconds, giving us a wide
# cancellation target,
# - wait until those serializers are registered on all workers,
# - launch the task and wait until we are confident that the cancellation signal
# will be received by the workers during task argument deserialization (currently a
# 200 ms wait).
# - check that a graceful task cancellation error is raised, not a
# WorkerCrashedError.
def non_reentrant_import():
# NOTE: Pandas has a non-reentrant import and should take ~0.5 seconds to
# import, giving us a wide cancellation target.
import pandas # noqa

def non_reentrant_import_and_delegate(obj):
# Custom serializer for task argument and task error resulting in non-reentrant
# imports being imported on both serialization and deserialization. We use the
# same custom serializer for both, doing non-reentrant imports on both
# serialization and deserialization, for the sake of simplicity/reuse.

# Import on serialization.
non_reentrant_import()

reduced = obj.__reduce__()
func = reduced[0]
args = reduced[1]
others = reduced[2:]

def non_reentrant_import_on_reconstruction(*args, **kwargs):
# Import on deserialization.
non_reentrant_import()

return func(*args, **kwargs)

out = (non_reentrant_import_on_reconstruction, args) + others
return out

# Dummy task argument for which we register a serializer that will trigger
# non-reentrant imports on deserialization.
class DummyArg:
pass

def register_non_reentrant_import_and_delegate_reducer(worker_info):
from ray.exceptions import RayTaskError

context = ray._private.worker.global_worker.get_serialization_context()
# Register non-reentrant import serializer for task argument.
context._register_cloudpickle_reducer(
DummyArg, non_reentrant_import_and_delegate
)
# Register non-reentrant import serializer for RayTaskError.
context._register_cloudpickle_reducer(
RayTaskError, non_reentrant_import_and_delegate
)

ray._private.worker.global_worker.run_function_on_all_workers(
register_non_reentrant_import_and_delegate_reducer,
)

# Wait for function to run on all workers.
time.sleep(3)

@ray.remote
def run_and_fail(a: DummyArg):
# Should never be reached.
assert False

arg = DummyArg()
obj = run_and_fail.remote(arg)
# Check that task isn't done.
# NOTE: This timeout was finely tuned to ensure that task cancellation happens while
# we are deserializing task arguments (10/10 runs when this comment was added).
timeout_to_reach_arg_deserialization = 0.2
assert len(ray.wait([obj], timeout=timeout_to_reach_arg_deserialization)[0]) == 0

# Cancel task.
use_force = False
ray.cancel(obj, force=use_force)

# Should raise RayTaskError or TaskCancelledError, NOT WorkerCrashedError.
with pytest.raises(valid_exceptions(use_force)):
ray.get(obj)


@pytest.mark.parametrize("use_force", [True, False])
def test_cancel_multiple_dependents(ray_start_regular, use_force):
signaler = SignalActor.remote()
Expand Down Expand Up @@ -336,8 +522,6 @@ def many_resources():


if __name__ == "__main__":
import os

if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
Expand Down

0 comments on commit df88dc9

Please sign in to comment.