Skip to content

Commit

Permalink
[core][scalability] Reduce unnecessary GCS get calls (16->8) in when …
Browse files Browse the repository at this point in the history
…starting a worker (#30166)

When a lot of workers started, it's going to add a lot of pressure to the GCS. Right now, we have 16 gcs internal kv get calls just for some metadata.

This PR optimize this and reduce it to 8 calls by passing the constants through the raylet directly.

Tests are added to enforce this and make sure no regression in the future.
  • Loading branch information
fishbone authored Dec 8, 2022
1 parent 15b57b9 commit e1a8796
Show file tree
Hide file tree
Showing 15 changed files with 181 additions and 56 deletions.
15 changes: 9 additions & 6 deletions python/ray/_private/gcs_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,17 @@ def __init__(self, address: str):
channel = gcs_utils.create_gcs_channel(address)
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)

def publish_error(self, key_id: bytes, error_info: ErrorTableData) -> None:
def publish_error(
self, key_id: bytes, error_info: ErrorTableData, num_retries=None
) -> None:
"""Publishes error info to GCS."""
msg = pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
key_id=key_id,
error_info_message=error_info,
)
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
self._gcs_publish(req)
self._gcs_publish(req, num_retries, timeout=1)

def publish_logs(self, log_batch: dict) -> None:
"""Publishes logs to GCS."""
Expand All @@ -184,16 +186,17 @@ def publish_function_key(self, key: bytes) -> None:
req = self._create_function_key_request(key)
self._gcs_publish(req)

def _gcs_publish(self, req) -> None:
count = MAX_GCS_PUBLISH_RETRIES
def _gcs_publish(self, req, num_retries=None, timeout=None) -> None:
count = num_retries or MAX_GCS_PUBLISH_RETRIES
while count > 0:
try:
self._stub.GcsPublish(req)
self._stub.GcsPublish(req, timeout=timeout)
return
except _InactiveRpcError:
pass
time.sleep(1)
count -= 1
if count > 0:
time.sleep(1)
raise TimeoutError(f"Failed to publish after retries: {req}")


Expand Down
20 changes: 20 additions & 0 deletions python/ray/_private/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import traceback
import inspect
import os
import asyncio
from functools import wraps
from typing import List, Optional
Expand Down Expand Up @@ -140,11 +141,24 @@ def check_health(address: str, timeout=2, skip_version_check=False) -> bool:
return True


# This global variable is used for testing only
_called_freq = {}


def _auto_reconnect(f):
# This is for testing to count the frequence
# of gcs call
if inspect.iscoroutinefunction(f):

@wraps(f)
async def wrapper(self, *args, **kwargs):
if "TEST_RAY_COLLECT_KV_FREQUENCY" in os.environ:
global _called_freq
name = f.__name__
if name not in _called_freq:
_called_freq[name] = 0
_called_freq[name] += 1

remaining_retry = self._nums_reconnect_retry
while True:
try:
Expand Down Expand Up @@ -173,6 +187,12 @@ async def wrapper(self, *args, **kwargs):

@wraps(f)
def wrapper(self, *args, **kwargs):
if "TEST_RAY_COLLECT_KV_FREQUENCY" in os.environ:
global _called_freq
name = f.__name__
if name not in _called_freq:
_called_freq[name] = 0
_called_freq[name] += 1
remaining_retry = self._nums_reconnect_retry
while True:
try:
Expand Down
79 changes: 53 additions & 26 deletions python/ray/_private/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
shutdown_at_exit: bool = True,
spawn_reaper: bool = True,
connect_only: bool = False,
default_worker: bool = False,
):
"""Start a node.
Expand All @@ -71,14 +72,15 @@ def __init__(
other spawned processes if this process dies unexpectedly.
connect_only: If true, connect to the node without starting
new processes.
default_worker: Whether it's running from a ray worker or not
"""
if shutdown_at_exit:
if connect_only:
raise ValueError(
"'shutdown_at_exit' and 'connect_only' cannot both be true."
)
self._register_shutdown_hooks()

self._default_worker = default_worker
self.head = head
self.kernel_fate_share = bool(
spawn_reaper and ray._private.utils.detect_fate_sharing_support()
Expand Down Expand Up @@ -179,31 +181,44 @@ def __init__(
date_str = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S_%f")
self._session_name = f"session_{date_str}_{os.getpid()}"
else:
session_name = ray._private.utils.internal_kv_get_with_retry(
self.get_gcs_client(),
"session_name",
ray_constants.KV_NAMESPACE_SESSION,
num_retries=NUM_REDIS_GET_RETRIES,
)
self._session_name = ray._private.utils.decode(session_name)
if ray_params.session_name is None:
assert not self._default_worker
session_name = ray._private.utils.internal_kv_get_with_retry(
self.get_gcs_client(),
"session_name",
ray_constants.KV_NAMESPACE_SESSION,
num_retries=NUM_REDIS_GET_RETRIES,
)
self._session_name = ray._private.utils.decode(session_name)
else:
# worker mode
self._session_name = ray_params.session_name
# setup gcs client
self.get_gcs_client()

# Initialize webui url
if head:
self._webui_url = None
else:
self._webui_url = ray._private.services.get_webui_url_from_internal_kv()
if ray_params.webui is None:
assert not self._default_worker
self._webui_url = ray._private.services.get_webui_url_from_internal_kv()
else:
self._webui_url = (
f"{ray_params.dashboard_host}:{ray_params.dashboard_port}"
)

self._init_temp()

# Validate and initialize the persistent storage API.
if head:
storage._init_storage(ray_params.storage, is_head=True)
else:
storage._init_storage(
ray._private.services.get_storage_uri_from_internal_kv(), is_head=False
)
if not self._default_worker:
storage_uri = ray._private.services.get_storage_uri_from_internal_kv()
else:
storage_uri = ray_params.storage
storage._init_storage(storage_uri, is_head=False)

# If it is a head node, try validating if
# external storage is configurable.
Expand Down Expand Up @@ -355,6 +370,9 @@ def check_version_info(self):

cluster_metadata = ray_usage_lib.get_cluster_metadata(self.get_gcs_client())
if cluster_metadata is None:
cluster_metadata = ray_usage_lib.get_cluster_metadata(self.get_gcs_client())

if not cluster_metadata:
return
ray._private.utils.check_version_info(cluster_metadata)

Expand Down Expand Up @@ -382,26 +400,34 @@ def _init_temp(self):
if self.head:
self._temp_dir = self._ray_params.temp_dir
else:
temp_dir = ray._private.utils.internal_kv_get_with_retry(
self.get_gcs_client(),
"temp_dir",
ray_constants.KV_NAMESPACE_SESSION,
num_retries=NUM_REDIS_GET_RETRIES,
)
self._temp_dir = ray._private.utils.decode(temp_dir)
if self._ray_params.temp_dir is None:
assert not self._default_worker
temp_dir = ray._private.utils.internal_kv_get_with_retry(
self.get_gcs_client(),
"temp_dir",
ray_constants.KV_NAMESPACE_SESSION,
num_retries=NUM_REDIS_GET_RETRIES,
)
self._temp_dir = ray._private.utils.decode(temp_dir)
else:
self._temp_dir = self._ray_params.temp_dir

try_to_create_directory(self._temp_dir)

if self.head:
self._session_dir = os.path.join(self._temp_dir, self._session_name)
else:
session_dir = ray._private.utils.internal_kv_get_with_retry(
self.get_gcs_client(),
"session_dir",
ray_constants.KV_NAMESPACE_SESSION,
num_retries=NUM_REDIS_GET_RETRIES,
)
self._session_dir = ray._private.utils.decode(session_dir)
if self._temp_dir is None or self._session_name is None:
assert not self._default_worker
session_dir = ray._private.utils.internal_kv_get_with_retry(
self.get_gcs_client(),
"session_dir",
ray_constants.KV_NAMESPACE_SESSION,
num_retries=NUM_REDIS_GET_RETRIES,
)
self._session_dir = ray._private.utils.decode(session_dir)
else:
self._session_dir = os.path.join(self._temp_dir, self._session_name)
session_symlink = os.path.join(self._temp_dir, SESSION_LATEST)

# Send a warning message if the session exists.
Expand Down Expand Up @@ -981,6 +1007,7 @@ def start_raylet(
ray_debugger_external=self._ray_params.ray_debugger_external,
env_updates=self._ray_params.env_vars,
node_name=self._ray_params.node_name,
webui=self._webui_url,
)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
Expand Down
6 changes: 6 additions & 0 deletions python/ray/_private/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class RayParams:
worker available externally to the node it is running on. This will
bind on 0.0.0.0 instead of localhost.
env_vars: Override environment variables for the raylet.
session_name: The name of the session of the ray cluster.
webui: The url of the UI.
"""

def __init__(
Expand Down Expand Up @@ -176,6 +178,8 @@ def __init__(
tracing_startup_hook=None,
no_monitor: Optional[bool] = False,
env_vars: Optional[Dict[str, str]] = None,
session_name: Optional[str] = None,
webui: Optional[str] = None,
):
self.redis_address = redis_address
self.gcs_address = gcs_address
Expand Down Expand Up @@ -232,6 +236,8 @@ def __init__(
)
self.ray_debugger_external = ray_debugger_external
self.env_vars = env_vars
self.session_name = session_name
self.webui = webui
self._system_config = _system_config or {}
self._enable_object_reconstruction = enable_object_reconstruction
self._check_usage()
Expand Down
10 changes: 8 additions & 2 deletions python/ray/_private/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,7 @@ def start_raylet(
ray_debugger_external: bool = False,
env_updates: Optional[dict] = None,
node_name: Optional[str] = None,
webui: Optional[str] = None,
):
"""Start a raylet, which is a combined local scheduler and object manager.
Expand Down Expand Up @@ -1437,15 +1438,20 @@ def start_raylet(
f"--object-store-name={plasma_store_name}",
f"--raylet-name={raylet_name}",
f"--redis-address={redis_address}",
f"--storage={storage}",
f"--temp-dir={temp_dir}",
f"--metrics-agent-port={metrics_agent_port}",
f"--logging-rotate-bytes={max_bytes}",
f"--logging-rotate-backup-count={backup_count}",
f"--gcs-address={gcs_address}",
"RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER",
f"--session-name={session_name}",
f"--temp-dir={temp_dir}",
f"--webui={webui}",
]

start_worker_command.append(f"--storage={storage}")

start_worker_command.append("RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER")

if redis_password:
start_worker_command += [f"--redis-password={redis_password}"]

Expand Down
3 changes: 2 additions & 1 deletion python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def publish_error_to_driver(
message: str,
gcs_publisher,
job_id=None,
num_retries=None,
):
"""Push an error message to the driver to be printed in the background.
Expand All @@ -225,7 +226,7 @@ def publish_error_to_driver(
assert isinstance(job_id, ray.JobID)
error_data = construct_error_message(job_id, error_type, message, time.time())
try:
gcs_publisher.publish_error(job_id.hex().encode(), error_data)
gcs_publisher.publish_error(job_id.hex().encode(), error_data, num_retries)
except Exception:
logger.exception(f"Failed to publish error {error_data}")

Expand Down
15 changes: 10 additions & 5 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,10 @@ def init(
# handler. We still spawn a reaper process in case the atexit handler
# isn't called.
_global_node = ray._private.node.Node(
head=True, shutdown_at_exit=False, spawn_reaper=True, ray_params=ray_params
head=True,
shutdown_at_exit=False,
spawn_reaper=True,
ray_params=ray_params,
)
else:
# In this case, we are connecting to an existing cluster.
Expand Down Expand Up @@ -1968,6 +1971,7 @@ def connect(
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str,
gcs_publisher=worker.gcs_publisher,
num_retries=1,
)

driver_name = ""
Expand Down Expand Up @@ -2098,22 +2102,23 @@ def connect(
# assumes that the directory structures on the machines in the clusters
# are the same.
# When using an interactive shell, there is no script directory.
code_paths = []
if not interactive_mode:
script_directory = os.path.dirname(os.path.realpath(sys.argv[0]))
# If driver's sys.path doesn't include the script directory
# (e.g driver is started via `python -m`,
# see https://peps.python.org/pep-0338/),
# then we shouldn't add it to the workers.
if script_directory in sys.path:
worker.run_function_on_all_workers(
lambda worker_info: sys.path.insert(1, script_directory)
)
code_paths.append(script_directory)
# In client mode, if we use runtime envs with "working_dir", then
# it'll be handled automatically. Otherwise, add the current dir.
if not job_config.client_job and not job_config.runtime_env_has_working_dir():
current_directory = os.path.abspath(os.path.curdir)
code_paths.append(current_directory)
if len(code_paths) != 0:
worker.run_function_on_all_workers(
lambda worker_info: sys.path.insert(1, current_directory)
lambda worker_info: [sys.path.insert(1, path) for path in code_paths]
)
# TODO(rkn): Here we first export functions to run, then remote
# functions. The order matters. For example, one of the functions to
Expand Down
11 changes: 9 additions & 2 deletions python/ray/_private/workers/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@
action="store_true",
help="True if Ray debugger is made available externally.",
)
parser.add_argument("--session-name", required=False, help="The current session name")
parser.add_argument(
"--webui",
required=False,
help="The address of web ui",
)


if __name__ == "__main__":
Expand All @@ -162,7 +168,6 @@
raylet_ip_address = args.raylet_ip_address
if raylet_ip_address is None:
raylet_ip_address = args.node_ip_address

ray_params = RayParams(
node_ip_address=args.node_ip_address,
raylet_ip_address=raylet_ip_address,
Expand All @@ -175,14 +180,16 @@
storage=args.storage,
metrics_agent_port=args.metrics_agent_port,
gcs_address=args.gcs_address,
session_name=args.session_name,
webui=args.webui,
)

node = ray._private.node.Node(
ray_params,
head=False,
shutdown_at_exit=False,
spawn_reaper=False,
connect_only=True,
default_worker=True,
)

# NOTE(suquark): We must initialize the external storage before we
Expand Down
Loading

0 comments on commit e1a8796

Please sign in to comment.