Skip to content

Commit

Permalink
warning if current device index is lower than current local rank
Browse files Browse the repository at this point in the history
  • Loading branch information
HelioStrike committed Oct 5, 2020
1 parent 221b6de commit 726041a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from typing import Callable, Mapping, Optional, Tuple

import torch
Expand Down Expand Up @@ -97,6 +98,8 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if torch.cuda.is_available():
index = torch.cuda.current_device()
if index < self.get_local_rank():
warnings.warn("Current device index is less than current local rank.")
return torch.device("cuda:{}".format(index))
return torch.device("cpu")

Expand Down
2 changes: 2 additions & 0 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if self.backend() == dist.Backend.NCCL:
index = torch.cuda.current_device()
if index < self.get_local_rank():
warnings.warn("Current device index is less than current local rank.")
return torch.device("cuda:{}".format(index))
return torch.device("cpu")

Expand Down
9 changes: 9 additions & 0 deletions tests/ignite/distributed/comp_models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,12 @@ def test__hvd_dist_model_spawn_cuda():
nproc_per_node=num_workers_per_machine,
use_gloo=True,
)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
def test__warning_if_deviceindex_less_than_localrank(local_rank, world_size):
with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."):
_test__hvd_dist_model_create_from_context_dist(
local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank)
)
9 changes: 9 additions & 0 deletions tests/ignite/distributed/comp_models/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,12 @@ def test__native_dist_model_spawn_gloo():
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__native_dist_model_spawn_nccl():
_test__native_dist_model_spawn("nccl", num_workers_per_machine=torch.cuda.device_count(), device="cuda")


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
def test__warning_if_deviceindex_less_than_localrank(local_rank, world_size):
with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."):
_test__native_dist_model_create_from_context_dist(
local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank)
)

0 comments on commit 726041a

Please sign in to comment.