Skip to content

Commit

Permalink
warning if current device index is lower than current local rank
Browse files Browse the repository at this point in the history
  • Loading branch information
HelioStrike committed Oct 4, 2020
1 parent d384dc6 commit 6ac5b17
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from typing import Callable, Mapping, Optional, Tuple

import torch
Expand Down Expand Up @@ -97,6 +98,8 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if torch.cuda.is_available():
index = torch.cuda.current_device()
if index < self.get_local_rank():
warnings.warn("Current device index is less than current local rank.")
return torch.device("cuda:{}".format(index))
return torch.device("cpu")

Expand Down
2 changes: 2 additions & 0 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if self.backend() == dist.Backend.NCCL:
index = torch.cuda.current_device()
if index < self.get_local_rank():
warnings.warn("Current device index is less than current local rank.")
return torch.device("cuda:{}".format(index))
return torch.device("cpu")

Expand Down
23 changes: 23 additions & 0 deletions tests/ignite/distributed/comp_models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,26 @@ def test__hvd_dist_model_spawn_cuda():
nproc_per_node=num_workers_per_machine,
use_gloo=True,
)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
def test__warning_if_deviceindex_less_than_localrank():
import os
import torch.distributed as dist

os.environ["RANK"] = "1"
os.environ["WORLD_SIZE"] = "2"
os.environ["MASTER_ADDR"] = "0.0.0.0"
os.environ["MASTER_PORT"] = "2222"

dist.init_process_group(backend="nccl", init_method="env://")

pytest.warns("Current device index is less than current local rank.", _HorovodDistModel.get_world_size)

dist.destroy_process_group()

del os.environ["RANK"]
del os.environ["WORLD_SIZE"]
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
24 changes: 24 additions & 0 deletions tests/ignite/distributed/comp_models/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,27 @@ def test__native_dist_model_spawn_gloo():
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__native_dist_model_spawn_nccl():
_test__native_dist_model_spawn("nccl", num_workers_per_machine=torch.cuda.device_count(), device="cuda")


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
def test__warning_if_deviceindex_less_than_localrank():
import os
import torch.distributed as dist
import ignite.distributed as idist

os.environ["RANK"] = "1"
os.environ["WORLD_SIZE"] = "2"
os.environ["MASTER_ADDR"] = "0.0.0.0"
os.environ["MASTER_PORT"] = "2222"

dist.init_process_group(backend="nccl", init_method="env://")

pytest.warns("Current device index is less than current local rank.", idist.get_world_size)

dist.destroy_process_group()

del os.environ["RANK"]
del os.environ["WORLD_SIZE"]
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]

0 comments on commit 6ac5b17

Please sign in to comment.