Skip to content

Commit

Permalink
Merge pull request #561 from allenai/shanea/delay-device-mesh-import
Browse files Browse the repository at this point in the history
Delay device mesh import
  • Loading branch information
2015aroras authored Apr 26, 2024
2 parents 4e8746d + 67e0b64 commit 295d309
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.multiprocessing as mp
import wandb
from packaging import version
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

Expand Down Expand Up @@ -138,14 +137,17 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
param_init_fn = None

# Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica
device_mesh: Optional[DeviceMesh] = None
device_mesh = None
hybrid_sharding_fsdp_kwargs = {}
if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2):
if version.parse(torch.__version__) < version.parse("2.2.0"):
# Device mesh was not added to PyTorch until v2.2.0
raise OLMoConfigurationError(
"OLMo training does not correctly support hybrid sharding before torch 2.2.0"
)

from torch.distributed.device_mesh import init_device_mesh

num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or (
get_world_size() // get_local_world_size()
)
Expand All @@ -158,17 +160,18 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes")

device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas))
hybrid_sharding_fsdp_kwargs["device_mesh"] = device_mesh

fsdp_model = FSDP(
olmo_model,
device_mesh=device_mesh,
sharding_strategy=cfg.fsdp.sharding_strategy,
mixed_precision=cfg.fsdp_precision,
auto_wrap_policy=wrap_policy,
use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics
limit_all_gathers=True,
device_id=get_local_rank(),
param_init_fn=param_init_fn,
**hybrid_sharding_fsdp_kwargs,
)
# when param_init_fn is None, FSDP will call reset_parameters() automatically
if param_init_fn is not None:
Expand Down

0 comments on commit 295d309

Please sign in to comment.