Skip to content

Commit

Permalink
fix name
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes committed Apr 1, 2021
1 parent 5cbe7f6 commit 2b31859
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
12 changes: 6 additions & 6 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@


class CommunicateTopology(object):
def __init__(self, hybrid_names, dims):
self._parallel_names = hybrid_names
def __init__(self, hybrid_group_names, dims):
self._parallel_names = hybrid_group_names
self._dims = dims
self.coordinate = collections.namedtuple('Coordinate',
self._parallel_names)
Expand All @@ -35,7 +35,7 @@ def __init__(self, hybrid_names, dims):
self._rank2coord = dict(
zip(self._coord2rank.values(), self._coord2rank.keys()))

def get_parallel_names(self):
def get_hybrid_group_names(self):
return self._parallel_names

def get_dim(self, axis_name):
Expand Down Expand Up @@ -64,7 +64,7 @@ def get_axis_list(self, axis_name, index):
ranks.sort()
return ranks

def get_dim_num(self, axis_name):
def get_dim_size(self, axis_name):
assert axis_name in self._parallel_names
return self._dims[self._parallel_names.index(axis_name)]

Expand All @@ -76,7 +76,7 @@ def get_comm_list(self, axis_name):

ranges = []
for name in other_axis_names:
dim_num = self.get_dim_num(name)
dim_num = self.get_dim_size(name)
ranges.append(range(dim_num))

all_result = []
Expand All @@ -86,7 +86,7 @@ def get_comm_list(self, axis_name):
key_coord[other_name] = x[other_axis_names.index(other_name)]

result = []
for i in range(0, self.get_dim_num(axis_name)):
for i in range(0, self.get_dim_size(axis_name)):
key_coord[axis_name] = i
result.append(self._coord2rank[self.coordinate(**key_coord)])
all_result.append(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def test_topology(self):
np.testing.assert_array_equal(mp_comm_list, topo.get_comm_list("mp"))
np.testing.assert_array_equal(pp_comm_list, topo.get_comm_list("pp"))

# test get_parallel_names
# test get_hybrid_group_names
parallel_names = ["dp", "mp", "pp"]
np.testing.assert_array_equal(parallel_names, topo.get_parallel_names())
np.testing.assert_array_equal(parallel_names,
topo.get_hybrid_group_names())

# test get_dims
np.testing.assert_array_equal(2, topo.get_dim("dp"))
Expand Down Expand Up @@ -73,10 +74,10 @@ def test_topology(self):
self.assertEqual(topo.get_axis_list("pp", 0), [0, 2, 4, 6])
self.assertEqual(topo.get_axis_list("pp", 1), [1, 3, 5, 7])

# test get_dim_num
self.assertEqual(topo.get_dim_num("dp"), 2)
self.assertEqual(topo.get_dim_num("mp"), 2)
self.assertEqual(topo.get_dim_num("pp"), 2)
# test get_dim_size
self.assertEqual(topo.get_dim_size("dp"), 2)
self.assertEqual(topo.get_dim_size("mp"), 2)
self.assertEqual(topo.get_dim_size("pp"), 2)


if __name__ == '__main__':
Expand Down

0 comments on commit 2b31859

Please sign in to comment.