Skip to content

Commit

Permalink
[Ray Client] Transfer dashboard_url over gRPC instead of ray.remote (#…
Browse files Browse the repository at this point in the history
…30941)

The ray.remote call is spawning worker tasks on the head node even if their client doesn't do anything, spawning unexpected workers.

Note: dashboard_url behavior is already tested by test_client_builder
  • Loading branch information
ckw017 authored and rkooo567 committed Dec 14, 2022
1 parent f1b8bfd commit c7aaa4a
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ def connect(self) -> ClientContext:
ray_init_kwargs=self._remote_init_kwargs,
metadata=self._metadata,
)
get_dashboard_url = ray.remote(ray._private.worker.get_dashboard_url)
dashboard_url = ray.get(get_dashboard_url.options(num_cpus=0).remote())

dashboard_url = ray.util.client.ray._get_dashboard_url()

cxt = ClientContext(
dashboard_url=dashboard_url,
python_version=client_info_dict["python_version"],
Expand Down
38 changes: 38 additions & 0 deletions python/ray/tests/test_client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import ray
import ray.client_builder as client_builder
import ray.util.client.server.server as ray_client_server
from ray.experimental.state.api import list_workers
from ray._private.test_utils import (
run_string_as_driver,
run_string_as_driver_nonblocking,
wait_for_condition,
)
import time


@pytest.mark.parametrize(
Expand Down Expand Up @@ -419,6 +421,42 @@ def test_client_deprecation_warn():
subprocess.check_output("ray stop --force", shell=True)


@pytest.mark.parametrize(
"call_ray_start",
[
"ray start --head --num-cpus=2 --min-worker-port=0 --max-worker-port=0 "
"--port 0 --ray-client-server-port=50056"
],
indirect=True,
)
def test_worker_processes(call_ray_start):
"""
Test that no workers are spawned until a remote function is called.
"""
ray.init("ray://localhost:50056")

# Check for 10 seconds that no workers spawned after connecting
for _ in range(10):
workers = list_workers()
non_driver_workers = [w for w in workers if w.get("worker_type") != "DRIVER"]
assert len(non_driver_workers) == 0, workers
time.sleep(1)

@ray.remote(num_cpus=2)
def f():
return 42

assert ray.get(f.remote()) == 42
time.sleep(3)

# 2 worker processes should have spawned to accommodate the remote func
for _ in range(10):
workers = list_workers()
non_driver_workers = [w for w in workers if w.get("worker_type") != "DRIVER"]
assert len(non_driver_workers) == 2, workers
time.sleep(1)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
12 changes: 11 additions & 1 deletion python/ray/tests/test_client_reconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,17 @@ def ListNamedActors(
return self._call_inner_function(request, context, "ListNamedActors")

def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse:
return self._call_inner_function(request, context, "ClusterInfo")
# Cluster info is currently used for health checks and isn't retried, so
# don't inject errors.
# TODO(ckw): update ClusterInfo so that retries are only skipped for PING
try:
return self.stub.ClusterInfo(
request, metadata=context.invocation_metadata()
)
except grpc.RpcError as e:
context.set_code(e.code())
context.set_details(e.details())
raise

def Terminate(self, req, context=None):
return self._call_inner_function(req, context, "Terminate")
Expand Down
2 changes: 1 addition & 1 deletion python/ray/util/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2022-10-05"
CURRENT_PROTOCOL_VERSION = "2022-12-06"


class _ClientContext:
Expand Down
7 changes: 7 additions & 0 deletions python/ray/util/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,10 @@ def _register_callback(
self, ref: "ClientObjectRef", callback: Callable[["DataResponse"], None]
) -> None:
self.worker.register_callback(ref, callback)

def _get_dashboard_url(self) -> str:
import ray.core.generated.ray_client_pb2 as ray_client_pb2

return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.DASHBOARD_URL
).get("dashboard_url", "")
2 changes: 2 additions & 0 deletions python/ray/util/client/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def _return_debug_cluster_info(self, request, context=None) -> str:
data = ray.timeline()
elif request.type == ray_client_pb2.ClusterInfoType.PING:
data = {}
elif request.type == ray_client_pb2.ClusterInfoType.DASHBOARD_URL:
data = {"dashboard_url": ray._private.worker.get_dashboard_url()}
else:
raise TypeError("Unsupported cluster info type")
return json.dumps(data)
Expand Down
1 change: 1 addition & 0 deletions src/ray/protobuf/ray_client.proto
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ message ClusterInfoType {
RUNTIME_CONTEXT = 4;
TIMELINE = 5;
PING = 6;
DASHBOARD_URL = 7;
}
}

Expand Down

0 comments on commit c7aaa4a

Please sign in to comment.