-
Notifications
You must be signed in to change notification settings - Fork 22.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DTensor][Optim] Add support for fused_adam and fused_adamw when lr is a tensor #126750
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126750
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3287eeb with merge base 980f5ac (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
7689f12
to
cc1d023
Compare
cc1d023
to
82d79b1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix! Please see inlined comments.
@@ -297,6 +297,41 @@ def unwrap_to_op_info( | |||
local_kwargs: Dict[str, object] = {} | |||
mesh: Optional[DeviceMesh] = None | |||
|
|||
def get_replicate_spec( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: name this as try_get_replicate_spec
{**config, "lr": torch.tensor(0.1)} | ||
for config in fused_adam_float_lr_configs | ||
] | ||
fused_adam_tensor_lr_configs.extend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand why you need to extend the list again? L121-124 did the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! The only difference is one lr is torch.tensor(0.1)
and the other one is torch.tensor([0.1])
.
{**config, "lr": torch.tensor(0.1)} | ||
for config in fused_adamw_float_lr_configs | ||
] | ||
fused_adamw_tensor_lr_configs.extend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, what's the rationle of this extend
?
82d79b1
to
c55eaf6
Compare
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 11 jobs have failed, first few of them are: pull / linux-focal-py3.12-clang10 / test (default, 3, 3, linux.2xlarge), pull / linux-focal-py3.8-clang10 / test (default, 3, 3, linux.2xlarge), pull / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu), pull / linux-jammy-py3.8-gcc11 / test (default, 3, 3, linux.2xlarge), pull / linux-jammy-py3.10-clang15-asan / test (default, 3, 6, linux.4xlarge) Details for Dev Infra teamRaised by workflow job |
Failures for public bindings are related to |
@pytorchbot merge -i |
Merge failedReason: 11 jobs have failed, first few of them are: pull / linux-focal-py3.12-clang10 / test (default, 3, 3, linux.2xlarge), pull / linux-focal-py3.8-clang10 / test (default, 3, 3, linux.2xlarge), pull / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu), pull / linux-jammy-py3.8-gcc11 / test (default, 3, 3, linux.2xlarge), pull / linux-jammy-py3.10-clang15-asan / test (default, 3, 6, linux.4xlarge) Details for Dev Infra teamRaised by workflow job |
@pytorchbot rebase -s |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
c55eaf6
to
3287eeb
Compare
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer. pytorch/pytorch#125369 pytorch/pytorch#126423 pytorch/pytorch#126750 Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default [llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml). Observation: For `fused = True` and `fused = False`, we observed similar loss curve and memory usage. wps is + ~100 and mfu is + 1.5-2% when fused = True. Below are the logs for the last 100 steps for both. ``` **Fused = False** [rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds [rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910 loss: 4.6039 memory: 59.48GiB(75.15%) wps: 2,217 mfu: 41.16% [rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920 loss: 4.6427 memory: 59.48GiB(75.15%) wps: 2,632 mfu: 48.85% [rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930 loss: 4.6339 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940 loss: 4.5964 memory: 59.48GiB(75.15%) wps: 2,631 mfu: 48.84% [rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950 loss: 4.6477 memory: 59.48GiB(75.15%) wps: 2,611 mfu: 48.47% [rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960 loss: 4.8137 memory: 59.48GiB(75.15%) wps: 2,626 mfu: 48.75% [rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970 loss: 4.5962 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980 loss: 4.5450 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990 loss: 4.5840 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000 loss: 4.5351 memory: 59.48GiB(75.15%) wps: 2,615 mfu: 48.53% [rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds [rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed **Fused = True** [rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910 loss: 4.6091 memory: 59.48GiB(75.15%) wps: 2,341 mfu: 43.46% [rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920 loss: 4.6468 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930 loss: 4.6365 memory: 59.48GiB(75.15%) wps: 2,715 mfu: 50.40% [rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940 loss: 4.6021 memory: 59.48GiB(75.15%) wps: 2,716 mfu: 50.41% [rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950 loss: 4.6522 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960 loss: 4.8163 memory: 59.48GiB(75.15%) wps: 2,717 mfu: 50.44% [rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970 loss: 4.6026 memory: 59.48GiB(75.15%) wps: 2,721 mfu: 50.51% [rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980 loss: 4.5491 memory: 59.48GiB(75.15%) wps: 2,720 mfu: 50.49% [rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990 loss: 4.5859 memory: 59.48GiB(75.15%) wps: 2,719 mfu: 50.47% [rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000 loss: 4.5396 memory: 59.48GiB(75.15%) wps: 2,714 mfu: 50.38% [rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed ```
With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer. pytorch/pytorch#125369 pytorch/pytorch#126423 pytorch/pytorch#126750 Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default [llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml). Observation: For `fused = True` and `fused = False`, we observed similar loss curve and memory usage. wps is + ~100 and mfu is + 1.5-2% when fused = True. Below are the logs for the last 100 steps for both. ``` **Fused = False** [rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds [rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910 loss: 4.6039 memory: 59.48GiB(75.15%) wps: 2,217 mfu: 41.16% [rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920 loss: 4.6427 memory: 59.48GiB(75.15%) wps: 2,632 mfu: 48.85% [rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930 loss: 4.6339 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940 loss: 4.5964 memory: 59.48GiB(75.15%) wps: 2,631 mfu: 48.84% [rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950 loss: 4.6477 memory: 59.48GiB(75.15%) wps: 2,611 mfu: 48.47% [rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960 loss: 4.8137 memory: 59.48GiB(75.15%) wps: 2,626 mfu: 48.75% [rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970 loss: 4.5962 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980 loss: 4.5450 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990 loss: 4.5840 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000 loss: 4.5351 memory: 59.48GiB(75.15%) wps: 2,615 mfu: 48.53% [rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds [rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed **Fused = True** [rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910 loss: 4.6091 memory: 59.48GiB(75.15%) wps: 2,341 mfu: 43.46% [rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920 loss: 4.6468 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930 loss: 4.6365 memory: 59.48GiB(75.15%) wps: 2,715 mfu: 50.40% [rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940 loss: 4.6021 memory: 59.48GiB(75.15%) wps: 2,716 mfu: 50.41% [rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950 loss: 4.6522 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960 loss: 4.8163 memory: 59.48GiB(75.15%) wps: 2,717 mfu: 50.44% [rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970 loss: 4.6026 memory: 59.48GiB(75.15%) wps: 2,721 mfu: 50.51% [rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980 loss: 4.5491 memory: 59.48GiB(75.15%) wps: 2,720 mfu: 50.49% [rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990 loss: 4.5859 memory: 59.48GiB(75.15%) wps: 2,719 mfu: 50.47% [rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000 loss: 4.5396 memory: 59.48GiB(75.15%) wps: 2,714 mfu: 50.38% [rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed ```
With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer. pytorch/pytorch#125369 pytorch/pytorch#126423 pytorch/pytorch#126750 Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default [llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml). Observation: For `fused = True` and `fused = False`, we observed similar loss curve and memory usage. wps is + ~100 and mfu is + 1.5-2% when fused = True. Below are the logs for the last 100 steps for both. ``` **Fused = False** [rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds [rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910 loss: 4.6039 memory: 59.48GiB(75.15%) wps: 2,217 mfu: 41.16% [rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920 loss: 4.6427 memory: 59.48GiB(75.15%) wps: 2,632 mfu: 48.85% [rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930 loss: 4.6339 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940 loss: 4.5964 memory: 59.48GiB(75.15%) wps: 2,631 mfu: 48.84% [rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950 loss: 4.6477 memory: 59.48GiB(75.15%) wps: 2,611 mfu: 48.47% [rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960 loss: 4.8137 memory: 59.48GiB(75.15%) wps: 2,626 mfu: 48.75% [rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970 loss: 4.5962 memory: 59.48GiB(75.15%) wps: 2,628 mfu: 48.78% [rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980 loss: 4.5450 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990 loss: 4.5840 memory: 59.48GiB(75.15%) wps: 2,627 mfu: 48.76% [rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000 loss: 4.5351 memory: 59.48GiB(75.15%) wps: 2,615 mfu: 48.53% [rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds [rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed **Fused = True** [rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910 loss: 4.6091 memory: 59.48GiB(75.15%) wps: 2,341 mfu: 43.46% [rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920 loss: 4.6468 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930 loss: 4.6365 memory: 59.48GiB(75.15%) wps: 2,715 mfu: 50.40% [rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940 loss: 4.6021 memory: 59.48GiB(75.15%) wps: 2,716 mfu: 50.41% [rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950 loss: 4.6522 memory: 59.48GiB(75.15%) wps: 2,718 mfu: 50.45% [rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960 loss: 4.8163 memory: 59.48GiB(75.15%) wps: 2,717 mfu: 50.44% [rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970 loss: 4.6026 memory: 59.48GiB(75.15%) wps: 2,721 mfu: 50.51% [rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980 loss: 4.5491 memory: 59.48GiB(75.15%) wps: 2,720 mfu: 50.49% [rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990 loss: 4.5859 memory: 59.48GiB(75.15%) wps: 2,719 mfu: 50.47% [rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000 loss: 4.5396 memory: 59.48GiB(75.15%) wps: 2,714 mfu: 50.38% [rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000 [rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds [rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed ```
Fixes #126670
In this PR, we update the following:
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @yf225 @chauhang @d4l3k @msaroufim