Skip to content

Commit

Permalink
[Train] Make prepare_model always use the correct device (#29104)
Browse files Browse the repository at this point in the history
Signed-off-by: Amog Kamsetty [email protected]

Previously, prepare_model would use the local rank as the device even though local rank may not be the same as the actual device index. This mismatch can happen when CUDA_VISIBLE_DEVICES is set for example, which we do by default in Ray Train.

We should always use train.torch.get_device() as the device values for wrapping in DDP.

Closes #28996
  • Loading branch information
amogkam authored Oct 6, 2022
1 parent 1e616ef commit 2217f0c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
26 changes: 26 additions & 0 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,32 @@ def train_fn():
trainer.shutdown()


def test_torch_prepare_model_uses_device(ray_start_4_cpus_2_gpus):
"""Tests if `prepare_model` uses the train.torch.get_device even if it does not
match with the local rank."""
# The below test should pass without errors.

@patch.object(
ray.train.torch.train_loop_utils._TorchAccelerator,
"get_device",
lambda self: torch.device(f"cuda:{1 - train.local_rank()}"),
)
def train_func():
# These assert statements must hold for prepare_model to wrap with DDP.
assert torch.cuda.is_available()
assert train.world_size() > 1
model = torch.nn.Linear(1, 1)
data = torch.ones(1)
data = data.to(train.torch.get_device())
model = train.torch.prepare_model(model)
model(data)

trainer = TorchTrainer(
train_func, scaling_config=ScalingConfig(num_workers=2, use_gpu=True)
)
trainer.fit()


# TODO: Refactor as a backend test.
@pytest.mark.parametrize(
"dataset", (LinearDataset, LinearDatasetDict, NonTensorDataset)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/torch/train_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ def model_get_state(self):
DataParallel = DistributedDataParallel
if torch.cuda.is_available():
parallel_strategy_kwargs = {
"device_ids": [rank],
"output_device": rank,
"device_ids": [device],
"output_device": device,
**parallel_strategy_kwargs,
}
else:
Expand Down

0 comments on commit 2217f0c

Please sign in to comment.