Skip to content

Commit

Permalink
Update base for Update on "Re-enable FSDP+TP w/ strided sharding"
Browse files Browse the repository at this point in the history
**Summary**
1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not.
2. remove temporary re-enablement added in #460 .

**Test**
Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8`
GPUs: A100
Output:
- without strided sharding:
```
[rank7]:2024-08-06 03:21:26,706 - root - INFO - step:  2  loss:  8.1652  memory:  0.51GiB(0.64%)  wps: 8,250  mfu: 0.25%
[rank7]:2024-08-06 03:21:27,013 - root - INFO - step:  3  loss:  8.0951  memory:  0.51GiB(0.64%)  wps: 13,358  mfu: 0.41%
[rank7]:2024-08-06 03:21:27,309 - root - INFO - step:  4  loss:  7.9748  memory:  0.51GiB(0.64%)  wps: 13,865  mfu: 0.42%
[rank7]:2024-08-06 03:21:27,582 - root - INFO - step:  5  loss:  7.8025  memory:  0.51GiB(0.64%)  wps: 15,057  mfu: 0.46%
[rank7]:2024-08-06 03:21:28,076 - root - INFO - step:  6  loss:  7.5612  memory:  0.51GiB(0.64%)  wps: 8,300  mfu: 0.25%
[rank7]:2024-08-06 03:21:28,608 - root - INFO - step:  7  loss:  7.3649  memory:  0.51GiB(0.64%)  wps: 7,705  mfu: 0.23%
[rank7]:2024-08-06 03:21:28,927 - root - INFO - step:  8  loss:  7.2946  memory:  0.51GiB(0.64%)  wps: 12,832  mfu: 0.39%
[rank7]:2024-08-06 03:21:29,251 - root - INFO - step:  9  loss:  7.1311  memory:  0.51GiB(0.64%)  wps: 12,669  mfu: 0.38%
[rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10  loss:  7.0540  memory:  0.51GiB(0.64%)  wps: 10,918  mfu: 0.33%
>>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<<
[rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11  loss:  7.0822  memory:  0.51GiB(0.64%)  wps: 1,139  mfu: 0.03%
[rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12  loss:  7.0508  memory:  0.51GiB(0.64%)  wps: 12,366  mfu: 0.38%
[rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13  loss:  6.9182  memory:  0.51GiB(0.64%)  wps: 14,370  mfu: 0.44%
[rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14  loss:  6.8948  memory:  0.51GiB(0.64%)  wps: 14,442  mfu: 0.44%
[rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15  loss:  6.8358  memory:  0.51GiB(0.64%)  wps: 14,514  mfu: 0.44%
[rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16  loss:  6.7653  memory:  0.51GiB(0.64%)  wps: 6,144  mfu: 0.19%
[rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17  loss:  6.7340  memory:  0.51GiB(0.64%)  wps: 6,453  mfu: 0.20%
[rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18  loss:  6.6874  memory:  0.51GiB(0.64%)  wps: 12,695  mfu: 0.39%
[rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19  loss:  6.6566  memory:  0.51GiB(0.64%)  wps: 12,406  mfu: 0.38%
[rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20  loss:  6.6629  memory:  0.51GiB(0.64%)  wps: 10,392  mfu: 0.32%
```
- with strided sharding
```
[rank7]:2024-08-06 03:26:18,288 - root - INFO - step:  1  loss:  8.2069  memory:  0.50GiB(0.63%)  wps: 915  mfu: 0.03%
[rank7]:2024-08-06 03:26:19,084 - root - INFO - step:  2  loss:  8.1913  memory:  0.51GiB(0.64%)  wps: 5,144  mfu: 0.16%
[rank7]:2024-08-06 03:26:19,365 - root - INFO - step:  3  loss:  8.1148  memory:  0.51GiB(0.64%)  wps: 14,593  mfu: 0.44%
[rank7]:2024-08-06 03:26:19,698 - root - INFO - step:  4  loss:  7.9982  memory:  0.51GiB(0.64%)  wps: 12,328  mfu: 0.37%
[rank7]:2024-08-06 03:26:20,011 - root - INFO - step:  5  loss:  7.8382  memory:  0.51GiB(0.64%)  wps: 13,100  mfu: 0.40%
[rank7]:2024-08-06 03:26:20,498 - root - INFO - step:  6  loss:  7.6293  memory:  0.51GiB(0.64%)  wps: 8,423  mfu: 0.26%
[rank7]:2024-08-06 03:26:21,126 - root - INFO - step:  7  loss:  7.4454  memory:  0.51GiB(0.64%)  wps: 6,530  mfu: 0.20%
[rank7]:2024-08-06 03:26:21,472 - root - INFO - step:  8  loss:  7.3337  memory:  0.51GiB(0.64%)  wps: 11,843  mfu: 0.36%
[rank7]:2024-08-06 03:26:21,849 - root - INFO - step:  9  loss:  7.1960  memory:  0.51GiB(0.64%)  wps: 10,892  mfu: 0.33%
[rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10  loss:  7.1208  memory:  0.51GiB(0.64%)  wps: 10,798  mfu: 0.33%
>>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<<
[rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11  loss:  7.1222  memory:  0.51GiB(0.64%)  wps: 866  mfu: 0.03%
[rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12  loss:  7.1189  memory:  0.51GiB(0.64%)  wps: 12,589  mfu: 0.38%
[rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13  loss:  6.9646  memory:  0.51GiB(0.64%)  wps: 14,417  mfu: 0.44%
[rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14  loss:  6.9626  memory:  0.51GiB(0.64%)  wps: 13,680  mfu: 0.42%
[rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15  loss:  6.8694  memory:  0.51GiB(0.64%)  wps: 13,799  mfu: 0.42%
[rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16  loss:  6.7994  memory:  0.51GiB(0.64%)  wps: 5,910  mfu: 0.18%
[rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17  loss:  6.7634  memory:  0.51GiB(0.64%)  wps: 4,847  mfu: 0.15%
[rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18  loss:  6.7233  memory:  0.51GiB(0.64%)  wps: 12,915  mfu: 0.39%
[rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19  loss:  6.7054  memory:  0.51GiB(0.64%)  wps: 12,995  mfu: 0.39%
[rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20  loss:  6.7130  memory:  0.51GiB(0.64%)  wps: 10,991  mfu: 0.33%
```

When to merge:
when pytorch/pytorch#130760 is in nightly build.


[ghstack-poisoned]
  • Loading branch information
XilunWu committed Aug 12, 2024
2 parents f9e114b + fa7fe1e commit f1f2d0f
Show file tree
Hide file tree
Showing 26 changed files with 2,164 additions and 45,139 deletions.
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch >= 2.3.0
torchdata >= 0.8.0
datasets >= 2.19.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/integration_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
1 change: 0 additions & 1 deletion .github/workflows/integration_test_8gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,5 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 8
1 change: 0 additions & 1 deletion .github/workflows/unit_test_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ jobs:
pip config --user set global.progress_bar off
pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
pytest test --cov=. --cov-report=xml --durations=20 -vv
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes
6. Learning rate scheduler, meta init, Optional Fused RMSNorm
7. All options easily configured via [toml files](train_configs/)
8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
9. [Float8 support](docs/float8.md)

We report our [Performance](docs/performance.md) verified on 64 A100 GPUs


### Coming soon

1. Async checkpointing
2. Float8 support
3. Context Parallel
4. 3D Pipeline Parallel
5. `torch.compile` support
6. Scalable data loading solution
2. Context Parallel
3. 3D Pipeline Parallel
4. `torch.compile` support
5. Scalable data loading solution


## Installation
Expand All @@ -64,7 +64,6 @@ git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
```

### Downloading a tokenizer
Expand Down
22 changes: 22 additions & 0 deletions docs/composability.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Building a Clean, Readable Distributed LLM
One of the main goals for TorchTitan was to provide a version of distributed LLM that was not only high performance, but utilized native pytorch techniques and readable code. The challenge is how to compose together so many individual library components (FSDP, TP, PP, FP8, Compile, DCP, ...) just to name a few, and avoid having to make too many changes to the model guts in the process. A lot of the work is behind the scenes, designing individual components to make fewer assumptions, use common abstractions (e.g. DTensor) and generally 'get along'. But we found a few tweaks to the model code invaluable as well, and wanted to share those changes and the rationale for them.



# Making the model "pipeline friendly"
When applying Pipeline Parallelism, you will have to construct nn.Module objects representing the portion of the model that runs on a given pipeline stage. Whether you plan to manually edit your model code, or use techniques like tracing to extract model chunks, a few changes to the original model code can go a long way to making this process easier.

### Simplifying the top-level model forward
Most likely, you can write your model in such a way that the top-level nn.Module owns a sequence of child modules that it calls during forward, delegating most of the complexity to the child module forwards. If you can reduce your top level forward to mostly a for-loop over child module calls, then you'll simplify the pipeline-partitioning task to choosing the set of submodules to keep per stage. If you have non-trivial logic in the top-level forward, you'll have to find a way to patch that logic back onto the resulting pipeline stage model, which can be annoying.

example ([PR #321](https://github.com/pytorch/torchtitan/pull/321)):
we used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.

example ([PR #322])https://github.com/pytorch/torchtitan/pull/322)): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.

# Using a seed checkpoint for init
Initializing the pipeline-parallel model is challenging becuase we assume the model could be so large as to not fit on local GPU (or possibly, even on CPU), and we also want to use the (bitwise) same initialization as we use for 1D or 2D parallel models, to ease debugging or comparisons between runs. It's not that easy to rewrite the original model's `init_weights` function to be tolerant of initializing only some layers, and also serializing initialization operations globally for consistent RNG order.

For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to `torch.pipelining`.

One issue with seed checkpoints is that we rely on initializing _every_ model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in `train.py` after pipeline splitting. `freqs_cis` was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.
18 changes: 18 additions & 0 deletions docs/float8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Enable Float8 Training on H100s

Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/float8) to support float8 dtype
```
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
```

Launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
```
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.

For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`).

For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.
33 changes: 12 additions & 21 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,33 +122,25 @@ def loss_fn(pred, labels):
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)
model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(whole_model)
float8_handler.convert_to_float8_training(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cuda"
for model in model_parts:
model.to_empty(device=init_device)
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

model.to_empty(device="cuda")
if not active_fake_mode():
whole_model.init_weights()
model.init_weights()
model.train()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
optimizers = build_optimizers([model], job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
logger.info(f"Vocab size: {model_config.vocab_size}")
# Create a dummy batch instead of loading from a dataset
batch = (
Expand All @@ -165,24 +157,23 @@ def loss_fn(pred, labels):
device="cuda",
),
)
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with train_context():
pred = whole_model(input_ids)
pred = model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
Expand Down
Loading

0 comments on commit f1f2d0f

Please sign in to comment.