Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add gradient logging as default. #47451

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
9684c80
Added funcitonality to compute global norm for torch independently fr…
simonsays1980 Sep 2, 2024
450a5df
Added gradient logging to non-global-norm clipping options.
simonsays1980 Sep 2, 2024
2a06536
Added @sven1977's review.
simonsays1980 Sep 2, 2024
6f38e7a
Added 'compute_global_norm' to the public API.
simonsays1980 Sep 2, 2024
ecaffe4
Fixed a small bug with dict values.
simonsays1980 Sep 2, 2024
9b6de9f
Moved default setting for 'log_gradients' to 'reporting' default block.
simonsays1980 Sep 3, 2024
11294cb
Removed unused argument from 'clip_gradients'.
simonsays1980 Sep 3, 2024
89e0428
Merge branch 'master' into add-default-gradient-logging
simonsays1980 Sep 3, 2024
57435a2
Converted gradients dict to list b/c 'tf.linalg.global_norm' needs it…
simonsays1980 Sep 4, 2024
14177e3
Merge branch 'master' into add-default-gradient-logging
simonsays1980 Sep 4, 2024
0460240
Fixed a small bug that was caused by passing in a gradient dict inste…
simonsays1980 Sep 4, 2024
c20001f
Merge branch 'master' into add-default-gradient-logging
simonsays1980 Sep 4, 2024
db926b7
Merge branch 'master' into add-default-gradient-logging
simonsays1980 Sep 6, 2024
e18204e
Merge branch 'master' into add-default-gradient-logging
simonsays1980 Sep 9, 2024
4d4da62
Set APPO multi-learner tests to 'enormous' as they were cosntantyl ti…
simonsays1980 Sep 11, 2024
613ceb3
Merge branch 'master' into add-default-gradient-logging
simonsays1980 Sep 12, 2024
df1d758
Merge branch 'master' into add-default-gradient-logging
simonsays1980 Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,15 @@ py_test(
name = "learning_tests_multi_agent_stateless_cartpole_appo_multi_cpu",
main = "tuned_examples/appo/multi_agent_stateless_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "large",
size = "enormous",
srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"]
)
py_test(
name = "learning_tests_multi_agent_stateless_cartpole_appo_multi_gpu",
main = "tuned_examples/appo/multi_agent_stateless_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"],
size = "large",
size = "enormous",
srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"]
)
Expand Down
7 changes: 7 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.min_time_s_per_iteration = None
self.min_train_timesteps_per_iteration = 0
self.min_sample_timesteps_per_iteration = 0
self.log_gradients = True

# `self.checkpointing()`
self.export_native_model_files = False
Expand Down Expand Up @@ -2883,6 +2884,7 @@ def reporting(
min_time_s_per_iteration: Optional[float] = NotProvided,
min_train_timesteps_per_iteration: Optional[int] = NotProvided,
min_sample_timesteps_per_iteration: Optional[int] = NotProvided,
log_gradients: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's reporting settings.

Expand Down Expand Up @@ -2923,6 +2925,9 @@ def reporting(
sampling timestep count has not been reached, will perform n more
`training_step()` calls until the minimum timesteps have been
executed. Set to 0 or None for no minimum timesteps.
log_gradients: Log gradients to results. If this is `True` the global norm
of the gradients dictionariy for each optimizer is logged to results.
The default is `True`.

Returns:
This updated AlgorithmConfig object.
Expand All @@ -2941,6 +2946,8 @@ def reporting(
self.min_train_timesteps_per_iteration = min_train_timesteps_per_iteration
if min_sample_timesteps_per_iteration is not NotProvided:
self.min_sample_timesteps_per_iteration = min_sample_timesteps_per_iteration
if log_gradients is not NotProvided:
self.log_gradients = log_gradients

return self

Expand Down
44 changes: 35 additions & 9 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def postprocess_gradients_for_module(
"""
postprocessed_grads = {}

if config.grad_clip is None:
if config.grad_clip is None and not config.log_gradients:
postprocessed_grads.update(module_gradients_dict)
return postprocessed_grads

Expand All @@ -550,19 +550,40 @@ def postprocess_gradients_for_module(
param_dict=module_gradients_dict,
optimizer=optimizer,
)
# Perform gradient clipping, if configured.
global_norm = self._get_clip_function()(
grad_dict_to_clip,
grad_clip=config.grad_clip,
grad_clip_by=config.grad_clip_by,
)
if config.grad_clip_by == "global_norm":
if config.grad_clip:
# Perform gradient clipping, if configured.
global_norm = self._get_clip_function()(
grad_dict_to_clip,
grad_clip=config.grad_clip,
grad_clip_by=config.grad_clip_by,
)
if config.grad_clip_by == "global_norm" or config.log_gradients:
# If we want to log gradients, but do not use the global norm
# for clipping compute it here.
if config.log_gradients and config.grad_clip_by != "global_norm":
# Compute the global norm of gradients.
global_norm = self._get_global_norm_function()(
# Note, `tf.linalg.global_norm` needs a list of tensors.
list(grad_dict_to_clip.values()),
)
self.metrics.log_value(
key=(module_id, f"gradients_{optimizer_name}_global_norm"),
value=global_norm,
window=1,
)
postprocessed_grads.update(grad_dict_to_clip)
# In the other case check, if we want to log gradients only.
elif config.log_gradients:
# Compute the global norm of gradients and log it.
global_norm = self._get_global_norm_function()(
# Note, `tf.linalg.global_norm` needs a list of tensors.
list(grad_dict_to_clip.values()),
)
self.metrics.log_value(
key=(module_id, f"gradients_{optimizer_name}_global_norm"),
value=global_norm,
window=1,
)
postprocessed_grads.update(grad_dict_to_clip)

return postprocessed_grads

Expand Down Expand Up @@ -1576,6 +1597,11 @@ def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None:
def _get_clip_function() -> Callable:
"""Returns the gradient clipping function to use, given the framework."""

@staticmethod
@abc.abstractmethod
def _get_global_norm_function() -> Callable:
"""Returns the global norm function to use, given the framework."""

def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
"""Logs this iteration's steps trained, based on given `batch`."""

Expand Down
5 changes: 5 additions & 0 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,8 @@ def _get_clip_function() -> Callable:
from ray.rllib.utils.tf_utils import clip_gradients

return clip_gradients

@staticmethod
@override(Learner)
def _get_global_norm_function() -> Callable:
return tf.linalg.global_norm
7 changes: 7 additions & 0 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,10 @@ def _get_clip_function() -> Callable:
from ray.rllib.utils.torch_utils import clip_gradients

return clip_gradients

@staticmethod
@override(Learner)
def _get_global_norm_function() -> Callable:
from ray.rllib.utils.torch_utils import compute_global_norm

return compute_global_norm
78 changes: 49 additions & 29 deletions rllib/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)

if TYPE_CHECKING:
from ray.rllib.core.learner.learner import ParamDict
from ray.rllib.core.learner.learner import ParamDict, ParamList
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2

Expand Down Expand Up @@ -106,7 +106,7 @@ def clip_gradients(
*,
grad_clip: Optional[float] = None,
grad_clip_by: str = "value",
) -> Optional[float]:
) -> TensorType:
"""Performs gradient clipping on a grad-dict based on a clip value and clip mode.

Changes the provided gradient dict in place.
Expand Down Expand Up @@ -147,46 +147,66 @@ def clip_gradients(
assert (
grad_clip_by == "global_norm"
), f"`grad_clip_by` ({grad_clip_by}) must be one of [value|norm|global_norm]!"

grads = [g for g in gradients_dict.values() if g is not None]
norm_type = 2.0
if len(grads) == 0:
return torch.tensor(0.0)
device = grads[0].device

total_norm = torch.norm(
torch.stack(
[
torch.norm(g.detach(), norm_type)
# Note, we want to avoid overflow in the norm computation, this does
# not affect the gradients themselves as we clamp by multiplying and
# not by overriding tensor values.
.nan_to_num(neginf=-10e8, posinf=10e8).to(device)
for g in grads
]
),
norm_type,
).nan_to_num(neginf=-10e8, posinf=10e8)
if torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. "
)
gradients_list = list(gradients_dict.values())
total_norm = compute_global_norm(gradients_list)
# We do want the coefficient to be in between 0.0 and 1.0, therefore
# if the global_norm is smaller than the clip value, we use the clip value
# as normalization constant.
device = gradients_list[0].device
clip_coef = grad_clip / torch.maximum(
torch.tensor(grad_clip).to(device), total_norm + 1e-6
)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to
# 1, but doing so avoids a `if clip_coef < 1:` conditional which can require a
# CPU <=> device synchronization when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in grads:
g.detach().mul_(clip_coef_clamped.to(g.device))
for g in gradients_list:
if g is not None:
g.detach().mul_(clip_coef_clamped.to(g.device))
return total_norm


@PublicAPI
def compute_global_norm(gradients_list: "ParamList") -> TensorType:
"""Computes the global norm for a gradients dict.

Args:
gradients_list: The gradients list containing parameters.

Returns:
Returns the global norm of all tensors in `gradients_list`.
"""
# Define the norm type to be L2.
norm_type = 2.0
# If we have no grads, return zero.
if len(gradients_list) == 0:
return torch.tensor(0.0)
device = gradients_list[0].device

# Compute the global norm.
total_norm = torch.norm(
torch.stack(
[
torch.norm(g.detach(), norm_type)
# Note, we want to avoid overflow in the norm computation, this does
# not affect the gradients themselves as we clamp by multiplying and
# not by overriding tensor values.
.nan_to_num(neginf=-10e8, posinf=10e8).to(device)
for g in gradients_list
if g is not None
]
),
norm_type,
).nan_to_num(neginf=-10e8, posinf=10e8)
if torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. "
)
# Return the global norm.
return total_norm


@PublicAPI
def concat_multi_gpu_td_errors(
policy: Union["TorchPolicy", "TorchPolicyV2"]
Expand Down
1 change: 1 addition & 0 deletions rllib/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
Param = Union["torch.Tensor", "tf.Variable"]
ParamRef = Hashable
ParamDict = Dict[ParamRef, Param]
ParamList = List[Param]

# A single learning rate or a learning rate schedule (list of sub-lists, each of
# the format: [ts (int), lr_to_reach_by_ts (float)]).
Expand Down