Skip to content

Commit

Permalink
[RLlib] Fix zero gradients for ppo-clipped vf (ray-project#22171)
Browse files Browse the repository at this point in the history
  • Loading branch information
smorad authored and simonsays1980 committed Feb 27, 2022
1 parent cb80e39 commit 1d2a3cc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 26 deletions.
19 changes: 7 additions & 12 deletions rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,22 @@ def reduce_mean_valid(t):

# Compute a value function loss.
if policy.config["use_critic"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
vf_loss1 = tf.math.square(
vf_loss = tf.math.square(
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS]
)
vf_clipped = prev_value_fn_out + tf.clip_by_value(
value_fn_out - prev_value_fn_out,
-policy.config["vf_clip_param"],
vf_loss_clipped = tf.clip_by_value(
vf_loss,
0,
policy.config["vf_clip_param"],
)
vf_loss2 = tf.math.square(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS]
)
vf_loss = tf.maximum(vf_loss1, vf_loss2)
mean_vf_loss = reduce_mean_valid(vf_loss)
mean_vf_loss = reduce_mean_valid(vf_loss_clipped)
# Ignore the value function.
else:
vf_loss = mean_vf_loss = tf.constant(0.0)
vf_loss_clipped = mean_vf_loss = tf.constant(0.0)

total_loss = reduce_mean_valid(
-surrogate_loss
+ policy.config["vf_loss_coeff"] * vf_loss
+ policy.config["vf_loss_coeff"] * vf_loss_clipped
- policy.entropy_coeff * curr_entropy
)
# Add mean_kl_loss (already processed through `reduce_mean_valid`),
Expand Down
19 changes: 5 additions & 14 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,28 +144,19 @@ def reduce_mean_valid(t):

# Compute a value function loss.
if self.config["use_critic"]:
prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
value_fn_out = model.value_function()
vf_loss1 = torch.pow(
vf_loss = torch.pow(
value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0
)
vf_clipped = prev_value_fn_out + torch.clamp(
value_fn_out - prev_value_fn_out,
-self.config["vf_clip_param"],
self.config["vf_clip_param"],
)
vf_loss2 = torch.pow(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0
)
vf_loss = torch.max(vf_loss1, vf_loss2)
mean_vf_loss = reduce_mean_valid(vf_loss)
vf_loss_clipped = torch.clamp(vf_loss, 0, self.config["vf_clip_param"])
mean_vf_loss = reduce_mean_valid(vf_loss_clipped)
# Ignore the value function.
else:
vf_loss = mean_vf_loss = 0.0
vf_loss_clipped = mean_vf_loss = 0.0

total_loss = reduce_mean_valid(
-surrogate_loss
+ self.config["vf_loss_coeff"] * vf_loss
+ self.config["vf_loss_coeff"] * vf_loss_clipped
- self.entropy_coeff * curr_entropy
)

Expand Down

0 comments on commit 1d2a3cc

Please sign in to comment.