Skip to content

Commit

Permalink
[RLlib] To improve performance, do not wait for sync weight calls by …
Browse files Browse the repository at this point in the history
…default. (ray-project#30509)

Also batch weight sync calls, and skip synching to local worker.

Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
Jun Gong authored and WeichenXu123 committed Dec 19, 2022
1 parent 52b2a2d commit fd687dc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
24 changes: 12 additions & 12 deletions rllib/algorithms/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
# update that particular worker's weights.
global_vars = None
learner_info_builder = LearnerInfoBuilder(num_devices=1)
to_sync_workers = set()
for worker_id, result in async_results:
# Apply gradients to local worker.
with self._timers[APPLY_GRADS_TIMER]:
Expand All @@ -237,18 +238,17 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}

# Synch updated weights back to the particular worker
# (only those policies that are trainable).
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(
policies=local_worker.get_policies_to_train(),
to_worker_indices=[worker_id],
global_vars=global_vars,
)

# Update global vars of the local worker.
if global_vars:
local_worker.set_global_vars(global_vars)
# Add this worker to be synced.
to_sync_workers.add(worker_id)

# Synch updated weights back to the particular worker
# (only those policies that are trainable).
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(
policies=local_worker.get_policies_to_train(),
to_worker_indices=list(to_sync_workers),
global_vars=global_vars,
)

return learner_info_builder.finalize()

Expand Down
21 changes: 12 additions & 9 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Union,
)

import ray
from ray.actor import ActorHandle
from ray.exceptions import RayActorError
from ray.rllib.evaluation.rollout_worker import RolloutWorker
Expand Down Expand Up @@ -385,18 +384,23 @@ def sync_weights(
from_worker: Optional[RolloutWorker] = None,
to_worker_indices: Optional[List[int]] = None,
global_vars: Optional[Dict[str, TensorType]] = None,
timeout_seconds: Optional[int] = 0,
) -> None:
"""Syncs model weights from the local worker to all remote workers.
Args:
policies: Optional list of PolicyIDs to sync weights for.
If None (default), sync weights to/from all policies.
from_worker: Optional RolloutWorker instance to sync from.
from_worker: Optional local RolloutWorker instance to sync from.
If None (default), sync from this WorkerSet's local worker.
to_worker_indices: Optional list of worker indices to sync the
weights to. If None (default), sync to all remote workers.
global_vars: An optional global vars dict to set this
worker to. If None, do not update the global_vars.
timeout_seconds: Timeout in seconds to wait for the sync weights
calls to complete. Default is 0 (sync-and-forget, do not wait
for any sync calls to finish). This significantly improves
algorithm performance.
"""
if self.local_worker() is None and from_worker is None:
raise TypeError(
Expand All @@ -408,21 +412,20 @@ def sync_weights(
weights = None
if self.num_remote_workers() or from_worker is not None:
weights = (from_worker or self.local_worker()).get_weights(policies)
# Put weights only once into object store and use same object
# ref to synch to all workers.
weights_ref = ray.put(weights)

def set_weight(w):
w.set_weights(ray.get(weights_ref), global_vars)
w.set_weights(weights, global_vars)

# Sync to specified remote workers in this WorkerSet.
self.foreach_worker(
func=set_weight,
local_worker=False, # Do not sync back to local worker.
remote_worker_ids=to_worker_indices,
# We can only sync to healthy remote workers.
# Restored workers need to have local work state synced over first,
# before they will have all the policies to receive these weights.
healthy_only=True,
timeout_seconds=timeout_seconds,
)

# If `from_worker` is provided, also sync to this WorkerSet's
Expand Down Expand Up @@ -655,7 +658,7 @@ def foreach_worker(
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
remote_worker_ids: List[int] = None,
timeout_seconds=None,
timeout_seconds: Optional[int] = None,
return_obj_refs: bool = False,
) -> List[T]:
"""Calls the given function with each worker instance as the argument.
Expand Down Expand Up @@ -706,7 +709,7 @@ def foreach_worker_with_id(
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
remote_worker_ids: List[int] = None,
timeout_seconds=None,
timeout_seconds: Optional[int] = None,
) -> List[T]:
"""Similar to foreach_worker(), but calls the function with id of the worker too.
Expand Down Expand Up @@ -777,7 +780,7 @@ def foreach_worker_async(
def fetch_ready_async_reqs(
self,
*,
timeout_seconds=0,
timeout_seconds: Optional[int] = 0,
return_obj_refs: bool = False,
) -> List[Tuple[int, T]]:
"""Get esults from outstanding asynchronous requests that are ready.
Expand Down

0 comments on commit fd687dc

Please sign in to comment.