diff --git a/rllib/algorithms/apex_dqn/apex_dqn.py b/rllib/algorithms/apex_dqn/apex_dqn.py index 6940230168ca..a876e1179784 100644 --- a/rllib/algorithms/apex_dqn/apex_dqn.py +++ b/rllib/algorithms/apex_dqn/apex_dqn.py @@ -13,7 +13,6 @@ """ # noqa: E501 import copy import platform -import queue import random from collections import defaultdict from typing import Callable, Dict, List, Optional, Type @@ -413,7 +412,6 @@ def setup(self, config: PartialAlgorithmConfigDict): weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) self.curr_num_samples_collected = 0 - self.replay_sample_batches = [] self._num_ts_trained_since_last_target_update = 0 @classmethod @@ -563,14 +561,15 @@ def wait_on_replay_actors() -> None: If the timeout is None, then block on the actors indefinitely. """ _replay_samples_ready = self._replay_actor_manager.get_ready() - + replay_sample_batches = [] for _replay_actor, _sample_batches in _replay_samples_ready.items(): for _sample_batch in _sample_batches: - self.replay_sample_batches.append((_replay_actor, _sample_batch)) + replay_sample_batches.append((_replay_actor, _sample_batch)) + return replay_sample_batches num_samples_collected = sum(num_samples_collected.values()) self.curr_num_samples_collected += num_samples_collected - wait_on_replay_actors() + replay_sample_batches = wait_on_replay_actors() if self.curr_num_samples_collected >= self.config["train_batch_size"]: training_intensity = int(self.config["training_intensity"] or 1) num_requests_to_launch = ( @@ -583,26 +582,17 @@ def wait_on_replay_actors() -> None: lambda actor, num_items: actor.sample(num_items), fn_args=[self.config["train_batch_size"]], ) - wait_on_replay_actors() + replay_sample_batches.extend(wait_on_replay_actors()) # add the sample batches to the learner queue - while self.replay_sample_batches: - try: - item = self.replay_sample_batches[0] - # the replay buffer returns none if it has not been filled to - # the minimum threshold yet. - if item: - # Setting block = True prevents the learner thread, - # the main thread, and the gpu loader threads from - # thrashing when there are more samples than the - # learner can reasonable process. - # see https://github.com/ray-project/ray/pull/26581#issuecomment-1187877674 # noqa - self.learner_thread.inqueue.put( - self.replay_sample_batches[0], block=True - ) - self.replay_sample_batches.pop(0) - except queue.Full: - break + for item in replay_sample_batches: + # Setting block = True prevents the learner thread, + # the main thread, and the gpu loader threads from + # thrashing when there are more samples than the + # learner can reasonable process. + # see https://github.com/ray-project/ray/pull/26581#issuecomment-1187877674 # noqa + self.learner_thread.inqueue.put(item, block=True) + del replay_sample_batches def update_replay_sample_priority(self) -> None: """Update the priorities of the sample batches with new priorities that are diff --git a/rllib/algorithms/dqn/learner_thread.py b/rllib/algorithms/dqn/learner_thread.py index 60abe0f1904f..918f2d0637f6 100644 --- a/rllib/algorithms/dqn/learner_thread.py +++ b/rllib/algorithms/dqn/learner_thread.py @@ -74,6 +74,9 @@ def step(self): self.outqueue.put( (replay_actor, prio_dict, ma_batch.count, ma_batch.agent_steps()) ) - self.learner_queue_size.push(self.inqueue.qsize()) - self.weights_updated = True - self.overall_timer.push_units_processed(ma_batch and ma_batch.count or 0) + self.learner_queue_size.push(self.inqueue.qsize()) + self.weights_updated = True + self.overall_timer.push_units_processed( + ma_batch and ma_batch.count or 0 + ) + del ma_batch