Skip to content

Commit

Permalink
fix aten._fused_adamw_.tensor_lr
Browse files Browse the repository at this point in the history
  • Loading branch information
wz337 authored and pytorchmergebot committed May 21, 2024
1 parent 980f5ac commit 3287eeb
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 38 deletions.
43 changes: 40 additions & 3 deletions test/distributed/_tensor/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def test_optimizer_foreach_supported_types_include_DTensor(self):
def test_adam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

# TODO: add fused_adam support
adam_configs = [
# lr as a Tensor is not supported for capturable=False and foreach=True
adam_float_lr_configs = [
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05},
Expand All @@ -105,6 +105,8 @@ def test_adam_1d_sharding(self):
"maximize": True,
"amsgrad": True,
},
]
fused_adam_float_lr_configs = [
{"lr": 0.1, "fused": True},
{"lr": 0.1, "weight_decay": 0.05, "amsgrad": True, "fused": True},
{
Expand All @@ -115,6 +117,22 @@ def test_adam_1d_sharding(self):
"fused": True,
},
]
# lr could be a Tensor or a float when fused=True for adam optimizer
fused_adam_tensor_lr_configs = [
{**config, "lr": torch.tensor(0.1)}
for config in fused_adam_float_lr_configs
]
fused_adam_tensor_lr_configs.extend(
[
{**config, "lr": torch.tensor([0.1])}
for config in fused_adam_float_lr_configs
]
)
adam_configs = [
*adam_float_lr_configs,
*fused_adam_float_lr_configs,
*fused_adam_tensor_lr_configs,
]

for config in adam_configs:
mod = MLPModule(self.device_type)
Expand All @@ -134,7 +152,8 @@ def test_adam_1d_sharding(self):
def test_adamw_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

adamw_configs = [
# lr as a Tensor is not supported for capturable=False and foreach=True
adamw_float_lr_configs = [
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05},
Expand All @@ -153,6 +172,8 @@ def test_adamw_1d_sharding(self):
"maximize": True,
"amsgrad": True,
},
]
fused_adamw_float_lr_configs = [
{"lr": 0.1, "weight_decay": 0.05, "fused": True},
{
"lr": 0.1,
Expand All @@ -172,6 +193,22 @@ def test_adamw_1d_sharding(self):
"fused": True,
},
]
# lr could be a Tensor or a float when fused=True for adamW optimizer
fused_adamw_tensor_lr_configs = [
{**config, "lr": torch.tensor(0.1)}
for config in fused_adamw_float_lr_configs
]
fused_adamw_tensor_lr_configs.extend(
[
{**config, "lr": torch.tensor([0.1])}
for config in fused_adamw_float_lr_configs
]
)
adamw_configs = [
*adamw_float_lr_configs,
*fused_adamw_float_lr_configs,
*fused_adamw_tensor_lr_configs,
]

for config in adamw_configs:
mod = MLPModule(self.device_type)
Expand Down
76 changes: 41 additions & 35 deletions torch/distributed/_tensor/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,41 @@ def unwrap_to_op_info(
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None

def try_get_replicate_spec(
tensor_arg: torch.Tensor, mesh: "DeviceMesh"
) -> DTensorSpec:
# tensor_arg is an instance of torch.Tensor and could be an arg or kwarg.
if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
warnings.warn(
"Found a non-scalar tensor with numel=1 and ndim!=0, "
"we are implicitly creating a replicated DTensor for it. "
"However, please consider changing it to a scalar tensor "
"or explicitly create a DTensor under distributed enviroment."
)

# if the arg.numel() == 1, arg.ndim could be 0 or 1.
if (
tensor_arg.ndim <= 1
and tensor_arg.numel() == 1
or self._allow_implicit_replication
):
# scalar tensor can be safely treated as replicated
replication_spec = DTensorSpec(
mesh,
(Replicate(),) * mesh.ndim,
tensor_meta=TensorMeta(
shape=tensor_arg.shape,
stride=tensor_arg.stride(),
dtype=tensor_arg.dtype,
),
)
else:
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
return replication_spec

for arg in args_list:
if isinstance(arg, dtensor.DTensor):
args_schema.append(arg._spec)
Expand All @@ -309,37 +344,9 @@ def unwrap_to_op_info(
else:
mesh = arg.device_mesh
elif isinstance(arg, torch.Tensor):
if arg.numel() == 1 and arg.ndim == 1:
warnings.warn(
"Found a non-scalar tensor with numel=1 and ndim!=0, "
"we are implicitly creating a replicated DTensor for it. "
"However, please consider changing it to a scalar tensor "
"or explicitly create a DTensor under distributed enviroment."
)

# if the arg.numel() == 1, arg.ndim could be 0 or 1.
if (
arg.ndim <= 1
and arg.numel() == 1
or self._allow_implicit_replication
):
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
# scalar tensor can be safely treated as replicated
args_schema.append(
DTensorSpec(
mesh,
(Replicate(),) * mesh.ndim,
tensor_meta=TensorMeta(
shape=arg.shape, stride=arg.stride(), dtype=arg.dtype
),
)
)
local_args.append(arg)
else:
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
args_schema.append(try_get_replicate_spec(arg, mesh))
local_args.append(arg)
else:
args_schema.append(arg)
local_args.append(arg)
Expand All @@ -356,10 +363,9 @@ def unwrap_to_op_info(
else:
mesh = v.device_mesh
elif isinstance(v, torch.Tensor):
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
kwargs_schema[k] = try_get_replicate_spec(v, mesh)
local_kwargs[k] = v
else:
kwargs_schema[k] = v
local_kwargs[k] = v
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/_tensor/ops/pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,12 @@ def list_linear_pointwise_strategy(
fused_ops = [
aten._fused_adam_.default,
aten._fused_adam.default,
aten._fused_adam.tensor_lr,
aten._fused_adam_.tensor_lr,
aten._fused_adamw_.default,
aten._fused_adamw.default,
aten._fused_adamw.tensor_lr,
aten._fused_adamw_.tensor_lr,
]

for op in fused_ops:
Expand Down

0 comments on commit 3287eeb

Please sign in to comment.