Skip to content

Commit

Permalink
Update on "Throw warning if users are using old pytorch version that …
Browse files Browse the repository at this point in the history
…not including DTensor strided sharding"


**Summary**
1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not.
2. remove temporary re-enablement added in #460 .

**Test**
Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8`
GPUs: A100
Output:
- without strided sharding:
```
[rank7]:2024-08-06 03:21:26,706 - root - INFO - step:  2  loss:  8.1652  memory:  0.51GiB(0.64%)  wps: 8,250  mfu: 0.25%
[rank7]:2024-08-06 03:21:27,013 - root - INFO - step:  3  loss:  8.0951  memory:  0.51GiB(0.64%)  wps: 13,358  mfu: 0.41%
[rank7]:2024-08-06 03:21:27,309 - root - INFO - step:  4  loss:  7.9748  memory:  0.51GiB(0.64%)  wps: 13,865  mfu: 0.42%
[rank7]:2024-08-06 03:21:27,582 - root - INFO - step:  5  loss:  7.8025  memory:  0.51GiB(0.64%)  wps: 15,057  mfu: 0.46%
[rank7]:2024-08-06 03:21:28,076 - root - INFO - step:  6  loss:  7.5612  memory:  0.51GiB(0.64%)  wps: 8,300  mfu: 0.25%
[rank7]:2024-08-06 03:21:28,608 - root - INFO - step:  7  loss:  7.3649  memory:  0.51GiB(0.64%)  wps: 7,705  mfu: 0.23%
[rank7]:2024-08-06 03:21:28,927 - root - INFO - step:  8  loss:  7.2946  memory:  0.51GiB(0.64%)  wps: 12,832  mfu: 0.39%
[rank7]:2024-08-06 03:21:29,251 - root - INFO - step:  9  loss:  7.1311  memory:  0.51GiB(0.64%)  wps: 12,669  mfu: 0.38%
[rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10  loss:  7.0540  memory:  0.51GiB(0.64%)  wps: 10,918  mfu: 0.33%
>>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<<
[rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11  loss:  7.0822  memory:  0.51GiB(0.64%)  wps: 1,139  mfu: 0.03%
[rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12  loss:  7.0508  memory:  0.51GiB(0.64%)  wps: 12,366  mfu: 0.38%
[rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13  loss:  6.9182  memory:  0.51GiB(0.64%)  wps: 14,370  mfu: 0.44%
[rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14  loss:  6.8948  memory:  0.51GiB(0.64%)  wps: 14,442  mfu: 0.44%
[rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15  loss:  6.8358  memory:  0.51GiB(0.64%)  wps: 14,514  mfu: 0.44%
[rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16  loss:  6.7653  memory:  0.51GiB(0.64%)  wps: 6,144  mfu: 0.19%
[rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17  loss:  6.7340  memory:  0.51GiB(0.64%)  wps: 6,453  mfu: 0.20%
[rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18  loss:  6.6874  memory:  0.51GiB(0.64%)  wps: 12,695  mfu: 0.39%
[rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19  loss:  6.6566  memory:  0.51GiB(0.64%)  wps: 12,406  mfu: 0.38%
[rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20  loss:  6.6629  memory:  0.51GiB(0.64%)  wps: 10,392  mfu: 0.32%
```
- with strided sharding
```
[rank7]:2024-08-06 03:26:18,288 - root - INFO - step:  1  loss:  8.2069  memory:  0.50GiB(0.63%)  wps: 915  mfu: 0.03%
[rank7]:2024-08-06 03:26:19,084 - root - INFO - step:  2  loss:  8.1913  memory:  0.51GiB(0.64%)  wps: 5,144  mfu: 0.16%
[rank7]:2024-08-06 03:26:19,365 - root - INFO - step:  3  loss:  8.1148  memory:  0.51GiB(0.64%)  wps: 14,593  mfu: 0.44%
[rank7]:2024-08-06 03:26:19,698 - root - INFO - step:  4  loss:  7.9982  memory:  0.51GiB(0.64%)  wps: 12,328  mfu: 0.37%
[rank7]:2024-08-06 03:26:20,011 - root - INFO - step:  5  loss:  7.8382  memory:  0.51GiB(0.64%)  wps: 13,100  mfu: 0.40%
[rank7]:2024-08-06 03:26:20,498 - root - INFO - step:  6  loss:  7.6293  memory:  0.51GiB(0.64%)  wps: 8,423  mfu: 0.26%
[rank7]:2024-08-06 03:26:21,126 - root - INFO - step:  7  loss:  7.4454  memory:  0.51GiB(0.64%)  wps: 6,530  mfu: 0.20%
[rank7]:2024-08-06 03:26:21,472 - root - INFO - step:  8  loss:  7.3337  memory:  0.51GiB(0.64%)  wps: 11,843  mfu: 0.36%
[rank7]:2024-08-06 03:26:21,849 - root - INFO - step:  9  loss:  7.1960  memory:  0.51GiB(0.64%)  wps: 10,892  mfu: 0.33%
[rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10  loss:  7.1208  memory:  0.51GiB(0.64%)  wps: 10,798  mfu: 0.33%
>>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<<
[rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11  loss:  7.1222  memory:  0.51GiB(0.64%)  wps: 866  mfu: 0.03%
[rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12  loss:  7.1189  memory:  0.51GiB(0.64%)  wps: 12,589  mfu: 0.38%
[rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13  loss:  6.9646  memory:  0.51GiB(0.64%)  wps: 14,417  mfu: 0.44%
[rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14  loss:  6.9626  memory:  0.51GiB(0.64%)  wps: 13,680  mfu: 0.42%
[rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15  loss:  6.8694  memory:  0.51GiB(0.64%)  wps: 13,799  mfu: 0.42%
[rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16  loss:  6.7994  memory:  0.51GiB(0.64%)  wps: 5,910  mfu: 0.18%
[rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17  loss:  6.7634  memory:  0.51GiB(0.64%)  wps: 4,847  mfu: 0.15%
[rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18  loss:  6.7233  memory:  0.51GiB(0.64%)  wps: 12,915  mfu: 0.39%
[rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19  loss:  6.7054  memory:  0.51GiB(0.64%)  wps: 12,995  mfu: 0.39%
[rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20  loss:  6.7130  memory:  0.51GiB(0.64%)  wps: 10,991  mfu: 0.33%
```

When to merge:
when pytorch/pytorch#130760 is in nightly build.


[ghstack-poisoned]
  • Loading branch information
XilunWu committed Aug 13, 2024
2 parents a5152af + be21aa8 commit 306c369
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def parallelize_llama(
reduce_dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_reduce
],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)
else:
Expand Down Expand Up @@ -290,6 +291,7 @@ def apply_fsdp(
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool,
):
"""
Expand All @@ -298,6 +300,10 @@ def apply_fsdp(
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

if tp_enabled:
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
Expand All @@ -314,9 +320,6 @@ def apply_fsdp(
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)

# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()

logger.info("Applied FSDP to the model")


Expand Down

0 comments on commit 306c369

Please sign in to comment.