From fd0442f0594a855cafa0ed14d879b498817396fd Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Sat, 19 Nov 2022 03:31:45 -0800 Subject: [PATCH 1/3] [RLlib] To improve performance, do not wait for sync weight calls by default. Signed-off-by: Jun Gong --- rllib/evaluation/worker_set.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index f17a4a4545d5..965bc28c1474 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -385,6 +385,7 @@ def sync_weights( from_worker: Optional[RolloutWorker] = None, to_worker_indices: Optional[List[int]] = None, global_vars: Optional[Dict[str, TensorType]] = None, + timeout_seconds: Optional[int] = 0, ) -> None: """Syncs model weights from the local worker to all remote workers. @@ -397,6 +398,10 @@ def sync_weights( weights to. If None (default), sync to all remote workers. global_vars: An optional global vars dict to set this worker to. If None, do not update the global_vars. + timeout_seconds: Timeout in seconds to wait for the sync weights + calls to complete. Default is 0 (sync-and-forget, do not wait + for any sync calls to finish). This significantly improves + algorithm performance. """ if self.local_worker() is None and from_worker is None: raise TypeError( @@ -408,12 +413,8 @@ def sync_weights( weights = None if self.num_remote_workers() or from_worker is not None: weights = (from_worker or self.local_worker()).get_weights(policies) - # Put weights only once into object store and use same object - # ref to synch to all workers. - weights_ref = ray.put(weights) - def set_weight(w): - w.set_weights(ray.get(weights_ref), global_vars) + w.set_weights(weights, global_vars) # Sync to specified remote workers in this WorkerSet. self.foreach_worker( @@ -423,6 +424,7 @@ def set_weight(w): # Restored workers need to have local work state synced over first, # before they will have all the policies to receive these weights. healthy_only=True, + timeout_seconds=timeout_seconds, ) # If `from_worker` is provided, also sync to this WorkerSet's @@ -655,7 +657,7 @@ def foreach_worker( # TODO(jungong) : switch to True once Algorithm is migrated. healthy_only=False, remote_worker_ids: List[int] = None, - timeout_seconds=None, + timeout_seconds: Optional[int] = None, return_obj_refs: bool = False, ) -> List[T]: """Calls the given function with each worker instance as the argument. @@ -706,7 +708,7 @@ def foreach_worker_with_id( # TODO(jungong) : switch to True once Algorithm is migrated. healthy_only=False, remote_worker_ids: List[int] = None, - timeout_seconds=None, + timeout_seconds: Optional[int] = None, ) -> List[T]: """Similar to foreach_worker(), but calls the function with id of the worker too. @@ -777,7 +779,7 @@ def foreach_worker_async( def fetch_ready_async_reqs( self, *, - timeout_seconds=0, + timeout_seconds: Optional[int] = 0, return_obj_refs: bool = False, ) -> List[Tuple[int, T]]: """Get esults from outstanding asynchronous requests that are ready. From a46b14f7942cb0c88292ba5cecbcc1c9a17df08a Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Sat, 19 Nov 2022 10:50:33 -0800 Subject: [PATCH 2/3] lint Signed-off-by: Jun Gong --- rllib/evaluation/worker_set.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 965bc28c1474..cdad19900ae6 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -16,7 +16,6 @@ Union, ) -import ray from ray.actor import ActorHandle from ray.exceptions import RayActorError from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -392,7 +391,7 @@ def sync_weights( Args: policies: Optional list of PolicyIDs to sync weights for. If None (default), sync weights to/from all policies. - from_worker: Optional RolloutWorker instance to sync from. + from_worker: Optional local RolloutWorker instance to sync from. If None (default), sync from this WorkerSet's local worker. to_worker_indices: Optional list of worker indices to sync the weights to. If None (default), sync to all remote workers. @@ -413,6 +412,7 @@ def sync_weights( weights = None if self.num_remote_workers() or from_worker is not None: weights = (from_worker or self.local_worker()).get_weights(policies) + def set_weight(w): w.set_weights(weights, global_vars) From c8bd405ce1576a9af3194adf79f12908bb866d66 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Sat, 19 Nov 2022 12:27:48 -0800 Subject: [PATCH 3/3] Batch weight sync calls. Skip synching to local worker. Signed-off-by: Jun Gong --- rllib/algorithms/a3c/a3c.py | 24 ++++++++++++------------ rllib/evaluation/worker_set.py | 1 + 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/rllib/algorithms/a3c/a3c.py b/rllib/algorithms/a3c/a3c.py index 752e234e27c3..e6a119dd10c6 100644 --- a/rllib/algorithms/a3c/a3c.py +++ b/rllib/algorithms/a3c/a3c.py @@ -218,6 +218,7 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: # update that particular worker's weights. global_vars = None learner_info_builder = LearnerInfoBuilder(num_devices=1) + to_sync_workers = set() for worker_id, result in async_results: # Apply gradients to local worker. with self._timers[APPLY_GRADS_TIMER]: @@ -237,18 +238,17 @@ def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } - # Synch updated weights back to the particular worker - # (only those policies that are trainable). - with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: - self.workers.sync_weights( - policies=local_worker.get_policies_to_train(), - to_worker_indices=[worker_id], - global_vars=global_vars, - ) - - # Update global vars of the local worker. - if global_vars: - local_worker.set_global_vars(global_vars) + # Add this worker to be synced. + to_sync_workers.add(worker_id) + + # Synch updated weights back to the particular worker + # (only those policies that are trainable). + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + self.workers.sync_weights( + policies=local_worker.get_policies_to_train(), + to_worker_indices=list(to_sync_workers), + global_vars=global_vars, + ) return learner_info_builder.finalize() diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index cdad19900ae6..44bf5b0d7c54 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -419,6 +419,7 @@ def set_weight(w): # Sync to specified remote workers in this WorkerSet. self.foreach_worker( func=set_weight, + local_worker=False, # Do not sync back to local worker. remote_worker_ids=to_worker_indices, # We can only sync to healthy remote workers. # Restored workers need to have local work state synced over first,