This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
[FSDP2] precompute scale after optimizer.step for dynamic scaling #266
Closed
Closed
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
9d5595c
[FSDP2] set vocab_size=32 to avoid must be divisible by 16 error
weifengpy e7005c2
precast after optimizer.step and dump profiler traces
weifengpy e41d589
Merge branch 'main' into fsdp2
weifengpy e0bee10
precast and preamax unit test
weifengpy c0ba5a2
remove duplicate vocab
weifengpy 8da238e
fused amax
weifengpy ffff5ed
Merge branch 'main' into fsdp2
weifengpy aefa21b
use FP8_TYPES and max
weifengpy d4a1db7
commit all changes before cleaning
weifengpy d36e79b
pre_compute and flatten / unflatten
weifengpy 6f244a2
remove unused constant
weifengpy dc5eab0
torch.compile works
weifengpy 546e979
eager ready
weifengpy 229ede6
linter
weifengpy d5b3ff6
linter
weifengpy 4f05e04
flatten tensor
weifengpy 3de59af
commit all changes for review before rebasing
weifengpy ffcd197
rebase on unified float8linear
weifengpy 6b18947
Merge branch 'pytorch-labs:main' into fsdp2
weifengpy 562424c
move precompute to fsdp_utils.py
weifengpy 75e0e45
simplify amax calc
weifengpy fe95f8b
explain _pre_computed_amax
weifengpy 1cbaa13
fix linter
weifengpy fe2e0a0
document precompute_float8_amax_for_fsdp
weifengpy e4eaa2a
rename pre_compute to precompute
weifengpy e4245e4
Merge branch 'main' into fsdp2
weifengpy e12c973
remove clamp_amax=True/False
weifengpy 9ef67fb
precompute scale
weifengpy fa2f08a
unit test for precomputing scales
weifengpy ba085e5
add precompute scale in README
weifengpy ac0afb0
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy 8e56dfc
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import math | ||
from typing import List | ||
|
||
import torch | ||
import torch.nn as nn | ||
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor | ||
from float8_experimental.float8_linear import Float8Linear, TensorScalingType | ||
from float8_experimental.float8_utils import EPS | ||
|
||
|
||
@torch.no_grad() | ||
def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | ||
""" | ||
Calculate scale dynamically for all float8 parameters. | ||
This should be run after the optimizer step. It performs a single all-reduce to compute the | ||
scales for all float8 weights. | ||
Example usage: | ||
model(input).sum().backward() | ||
optim.step() | ||
precompute_float8_dynamic_scale_for_fsdp(model) | ||
""" | ||
from torch.distributed._tensor import DTensor | ||
|
||
if any( | ||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED | ||
for m in module.modules() | ||
): | ||
raise NotImplementedError("Only supports delayed scaling") | ||
float8_linears: List[Float8Linear] = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this expensive for real models? if yes, maybe we can offer option to precompute this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My intuition is that this should be pretty fast as the number of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
m | ||
for m in module.modules() | ||
if isinstance(m, Float8Linear) | ||
and isinstance(m.weight, DTensor) | ||
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) | ||
] | ||
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||
|
||
if not weights: | ||
return | ||
|
||
# inf-norm is equivalent to max(abs(w)) | ||
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial | ||
amax_tensor = torch.vstack(max_weights) # Partial | ||
# clamp is dispatched through DTensor | ||
# it will issue a single all-reduce | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate | ||
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate | ||
if amax_tensor.dtype is torch.float16: | ||
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) | ||
scales = torch.split(scale_tensor, 1) # Replicate | ||
for scale, float8_linear in zip(scales, float8_linears): | ||
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
improve docstring with example API usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! can we add this to the README?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just added API usage to README
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we can make sure
dynamic
is in the name, since this is specific to dynamic scaling?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renaming to
precompute_float8_dynamic_scale_for_fsdp