Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register_mr_buffers:544 NCCL WARN NET/OFI Unable to register memory (type = 2) for device 0. RC: -22, Error: Invalid argument #584

Open
visatish opened this issue Sep 11, 2024 · 8 comments

Comments

@visatish
Copy link

visatish commented Sep 11, 2024

Hi,

I'm trying to run a nccl allreduce benchmark on AWS EC2 and running into the following error:

register_mr_buffers:544 NCCL WARN NET/OFI Unable to register memory (type = 2) for device 0. RC: -22, Error: Invalid argument

Setup:

2x p4d.24xlarge

"Deep Learning AMI GPU PyTorch 2.1.0 (Ubuntu 20.04)" AMI

Relevant libs (note that I have installed the latest torch 2.4.1 & deps fresh):

  • torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl
  • nvidia-nccl-cu12==2.20.5

Single EFA-enabled NIC (note that I know this instance type can support up to 4x, but I'm starting with 1):

(base) ubuntu@ip-172-31-36-110:~$ fi_info -p efa -t FI_EP_RDM
provider: efa
    fabric: efa
    domain: rdmap16s27-rdm
    version: 118.20
    type: FI_EP_RDM
    protocol: FI_PROTO_EFA
(base) ubuntu@ip-172-31-32-222:~$ fi_info --version
fi_info: 1.18.2amzn1.0
libfabric: 1.18.2amzn1.0
libfabric api: 1.18
(base) ubuntu@ip-172-31-36-110:~$ lspci -i efa
00:00.0 Host bridge: Intel Corporation 440FX - 82441FX PMC [Natoma]
00:01.0 ISA bridge: Intel Corporation 82371SB PIIX3 ISA [Natoma/Triton II]
00:01.3 Non-VGA unclassified device: Intel Corporation 82371AB/EB/MB PIIX4 ACPI (rev 08)
00:03.0 VGA compatible controller: Amazon.com, Inc. Device 1111
00:04.0 Non-Volatile memory controller: Amazon.com, Inc. Device 8061
10:00.0 Ethernet controller: Amazon.com, Inc. Elastic Network Adapter (ENA)
10:1b.0 Ethernet controller: Amazon.com, Inc. Elastic Fabric Adapter (EFA)
10:1c.0 3D controller: NVIDIA Corporation Device 20b0 (rev a1)
10:1d.0 3D controller: NVIDIA Corporation Device 20b0 (rev a1)
10:1e.0 Non-Volatile memory controller: Amazon.com, Inc. NVMe SSD Controller
10:1f.0 Non-Volatile memory controller: Amazon.com, Inc. NVMe SSD Controller
20:1c.0 3D controller: NVIDIA Corporation Device 20b0 (rev a1)
20:1d.0 3D controller: NVIDIA Corporation Device 20b0 (rev a1)
20:1e.0 Non-Volatile memory controller: Amazon.com, Inc. NVMe SSD Controller
20:1f.0 Non-Volatile memory controller: Amazon.com, Inc. NVMe SSD Controller
80:1a.0 Bridge: NVIDIA Corporation Device 1af1 (rev a1)
80:1b.0 Bridge: NVIDIA Corporation Device 1af1 (rev a1)
80:1c.0 Bridge: NVIDIA Corporation Device 1af1 (rev a1)
80:1d.0 Bridge: NVIDIA Corporation Device 1af1 (rev a1)
80:1e.0 Bridge: NVIDIA Corporation Device 1af1 (rev a1)
80:1f.0 Bridge: NVIDIA Corporation Device 1af1 (rev a1)
90:1c.0 3D controller: NVIDIA Corporation Device 20b0 (rev a1)
90:1d.0 3D controller: NVIDIA Corporation Device 20b0 (rev a1)
90:1e.0 Non-Volatile memory controller: Amazon.com, Inc. NVMe SSD Controller
90:1f.0 Non-Volatile memory controller: Amazon.com, Inc. NVMe SSD Controller
a0:1c.0 3D controller: NVIDIA Corporation Device 20b0 (rev a1)
a0:1d.0 3D controller: NVIDIA Corporation Device 20b0 (rev a1)
a0:1e.0 Non-Volatile memory controller: Amazon.com, Inc. NVMe SSD Controller
a0:1f.0 Non-Volatile memory controller: Amazon.com, Inc. NVMe SSD Controller

Cmd:

From https://github.com/stas00/ml-engineering.git:

cd ml-engineering/network/benchmarks
NCCL_DEBUG=INFO python -u -m torch.distributed.run --nproc_per_node 8 --nnodes 2 --rdzv_endpoint <head node addr>:8888 --rdzv_backend c10d --max_restarts 0 --role `hostname -s`: --tee 3 all_reduce_bench.py

Output:

nccl_out.txt

Note this particular portion:

(head, rank=0, pid=32220) [ip-172-31-33-151:1]:ip-172-31-33-151:38696:38818 [1] register_mr_buffers:544 NCCL WARN NET/OFI Unable to register memory (type = 2) for device 0. RC: -22, Error: Invalid argument
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:ip-172-31-33-151:38696:38818 [1] NCCL INFO transport/net.cc:779 -> 2
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:ip-172-31-33-151:38696:38818 [1] NCCL INFO misc/socket.cc:47 -> 3
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:ip-172-31-33-151:38696:38818 [1] NCCL INFO misc/socket.cc:58 -> 3
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:ip-172-31-33-151:38696:38818 [1] NCCL INFO misc/socket.cc:775 -> 3
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:ip-172-31-33-151:38696:38818 [1] NCCL INFO proxy.cc:1384 -> 3
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:ip-172-31-33-151:38696:38818 [1] proxy.cc:1533 NCCL WARN [Service thread] Error encountered progressing operation=Connect, res=3, closing connection
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:
(head, rank=0, pid=32220) [ip-172-31-33-151:1]:ip-172-31-33-151:38696:38818 [1] proxy.cc:1567 NCCL WARN [Proxy Service 1] Failed to execute operation Connect from rank 1, retcode 3
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30883 [1] register_mr_buffers:544 NCCL WARN NET/OFI Unable to register memory (type = 2) for device 0. RC: -22, Error: Invalid argument
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30883 [1] NCCL INFO transport/net.cc:779 -> 2
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30872 [1] NCCL INFO transport/net.cc:304 -> 2
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30872 [1] NCCL INFO transport.cc:165 -> 2
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30872 [1] NCCL INFO init.cc:1222 -> 2
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30872 [1] NCCL INFO init.cc:1501 -> 2
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30872 [1] NCCL INFO group.cc:64 -> 2 [Async thread]
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30780 [1] NCCL INFO group.cc:418 -> 2
(worker1, rank=1, pid=30588, ip=172.31.42.166) [ip-172-31-42-166:1]:ip-172-31-42-166:30780:30780 [1] NCCL INFO init.cc:1876 -> 2

I'm not quite sure what Error: Invalid argument could be - any help is appreciated. Thnx!

@visatish
Copy link
Author

@bwbarrett I noticed you had helped with some related issues

@AmedeoSapio
Copy link
Contributor

Hi,
can you please try enabling all 4 EFAs?

@visatish
Copy link
Author

visatish commented Sep 11, 2024

@AmedeoSapio I was actually able to get it working with the native pytorch version in the AMI, i.e. conda activate pytorch:

(head, rank=0, pid=36870) [ip-172-31-40-103:0]:The average bandwidth of all_reduce with a 4.0GB payload (5 trials, 16 ranks):
(head, rank=0, pid=36870) [ip-172-31-40-103:0]: algbw: 11.135 GBps (89.1 Gbps)
(head, rank=0, pid=36870) [ip-172-31-40-103:0]: busbw: 20.878 GBps (167.0 Gbps)

I will try with 4 NICs, but presumably that will just increase bandwidth.

This hints that there is some incompatibility between aws-ofi-nccl and the latest torch + torch deps (I have updated the original issue to note that I was installing the latest fresh - i.e. pip install torch before running cmds).

@rauteric
Copy link
Contributor

Hello. There is a known incompatibility between NCCL 2.19+ and Libfabric from EFA installers before 1.29. I'm guessing using the latest PyTorch will upgrade the NCCL version.

Workarounds are any of the following:

  1. Set FI_EFA_SET_CUDA_SYNC_MEMOPS=0 in the environment
  2. Downgrade to NCCL 2.18 (which it sounds like using native PyTorch will do)
  3. Upgrade to EFA installer 1.29 or greater (latest is 1.34)

@visatish
Copy link
Author

Hi @rauteric, good to know! Is there any significant performance downside to (1) as that would be the least-invasive for our stack atm?

@rauteric
Copy link
Contributor

Hi @rauteric, good to know! Is there any significant performance downside to (1) as that would be the least-invasive for our stack atm?

No, this setting merely prevents Libfabric from setting a property on a CUDA buffer (sync_memops) that is not needed for NCCL. It shouldn't have any performance impact.

@visatish
Copy link
Author

Gotcha, confirmed that FI_EFA_SET_CUDA_SYNC_MEMOPS=0 works with the latest pytorch+NCCL stack in the original example.

Might be nice for future new users to maybe "pin" this in some fashion under "Known problems/limitations" in an easy-to-find place or have an up-to-date compatibility chart. But for now, guess it's indexed in this ticket :)

Thanks again for the help!

@aws-nslick
Copy link
Contributor

For future searchers, if it's at all possible, please do prefer to update efa.ko and libfabric instead of relying on this environment variable -- this specific workaround doesn't come with a perf hit, but you are missing out on other performance improvements and bug fixes by using older versions, and you should update whenever you can.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants