Skip to content

Commit

Permalink
[RLlib] Add gradient checks to avoid nan gradients in `TorchLearner…
Browse files Browse the repository at this point in the history
…`. (#47452)
  • Loading branch information
simonsays1980 authored Sep 3, 2024
1 parent eda6d09 commit e1ed103
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
18 changes: 18 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def __init__(self, algo_class: Optional[type] = None):
self.torch_compile_worker_dynamo_mode = None
# Default kwargs for `torch.nn.parallel.DistributedDataParallel`.
self.torch_ddp_kwargs = {}
# Default setting for skipping `nan` gradient updates.
self.torch_skip_nan_gradients = False

# `self.api_stack()`
self.enable_rl_module_and_learner = False
Expand Down Expand Up @@ -1381,6 +1383,7 @@ def framework(
torch_compile_worker_dynamo_backend: Optional[str] = NotProvided,
torch_compile_worker_dynamo_mode: Optional[str] = NotProvided,
torch_ddp_kwargs: Optional[Dict[str, Any]] = NotProvided,
torch_skip_nan_gradients: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's DL framework settings.
Expand Down Expand Up @@ -1426,6 +1429,19 @@ def framework(
that are not used in the backward pass. This can give hints for errors
in custom models where some parameters do not get touched in the
backward pass although they should.
torch_skip_nan_gradients: If updates with `nan` gradients should be entirely
skipped. This skips updates in the optimizer entirely if they contain
any `nan` gradient. This can help to avoid biasing moving-average based
optimizers - like Adam. This can help in training phases where policy
updates can be highly unstable such as during the early stages of
training or with highly exploratory policies. In such phases many
gradients might turn `nan` and setting them to zero could corrupt the
optimizer's internal state. The default is `False` and turns `nan`
gradients to zero. If many `nan` gradients are encountered consider (a)
monitoring gradients by setting `log_gradients` in `AlgorithmConfig` to
`True`, (b) use proper weight initialization (e.g. Xavier, Kaiming) via
the `model_config_dict` in `AlgorithmConfig.rl_module` and/or (c)
gradient clipping via `grad_clip` in `AlgorithmConfig.training`.
Returns:
This updated AlgorithmConfig object.
Expand Down Expand Up @@ -1469,6 +1485,8 @@ def framework(
self.torch_compile_worker_dynamo_mode = torch_compile_worker_dynamo_mode
if torch_ddp_kwargs is not NotProvided:
self.torch_ddp_kwargs = torch_ddp_kwargs
if torch_skip_nan_gradients is not NotProvided:
self.torch_skip_nan_gradients = torch_skip_nan_gradients

return self

Expand Down
35 changes: 34 additions & 1 deletion rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,27 @@ def compute_gradients(
def apply_gradients(self, gradients_dict: ParamDict) -> None:
# Set the gradient of the parameters.
for pid, grad in gradients_dict.items():
self._params[pid].grad = grad
# If updates should not be skipped turn `nan` and `inf` gradients to zero.
if (
not torch.isfinite(grad).all()
and not self.config.torch_skip_nan_gradients
):
# Warn the user about `nan` gradients.
logger.warning(f"Gradients {pid} contain `nan/inf` values.")
# If updates should be skipped, do not step the optimizer and return.
if not self.config.torch_skip_nan_gradients:
logger.warning(
"Setting `nan/inf` gradients to zero. If updates with "
"`nan/inf` gradients should not be set to zero and instead "
"the update be skipped entirely set `torch_skip_nan_gradients` "
"to `True`."
)
# If necessary turn `nan` gradients to zero. Note this can corrupt the
# internal state of the optimizer, if many `nan` gradients occur.
self._params[pid].grad = torch.nan_to_num(grad)
# Otherwise, use the gradient as is.
else:
self._params[pid].grad = grad

# For each optimizer call its step function.
for module_id, optimizer_names in self._module_optimizers.items():
Expand All @@ -200,6 +220,19 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None:
for param in group["params"]
):
optim.step()
# If gradients are not all finite warn the user that the update will be
# skipped.
elif not all(
torch.isfinite(param.grad).all()
for group in optim.param_groups
for param in group["params"]
):
logger.warning(
"Skipping this update. If updates with `nan/inf` gradients "
"should not be skipped entirely and instead `nan/inf` "
"gradients set to `zero` set `torch_skip_nan_gradients` to "
"`False`."
)

@override(Learner)
def _get_optimizer_state(self) -> StateDict:
Expand Down

0 comments on commit e1ed103

Please sign in to comment.