diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 4f5caef7..647049e7 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -84,6 +84,7 @@ def parallelize_llama( reduce_dtype=TORCH_DTYPE_MAP[ job_config.training.mixed_precision_reduce ], + tp_enabled=parallel_dims.tp_enabled, pp_enabled=parallel_dims.pp_enabled, ) else: @@ -290,6 +291,7 @@ def apply_fsdp( dp_mesh: DeviceMesh, param_dtype: torch.dtype, reduce_dtype: torch.dtype, + tp_enabled: bool, pp_enabled: bool, ): """ @@ -298,6 +300,10 @@ def apply_fsdp( mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if tp_enabled: + # check if strided sharding is enabled, which is necessary for 2D/3D DCP + check_strided_sharding_enabled() + for layer_id, transformer_block in model.layers.items(): if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch @@ -314,9 +320,6 @@ def apply_fsdp( ) fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) - # check if strided sharding is enabled, which is necessary for 2D/3D DCP - check_strided_sharding_enabled() - logger.info("Applied FSDP to the model")