Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Added changes for GPT-2 perf #533

Draft
wants to merge 5 commits into
base: gh/awgu/15/base
Choose a base branch
from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Aug 19, 2024

Stack from ghstack (oldest at bottom):

Credit: @felipemello1 for the previous token chunked cross entropy
Credit: @Chillee for the new token chunked cross entropy

Running on 4xH100s:
Without these changes (torch.compile), the max local batch size is 5:

[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
Without these changes, no compile

Without these changes (no torch.compile), local batch size 5:

[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%

With these changes (torch.compile), local batch size 32:

[rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2024-09-06 19:49:08,904 - root - INFO - step:  1  loss: 12.2442  memory: 79.40GiB(83.54%)  wps: 24,819  mfu: 5.04%
[rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10  loss: 12.1998  memory: 80.81GiB(85.03%)  wps: 165,880  mfu: 33.66%
[rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20  loss: 11.9284  memory: 80.81GiB(85.03%)  wps: 165,732  mfu: 33.63%
[rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30  loss: 10.9587  memory: 80.81GiB(85.03%)  wps: 165,733  mfu: 33.63%
[rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40  loss:  9.8493  memory: 80.81GiB(85.03%)  wps: 165,904  mfu: 33.66%
[rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50  loss:  9.2317  memory: 80.81GiB(85.03%)  wps: 159,786  mfu: 32.42%
Old Results

With these changes, we can use local batch size 16:

[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%

22.7% MFU -> 34.1% MFU

awgu added a commit that referenced this pull request Aug 19, 2024
ghstack-source-id: 82808b1e55456ddc3df041231d965a5666b5b465
Pull Request resolved: #533
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 19, 2024
@felipemello1
Copy link

Great results! Consider using .split(8192//32, dim=1) instead of .chunk(16).

There was a huge difference in reserved memory in my experiments (no change in active memory though)
image

Credit: felipemello1 for most of the work here (especially around chunked cross entropy)

Running on 4xH100s:
Without these changes (`torch.compile`), the max local batch size is 5:
```
[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
```
Without these changes (no `torch.compile`), local batch size 5:
```
[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
```

With these changes, we can use local batch size 16:
```
[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
```

22.7% MFU -> 34.1% MFU


[ghstack-poisoned]
awgu added a commit that referenced this pull request Aug 19, 2024
ghstack-source-id: af1b1c31ed203910bb6a431296097b2c8fe0534e
Pull Request resolved: #533
@felipemello1
Copy link

felipemello1 commented Aug 19, 2024

FYI, compiling loss + model together should yield much better results than compiling the model alone, if this is whats happening.

instead of doing: torch.compile(model)

do something like:

@torch.compile()
def loss_step(input,label):
       output = self.model(input)
       loss = calculate_loss(output, label)
       loss.backward()

What we found is that using torch.compile on the cross entropy loss alone has great memory benefits (but not better than chunked): https://fb.workplace.com/groups/257735836456307/permalink/708422718054281/
And compiling model + loss together almost doubles toks/second: pytorch/torchtune#1228 (comment)

But the best results for us is compiling only the model + using the chunked cross entropy. If we compile everything, then the results of chunked cross entropy are lost.

@awgu
Copy link
Contributor Author

awgu commented Aug 21, 2024

If I try to compile both the output linear and cross entropy loss together instead of just compiling the cross entropy loss, I get OOMs at the same batch size.

@felipemello1
Copy link

My uneducated guess is that the optimizations they made for CrossEntropyLoss accounts only for the loss being compiled on its own. Details of their implementation here: https://fb.workplace.com/groups/257735836456307/permalink/708422718054281/

@awgu
Copy link
Contributor Author

awgu commented Aug 21, 2024

Llama3-8B

With these changes:

[rank0]:2024-08-21 08:44:32,865 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:44:32,897 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:44:32,953 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:44:45,742 - root - INFO - GPU memory usage for model: 3.99GiB(4.20%)
[rank0]:2024-08-21 08:44:45,743 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-21 08:46:07,756 - root - INFO - step:  1  loss: 12.2044  memory: 71.90GiB(75.65%)  wps: 100  mfu: 0.58%
[rank0]:2024-08-21 08:46:07,756 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:46:19,266 - root - INFO - step: 10  loss: 10.8650  memory: 82.42GiB(86.72%)  wps: 6,405  mfu: 37.51%
[rank0]:2024-08-21 08:46:32,000 - root - INFO - step: 20  loss:  9.1536  memory: 82.42GiB(86.72%)  wps: 6,434  mfu: 37.68%
[rank0]:2024-08-21 08:46:44,771 - root - INFO - step: 30  loss:  8.1057  memory: 82.42GiB(86.72%)  wps: 6,416  mfu: 37.57%

Baseline:

[rank0]:2024-08-21 08:47:49,505 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:47:49,829 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:47:49,892 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:48:01,630 - root - INFO - GPU memory usage for model: 3.78GiB(3.98%)
[rank0]:2024-08-21 08:48:01,631 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-21 08:48:56,359 - root - INFO - step:  1  loss: 12.2556  memory: 67.73GiB(71.26%)  wps: 150  mfu: 0.88%
[rank0]:2024-08-21 08:48:56,359 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:49:08,017 - root - INFO - step: 10  loss: 10.8891  memory: 71.64GiB(75.38%)  wps: 6,324  mfu: 37.03%
[rank0]:2024-08-21 08:49:20,989 - root - INFO - step: 20  loss:  9.0440  memory: 71.64GiB(75.38%)  wps: 6,316  mfu: 36.99%
[rank0]:2024-08-21 08:49:33,984 - root - INFO - step: 30  loss:  8.0371  memory: 71.64GiB(75.38%)  wps: 6,305  mfu: 36.92%
[rank0]:2024-08-21 08:49:46,985 - root - INFO - step: 40  loss:  7.4550  memory: 71.64GiB(75.38%)  wps: 6,302  mfu: 36.90%
[rank0]:2024-08-21 08:50:00,009 - root - INFO - step: 50  loss:  7.2264  memory: 71.64GiB(75.38%)  wps: 6,290  mfu: 36.84%

Moving .float() into the CE loss and compiling it (P1539164756):

[rank0]:2024-08-21 08:55:16,471 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:55:16,786 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:55:16,847 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:55:29,108 - root - INFO - GPU memory usage for model: 3.78GiB(3.98%)
[rank0]:2024-08-21 08:55:29,110 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-21 08:56:21,884 - root - INFO - step:  1  loss: 12.2164  memory: 58.18GiB(61.22%)  wps: 155  mfu: 0.91%
[rank0]:2024-08-21 08:56:21,884 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:56:33,351 - root - INFO - step: 10  loss: 10.8179  memory: 66.01GiB(69.46%)  wps: 6,430  mfu: 37.65%
[rank0]:2024-08-21 08:56:46,101 - root - INFO - step: 20  loss:  9.0846  memory: 66.01GiB(69.46%)  wps: 6,426  mfu: 37.63%
[rank0]:2024-08-21 08:56:58,879 - root - INFO - step: 30  loss:  8.0600  memory: 66.01GiB(69.46%)  wps: 6,412  mfu: 37.55%
[rank0]:2024-08-21 08:57:11,658 - root - INFO - step: 40  loss:  7.4393  memory: 66.01GiB(69.46%)  wps: 6,411  mfu: 37.54%
[rank0]:2024-08-21 08:57:24,460 - root - INFO - step: 50  loss:  7.1899  memory: 66.01GiB(69.46%)  wps: 6,400  mfu: 37.48%
[rank0]:2024-08-21 08:57:37,291 - root - INFO - step: 60  loss:  7.0205  memory: 66.01GiB(69.46%)  wps: 6,386  mfu: 37.39%

Credit: felipemello1 for most of the work here (especially around chunked cross entropy)

Running on 4xH100s:
Without these changes (`torch.compile`), the max local batch size is 5:
```
[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
```
Without these changes (no `torch.compile`), local batch size 5:
```
[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
```

With these changes, we can use local batch size 16:
```
[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
```

22.7% MFU -> 34.1% MFU


[ghstack-poisoned]
awgu added a commit that referenced this pull request Sep 7, 2024
ghstack-source-id: 39b91d06c8c1c6398e58a7d8841c4432ba4532c7
Pull Request resolved: #533
Credit: felipemello1 for most of the work here (especially around chunked cross entropy)

Running on 4xH100s:
Without these changes (`torch.compile`), the max local batch size is 5:
```
[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
```

<details>

<summary> Without these changes, no compile </summary>

Without these changes (no `torch.compile`), local batch size 5:
```
[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
```
</details>

With these changes (`torch.compile`), local batch size 32:
```
[rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2024-09-06 19:49:08,904 - root - INFO - step:  1  loss: 12.2442  memory: 79.40GiB(83.54%)  wps: 24,819  mfu: 5.04%
[rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10  loss: 12.1998  memory: 80.81GiB(85.03%)  wps: 165,880  mfu: 33.66%
[rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20  loss: 11.9284  memory: 80.81GiB(85.03%)  wps: 165,732  mfu: 33.63%
[rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30  loss: 10.9587  memory: 80.81GiB(85.03%)  wps: 165,733  mfu: 33.63%
[rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40  loss:  9.8493  memory: 80.81GiB(85.03%)  wps: 165,904  mfu: 33.66%
[rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50  loss:  9.2317  memory: 80.81GiB(85.03%)  wps: 159,786  mfu: 32.42%
```


<details>
<summary> Old Results </summary>

With these changes, we can use local batch size 16:
```
[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
```

22.7% MFU -> 34.1% MFU

</details>

[ghstack-poisoned]
awgu added a commit that referenced this pull request Sep 7, 2024
ghstack-source-id: 0cdcc964f2012f1b0c00e3eeba7eaca14e768629
Pull Request resolved: #533
Credit: felipemello1 for the previous token chunked cross entropy
Credit: Chillee for the new token chunked cross entropy

Running on 4xH100s:
Without these changes (`torch.compile`), the max local batch size is 5:
```
[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
```

<details>

<summary> Without these changes, no compile </summary>

Without these changes (no `torch.compile`), local batch size 5:
```
[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
```
</details>

With these changes (`torch.compile`), local batch size 32:
```
[rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2024-09-06 19:49:08,904 - root - INFO - step:  1  loss: 12.2442  memory: 79.40GiB(83.54%)  wps: 24,819  mfu: 5.04%
[rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10  loss: 12.1998  memory: 80.81GiB(85.03%)  wps: 165,880  mfu: 33.66%
[rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20  loss: 11.9284  memory: 80.81GiB(85.03%)  wps: 165,732  mfu: 33.63%
[rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30  loss: 10.9587  memory: 80.81GiB(85.03%)  wps: 165,733  mfu: 33.63%
[rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40  loss:  9.8493  memory: 80.81GiB(85.03%)  wps: 165,904  mfu: 33.66%
[rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50  loss:  9.2317  memory: 80.81GiB(85.03%)  wps: 159,786  mfu: 32.42%
```


<details>
<summary> Old Results </summary>

With these changes, we can use local batch size 16:
```
[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
```

22.7% MFU -> 34.1% MFU

</details>

[ghstack-poisoned]
awgu added a commit that referenced this pull request Sep 7, 2024
ghstack-source-id: ddfb8a972f0332ca9c7bd7ca6072b02df4e1792c
Pull Request resolved: #533
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants