-
Notifications
You must be signed in to change notification settings - Fork 320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tp_overlap init failed when tp_size != world_size #994
Comments
Hi @liuhatry -- I tested PR #986 earlier today on 2 nodes of 8xH100s and confirmed that the
These cases work with both with and without Can test if this PR resolves your issue? |
Hi @denera I tested examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py My environment:H800 NVIDIA-SMI 535.161.08 Driver Version: 535.161.08 CUDA Version: 12.2
commands
LOGFile "examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py", line 175, in train |
Hi @liuhatry -- I updated PR #986 to prefer Gloo backend over NCCL whenever possible for bootstrapping Userbuffers. The application code still has to initialize NCCL process groups for TE modules, but this change eliminates the requirement for the I've also tested the example problem with Could you test if the latest changes resolve your issue? |
Hi @denera I tested the new code, it still failed !!! [NVTE] Bootstrapping Userbuffers with backend="gloo" File "examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py", line 188, in train |
Hi @liuhatry — if the Gloo backend in PyTorch distributed can’t do an all-gather over processes on a single host CPU, that suggests something is broken outside of Transformer Engine. Could you verify that you can perform the necessary collectives on host tensors with pure PyTorch (no TE code)? For example: import os
import torch
import torch.distributed as dist
# initialize default NCCL process group
world_rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
dist.init_process_group(backend="nccl", rank=world_rank, world_size=world_size)
# get a Gloo group for comms with host tensors
gloo_world = dist.new_group(backend="gloo")
localdata = torch.tensor([world_rank], dtype=torch.uint8, device="cpu")
globaldata = torch.empty(world_size, style=torch.uint8, device="cpu")
dist.all_gather_into_tensor(globaldata, localdata, gloo_world)
# verify result of all gather
reference = torch.tensor(list(range(world_size)), style=torch.uint8, device="cpu")
assert torch.eq(globaldata, reference) The above is a simple representation of what happens when you run the comm+GEMM overlap example problem. The application initializes a default NCCL process group, and Transformer Engine then creates a Gloo process group for host tensor communication during Userbuffers bootstrapping. If this does not run correctly, I would recommend working with your sysadmin to troubleshoot the machine you’re running on, and possibly reaching out to the PyTorch team as well for their feedback. |
Hi @denera your example code cannot run as before
https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L3392 File "gloo.py", line 14, in |
Hi @liuhatry -- you're correct, Gloo supports import os
import socket
import torch
import torch.distributed as dist
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
MASTER_ADDR = str(os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())))
MASTER_PORT = str(os.getenv("MASTER_PORT", "1234"))
BOOTSTRAP_BACKEND = str(os.getenv("BOOTSTRAP_BACKEND", "gloo")).lower()
BOOTSTRAP_DEVICE = "cuda" if BOOTSTRAP_BACKEND == "nccl" else "cpu"
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(backend="nccl",
init_method=f"tcp://{MASTER_ADDR}:{MASTER_PORT}",
rank=WORLD_RANK,
world_size=WORLD_SIZE)
bootstrap_world = dist.new_group(backend=BOOTSTRAP_BACKEND)
localdata = torch.tensor([WORLD_RANK], dtype=torch.uint8, device=BOOTSTRAP_DEVICE)
globaldata = torch.empty(WORLD_SIZE, dtype=torch.uint8, device=BOOTSTRAP_DEVICE)
dist.all_gather(list(globaldata.chunk(WORLD_SIZE)), localdata, bootstrap_world)
reference = torch.tensor(list(range(WORLD_SIZE)), dtype=torch.uint8, device=BOOTSTRAP_DEVICE)
assert torch.eq(globaldata, reference) In order to be able to use comm+GEMM overlap, your platform needs to be able to run this code snippet with If you can get this running, then the |
Hi @denera I can run your snippet, but cannot run ln_mlp_with_overlap.py with two nodes. snippetone node in nccl can runexport BOOTSTRAP_BACKEND=nccl one node in gloo can runexport GLOO_SOCKET_IFNAME=bond1 two nodes in nccl can runexport BOOTSTRAP_BACKEND=nccl two nodes in gloo can runexport GLOO_SOCKET_IFNAME=bond1 ln_mlp_with_overlap.pyone node in nccl can runexport BOOTSTRAP_BACKEND=nccl one node in gloo can runexport GLOO_SOCKET_IFNAME=bond1 two nodes in nccl cannot runtorchrun --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_addr $MASTER_ADDR --master_port 60000 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --tcp-init --bootstrap-backend nccl !!! [NVTE] Number of physical nodes: 1 two nodes in gloo cannot runexport GLOO_SOCKET_IFNAME=bond1 !!! [NVTE] Bootstrapping Userbuffers with backend="gloo" I checked the code(https://github.com/denera/TransformerEngine/blob/userbuffers-missing-data-parallel-pg/transformer_engine/pytorch/module/base.py#L128), found socket.gethostname() return the same result in my env, and the local_size is 16. |
The UDS (Unix Domain Socket) error you’re seeing is coming from the CUDA Multicast handle initialization. Userbuffers bootstrapping needs to communicate CUDA Multicast handles between processes, but these handles are POSIX file descriptors that have to be communicated over Unix Domain Sockets in order for the kernel to reconstruct the descriptors correctly on every process. Trying to do this with comm libraries like MPI or NCCL mangles the descriptors and prevents processes from importing each others’ Multicast handles. The code under It looks like these Unix Domain Sockets aren’t working correctly on your nodes. Are there any limitation's on your node(s) or permission issues that may be causing this? I will also try to provide a minimal C++ tester to possibly help diagnose it without TE in the mix. In the meantime, please disable Multicast with |
Hi @denera !!! [NVTE] Bootstrapping Userbuffers with backend="gloo" |
Revisiting an issue from earlier:
I'm guessing this is a consequence of a containerized cluster environment like Kubernetes, correct? The nodes are probably reachable by IP address but not by hostname. Can you try replacing base.py lines 127-128 with the following? hostname = socket.gethostname()
ifname = os.getenv("NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME",
os.getenv("GLOO_SOCKET_IFNAME")))
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(),
0x8915,
struct.pack('256s', ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(world_size)] This will attempt to construct a list of global ranks on each physical node via the IP address on the specified network interface. You already run with Edit: I updated PR #986 with this change and it should automatically kick in whenever you run with |
Hi @denera Thanks for your reply. Because my torch version is 2.1, when set num_replicas=2, the example will fail: Can you update the example to fix the problem? |
Hi @liuhatry -- I recently merged PR986 into TE/main after confirming that it is resolving multi-node issues for us in NeMo and Mcore. These changes also update the example problem to no longer use device mesh to handle replicas on a single node run, so it should be able to support older PyTorch versions. Could you please test TE/main and let me know if it resolves the issue for you? |
Hi @denera, I met a new error when run in two nodes, the intra node barrier will hang: I modified the code(new_group()) both in the example and TE like this:
And now I can run the example successfully. The problem is the new_group function requires all processes enter this function, even if they are not going to be members of the group. |
Hi @denera, can you please help to confirm this issue, thks. |
Hi @liuhatry -- I've reproduced the issue with TE/main but I'm able to resolve it by adding I would also strongly recommend updating PyTorch and NCCL versions to the latest available and initialize the default NCCL process group in PyTorch with |
Hi @liuhatry -- I filed a PR with a fix for this issue. Could you confirm if it works for you? Thanks! |
@liuhatry -- thanks for confirming. I merged the PR so TE/main should now have all the fixes we've discussed here. Please feel free to close the issue here if everything is resolved on your end. Thanks! |
Machine
NVIDIA-SMI 535.161.08 Driver Version: 535.161.08 CUDA Version: 12.2
SoftWare
torch 2.1.1
transformer-engine 1.9.0.dev0+56e0b35
Run Cmd:
deepspeed --hostfile hostfile --master_addr ${MASTER_IP} pretrain_gpt.py --deepspeed-activation-checkpointing --deepspeed_config=ds_config_gpt_test.json --deepspeed --tensor-model-parallel-size 4 --pipeline-model-parallel-size 1 ......
LOG
tp_size=8, world_size=8
!!! [UB] Create UbufP2PCommOverlap Communicator
UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
NCCL_TOPO_AFFINITY set by environment to 0
MC initialized succesfully, window size = 549755813888
!!! [UBP2P] Register UBuf 1
!!! [UBP2P] Register UBuf 2
!!! [UBP2P] Register UBuf 3
!!! [UBP2P] Register UBuf 4
!!! [UB] Register UBuf 5
!!! [UB] Register UBuf 6
!!! [UB] Register UBuf 7
!!! [UB] Register UBuf 8
!!! [UB] Register UBuf 9
!!! [UB] Register UBuf 10
rank 7 | iteration 1/ 45776 | consumed samples: 128 | consumed tokens: 262144 | elapsed time this iteration (ms): 33222.7 |
tp_size=4, world_size=8
Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:223 ''
!!! [UB] Create UbufP2PCommOverlap Communicator
UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:223 ''
Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:223 ''
Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:223 ''
tp_size=4, world_size=8 UB_SKIPMC=1
!!! [UB] Create UbufP2PCommOverlap Communicator
UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
MC NOT initialized and used
NCCL_TOPO_AFFINITY set by environment to 0
NCCL_TOPO_AFFINITY set by environment to 0
UB: warning region 1 size 40 MB registered without MC access
!!! [UBP2P] Register UBuf 1
Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:513 ''
Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:513 ''
Failed, NCCL errorTransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:513 ''
UB: warning region 2 size 40 MB registered without MC access
!!! [UBP2P] Register UBuf 2
Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:513 '
The text was updated successfully, but these errors were encountered: