Skip to content

Commit

Permalink
index_add_ (#26761)
Browse files Browse the repository at this point in the history
  • Loading branch information
imsoumya18 authored Jan 18, 2024
1 parent 831e04e commit 1506548
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 2 deletions.
36 changes: 34 additions & 2 deletions ivy/functional/frontends/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,39 @@
)
from ivy.func_wrapper import with_unsupported_dtypes


@with_supported_dtypes(
{"2.5.1 and below": ("bool", "int32", "int64", "float16", "float32", "float64")},
"paddle",
)
@to_ivy_arrays_and_back
def index_add_(x, index, axis, value, *, name=None):
x = ivy.swapaxes(x, axis, 0)
value = ivy.swapaxes(value, axis, 0)
_to_adds = []
index = sorted(zip(ivy.to_list(index), range(len(index))), key=(lambda i: i[0]))
while index:
_curr_idx = index[0][0]
while len(_to_adds) < _curr_idx:
_to_adds.append(ivy.zeros_like(value[0]))
_to_add_cum = ivy.get_item(value, index[0][1])
while (len(index)) > 1 and (index[0][0] == index[1][0]):
_to_add_cum = _to_add_cum + ivy.get_item(value, index.pop(1)[1])
index.pop(0)
_to_adds.append(_to_add_cum)
while len(_to_adds) < x.shape[0]:
_to_adds.append(ivy.zeros_like(value[0]))
_to_adds = ivy.stack(_to_adds)
if len(x.shape) < 2:
# Added this line due to the paddle backend treating scalars as 1-d arrays
_to_adds = ivy.flatten(_to_adds)

ret = ivy.add(x, _to_adds)
ret = ivy.swapaxes(ret, axis, 0)
x = ret
return x


# NOTE:
# Only inplace functions are to be added in this file.
# Please add non-inplace counterparts to `/frontends/paddle/manipulation.py`.
Expand All @@ -17,6 +50,5 @@
)
@to_ivy_arrays_and_back
def reshape_(x, shape):
ret = ivy.reshape(x, shape)
ivy.inplace_update(x, ret)
ivy.reshape(x, shape)
return x
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,63 @@
# --------------- #


@st.composite
def _arrays_dim_idx_n_dtypes(draw):
num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims"))
num_arrays = 2
common_shape = draw(
helpers.lists(
x=helpers.ints(min_value=2, max_value=3),
min_size=num_dims - 1,
max_size=num_dims - 1,
)
)
_dim = draw(helpers.ints(min_value=0, max_value=num_dims - 1))
unique_dims = draw(
helpers.lists(
x=helpers.ints(min_value=2, max_value=3),
min_size=num_arrays,
max_size=num_arrays,
)
)

min_dim = min(unique_dims)
max_dim = max(unique_dims)
_idx = draw(
helpers.array_values(
shape=min_dim,
dtype="int64",
min_value=0,
max_value=max_dim,
exclude_min=False,
)
)

xs = []
# available_input_types = draw(helpers.get_dtypes("integer"))
# available_input_types = ["int32", "int64", "float16", "float32", "float64"]
available_input_types = ["int32", "int64"]
input_dtypes = draw(
helpers.array_dtypes(
available_dtypes=available_input_types,
num_arrays=num_arrays,
shared_dtype=True,
)
)
for ud, dt in zip(unique_dims, input_dtypes):
x = draw(
helpers.array_values(
shape=common_shape[:_dim] + [ud] + common_shape[_dim:],
dtype=dt,
large_abs_safety_factor=2.5,
small_abs_safety_factor=2.5,
safety_factor_scale="log",
)
)
xs.append(x)
return xs, input_dtypes, _dim, _idx


@st.composite
def dtypes_x_reshape_(draw):
shape = draw(helpers.get_shape(min_num_dims=1))
Expand All @@ -25,6 +82,42 @@ def dtypes_x_reshape_(draw):
return dtypes, x, shape


# --- Main --- #
# ------------ #


@handle_frontend_test(
fn_tree="paddle.tensor.manipulation.index_add_",
xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(),
)
def test_paddle_index_add_(
*,
xs_dtypes_dim_idx,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
xs, input_dtypes, axis, indices = xs_dtypes_dim_idx
if xs[0].shape[axis] < xs[1].shape[axis]:
source, input = xs
else:
input, source = xs
helpers.test_frontend_function(
input_dtypes=input_dtypes,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
frontend=frontend,
on_device=on_device,
x=input,
index=indices,
axis=axis,
value=source,
)


# reshape_
@handle_frontend_test(
fn_tree="paddle.tensor.manipulation.reshape_",
Expand Down

0 comments on commit 1506548

Please sign in to comment.