diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1e0d7c60..86bcffd8 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -339,15 +339,11 @@ def __init__(self): ) self.parser.add_argument( "--training.fp8_linear", - type=str, - default="", - choices=[ - "dynamic", - "", - ], # TODO: add "delayed" option back in when supported + action="store_true", help=""" - Type of fp8 linear quantization to apply to the model ['', 'dynamic']. - This features requires you to install 'float8_experimental' which can be found + If true, swaps `torch.nn.Linear` with `Float8Linear` with + default settings (dynamic scaling). + This feature requires you to install 'float8_experimental' which can be found here: https://github.com/pytorch-labs/float8_experimental """, ) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 9bd88cae..0bd0900c 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -21,17 +21,14 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): """ - This function converts the linear layers to one of the fp8 types: - - Float8DynamicLinear: Dynamic quantization of the weights and the activations - - [Not Yet Supported] Float8Linear: Uses a history of amaxs to quantize the weights and activations + This function converts the linear layers to `Float8Linear`. Note that today, + only dynamic tensor scaling (the default) is supported. This will mutate the model inplace. """ - linear_type = job_config.training.fp8_linear.lower() + use_fp8_linear = job_config.training.fp8_linear try: - from float8_experimental.float8_dynamic_linear import Float8DynamicLinear - - # from float8_experimental.float8_linear import Float8Linear + from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) @@ -39,16 +36,7 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): raise ImportError( "float8_experimental is not installed. Please install it to use fp8 linear layers." ) from exc - if linear_type: - linear_type_map = { - # "delayed": Float8Linear, # TODO: add "delayed" option back in when supported - "dynamic": Float8DynamicLinear, - } - assert ( - linear_type in linear_type_map - ), f"Invalid fp8 linear type: {linear_type}, supported types: {', '.join(linear_type_map.keys())}." - float8_linear_type = linear_type_map[linear_type.lower()] - - # Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type - swap_linear_with_float8_linear(model, float8_linear_type) - logger.info(f"Swapped to {linear_type} float8 linear layers") + if use_fp8_linear: + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + swap_linear_with_float8_linear(model, Float8Linear) + logger.info("Swapped to Float8Linear layers") diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index da634031..1b4b3539 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = "" +fp8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index f3048ac4..719fc445 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = "" +fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 97b1bc71..c8ec9595 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = "" +fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 95b4c496..7e2196fb 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -fp8_linear = "" +fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index d498e677..218f3783 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = "" +fp8_linear = false compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index f194addb..2fb89004 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = "" +fp8_linear = false compile = false dataset = "c4"