Skip to content

Commit

Permalink
[Train] Don't use NCCL_BLOCKING_WAIT (#29562)
Browse files Browse the repository at this point in the history
From the pytorch docs, we should use NCCL_ASYNC_ERROR_HANDLING instead.

Signed-off-by: amogkam <[email protected]>
  • Loading branch information
amogkam authored Nov 17, 2022
1 parent 64131d6 commit 2631806
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
28 changes: 28 additions & 0 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from collections import Counter
import time

from unittest.mock import patch
import pytest
Expand All @@ -9,6 +10,7 @@
from torch.utils.data import DataLoader, DistributedSampler

import ray
from ray.exceptions import RayTaskError
from ray.air import session
from ray import tune

Expand Down Expand Up @@ -307,6 +309,32 @@ def assert_env_var_set():
worker_group.execute(assert_env_var_set)


def test_torch_fail_on_nccl_timeout(ray_start_4_cpus_2_gpus):
"""Tests that TorchTrainer raises exception on NCCL timeouts."""

def train_fn():
model = torch.nn.Linear(1, 1)
model = train.torch.prepare_model(model)

# Rank 0 worker will never reach the collective operation.
# NCCL should timeout.
if session.get_world_rank() == 0:
while True:
time.sleep(100)

torch.distributed.barrier()

trainer = TorchTrainer(
train_fn,
scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
torch_config=TorchConfig(timeout_s=5),
)

# Training should fail and not hang.
with pytest.raises(RayTaskError):
trainer.fit()


if __name__ == "__main__":
import sys

Expand Down
17 changes: 13 additions & 4 deletions python/ray/train/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,21 @@ def _setup_torch_process_group(
)
logger.debug(f"using {backend}")

if backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
# See the `timeout` arg in https://pytorch.org/docs/master/
# distributed.html#torch.distributed.init_process_group for description of
# NCCL_ASYNC_ERROR_HANDLING. We do not use NCCL_BLOCKING_WAIT due to performance
# overhead.
if (
backend == "nccl"
and "NCCL_ASYNC_ERROR_HANDLING" not in os.environ
and "NCCL_BLOCKING_WAIT" not in os.environ
):
logger.debug(
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0."
"Setting NCCL_ASYNC_ERROR_HANDLING to fail if NCCL collective "
"communication operations are timing out. "
"To override this behavior, you can set NCCL_ASYNC_ERROR_HANDLING=0."
)
os.environ["NCCL_BLOCKING_WAIT"] = "1"
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"

dist.init_process_group(
backend=backend,
Expand Down

0 comments on commit 2631806

Please sign in to comment.