-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] To improve performance, do not wait for sync weight calls by default. #30509
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ | |
Union, | ||
) | ||
|
||
import ray | ||
from ray.actor import ActorHandle | ||
from ray.exceptions import RayActorError | ||
from ray.rllib.evaluation.rollout_worker import RolloutWorker | ||
|
@@ -385,18 +384,23 @@ def sync_weights( | |
from_worker: Optional[RolloutWorker] = None, | ||
to_worker_indices: Optional[List[int]] = None, | ||
global_vars: Optional[Dict[str, TensorType]] = None, | ||
timeout_seconds: Optional[int] = 0, | ||
) -> None: | ||
"""Syncs model weights from the local worker to all remote workers. | ||
|
||
Args: | ||
policies: Optional list of PolicyIDs to sync weights for. | ||
If None (default), sync weights to/from all policies. | ||
from_worker: Optional RolloutWorker instance to sync from. | ||
from_worker: Optional local RolloutWorker instance to sync from. | ||
If None (default), sync from this WorkerSet's local worker. | ||
to_worker_indices: Optional list of worker indices to sync the | ||
weights to. If None (default), sync to all remote workers. | ||
global_vars: An optional global vars dict to set this | ||
worker to. If None, do not update the global_vars. | ||
timeout_seconds: Timeout in seconds to wait for the sync weights | ||
calls to complete. Default is 0 (sync-and-forget, do not wait | ||
for any sync calls to finish). This significantly improves | ||
algorithm performance. | ||
""" | ||
if self.local_worker() is None and from_worker is None: | ||
raise TypeError( | ||
|
@@ -408,21 +412,20 @@ def sync_weights( | |
weights = None | ||
if self.num_remote_workers() or from_worker is not None: | ||
weights = (from_worker or self.local_worker()).get_weights(policies) | ||
# Put weights only once into object store and use same object | ||
# ref to synch to all workers. | ||
weights_ref = ray.put(weights) | ||
|
||
def set_weight(w): | ||
w.set_weights(ray.get(weights_ref), global_vars) | ||
w.set_weights(weights, global_vars) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would If it's an object ref, would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from_worker has to be a local RolloutWorker. so the reason we have from_worker is that oftentimes, evaluation_worker_set doesn't have a local worker to sync from. also, the reason I got rid of ray.get/put here is that, when testing everything, I noticed some slight improvements if we don't force ray.put on every single weights dict. seems like Ray core may optimize things and say if all the remote workers are on the same instance, skip serialization and simply copy over the data. need to confirm this though. |
||
|
||
# Sync to specified remote workers in this WorkerSet. | ||
self.foreach_worker( | ||
func=set_weight, | ||
local_worker=False, # Do not sync back to local worker. | ||
remote_worker_ids=to_worker_indices, | ||
# We can only sync to healthy remote workers. | ||
# Restored workers need to have local work state synced over first, | ||
# before they will have all the policies to receive these weights. | ||
healthy_only=True, | ||
timeout_seconds=timeout_seconds, | ||
) | ||
|
||
# If `from_worker` is provided, also sync to this WorkerSet's | ||
|
@@ -655,7 +658,7 @@ def foreach_worker( | |
# TODO(jungong) : switch to True once Algorithm is migrated. | ||
healthy_only=False, | ||
remote_worker_ids: List[int] = None, | ||
timeout_seconds=None, | ||
timeout_seconds: Optional[int] = None, | ||
return_obj_refs: bool = False, | ||
) -> List[T]: | ||
"""Calls the given function with each worker instance as the argument. | ||
|
@@ -706,7 +709,7 @@ def foreach_worker_with_id( | |
# TODO(jungong) : switch to True once Algorithm is migrated. | ||
healthy_only=False, | ||
remote_worker_ids: List[int] = None, | ||
timeout_seconds=None, | ||
timeout_seconds: Optional[int] = None, | ||
) -> List[T]: | ||
"""Similar to foreach_worker(), but calls the function with id of the worker too. | ||
|
||
|
@@ -777,7 +780,7 @@ def foreach_worker_async( | |
def fetch_ready_async_reqs( | ||
self, | ||
*, | ||
timeout_seconds=0, | ||
timeout_seconds: Optional[int] = 0, | ||
return_obj_refs: bool = False, | ||
) -> List[Tuple[int, T]]: | ||
"""Get esults from outstanding asynchronous requests that are ready. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would None do then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indefinitely until all the object_refs are ready.
this is actually all standard ray.wait behavior. look at documentation for
ray.wait(timeout=...)
.https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-wait
the default timeout is None.