-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] APPO on new API stack (w/ EnvRunners). #46216
Conversation
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…_new_api_stack Signed-off-by: sven1977 <[email protected]> # Conflicts: # rllib/algorithms/appo/appo_learner.py # rllib/algorithms/appo/tf/appo_tf_learner.py # rllib/algorithms/appo/torch/appo_torch_learner.py # rllib/algorithms/dqn/dqn_rainbow_learner.py # rllib/algorithms/dqn/dqn_rainbow_rl_module.py # rllib/algorithms/sac/torch/sac_torch_rl_module.py
Signed-off-by: sven1977 <[email protected]>
module_id, config, mean_kl_loss_per_module[module_id] | ||
) | ||
@override(Learner) | ||
def _after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this PR, we get rid of additional_update_for_module
and instead support customizing:
Learner.before_gradient_based_update()
Learner.after_gradient_based_update()
These get called along with the regular Learner.update()
call, so we won't have the problem of 2x metrics reduction anymore or having to pass around results from the update()
call back into the additional_update
call (e.g. the KL values, which felt a little clumsy).
We will still have to streamline this API in the future (maybe make it per module?, give them better names, make them public, unify the timesteps
arg format).
def _update_module_kl_coeff( | ||
self, module_id: ModuleID, config: APPOConfig, sampled_kl: float | ||
) -> None: | ||
def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll take the KL directly from the metrics now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just a duplicate target network synch at thesetup step of DQN Rainbow/SAC
lambda mid, module: module.sync_target_networks(tau=1.0) | ||
) | ||
# Initially sync target networks (w/ tau=1.0 -> full overwrite). | ||
self.module.sync_target_networks(tau=1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We sync twice at the beginning now - the TorchDQNRainbowRLModule
does sync in its setup()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I was debating this with myself: Should the RLModule perform the initial sync or the Learner?
Since the Learner also controls the regular syncs during training, I felt like we should do it in the Learner, then it's all in one place. The RLModule itself (at least in its inference_only mode) doesn't really care about the target nets anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
APPO on new API stack (w/ EnvRunners).
get_target_net_pairs...
method, then the Learner can call thesync
method on the module to sync either with or without (1.0) a tau value.additional_update
entirely (replaced by a more flexible yet simpler API:before_gradient_based_update
andafter_gradient_based_update
, which get called along withupdate
, NOT in sequence anymore).Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.