Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Helper function all_gather_tensors_with_shapes() #3281

Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion ignite/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import itertools
import socket
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import distributed as dist

from ignite.distributed.comp_models import (
_SerialModel,
Expand Down Expand Up @@ -43,6 +45,7 @@
"one_rank_only",
"new_group",
"one_rank_first",
"all_gather_tensors_with_shapes",
]

_model = _SerialModel()
Expand Down Expand Up @@ -350,6 +353,71 @@ def all_reduce(
return _model.all_reduce(tensor, op, group=group)


def all_gather_tensors_with_shapes(
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None
) -> List[torch.Tensor]:
"""Helper method to gather tensors of possibly different shapes but with the same number of dimensions
across processes.

This function gets the shapes of participating tensors as input so you should know them beforehand. If your
tensors are of different number of dimensions or you don't know their shapes beforehand, you can use
``torch.distributed.all_gather_object``, otherwise this method is quite faster.

Examples:
.. code-block:: python

import ignite.distributed as idist

rank = idist.get_rank()
ws = idist.get_world_size()
tensor = torch.randn(rank+1, rank+2)
tensors = all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)], )
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

# To gather from a group of processes:

group = list(range(1, ws)) # Process #0 excluded

tensors = all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in group], group)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
if rank == 0:
# For processes not in the group, the sole tensor of the process is returned in a list.
assert tensors == [tensor]
else:
assert (tensors[rank-1] == tensor).all()

Args:
tensor: tensor to collect across participating processes.
shapes: A sequence containing the shape of participating processes' ``tensor`` s.
group: list of integer or the process group for each backend. If None, the default process group will be used.

Returns:
List[torch.Tensor]
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER:
return [tensor]

max_shape = torch.tensor(shapes).amax(dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we could actually get tensor shapes using all_gather such that shapes arg can be optional ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can. Do you want it in this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you, if you would like it in another PR, OK to me as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it in another PR, I'll merge this one as CI is green

padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
padded_tensor = torch.nn.functional.pad(
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
)
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group)
return [
all_padded_tensors[
[
slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size)
for dim, dim_size in enumerate(shape)
]
]
for rank, shape in enumerate(shapes)
]


def all_gather(
tensor: Union[torch.Tensor, float, str, Any], group: Optional[Union[Any, List[int]]] = None
) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]:
Expand Down
55 changes: 54 additions & 1 deletion tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist

import ignite.distributed as idist
from ignite.distributed.utils import sync
from ignite.distributed.utils import all_gather_tensors_with_shapes, sync
from ignite.engine import Engine, Events


Expand Down Expand Up @@ -291,6 +291,59 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group="abc")


def _test_idist_all_gather_tensors_with_shapes(device):
torch.manual_seed(41)
rank = idist.get_rank()
ws = idist.get_world_size()
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)])
for r in range(ws):
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r]).all()


def _test_idist_all_gather_tensors_with_shapes_group(device):
torch.manual_seed(41)

rank = idist.get_rank()
ranks = list(range(1, idist.get_world_size()))
ws = idist.get_world_size()
bnd = idist.backend()
if rank in ranks:
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
else:
rank_tensor = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
else:
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
if rank in ranks:
for r in ranks:
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r - 1]).all()
else:
assert [rank_tensor] == tensors


def _test_distrib_broadcast(device):
rank = idist.get_rank()
ws = idist.get_world_size()
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -163,6 +165,8 @@ def test_idist_all_gather_hvd(gloo_hvd_executor):
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_shapes, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_shapes_group, (device,), np=np, do_init=True)


@pytest.mark.distributed
Expand Down
19 changes: 19 additions & 0 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -253,6 +255,23 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
_test_distrib_all_gather_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_all_gather_tensors_with_shapes_nccl(distributed_context_single_node_nccl):
device = idist.device()
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
def test_idist_all_gather_tensors_with_shapes_gloo(distributed_context_single_node_gloo):
device = idist.device()
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_test_distrib_barrier,
_test_distrib_broadcast,
_test_distrib_new_group,
_test_idist_all_gather_tensors_with_shapes,
_test_sync,
)

Expand Down Expand Up @@ -70,13 +71,15 @@ def test_idist__model_methods_no_dist():
def test_idist_collective_ops_no_dist():
_test_distrib_all_reduce("cpu")
_test_distrib_all_gather("cpu")
_test_idist_all_gather_tensors_with_shapes("cpu")
_test_distrib_barrier("cpu")
_test_distrib_broadcast("cpu")
_test_distrib_new_group("cpu")

if torch.cuda.device_count() > 1:
_test_distrib_all_reduce("cuda")
_test_distrib_all_gather("cuda")
_test_idist_all_gather_tensors_with_shapes("cuda")
_test_distrib_barrier("cuda")
_test_distrib_broadcast("cuda")
_test_distrib_new_group("cuda")
4 changes: 4 additions & 0 deletions tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -151,6 +153,8 @@ def test_idist_all_gather_xla():
device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


def _test_idist_all_gather_xla_in_child_proc(index):
Expand Down
Loading