Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add gradient checks to avoid nan gradients in TorchLearner. #47452

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def __init__(self, algo_class: Optional[type] = None):
self.torch_compile_worker_dynamo_mode = None
# Default kwargs for `torch.nn.parallel.DistributedDataParallel`.
self.torch_ddp_kwargs = {}
# Default setting for skipping `nan` gradient updates.
self.torch_skip_nan_gradients = False

# `self.api_stack()`
self.enable_rl_module_and_learner = False
Expand Down Expand Up @@ -1381,6 +1383,7 @@ def framework(
torch_compile_worker_dynamo_backend: Optional[str] = NotProvided,
torch_compile_worker_dynamo_mode: Optional[str] = NotProvided,
torch_ddp_kwargs: Optional[Dict[str, Any]] = NotProvided,
torch_skip_nan_gradients: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's DL framework settings.

Expand Down Expand Up @@ -1426,6 +1429,19 @@ def framework(
that are not used in the backward pass. This can give hints for errors
in custom models where some parameters do not get touched in the
backward pass although they should.
torch_skip_nan_gradients: If updates with `nan` gradients should be entirely
skipped. This skips updates in the optimizer entirely if they contain
any `nan` gradient. This can help to avoid biasing moving-average based
optimizers - like Adam. This can help in training phases where policy
updates can be highly unstable such as during the early stages of
training or with highly exploratory policies. In such phases many
gradients might turn `nan` and setting them to zero could corrupt the
optimizer's internal state. The default is `False` and turns `nan`
gradients to zero. If many `nan` gradients are encountered consider (a)
monitoring gradients by setting `log_gradients` in `AlgorithmConfig` to
`True`, (b) use proper weight initialization (e.g. Xavier, Kaiming) via
the `model_config_dict` in `AlgorithmConfig.rl_module` and/or (c)
gradient clipping via `grad_clip` in `AlgorithmConfig.training`.

Returns:
This updated AlgorithmConfig object.
Expand Down Expand Up @@ -1469,6 +1485,8 @@ def framework(
self.torch_compile_worker_dynamo_mode = torch_compile_worker_dynamo_mode
if torch_ddp_kwargs is not NotProvided:
self.torch_ddp_kwargs = torch_ddp_kwargs
if torch_skip_nan_gradients is not NotProvided:
self.torch_skip_nan_gradients = torch_skip_nan_gradients

return self

Expand Down
35 changes: 34 additions & 1 deletion rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,27 @@ def compute_gradients(
def apply_gradients(self, gradients_dict: ParamDict) -> None:
# Set the gradient of the parameters.
for pid, grad in gradients_dict.items():
self._params[pid].grad = grad
# If updates should not be skipped turn `nan` and `inf` gradients to zero.
if (
not torch.isfinite(grad).all()
and not self.config.torch_skip_nan_gradients
):
# Warn the user about `nan` gradients.
logger.warning(f"Gradients {pid} contain `nan/inf` values.")
# If updates should be skipped, do not step the optimizer and return.
if not self.config.torch_skip_nan_gradients:
logger.warning(
"Setting `nan/inf` gradients to zero. If updates with "
"`nan/inf` gradients should not be set to zero and instead "
"the update be skipped entirely set `torch_skip_nan_gradients` "
"to `True`."
)
# If necessary turn `nan` gradients to zero. Note this can corrupt the
# internal state of the optimizer, if many `nan` gradients occur.
self._params[pid].grad = torch.nan_to_num(grad)
# Otherwise, use the gradient as is.
else:
self._params[pid].grad = grad

# For each optimizer call its step function.
for module_id, optimizer_names in self._module_optimizers.items():
Expand All @@ -200,6 +220,19 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None:
for param in group["params"]
):
optim.step()
# If gradients are not all finite warn the user that the update will be
# skipped.
elif not all(
torch.isfinite(param.grad).all()
for group in optim.param_groups
for param in group["params"]
):
logger.warning(
"Skipping this update. If updates with `nan/inf` gradients "
"should not be skipped entirely and instead `nan/inf` "
"gradients set to `zero` set `torch_skip_nan_gradients` to "
"`False`."
)

@override(Learner)
def _get_optimizer_state(self) -> StateDict:
Expand Down
Loading