Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add gradient checks to avoid nan gradients in TorchLearner. #47452

Merged

Conversation

simonsays1980
Copy link
Collaborator

@simonsays1980 simonsays1980 commented Sep 2, 2024

Why are these changes needed?

If any gradients turn nan in TorchLearner these gradients get added to the network's weights and in turn weights become nan and all network outputs as well. As a result the training errors out and stops. This PR proposes a gradient check to only add gradients if they are sane. It switches nan values to zeros in the gradients or skips en update entirely.
The latter can be of advantage, if training phase ecnounters highly unstable policy updates (e.g. with highly explorative policies or during early stages of training). In such phases many gradients could turn nan and this may lead to corrupted internal optimizer states (e.g. Adam).

Related issue number

#47451

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…in highly unstable training phases. This helps to keep the optimizer's internal state intact whoch could get corrupted with many zero gradients. Furthermore, added better logging messages.

Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
@simonsays1980 simonsays1980 changed the title [RLlib] - Add gradient checks to avoid 'nan' gradients in 'TorchLearner'. [RLlib] - Add gradient checks to avoid nan gradients in TorchLearner. Sep 2, 2024
@@ -176,7 +176,27 @@ def compute_gradients(
def apply_gradients(self, gradients_dict: ParamDict) -> None:
# Set the gradient of the parameters.
for pid, grad in gradients_dict.items():
self._params[pid].grad = grad
# If updates should not be skipped turn `nan` gradients to zero.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I'm confused. We have this block here further below, which I think does the exact same thing: skips the entire optim.step in case any gradient is non-finite (inf or nan).

                # `step` the optimizer (default), but only if all gradients are finite.
                elif all(
                    param.grad is None or torch.isfinite(param.grad).all()
                    for group in optim.param_groups
                    for param in group["params"]
                ):

Can you check and see whether these two logics can be consolidated?

Kind of like this:

  • If user sets this flag (default=False), the optimizer will skip the update step entirely (+ warning raised by RLlib).
  • If user does NOT set this flag (default behavior), grads that are non-finite will be set to 0.0 (+ warning raised by RLlib).

@sven1977 sven1977 changed the title [RLlib] - Add gradient checks to avoid nan gradients in TorchLearner. [RLlib] Add gradient checks to avoid nan gradients in TorchLearner. Sep 2, 2024
…lution considers non-finite gradients and gives the user still the option to set such gradients to zero to keep the optimizer's internal state intact.

Signed-off-by: simonsays1980 <[email protected]>
Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @simonsays1980

@sven1977 sven1977 enabled auto-merge (squash) September 3, 2024 10:50
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Sep 3, 2024
@sven1977 sven1977 merged commit e1ed103 into ray-project:master Sep 3, 2024
7 checks passed
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests rllib RLlib related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants