-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] - Fix numerical overflow in gradient clipping for (many) large gradients #45055
[RLlib] - Fix numerical overflow in gradient clipping for (many) large gradients #45055
Conversation
…ring values or summing many large values results in +/- infinity. As we clip by multiplying with a clipping coefficient instead of overriding values inside of the gradients tensors this modification allows to clip very large gradients. Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
@@ -94,6 +95,29 @@ def test_copy_torch_tensors(self): | |||
all(copied_tensor.detach().numpy() == tensor.detach().cpu().numpy()) | |||
) | |||
|
|||
def test_large_gradients_clipping(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for creating this important test case!
rllib/utils/torch_utils.py
Outdated
if torch.logical_or(total_norm.isnan(), total_norm.isinf()): | ||
raise RuntimeError( | ||
f"The total norm of order {norm_type} for gradients from " | ||
"`parameters` is non-finite, so it cannot be clipped. " | ||
) | ||
clip_coef = grad_clip / (total_norm + 1e-6) | ||
clip_coef = grad_clip / torch.maximum( | ||
torch.tensor(grad_clip), total_norm + 1e-6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two questions:
- Would this
torch.tensor()
pose a danger when on another device? GPU? - Can we add a comment here (or enhance the one below) explaining why we compute the coeff like this? What are the expected final values of the coeff (between 0.0 and 1.0)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Yes, let's put the tensor on the device
we have extracted before to be save. I will add then a note why we want the coefficient not larger than 1.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for fixing this very important piece of code @simonsays1980 . Just 2 nits and questions before we can merge.
Signed-off-by: Simon Zehnder <[email protected]>
clip_coef = grad_clip / (total_norm + 1e-6) | ||
# We do want the coefficient to be in between 0.0 and 1.0, therefore | ||
# if the global_norm is smaller than the clip value, we use the clip value | ||
# as normalization constant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks for making this comment much more clear!
Looks great, let's merge once tests pass ... |
Why are these changes needed?
Large gradients and many of these could lead to numerical overflow when computing their l2-norm in
torch_utils.clip_gradients
(using the "global_norm"). This is counterproductive as a user wants to clip such gradients and instead runs into numerical overflow because of clipping gradients.This PR proposes small changes to turn
inf
andneginf
values returned from norms to10e8
and-10e8
, respectively. This does not harm gradients themselves (if these for example were alreadyinf/neginf
b/c we clip gradients by multiplication and not overriding values).Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.