Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[FSDP2] precompute scale after optimizer.step for dynamic scaling #266

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9d5595c
[FSDP2] set vocab_size=32 to avoid must be divisible by 16 error
weifengpy May 21, 2024
e7005c2
precast after optimizer.step and dump profiler traces
weifengpy May 21, 2024
e41d589
Merge branch 'main' into fsdp2
weifengpy May 21, 2024
e0bee10
precast and preamax unit test
weifengpy May 24, 2024
c0ba5a2
remove duplicate vocab
weifengpy May 24, 2024
8da238e
fused amax
weifengpy May 30, 2024
ffff5ed
Merge branch 'main' into fsdp2
weifengpy Jun 6, 2024
aefa21b
use FP8_TYPES and max
weifengpy Jun 6, 2024
d4a1db7
commit all changes before cleaning
weifengpy Jun 6, 2024
d36e79b
pre_compute and flatten / unflatten
weifengpy Jun 6, 2024
6f244a2
remove unused constant
weifengpy Jun 6, 2024
dc5eab0
torch.compile works
weifengpy Jun 6, 2024
546e979
eager ready
weifengpy Jun 6, 2024
229ede6
linter
weifengpy Jun 6, 2024
d5b3ff6
linter
weifengpy Jun 6, 2024
4f05e04
flatten tensor
weifengpy Jun 25, 2024
3de59af
commit all changes for review before rebasing
weifengpy Jul 8, 2024
ffcd197
rebase on unified float8linear
weifengpy Jul 9, 2024
6b18947
Merge branch 'pytorch-labs:main' into fsdp2
weifengpy Jul 9, 2024
562424c
move precompute to fsdp_utils.py
weifengpy Jul 9, 2024
75e0e45
simplify amax calc
weifengpy Jul 9, 2024
fe95f8b
explain _pre_computed_amax
weifengpy Jul 9, 2024
1cbaa13
fix linter
weifengpy Jul 9, 2024
fe2e0a0
document precompute_float8_amax_for_fsdp
weifengpy Jul 9, 2024
e4eaa2a
rename pre_compute to precompute
weifengpy Jul 9, 2024
e4245e4
Merge branch 'main' into fsdp2
weifengpy Jul 10, 2024
e12c973
remove clamp_amax=True/False
weifengpy Jul 10, 2024
9ef67fb
precompute scale
weifengpy Jul 10, 2024
fa2f08a
unit test for precomputing scales
weifengpy Jul 10, 2024
ba085e5
add precompute scale in README
weifengpy Jul 10, 2024
ac0afb0
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy Jul 11, 2024
8e56dfc
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import tensor_to_scale
from float8_experimental.float8_utils import amax_to_scale, tensor_to_scale
from torch._prims_common import suggest_memory_format


Expand Down Expand Up @@ -151,6 +151,7 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
self._tensor = tensor
self._mm_config = mm_config
self._pre_computed_amax = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be added to __tensor_flatten__?

can we add some comments on intended usage of this?

Copy link
Contributor

@drisspg drisspg Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on adding to flatten/unflatten and comments/ intended usage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Expand Down Expand Up @@ -190,9 +191,20 @@ def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if _pre_computed_amax, we skip tensor_to_amax and directly do amax_to_scale

float8_tensor = cast_to_float8_e4m3fn(
self._tensor, self._mm_config, reduce_amax=True
)
if self._pre_computed_amax is not None:
scale = amax_to_scale(
self._pre_computed_amax,
torch.float8_e4m3fn,
self._pre_computed_amax.dtype,
clamp_amax=False,
)
float8_tensor = Float8Tensor.to_float8(
self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config
)
else:
float8_tensor = cast_to_float8_e4m3fn(
self._tensor, self._mm_config, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
Expand Down
41 changes: 39 additions & 2 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@
# LICENSE file in the root directory of this source tree.
import copy
import logging

import math
import warnings
from enum import auto, Enum
from typing import Callable, List, Optional, Type

import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_dynamic_linear import (
Float8DynamicLinear,
WeightWithDynamicFloat8CastTensor,
)
from float8_experimental.float8_linear import Float8Linear

from float8_experimental.float8_utils import amax_history_to_scale_stack
from float8_experimental.float8_utils import amax_history_to_scale_stack, EPS
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -322,3 +328,34 @@ def inner_func():
for child in fp8_layers:
# Set a flag to signal amaxes/scales are ready
child.amax_and_scale_synced = True


def precompute_float8_amax(module: nn.Module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put this in distributed_utils.py?

I think the function name should include that this is intended for FSDP2 with float8 all-gather

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving to fsdp_utils.py according to PR #310

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indicating fsdp by renaming to precompute_float8_amax_for_fsdp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy do you plan / want to use compile on this, and are there any gaps around here that you think would be good to prioritize on the compile side?

This is mostly just me remembering @awgu mention a while ago that he thought compile added noticeable runtime overhead, and I can't remember if it was for this specific case. If it is, and we think compiling this code would be useful, I can prioritize looking into the runtime overhead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bdhirsh, I plan to polish and land this PR without compile next week to conclude H1. most importantly add _pre_computed_amax to flatten/unflatten

Reducing runtime overhead from torch.compile is still meaningful since we want torch.compile(fp8 casting) in FSDP2 pre-forward hooks. would it be helpful if I work on a mini repro with profiler traces? Want to unblock you in the short-term

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a mini repro showing bad runtime overheads with compile, that would be great!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bdhirsh , I have created a repro pytorch/pytorch#129457 . I highlighted extra cpu overhead and gpu time for torch.compile(mode="reduce-overhead")

from torch.distributed._tensor import DTensor

if any(isinstance(m, Float8Linear) for m in module.modules()):
raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear")
float8_linears: List[Float8DynamicLinear] = [
m
for m in module.modules()
if isinstance(m, Float8DynamicLinear)
and isinstance(m.weight, DTensor)
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_amaxes(weights: List[DTensor]):
max_weights = torch._foreach_norm(weights, ord=math.inf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment that this is equivalent to max(abs(w))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

amax_tensor = torch.vstack(max_weights)
amax_tensor = torch.clamp(amax_tensor, EPS) # R
Copy link
Contributor Author

@weifengpy weifengpy Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.clamp calls all_reduce. I avoided calling it again in amax_to_scale(clamp_amax=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are relying on torch.clamp to run the all-reduce implicitly from changing sharding from partial to replicate?

If this fragments the code, could we just all-reduce the amax tensor and then leave the clamp to amax_to_scale? I agree the current way is faster since we are doing one clamp for all amaxes, but in case float8 folks are not happy with this fragmentation, this seems like another way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestions. I can collect feedback from float8 folks if they have a preference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just comment with what is going on? I think it's fine as long as the code is easy to understand and there is no magic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

amaxes = torch.split(amax_tensor, 1) # R
return amaxes

if weights:
amaxes = compute_amaxes(weights)
for amax, float8_linear in zip(amaxes, float8_linears):
float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor
else:
warnings.warn(
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
)
11 changes: 9 additions & 2 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,24 @@

@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
amax: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
clamp_amax: bool = True,
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing. How about precomputing the scale instead so we don't have to have gotchas like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion! I changed the API to precompute scale and it shows another 9% speed up in unit test vs precomputing amax

fsdp_pre_all_gather is also greatly simplified because of using self._precomputed_scale

"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
if clamp_amax:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think if you have this on a seperate line
amax = clamp(amax, eps) if clamp_amax else amax

makes the logic a lil easier to follow

res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
res = torch.finfo(float8_dtype).max / amax
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

Expand Down
13 changes: 12 additions & 1 deletion test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
from float8_experimental.float8_linear_utils import (
precompute_float8_amax,
sync_float8_amax_and_scale_history,
)


def check_parity_no_mp(
Expand All @@ -18,6 +22,7 @@ def check_parity_no_mp(
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
module_cls: Type,
pre_compute: bool = False,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
Expand All @@ -32,6 +37,12 @@ def check_parity_no_mp(
if module_cls is Float8Linear:
sync_float8_amax_and_scale_history(model)
optim.step()
if (
model is fsdp_model
and module_cls is Float8DynamicLinear
and pre_compute
):
precompute_float8_amax(model)
test_cls.assertEqual(losses[0], losses[1])


Expand Down
26 changes: 22 additions & 4 deletions test/test_fsdp2/test_fsdp2_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,21 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
def test_transformer_parity_dynamic(self):
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
self.run_subtests(
{
"enable_fsdp_fp8_all_gather": [False, True],
"pre_compute": [False, True],
},
self._test_transformer_parity_dynamic,
)

def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
def _test_transformer_parity_dynamic(
self,
enable_fsdp_fp8_all_gather: bool,
pre_compute: bool,
):
if not enable_fsdp_fp8_all_gather and pre_compute:
return
# NOTE: Weight-tying does not compose with fp8 all-gather because the
# embedding weight and output linear weight are tied but only the
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
Expand All @@ -109,7 +120,14 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
)
check_parity_no_mp(
self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear
self,
ref_module,
ref_optim,
module,
optim,
local_inp,
Float8DynamicLinear,
pre_compute,
)

@skip_if_lt_x_gpu(2)
Expand Down
Loading