Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Implement twin-Q net option for CQL. #47105

Merged
merged 4 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 113 additions & 17 deletions rllib/algorithms/cql/torch/cql_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
QF_MAX_KEY,
QF_MIN_KEY,
QF_PREDS,
QF_TWIN_LOSS_KEY,
QF_TWIN_PREDS,
TD_ERROR_MEAN_KEY,
)
from ray.rllib.algorithms.cql.cql import CQLConfig
Expand Down Expand Up @@ -90,8 +92,9 @@ def compute_loss_for_module(
# Use the actions sampled from the current policy.
Columns.ACTIONS: actions_curr,
}
# Note, if `twin_q` is `True`, `compute_q_values` computes the minimum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# of the `qf` and `qf_twin` and returns this minimum.
q_curr = self.module[module_id].compute_q_values(batch_curr)
# TODO (simon): Add twin Q
actor_loss = torch.mean(alpha.detach() * logps_curr - q_curr)
else:
# Use log-probabilities of the current action distribution to clone
Expand All @@ -117,16 +120,19 @@ def compute_loss_for_module(
# Get the Q-values for the actually selected actions in the offline data.
# In the critic loss we use these as predictions.
q_selected = fwd_out[QF_PREDS]
# TODO (simon): Implement twin Q
if config.twin_q:
q_twin_selected = fwd_out[QF_TWIN_PREDS]

# Compute Q-values from the target Q network for the next state with the
# sampled actions for the next state.
q_batch_next = {
Columns.OBS: batch[Columns.NEXT_OBS],
Columns.ACTIONS: actions_next,
}
# Note, if `twin_q` is `True`, `SACTorchRLModule.forward_target` calculates
# the Q-values for both, `qf_target` and `qf_twin_target` and
# returns the minimum.
q_target_next = self.module[module_id].forward_target(q_batch_next)
# TODO (simon): Apply twin Q

# Now mask all Q-values with terminating next states in the targets.
q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_target_next
Expand All @@ -143,16 +149,25 @@ def compute_loss_for_module(

# Calculate the TD error.
td_error = torch.abs(q_selected - q_selected_target)
# TODO (simon): Add the Twin TD error
# Calculate a TD-error for twin-Q values, if needed.
if config.twin_q:
td_error += torch.abs(q_twin_selected - q_selected_target)
# Rescale the TD error
td_error *= 0.5

# MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)).
# Note, this needs a sample from the current policy given the next state.
# Note further, we could also use here the Huber loss instead of the MSE.
# TODO (simon): Add the huber loss as an alternative (SAC uses it).
sac_critic_loss = torch.nn.MSELoss(reduction="mean")(
q_selected, q_selected_target
q_selected,
q_selected_target,
)
# TODO (simon): Add the Twin Q critic loss
if config.twin_q:
sac_critic_twin_loss = torch.nn.MSELoss(reduction="mean")(
q_twin_selected,
q_selected_target,
)

# Now calculate the CQL loss (we use the entropy version of the CQL algorithm).
# Note, the entropy version performs best in shown experiments.
Expand Down Expand Up @@ -185,34 +200,88 @@ def compute_loss_for_module(
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_rand_repeat,
}
# Note, we need here the Q-values from the base Q-value function
# and not the minimum with an eventual Q-value twin.
q_rand_repeat = (
self.module[module_id]
.compute_q_values(batch_rand_repeat)
._qf_forward_train_helper(
batch_rand_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate twin Q-values for the random actions, if needed.
if config.twin_q:
q_twin_rand_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_rand_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_rand_repeat
batch_curr_repeat = {
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_curr_repeat,
}
q_curr_repeat = (
self.module[module_id]
.compute_q_values(batch_curr_repeat)
._qf_forward_train_helper(
batch_curr_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate twin Q-values for the repeated actions from the current policy,
# if needed.
if config.twin_q:
q_twin_curr_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_curr_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_curr_repeat
batch_next_repeat = {
# Note, we use here the current observations b/c we want to keep the
# state fix while sampling the actions.
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_next_repeat,
}
q_next_repeat = (
self.module[module_id]
.compute_q_values(batch_next_repeat)
._qf_forward_train_helper(
batch_next_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate also the twin Q-values for the current policy and next actions,
# if needed.
if config.twin_q:
q_twin_next_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_next_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_next_repeat

# Compute the log-probabilities for the random actions.
# TODO (simon): This is the density for a discrete uniform, however, actions
# come from a continuous one. So actually this density should use (1/(high-low))
# instead of (1/2).
random_density = torch.log(
torch.pow(
torch.tensor(
Expand All @@ -231,23 +300,43 @@ def compute_loss_for_module(
],
dim=1,
)

cql_loss = (
torch.logsumexp(q_repeat / config.temperature, dim=1).mean()
* config.min_q_weight
* config.temperature
)
cql_loss = cql_loss - (q_selected.mean() * config.min_q_weight)
# TODO (simon): Implement CQL twin-Q loss here
cql_loss -= q_selected.mean() * config.min_q_weight
# Add the CQL loss term to the SAC loss term.
critic_loss = sac_critic_loss + cql_loss

# If a twin Q-value function is implemented calculated its CQL loss.
if config.twin_q:
q_twin_repeat = torch.cat(
[
q_twin_rand_repeat - random_density,
q_twin_next_repeat - logps_next_repeat.detach(),
q_twin_curr_repeat - logps_curr_repeat.detach(),
],
dim=1,
)
cql_twin_loss = (
torch.logsumexp(q_twin_repeat / config.temperature, dim=1).mean()
* config.min_q_weight
* config.temperature
)
cql_twin_loss -= q_twin_selected.mean() * config.min_q_weight
# Add the CQL loss term to the SAC loss term.
critic_twin_loss = sac_critic_twin_loss + cql_twin_loss

# TODO (simon): Check, if we need to implement here also a Lagrangian
# loss.

critic_loss = sac_critic_loss + cql_loss
# TODO (simon): Add here also the critic loss for the twin-Q

total_loss = actor_loss + critic_loss + alpha_loss
# TODO (simon): Add Twin Q losses

# Add the twin critic loss to the total loss, if needed.
if config.twin_q:
# Reweigh the critic loss terms in the total loss.
total_loss += 0.5 * critic_twin_loss - 0.5 * critic_loss

# Log important loss stats (reduce=mean (default), but with window=1
# in order to keep them history free).
Expand All @@ -273,7 +362,14 @@ def compute_loss_for_module(
)
# TODO (simon): Add loss keys for langrangian, if needed.
# TODO (simon): Add only here then the Langrange parameter optimization.
# TODO (simon): Add keys for twin Q
if config.twin_q:
self.metrics.log_dict(
{
QF_TWIN_LOSS_KEY: critic_twin_loss,
},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
)

# Return the total loss.
return total_loss
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def compute_loss_for_module(
total_loss = actor_loss + critic_loss + alpha_loss
# If twin Q networks should be used, add the critic loss of the twin Q network.
if config.twin_q:
# TODO (simon): Check, if we need to multiply the critic_loss then with 0.5.
total_loss += critic_twin_loss

# Log the TD-error with reduce=None, such that - in case we have n parallel
Expand Down
Loading