From c9ea2590482c59773ba7d537c4ae6e2000ef6af4 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 15 Jul 2024 12:11:50 -0700 Subject: [PATCH] Fix 8gpu PP failure due to 2D DCP disablement DCP recently added safeties to avoid using it for 2D/3D since strided sharding (a feature needed for safe 2D/3D resharding) is not ready yet. PP uses DCP to load a seed checkpoint. Disabling the safety mechanism is enough to make 3D/PP still work (for the case where we train from the beginning or do not re-shard. (Resharding refers to saving a checkpoint from one world size/parallelism config and loading/resuming under a different one). ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f Pull Request resolved: https://github.com/pytorch/torchtitan/pull/460 --- torchtitan/parallelisms/parallelize_llama.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 31eabc6c..3d123953 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -507,6 +507,17 @@ def apply_fsdp( model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled ) + if parallel_dims.pp_enabled: + # TODO + # This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since + # without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even + # without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be + # removed after strided sharding is landed in DCP. + for module in model.modules(): + assert len(module._load_state_dict_pre_hooks) <= 1 + module._load_state_dict_pre_hooks.clear() + assert len(module._state_dict_pre_hooks) <= 1 + module._state_dict_pre_hooks.clear() logger.info("Applied FSDP to the model") return model