Skip to content

Commit

Permalink
add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)
Browse files Browse the repository at this point in the history
fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
pytorch-labs/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">
  • Loading branch information
weifengpy authored Jul 19, 2024
1 parent 2f989b9 commit 71b8eae
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,18 @@ def build_test_list():
"FSDP2 with float8 all-gather and precomputed dynamic scales",
"fsdp2_float8_all_gather_precompute_dynamic_scales",
),
OverrideDefinitions(
[
[
"--training.enable_float8_linear",
"--training.enable_fsdp_float8_all_gather",
"--training.precompute_float8_dynamic_scale_for_fsdp",
"--training.compile",
]
],
"FSDP2 with float8 all-gather and precomputed dynamic scales",
"fsdp2_float8_all_gather_precompute_dynamic_scales_compile",
),
OverrideDefinitions(
[
[
Expand Down

0 comments on commit 71b8eae

Please sign in to comment.