Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent 0d8eb3b commit 538f0c0
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 26 deletions.
14 changes: 7 additions & 7 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group="abc")


def _test_idist_all_gather_tensors_with_different_shapes(device):
def _test_idist_all_gather_tensors_with_shapes(device):
torch.manual_seed(41)
rank = idist.get_rank()
ws = idist.get_world_size()
Expand All @@ -311,7 +311,7 @@ def _test_idist_all_gather_tensors_with_different_shapes(device):
assert (r_tensor == tensors[r]).all()


def _test_idist_all_gather_tensors_with_different_shapes_group(device):
def _test_idist_all_gather_tensors_with_shapes_group(device):
torch.manual_seed(41)

rank = idist.get_rank()
Expand All @@ -332,16 +332,16 @@ def _test_idist_all_gather_tensors_with_different_shapes_group(device):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
else:
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
for r in range(ws):
if r in ranks:
if rank in ranks:
for r in ranks:
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r]).all()
else:
assert tensors == rank_tensor
assert (r_tensor == tensors[r - 1]).all()
else:
assert [rank_tensor] == tensors


def _test_distrib_broadcast(device):
Expand Down
8 changes: 4 additions & 4 deletions tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_different_shapes,
_test_idist_all_gather_tensors_with_different_shapes_group,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -165,8 +165,8 @@ def test_idist_all_gather_hvd(gloo_hvd_executor):
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_different_shapes, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_different_shapes_group, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_shapes, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_shapes_group, (device,), np=np, do_init=True)


@pytest.mark.distributed
Expand Down
16 changes: 8 additions & 8 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_different_shapes,
_test_idist_all_gather_tensors_with_different_shapes_group,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -258,18 +258,18 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_all_gather_tensors_with_different_shapes_nccl(distributed_context_single_node_nccl):
def test_idist_all_gather_tensors_with_shapes_nccl(distributed_context_single_node_nccl):
device = idist.device()
_test_idist_all_gather_tensors_with_different_shapes(device)
_test_idist_all_gather_tensors_with_different_shapes_group(device)
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
def test_idist_all_gather_tensors_with_different_shapes_gloo(distributed_context_single_node_gloo):
def test_idist_all_gather_tensors_with_shapes_gloo(distributed_context_single_node_gloo):
device = idist.device()
_test_idist_all_gather_tensors_with_different_shapes(device)
_test_idist_all_gather_tensors_with_different_shapes_group(device)
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


@pytest.mark.distributed
Expand Down
6 changes: 3 additions & 3 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_test_distrib_barrier,
_test_distrib_broadcast,
_test_distrib_new_group,
_test_idist_all_gather_tensors_with_different_shapes,
_test_idist_all_gather_tensors_with_shapes,
_test_sync,
)

Expand Down Expand Up @@ -71,15 +71,15 @@ def test_idist__model_methods_no_dist():
def test_idist_collective_ops_no_dist():
_test_distrib_all_reduce("cpu")
_test_distrib_all_gather("cpu")
_test_idist_all_gather_tensors_with_different_shapes("cpu")
_test_idist_all_gather_tensors_with_shapes("cpu")
_test_distrib_barrier("cpu")
_test_distrib_broadcast("cpu")
_test_distrib_new_group("cpu")

if torch.cuda.device_count() > 1:
_test_distrib_all_reduce("cuda")
_test_distrib_all_gather("cuda")
_test_idist_all_gather_tensors_with_different_shapes("cuda")
_test_idist_all_gather_tensors_with_shapes("cuda")
_test_distrib_barrier("cuda")
_test_distrib_broadcast("cuda")
_test_distrib_new_group("cuda")
8 changes: 4 additions & 4 deletions tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_different_shapes,
_test_idist_all_gather_tensors_with_different_shapes_group,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -153,8 +153,8 @@ def test_idist_all_gather_xla():
device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)
_test_idist_all_gather_tensors_with_different_shapes(device)
_test_idist_all_gather_tensors_with_different_shapes_group(device)
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


def _test_idist_all_gather_xla_in_child_proc(index):
Expand Down

0 comments on commit 538f0c0

Please sign in to comment.