diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index a41770ff0592..a28547393c2a 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -9,6 +9,8 @@ QF_MAX_KEY, QF_MIN_KEY, QF_PREDS, + QF_TWIN_LOSS_KEY, + QF_TWIN_PREDS, TD_ERROR_MEAN_KEY, ) from ray.rllib.algorithms.cql.cql import CQLConfig @@ -90,8 +92,9 @@ def compute_loss_for_module( # Use the actions sampled from the current policy. Columns.ACTIONS: actions_curr, } + # Note, if `twin_q` is `True`, `compute_q_values` computes the minimum + # of the `qf` and `qf_twin` and returns this minimum. q_curr = self.module[module_id].compute_q_values(batch_curr) - # TODO (simon): Add twin Q actor_loss = torch.mean(alpha.detach() * logps_curr - q_curr) else: # Use log-probabilities of the current action distribution to clone @@ -117,7 +120,8 @@ def compute_loss_for_module( # Get the Q-values for the actually selected actions in the offline data. # In the critic loss we use these as predictions. q_selected = fwd_out[QF_PREDS] - # TODO (simon): Implement twin Q + if config.twin_q: + q_twin_selected = fwd_out[QF_TWIN_PREDS] # Compute Q-values from the target Q network for the next state with the # sampled actions for the next state. @@ -125,8 +129,10 @@ def compute_loss_for_module( Columns.OBS: batch[Columns.NEXT_OBS], Columns.ACTIONS: actions_next, } + # Note, if `twin_q` is `True`, `SACTorchRLModule.forward_target` calculates + # the Q-values for both, `qf_target` and `qf_twin_target` and + # returns the minimum. q_target_next = self.module[module_id].forward_target(q_batch_next) - # TODO (simon): Apply twin Q # Now mask all Q-values with terminating next states in the targets. q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_target_next @@ -143,16 +149,25 @@ def compute_loss_for_module( # Calculate the TD error. td_error = torch.abs(q_selected - q_selected_target) - # TODO (simon): Add the Twin TD error + # Calculate a TD-error for twin-Q values, if needed. + if config.twin_q: + td_error += torch.abs(q_twin_selected - q_selected_target) + # Rescale the TD error + td_error *= 0.5 # MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)). # Note, this needs a sample from the current policy given the next state. # Note further, we could also use here the Huber loss instead of the MSE. # TODO (simon): Add the huber loss as an alternative (SAC uses it). sac_critic_loss = torch.nn.MSELoss(reduction="mean")( - q_selected, q_selected_target + q_selected, + q_selected_target, ) - # TODO (simon): Add the Twin Q critic loss + if config.twin_q: + sac_critic_twin_loss = torch.nn.MSELoss(reduction="mean")( + q_twin_selected, + q_selected_target, + ) # Now calculate the CQL loss (we use the entropy version of the CQL algorithm). # Note, the entropy version performs best in shown experiments. @@ -185,11 +200,28 @@ def compute_loss_for_module( Columns.OBS: obs_curr_repeat, Columns.ACTIONS: actions_rand_repeat, } + # Note, we need here the Q-values from the base Q-value function + # and not the minimum with an eventual Q-value twin. q_rand_repeat = ( self.module[module_id] - .compute_q_values(batch_rand_repeat) + ._qf_forward_train_helper( + batch_rand_repeat, + self.module[module_id].qf_encoder, + self.module[module_id].qf, + ) .view(batch_size, config.num_actions, 1) ) + # Calculate twin Q-values for the random actions, if needed. + if config.twin_q: + q_twin_rand_repeat = ( + self.module[module_id] + ._qf_forward_train_helper( + batch_rand_repeat, + self.module[module_id].qf_twin_encoder, + self.module[module_id].qf_twin, + ) + .view(batch_size, config.num_actions, 1) + ) del batch_rand_repeat batch_curr_repeat = { Columns.OBS: obs_curr_repeat, @@ -197,22 +229,59 @@ def compute_loss_for_module( } q_curr_repeat = ( self.module[module_id] - .compute_q_values(batch_curr_repeat) + ._qf_forward_train_helper( + batch_curr_repeat, + self.module[module_id].qf_encoder, + self.module[module_id].qf, + ) .view(batch_size, config.num_actions, 1) ) + # Calculate twin Q-values for the repeated actions from the current policy, + # if needed. + if config.twin_q: + q_twin_curr_repeat = ( + self.module[module_id] + ._qf_forward_train_helper( + batch_curr_repeat, + self.module[module_id].qf_twin_encoder, + self.module[module_id].qf_twin, + ) + .view(batch_size, config.num_actions, 1) + ) del batch_curr_repeat batch_next_repeat = { + # Note, we use here the current observations b/c we want to keep the + # state fix while sampling the actions. Columns.OBS: obs_curr_repeat, Columns.ACTIONS: actions_next_repeat, } q_next_repeat = ( self.module[module_id] - .compute_q_values(batch_next_repeat) + ._qf_forward_train_helper( + batch_next_repeat, + self.module[module_id].qf_encoder, + self.module[module_id].qf, + ) .view(batch_size, config.num_actions, 1) ) + # Calculate also the twin Q-values for the current policy and next actions, + # if needed. + if config.twin_q: + q_twin_next_repeat = ( + self.module[module_id] + ._qf_forward_train_helper( + batch_next_repeat, + self.module[module_id].qf_twin_encoder, + self.module[module_id].qf_twin, + ) + .view(batch_size, config.num_actions, 1) + ) del batch_next_repeat # Compute the log-probabilities for the random actions. + # TODO (simon): This is the density for a discrete uniform, however, actions + # come from a continuous one. So actually this density should use (1/(high-low)) + # instead of (1/2). random_density = torch.log( torch.pow( torch.tensor( @@ -231,23 +300,43 @@ def compute_loss_for_module( ], dim=1, ) - cql_loss = ( torch.logsumexp(q_repeat / config.temperature, dim=1).mean() * config.min_q_weight * config.temperature ) - cql_loss = cql_loss - (q_selected.mean() * config.min_q_weight) - # TODO (simon): Implement CQL twin-Q loss here + cql_loss -= q_selected.mean() * config.min_q_weight + # Add the CQL loss term to the SAC loss term. + critic_loss = sac_critic_loss + cql_loss + + # If a twin Q-value function is implemented calculated its CQL loss. + if config.twin_q: + q_twin_repeat = torch.cat( + [ + q_twin_rand_repeat - random_density, + q_twin_next_repeat - logps_next_repeat.detach(), + q_twin_curr_repeat - logps_curr_repeat.detach(), + ], + dim=1, + ) + cql_twin_loss = ( + torch.logsumexp(q_twin_repeat / config.temperature, dim=1).mean() + * config.min_q_weight + * config.temperature + ) + cql_twin_loss -= q_twin_selected.mean() * config.min_q_weight + # Add the CQL loss term to the SAC loss term. + critic_twin_loss = sac_critic_twin_loss + cql_twin_loss # TODO (simon): Check, if we need to implement here also a Lagrangian # loss. - critic_loss = sac_critic_loss + cql_loss - # TODO (simon): Add here also the critic loss for the twin-Q - total_loss = actor_loss + critic_loss + alpha_loss - # TODO (simon): Add Twin Q losses + + # Add the twin critic loss to the total loss, if needed. + if config.twin_q: + # Reweigh the critic loss terms in the total loss. + total_loss += 0.5 * critic_twin_loss - 0.5 * critic_loss # Log important loss stats (reduce=mean (default), but with window=1 # in order to keep them history free). @@ -273,7 +362,14 @@ def compute_loss_for_module( ) # TODO (simon): Add loss keys for langrangian, if needed. # TODO (simon): Add only here then the Langrange parameter optimization. - # TODO (simon): Add keys for twin Q + if config.twin_q: + self.metrics.log_dict( + { + QF_TWIN_LOSS_KEY: critic_twin_loss, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) # Return the total loss. return total_loss diff --git a/rllib/algorithms/sac/torch/sac_torch_learner.py b/rllib/algorithms/sac/torch/sac_torch_learner.py index 5299bca2ae5f..aed5f21b909e 100644 --- a/rllib/algorithms/sac/torch/sac_torch_learner.py +++ b/rllib/algorithms/sac/torch/sac_torch_learner.py @@ -232,6 +232,7 @@ def compute_loss_for_module( total_loss = actor_loss + critic_loss + alpha_loss # If twin Q networks should be used, add the critic loss of the twin Q network. if config.twin_q: + # TODO (simon): Check, if we need to multiply the critic_loss then with 0.5. total_loss += critic_twin_loss # Log the TD-error with reduce=None, such that - in case we have n parallel