diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index 1b54f1f42f9..ef4600bb4d7 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -1,4 +1,5 @@ import os +import warnings from typing import Any, Callable, Mapping, Optional, Tuple import torch @@ -68,7 +69,7 @@ def __init__(self, do_init: bool = False, **kwargs: Any) -> None: self._local_rank = hvd.local_rank() - if torch.cuda.is_available(): + if do_init and torch.cuda.is_available(): torch.cuda.set_device(self._local_rank) self._setup_attrs() @@ -97,6 +98,11 @@ def get_node_rank(self) -> int: def device(self) -> torch.device: if torch.cuda.is_available(): index = torch.cuda.current_device() + if index < self.get_local_rank(): + warnings.warn( + "Current device index is less than current local rank. " + "Please, make sure to call torch.cuda.set_device(local_rank)." + ) return torch.device("cuda:{}".format(index)) return torch.device("cpu") diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index e64ddeb7a69..ff489d743a7 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -97,7 +97,12 @@ def _init_from_context(self) -> None: self._setup_attrs() def _compute_nproc_per_node(self) -> int: - tensor = torch.tensor([self.get_local_rank() + 1]).to(self.device()) + local_rank = self.get_local_rank() + device = torch.device("cpu") + if self.backend() == dist.Backend.NCCL: + # we manually set cuda device to local rank in order to avoid a hang on all_reduce + device = torch.device("cuda:{}".format(local_rank)) + tensor = torch.tensor([self.get_local_rank() + 1]).to(device) dist.all_reduce(tensor, op=dist.ReduceOp.MAX) return int(tensor.item()) @@ -220,6 +225,11 @@ def get_node_rank(self) -> int: def device(self) -> torch.device: if self.backend() == dist.Backend.NCCL: index = torch.cuda.current_device() + if index < self.get_local_rank(): + warnings.warn( + "Current device index is less than current local rank. " + "Please, make sure to call torch.cuda.set_device(local_rank)." + ) return torch.device("cuda:{}".format(index)) return torch.device("cpu") diff --git a/tests/ignite/conftest.py b/tests/ignite/conftest.py index 1a8f31edebf..e2abc6ab09b 100644 --- a/tests/ignite/conftest.py +++ b/tests/ignite/conftest.py @@ -231,6 +231,10 @@ def _hvd_task_with_init(func, args): import horovod.torch as hvd hvd.init() + lrank = hvd.local_rank() + if torch.cuda.is_available(): + torch.cuda.set_device(lrank) + func(*args) hvd.shutdown() diff --git a/tests/ignite/distributed/comp_models/test_horovod.py b/tests/ignite/distributed/comp_models/test_horovod.py index daff622d2d4..800e1b7bfa3 100644 --- a/tests/ignite/distributed/comp_models/test_horovod.py +++ b/tests/ignite/distributed/comp_models/test_horovod.py @@ -36,7 +36,6 @@ def _test__hvd_dist_model_create_from_backend_no_dist(backend, true_device): model = _HorovodDistModel.create_from_backend(backend=backend) assert hvd.rank() > -1 - print("true_device", true_device) _assert_model( model, { @@ -109,10 +108,13 @@ def _test__hvd_dist_model_create_from_context_dist(true_backend, true_device): assert _HorovodDistModel.create_from_context() is None hvd.init() + lrank = hvd.local_rank() + if torch.cuda.is_available(): + torch.cuda.set_device(lrank) true_conf = { "device": true_device, - "local_rank": hvd.local_rank(), + "local_rank": lrank, "rank": hvd.rank(), "world_size": hvd.size(), "node_index": 0, @@ -121,6 +123,7 @@ def _test__hvd_dist_model_create_from_context_dist(true_backend, true_device): } model = _HorovodDistModel.create_from_context() + assert model.backend() == true_backend _assert_model(model, true_conf) hvd.shutdown() @@ -142,18 +145,53 @@ def test__hvd_dist_model_create_no_dist_cuda(gloo_hvd_executor): @pytest.mark.distributed @pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Skip if has GPU") -def test__hvd_dist_model_create_dist(gloo_hvd_executor): +def test__hvd_dist_model_create_dist_1(gloo_hvd_executor): gloo_hvd_executor(_test__hvd_dist_model_create_from_backend_dist, ("horovod", "cpu"), np=4) + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Skip if has GPU") +def test__hvd_dist_model_create_dist_2(gloo_hvd_executor): gloo_hvd_executor(_test__hvd_dist_model_create_from_context_dist, ("horovod", "cpu"), np=4) @pytest.mark.distributed @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test__hvd_dist_model_create_dist_cuda(gloo_hvd_executor): +def test__hvd_dist_model_create_dist_cuda_1(gloo_hvd_executor): gloo_hvd_executor(_test__hvd_dist_model_create_from_backend_dist, ("horovod", "cuda"), np=torch.cuda.device_count()) + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test__hvd_dist_model_create_dist_cuda_2(gloo_hvd_executor): gloo_hvd_executor(_test__hvd_dist_model_create_from_context_dist, ("horovod", "cuda"), np=torch.cuda.device_count()) +def _test__hvd_dist_model_warning_index_less_localrank(): + + assert torch.cuda.is_available() + assert _HorovodDistModel.create_from_context() is None + + hvd.init() + # We deliberately incorrectly set cuda device to 0 + torch.cuda.set_device(0) + + model = _HorovodDistModel.create_from_context() + assert isinstance(model, _HorovodDistModel), "{} vs _HorovodDistModel".format(type(model)) + + if hvd.local_rank() == 1: + with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."): + model.device() + + hvd.shutdown() + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs") +def test__hvd_dist_model_warning_index_less_localrank(gloo_hvd_executor): + gloo_hvd_executor(_test__hvd_dist_model_warning_index_less_localrank, (), np=torch.cuda.device_count()) + + def _test_dist_spawn_fn(local_rank, backend, world_size, device): from ignite.distributed.utils import _model diff --git a/tests/ignite/distributed/comp_models/test_native.py b/tests/ignite/distributed/comp_models/test_native.py index b509df82b18..6bdf6d821d7 100644 --- a/tests/ignite/distributed/comp_models/test_native.py +++ b/tests/ignite/distributed/comp_models/test_native.py @@ -211,6 +211,8 @@ def _test__native_dist_model_create_from_context_dist(local_rank, rank, world_si dist.init_process_group(true_backend, "tcp://0.0.0.0:2222", world_size=world_size, rank=rank) dist.barrier() + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) true_conf = { "device": true_device, @@ -244,22 +246,52 @@ def test__native_dist_model_create_no_dist_nccl(clean_env): @pytest.mark.distributed -def test__native_dist_model_create_dist_gloo(local_rank, world_size): +def test__native_dist_model_create_dist_gloo_1(local_rank, world_size): _test__native_dist_model_create_from_backend_dist(local_rank, local_rank, world_size, "gloo", "cpu") + + +@pytest.mark.distributed +def test__native_dist_model_create_dist_gloo_2(local_rank, world_size): _test__native_dist_model_create_from_context_dist(local_rank, local_rank, world_size, "gloo", "cpu") @pytest.mark.distributed @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test__native_dist_model_create_dist_nccl(local_rank, world_size): +def test__native_dist_model_create_dist_nccl_1(local_rank, world_size): _test__native_dist_model_create_from_backend_dist( local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) ) + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test__native_dist_model_create_dist_nccl_2(local_rank, world_size): _test__native_dist_model_create_from_context_dist( local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) ) +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs") +def test__native_dist_model_warning_index_less_localrank(local_rank, world_size): + + assert _NativeDistModel.create_from_context() is None + + dist.init_process_group("nccl", "tcp://0.0.0.0:2222", world_size=world_size, rank=local_rank) + dist.barrier() + # We deliberately incorrectly set cuda device to 0 + torch.cuda.set_device(0) + + model = _NativeDistModel.create_from_context() + assert isinstance(model, _NativeDistModel), "{} vs _NativeDistModel".format(type(model)) + + if local_rank == 1: + with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."): + model.device() + + dist.destroy_process_group() + + def _test_dist_spawn_fn(local_rank, backend, world_size, device): from ignite.distributed.utils import _model diff --git a/tests/ignite/distributed/utils/test_horovod.py b/tests/ignite/distributed/utils/test_horovod.py index 4f7a0689532..aa1355743d8 100644 --- a/tests/ignite/distributed/utils/test_horovod.py +++ b/tests/ignite/distributed/utils/test_horovod.py @@ -75,6 +75,9 @@ def _test_sync_as_hvd(): from ignite.distributed.comp_models.horovod import _HorovodDistModel hvd.init() + lrank = hvd.local_rank() + if torch.cuda.is_available(): + torch.cuda.set_device(lrank) _test_sync(_HorovodDistModel) @@ -111,6 +114,10 @@ def _test_idist_methods_in_hvd_context(backend, device): ws = hvd.size() rank = hvd.rank() local_rank = hvd.local_rank() + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + _test_distrib_config(local_rank, backend=backend, ws=ws, true_device=device, rank=rank) hvd.shutdown() diff --git a/tests/run_cpu_tests.sh b/tests/run_cpu_tests.sh index 4e48e19f9f2..78dd6b16d81 100644 --- a/tests/run_cpu_tests.sh +++ b/tests/run_cpu_tests.sh @@ -2,7 +2,7 @@ set -xeu -CUDA_VISIBLE_DEVICES="" py.test --tx 4*popen//python=python$CI_PYTHON_VERSION --cov ignite --cov-report term-missing -vvv tests/ +CUDA_VISIBLE_DEVICES="" py.test --tx 4*popen//python=python${CI_PYTHON_VERSION:-3.7} --cov ignite --cov-report term-missing -vvv tests/ # https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02 if [ "${SKIP_DISTRIB_TESTS:-0}" -eq "1" ]; then @@ -11,4 +11,4 @@ fi export WORLD_SIZE=2 -CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python$CI_PYTHON_VERSION tests -m distributed -vvv +CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python${CI_PYTHON_VERSION:-3.7} tests -m distributed -vvv