From 360bde7eeb8360a065627efdc132c4eeff487b74 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Fri, 2 Aug 2024 19:18:31 +0200 Subject: [PATCH 1/6] Added CQLLearner and CQLTorchLearner. Signed-off-by: simonsays1980 --- rllib/algorithms/cql/cql_learner.py | 17 + .../algorithms/cql/torch/cql_torch_learner.py | 315 ++++++++++++++++++ 2 files changed, 332 insertions(+) create mode 100644 rllib/algorithms/cql/cql_learner.py create mode 100644 rllib/algorithms/cql/torch/cql_torch_learner.py diff --git a/rllib/algorithms/cql/cql_learner.py b/rllib/algorithms/cql/cql_learner.py new file mode 100644 index 000000000000..b7f907cf10d1 --- /dev/null +++ b/rllib/algorithms/cql/cql_learner.py @@ -0,0 +1,17 @@ +from ray.rllib.algorithms.sac.sac_learner import SACLearner + +from ray.rllib.core.learner.learner import Learner +from ray.rllib.utils.annotations import override + + +class CQLLearner(SACLearner): + @override(Learner) + def build(self) -> None: + + # Set up the gradient buffer to store gradients to apply + # them later in `self.apply_gradients`. + self.grads = {} + + # We need to call the `super()`'s `build` method here to have the variables + # for `alpha`` and the target entropy defined. + super().build() diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py new file mode 100644 index 000000000000..4109ba6fd43b --- /dev/null +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -0,0 +1,315 @@ +from typing import Dict + +from ray.rllib.cql.cql_learner import CQLLearner +from ray.rllib.algorithms.sac.sac_learner import ( + LOGPS_KEY, + QF_LOSS_KEY, + QF_MEAN_KEY, + QF_MAX_KEY, + QF_MIN_KEY, + QF_PREDS, + TD_ERROR_MEAN_KEY, +) +from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import ( + POLICY_LOSS_KEY, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType + +torch, nn = try_import_torch() + + +class CQLTorchLearner(CQLLearner, SACTorchLearner): + @override(SACTorchLearner) + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: CQLConfig, + batch: Dict, + fwd_out: Dict[str, TensorType], + ) -> TensorType: + + # Get the train action distribution for the current policy and current state. + # This is needed for the policy (actor) loss and the `alpha`` loss. + action_dist_class = self.module[module_id].get_train_action_dist_cls() + action_dist_curr = action_dist_class.from_logits( + fwd_out[Columns.ACTION_DIST_INPUTS] + ) + + # Sample actions for the current state. Note that we need to apply the + # reparameterization trick here to avoid the expectation over actions. + actions_curr = ( + action_dist_curr.rsample() + if not config._deterministic_loss + # If deterministic, we use the mean.s + else action_dist_curr.to_deterministic().sample() + ) + # Compute the log probabilities for the current state (for the alpha loss) + logps_curr = action_dist_curr.logp(actions_curr) + + # Optimize also the hyperparameter `alpha` by using the current policy + # evaluated at the current state (from offline data). Note, in contrast + # to the original SAC loss, here the `alpha` and actor losses are + # calculated first. + # TODO (simon): Check, why log(alpha) is used, prob. just better + # to optimize and monotonic function. Original equation uses alpha. + alpha_loss = -torch.mean( + self.curr_log_alpha[module_id] + * (logps_curr.detach() + self.target_entropy[module_id]) + ) + + # Get the current batch size. Note, this size might vary in case the + # last batch contains less than `train_batch_size_per_learner` examples. + batch_size = batch[Columns.OBS].shape[0] + # Optimize the hyperparameter `alpha` by using the current policy evaluated + # at the current state. Note further, we minimize here, while the original + # equation in Haarnoja et al. (2018) considers maximization. + if batch_size == config.train_batch_size_per_learner: + optim = self.get_optimizer(module_id=module_id, optimizer_name="alpha") + optim.zero_grad(set_to_none=True) + alpha_loss.backward() + # Add the gradients to the gradient buffer that is evaluated later in + # `self.apply_gradients`. + self.grads.update( + { + pid: p.grad.clone() + for pid, p in self.filter_param_dict_for_optimizer( + self._params, optim + ).items() + } + ) + + # Get the current alpha. + alpha = torch.exp(self.curr_log_alpha[module_id]) + if self.metrics.peek("current_iteration") >= config.bc_iterations: + q_selected = fwd_out[QF_PREDS] + # TODO (simon): Add twin Q + actor_loss = torch.mean(alpha.detach() * logps_curr - q_selected) + else: + bc_logps_curr = action_dist_curr.logp(batch[Columns.ACTIONS]) + actor_loss = torch.mean(alpha.detach() * logps_curr - bc_logps_curr) + + # Optimize the SAC actor loss. + if batch_size == config.train_batch_size_per_learner: + optim = self.get_optimizer(module_id=module_id, optimizer_name="policy") + optim.zero_grad() + actor_loss.backward() + # Add the gradients to the gradient buffer that is used + # in `self.apply_gradients`. + self.grads.update( + { + pid: p.grad.clone() + for pid, p in self.filter_param_dict_for_optimizer( + self._params, optim + ).items() + } + ) + + # The critic loss is composed of the standard SAC Critic L2 loss and the + # CQL Entropy loss. + action_dist_next = action_dist_class.from_logits( + fwd_out["action_dist_inputs_next"] + ) + # Sample the actions for the next state. + actions_next = ( + # Note, we do not need to backpropagate through the + # next actions. + action_dist_next.sample() + if not config._deterministic_loss + else action_dist_next.to_deterministic().sample() + ) + # Compute the log probabilities for the next actions. + logps_next = action_dist_next.logp(actions_next) + + # Get the Q-values for the actually selected actions in the offline data. + # In the critic loss we use these as predictions. + q_selected = fwd_out[QF_PREDS] + # TODO (simon): Implement twin Q + + # Compute Q-values from the target Q network for the next state with the + # sampled actions for the next state. + q_batch_next = { + Columns.OBS: batch[Columns.NEXT_OBS], + Columns.ATIONS: actions_next, + } + q_target_next = self.module[module_id].forward_target(q_batch_next) + # TODO (simon): Apply twin Q + + # Now mask all Q-values with terminating next states in the targets. + q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_target_next + + # Compute the right hand side of the Bellman equation. Detach this node + # from the computation graph as we do not want to backpropagate through + # the target netowrk when optimizing the Q loss. + q_selected_target = ( + # TODO (simon): Add an `n_step` option to the `AddNextObsToBatch` connector. + batch[Columns.REWARDS] + + (config.gamma ** batch["n_step"]) * q_next_masked + ).detach() + + # Calculate the TD error. + td_error = torch.abs(q_selected - q_selected_target) + # TODO (simon): Add the Twin TD error + + # MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)). + # Note, this needs a sample from the current policy given the next state. + # Note further, we could also use here the Huber loss instead of the MSE. + # TODO (simon): Add the huber loss as an alternative (SAC uses it). + sac_critic_loss = torch.nn.MSELoss(reduction="mean")( + q_selected, q_selected_target + ) + # TODO (simon): Add the Twin Q critic loss + + # Now calculate the CQL loss (we use the entropy version of the CQL algorithm). + # Note, the entropy version performs best in shown experiments. + actions_rand_repeat = torch.FloatTensor( + batch[Columns.ACTIONS].shape[0] * config.num_actions, + device=fwd_out[QF_PREDS].device, + ).uniform_( + self.module.config.action_space.low, self.module.config.action_space.high + ) + # TODO (simon): Check, if we can repeat these like this or if we need to + # backpropagate through these actions. + # actions_curr_repeat = actions_curr[actions_curr. + # multinomial(config.num_actions, replacement=True).view(-1)] + # actions_next_repeat = actions_curr[actions_next. + # multinomial(config.num_actions, replacement=True).view(-1)] + actions_curr_repeat = ( + action_dist_curr.rsample() + if not config._deterministic_loss + else action_dist_curr.to_deterministic().sample() + ) + logps_curr_repeat = action_dist_curr.logp(actions_curr_repeat) + random_idx = actions_curr_repeat.multinomial( + config.num_actions, replacement=True + ).view(-1) + actions_curr_repeat = actions_curr_repeat[random_idx] + logps_curr_repeat = logps_curr_repeat[random_idx] + q_batch_curr_repeat = { + Columns.OBS: batch[Columns.OBS][random_idx], + Columns.ACTIONS: actions_curr_repeat, + } + q_curr_repeat = self.module.compute_q_values(q_batch_curr_repeat) + del q_batch_curr_repeat + # Sample actions for the next state. + actions_next_repeat = ( + action_dist_next.rsample() + if not config._deterministic_loss + else action_dist_next.to_deterministic().sample() + ) + logps_next_repeat = action_dist_next.logp(actions_next_repeat) + random_idx = actions_next_repeat.multinomial( + config.num_actions, replacement=True + ).view(-1) + actions_next_repeat = actions_next_repeat[random_idx] + logps_next_repeat = logps_next_repeat[random_idx] + q_batch_next_repeat = { + Columns.OBS: batch[Columns.NEXT_OBS][random_idx], + Columns.ACTIONS: actions_next_repeat, + } + q_next_repeat = self.module.compute_q_values(q_batch_next_repeat) + del q_batch_next_repeat + + q_batch_random_repeat = { + # Note, we can use here simply the same random index + # as within the last batch. + Columns.OBS: batch[Columns.OBS][random_idx], + Columns.ACTIONS: actions_rand_repeat, + } + # Compute the Q-values for the random actions (from the mu-distribution). + q_random_repeat = self.module.compute_q_values(q_batch_random_repeat) + del q_batch_random_repeat + # TODO (simon): Check, if this should be `actions_random_repeat`. + logps_random_repeat = torch.logp(0.5**actions_curr_repeat) + q_repeat = torch.cat( + [ + q_random_repeat - logps_random_repeat, + q_curr_repeat - logps_curr_repeat, + q_next_repeat - logps_next_repeat, + ], + dim=1, + ) + # TODO (simon): Also run the twin Q here. + + # Compute the entropy version of the CQL loss (see eq. (4) in Kumar et al. + # (2020)). Note that we use here the softmax with a temperature parameter. + cql_loss = ( + torch.logsumexp(q_repeat / config.temperature, dim=1).mean() + * config.min_q_weight + * config.temperature + ) + # The actual minimum Q-loss subtracts the value V (i.e. the expected Q-value) + # evaluated at the actually selected actions. + cql_loss = cql_loss - q_selected.mean() * config.min_q_weight + # TODO (simon): Implement CQL twin-Q loss here + + # TODO (simon): Check, if we need to implement here also a Lagrangian + # loss. + + critic_loss = sac_critic_loss + cql_loss + # TODO (simon): Add here also the critic loss for the twin-Q + + # If the batch size is large enough optimize the critic. + if batch_size == config.train_batch_size_per_learner: + critic_optim = self.get_optimizer(module_id=module_id, optimizer_name="qf") + critic_optim.zero_grad(set_to_none=True) + critic_loss.backward(retain_graph=True) + # Add the gradients to the gradient buffer that is evaluated + # in `self.apply_gradients` later. + self.grads.update( + { + pid: p.grad.clone() + for pid, p in self.filter_param_dict_for_optimizer( + self._params, optim + ).items() + } + ) + # TODO (simon): Also optimize the twin-Q. + + total_loss = actor_loss + critic_loss + alpha_loss + # TODO (simon): Add Twin Q losses + + # Log important loss stats (reduce=mean (default), but with window=1 + # in order to keep them history free). + self.metrics.log_dict( + { + POLICY_LOSS_KEY: actor_loss, + QF_LOSS_KEY: critic_loss, + # TODO (simon): Add these keys to SAC Learner. + "cql_loss": cql_loss, + "alpha_loss": alpha_loss, + "alpha_value": alpha, + "log_alpha_value": torch.log(alpha), + "target_entropy": self.target_entropy[module_id], + "actions_curr_policy": torch.mean(actions_curr), + LOGPS_KEY: torch.mean(logps_curr), + QF_MEAN_KEY: torch.mean(q_curr_repeat), + QF_MAX_KEY: torch.max(q_curr_repeat), + QF_MIN_KEY: torch.min(q_curr_repeat), + TD_ERROR_MEAN_KEY: torch.mean(td_error), + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + # TODO (simon): Add loss keys for langrangian, if needed. + # TODO (simon): Add only here then the Langrange parameter optimization. + # TODO (simon): Add keys for twin Q + + # Return the total loss. + return total_loss + + @override(SACTorchLearner) + def compute_gradients( + self, loss_per_module: Dict[str, TensorType], **kwargs + ) -> ParamDict: + + # Return here simply the buffered gradients from `compute_loss_for_module`. + grads = self.grads + # Reset the gradient buffer. + self.grads = {} + # Finally, return the gradients. + return grads From 72acecff7cbf3225f035ac79312bc85c2c98612e Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Mon, 5 Aug 2024 16:39:09 +0200 Subject: [PATCH 2/6] Fixed two linter errors. Signed-off-by: simonsays1980 --- rllib/algorithms/cql/torch/cql_torch_learner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index 4109ba6fd43b..a3b6c201ab34 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -10,6 +10,7 @@ QF_PREDS, TD_ERROR_MEAN_KEY, ) +from ray.rllib.algorithms.cql.cql import CQLConfig from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ( @@ -122,8 +123,6 @@ def compute_loss_for_module( if not config._deterministic_loss else action_dist_next.to_deterministic().sample() ) - # Compute the log probabilities for the next actions. - logps_next = action_dist_next.logp(actions_next) # Get the Q-values for the actually selected actions in the offline data. # In the critic loss we use these as predictions. From ec109e121d6e36b7123771131f0259229028c2c6 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 7 Aug 2024 10:49:11 +0200 Subject: [PATCH 3/6] MOdified loss calculation, specifically Q-value sampling in the CQL loss and switched in actor loss from selected actions to sampled actions from the current policy. Signed-off-by: simonsays1980 --- rllib/algorithms/cql/cql_learner.py | 16 +- .../algorithms/cql/torch/cql_torch_learner.py | 256 +++++++++++++----- 2 files changed, 197 insertions(+), 75 deletions(-) diff --git a/rllib/algorithms/cql/cql_learner.py b/rllib/algorithms/cql/cql_learner.py index b7f907cf10d1..e55015125a19 100644 --- a/rllib/algorithms/cql/cql_learner.py +++ b/rllib/algorithms/cql/cql_learner.py @@ -1,17 +1,25 @@ +from ray.air.constants import TRAINING_ITERATION from ray.rllib.algorithms.sac.sac_learner import SACLearner - from ray.rllib.core.learner.learner import Learner from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics import ALL_MODULES class CQLLearner(SACLearner): @override(Learner) def build(self) -> None: + # We need to call the `super()`'s `build` method here to have the variables + # for `alpha`` and the target entropy defined. + super().build() # Set up the gradient buffer to store gradients to apply # them later in `self.apply_gradients`. self.grads = {} - # We need to call the `super()`'s `build` method here to have the variables - # for `alpha`` and the target entropy defined. - super().build() + # Log the training iteration to switch from behavior cloning to improving + # the policy. + # TODO (simon, sven): Add upstream information pieces into this timesteps + # call arg to Learner.update_...(). + self.metrics.log_value( + (ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1 + ) diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index a3b6c201ab34..0f78742a6ef2 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -1,6 +1,9 @@ +import math +import tree from typing import Dict -from ray.rllib.cql.cql_learner import CQLLearner +from ray.air.constants import TRAINING_ITERATION +from ray.rllib.algorithms.cql.cql_learner import CQLLearner from ray.rllib.algorithms.sac.sac_learner import ( LOGPS_KEY, QF_LOSS_KEY, @@ -17,13 +20,14 @@ POLICY_LOSS_KEY, ) from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics import ALL_MODULES from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType torch, nn = try_import_torch() -class CQLTorchLearner(CQLLearner, SACTorchLearner): +class CQLTorchLearner(SACTorchLearner, CQLLearner): @override(SACTorchLearner) def compute_loss_for_module( self, @@ -34,6 +38,15 @@ def compute_loss_for_module( fwd_out: Dict[str, TensorType], ) -> TensorType: + # TODO (simon, sven): Add upstream information pieces into this timesteps + # call arg to Learner.update_...(). + self.metrics.log_value( + (ALL_MODULES, TRAINING_ITERATION), + 0 + if math.isnan(self.metrics.peek((ALL_MODULES, TRAINING_ITERATION))) + else self.metrics.peek((ALL_MODULES, TRAINING_ITERATION)) + 1, + window=1, + ) # Get the train action distribution for the current policy and current state. # This is needed for the policy (actor) loss and the `alpha`` loss. action_dist_class = self.module[module_id].get_train_action_dist_cls() @@ -86,11 +99,22 @@ def compute_loss_for_module( # Get the current alpha. alpha = torch.exp(self.curr_log_alpha[module_id]) - if self.metrics.peek("current_iteration") >= config.bc_iterations: - q_selected = fwd_out[QF_PREDS] + # Start training with behavior cloning and turn to the classic Soft-Actor Critic + # after `bc_iters` of training iterations. + if self.metrics.peek((ALL_MODULES, TRAINING_ITERATION)) >= config.bc_iters: + # Calculate current Q-values. + batch_curr = { + Columns.OBS: batch[Columns.OBS], + # Use the actions sampled from the current policy. + Columns.ACTIONS: actions_curr, + } + q_curr = self.module[module_id].compute_q_values(batch_curr) # TODO (simon): Add twin Q - actor_loss = torch.mean(alpha.detach() * logps_curr - q_selected) + actor_loss = torch.mean(alpha.detach() * logps_curr - q_curr) else: + # Use log-probabilities of the current action distribution to clone + # the behavior policy (selected actions in data) in the first `bc_iters` + # training iterations. bc_logps_curr = action_dist_curr.logp(batch[Columns.ACTIONS]) actor_loss = torch.mean(alpha.detach() * logps_curr - bc_logps_curr) @@ -98,7 +122,8 @@ def compute_loss_for_module( if batch_size == config.train_batch_size_per_learner: optim = self.get_optimizer(module_id=module_id, optimizer_name="policy") optim.zero_grad() - actor_loss.backward() + # Retain the graph b/c we want to step a nother time through it. + actor_loss.backward(retain_graph=True) # Add the gradients to the gradient buffer that is used # in `self.apply_gradients`. self.grads.update( @@ -111,7 +136,7 @@ def compute_loss_for_module( ) # The critic loss is composed of the standard SAC Critic L2 loss and the - # CQL Entropy loss. + # CQL entropy loss. action_dist_next = action_dist_class.from_logits( fwd_out["action_dist_inputs_next"] ) @@ -133,7 +158,7 @@ def compute_loss_for_module( # sampled actions for the next state. q_batch_next = { Columns.OBS: batch[Columns.NEXT_OBS], - Columns.ATIONS: actions_next, + Columns.ACTIONS: actions_next, } q_target_next = self.module[module_id].forward_target(q_batch_next) # TODO (simon): Apply twin Q @@ -147,7 +172,8 @@ def compute_loss_for_module( q_selected_target = ( # TODO (simon): Add an `n_step` option to the `AddNextObsToBatch` connector. batch[Columns.REWARDS] - + (config.gamma ** batch["n_step"]) * q_next_masked + # TODO (simon): Implement n_step. + + (config.gamma) * q_next_masked ).detach() # Calculate the TD error. @@ -165,85 +191,88 @@ def compute_loss_for_module( # Now calculate the CQL loss (we use the entropy version of the CQL algorithm). # Note, the entropy version performs best in shown experiments. - actions_rand_repeat = torch.FloatTensor( - batch[Columns.ACTIONS].shape[0] * config.num_actions, + # Generate random actions (from the mu distribution as named in Kumar et + # al. (2020)) + low = torch.tensor( + self.module[module_id].config.action_space.low, device=fwd_out[QF_PREDS].device, - ).uniform_( - self.module.config.action_space.low, self.module.config.action_space.high ) - # TODO (simon): Check, if we can repeat these like this or if we need to - # backpropagate through these actions. - # actions_curr_repeat = actions_curr[actions_curr. - # multinomial(config.num_actions, replacement=True).view(-1)] - # actions_next_repeat = actions_curr[actions_next. - # multinomial(config.num_actions, replacement=True).view(-1)] - actions_curr_repeat = ( - action_dist_curr.rsample() - if not config._deterministic_loss - else action_dist_curr.to_deterministic().sample() + high = torch.tensor( + self.module[module_id].config.action_space.high, + device=fwd_out[QF_PREDS].device, + ) + num_samples = batch[Columns.ACTIONS].shape[0] * config.num_actions + actions_rand_repeat = low + (high - low) * torch.rand( + (num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device + ) + + # Sample current and next actions (from the pi distribution as named in Kumar + # et al. (2020)) using repeated observations. + actions_curr_repeat, logps_curr_repeat, obs_curr_repeat = self._repeat_actions( + action_dist_class, batch[Columns.OBS], config.num_actions, module_id + ) + actions_next_repeat, logps_next_repeat, obs_next_repeat = self._repeat_actions( + action_dist_class, batch[Columns.NEXT_OBS], config.num_actions, module_id ) - logps_curr_repeat = action_dist_curr.logp(actions_curr_repeat) - random_idx = actions_curr_repeat.multinomial( - config.num_actions, replacement=True - ).view(-1) - actions_curr_repeat = actions_curr_repeat[random_idx] - logps_curr_repeat = logps_curr_repeat[random_idx] - q_batch_curr_repeat = { - Columns.OBS: batch[Columns.OBS][random_idx], + + # Calculate the Q-values for all actions. + batch_rand_repeat = { + Columns.OBS: obs_curr_repeat, + Columns.ACTIONS: actions_rand_repeat, + } + q_rand_repeat = ( + self.module[module_id] + .compute_q_values(batch_rand_repeat) + .view(batch_size, config.num_actions, 1) + ) + del batch_rand_repeat + batch_curr_repeat = { + Columns.OBS: obs_curr_repeat, Columns.ACTIONS: actions_curr_repeat, } - q_curr_repeat = self.module.compute_q_values(q_batch_curr_repeat) - del q_batch_curr_repeat - # Sample actions for the next state. - actions_next_repeat = ( - action_dist_next.rsample() - if not config._deterministic_loss - else action_dist_next.to_deterministic().sample() + q_curr_repeat = ( + self.module[module_id] + .compute_q_values(batch_curr_repeat) + .view(batch_size, config.num_actions, 1) ) - logps_next_repeat = action_dist_next.logp(actions_next_repeat) - random_idx = actions_next_repeat.multinomial( - config.num_actions, replacement=True - ).view(-1) - actions_next_repeat = actions_next_repeat[random_idx] - logps_next_repeat = logps_next_repeat[random_idx] - q_batch_next_repeat = { - Columns.OBS: batch[Columns.NEXT_OBS][random_idx], + del batch_curr_repeat + batch_next_repeat = { + Columns.OBS: obs_curr_repeat, Columns.ACTIONS: actions_next_repeat, } - q_next_repeat = self.module.compute_q_values(q_batch_next_repeat) - del q_batch_next_repeat + q_next_repeat = ( + self.module[module_id] + .compute_q_values(batch_next_repeat) + .view(batch_size, config.num_actions, 1) + ) + del batch_next_repeat - q_batch_random_repeat = { - # Note, we can use here simply the same random index - # as within the last batch. - Columns.OBS: batch[Columns.OBS][random_idx], - Columns.ACTIONS: actions_rand_repeat, - } - # Compute the Q-values for the random actions (from the mu-distribution). - q_random_repeat = self.module.compute_q_values(q_batch_random_repeat) - del q_batch_random_repeat - # TODO (simon): Check, if this should be `actions_random_repeat`. - logps_random_repeat = torch.logp(0.5**actions_curr_repeat) + # Compute the log-probabilities for the random actions. + random_density = torch.log( + torch.pow( + torch.tensor( + actions_curr_repeat.shape[-1], device=actions_curr_repeat.device + ), + 0.5, + ) + ) + # Merge all Q-values and subtract the log-probabilities (note, we use the + # entropy version of CQL). q_repeat = torch.cat( [ - q_random_repeat - logps_random_repeat, - q_curr_repeat - logps_curr_repeat, - q_next_repeat - logps_next_repeat, + q_rand_repeat - random_density, + q_next_repeat - logps_next_repeat.detach(), + q_curr_repeat - logps_curr_repeat.detach(), ], dim=1, ) - # TODO (simon): Also run the twin Q here. - # Compute the entropy version of the CQL loss (see eq. (4) in Kumar et al. - # (2020)). Note that we use here the softmax with a temperature parameter. cql_loss = ( torch.logsumexp(q_repeat / config.temperature, dim=1).mean() * config.min_q_weight * config.temperature ) - # The actual minimum Q-loss subtracts the value V (i.e. the expected Q-value) - # evaluated at the actually selected actions. - cql_loss = cql_loss - q_selected.mean() * config.min_q_weight + cql_loss = cql_loss - (q_selected.mean() * config.min_q_weight) # TODO (simon): Implement CQL twin-Q loss here # TODO (simon): Check, if we need to implement here also a Lagrangian @@ -261,9 +290,11 @@ def compute_loss_for_module( # in `self.apply_gradients` later. self.grads.update( { - pid: p.grad.clone() + pid: self.grads[pid] + p.grad.clone() + if pid in self.grads + else p.grad.clone() for pid, p in self.filter_param_dict_for_optimizer( - self._params, optim + self._params, critic_optim ).items() } ) @@ -303,12 +334,95 @@ def compute_loss_for_module( @override(SACTorchLearner) def compute_gradients( - self, loss_per_module: Dict[str, TensorType], **kwargs + self, loss_per_module: Dict[ModuleID, TensorType], **kwargs ) -> ParamDict: + """Returns the collected gradients from loss computation. + Note, the gradients for each module are collected in the + `compute_loss_for_module` method. CQL uses similar to SAC multiple learners + and multiple passes through the networks which need to be recorded + step-wise. + + Dict mapping module IDs to their individual total loss + terms, computed by the individual `compute_loss_for_module()` calls. + The overall total loss (sum of loss terms over all modules) is stored + under `loss_per_module[ALL_MODULES]`. + **kwargs: Forward compatibility kwargs. + + Returns: + Returns: + The gradients in the same (flat) format as self._params. Note that all + top-level structures, such as module IDs, will not be present anymore in + the returned dict. It will merely map parameter tensor references to their + respective gradient tensors. + """ + # TODO (simon): Check, if we can use a similar setup to SAC and also make the + # backward passes all here. # Return here simply the buffered gradients from `compute_loss_for_module`. grads = self.grads # Reset the gradient buffer. self.grads = {} # Finally, return the gradients. return grads + + def _repeat_tensor(self, tensor, repeat): + """Generates a repeated version of a tensor. + + The repetition is done similar `np.repeat` and repeats each value + instead of the complete vector. + + Args: + tensor: The tensor to be repeated. + repeat: How often each value in the tensor should be repeated. + + Returns: + A tensor holding `repeat` repeated values of the input `tensor` + """ + # Insert the new dimension at axis 1 into the tensor. + t_repeat = tensor.unsqueeze(1) + # Repeat the tensor along the new dimension. + t_repeat = torch.repeat_interleave(t_repeat, repeat, dim=1) + # Stack the repeated values into the batch dimension. + t_repeat = t_repeat.view(-1, *tensor.shape[1:]) + # Return the repeated tensor. + return t_repeat + + def _repeat_actions(self, action_dist_class, obs, num_actions, module_id): + """Generated actions for repeated observations. + + The `num_actions` define a multiplier used for generating `num_actions` + as many actions as the batch size. Observations are repeated and then a + model forward pass is made. + + Args: + action_dist_class: The action distribution class to be sued for sampling + actions. + obs: A batched observation tensor. + num_actions: The multiplier for actions, i.e. how much more actions + than the batch size should be generated. + module_id: The module ID to be used when calling the forward pass. + + Returns: + A tuple containing the sampled actions, their log-probabilities and the + repeated observations. + """ + # Receive the batch size. + batch_size = obs.shape[0] + # Repeat the observations `num_actions` times. + obs_repeat = tree.map_structure( + lambda t: self._repeat_tensor(t, num_actions), obs + ) + # Generate a batch for the forward pass. + temp_batch = {Columns.OBS: obs_repeat} + # Run the forward pass in inference mode. + fwd_out = self.module[module_id].forward_inference(temp_batch) + # Generate the squashed Gaussian from the model's logits. + action_dist = action_dist_class.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS]) + # Sample the actions. Note, we want to make a backward pass through + # these actions. + actions = action_dist.rsample() + # Compute the action log-probabilities. + action_logps = action_dist.logp(actions).view(batch_size, num_actions, 1) + + # Return + return actions, action_logps, obs_repeat From 3d7e3536685eacf5aebe14ed5e46c29b6a943502 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Mon, 12 Aug 2024 11:03:47 +0200 Subject: [PATCH 4/6] Moved optimizing from 'compute_loss_for_module' to 'compute_gradients'. Signed-off-by: simonsays1980 --- rllib/algorithms/cql/cql_learner.py | 11 +- .../algorithms/cql/torch/cql_torch_learner.py | 103 +++++------------- 2 files changed, 28 insertions(+), 86 deletions(-) diff --git a/rllib/algorithms/cql/cql_learner.py b/rllib/algorithms/cql/cql_learner.py index e55015125a19..c0e76d3dafd7 100644 --- a/rllib/algorithms/cql/cql_learner.py +++ b/rllib/algorithms/cql/cql_learner.py @@ -12,14 +12,9 @@ def build(self) -> None: # for `alpha`` and the target entropy defined. super().build() - # Set up the gradient buffer to store gradients to apply - # them later in `self.apply_gradients`. - self.grads = {} - - # Log the training iteration to switch from behavior cloning to improving - # the policy. - # TODO (simon, sven): Add upstream information pieces into this timesteps - # call arg to Learner.update_...(). + # Add a metric to keep track of training iterations to + # determine when switching the actor loss from behavior + # cloning to SAC. self.metrics.log_value( (ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1 ) diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index 0f78742a6ef2..66695115c075 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -79,23 +79,6 @@ def compute_loss_for_module( # Get the current batch size. Note, this size might vary in case the # last batch contains less than `train_batch_size_per_learner` examples. batch_size = batch[Columns.OBS].shape[0] - # Optimize the hyperparameter `alpha` by using the current policy evaluated - # at the current state. Note further, we minimize here, while the original - # equation in Haarnoja et al. (2018) considers maximization. - if batch_size == config.train_batch_size_per_learner: - optim = self.get_optimizer(module_id=module_id, optimizer_name="alpha") - optim.zero_grad(set_to_none=True) - alpha_loss.backward() - # Add the gradients to the gradient buffer that is evaluated later in - # `self.apply_gradients`. - self.grads.update( - { - pid: p.grad.clone() - for pid, p in self.filter_param_dict_for_optimizer( - self._params, optim - ).items() - } - ) # Get the current alpha. alpha = torch.exp(self.curr_log_alpha[module_id]) @@ -118,23 +101,6 @@ def compute_loss_for_module( bc_logps_curr = action_dist_curr.logp(batch[Columns.ACTIONS]) actor_loss = torch.mean(alpha.detach() * logps_curr - bc_logps_curr) - # Optimize the SAC actor loss. - if batch_size == config.train_batch_size_per_learner: - optim = self.get_optimizer(module_id=module_id, optimizer_name="policy") - optim.zero_grad() - # Retain the graph b/c we want to step a nother time through it. - actor_loss.backward(retain_graph=True) - # Add the gradients to the gradient buffer that is used - # in `self.apply_gradients`. - self.grads.update( - { - pid: p.grad.clone() - for pid, p in self.filter_param_dict_for_optimizer( - self._params, optim - ).items() - } - ) - # The critic loss is composed of the standard SAC Critic L2 loss and the # CQL entropy loss. action_dist_next = action_dist_class.from_logits( @@ -281,25 +247,6 @@ def compute_loss_for_module( critic_loss = sac_critic_loss + cql_loss # TODO (simon): Add here also the critic loss for the twin-Q - # If the batch size is large enough optimize the critic. - if batch_size == config.train_batch_size_per_learner: - critic_optim = self.get_optimizer(module_id=module_id, optimizer_name="qf") - critic_optim.zero_grad(set_to_none=True) - critic_loss.backward(retain_graph=True) - # Add the gradients to the gradient buffer that is evaluated - # in `self.apply_gradients` later. - self.grads.update( - { - pid: self.grads[pid] + p.grad.clone() - if pid in self.grads - else p.grad.clone() - for pid, p in self.filter_param_dict_for_optimizer( - self._params, critic_optim - ).items() - } - ) - # TODO (simon): Also optimize the twin-Q. - total_loss = actor_loss + critic_loss + alpha_loss # TODO (simon): Add Twin Q losses @@ -336,33 +283,33 @@ def compute_loss_for_module( def compute_gradients( self, loss_per_module: Dict[ModuleID, TensorType], **kwargs ) -> ParamDict: - """Returns the collected gradients from loss computation. - - Note, the gradients for each module are collected in the - `compute_loss_for_module` method. CQL uses similar to SAC multiple learners - and multiple passes through the networks which need to be recorded - step-wise. - Dict mapping module IDs to their individual total loss - terms, computed by the individual `compute_loss_for_module()` calls. - The overall total loss (sum of loss terms over all modules) is stored - under `loss_per_module[ALL_MODULES]`. - **kwargs: Forward compatibility kwargs. + grads = {} + for module_id in set(loss_per_module.keys()) - {ALL_MODULES}: + # Loop through optimizers registered for this module. + for optim_name, optim in self.get_optimizers_for_module(module_id): + # Zero the gradients. Note, we need to reset the gradients b/c + # each component for a module operates on the same graph. + optim.zero_grad(set_to_none=True) + + # Compute the gradients for the component and module. + self.metrics.peek((module_id, optim_name + "_loss")).backward( + retain_graph=True + ) + # Store the gradients for the component and module. + # TODO (simon): Check another time the graph for overlapping + # gradients. + grads.update( + { + pid: grads[pid] + p.grad.clone() + if pid in grads + else p.grad.clone() + for pid, p in self.filter_param_dict_for_optimizer( + self._params, optim + ).items() + } + ) - Returns: - Returns: - The gradients in the same (flat) format as self._params. Note that all - top-level structures, such as module IDs, will not be present anymore in - the returned dict. It will merely map parameter tensor references to their - respective gradient tensors. - """ - # TODO (simon): Check, if we can use a similar setup to SAC and also make the - # backward passes all here. - # Return here simply the buffered gradients from `compute_loss_for_module`. - grads = self.grads - # Reset the gradient buffer. - self.grads = {} - # Finally, return the gradients. return grads def _repeat_tensor(self, tensor, repeat): From af3ac60336add397be1c010b1b1e3fbf7c1cf67d Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Mon, 12 Aug 2024 11:08:21 +0200 Subject: [PATCH 5/6] Changed logging training iterations to a simpler logic proposed by @sven1977. Signed-off-by: simonsays1980 --- rllib/algorithms/cql/torch/cql_torch_learner.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index 66695115c075..318c7fa805fc 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -1,4 +1,3 @@ -import math import tree from typing import Dict @@ -42,10 +41,8 @@ def compute_loss_for_module( # call arg to Learner.update_...(). self.metrics.log_value( (ALL_MODULES, TRAINING_ITERATION), - 0 - if math.isnan(self.metrics.peek((ALL_MODULES, TRAINING_ITERATION))) - else self.metrics.peek((ALL_MODULES, TRAINING_ITERATION)) + 1, - window=1, + 1, + reduce="sum", ) # Get the train action distribution for the current policy and current state. # This is needed for the policy (actor) loss and the `alpha`` loss. @@ -84,7 +81,10 @@ def compute_loss_for_module( alpha = torch.exp(self.curr_log_alpha[module_id]) # Start training with behavior cloning and turn to the classic Soft-Actor Critic # after `bc_iters` of training iterations. - if self.metrics.peek((ALL_MODULES, TRAINING_ITERATION)) >= config.bc_iters: + if ( + self.metrics.peek((ALL_MODULES, TRAINING_ITERATION), default=0) + >= config.bc_iters + ): # Calculate current Q-values. batch_curr = { Columns.OBS: batch[Columns.OBS], From 5d8ba924736c335b2063fc2e1f8d7c7994dc5505 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 12 Aug 2024 12:37:53 +0200 Subject: [PATCH 6/6] wip Signed-off-by: sven1977 --- rllib/algorithms/cql/cql_learner.py | 20 ------------------- .../algorithms/cql/torch/cql_torch_learner.py | 5 ++--- 2 files changed, 2 insertions(+), 23 deletions(-) delete mode 100644 rllib/algorithms/cql/cql_learner.py diff --git a/rllib/algorithms/cql/cql_learner.py b/rllib/algorithms/cql/cql_learner.py deleted file mode 100644 index c0e76d3dafd7..000000000000 --- a/rllib/algorithms/cql/cql_learner.py +++ /dev/null @@ -1,20 +0,0 @@ -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms.sac.sac_learner import SACLearner -from ray.rllib.core.learner.learner import Learner -from ray.rllib.utils.annotations import override -from ray.rllib.utils.metrics import ALL_MODULES - - -class CQLLearner(SACLearner): - @override(Learner) - def build(self) -> None: - # We need to call the `super()`'s `build` method here to have the variables - # for `alpha`` and the target entropy defined. - super().build() - - # Add a metric to keep track of training iterations to - # determine when switching the actor loss from behavior - # cloning to SAC. - self.metrics.log_value( - (ALL_MODULES, TRAINING_ITERATION), float("nan"), window=1 - ) diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index 318c7fa805fc..a41770ff0592 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -2,7 +2,6 @@ from typing import Dict from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms.cql.cql_learner import CQLLearner from ray.rllib.algorithms.sac.sac_learner import ( LOGPS_KEY, QF_LOSS_KEY, @@ -26,7 +25,7 @@ torch, nn = try_import_torch() -class CQLTorchLearner(SACTorchLearner, CQLLearner): +class CQLTorchLearner(SACTorchLearner): @override(SACTorchLearner) def compute_loss_for_module( self, @@ -38,7 +37,7 @@ def compute_loss_for_module( ) -> TensorType: # TODO (simon, sven): Add upstream information pieces into this timesteps - # call arg to Learner.update_...(). + # call arg to Learner.update_...(). self.metrics.log_value( (ALL_MODULES, TRAINING_ITERATION), 1,