Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] To improve performance, do not wait for sync weight calls by default. #30509

Merged
merged 3 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions rllib/algorithms/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
# update that particular worker's weights.
global_vars = None
learner_info_builder = LearnerInfoBuilder(num_devices=1)
to_sync_workers = set()
for worker_id, result in async_results:
# Apply gradients to local worker.
with self._timers[APPLY_GRADS_TIMER]:
Expand All @@ -237,18 +238,17 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}

# Synch updated weights back to the particular worker
# (only those policies that are trainable).
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(
policies=local_worker.get_policies_to_train(),
to_worker_indices=[worker_id],
global_vars=global_vars,
)

# Update global vars of the local worker.
if global_vars:
local_worker.set_global_vars(global_vars)
# Add this worker to be synced.
to_sync_workers.add(worker_id)

# Synch updated weights back to the particular worker
# (only those policies that are trainable).
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(
policies=local_worker.get_policies_to_train(),
to_worker_indices=list(to_sync_workers),
global_vars=global_vars,
)

return learner_info_builder.finalize()

Expand Down
21 changes: 12 additions & 9 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Union,
)

import ray
from ray.actor import ActorHandle
from ray.exceptions import RayActorError
from ray.rllib.evaluation.rollout_worker import RolloutWorker
Expand Down Expand Up @@ -385,18 +384,23 @@ def sync_weights(
from_worker: Optional[RolloutWorker] = None,
to_worker_indices: Optional[List[int]] = None,
global_vars: Optional[Dict[str, TensorType]] = None,
timeout_seconds: Optional[int] = 0,
) -> None:
"""Syncs model weights from the local worker to all remote workers.

Args:
policies: Optional list of PolicyIDs to sync weights for.
If None (default), sync weights to/from all policies.
from_worker: Optional RolloutWorker instance to sync from.
from_worker: Optional local RolloutWorker instance to sync from.
If None (default), sync from this WorkerSet's local worker.
to_worker_indices: Optional list of worker indices to sync the
weights to. If None (default), sync to all remote workers.
global_vars: An optional global vars dict to set this
worker to. If None, do not update the global_vars.
timeout_seconds: Timeout in seconds to wait for the sync weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would None do then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indefinitely until all the object_refs are ready.
this is actually all standard ray.wait behavior. look at documentation for ray.wait(timeout=...).
https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-wait
the default timeout is None.

calls to complete. Default is 0 (sync-and-forget, do not wait
for any sync calls to finish). This significantly improves
algorithm performance.
"""
if self.local_worker() is None and from_worker is None:
raise TypeError(
Expand All @@ -408,21 +412,20 @@ def sync_weights(
weights = None
if self.num_remote_workers() or from_worker is not None:
weights = (from_worker or self.local_worker()).get_weights(policies)
# Put weights only once into object store and use same object
# ref to synch to all workers.
weights_ref = ray.put(weights)

def set_weight(w):
w.set_weights(ray.get(weights_ref), global_vars)
w.set_weights(weights, global_vars)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would from_worker.get_weights(policies) return an object ref, or the actual weights?

If it's an object ref, would set_weight work for the case where from_worker is a remote worker? If this is the case do we test this behavior in a unit test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_worker has to be a local RolloutWorker. so weights here must be raw weights.

the reason we have from_worker is that oftentimes, evaluation_worker_set doesn't have a local worker to sync from.
you need to sync weights from rollout_workers.local_worker() to evaluation_workers.remote_workers().

also, the reason I got rid of ray.get/put here is that, when testing everything, I noticed some slight improvements if we don't force ray.put on every single weights dict. seems like Ray core may optimize things and say if all the remote workers are on the same instance, skip serialization and simply copy over the data. need to confirm this though.


# Sync to specified remote workers in this WorkerSet.
self.foreach_worker(
func=set_weight,
local_worker=False, # Do not sync back to local worker.
remote_worker_ids=to_worker_indices,
# We can only sync to healthy remote workers.
# Restored workers need to have local work state synced over first,
# before they will have all the policies to receive these weights.
healthy_only=True,
timeout_seconds=timeout_seconds,
)

# If `from_worker` is provided, also sync to this WorkerSet's
Expand Down Expand Up @@ -655,7 +658,7 @@ def foreach_worker(
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
remote_worker_ids: List[int] = None,
timeout_seconds=None,
timeout_seconds: Optional[int] = None,
return_obj_refs: bool = False,
) -> List[T]:
"""Calls the given function with each worker instance as the argument.
Expand Down Expand Up @@ -706,7 +709,7 @@ def foreach_worker_with_id(
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
remote_worker_ids: List[int] = None,
timeout_seconds=None,
timeout_seconds: Optional[int] = None,
) -> List[T]:
"""Similar to foreach_worker(), but calls the function with id of the worker too.

Expand Down Expand Up @@ -777,7 +780,7 @@ def foreach_worker_async(
def fetch_ready_async_reqs(
self,
*,
timeout_seconds=0,
timeout_seconds: Optional[int] = 0,
return_obj_refs: bool = False,
) -> List[Tuple[int, T]]:
"""Get esults from outstanding asynchronous requests that are ready.
Expand Down