-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add "shuffle batch per epoch" option. #47458
Changes from all commits
596a4d8
06ec0d1
38f0d99
ea8075f
585095d
4e1e42e
61c3f20
a20f44c
b966d99
be6c2e5
42535d4
292c71f
1f748f1
c13647a
cd38695
4f36d7a
927ba3d
bc552f0
f56c255
24cc69e
3264f9c
804bfc2
a79630a
74b7a58
7bdab98
c26ae5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,7 +134,6 @@ def __init__(self, algo_class=None): | |
self.vtrace_clip_pg_rho_threshold = 1.0 | ||
self.num_multi_gpu_tower_stacks = 1 # @OldAPIstack | ||
self.minibatch_buffer_size = 1 # @OldAPIstack | ||
self.num_sgd_iter = 1 | ||
self.replay_proportion = 0.0 # @OldAPIstack | ||
self.replay_buffer_num_slots = 0 # @OldAPIstack | ||
self.learner_queue_size = 3 | ||
|
@@ -168,10 +167,10 @@ def __init__(self, algo_class=None): | |
self._lr_vf = 0.0005 # @OldAPIstack | ||
|
||
# Override some of AlgorithmConfig's default values with IMPALA-specific values. | ||
self.num_learners = 1 | ||
self.rollout_fragment_length = 50 | ||
self.train_batch_size = 500 # @OldAPIstack | ||
self.train_batch_size_per_learner = 500 | ||
self._minibatch_size = "auto" | ||
self.num_env_runners = 2 | ||
self.num_gpus = 1 # @OldAPIstack | ||
self.lr = 0.0005 | ||
|
@@ -200,8 +199,6 @@ def training( | |
num_gpu_loader_threads: Optional[int] = NotProvided, | ||
num_multi_gpu_tower_stacks: Optional[int] = NotProvided, | ||
minibatch_buffer_size: Optional[int] = NotProvided, | ||
minibatch_size: Optional[Union[int, str]] = NotProvided, | ||
num_sgd_iter: Optional[int] = NotProvided, | ||
replay_proportion: Optional[float] = NotProvided, | ||
replay_buffer_num_slots: Optional[int] = NotProvided, | ||
learner_queue_size: Optional[int] = NotProvided, | ||
|
@@ -252,15 +249,7 @@ def training( | |
- This enables us to preload data into these stacks while another stack | ||
is performing gradient calculations. | ||
minibatch_buffer_size: How many train batches should be retained for | ||
minibatching. This conf only has an effect if `num_sgd_iter > 1`. | ||
minibatch_size: The size of minibatches that are trained over during | ||
each SGD iteration. If "auto", will use the same value as | ||
`train_batch_size`. | ||
Note that this setting only has an effect if | ||
`enable_rl_module_and_learner=True` and it must be a multiple of | ||
`rollout_fragment_length` or `sequence_length` and smaller than or equal | ||
to `train_batch_size`. | ||
num_sgd_iter: Number of passes to make over each train batch. | ||
minibatching. This conf only has an effect if `num_epochs > 1`. | ||
replay_proportion: Set >0 to enable experience replay. Saved samples will | ||
be replayed with a p:1 proportion to new data samples. | ||
replay_buffer_num_slots: Number of sample batches to store for replay. | ||
|
@@ -330,8 +319,6 @@ def training( | |
self.num_multi_gpu_tower_stacks = num_multi_gpu_tower_stacks | ||
if minibatch_buffer_size is not NotProvided: | ||
self.minibatch_buffer_size = minibatch_buffer_size | ||
if num_sgd_iter is not NotProvided: | ||
self.num_sgd_iter = num_sgd_iter | ||
if replay_proportion is not NotProvided: | ||
self.replay_proportion = replay_proportion | ||
if replay_buffer_num_slots is not NotProvided: | ||
|
@@ -374,8 +361,6 @@ def training( | |
self._separate_vf_optimizer = _separate_vf_optimizer | ||
if _lr_vf is not NotProvided: | ||
self._lr_vf = _lr_vf | ||
if minibatch_size is not NotProvided: | ||
self._minibatch_size = minibatch_size | ||
|
||
return self | ||
|
||
|
@@ -452,14 +437,14 @@ def validate(self) -> None: | |
# Learner API specific checks. | ||
if ( | ||
self.enable_rl_module_and_learner | ||
and self._minibatch_size != "auto" | ||
and self.minibatch_size is not None | ||
and not ( | ||
(self.minibatch_size % self.rollout_fragment_length == 0) | ||
and self.minibatch_size <= self.total_train_batch_size | ||
) | ||
): | ||
raise ValueError( | ||
f"`minibatch_size` ({self._minibatch_size}) must either be 'auto' " | ||
f"`minibatch_size` ({self._minibatch_size}) must either be None " | ||
"or a multiple of `rollout_fragment_length` " | ||
f"({self.rollout_fragment_length}) while at the same time smaller " | ||
"than or equal to `total_train_batch_size` " | ||
|
@@ -474,20 +459,6 @@ def replay_ratio(self) -> float: | |
""" | ||
return (1 / self.replay_proportion) if self.replay_proportion > 0 else 0.0 | ||
|
||
@property | ||
def minibatch_size(self): | ||
# If 'auto', use the train_batch_size (meaning each SGD iter is a single pass | ||
# through the entire train batch). Otherwise, use user provided setting. | ||
return ( | ||
( | ||
self.train_batch_size_per_learner | ||
if self.enable_env_runner_and_connector_v2 | ||
else self.train_batch_size | ||
) | ||
if self._minibatch_size == "auto" | ||
else self._minibatch_size | ||
) | ||
|
||
@override(AlgorithmConfig) | ||
def get_default_learner_class(self): | ||
if self.framework_str == "torch": | ||
|
@@ -539,7 +510,7 @@ class IMPALA(Algorithm): | |
2. If enabled, the replay buffer stores and produces batches of size | ||
`rollout_fragment_length * num_envs_per_env_runner`. | ||
3. If enabled, the minibatch ring buffer stores and replays batches of | ||
size `train_batch_size` up to `num_sgd_iter` times per batch. | ||
size `train_batch_size` up to `num_epochs` times per batch. | ||
4. The learner thread executes data parallel SGD across `num_gpus` GPUs | ||
on batches of size `train_batch_size`. | ||
""" | ||
|
@@ -734,6 +705,9 @@ def training_step(self) -> ResultDict: | |
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 | ||
), | ||
}, | ||
num_epochs=self.config.num_epochs, | ||
minibatch_size=self.config.minibatch_size, | ||
shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch, | ||
) | ||
else: | ||
learner_results = self.learner_group.update_from_episodes( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder: isn't it possible to just turn over a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could run all of this (in the new stack) through the |
||
|
@@ -745,6 +719,9 @@ def training_step(self) -> ResultDict: | |
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 | ||
), | ||
}, | ||
num_epochs=self.config.num_epochs, | ||
minibatch_size=self.config.minibatch_size, | ||
shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch, | ||
) | ||
if not do_async_updates: | ||
learner_results = [learner_results] | ||
|
@@ -1292,7 +1269,7 @@ def _learn_on_processed_samples(self) -> ResultDict: | |
), | ||
}, | ||
async_update=async_update, | ||
num_iters=self.config.num_sgd_iter, | ||
num_epochs=self.config.num_epochs, | ||
minibatch_size=self.config.minibatch_size, | ||
) | ||
if not async_update: | ||
|
@@ -1531,7 +1508,7 @@ def make_learner_thread(local_worker, config): | |
lr=config["lr"], | ||
train_batch_size=config["train_batch_size"], | ||
num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"], | ||
num_sgd_iter=config["num_sgd_iter"], | ||
num_sgd_iter=config["num_epochs"], | ||
learner_queue_size=config["learner_queue_size"], | ||
learner_queue_timeout=config["learner_queue_timeout"], | ||
num_data_load_threads=config["num_gpu_loader_threads"], | ||
|
@@ -1540,7 +1517,7 @@ def make_learner_thread(local_worker, config): | |
learner_thread = LearnerThread( | ||
local_worker, | ||
minibatch_buffer_size=config["minibatch_buffer_size"], | ||
num_sgd_iter=config["num_sgd_iter"], | ||
num_sgd_iter=config["num_epochs"], | ||
learner_queue_size=config["learner_queue_size"], | ||
learner_queue_timeout=config["learner_queue_timeout"], | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! For Offline RL we might want to add here that an epoch might loop over the entire dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add!