Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DTensor][Optim] Add support for fused_adam and fused_adamw when lr is a tensor #126750

Closed
wants to merge 1 commit into from

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented May 21, 2024

Fixes #126670

In this PR, we update the following:

  1. lr is an kwarg. Add support to automatically turn on implict replication for kwarg. We only did this for arg previously.
  2. add associated tensor_lr ops in pointwises.py
  3. add associated unit test in test_optimizers.py

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @yf225 @chauhang @d4l3k @msaroufim

Copy link

pytorch-bot bot commented May 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126750

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3287eeb with merge base 980f5ac (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels May 21, 2024
@wz337 wz337 added module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category labels May 21, 2024
@wz337 wz337 requested a review from wanchaol May 21, 2024 03:33
@wz337 wz337 marked this pull request as ready for review May 21, 2024 03:36
@wz337 wz337 force-pushed the fix_fsdp2_with_fused_adamW branch from 7689f12 to cc1d023 Compare May 21, 2024 03:40
@wz337 wz337 requested a review from msaroufim May 21, 2024 03:50
@wz337 wz337 force-pushed the fix_fsdp2_with_fused_adamW branch from cc1d023 to 82d79b1 Compare May 21, 2024 04:32
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix! Please see inlined comments.

@@ -297,6 +297,41 @@ def unwrap_to_op_info(
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None

def get_replicate_spec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: name this as try_get_replicate_spec

{**config, "lr": torch.tensor(0.1)}
for config in fused_adam_float_lr_configs
]
fused_adam_tensor_lr_configs.extend(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand why you need to extend the list again? L121-124 did the same thing?

Copy link
Contributor Author

@wz337 wz337 May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! The only difference is one lr is torch.tensor(0.1) and the other one is torch.tensor([0.1]).

{**config, "lr": torch.tensor(0.1)}
for config in fused_adamw_float_lr_configs
]
fused_adamw_tensor_lr_configs.extend(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, what's the rationle of this extend?

@wz337 wz337 force-pushed the fix_fsdp2_with_fused_adamW branch from 82d79b1 to c55eaf6 Compare May 21, 2024 07:53
@wz337
Copy link
Contributor Author

wz337 commented May 21, 2024

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 21, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@wz337
Copy link
Contributor Author

wz337 commented May 21, 2024

Failures for public bindings are related to torch.fx.experimental.proxy_tensor.ModuleNotInstalledAsSubmoduleError, not DTensor related.

@wz337
Copy link
Contributor Author

wz337 commented May 21, 2024

@pytorchbot merge -i

@wz337
Copy link
Contributor Author

wz337 commented May 21, 2024

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix_fsdp2_with_fused_adamW onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix_fsdp2_with_fused_adamW && git pull --rebase)

@wz337
Copy link
Contributor Author

wz337 commented May 21, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

wz337 added a commit to pytorch/torchtitan that referenced this pull request Jun 6, 2024
With these three PRs landed, we can now support the option fused=True in
torchtitan for Adam and AdamW optimizer.

pytorch/pytorch#125369
pytorch/pytorch#126423
pytorch/pytorch#126750

Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default
[llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml).

Observation: 
For `fused = True` and `fused = False`, we observed similar loss curve
and memory usage.
wps is + ~100 and mfu is + 1.5-2% when fused = True. 

Below are the logs for the last 100 steps for both.
```
**Fused = False**
[rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds
[rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910  loss:  4.6039  memory: 59.48GiB(75.15%)  wps: 2,217  mfu: 41.16%
[rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920  loss:  4.6427  memory: 59.48GiB(75.15%)  wps: 2,632  mfu: 48.85%
[rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930  loss:  4.6339  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940  loss:  4.5964  memory: 59.48GiB(75.15%)  wps: 2,631  mfu: 48.84%
[rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950  loss:  4.6477  memory: 59.48GiB(75.15%)  wps: 2,611  mfu: 48.47%
[rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960  loss:  4.8137  memory: 59.48GiB(75.15%)  wps: 2,626  mfu: 48.75%
[rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970  loss:  4.5962  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980  loss:  4.5450  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990  loss:  4.5840  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000  loss:  4.5351  memory: 59.48GiB(75.15%)  wps: 2,615  mfu: 48.53%
[rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds
[rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed

**Fused = True**
[rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910  loss:  4.6091  memory: 59.48GiB(75.15%)  wps: 2,341  mfu: 43.46%
[rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920  loss:  4.6468  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930  loss:  4.6365  memory: 59.48GiB(75.15%)  wps: 2,715  mfu: 50.40%
[rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940  loss:  4.6021  memory: 59.48GiB(75.15%)  wps: 2,716  mfu: 50.41%
[rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950  loss:  4.6522  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960  loss:  4.8163  memory: 59.48GiB(75.15%)  wps: 2,717  mfu: 50.44%
[rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970  loss:  4.6026  memory: 59.48GiB(75.15%)  wps: 2,721  mfu: 50.51%
[rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980  loss:  4.5491  memory: 59.48GiB(75.15%)  wps: 2,720  mfu: 50.49%
[rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990  loss:  4.5859  memory: 59.48GiB(75.15%)  wps: 2,719  mfu: 50.47%
[rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000  loss:  4.5396  memory: 59.48GiB(75.15%)  wps: 2,714  mfu: 50.38%
[rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed
```
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
With these three PRs landed, we can now support the option fused=True in
torchtitan for Adam and AdamW optimizer.

pytorch/pytorch#125369
pytorch/pytorch#126423
pytorch/pytorch#126750

Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default
[llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml).

Observation: 
For `fused = True` and `fused = False`, we observed similar loss curve
and memory usage.
wps is + ~100 and mfu is + 1.5-2% when fused = True. 

Below are the logs for the last 100 steps for both.
```
**Fused = False**
[rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds
[rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910  loss:  4.6039  memory: 59.48GiB(75.15%)  wps: 2,217  mfu: 41.16%
[rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920  loss:  4.6427  memory: 59.48GiB(75.15%)  wps: 2,632  mfu: 48.85%
[rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930  loss:  4.6339  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940  loss:  4.5964  memory: 59.48GiB(75.15%)  wps: 2,631  mfu: 48.84%
[rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950  loss:  4.6477  memory: 59.48GiB(75.15%)  wps: 2,611  mfu: 48.47%
[rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960  loss:  4.8137  memory: 59.48GiB(75.15%)  wps: 2,626  mfu: 48.75%
[rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970  loss:  4.5962  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980  loss:  4.5450  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990  loss:  4.5840  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000  loss:  4.5351  memory: 59.48GiB(75.15%)  wps: 2,615  mfu: 48.53%
[rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds
[rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed

**Fused = True**
[rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910  loss:  4.6091  memory: 59.48GiB(75.15%)  wps: 2,341  mfu: 43.46%
[rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920  loss:  4.6468  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930  loss:  4.6365  memory: 59.48GiB(75.15%)  wps: 2,715  mfu: 50.40%
[rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940  loss:  4.6021  memory: 59.48GiB(75.15%)  wps: 2,716  mfu: 50.41%
[rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950  loss:  4.6522  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960  loss:  4.8163  memory: 59.48GiB(75.15%)  wps: 2,717  mfu: 50.44%
[rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970  loss:  4.6026  memory: 59.48GiB(75.15%)  wps: 2,721  mfu: 50.51%
[rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980  loss:  4.5491  memory: 59.48GiB(75.15%)  wps: 2,720  mfu: 50.49%
[rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990  loss:  4.5859  memory: 59.48GiB(75.15%)  wps: 2,719  mfu: 50.47%
[rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000  loss:  4.5396  memory: 59.48GiB(75.15%)  wps: 2,714  mfu: 50.38%
[rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed
```
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
With these three PRs landed, we can now support the option fused=True in
torchtitan for Adam and AdamW optimizer.

pytorch/pytorch#125369
pytorch/pytorch#126423
pytorch/pytorch#126750

Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default
[llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml).

Observation: 
For `fused = True` and `fused = False`, we observed similar loss curve
and memory usage.
wps is + ~100 and mfu is + 1.5-2% when fused = True. 

Below are the logs for the last 100 steps for both.
```
**Fused = False**
[rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds
[rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910  loss:  4.6039  memory: 59.48GiB(75.15%)  wps: 2,217  mfu: 41.16%
[rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920  loss:  4.6427  memory: 59.48GiB(75.15%)  wps: 2,632  mfu: 48.85%
[rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930  loss:  4.6339  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940  loss:  4.5964  memory: 59.48GiB(75.15%)  wps: 2,631  mfu: 48.84%
[rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950  loss:  4.6477  memory: 59.48GiB(75.15%)  wps: 2,611  mfu: 48.47%
[rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960  loss:  4.8137  memory: 59.48GiB(75.15%)  wps: 2,626  mfu: 48.75%
[rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970  loss:  4.5962  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980  loss:  4.5450  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990  loss:  4.5840  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000  loss:  4.5351  memory: 59.48GiB(75.15%)  wps: 2,615  mfu: 48.53%
[rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds
[rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed

**Fused = True**
[rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910  loss:  4.6091  memory: 59.48GiB(75.15%)  wps: 2,341  mfu: 43.46%
[rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920  loss:  4.6468  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930  loss:  4.6365  memory: 59.48GiB(75.15%)  wps: 2,715  mfu: 50.40%
[rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940  loss:  4.6021  memory: 59.48GiB(75.15%)  wps: 2,716  mfu: 50.41%
[rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950  loss:  4.6522  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960  loss:  4.8163  memory: 59.48GiB(75.15%)  wps: 2,717  mfu: 50.44%
[rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970  loss:  4.6026  memory: 59.48GiB(75.15%)  wps: 2,721  mfu: 50.51%
[rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980  loss:  4.5491  memory: 59.48GiB(75.15%)  wps: 2,720  mfu: 50.49%
[rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990  loss:  4.5859  memory: 59.48GiB(75.15%)  wps: 2,719  mfu: 50.47%
[rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000  loss:  4.5396  memory: 59.48GiB(75.15%)  wps: 2,714  mfu: 50.38%
[rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fused AdamW not supported with FSDP2
4 participants