Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchtitan][optim] Add fused as an option in train config #355

Merged
merged 1 commit into from
Jun 6, 2024

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented May 22, 2024

With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer.

pytorch/pytorch#125369
pytorch/pytorch#126423
pytorch/pytorch#126750

Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default llama_8b.toml.

Observation:
For fused = True and fused = False, we observed similar loss curve and memory usage.
wps is + ~100 and mfu is + 1.5-2% when fused = True.

Below are the logs for the last 100 steps for both.

**Fused = False**
[rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds
[rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910  loss:  4.6039  memory: 59.48GiB(75.15%)  wps: 2,217  mfu: 41.16%
[rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920  loss:  4.6427  memory: 59.48GiB(75.15%)  wps: 2,632  mfu: 48.85%
[rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930  loss:  4.6339  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940  loss:  4.5964  memory: 59.48GiB(75.15%)  wps: 2,631  mfu: 48.84%
[rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950  loss:  4.6477  memory: 59.48GiB(75.15%)  wps: 2,611  mfu: 48.47%
[rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960  loss:  4.8137  memory: 59.48GiB(75.15%)  wps: 2,626  mfu: 48.75%
[rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970  loss:  4.5962  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980  loss:  4.5450  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990  loss:  4.5840  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000  loss:  4.5351  memory: 59.48GiB(75.15%)  wps: 2,615  mfu: 48.53%
[rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds
[rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed

**Fused = True**
[rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910  loss:  4.6091  memory: 59.48GiB(75.15%)  wps: 2,341  mfu: 43.46%
[rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920  loss:  4.6468  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930  loss:  4.6365  memory: 59.48GiB(75.15%)  wps: 2,715  mfu: 50.40%
[rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940  loss:  4.6021  memory: 59.48GiB(75.15%)  wps: 2,716  mfu: 50.41%
[rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950  loss:  4.6522  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960  loss:  4.8163  memory: 59.48GiB(75.15%)  wps: 2,717  mfu: 50.44%
[rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970  loss:  4.6026  memory: 59.48GiB(75.15%)  wps: 2,721  mfu: 50.51%
[rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980  loss:  4.5491  memory: 59.48GiB(75.15%)  wps: 2,720  mfu: 50.49%
[rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990  loss:  4.5859  memory: 59.48GiB(75.15%)  wps: 2,719  mfu: 50.47%
[rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000  loss:  4.5396  memory: 59.48GiB(75.15%)  wps: 2,714  mfu: 50.38%
[rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed

@wz337 wz337 requested a review from awgu May 22, 2024 22:57
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 22, 2024
@wz337 wz337 requested a review from wanchaol May 22, 2024 22:57
@wz337 wz337 force-pushed the add_fused_to_train_config branch from 11826b1 to e8461b4 Compare May 22, 2024 22:59
@wz337 wz337 changed the title [torchtitan][optim] Add fused to train optimizer config [torchtitan][optim] Add fused as an option in train config May 22, 2024
@awgu
Copy link
Contributor

awgu commented May 23, 2024

I am curious if we have any experiments to see the performance difference with fused=True.

@wz337
Copy link
Contributor Author

wz337 commented May 23, 2024

I am curious if we have any experiments to see the performance difference with fused=True.

Gonna talk to Tianyu to learn how to run the perf experiments on the new 128 GPUs.
This is just adding it to the config to allow it, but the default behavior is still foreach=True.

We can totally wait for the result before landing this.

@wanchaol
Copy link
Contributor

can we add some 8 GPU numbers at least? 128 GPU can be done separately

@wz337
Copy link
Contributor Author

wz337 commented Jun 5, 2024

@wanchaol @awgu Added performance diff in the summary. I think we are comfortable offering this option in torchtitan?

@wz337 wz337 merged commit 7cf41bb into pytorch:main Jun 6, 2024
4 checks passed
@weifengpy
Copy link
Contributor

weifengpy commented Jun 18, 2024

this PR (foreach=true) shortened opt.step from 2000ms to 200ms. That's +10% e2e QPS on 16 H100 node (16 x 8 GPUs). I might need to refresh 1D and 2D benchmark base on this @drisspg

@awgu
Copy link
Contributor

awgu commented Jun 18, 2024

@weifengpy foreach=True used to be the default, so perhaps your package was before #386 landed. Without #386, the optimizer would fall back to foreach=False when fused=False. 2000 ms for optimizer step sounds like foreach=False.

@weifengpy
Copy link
Contributor

@weifengpy foreach=True used to be the default, so perhaps your package was before #386 landed. Without #386, the optimizer would fall back to foreach=False when fused=False. 2000 ms for optimizer step sounds like foreach=False.

Ah got you. I checked the trace and 2000ms indeed comes from foreach=False

tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
With these three PRs landed, we can now support the option fused=True in
torchtitan for Adam and AdamW optimizer.

pytorch/pytorch#125369
pytorch/pytorch#126423
pytorch/pytorch#126750

Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default
[llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml).

Observation: 
For `fused = True` and `fused = False`, we observed similar loss curve
and memory usage.
wps is + ~100 and mfu is + 1.5-2% when fused = True. 

Below are the logs for the last 100 steps for both.
```
**Fused = False**
[rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds
[rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910  loss:  4.6039  memory: 59.48GiB(75.15%)  wps: 2,217  mfu: 41.16%
[rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920  loss:  4.6427  memory: 59.48GiB(75.15%)  wps: 2,632  mfu: 48.85%
[rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930  loss:  4.6339  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940  loss:  4.5964  memory: 59.48GiB(75.15%)  wps: 2,631  mfu: 48.84%
[rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950  loss:  4.6477  memory: 59.48GiB(75.15%)  wps: 2,611  mfu: 48.47%
[rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960  loss:  4.8137  memory: 59.48GiB(75.15%)  wps: 2,626  mfu: 48.75%
[rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970  loss:  4.5962  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980  loss:  4.5450  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990  loss:  4.5840  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000  loss:  4.5351  memory: 59.48GiB(75.15%)  wps: 2,615  mfu: 48.53%
[rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds
[rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed

**Fused = True**
[rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910  loss:  4.6091  memory: 59.48GiB(75.15%)  wps: 2,341  mfu: 43.46%
[rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920  loss:  4.6468  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930  loss:  4.6365  memory: 59.48GiB(75.15%)  wps: 2,715  mfu: 50.40%
[rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940  loss:  4.6021  memory: 59.48GiB(75.15%)  wps: 2,716  mfu: 50.41%
[rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950  loss:  4.6522  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960  loss:  4.8163  memory: 59.48GiB(75.15%)  wps: 2,717  mfu: 50.44%
[rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970  loss:  4.6026  memory: 59.48GiB(75.15%)  wps: 2,721  mfu: 50.51%
[rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980  loss:  4.5491  memory: 59.48GiB(75.15%)  wps: 2,720  mfu: 50.49%
[rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990  loss:  4.5859  memory: 59.48GiB(75.15%)  wps: 2,719  mfu: 50.47%
[rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000  loss:  4.5396  memory: 59.48GiB(75.15%)  wps: 2,714  mfu: 50.38%
[rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed
```
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
With these three PRs landed, we can now support the option fused=True in
torchtitan for Adam and AdamW optimizer.

pytorch/pytorch#125369
pytorch/pytorch#126423
pytorch/pytorch#126750

Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default
[llama_8b.toml](https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml).

Observation: 
For `fused = True` and `fused = False`, we observed similar loss curve
and memory usage.
wps is + ~100 and mfu is + 1.5-2% when fused = True. 

Below are the logs for the last 100 steps for both.
```
**Fused = False**
[rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds
[rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910  loss:  4.6039  memory: 59.48GiB(75.15%)  wps: 2,217  mfu: 41.16%
[rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920  loss:  4.6427  memory: 59.48GiB(75.15%)  wps: 2,632  mfu: 48.85%
[rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930  loss:  4.6339  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940  loss:  4.5964  memory: 59.48GiB(75.15%)  wps: 2,631  mfu: 48.84%
[rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950  loss:  4.6477  memory: 59.48GiB(75.15%)  wps: 2,611  mfu: 48.47%
[rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960  loss:  4.8137  memory: 59.48GiB(75.15%)  wps: 2,626  mfu: 48.75%
[rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970  loss:  4.5962  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980  loss:  4.5450  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990  loss:  4.5840  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000  loss:  4.5351  memory: 59.48GiB(75.15%)  wps: 2,615  mfu: 48.53%
[rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds
[rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed

**Fused = True**
[rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910  loss:  4.6091  memory: 59.48GiB(75.15%)  wps: 2,341  mfu: 43.46%
[rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920  loss:  4.6468  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930  loss:  4.6365  memory: 59.48GiB(75.15%)  wps: 2,715  mfu: 50.40%
[rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940  loss:  4.6021  memory: 59.48GiB(75.15%)  wps: 2,716  mfu: 50.41%
[rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950  loss:  4.6522  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960  loss:  4.8163  memory: 59.48GiB(75.15%)  wps: 2,717  mfu: 50.44%
[rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970  loss:  4.6026  memory: 59.48GiB(75.15%)  wps: 2,721  mfu: 50.51%
[rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980  loss:  4.5491  memory: 59.48GiB(75.15%)  wps: 2,720  mfu: 50.49%
[rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990  loss:  4.5859  memory: 59.48GiB(75.15%)  wps: 2,719  mfu: 50.47%
[rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000  loss:  4.5396  memory: 59.48GiB(75.15%)  wps: 2,714  mfu: 50.38%
[rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants