Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Added changes for GPT-2 perf #533

Draft
wants to merge 5 commits into
base: gh/awgu/15/base
Choose a base branch
from

Commits on Aug 19, 2024

  1. [Not for land] Added changes for GPT-2 perf

    [ghstack-poisoned]
    awgu committed Aug 19, 2024
    Configuration menu
    Copy the full SHA
    6ad9afa View commit details
    Browse the repository at this point in the history
  2. Update on "[Not for land] Added changes for GPT-2 perf"

    Credit: felipemello1 for most of the work here (especially around chunked cross entropy)
    
    Running on 4xH100s:
    Without these changes (`torch.compile`), the max local batch size is 5:
    ```
    [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
    [rank0]:  warnings.warn(
    [rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
    [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
    [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
    [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
    [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
    [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
    [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
    [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
    [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
    [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
    [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
    ```
    Without these changes (no `torch.compile`), local batch size 5:
    ```
    [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
    [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
    [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
    [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
    [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
    [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
    [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
    [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
    [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
    [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
    [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
    ```
    
    With these changes, we can use local batch size 16:
    ```
    [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
    [rank0]:  warnings.warn(
    [rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
    [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
    [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
    [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
    [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
    [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
    [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
    [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
    [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
    [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
    [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
    ```
    
    22.7% MFU -> 34.1% MFU
    
    
    [ghstack-poisoned]
    awgu committed Aug 19, 2024
    Configuration menu
    Copy the full SHA
    b4a24d2 View commit details
    Browse the repository at this point in the history

Commits on Sep 7, 2024

  1. Update on "[Not for land] Added changes for GPT-2 perf"

    Credit: felipemello1 for most of the work here (especially around chunked cross entropy)
    
    Running on 4xH100s:
    Without these changes (`torch.compile`), the max local batch size is 5:
    ```
    [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
    [rank0]:  warnings.warn(
    [rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
    [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
    [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
    [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
    [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
    [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
    [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
    [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
    [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
    [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
    [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
    ```
    Without these changes (no `torch.compile`), local batch size 5:
    ```
    [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
    [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
    [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
    [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
    [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
    [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
    [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
    [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
    [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
    [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
    [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
    ```
    
    With these changes, we can use local batch size 16:
    ```
    [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
    [rank0]:  warnings.warn(
    [rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
    [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
    [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
    [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
    [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
    [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
    [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
    [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
    [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
    [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
    [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
    ```
    
    22.7% MFU -> 34.1% MFU
    
    
    [ghstack-poisoned]
    awgu committed Sep 7, 2024
    Configuration menu
    Copy the full SHA
    b6a84e4 View commit details
    Browse the repository at this point in the history
  2. Update on "[Not for land] Added changes for GPT-2 perf"

    Credit: felipemello1 for most of the work here (especially around chunked cross entropy)
    
    Running on 4xH100s:
    Without these changes (`torch.compile`), the max local batch size is 5:
    ```
    [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
    [rank0]:  warnings.warn(
    [rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
    [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
    [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
    [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
    [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
    [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
    [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
    [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
    [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
    [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
    [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
    ```
    
    <details>
    
    <summary> Without these changes, no compile </summary>
    
    Without these changes (no `torch.compile`), local batch size 5:
    ```
    [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
    [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
    [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
    [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
    [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
    [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
    [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
    [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
    [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
    [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
    [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
    ```
    </details>
    
    With these changes (`torch.compile`), local batch size 32:
    ```
    [rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200)
    [rank0]:2024-09-06 19:49:08,904 - root - INFO - step:  1  loss: 12.2442  memory: 79.40GiB(83.54%)  wps: 24,819  mfu: 5.04%
    [rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10  loss: 12.1998  memory: 80.81GiB(85.03%)  wps: 165,880  mfu: 33.66%
    [rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20  loss: 11.9284  memory: 80.81GiB(85.03%)  wps: 165,732  mfu: 33.63%
    [rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30  loss: 10.9587  memory: 80.81GiB(85.03%)  wps: 165,733  mfu: 33.63%
    [rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40  loss:  9.8493  memory: 80.81GiB(85.03%)  wps: 165,904  mfu: 33.66%
    [rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50  loss:  9.2317  memory: 80.81GiB(85.03%)  wps: 159,786  mfu: 32.42%
    ```
    
    
    <details>
    <summary> Old Results </summary>
    
    With these changes, we can use local batch size 16:
    ```
    [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
    [rank0]:  warnings.warn(
    [rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
    [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
    [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
    [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
    [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
    [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
    [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
    [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
    [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
    [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
    [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
    ```
    
    22.7% MFU -> 34.1% MFU
    
    </details>
    
    [ghstack-poisoned]
    awgu committed Sep 7, 2024
    Configuration menu
    Copy the full SHA
    8b27669 View commit details
    Browse the repository at this point in the history
  3. Update on "[Not for land] Added changes for GPT-2 perf"

    Credit: felipemello1 for the previous token chunked cross entropy
    Credit: Chillee for the new token chunked cross entropy
    
    Running on 4xH100s:
    Without these changes (`torch.compile`), the max local batch size is 5:
    ```
    [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
    [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
    [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
    [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
    [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
    [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
    [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
    [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
    [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
    [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
    [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
    ```
    
    <details>
    
    <summary> Without these changes, no compile </summary>
    
    Without these changes (no `torch.compile`), local batch size 5:
    ```
    [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
    [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
    [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
    [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
    [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
    [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
    [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
    [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
    [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
    [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
    [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
    ```
    </details>
    
    With these changes (`torch.compile`), local batch size 32:
    ```
    [rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200)
    [rank0]:2024-09-06 19:49:08,904 - root - INFO - step:  1  loss: 12.2442  memory: 79.40GiB(83.54%)  wps: 24,819  mfu: 5.04%
    [rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10  loss: 12.1998  memory: 80.81GiB(85.03%)  wps: 165,880  mfu: 33.66%
    [rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20  loss: 11.9284  memory: 80.81GiB(85.03%)  wps: 165,732  mfu: 33.63%
    [rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30  loss: 10.9587  memory: 80.81GiB(85.03%)  wps: 165,733  mfu: 33.63%
    [rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40  loss:  9.8493  memory: 80.81GiB(85.03%)  wps: 165,904  mfu: 33.66%
    [rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50  loss:  9.2317  memory: 80.81GiB(85.03%)  wps: 159,786  mfu: 32.42%
    ```
    
    
    <details>
    <summary> Old Results </summary>
    
    With these changes, we can use local batch size 16:
    ```
    [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
    [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
    [rank0]:  warnings.warn(
    [rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
    [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
    [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
    [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
    [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
    [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
    [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
    [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
    [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
    [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
    [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
    ```
    
    22.7% MFU -> 34.1% MFU
    
    </details>
    
    [ghstack-poisoned]
    awgu committed Sep 7, 2024
    Configuration menu
    Copy the full SHA
    e0634d9 View commit details
    Browse the repository at this point in the history