Skip to content

Commit

Permalink
Take/Put_along_axis more input size support (#39072)
Browse files Browse the repository at this point in the history
Support the cases that the indices shape size is larger than the arr shape size
  • Loading branch information
huangxu96 authored Jan 27, 2022
1 parent 809a10b commit 41a6435
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 60 deletions.
68 changes: 39 additions & 29 deletions python/paddle/fluid/tests/unittests/test_put_along_axis_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def setUp(self):
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def test_api_static_case1(self):
def test_api_static(self):
paddle.enable_static()

def run(place):
Expand Down Expand Up @@ -110,7 +110,7 @@ def run(place):
for place in self.place:
run(place)

def test_api_dygraph_case1(self):
def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np)
Expand All @@ -137,33 +137,7 @@ def run(place):
for place in self.place:
run(place)

def test_api_dygraph_case2(self):
def run(place):
paddle.disable_static(place)
self.shape = [2, 2]
self.index_shape = [2, 2]
self.index_np = np.array([[0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)

x_tensor = paddle.to_tensor(self.x_np)
index_tensor = paddle.to_tensor(self.index_np)
value_tensor = paddle.to_tensor(self.value_np)
out = paddle.put_along_axis(x_tensor, index_tensor, value_tensor,
self.axis)
np.array(
np.put_along_axis(self.x_np, self.index_np, self.value_np,
self.axis))
out_ref = self.x_np
self.assertEqual(
np.allclose(
out.numpy(), out_ref, rtol=1e-03), True)

paddle.enable_static()

for place in self.place:
run(place)

def test_inplace_dygraph_case3(self):
def test_inplace_dygraph(self):
def run(place):
paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np)
Expand All @@ -186,6 +160,42 @@ def run(place):
run(place)


class TestPutAlongAxisAPICase2(TestPutAlongAxisAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 2]
self.index_shape = [2, 2]
self.index_np = np.array([[0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 0
self.value_np = 99.0
self.value_shape = [1]
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))


class TestPutAlongAxisAPICase3(TestPutAlongAxisAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 2]
self.index_shape = [4, 2]
self.index_np = np.array(
[[0, 0], [1, 0], [0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 0
self.value_np = 99.0
self.value_shape = [1]
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def test_inplace_dygraph(self):
pass


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/test_take_along_axis_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ def test_api_dygraph(self):
paddle.enable_static()


class TestTakeAlongAxisAPICase1(TestTakeAlongAxisAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 2]
self.index_shape = [4, 2]
self.index_np = np.array(
[[0, 0], [1, 0], [0, 0], [1, 0]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 0
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
87 changes: 61 additions & 26 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2751,6 +2751,31 @@ def moveaxis(x, source, destination, name=None):
return out


def non_negative_axis(arr, axis):
ndim = len(arr.shape)
if axis >= 0:
assert axis < ndim, "'axis' must be in the range of [-{0}, {0})".format(
ndim)
else:
assert axis >= -ndim, "'axis' must be in the range of [-{0}, {0})".format(
ndim)
axis += ndim

return axis


def infer_broadcast_shape(arr, indices, axis):
# This function is used in take/put_along_axis
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = list(indices.shape)[axis]
broadcast_shape = tuple(broadcast_shape_list)
for i in range(len(arr.shape)):
if arr.shape[i] < indices.shape[i]:
# if indices matrix has larger size than arr matrix, do not broadcast.
return None
return broadcast_shape


def take_along_axis(arr, indices, axis):
"""
Take values from the input array by given indices matrix along the designated axis.
Expand Down Expand Up @@ -2779,21 +2804,31 @@ def take_along_axis(arr, indices, axis):
print(result)
# [[1, 2, 3]]
"""
if (arr.shape == indices.shape):
broadcast_shape = arr.shape
else:
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = 1
broadcast_shape = tuple(broadcast_shape_list)
if (len(arr.shape) != len(indices.shape)):
raise ValueError(
"`indices` and `arr` must have the same number of dimensions!")
axis = non_negative_axis(arr, axis)
broadcast_shape = infer_broadcast_shape(arr, indices, axis)
if not broadcast_shape:
# if indices matrix have larger size than arr, arr should broadcast into indices shape.
broadcast_shape = indices.shape
if in_dygraph_mode():
indices = paddle.broadcast_to(indices, broadcast_shape)
broadcast_shape_list = list(broadcast_shape)
broadcast_shape_list[axis] = list(arr.shape)[axis]
broadcast_shape = tuple(broadcast_shape_list)
arr = paddle.broadcast_to(arr, broadcast_shape)
return _C_ops.take_along_axis(arr, indices, 'Axis', axis)
check_variable_and_dtype(
arr, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
'take_along_axis')
check_variable_and_dtype(indices, 'index', ['int32', 'int64'],
'take_along_axis')
indices = paddle.broadcast_to(indices, broadcast_shape)
broadcast_shape_list = list(broadcast_shape)
broadcast_shape_list[axis] = list(arr.shape)[axis]
broadcast_shape = tuple(broadcast_shape_list)
arr = paddle.broadcast_to(arr, broadcast_shape)
helper = LayerHelper('take_along_axis', **locals())
dtype = helper.input_dtype()
result = helper.create_variable_for_type_inference(dtype)
Expand Down Expand Up @@ -2837,17 +2872,17 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
# [60, 40, 50]]
"""
if (arr.shape == indices.shape):
broadcast_shape = arr.shape
else:
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = 1
broadcast_shape = tuple(broadcast_shape_list)
if (len(arr.shape) != len(indices.shape)):
raise ValueError(
"`indices` and `arr` must have the same number of dimensions!")
axis = non_negative_axis(arr, axis)
broadcast_shape = infer_broadcast_shape(arr, indices, axis)
if in_dygraph_mode():
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.to_tensor(values) if not isinstance(
values, paddle.Tensor) else values
values = paddle.broadcast_to(values, broadcast_shape)
if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape)
return _C_ops.put_along_axis(arr, indices, values, "Axis", axis,
"Reduce", reduce)

Expand All @@ -2856,8 +2891,9 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
'put_along_axis')
check_variable_and_dtype(indices, 'index', ['int32', 'int64'],
'put_along_axis')
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, broadcast_shape)
if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape)
helper = LayerHelper('put_along_axis', **locals())
dtype = helper.input_dtype()
result = helper.create_variable_for_type_inference(dtype)
Expand All @@ -2875,19 +2911,18 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
@inplace_apis_in_dygraph_only
def put_along_axis_(arr, indices, values, axis, reduce='assign'):
r"""
Inplace version of ``put_along_axis`` API, the output Tensor will be inplaced with input ``x``.
Inplace version of ``put_along_axis`` API, the output Tensor will be inplaced with input ``arr``.
Please refer to :ref:`api_tensor_put_along_axis`.
"""
if (arr.shape == indices.shape):
broadcast_shape = arr.shape
else:
broadcast_shape_list = list(arr.shape)
broadcast_shape_list[axis] = 1
broadcast_shape = tuple(broadcast_shape_list)

indices = paddle.broadcast_to(indices, broadcast_shape)
if (len(arr.shape) != len(indices.shape)):
raise ValueError(
"`indices` and `arr` must have the same number of dimensions!")
axis = non_negative_axis(arr, axis)
broadcast_shape = infer_broadcast_shape(arr, indices, axis)
values = paddle.to_tensor(values) if not isinstance(
values, paddle.Tensor) else values
values = paddle.broadcast_to(values, broadcast_shape)
if broadcast_shape:
indices = paddle.broadcast_to(indices, broadcast_shape)
values = paddle.broadcast_to(values, indices.shape)
return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce",
reduce)
22 changes: 17 additions & 5 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,17 +437,29 @@ def quantile(x, q, axis=None, keepdim=False):
indices_upper = paddle.ceil(indices).astype(paddle.int32)
outputs = []

def expand_dim(indices, sorted_tensor_shape, axis):
assert axis < len(list(sorted_tensor_shape))
expanded_shape = [1] * len(list(sorted_tensor_shape))
expanded_shape[axis] = len(indices)
expanded_shape = tuple(expanded_shape)
indices = indices.reshape(expanded_shape)
return indices

# TODO(chenjianye): replace the for-loop to directly take elements.
for i in range(len(indices)):
if (indices_upper[i] != indices_below[i]):
tensor_below = paddle.take_along_axis(sorted_tensor,
indices_below[i], axis)
tensor_upper = paddle.take_along_axis(sorted_tensor,
indices_upper[i], axis)
tensor_below = paddle.take_along_axis(
sorted_tensor,
expand_dim(indices_below[i], sorted_tensor.shape, axis), axis)
tensor_upper = paddle.take_along_axis(
sorted_tensor,
expand_dim(indices_upper[i], sorted_tensor.shape, axis), axis)
weights = (indices[i] - indices_below[i]).astype(x.dtype)
out = paddle.lerp(tensor_below, tensor_upper, weights)
else:
out = paddle.take_along_axis(sorted_tensor, indices_below[i], axis)
out = paddle.take_along_axis(
sorted_tensor,
expand_dim(indices_below[i], sorted_tensor.shape, axis), axis)
if not keepdim:
out = paddle.squeeze(out, axis=axis)
else:
Expand Down

0 comments on commit 41a6435

Please sign in to comment.