Skip to content

Commit

Permalink
[RLlib; Offline RL] Implement twin-Q net option for CQL. (ray-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 committed Aug 14, 2024
1 parent c77eab7 commit 38d8178
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 17 deletions.
130 changes: 113 additions & 17 deletions rllib/algorithms/cql/torch/cql_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
QF_MAX_KEY,
QF_MIN_KEY,
QF_PREDS,
QF_TWIN_LOSS_KEY,
QF_TWIN_PREDS,
TD_ERROR_MEAN_KEY,
)
from ray.rllib.algorithms.cql.cql import CQLConfig
Expand Down Expand Up @@ -90,8 +92,9 @@ def compute_loss_for_module(
# Use the actions sampled from the current policy.
Columns.ACTIONS: actions_curr,
}
# Note, if `twin_q` is `True`, `compute_q_values` computes the minimum
# of the `qf` and `qf_twin` and returns this minimum.
q_curr = self.module[module_id].compute_q_values(batch_curr)
# TODO (simon): Add twin Q
actor_loss = torch.mean(alpha.detach() * logps_curr - q_curr)
else:
# Use log-probabilities of the current action distribution to clone
Expand All @@ -117,16 +120,19 @@ def compute_loss_for_module(
# Get the Q-values for the actually selected actions in the offline data.
# In the critic loss we use these as predictions.
q_selected = fwd_out[QF_PREDS]
# TODO (simon): Implement twin Q
if config.twin_q:
q_twin_selected = fwd_out[QF_TWIN_PREDS]

# Compute Q-values from the target Q network for the next state with the
# sampled actions for the next state.
q_batch_next = {
Columns.OBS: batch[Columns.NEXT_OBS],
Columns.ACTIONS: actions_next,
}
# Note, if `twin_q` is `True`, `SACTorchRLModule.forward_target` calculates
# the Q-values for both, `qf_target` and `qf_twin_target` and
# returns the minimum.
q_target_next = self.module[module_id].forward_target(q_batch_next)
# TODO (simon): Apply twin Q

# Now mask all Q-values with terminating next states in the targets.
q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_target_next
Expand All @@ -143,16 +149,25 @@ def compute_loss_for_module(

# Calculate the TD error.
td_error = torch.abs(q_selected - q_selected_target)
# TODO (simon): Add the Twin TD error
# Calculate a TD-error for twin-Q values, if needed.
if config.twin_q:
td_error += torch.abs(q_twin_selected - q_selected_target)
# Rescale the TD error
td_error *= 0.5

# MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)).
# Note, this needs a sample from the current policy given the next state.
# Note further, we could also use here the Huber loss instead of the MSE.
# TODO (simon): Add the huber loss as an alternative (SAC uses it).
sac_critic_loss = torch.nn.MSELoss(reduction="mean")(
q_selected, q_selected_target
q_selected,
q_selected_target,
)
# TODO (simon): Add the Twin Q critic loss
if config.twin_q:
sac_critic_twin_loss = torch.nn.MSELoss(reduction="mean")(
q_twin_selected,
q_selected_target,
)

# Now calculate the CQL loss (we use the entropy version of the CQL algorithm).
# Note, the entropy version performs best in shown experiments.
Expand Down Expand Up @@ -185,34 +200,88 @@ def compute_loss_for_module(
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_rand_repeat,
}
# Note, we need here the Q-values from the base Q-value function
# and not the minimum with an eventual Q-value twin.
q_rand_repeat = (
self.module[module_id]
.compute_q_values(batch_rand_repeat)
._qf_forward_train_helper(
batch_rand_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate twin Q-values for the random actions, if needed.
if config.twin_q:
q_twin_rand_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_rand_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_rand_repeat
batch_curr_repeat = {
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_curr_repeat,
}
q_curr_repeat = (
self.module[module_id]
.compute_q_values(batch_curr_repeat)
._qf_forward_train_helper(
batch_curr_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate twin Q-values for the repeated actions from the current policy,
# if needed.
if config.twin_q:
q_twin_curr_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_curr_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_curr_repeat
batch_next_repeat = {
# Note, we use here the current observations b/c we want to keep the
# state fix while sampling the actions.
Columns.OBS: obs_curr_repeat,
Columns.ACTIONS: actions_next_repeat,
}
q_next_repeat = (
self.module[module_id]
.compute_q_values(batch_next_repeat)
._qf_forward_train_helper(
batch_next_repeat,
self.module[module_id].qf_encoder,
self.module[module_id].qf,
)
.view(batch_size, config.num_actions, 1)
)
# Calculate also the twin Q-values for the current policy and next actions,
# if needed.
if config.twin_q:
q_twin_next_repeat = (
self.module[module_id]
._qf_forward_train_helper(
batch_next_repeat,
self.module[module_id].qf_twin_encoder,
self.module[module_id].qf_twin,
)
.view(batch_size, config.num_actions, 1)
)
del batch_next_repeat

# Compute the log-probabilities for the random actions.
# TODO (simon): This is the density for a discrete uniform, however, actions
# come from a continuous one. So actually this density should use (1/(high-low))
# instead of (1/2).
random_density = torch.log(
torch.pow(
torch.tensor(
Expand All @@ -231,23 +300,43 @@ def compute_loss_for_module(
],
dim=1,
)

cql_loss = (
torch.logsumexp(q_repeat / config.temperature, dim=1).mean()
* config.min_q_weight
* config.temperature
)
cql_loss = cql_loss - (q_selected.mean() * config.min_q_weight)
# TODO (simon): Implement CQL twin-Q loss here
cql_loss -= q_selected.mean() * config.min_q_weight
# Add the CQL loss term to the SAC loss term.
critic_loss = sac_critic_loss + cql_loss

# If a twin Q-value function is implemented calculated its CQL loss.
if config.twin_q:
q_twin_repeat = torch.cat(
[
q_twin_rand_repeat - random_density,
q_twin_next_repeat - logps_next_repeat.detach(),
q_twin_curr_repeat - logps_curr_repeat.detach(),
],
dim=1,
)
cql_twin_loss = (
torch.logsumexp(q_twin_repeat / config.temperature, dim=1).mean()
* config.min_q_weight
* config.temperature
)
cql_twin_loss -= q_twin_selected.mean() * config.min_q_weight
# Add the CQL loss term to the SAC loss term.
critic_twin_loss = sac_critic_twin_loss + cql_twin_loss

# TODO (simon): Check, if we need to implement here also a Lagrangian
# loss.

critic_loss = sac_critic_loss + cql_loss
# TODO (simon): Add here also the critic loss for the twin-Q

total_loss = actor_loss + critic_loss + alpha_loss
# TODO (simon): Add Twin Q losses

# Add the twin critic loss to the total loss, if needed.
if config.twin_q:
# Reweigh the critic loss terms in the total loss.
total_loss += 0.5 * critic_twin_loss - 0.5 * critic_loss

# Log important loss stats (reduce=mean (default), but with window=1
# in order to keep them history free).
Expand All @@ -273,7 +362,14 @@ def compute_loss_for_module(
)
# TODO (simon): Add loss keys for langrangian, if needed.
# TODO (simon): Add only here then the Langrange parameter optimization.
# TODO (simon): Add keys for twin Q
if config.twin_q:
self.metrics.log_dict(
{
QF_TWIN_LOSS_KEY: critic_twin_loss,
},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
)

# Return the total loss.
return total_loss
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def compute_loss_for_module(
total_loss = actor_loss + critic_loss + alpha_loss
# If twin Q networks should be used, add the critic loss of the twin Q network.
if config.twin_q:
# TODO (simon): Check, if we need to multiply the critic_loss then with 0.5.
total_loss += critic_twin_loss

# Log the TD-error with reduce=None, such that - in case we have n parallel
Expand Down

0 comments on commit 38d8178

Please sign in to comment.