Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] - Fix numerical overflow in gradient clipping for (many) large gradients #45055

Merged

Conversation

simonsays1980
Copy link
Collaborator

@simonsays1980 simonsays1980 commented Apr 30, 2024

Why are these changes needed?

Large gradients and many of these could lead to numerical overflow when computing their l2-norm in torch_utils.clip_gradients (using the "global_norm"). This is counterproductive as a user wants to clip such gradients and instead runs into numerical overflow because of clipping gradients.

This PR proposes small changes to turn inf and neginf values returned from norms to 10e8 and -10e8, respectively. This does not harm gradients themselves (if these for example were already inf/neginf b/c we clip gradients by multiplication and not overriding values).

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…ring values or summing many large values results in +/- infinity. As we clip by multiplying with a clipping coefficient instead of overriding values inside of the gradients tensors this modification allows to clip very large gradients.

Signed-off-by: Simon Zehnder <[email protected]>
@simonsays1980 simonsays1980 added rllib RLlib related issues rllib-newstack labels Apr 30, 2024
@simonsays1980 simonsays1980 self-assigned this Apr 30, 2024
@@ -94,6 +95,29 @@ def test_copy_torch_tensors(self):
all(copied_tensor.detach().numpy() == tensor.detach().cpu().numpy())
)

def test_large_gradients_clipping(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for creating this important test case!

if torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. "
)
clip_coef = grad_clip / (total_norm + 1e-6)
clip_coef = grad_clip / torch.maximum(
torch.tensor(grad_clip), total_norm + 1e-6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions:

  • Would this torch.tensor() pose a danger when on another device? GPU?
  • Can we add a comment here (or enhance the one below) explaining why we compute the coeff like this? What are the expected final values of the coeff (between 0.0 and 1.0)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Yes, let's put the tensor on the device we have extracted before to be save. I will add then a note why we want the coefficient not larger than 1.0.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for fixing this very important piece of code @simonsays1980 . Just 2 nits and questions before we can merge.

@sven1977 sven1977 assigned sven1977 and unassigned simonsays1980 May 2, 2024
clip_coef = grad_clip / (total_norm + 1e-6)
# We do want the coefficient to be in between 0.0 and 1.0, therefore
# if the global_norm is smaller than the clip value, we use the clip value
# as normalization constant.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks for making this comment much more clear!

@sven1977
Copy link
Contributor

sven1977 commented May 2, 2024

Looks great, let's merge once tests pass ...

@sven1977 sven1977 merged commit 711f386 into ray-project:master May 2, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rllib RLlib related issues rllib-newstack
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants