You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We do not plan to support DDP + TP as we have not identified any major use cases for this combination. When working with large models, it is more common to use FSDP + TP instead of DDP + TP. Additionally, FSDP offers several features that are not available in DDP, such as fp8. Therefore, we believe that DDP is better suited for smaller models.
In TorchTitan, we enabled DDP primarily for sanity check purposes, such as verifying parallelism with 8B model and very a small batch size. So we did not verify the correctness of DDP + TP.
We do not plan to support DDP + TP as we have not identified any major use cases for this combination. When working with large models, it is more common to use FSDP + TP instead of DDP + TP. Additionally, FSDP offers several features that are not available in DDP, such as fp8. Therefore, we believe that DDP is better suited for smaller models. In TorchTitan, we enabled DDP primarily for sanity check purposes, such as verifying parallelism with 8B model and very a small batch size. So we did not verify the correctness of DDP + TP.
Thanks for the reply! I learned that FSDP+TP should be the primary/only use, especially for LLMs.
Currently, when there are two device meshes (
tp
anddp
), torchtitan should choose FSDP as the only backend for DP. Ref:torchtitan/torchtitan/parallelisms/parallelize_llama.py
Lines 97 to 98 in d2a4904
However, the
replicate
should support >1D mesh and be used with TP enabled. Ref.Q1: Why does torchtitan not support DDP (replicate) + TP? Is it only an implementation choice?
I have handwritten DDP + TP in torchtitan and surprisingly found that the loss never goes down. It seems there are no gradients after
loss.backward()
.To reproduce, use the branch above and run
run_llama_train.sh
on an 8-GPU machine.Q2: Is it a bug or an intended feature that DDP+TP is not used, and that results in missing gradients?
And collect_env:
P.S.
DistributedDataParallel
(class) rather thanreplicate
behaves wellThanks in advance!
The text was updated successfully, but these errors were encountered: