-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix zero gradients for ppo-clipped vf #22171
Conversation
Here are some preliminary results using the tuned cartpole example. As you can see, the mean value error (plotted as Cartpole is a simple environment that can be solved using vanilla policy gradient, so the value function has little effect on final reward. I suspect more challenging environments would see a significant reward disparity between the old and new value functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, this fixes the long outstanding value loss calculation?
thanks so much man!
Is this the same case on the tf branch? Would be good to apply the fix there as well. |
Tensorflow does indeed do the weird clipping as well: ray/rllib/agents/ppo/ppo_tf_policy.py Lines 115 to 129 in 7f1bacc
|
We'll need to make this change uniformly to both frameworks. Can you submit the change for tf as well? Thx 😊 |
@sven1977 I think it's worth setting the default A nicer idea would be to clip the value function gradient instead of the loss, or clip the value function error before taking the mean (this will produce zero-gradients for individual predictions, rather than the entire train batch). But we can discuss this another time. I think |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thanks for the fixes @smorad !
Why are these changes needed?
The PPO value loss calculation returns a zero-gradient when clipping is applied and
vf_loss2
is selected, becauseprev_value_fn_out
is from theSampleBatch
which doesn't track gradients. Furthermore, the logic itself is a bit convoluted. See the related issue for a more in-depth description.Related issue number
Closes #19291
Checks
scripts/format.sh
to lint the changes in this PR.