You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
will use the same device cuda:0 for all_reduce op.
For older NCCL, it will setup itself such that i-th proc uses cuda:0 device and thus following collective op will hang with other devices. For example
importtorchimporttorch.distributedasdistdefmain():
# !!! We do not call torch.cuda.set_device("cuda:lrank")dist.init_process_group(backend="nccl", init_method="env://")
importoslocal_rank=int(os.environ["LOCAL_RANK"])
tensor=torch.tensor([local_rank+1]).to("cuda")
dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
print(tensor)
tensor=torch.tensor([local_rank+1]).to("cuda:{}".format(local_rank))
# PROGRAM WILL HANG HERE >>>>dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
print(tensor)
dist.destroy_process_group()
if__name__=="__main__":
main()
Let's improve the code by raising a warning for native and horovod dist models when calling idist.device() if we encounter the situation where current cuda device index is smaller than the local rank.
PyTorch docs suggest that to use 1 proc per 1 cuda device => local rank should be equal to cuda device index.
However, it is also possible to have M procs with K devices / proc (e.g. 4 procs with 2 GPUs per proc) => local rank <= cuda device index.
The text was updated successfully, but these errors were encountered:
🚀 Feature
Following #1307, if user does not set
torch.cuda.set_device("cuda:lrank")
, ignite's codeignite/ignite/distributed/comp_models/native.py
Lines 99 to 102 in 0c41778
will use the same device
cuda:0
forall_reduce
op.For older NCCL, it will setup itself such that i-th proc uses
cuda:0
device and thus following collective op will hang with other devices. For exampleFor newer NCCL, it raises the error as in #1307.
Let's improve the code by raising a warning for native and horovod dist models when calling
idist.device()
if we encounter the situation where current cuda device index is smaller than the local rank.PyTorch docs suggest that to use 1 proc per 1 cuda device => local rank should be equal to cuda device index.
However, it is also possible to have M procs with K devices / proc (e.g. 4 procs with 2 GPUs per proc) => local rank <= cuda device index.
The text was updated successfully, but these errors were encountered: