Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Add CQLLearner and CQLTorchLearner. #46969

Merged
merged 8 commits into from
Aug 12, 2024
374 changes: 374 additions & 0 deletions rllib/algorithms/cql/torch/cql_torch_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
import tree
from typing import Dict

from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.sac.sac_learner import (
LOGPS_KEY,
QF_LOSS_KEY,
QF_MEAN_KEY,
QF_MAX_KEY,
QF_MIN_KEY,
QF_PREDS,
TD_ERROR_MEAN_KEY,
)
from ray.rllib.algorithms.cql.cql import CQLConfig
from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import (
POLICY_LOSS_KEY,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import ALL_MODULES
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType

torch, nn = try_import_torch()


class CQLTorchLearner(SACTorchLearner):
@override(SACTorchLearner)
def compute_loss_for_module(
self,
*,
module_id: ModuleID,
config: CQLConfig,
batch: Dict,
fwd_out: Dict[str, TensorType],
) -> TensorType:

# TODO (simon, sven): Add upstream information pieces into this timesteps
# call arg to Learner.update_...().
self.metrics.log_value(
Copy link
Contributor

@sven1977 sven1977 Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New that reduce="sum" is already defined above, you can just do:

log_value(
    (ALL_MODULES, TRAINING_ITERATION),
    1,
    reduce="sum",
)

(ALL_MODULES, TRAINING_ITERATION),
1,
reduce="sum",
)
# Get the train action distribution for the current policy and current state.
# This is needed for the policy (actor) loss and the `alpha`` loss.
action_dist_class = self.module[module_id].get_train_action_dist_cls()
action_dist_curr = action_dist_class.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)

# Sample actions for the current state. Note that we need to apply the
# reparameterization trick here to avoid the expectation over actions.
actions_curr = (
action_dist_curr.rsample()
if not config._deterministic_loss
# If deterministic, we use the mean.s
else action_dist_curr.to_deterministic().sample()
)
# Compute the log probabilities for the current state (for the alpha loss)
logps_curr = action_dist_curr.logp(actions_curr)

# Optimize also the hyperparameter `alpha` by using the current policy
# evaluated at the current state (from offline data). Note, in contrast
# to the original SAC loss, here the `alpha` and actor losses are
# calculated first.
# TODO (simon): Check, why log(alpha) is used, prob. just better
# to optimize and monotonic function. Original equation uses alpha.
alpha_loss = -torch.mean(
self.curr_log_alpha[module_id]
* (logps_curr.detach() + self.target_entropy[module_id])
)

# Get the current batch size. Note, this size might vary in case the
# last batch contains less than `train_batch_size_per_learner` examples.
batch_size = batch[Columns.OBS].shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we explain this logic here and why we defend against different batch sizes coming in?

Don't we expect always the same batch size from the data pipeline?

If not, we should:
a) explain here why we are expecting various batch sizes
b) probably fix this logic here. What if we call compute_loss_for_module 10x with a batch of train_batch_size_per_learner - 1 (accumulating gradients for these, but not applying these) and then one batch of size train_batch_size_per_learner. In this case, the effective batch size would be roughly 11x the user configured one, correct?

c) Also, what if the incoming batch is larger than train_batch_size_per_learner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sven1977 I fully agree on your comment.

a) This was also part of the old stack algorithm. The reason behind this is that when iterating over a dataset the last batch could have less than train_batch_size_per_learner samples in it.

b) The offline logic here is that after each call to compute_loss_for_module a compute_gradients and apply_gradients will occur. No SGD is run on the offline algorithms (yet). So the train batch size should be always as large as configured (without the last one being smaller).

c) This case should also not happen. iter_batches takes care of this and ensures that always the batch size is sampled, without the last one which can be avoided by setting a flag, but neglects data.


# Get the current alpha.
alpha = torch.exp(self.curr_log_alpha[module_id])
# Start training with behavior cloning and turn to the classic Soft-Actor Critic
# after `bc_iters` of training iterations.
if (
self.metrics.peek((ALL_MODULES, TRAINING_ITERATION), default=0)
>= config.bc_iters
):
# Calculate current Q-values.
batch_curr = {
Columns.OBS: batch[Columns.OBS],
# Use the actions sampled from the current policy.
Columns.ACTIONS: actions_curr,
}
q_curr = self.module[module_id].compute_q_values(batch_curr)
# TODO (simon): Add twin Q
actor_loss = torch.mean(alpha.detach() * logps_curr - q_curr)
else:
# Use log-probabilities of the current action distribution to clone
# the behavior policy (selected actions in data) in the first `bc_iters`
# training iterations.
bc_logps_curr = action_dist_curr.logp(batch[Columns.ACTIONS])
actor_loss = torch.mean(alpha.detach() * logps_curr - bc_logps_curr)

# The critic loss is composed of the standard SAC Critic L2 loss and the
# CQL entropy loss.
action_dist_next = action_dist_class.from_logits(
fwd_out["action_dist_inputs_next"]
)
# Sample the actions for the next state.
actions_next = (
# Note, we do not need to backpropagate through the
# next actions.
action_dist_next.sample()
if not config._deterministic_loss
else action_dist_next.to_deterministic().sample()
)

# Get the Q-values for the actually selected actions in the offline data.
# In the critic loss we use these as predictions.
q_selected = fwd_out[QF_PREDS]
# TODO (simon): Implement twin Q

# Compute Q-values from the target Q network for the next state with the
# sampled actions for the next state.
q_batch_next = {
Columns.OBS: batch[Columns.NEXT_OBS],
Columns.ACTIONS: actions_next,
}
q_target_next = self.module[module_id].forward_target(q_batch_next)
# TODO (simon): Apply twin Q

# Now mask all Q-values with terminating next states in the targets.
q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_target_next

# Compute the right hand side of the Bellman equation. Detach this node
# from the computation graph as we do not want to backpropagate through
# the target netowrk when optimizing the Q loss.
q_selected_target = (
# TODO (simon): Add an `n_step` option to the `AddNextObsToBatch` connector.
batch[Columns.REWARDS]
# TODO (simon): Implement n_step.
+ (config.gamma) * q_next_masked
).detach()

# Calculate the TD error.
td_error = torch.abs(q_selected - q_selected_target)
# TODO (simon): Add the Twin TD error

# MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)).
# Note, this needs a sample from the current policy given the next state.
# Note further, we could also use here the Huber loss instead of the MSE.
# TODO (simon): Add the huber loss as an alternative (SAC uses it).
sac_critic_loss = torch.nn.MSELoss(reduction="mean")(
q_selected, q_selected_target
)
# TODO (simon): Add the Twin Q critic loss

# Now calculate the CQL loss (we use the entropy version of the CQL algorithm).
# Note, the entropy version performs best in shown experiments.
# Generate random actions (from the mu distribution as named in Kumar et
# al. (2020))
low = torch.tensor(
self.module[module_id].config.action_space.low,
device=fwd_out[QF_PREDS].device,
)
high = torch.tensor(
self.module[module_id].config.action_space.high,
device=fwd_out[QF_PREDS].device,
)
num_samples = batch[Columns.ACTIONS].shape[0] * config.num_actions
actions_rand_repeat = low + (high - low) * torch.rand(
(num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device
)

# Sample current and next actions (from the pi distribution as named in Kumar
# et al. (2020)) using repeated observations.
actions_curr_repeat, logps_curr_repeat, obs_curr_repeat = self._repeat_actions(
action_dist_class, batch[Columns.OBS], config.num_actions, module_id
)
actions_next_repeat, logps_next_repeat, obs_next_repeat = self._repeat_actions(
action_dist_class, batch[Columns.NEXT_OBS], config.num_actions, module_id
)

# Calculate the Q-values for all actions.
batch_rand_repeat = {
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_rand_repeat,
}
q_rand_repeat = (
self.module[module_id]
.compute_q_values(batch_rand_repeat)
.view(batch_size, config.num_actions, 1)
)
del batch_rand_repeat
batch_curr_repeat = {
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_curr_repeat,
}
q_curr_repeat = (
self.module[module_id]
.compute_q_values(batch_curr_repeat)
.view(batch_size, config.num_actions, 1)
)
del batch_curr_repeat
batch_next_repeat = {
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_next_repeat,
}
q_next_repeat = (
self.module[module_id]
.compute_q_values(batch_next_repeat)
.view(batch_size, config.num_actions, 1)
)
del batch_next_repeat

# Compute the log-probabilities for the random actions.
random_density = torch.log(
torch.pow(
torch.tensor(
actions_curr_repeat.shape[-1], device=actions_curr_repeat.device
),
0.5,
)
)
# Merge all Q-values and subtract the log-probabilities (note, we use the
# entropy version of CQL).
q_repeat = torch.cat(
[
q_rand_repeat - random_density,
q_next_repeat - logps_next_repeat.detach(),
q_curr_repeat - logps_curr_repeat.detach(),
],
dim=1,
)

cql_loss = (
torch.logsumexp(q_repeat / config.temperature, dim=1).mean()
* config.min_q_weight
* config.temperature
)
cql_loss = cql_loss - (q_selected.mean() * config.min_q_weight)
# TODO (simon): Implement CQL twin-Q loss here

# TODO (simon): Check, if we need to implement here also a Lagrangian
# loss.

critic_loss = sac_critic_loss + cql_loss
# TODO (simon): Add here also the critic loss for the twin-Q

total_loss = actor_loss + critic_loss + alpha_loss
# TODO (simon): Add Twin Q losses

# Log important loss stats (reduce=mean (default), but with window=1
# in order to keep them history free).
self.metrics.log_dict(
{
POLICY_LOSS_KEY: actor_loss,
QF_LOSS_KEY: critic_loss,
# TODO (simon): Add these keys to SAC Learner.
"cql_loss": cql_loss,
"alpha_loss": alpha_loss,
"alpha_value": alpha,
"log_alpha_value": torch.log(alpha),
"target_entropy": self.target_entropy[module_id],
"actions_curr_policy": torch.mean(actions_curr),
LOGPS_KEY: torch.mean(logps_curr),
QF_MEAN_KEY: torch.mean(q_curr_repeat),
QF_MAX_KEY: torch.max(q_curr_repeat),
QF_MIN_KEY: torch.min(q_curr_repeat),
TD_ERROR_MEAN_KEY: torch.mean(td_error),
},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
)
# TODO (simon): Add loss keys for langrangian, if needed.
# TODO (simon): Add only here then the Langrange parameter optimization.
# TODO (simon): Add keys for twin Q

# Return the total loss.
return total_loss

@override(SACTorchLearner)
def compute_gradients(
self, loss_per_module: Dict[ModuleID, TensorType], **kwargs
) -> ParamDict:

grads = {}
for module_id in set(loss_per_module.keys()) - {ALL_MODULES}:
# Loop through optimizers registered for this module.
for optim_name, optim in self.get_optimizers_for_module(module_id):
# Zero the gradients. Note, we need to reset the gradients b/c
# each component for a module operates on the same graph.
optim.zero_grad(set_to_none=True)

# Compute the gradients for the component and module.
self.metrics.peek((module_id, optim_name + "_loss")).backward(
retain_graph=True
)
# Store the gradients for the component and module.
# TODO (simon): Check another time the graph for overlapping
# gradients.
grads.update(
{
pid: grads[pid] + p.grad.clone()
if pid in grads
else p.grad.clone()
for pid, p in self.filter_param_dict_for_optimizer(
self._params, optim
).items()
}
)

return grads

def _repeat_tensor(self, tensor, repeat):
"""Generates a repeated version of a tensor.

The repetition is done similar `np.repeat` and repeats each value
instead of the complete vector.

Args:
tensor: The tensor to be repeated.
repeat: How often each value in the tensor should be repeated.

Returns:
A tensor holding `repeat` repeated values of the input `tensor`
"""
# Insert the new dimension at axis 1 into the tensor.
t_repeat = tensor.unsqueeze(1)
# Repeat the tensor along the new dimension.
t_repeat = torch.repeat_interleave(t_repeat, repeat, dim=1)
# Stack the repeated values into the batch dimension.
t_repeat = t_repeat.view(-1, *tensor.shape[1:])
# Return the repeated tensor.
return t_repeat

def _repeat_actions(self, action_dist_class, obs, num_actions, module_id):
"""Generated actions for repeated observations.

The `num_actions` define a multiplier used for generating `num_actions`
as many actions as the batch size. Observations are repeated and then a
model forward pass is made.

Args:
action_dist_class: The action distribution class to be sued for sampling
actions.
obs: A batched observation tensor.
num_actions: The multiplier for actions, i.e. how much more actions
than the batch size should be generated.
module_id: The module ID to be used when calling the forward pass.

Returns:
A tuple containing the sampled actions, their log-probabilities and the
repeated observations.
"""
# Receive the batch size.
batch_size = obs.shape[0]
# Repeat the observations `num_actions` times.
obs_repeat = tree.map_structure(
lambda t: self._repeat_tensor(t, num_actions), obs
)
# Generate a batch for the forward pass.
temp_batch = {Columns.OBS: obs_repeat}
# Run the forward pass in inference mode.
fwd_out = self.module[module_id].forward_inference(temp_batch)
# Generate the squashed Gaussian from the model's logits.
action_dist = action_dist_class.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS])
# Sample the actions. Note, we want to make a backward pass through
# these actions.
actions = action_dist.rsample()
# Compute the action log-probabilities.
action_logps = action_dist.logp(actions).view(batch_size, num_actions, 1)

# Return
return actions, action_logps, obs_repeat
Loading