Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix autocast deprecation warnings #1277

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.jit import gpu_autocast_ctx

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
Expand Down Expand Up @@ -333,7 +334,7 @@ def test_grad_scaler(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -342,7 +343,7 @@ def test_grad_scaler(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down Expand Up @@ -384,7 +385,7 @@ def test_grad_scaler_capturable(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -393,7 +394,7 @@ def test_grad_scaler_capturable(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down Expand Up @@ -442,7 +443,7 @@ def test_grad_scaler_capturable_master(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -451,7 +452,7 @@ def test_grad_scaler_capturable_master(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down
41 changes: 31 additions & 10 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,38 @@ def _get_active_autocast_contexts():
"""
autocast_cached = torch.is_autocast_cache_enabled()

gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)

cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu",
enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)
else:
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)

cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)

return gpu_autocast_ctx, cpu_autocast_ctx

Expand Down
14 changes: 10 additions & 4 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""NVFuser functions and JIT utilities"""
import os
from typing import Callable, Optional, Tuple
from functools import partial

import torch

Expand Down Expand Up @@ -33,6 +34,11 @@
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable

if torch.__version__ >= "2.4":
gpu_autocast_ctx = partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast


def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
Expand Down Expand Up @@ -110,7 +116,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:

def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)
Expand All @@ -120,7 +126,7 @@ def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)
Expand Down Expand Up @@ -161,7 +167,7 @@ def bias_dropout_add_fused_train(
) -> torch.Tensor:
"""Disable native AMP and enable grad for BDA"""
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_train_(x, bias, residual, prob)


Expand All @@ -177,7 +183,7 @@ def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Disable native AMP for BDA"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_inference_(x, bias, residual, prob)


Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ..constants import dist_group_type
from ..float8_tensor import Float8Tensor
from ..utils import torch_get_autocast_gpu_dtype

__all__ = ["initialize_ub", "destroy_ub"]

Expand Down Expand Up @@ -619,7 +620,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
return

# All checks after this have already been performed once, thus skip
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
layernorm_fwd_inf,
)
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype

__all__ = ["LayerNorm"]

Expand Down Expand Up @@ -193,7 +193,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype


__all__ = ["RMSNorm"]
Expand Down Expand Up @@ -190,7 +190,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformer_engine.pytorch.utils import (
cast_if_needed,
get_default_init_method,
torch_get_autocast_gpu_dtype,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
Expand Down Expand Up @@ -677,7 +678,7 @@ def forward(

# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())

# Self attention.
self_attention_outputs = self.self_attention(
Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,12 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2


def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
return torch.get_autocast_dtype("cuda")
return torch.get_autocast_gpu_dtype()
Loading