From 239d56fc1468106aa411f542ec809e4198fdc5c6 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 30 Jul 2024 11:57:37 -0700 Subject: [PATCH] move float8 callsites to torchao.float8 (#492) Summary: The `float8_experimental` repository moved to `torchao.float8` in https://github.com/pytorch/ao/pull/551 This PR updates `torchtitan` to use float8 from the new location. Test Plan: ``` with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags: --- .github/workflows/integration_test_4gpu.yaml | 2 +- torchtitan/config_manager.py | 4 ++-- torchtitan/float8_linear.py | 16 ++++++++-------- torchtitan/parallelisms/parallelize_llama.py | 4 +++- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/.github/workflows/integration_test_4gpu.yaml b/.github/workflows/integration_test_4gpu.yaml index 7c913b07..813e11af 100644 --- a/.github/workflows/integration_test_4gpu.yaml +++ b/.github/workflows/integration_test_4gpu.yaml @@ -39,6 +39,6 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ - python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git + USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded --ngpu 4 diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 26570ec7..33070120 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -353,8 +353,8 @@ def __init__(self): action="store_true", help=""" If true, swaps `torch.nn.Linear` with `Float8Linear`. - This feature requires you to install 'float8_experimental' which can be found - here: https://github.com/pytorch-labs/float8_experimental + This feature requires you to install 'torchao' which can be found + here: https://github.com/pytorch/ao """, ) self.parser.add_argument( diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 1651585e..658a41cc 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# [Note] Getting the 'float8_experimental' package: -# This script requires the 'float8_experimental' package to function correctly. +# [Note] Getting the 'torchao' package: +# This script requires the 'torchao' package to function correctly. # Please ensure you have this package installed from the appropriate repository. -# You can obtain it from https://github.com/pytorch-labs/float8_experimental. -# Either clone and run `pip install .` or run `pip install git+https://github.com/pytorch-labs/float8_experimental.git` +# You can obtain it from https://github.com/pytorch/ao by following the +# installation instructions. # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance @@ -48,7 +48,7 @@ def maybe_build_fp8_linear( ) return try: - from float8_experimental import ( + from torchao.float8 import ( CastConfig, convert_to_float8_training, Float8LinearConfig, @@ -83,7 +83,7 @@ def maybe_build_fp8_linear( ) except ImportError as exc: raise ImportError( - "float8_experimental is not installed. Please install it to use fp8 linear layers." + "torchao is not installed. Please install it to use fp8 linear layers." ) from exc @@ -102,7 +102,7 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp( "Skipped precomputing fp8 scales because SM90 or later is not available", ) return - from float8_experimental import precompute_float8_dynamic_scale_for_fsdp + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp precompute_float8_dynamic_scale_for_fsdp(model) @@ -121,7 +121,7 @@ def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobCo ): return - from float8_experimental import sync_float8_amax_and_scale_history + from torchao.float8 import sync_float8_amax_and_scale_history # TODO(future): see if precalculating the modules to sync over is going to # meaningfully help performance diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 3d123953..e3c6fc80 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -129,7 +129,9 @@ def get_tp_parallel_strategy_for_transformer_block( # TODO(future PR): once float8 configuration supports delayed # scaling, add a check here to enforce supported float8 all-gather # configurations - from float8_experimental.float8_tensor_parallel import ( + # TODO(future PR): add the items below to __init__.py of torchao.float8, + # and import from there + from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, PrepareFloat8ModuleInput,