Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero gradients for ppo-clipped vf #22171

Merged
merged 3 commits into from
Feb 15, 2022
Merged

Conversation

smorad
Copy link
Contributor

@smorad smorad commented Feb 7, 2022

Why are these changes needed?

The PPO value loss calculation returns a zero-gradient when clipping is applied and vf_loss2 is selected, because prev_value_fn_out is from the SampleBatch which doesn't track gradients. Furthermore, the logic itself is a bit convoluted. See the related issue for a more in-depth description.

Related issue number

Closes #19291

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@smorad
Copy link
Contributor Author

smorad commented Feb 7, 2022

Here are some preliminary results using the tuned cartpole example.

As you can see, the mean value error (plotted as vf_loss) is significantly lower after using the fix. Cartpole reward is in [0,200], there is no reason the mean value estimate (i.e. error at a single timestep) should have an error > 1000.

(Before)
Screen Shot 2022-02-07 at 6 19 57 PM
(After)
Screen Shot 2022-02-07 at 6 19 49 PM

(Before)
Screen Shot 2022-02-07 at 6 20 03 PM
(After)
Screen Shot 2022-02-07 at 6 19 42 PM

Cartpole is a simple environment that can be solved using vanilla policy gradient, so the value function has little effect on final reward. I suspect more challenging environments would see a significant reward disparity between the old and new value functions.

Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, this fixes the long outstanding value loss calculation?
thanks so much man!

@avnishn
Copy link
Member

avnishn commented Feb 7, 2022

Is this the same case on the tf branch? Would be good to apply the fix there as well.

@smorad
Copy link
Contributor Author

smorad commented Feb 8, 2022

Tensorflow does indeed do the weird clipping as well:

if policy.config["use_critic"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
vf_loss1 = tf.math.square(
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS]
)
vf_clipped = prev_value_fn_out + tf.clip_by_value(
value_fn_out - prev_value_fn_out,
-policy.config["vf_clip_param"],
policy.config["vf_clip_param"],
)
vf_loss2 = tf.math.square(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS]
)
vf_loss = tf.maximum(vf_loss1, vf_loss2)
mean_vf_loss = reduce_mean_valid(vf_loss)

@avnishn
Copy link
Member

avnishn commented Feb 8, 2022

We'll need to make this change uniformly to both frameworks. Can you submit the change for tf as well? Thx 😊

@smorad
Copy link
Contributor Author

smorad commented Feb 13, 2022

Results after tf fix:
Screen Shot 2022-02-13 at 3 23 23 PM

@smorad
Copy link
Contributor Author

smorad commented Feb 13, 2022

@sven1977 I think it's worth setting the default vf_clip_param to inf, essentially ignoring it. Clamping will set the value function gradient to zero via the chain rule: https://discuss.pytorch.org/t/exluding-torch-clamp-from-backpropagation-as-tf-stop-gradient-in-tensorflow/52404. So if vf_loss > policy.config["clip_param"], the gradient for the value function becomes zero.

A nicer idea would be to clip the value function gradient instead of the loss, or clip the value function error before taking the mean (this will produce zero-gradients for individual predictions, rather than the entire train batch). But we can discuss this another time. I think vf_clip_param will generally hurt newcomers, rather than help them. It makes sense this should be explicitly set by the user only when they know what they're doing.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks for the fixes @smorad !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] PPO value function loss is incorrect
4 participants