diff --git a/python/ray/train/tests/test_gpu.py b/python/ray/train/tests/test_gpu.py index 1b6ec2c699d3..9d096cf66a3e 100644 --- a/python/ray/train/tests/test_gpu.py +++ b/python/ray/train/tests/test_gpu.py @@ -1,5 +1,6 @@ import os from collections import Counter +import time from unittest.mock import patch import pytest @@ -9,6 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler import ray +from ray.exceptions import RayTaskError from ray.air import session from ray import tune @@ -307,6 +309,32 @@ def assert_env_var_set(): worker_group.execute(assert_env_var_set) +def test_torch_fail_on_nccl_timeout(ray_start_4_cpus_2_gpus): + """Tests that TorchTrainer raises exception on NCCL timeouts.""" + + def train_fn(): + model = torch.nn.Linear(1, 1) + model = train.torch.prepare_model(model) + + # Rank 0 worker will never reach the collective operation. + # NCCL should timeout. + if session.get_world_rank() == 0: + while True: + time.sleep(100) + + torch.distributed.barrier() + + trainer = TorchTrainer( + train_fn, + scaling_config=ScalingConfig(num_workers=2, use_gpu=True), + torch_config=TorchConfig(timeout_s=5), + ) + + # Training should fail and not hang. + with pytest.raises(RayTaskError): + trainer.fit() + + if __name__ == "__main__": import sys diff --git a/python/ray/train/torch/config.py b/python/ray/train/torch/config.py index fca7ec745169..9d2ba2d0a60a 100644 --- a/python/ray/train/torch/config.py +++ b/python/ray/train/torch/config.py @@ -94,12 +94,21 @@ def _setup_torch_process_group( ) logger.debug(f"using {backend}") - if backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ: + # See the `timeout` arg in https://pytorch.org/docs/master/ + # distributed.html#torch.distributed.init_process_group for description of + # NCCL_ASYNC_ERROR_HANDLING. We do not use NCCL_BLOCKING_WAIT due to performance + # overhead. + if ( + backend == "nccl" + and "NCCL_ASYNC_ERROR_HANDLING" not in os.environ + and "NCCL_BLOCKING_WAIT" not in os.environ + ): logger.debug( - "Setting NCCL_BLOCKING_WAIT for detecting node failure. " - "To override this behavior, you can set NCCL_BLOCKING_WAIT=0." + "Setting NCCL_ASYNC_ERROR_HANDLING to fail if NCCL collective " + "communication operations are timing out. " + "To override this behavior, you can set NCCL_ASYNC_ERROR_HANDLING=0." ) - os.environ["NCCL_BLOCKING_WAIT"] = "1" + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" dist.init_process_group( backend=backend,