diff --git a/rllib/BUILD b/rllib/BUILD index 5cba1f4607f9..a3591e59ea5c 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -271,7 +271,7 @@ py_test( name = "learning_tests_multi_agent_stateless_cartpole_appo_multi_cpu", main = "tuned_examples/appo/multi_agent_stateless_cartpole_appo.py", tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], - size = "large", + size = "enormous", srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"], args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] ) @@ -279,7 +279,7 @@ py_test( name = "learning_tests_multi_agent_stateless_cartpole_appo_multi_gpu", main = "tuned_examples/appo/multi_agent_stateless_cartpole_appo.py", tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], - size = "large", + size = "enormous", srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"], args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] ) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 66820d3cb947..b658eb1c32de 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -492,6 +492,7 @@ def __init__(self, algo_class: Optional[type] = None): self.min_time_s_per_iteration = None self.min_train_timesteps_per_iteration = 0 self.min_sample_timesteps_per_iteration = 0 + self.log_gradients = True # `self.checkpointing()` self.export_native_model_files = False @@ -2883,6 +2884,7 @@ def reporting( min_time_s_per_iteration: Optional[float] = NotProvided, min_train_timesteps_per_iteration: Optional[int] = NotProvided, min_sample_timesteps_per_iteration: Optional[int] = NotProvided, + log_gradients: Optional[bool] = NotProvided, ) -> "AlgorithmConfig": """Sets the config's reporting settings. @@ -2923,6 +2925,9 @@ def reporting( sampling timestep count has not been reached, will perform n more `training_step()` calls until the minimum timesteps have been executed. Set to 0 or None for no minimum timesteps. + log_gradients: Log gradients to results. If this is `True` the global norm + of the gradients dictionariy for each optimizer is logged to results. + The default is `True`. Returns: This updated AlgorithmConfig object. @@ -2941,6 +2946,8 @@ def reporting( self.min_train_timesteps_per_iteration = min_train_timesteps_per_iteration if min_sample_timesteps_per_iteration is not NotProvided: self.min_sample_timesteps_per_iteration = min_sample_timesteps_per_iteration + if log_gradients is not NotProvided: + self.log_gradients = log_gradients return self diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index c49d80ab2c29..6edd31df0275 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -541,7 +541,7 @@ def postprocess_gradients_for_module( """ postprocessed_grads = {} - if config.grad_clip is None: + if config.grad_clip is None and not config.log_gradients: postprocessed_grads.update(module_gradients_dict) return postprocessed_grads @@ -550,19 +550,40 @@ def postprocess_gradients_for_module( param_dict=module_gradients_dict, optimizer=optimizer, ) - # Perform gradient clipping, if configured. - global_norm = self._get_clip_function()( - grad_dict_to_clip, - grad_clip=config.grad_clip, - grad_clip_by=config.grad_clip_by, - ) - if config.grad_clip_by == "global_norm": + if config.grad_clip: + # Perform gradient clipping, if configured. + global_norm = self._get_clip_function()( + grad_dict_to_clip, + grad_clip=config.grad_clip, + grad_clip_by=config.grad_clip_by, + ) + if config.grad_clip_by == "global_norm" or config.log_gradients: + # If we want to log gradients, but do not use the global norm + # for clipping compute it here. + if config.log_gradients and config.grad_clip_by != "global_norm": + # Compute the global norm of gradients. + global_norm = self._get_global_norm_function()( + # Note, `tf.linalg.global_norm` needs a list of tensors. + list(grad_dict_to_clip.values()), + ) + self.metrics.log_value( + key=(module_id, f"gradients_{optimizer_name}_global_norm"), + value=global_norm, + window=1, + ) + postprocessed_grads.update(grad_dict_to_clip) + # In the other case check, if we want to log gradients only. + elif config.log_gradients: + # Compute the global norm of gradients and log it. + global_norm = self._get_global_norm_function()( + # Note, `tf.linalg.global_norm` needs a list of tensors. + list(grad_dict_to_clip.values()), + ) self.metrics.log_value( key=(module_id, f"gradients_{optimizer_name}_global_norm"), value=global_norm, window=1, ) - postprocessed_grads.update(grad_dict_to_clip) return postprocessed_grads @@ -1576,6 +1597,11 @@ def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None: def _get_clip_function() -> Callable: """Returns the gradient clipping function to use, given the framework.""" + @staticmethod + @abc.abstractmethod + def _get_global_norm_function() -> Callable: + """Returns the global norm function to use, given the framework.""" + def _log_steps_trained_metrics(self, batch: MultiAgentBatch): """Logs this iteration's steps trained, based on given `batch`.""" diff --git a/rllib/core/learner/tf/tf_learner.py b/rllib/core/learner/tf/tf_learner.py index 2745ad817f82..4e4765d64e91 100644 --- a/rllib/core/learner/tf/tf_learner.py +++ b/rllib/core/learner/tf/tf_learner.py @@ -348,3 +348,8 @@ def _get_clip_function() -> Callable: from ray.rllib.utils.tf_utils import clip_gradients return clip_gradients + + @staticmethod + @override(Learner) + def _get_global_norm_function() -> Callable: + return tf.linalg.global_norm diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index a15ff06651aa..aac22e4bd5bd 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -601,3 +601,10 @@ def _get_clip_function() -> Callable: from ray.rllib.utils.torch_utils import clip_gradients return clip_gradients + + @staticmethod + @override(Learner) + def _get_global_norm_function() -> Callable: + from ray.rllib.utils.torch_utils import compute_global_norm + + return compute_global_norm diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py index b9667ec05058..462d0fe9ff69 100644 --- a/rllib/utils/torch_utils.py +++ b/rllib/utils/torch_utils.py @@ -22,7 +22,7 @@ ) if TYPE_CHECKING: - from ray.rllib.core.learner.learner import ParamDict + from ray.rllib.core.learner.learner import ParamDict, ParamList from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 @@ -106,7 +106,7 @@ def clip_gradients( *, grad_clip: Optional[float] = None, grad_clip_by: str = "value", -) -> Optional[float]: +) -> TensorType: """Performs gradient clipping on a grad-dict based on a clip value and clip mode. Changes the provided gradient dict in place. @@ -147,34 +147,12 @@ def clip_gradients( assert ( grad_clip_by == "global_norm" ), f"`grad_clip_by` ({grad_clip_by}) must be one of [value|norm|global_norm]!" - - grads = [g for g in gradients_dict.values() if g is not None] - norm_type = 2.0 - if len(grads) == 0: - return torch.tensor(0.0) - device = grads[0].device - - total_norm = torch.norm( - torch.stack( - [ - torch.norm(g.detach(), norm_type) - # Note, we want to avoid overflow in the norm computation, this does - # not affect the gradients themselves as we clamp by multiplying and - # not by overriding tensor values. - .nan_to_num(neginf=-10e8, posinf=10e8).to(device) - for g in grads - ] - ), - norm_type, - ).nan_to_num(neginf=-10e8, posinf=10e8) - if torch.logical_or(total_norm.isnan(), total_norm.isinf()): - raise RuntimeError( - f"The total norm of order {norm_type} for gradients from " - "`parameters` is non-finite, so it cannot be clipped. " - ) + gradients_list = list(gradients_dict.values()) + total_norm = compute_global_norm(gradients_list) # We do want the coefficient to be in between 0.0 and 1.0, therefore # if the global_norm is smaller than the clip value, we use the clip value # as normalization constant. + device = gradients_list[0].device clip_coef = grad_clip / torch.maximum( torch.tensor(grad_clip).to(device), total_norm + 1e-6 ) @@ -182,11 +160,53 @@ def clip_gradients( # 1, but doing so avoids a `if clip_coef < 1:` conditional which can require a # CPU <=> device synchronization when the gradients do not reside in CPU memory. clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - for g in grads: - g.detach().mul_(clip_coef_clamped.to(g.device)) + for g in gradients_list: + if g is not None: + g.detach().mul_(clip_coef_clamped.to(g.device)) return total_norm +@PublicAPI +def compute_global_norm(gradients_list: "ParamList") -> TensorType: + """Computes the global norm for a gradients dict. + + Args: + gradients_list: The gradients list containing parameters. + + Returns: + Returns the global norm of all tensors in `gradients_list`. + """ + # Define the norm type to be L2. + norm_type = 2.0 + # If we have no grads, return zero. + if len(gradients_list) == 0: + return torch.tensor(0.0) + device = gradients_list[0].device + + # Compute the global norm. + total_norm = torch.norm( + torch.stack( + [ + torch.norm(g.detach(), norm_type) + # Note, we want to avoid overflow in the norm computation, this does + # not affect the gradients themselves as we clamp by multiplying and + # not by overriding tensor values. + .nan_to_num(neginf=-10e8, posinf=10e8).to(device) + for g in gradients_list + if g is not None + ] + ), + norm_type, + ).nan_to_num(neginf=-10e8, posinf=10e8) + if torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. " + ) + # Return the global norm. + return total_norm + + @PublicAPI def concat_multi_gpu_td_errors( policy: Union["TorchPolicy", "TorchPolicyV2"] diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 7c86ad47844b..95f33688a85f 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -168,6 +168,7 @@ Param = Union["torch.Tensor", "tf.Variable"] ParamRef = Hashable ParamDict = Dict[ParamRef, Param] +ParamList = List[Param] # A single learning rate or a learning rate schedule (list of sub-lists, each of # the format: [ts (int), lr_to_reach_by_ts (float)]).