Skip to content

Commit

Permalink
warning if current device index is lower than current local rank (#1335
Browse files Browse the repository at this point in the history
…) (#1376)

* warning if current device index is lower than current local rank (#1335)

* warning if current device index is lower than current local rank

* Updated code and tests

* Fixed formatting

* Updated code and tests for horovod
- fixed failing test

* Updated tests

Co-authored-by: vfdev-5 <[email protected]>

* Removed debug prints

* Fixed failing hvd tests

Co-authored-by: Sai Sandeep Mutyala <[email protected]>
  • Loading branch information
vfdev-5 and HelioStrike committed Oct 8, 2020
1 parent 52bb624 commit ba1f67b
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 10 deletions.
8 changes: 7 additions & 1 deletion ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from typing import Any, Callable, Mapping, Optional, Tuple

import torch
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(self, do_init: bool = False, **kwargs: Any) -> None:

self._local_rank = hvd.local_rank()

if torch.cuda.is_available():
if do_init and torch.cuda.is_available():
torch.cuda.set_device(self._local_rank)

self._setup_attrs()
Expand Down Expand Up @@ -97,6 +98,11 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if torch.cuda.is_available():
index = torch.cuda.current_device()
if index < self.get_local_rank():
warnings.warn(
"Current device index is less than current local rank. "
"Please, make sure to call torch.cuda.set_device(local_rank)."
)
return torch.device("cuda:{}".format(index))
return torch.device("cpu")

Expand Down
12 changes: 11 additions & 1 deletion ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ def _init_from_context(self) -> None:
self._setup_attrs()

def _compute_nproc_per_node(self) -> int:
tensor = torch.tensor([self.get_local_rank() + 1]).to(self.device())
local_rank = self.get_local_rank()
device = torch.device("cpu")
if self.backend() == dist.Backend.NCCL:
# we manually set cuda device to local rank in order to avoid a hang on all_reduce
device = torch.device("cuda:{}".format(local_rank))
tensor = torch.tensor([self.get_local_rank() + 1]).to(device)
dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
return int(tensor.item())

Expand Down Expand Up @@ -220,6 +225,11 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if self.backend() == dist.Backend.NCCL:
index = torch.cuda.current_device()
if index < self.get_local_rank():
warnings.warn(
"Current device index is less than current local rank. "
"Please, make sure to call torch.cuda.set_device(local_rank)."
)
return torch.device("cuda:{}".format(index))
return torch.device("cpu")

Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def _hvd_task_with_init(func, args):
import horovod.torch as hvd

hvd.init()
lrank = hvd.local_rank()
if torch.cuda.is_available():
torch.cuda.set_device(lrank)

func(*args)
hvd.shutdown()

Expand Down
46 changes: 42 additions & 4 deletions tests/ignite/distributed/comp_models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def _test__hvd_dist_model_create_from_backend_no_dist(backend, true_device):
model = _HorovodDistModel.create_from_backend(backend=backend)

assert hvd.rank() > -1
print("true_device", true_device)
_assert_model(
model,
{
Expand Down Expand Up @@ -109,10 +108,13 @@ def _test__hvd_dist_model_create_from_context_dist(true_backend, true_device):
assert _HorovodDistModel.create_from_context() is None

hvd.init()
lrank = hvd.local_rank()
if torch.cuda.is_available():
torch.cuda.set_device(lrank)

true_conf = {
"device": true_device,
"local_rank": hvd.local_rank(),
"local_rank": lrank,
"rank": hvd.rank(),
"world_size": hvd.size(),
"node_index": 0,
Expand All @@ -121,6 +123,7 @@ def _test__hvd_dist_model_create_from_context_dist(true_backend, true_device):
}

model = _HorovodDistModel.create_from_context()
assert model.backend() == true_backend
_assert_model(model, true_conf)

hvd.shutdown()
Expand All @@ -142,18 +145,53 @@ def test__hvd_dist_model_create_no_dist_cuda(gloo_hvd_executor):

@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Skip if has GPU")
def test__hvd_dist_model_create_dist(gloo_hvd_executor):
def test__hvd_dist_model_create_dist_1(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_create_from_backend_dist, ("horovod", "cpu"), np=4)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Skip if has GPU")
def test__hvd_dist_model_create_dist_2(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_create_from_context_dist, ("horovod", "cpu"), np=4)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__hvd_dist_model_create_dist_cuda(gloo_hvd_executor):
def test__hvd_dist_model_create_dist_cuda_1(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_create_from_backend_dist, ("horovod", "cuda"), np=torch.cuda.device_count())


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__hvd_dist_model_create_dist_cuda_2(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_create_from_context_dist, ("horovod", "cuda"), np=torch.cuda.device_count())


def _test__hvd_dist_model_warning_index_less_localrank():

assert torch.cuda.is_available()
assert _HorovodDistModel.create_from_context() is None

hvd.init()
# We deliberately incorrectly set cuda device to 0
torch.cuda.set_device(0)

model = _HorovodDistModel.create_from_context()
assert isinstance(model, _HorovodDistModel), "{} vs _HorovodDistModel".format(type(model))

if hvd.local_rank() == 1:
with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."):
model.device()

hvd.shutdown()


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
def test__hvd_dist_model_warning_index_less_localrank(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_warning_index_less_localrank, (), np=torch.cuda.device_count())


def _test_dist_spawn_fn(local_rank, backend, world_size, device):
from ignite.distributed.utils import _model

Expand Down
36 changes: 34 additions & 2 deletions tests/ignite/distributed/comp_models/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def _test__native_dist_model_create_from_context_dist(local_rank, rank, world_si

dist.init_process_group(true_backend, "tcp://0.0.0.0:2222", world_size=world_size, rank=rank)
dist.barrier()
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)

true_conf = {
"device": true_device,
Expand Down Expand Up @@ -244,22 +246,52 @@ def test__native_dist_model_create_no_dist_nccl(clean_env):


@pytest.mark.distributed
def test__native_dist_model_create_dist_gloo(local_rank, world_size):
def test__native_dist_model_create_dist_gloo_1(local_rank, world_size):
_test__native_dist_model_create_from_backend_dist(local_rank, local_rank, world_size, "gloo", "cpu")


@pytest.mark.distributed
def test__native_dist_model_create_dist_gloo_2(local_rank, world_size):
_test__native_dist_model_create_from_context_dist(local_rank, local_rank, world_size, "gloo", "cpu")


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__native_dist_model_create_dist_nccl(local_rank, world_size):
def test__native_dist_model_create_dist_nccl_1(local_rank, world_size):
_test__native_dist_model_create_from_backend_dist(
local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank)
)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__native_dist_model_create_dist_nccl_2(local_rank, world_size):
_test__native_dist_model_create_from_context_dist(
local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank)
)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
def test__native_dist_model_warning_index_less_localrank(local_rank, world_size):

assert _NativeDistModel.create_from_context() is None

dist.init_process_group("nccl", "tcp://0.0.0.0:2222", world_size=world_size, rank=local_rank)
dist.barrier()
# We deliberately incorrectly set cuda device to 0
torch.cuda.set_device(0)

model = _NativeDistModel.create_from_context()
assert isinstance(model, _NativeDistModel), "{} vs _NativeDistModel".format(type(model))

if local_rank == 1:
with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."):
model.device()

dist.destroy_process_group()


def _test_dist_spawn_fn(local_rank, backend, world_size, device):
from ignite.distributed.utils import _model

Expand Down
7 changes: 7 additions & 0 deletions tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def _test_sync_as_hvd():
from ignite.distributed.comp_models.horovod import _HorovodDistModel

hvd.init()
lrank = hvd.local_rank()
if torch.cuda.is_available():
torch.cuda.set_device(lrank)

_test_sync(_HorovodDistModel)

Expand Down Expand Up @@ -111,6 +114,10 @@ def _test_idist_methods_in_hvd_context(backend, device):
ws = hvd.size()
rank = hvd.rank()
local_rank = hvd.local_rank()

if torch.cuda.is_available():
torch.cuda.set_device(local_rank)

_test_distrib_config(local_rank, backend=backend, ws=ws, true_device=device, rank=rank)

hvd.shutdown()
Expand Down
4 changes: 2 additions & 2 deletions tests/run_cpu_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

set -xeu

CUDA_VISIBLE_DEVICES="" py.test --tx 4*popen//python=python$CI_PYTHON_VERSION --cov ignite --cov-report term-missing -vvv tests/
CUDA_VISIBLE_DEVICES="" py.test --tx 4*popen//python=python${CI_PYTHON_VERSION:-3.7} --cov ignite --cov-report term-missing -vvv tests/

# https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02
if [ "${SKIP_DISTRIB_TESTS:-0}" -eq "1" ]; then
Expand All @@ -11,4 +11,4 @@ fi

export WORLD_SIZE=2

CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python$CI_PYTHON_VERSION tests -m distributed -vvv
CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python${CI_PYTHON_VERSION:-3.7} tests -m distributed -vvv

0 comments on commit ba1f67b

Please sign in to comment.