Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch float8 logic from Float8DynamicLinear to Float8Linear #436

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,11 @@ def __init__(self):
)
self.parser.add_argument(
"--training.fp8_linear",
type=str,
default="",
choices=[
"dynamic",
"",
], # TODO: add "delayed" option back in when supported
action="store_true",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might want to switch this flag to be False by default so that it is an opt-in feature? The reason is that many users might be still using A100s and not too many H100 or above GPUs on the market yet, so fp8 training would not be available to those users yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought action="store_true" is off by default, lmk if I missed that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh I think you are right my bad

help="""
Type of fp8 linear quantization to apply to the model ['', 'dynamic'].
This features requires you to install 'float8_experimental' which can be found
If true, swaps `torch.nn.Linear` with `Float8Linear` with
default settings (dynamic scaling).
This feature requires you to install 'float8_experimental' which can be found
here: https://github.com/pytorch-labs/float8_experimental
""",
)
Expand Down
28 changes: 8 additions & 20 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,22 @@

def build_fp8_linear(model: nn.Module, job_config: JobConfig):
"""
This function converts the linear layers to one of the fp8 types:
- Float8DynamicLinear: Dynamic quantization of the weights and the activations
- [Not Yet Supported] Float8Linear: Uses a history of amaxs to quantize the weights and activations
This function converts the linear layers to `Float8Linear`. Note that today,
only dynamic tensor scaling (the default) is supported.

This will mutate the model inplace.
"""
linear_type = job_config.training.fp8_linear.lower()
use_fp8_linear = job_config.training.fp8_linear
try:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

# from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
except ImportError as exc:
raise ImportError(
"float8_experimental is not installed. Please install it to use fp8 linear layers."
) from exc
if linear_type:
linear_type_map = {
# "delayed": Float8Linear, # TODO: add "delayed" option back in when supported
"dynamic": Float8DynamicLinear,
}
assert (
linear_type in linear_type_map
), f"Invalid fp8 linear type: {linear_type}, supported types: {', '.join(linear_type_map.keys())}."
float8_linear_type = linear_type_map[linear_type.lower()]

# Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type
swap_linear_with_float8_linear(model, float8_linear_type)
logger.info(f"Swapped to {linear_type} float8 linear layers")
if use_fp8_linear:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
swap_linear_with_float8_linear(model, Float8Linear)
logger.info("Swapped to Float8Linear layers")
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = ""
fp8_linear = false
compile = false
dataset = "c4"

Expand Down
Loading