From 71b8eaecf7a9f66376fc74b693d09d4d839361c9 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Thu, 18 Jul 2024 19:17:30 -0700 Subject: [PATCH] add torch.compile + FSDP2 float8 all-gather in CI (#468) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fixed my bug in float8_experimental. now we can torch.compile transfromer blocks with FSDP float8 all-gather https://github.com/pytorch-labs/float8_experimental/pull/321 local test: `CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.compile` profiler traces: I can see compiled region in cpu thread and float8 malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream Screenshot 2024-07-18 at 4 22 17 PM --- test_runner.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test_runner.py b/test_runner.py index 6a7b6b1a..b0eb04c9 100755 --- a/test_runner.py +++ b/test_runner.py @@ -305,6 +305,18 @@ def build_test_list(): "FSDP2 with float8 all-gather and precomputed dynamic scales", "fsdp2_float8_all_gather_precompute_dynamic_scales", ), + OverrideDefinitions( + [ + [ + "--training.enable_float8_linear", + "--training.enable_fsdp_float8_all_gather", + "--training.precompute_float8_dynamic_scale_for_fsdp", + "--training.compile", + ] + ], + "FSDP2 with float8 all-gather and precomputed dynamic scales", + "fsdp2_float8_all_gather_precompute_dynamic_scales_compile", + ), OverrideDefinitions( [ [