diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 3ee0490b771f..53445b30bb77 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -314,6 +314,8 @@ def __init__(self, algo_class: Optional[type] = None): self.torch_compile_worker_dynamo_mode = None # Default kwargs for `torch.nn.parallel.DistributedDataParallel`. self.torch_ddp_kwargs = {} + # Default setting for skipping `nan` gradient updates. + self.torch_skip_nan_gradients = False # `self.api_stack()` self.enable_rl_module_and_learner = False @@ -1381,6 +1383,7 @@ def framework( torch_compile_worker_dynamo_backend: Optional[str] = NotProvided, torch_compile_worker_dynamo_mode: Optional[str] = NotProvided, torch_ddp_kwargs: Optional[Dict[str, Any]] = NotProvided, + torch_skip_nan_gradients: Optional[bool] = NotProvided, ) -> "AlgorithmConfig": """Sets the config's DL framework settings. @@ -1426,6 +1429,19 @@ def framework( that are not used in the backward pass. This can give hints for errors in custom models where some parameters do not get touched in the backward pass although they should. + torch_skip_nan_gradients: If updates with `nan` gradients should be entirely + skipped. This skips updates in the optimizer entirely if they contain + any `nan` gradient. This can help to avoid biasing moving-average based + optimizers - like Adam. This can help in training phases where policy + updates can be highly unstable such as during the early stages of + training or with highly exploratory policies. In such phases many + gradients might turn `nan` and setting them to zero could corrupt the + optimizer's internal state. The default is `False` and turns `nan` + gradients to zero. If many `nan` gradients are encountered consider (a) + monitoring gradients by setting `log_gradients` in `AlgorithmConfig` to + `True`, (b) use proper weight initialization (e.g. Xavier, Kaiming) via + the `model_config_dict` in `AlgorithmConfig.rl_module` and/or (c) + gradient clipping via `grad_clip` in `AlgorithmConfig.training`. Returns: This updated AlgorithmConfig object. @@ -1469,6 +1485,8 @@ def framework( self.torch_compile_worker_dynamo_mode = torch_compile_worker_dynamo_mode if torch_ddp_kwargs is not NotProvided: self.torch_ddp_kwargs = torch_ddp_kwargs + if torch_skip_nan_gradients is not NotProvided: + self.torch_skip_nan_gradients = torch_skip_nan_gradients return self diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index 25f54c5a4c70..b336cf02338b 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -176,7 +176,27 @@ def compute_gradients( def apply_gradients(self, gradients_dict: ParamDict) -> None: # Set the gradient of the parameters. for pid, grad in gradients_dict.items(): - self._params[pid].grad = grad + # If updates should not be skipped turn `nan` and `inf` gradients to zero. + if ( + not torch.isfinite(grad).all() + and not self.config.torch_skip_nan_gradients + ): + # Warn the user about `nan` gradients. + logger.warning(f"Gradients {pid} contain `nan/inf` values.") + # If updates should be skipped, do not step the optimizer and return. + if not self.config.torch_skip_nan_gradients: + logger.warning( + "Setting `nan/inf` gradients to zero. If updates with " + "`nan/inf` gradients should not be set to zero and instead " + "the update be skipped entirely set `torch_skip_nan_gradients` " + "to `True`." + ) + # If necessary turn `nan` gradients to zero. Note this can corrupt the + # internal state of the optimizer, if many `nan` gradients occur. + self._params[pid].grad = torch.nan_to_num(grad) + # Otherwise, use the gradient as is. + else: + self._params[pid].grad = grad # For each optimizer call its step function. for module_id, optimizer_names in self._module_optimizers.items(): @@ -200,6 +220,19 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None: for param in group["params"] ): optim.step() + # If gradients are not all finite warn the user that the update will be + # skipped. + elif not all( + torch.isfinite(param.grad).all() + for group in optim.param_groups + for param in group["params"] + ): + logger.warning( + "Skipping this update. If updates with `nan/inf` gradients " + "should not be skipped entirely and instead `nan/inf` " + "gradients set to `zero` set `torch_skip_nan_gradients` to " + "`False`." + ) @override(Learner) def _get_optimizer_state(self) -> StateDict: