-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchtitan][optim] Add fused as an option in train config #355
Conversation
11826b1
to
e8461b4
Compare
I am curious if we have any experiments to see the performance difference with |
Gonna talk to Tianyu to learn how to run the perf experiments on the new 128 GPUs. We can totally wait for the result before landing this. |
can we add some 8 GPU numbers at least? 128 GPU can be done separately |
this PR ( |
@weifengpy |
Ah got you. I checked the trace and 2000ms indeed comes from foreach=False |
With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer. pytorch/pytorch#125369 pytorch/pytorch#126423 pytorch/pytorch#126750 Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default [llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml). Observation: For `fused = True` and `fused = False`, we observed similar loss curve and memory usage. wps is + ~100 and mfu is + 1.5-2% when fused = True. Below are the logs for the last 100 steps for both. ``` **Fused = False** [rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds [rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910 loss: 4.6039 memory: 59.48GiB(75.15%) wps: 2,217 mfu: 41.16% [rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920 loss: 4.6427 memory: 59.48GiB(75.15%) wps: 2,632 mfu: 48.85% [rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930 loss: 4.6339 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940 loss: 4.5964 memory: 59.48GiB(75.15%) wps: 2,631 mfu: 48.84% [rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950 loss: 4.6477 memory: 59.48GiB(75.15%) wps: 2,611 mfu: 48.47% [rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960 loss: 4.8137 memory: 59.48GiB(75.15%) wps: 2,626 mfu: 48.75% [rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970 loss: 4.5962 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980 loss: 4.5450 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990 loss: 4.5840 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000 loss: 4.5351 memory: 59.48GiB(75.15%) wps: 2,615 mfu: 48.53% [rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds [rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed **Fused = True** [rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910 loss: 4.6091 memory: 59.48GiB(75.15%) wps: 2,341 mfu: 43.46% [rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920 loss: 4.6468 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930 loss: 4.6365 memory: 59.48GiB(75.15%) wps: 2,715 mfu: 50.40% [rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940 loss: 4.6021 memory: 59.48GiB(75.15%) wps: 2,716 mfu: 50.41% [rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950 loss: 4.6522 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960 loss: 4.8163 memory: 59.48GiB(75.15%) wps: 2,717 mfu: 50.44% [rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970 loss: 4.6026 memory: 59.48GiB(75.15%) wps: 2,721 mfu: 50.51% [rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980 loss: 4.5491 memory: 59.48GiB(75.15%) wps: 2,720 mfu: 50.49% [rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990 loss: 4.5859 memory: 59.48GiB(75.15%) wps: 2,719 mfu: 50.47% [rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000 loss: 4.5396 memory: 59.48GiB(75.15%) wps: 2,714 mfu: 50.38% [rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed ```
With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer. pytorch/pytorch#125369 pytorch/pytorch#126423 pytorch/pytorch#126750 Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default [llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml). Observation: For `fused = True` and `fused = False`, we observed similar loss curve and memory usage. wps is + ~100 and mfu is + 1.5-2% when fused = True. Below are the logs for the last 100 steps for both. ``` **Fused = False** [rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds [rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910 loss: 4.6039 memory: 59.48GiB(75.15%) wps: 2,217 mfu: 41.16% [rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920 loss: 4.6427 memory: 59.48GiB(75.15%) wps: 2,632 mfu: 48.85% [rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930 loss: 4.6339 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940 loss: 4.5964 memory: 59.48GiB(75.15%) wps: 2,631 mfu: 48.84% [rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950 loss: 4.6477 memory: 59.48GiB(75.15%) wps: 2,611 mfu: 48.47% [rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960 loss: 4.8137 memory: 59.48GiB(75.15%) wps: 2,626 mfu: 48.75% [rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970 loss: 4.5962 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980 loss: 4.5450 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990 loss: 4.5840 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000 loss: 4.5351 memory: 59.48GiB(75.15%) wps: 2,615 mfu: 48.53% [rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds [rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed **Fused = True** [rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910 loss: 4.6091 memory: 59.48GiB(75.15%) wps: 2,341 mfu: 43.46% [rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920 loss: 4.6468 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930 loss: 4.6365 memory: 59.48GiB(75.15%) wps: 2,715 mfu: 50.40% [rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940 loss: 4.6021 memory: 59.48GiB(75.15%) wps: 2,716 mfu: 50.41% [rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950 loss: 4.6522 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960 loss: 4.8163 memory: 59.48GiB(75.15%) wps: 2,717 mfu: 50.44% [rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970 loss: 4.6026 memory: 59.48GiB(75.15%) wps: 2,721 mfu: 50.51% [rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980 loss: 4.5491 memory: 59.48GiB(75.15%) wps: 2,720 mfu: 50.49% [rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990 loss: 4.5859 memory: 59.48GiB(75.15%) wps: 2,719 mfu: 50.47% [rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000 loss: 4.5396 memory: 59.48GiB(75.15%) wps: 2,714 mfu: 50.38% [rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed ```
With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer.
pytorch/pytorch#125369
pytorch/pytorch#126423
pytorch/pytorch#126750
Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default llama_8b.toml.
Observation:
For
fused = True
andfused = False
, we observed similar loss curve and memory usage.wps is + ~100 and mfu is + 1.5-2% when fused = True.
Below are the logs for the last 100 steps for both.