Skip to content

Commit

Permalink
Fix test and docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent 9501489 commit 907dcc4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 39 deletions.
13 changes: 1 addition & 12 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,18 +371,7 @@ def all_gather_tensors_with_shapes(
rank = idist.get_rank()
ws = idist.get_world_size()
tensor = torch.randn(rank+1, rank+2)
tensors = all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)], )
# To gather from a group of processes:
group = list(range(1, ws)) # Process #0 excluded
tensors = all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in group], group)
if rank == 0:
# For processes not in the group, the sole tensor of the process is returned in a list.
assert tensors == [tensor]
else:
assert (tensors[rank-1] == tensor).all()
tensors = idist.all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)])
Args:
tensor: tensor to collect across participating processes.
Expand Down
55 changes: 28 additions & 27 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,36 +312,37 @@ def _test_idist_all_gather_tensors_with_shapes(device):


def _test_idist_all_gather_tensors_with_shapes_group(device):
torch.manual_seed(41)
if idist.get_world_size() > 1:
torch.manual_seed(41)

rank = idist.get_rank()
ranks = list(range(1, idist.get_world_size()))
ws = idist.get_world_size()
bnd = idist.backend()
if rank in ranks:
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
else:
rank_tensor = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
else:
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
rank = idist.get_rank()
ranks = list(range(1, idist.get_world_size()))
ws = idist.get_world_size()
bnd = idist.backend()
if rank in ranks:
for r in ranks:
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r - 1]).all()
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
else:
rank_tensor = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
else:
assert [rank_tensor] == tensors
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
if rank in ranks:
for r in ranks:
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r - 1]).all()
else:
assert [rank_tensor] == tensors


def _test_distrib_broadcast(device):
Expand Down

0 comments on commit 907dcc4

Please sign in to comment.