Skip to content

Commit

Permalink
[ray client] Fix ctrl-c for ray.get() by setting a short-server side …
Browse files Browse the repository at this point in the history
…timeout (ray-project#14425)
  • Loading branch information
ericl authored Mar 4, 2021
1 parent 190ab40 commit 2cf4c72
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
33 changes: 33 additions & 0 deletions python/ray/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,45 @@
import sys
import logging
import threading
import _thread

import ray.util.client.server.server as ray_client_server
from ray.util.client.common import ClientObjectRef
from ray.util.client.ray_client_helpers import ray_start_client_server


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_interrupt_ray_get(call_ray_stop_only):
import ray
ray.init(num_cpus=2)

with ray_start_client_server() as ray:

@ray.remote
def block():
print("blocking run")
time.sleep(99)

@ray.remote
def fast():
print("fast run")
time.sleep(1)
return "ok"

class Interrupt(threading.Thread):
def run(self):
time.sleep(2)
_thread.interrupt_main()

it = Interrupt()
it.start()
with pytest.raises(KeyboardInterrupt):
ray.get(block.remote())

# Assert we can still get new items after the interrupt.
assert ray.get(fast.remote()) == "ok"


@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_real_ray_fallback(ray_start_regular_shared):
with ray_start_client_server() as ray:
Expand Down
32 changes: 30 additions & 2 deletions python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray.cloudpickle.compat import pickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.exceptions import GetTimeoutError
from ray.util.client.client_pickler import convert_to_arg
from ray.util.client.client_pickler import dumps_from_client
from ray.util.client.client_pickler import loads_from_server
Expand All @@ -44,6 +45,11 @@
INITIAL_TIMEOUT_SEC = 5
MAX_TIMEOUT_SEC = 30

# The max amount of time an operation can run blocking in the server. This
# allows for Ctrl-C of the client to work without explicitly cancelling server
# operations.
MAX_BLOCKING_OPERATION_TIME_S = 2


def backoff(timeout: int) -> int:
timeout = timeout + 5
Expand Down Expand Up @@ -171,7 +177,29 @@ def get(self, vals, *, timeout: Optional[float] = None) -> Any:
"list of IDs or just an ID: %s" % type(vals))
if timeout is None:
timeout = 0
out = [self._get(x, timeout) for x in to_get]
deadline = None
else:
deadline = time.monotonic() + timeout
out = []
for obj_ref in to_get:
res = None
# Implement non-blocking get with a short-polling loop. This allows
# cancellation of gets via Ctrl-C, since we never block for long.
while True:
try:
if deadline:
op_timeout = min(
MAX_BLOCKING_OPERATION_TIME_S,
max(deadline - time.monotonic(), 0.001))
else:
op_timeout = MAX_BLOCKING_OPERATION_TIME_S
res = self._get(obj_ref, op_timeout)
break
except GetTimeoutError:
if deadline and time.monotonic() > deadline:
raise
logger.debug("Internal retry for get {}".format(obj_ref))
out.append(res)
if single:
out = out[0]
return out
Expand All @@ -188,7 +216,6 @@ def _get(self, ref: ClientObjectRef, timeout: float):
except pickle.UnpicklingError:
logger.exception("Failed to deserialize {}".format(data.error))
raise
logger.error(err)
raise err
return loads_from_server(data.data)

Expand Down Expand Up @@ -221,6 +248,7 @@ def _put(self, val, *, client_ref_id: bytes = None):
resp = self.data_client.PutObject(req)
return ClientObjectRef(resp.id)

# TODO(ekl) respect MAX_BLOCKING_OPERATION_TIME_S for wait too
def wait(self,
object_refs: List[ClientObjectRef],
*,
Expand Down

0 comments on commit 2cf4c72

Please sign in to comment.