Skip to content

Commit

Permalink
switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch…
Browse files Browse the repository at this point in the history
…#436)

Summary:

After pytorch-labs/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Jul 8, 2024
1 parent 38187bc commit 49b02a6
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 34 deletions.
12 changes: 4 additions & 8 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,11 @@ def __init__(self):
)
self.parser.add_argument(
"--training.fp8_linear",
type=str,
default="",
choices=[
"dynamic",
"",
], # TODO: add "delayed" option back in when supported
action="store_true",
help="""
Type of fp8 linear quantization to apply to the model ['', 'dynamic'].
This features requires you to install 'float8_experimental' which can be found
If true, swaps `torch.nn.Linear` with `Float8Linear` with
default settings (dynamic scaling).
This feature requires you to install 'float8_experimental' which can be found
here: https://github.com/pytorch-labs/float8_experimental
""",
)
Expand Down
28 changes: 8 additions & 20 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,22 @@

def build_fp8_linear(model: nn.Module, job_config: JobConfig):
"""
This function converts the linear layers to one of the fp8 types:
- Float8DynamicLinear: Dynamic quantization of the weights and the activations
- [Not Yet Supported] Float8Linear: Uses a history of amaxs to quantize the weights and activations
This function converts the linear layers to `Float8Linear`. Note that today,
only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
linear_type = job_config.training.fp8_linear.lower()
use_fp8_linear = job_config.training.fp8_linear
try:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

# from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
except ImportError as exc:
raise ImportError(
"float8_experimental is not installed. Please install it to use fp8 linear layers."
) from exc
if linear_type:
linear_type_map = {
# "delayed": Float8Linear, # TODO: add "delayed" option back in when supported
"dynamic": Float8DynamicLinear,
}
assert (
linear_type in linear_type_map
), f"Invalid fp8 linear type: {linear_type}, supported types: {', '.join(linear_type_map.keys())}."
float8_linear_type = linear_type_map[linear_type.lower()]

# Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type
swap_linear_with_float8_linear(model, float8_linear_type)
logger.info(f"Swapped to {linear_type} float8 linear layers")
if use_fp8_linear:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
swap_linear_with_float8_linear(model, Float8Linear)
logger.info("Swapped to Float8Linear layers")
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down

0 comments on commit 49b02a6

Please sign in to comment.