Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Don't use NCCL_BLOCKING_WAIT #29562

Merged
merged 7 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from collections import Counter
import time

from unittest.mock import patch
import pytest
Expand All @@ -9,6 +10,7 @@
from torch.utils.data import DataLoader, DistributedSampler

import ray
from ray.exceptions import RayTaskError
from ray.air import session
from ray import tune

Expand Down Expand Up @@ -307,6 +309,32 @@ def assert_env_var_set():
worker_group.execute(assert_env_var_set)


def test_torch_fail_on_nccl_timeout(ray_start_4_cpus_2_gpus):
"""Tests that TorchTrainer raises exception on NCCL timeouts."""

def train_fn():
model = torch.nn.Linear(1, 1)
model = train.torch.prepare_model(model)

# Rank 0 worker will never reach the collective operation.
# NCCL should timeout.
if session.get_world_rank() == 0:
while True:
time.sleep(100)

torch.distributed.barrier()

trainer = TorchTrainer(
train_fn,
scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
torch_config=TorchConfig(timeout_s=5),
)

# Training should fail and not hang.
with pytest.raises(RayTaskError):
trainer.fit()


if __name__ == "__main__":
import sys

Expand Down
17 changes: 13 additions & 4 deletions python/ray/train/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,21 @@ def _setup_torch_process_group(
)
logger.debug(f"using {backend}")

if backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
# See the `timeout` arg in https://pytorch.org/docs/master/
# distributed.html#torch.distributed.init_process_group for description of
# NCCL_ASYNC_ERROR_HANDLING. We do not use NCCL_BLOCKING_WAIT due to performance
# overhead.
if (
backend == "nccl"
and "NCCL_ASYNC_ERROR_HANDLING" not in os.environ
and "NCCL_BLOCKING_WAIT" not in os.environ
):
logger.debug(
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0."
"Setting NCCL_ASYNC_ERROR_HANDLING to fail if NCCL collective "
"communication operations are timing out. "
"To override this behavior, you can set NCCL_ASYNC_ERROR_HANDLING=0."
)
os.environ["NCCL_BLOCKING_WAIT"] = "1"
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update the test plan for failure behavior ? iiuc documentation says NCCL_ASYNC_ERROR_HANDLING is more performant but crashes the process, but NCCL_BLOCKING_WAIT will provide errors to the user which can be caught and handled --> this has implication of ray trainer's error handling semantics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, we should trigger this code path and make sure the crash output provides enough information to the user before merging. I don't think we can do much better than crashing unfortunately.

Copy link
Contributor Author

@amogkam amogkam Oct 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed we should do it. Any suggestions on how to trigger this code path? Couldn't think of an easy way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Launch data-parallel training (minimum two actors) that use NCCL to do the allreduce. Make one of the actors enter a while True: sleep loop so that it never enters the allreduce. Then, after 30 minutes, you'll see how PyTorch crashes the process. Will be even easier if you reduce the timeout ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, looks like an exception is being raised

(RayTrainWorker pid=13803) [E ProcessGroupNCCL.cpp:737] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4, OpType=ALLREDUCE, Timeout(ms)=5000) ran for 7751 milliseconds before timing out.
(RayTrainWorker pid=13803) [E ProcessGroupNCCL.cpp:414] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
(RayTrainWorker pid=13803) [2022-10-21 16:23:36,638 E 13803 13875] logging.cc:97: Unhandled exception: St13runtime_error. what(): [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4, OpType=ALLREDUCE, Timeout(ms)=5000) ran for 7751 milliseconds before timing out.
(RayTrainWorker pid=13803) [2022-10-21 16:23:36,648 E 13803 13875] logging.cc:104: Stack trace: 
(RayTrainWorker pid=13803)  /home/ray/anaconda3/lib/python3.8/site-packages/ray/_raylet.so(+0xc74dda) [0x7f0934867dda] ray::operator<<()
(RayTrainWorker pid=13803) /home/ray/anaconda3/lib/python3.8/site-packages/ray/_raylet.so(+0xc77598) [0x7f093486a598] ray::TerminateHandler()
(RayTrainWorker pid=13803) /home/ray/anaconda3/bin/../lib/libstdc++.so.6(+0xacf6f) [0x7f0933b2af6f] __cxxabiv1::__terminate()
(RayTrainWorker pid=13803) /home/ray/anaconda3/bin/../lib/libstdc++.so.6(+0xacfb1) [0x7f0933b2afb1] __cxxabiv1::__unexpected()
(RayTrainWorker pid=13803) /home/ray/anaconda3/bin/../lib/libstdc++.so.6(+0xacf6c) [0x7f0933b2af6c] __cxxabiv1::__terminate()
(RayTrainWorker pid=13803) /home/ray/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so(_ZN4c10d16ProcessGroupNCCL8WorkNCCL15handleNCCLGuardEv+0x19f) [0x7efbc5ae2d4f] c10d::ProcessGroupNCCL::WorkNCCL::handleNCCLGuard()
(RayTrainWorker pid=13803) /home/ray/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so(_ZN4c10d16ProcessGroupNCCL15workCleanupLoopEv+0x199) [0x7efbc5ae71c9] c10d::ProcessGroupNCCL::workCleanupLoop()
(RayTrainWorker pid=13803) /home/ray/anaconda3/bin/../lib/libstdc++.so.6(+0xc9039) [0x7f0933b47039] execute_native_thread_routine
(RayTrainWorker pid=13803) /usr/lib/x86_64-linux-gnu/libpthread.so.0(+0x8609) [0x7f09354e5609] start_thread
(RayTrainWorker pid=13803) /usr/lib/x86_64-linux-gnu/libc.so.6(clone+0x43) [0x7f093540a133] __clone
(RayTrainWorker pid=13803) 

But the Ray Actor is still alive, causing training to hang. @rkooo567 do you know why the actor is not terminating when receiving this exception?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the ray actor still alive? I think the process that contained the ray actor should be killed by SIGABRT https://github.com/ray-project/ray/blob/master/src/ray/util/logging.cc#L106

Copy link
Contributor Author

@amogkam amogkam Oct 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the actor is still alive. Not sure why the std::abort() is not being captured.

Note, that the std:abort() is not being run in the main thread, but from what I understand, it should kill the entire process.

Copy link
Contributor Author

@amogkam amogkam Nov 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test


dist.init_process_group(
backend=backend,
Expand Down