Skip to content

Commit

Permalink
[PyTorch] Fix FP8 activation recompute (#1254)
Browse files Browse the repository at this point in the history
Fix FP8 activation recompute

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Oct 16, 2024
1 parent 6e90fcb commit a518151
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
11 changes: 11 additions & 0 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator):
activations, followed by calculation of gradients using these values.
"""

_is_first_fp8_module: List = []

def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False):
super().__init__()
self.activation_recompute = activation_recompute
Expand All @@ -218,6 +220,15 @@ def __enter__(self):
)
_FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase

if self.activation_recompute and not self.recompute_phase:
activation_recompute_forward._is_first_fp8_module.append(
FP8GlobalStateManager.IS_FIRST_FP8_MODULE
)
if self.activation_recompute and self.recompute_phase:
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = (
activation_recompute_forward._is_first_fp8_module.pop(0)
)

def __exit__(self, *exc_details):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
Expand Down
9 changes: 5 additions & 4 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
Expand Down Expand Up @@ -361,10 +362,10 @@ def forward(
ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module

# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
reduce_scatter_along_first_dim,
gather_along_first_dim,
use_reentrant_activation_recompute,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
Expand Down Expand Up @@ -516,7 +517,10 @@ def forward(
if ctx.fp8 and requires_grad(
inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias
):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module

# Row Parallel Linear
if ub_overlap_rs:
Expand Down
9 changes: 5 additions & 4 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
allreduce,
reduce_scatter_along_first_dim,
gather_along_first_dim,
in_fp8_activation_recompute_phase,
_fsdp_scatter_tensors,
_fsdp_gather_tensors,
)
Expand Down Expand Up @@ -349,10 +350,10 @@ def forward(
ctx.is_input_fp8 = is_input_fp8
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module

# Row Parallel Linear
if ub_overlap_rs:
Expand Down

0 comments on commit a518151

Please sign in to comment.