Skip to content

Commit

Permalink
Fix 8gpu PP failure due to 2D DCP disablement
Browse files Browse the repository at this point in the history
DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f
Pull Request resolved: pytorch#460
  • Loading branch information
wconstab committed Jul 26, 2024
1 parent aec4788 commit c9ea259
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,17 @@ def apply_fsdp(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)

if parallel_dims.pp_enabled:
# TODO
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
# without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be
# removed after strided sharding is landed in DCP.
for module in model.modules():
assert len(module._load_state_dict_pre_hooks) <= 1
module._load_state_dict_pre_hooks.clear()
assert len(module._state_dict_pre_hooks) <= 1
module._state_dict_pre_hooks.clear()
logger.info("Applied FSDP to the model")
return model

Expand Down

0 comments on commit c9ea259

Please sign in to comment.