Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Use config (not self.config) in Learner.compute_loss_for_module to prepare these for multi-agent-capability. #45053

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def compute_loss_for_module(
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
if self.config.enable_env_runner_and_connector_v2:
if config.enable_env_runner_and_connector_v2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder, if the new env runners work on APPO/IMPALA. In my test case they do not in the MA case where a list of episodes is tried to be compressed_if_needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMPALA and APPO are WIP on the new EnvRunners, officially not supported yet.

https://docs.ray.io/en/master/rllib/rllib-new-api-stack.html

bootstrap_values = batch[Columns.VALUES_BOOTSTRAPPED]
else:
bootstrap_values_time_major = make_time_major(
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/tf/impala_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def compute_loss_for_module(
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
if self.config.enable_env_runner_and_connector_v2:
if config.enable_env_runner_and_connector_v2:
bootstrap_values = batch[Columns.VALUES_BOOTSTRAPPED]
else:
bootstrap_values_time_major = make_time_major(
Expand Down
67 changes: 35 additions & 32 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import ALL_MODULES
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType

Expand Down Expand Up @@ -60,10 +61,10 @@ def configure_optimizers_for_module(
optimizer_name="qf",
optimizer=optim_critic,
params=params_critic,
lr_or_lr_schedule=self.config.lr,
lr_or_lr_schedule=config.lr,
)
# If necessary register also an optimizer for a twin Q network.
if self.config.twin_q:
if config.twin_q:
params_twin_critic = self.get_parameters(
module.qf_twin_encoder
) + self.get_parameters(module.qf_twin)
Expand All @@ -74,7 +75,7 @@ def configure_optimizers_for_module(
optimizer_name="qf_twin",
optimizer=optim_twin_critic,
params=params_twin_critic,
lr_or_lr_schedule=self.config.lr,
lr_or_lr_schedule=config.lr,
)

# Define the optimizer for the actor.
Expand All @@ -88,7 +89,7 @@ def configure_optimizers_for_module(
optimizer_name="policy",
optimizer=optim_actor,
params=params_actor,
lr_or_lr_schedule=self.config.lr,
lr_or_lr_schedule=config.lr,
)

# Define the optimizer for the temperature.
Expand All @@ -99,7 +100,7 @@ def configure_optimizers_for_module(
optimizer_name="alpha",
optimizer=optim_temperature,
params=[temperature],
lr_or_lr_schedule=self.config.lr,
lr_or_lr_schedule=config.lr,
)

@override(DQNRainbowTorchLearner)
Expand All @@ -112,7 +113,7 @@ def compute_loss_for_module(
fwd_out: Mapping[str, TensorType]
) -> TensorType:
# Only for debugging.
deterministic = self.config._deterministic_loss
deterministic = config._deterministic_loss

# Receive the current alpha hyperparameter.
alpha = torch.exp(self.curr_log_alpha[module_id])
Expand Down Expand Up @@ -154,7 +155,7 @@ def compute_loss_for_module(
# Get Q-values for the actually selected actions during rollout.
# In the critic loss we use these as predictions.
q_selected = fwd_out[QF_PREDS]
if self.config.twin_q:
if config.twin_q:
q_twin_selected = fwd_out[QF_TWIN_PREDS]

# Compute Q-values for the current policy in the current state with
Expand All @@ -168,7 +169,7 @@ def compute_loss_for_module(
q_curr = self.module[module_id]._qf_forward_train(q_batch_curr)[QF_PREDS]
# If a twin Q network should be used, calculate twin Q-values and use the
# minimum.
if self.config.twin_q:
if config.twin_q:
q_twin_curr = self.module[module_id]._qf_twin_forward_train(q_batch_curr)[
QF_PREDS
]
Expand All @@ -187,7 +188,7 @@ def compute_loss_for_module(
]
# If a twin Q network should be used, calculate twin Q-values and use the
# minimum.
if self.config.twin_q:
if config.twin_q:
q_target_twin_next = self.module[module_id]._qf_target_twin_forward_train(
q_batch_next
)[QF_PREDS]
Expand All @@ -203,15 +204,14 @@ def compute_loss_for_module(
# Detach this node from the computation graph as we do not want to
# backpropagate through the target network when optimizing the Q loss.
q_selected_target = (
batch[Columns.REWARDS]
+ (self.config.gamma ** batch["n_steps"]) * q_next_masked
batch[Columns.REWARDS] + (config.gamma ** batch["n_steps"]) * q_next_masked
).detach()

# Calculate the TD-error. Note, this is needed for the priority weights in
# the replay buffer.
td_error = torch.abs(q_selected - q_selected_target)
# If a twin Q network should be used, add the TD error of the twin Q network.
if self.config.twin_q:
if config.twin_q:
td_error += torch.abs(q_twin_selected - q_selected_target)
# Rescale the TD error.
td_error *= 0.5
Expand All @@ -229,7 +229,7 @@ def compute_loss_for_module(
)
)
# If a twin Q network should be used, add the critic loss of the twin Q network.
if self.config.twin_q:
if config.twin_q:
critic_twin_loss = torch.mean(
batch["weights"]
* torch.nn.HuberLoss(reduction="none", delta=1.0)(
Expand All @@ -254,7 +254,7 @@ def compute_loss_for_module(

total_loss = actor_loss + critic_loss + alpha_loss
# If twin Q networks should be used, add the critic loss of the twin Q network.
if self.config.twin_q:
if config.twin_q:
total_loss += critic_twin_loss

# Log the TD-error with reduce=None, such that - in case we have n parallel
Expand Down Expand Up @@ -288,7 +288,7 @@ def compute_loss_for_module(
)
# If twin Q networks should be used add a critic loss for the twin Q network.
# Note, we need this in the `self.compute_gradients()` to optimize.
if self.config.twin_q:
if config.twin_q:
self.metrics.log_dict(
{
QF_TWIN_LOSS_KEY: critic_twin_loss,
Expand All @@ -308,22 +308,25 @@ def compute_gradients(

grads = {}

# Calculate gradients for each loss by its optimizer.
# TODO (sven): Maybe we rename to `actor`, `critic`. We then also
# need to either add to or change in the `Learner` constants.
for component in ["qf", "policy", "alpha"] + (
["qf_twin"] if self.config.twin_q else []
):
self.metrics.peek(DEFAULT_MODULE_ID, component + "_loss").backward(
retain_graph=True
)
grads.update(
{
pid: p.grad
for pid, p in self.filter_param_dict_for_optimizer(
self._params, self.get_optimizer(optimizer_name=component)
).items()
}
)
for module_id in set(loss_per_module.keys()) - {ALL_MODULES}:
config = self.config.get_config_for_module(module_id)

# Calculate gradients for each loss by its optimizer.
# TODO (sven): Maybe we rename to `actor`, `critic`. We then also
# need to either add to or change in the `Learner` constants.
for component in (
["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else []
):
self.metrics.peek(DEFAULT_MODULE_ID, component + "_loss").backward(
retain_graph=True
)
grads.update(
{
pid: p.grad
for pid, p in self.filter_param_dict_for_optimizer(
self._params, self.get_optimizer(module_id, component)
).items()
}
)

return grads
Loading