diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index d078cd26b57..3ec8c05a1da 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -371,18 +371,7 @@ def all_gather_tensors_with_shapes( rank = idist.get_rank() ws = idist.get_world_size() tensor = torch.randn(rank+1, rank+2) - tensors = all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)], ) - - # To gather from a group of processes: - - group = list(range(1, ws)) # Process #0 excluded - - tensors = all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in group], group) - if rank == 0: - # For processes not in the group, the sole tensor of the process is returned in a list. - assert tensors == [tensor] - else: - assert (tensors[rank-1] == tensor).all() + tensors = idist.all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)]) Args: tensor: tensor to collect across participating processes. diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index a27af17c100..1f3ad55dd84 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -312,36 +312,37 @@ def _test_idist_all_gather_tensors_with_shapes(device): def _test_idist_all_gather_tensors_with_shapes_group(device): - torch.manual_seed(41) + if idist.get_world_size() > 1: + torch.manual_seed(41) - rank = idist.get_rank() - ranks = list(range(1, idist.get_world_size())) - ws = idist.get_world_size() - bnd = idist.backend() - if rank in ranks: - reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device) - rank_tensor = reference[ - rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1, - rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2, - rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3, - ] - else: - rank_tensor = torch.tensor([rank], device=device) - if bnd in ("horovod"): - with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): - tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks) - else: - tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks) + rank = idist.get_rank() + ranks = list(range(1, idist.get_world_size())) + ws = idist.get_world_size() + bnd = idist.backend() if rank in ranks: - for r in ranks: - r_tensor = reference[ - r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1, - r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2, - r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3, - ] - assert (r_tensor == tensors[r - 1]).all() + reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device) + rank_tensor = reference[ + rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1, + rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2, + rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3, + ] + else: + rank_tensor = torch.tensor([rank], device=device) + if bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks) else: - assert [rank_tensor] == tensors + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks) + if rank in ranks: + for r in ranks: + r_tensor = reference[ + r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1, + r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2, + r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3, + ] + assert (r_tensor == tensors[r - 1]).all() + else: + assert [rank_tensor] == tensors def _test_distrib_broadcast(device):