diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index f1a0089acd80..b0377524f9d7 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -113,27 +113,22 @@ def reduce_mean_valid(t): # Compute a value function loss. if policy.config["use_critic"]: - prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] - vf_loss1 = tf.math.square( + vf_loss = tf.math.square( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS] ) - vf_clipped = prev_value_fn_out + tf.clip_by_value( - value_fn_out - prev_value_fn_out, - -policy.config["vf_clip_param"], + vf_loss_clipped = tf.clip_by_value( + vf_loss, + 0, policy.config["vf_clip_param"], ) - vf_loss2 = tf.math.square( - vf_clipped - train_batch[Postprocessing.VALUE_TARGETS] - ) - vf_loss = tf.maximum(vf_loss1, vf_loss2) - mean_vf_loss = reduce_mean_valid(vf_loss) + mean_vf_loss = reduce_mean_valid(vf_loss_clipped) # Ignore the value function. else: - vf_loss = mean_vf_loss = tf.constant(0.0) + vf_loss_clipped = mean_vf_loss = tf.constant(0.0) total_loss = reduce_mean_valid( -surrogate_loss - + policy.config["vf_loss_coeff"] * vf_loss + + policy.config["vf_loss_coeff"] * vf_loss_clipped - policy.entropy_coeff * curr_entropy ) # Add mean_kl_loss (already processed through `reduce_mean_valid`), diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index c45a7dcacf48..285ec95901c5 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -144,28 +144,19 @@ def reduce_mean_valid(t): # Compute a value function loss. if self.config["use_critic"]: - prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() - vf_loss1 = torch.pow( + vf_loss = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0 ) - vf_clipped = prev_value_fn_out + torch.clamp( - value_fn_out - prev_value_fn_out, - -self.config["vf_clip_param"], - self.config["vf_clip_param"], - ) - vf_loss2 = torch.pow( - vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0 - ) - vf_loss = torch.max(vf_loss1, vf_loss2) - mean_vf_loss = reduce_mean_valid(vf_loss) + vf_loss_clipped = torch.clamp(vf_loss, 0, self.config["vf_clip_param"]) + mean_vf_loss = reduce_mean_valid(vf_loss_clipped) # Ignore the value function. else: - vf_loss = mean_vf_loss = 0.0 + vf_loss_clipped = mean_vf_loss = 0.0 total_loss = reduce_mean_valid( -surrogate_loss - + self.config["vf_loss_coeff"] * vf_loss + + self.config["vf_loss_coeff"] * vf_loss_clipped - self.entropy_coeff * curr_entropy )