Skip to content

Commit

Permalink
fix: fixed test_torch___getitem__ for paddle backend
Browse files Browse the repository at this point in the history
Fixed the syntax for checking if x is a True.
Added a function to remove negative values in `query` of `paddle.get_item`.
  • Loading branch information
ZenithFlux committed Mar 18, 2024
1 parent 35ba3f6 commit 5c9d3e0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ivy/functional/backends/paddle/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def astype(
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
dtype = ivy.as_native_dtype(dtype)

if copy and 0 in x.shape:
return paddle.empty(tuple(x.shape), dtype=dtype)

if x.dtype == dtype:
return x.clone() if copy else x
return x.clone().cast(dtype) if copy else x.cast(dtype)
Expand Down
44 changes: 43 additions & 1 deletion ivy/functional/backends/paddle/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,47 @@ def _squeeze_helper(query, x_ndim):
return squeeze_indices


def _make_non_negative(query, x):
"""Converts negative values inside the tensors in the query to their
positive form.
Returns ``query`` unmodified if it is not a ``list``, ``tuple``
or ``paddle.Tensor``.
This function leaves non-tensor values in ``query`` as is.
"""
if isinstance(query, paddle.Tensor):
query[query < 0] = x.shape[0] + query[query < 0]
return query

if not isinstance(query, (list, tuple)):
return query

found_ellipsis = False
for i, q in enumerate(query):
if not isinstance(q, paddle.Tensor):
continue

if q is Ellipsis:
found_ellipsis = True
break

q[q < 0] = x.shape[i] + q[q < 0]

if not found_ellipsis:
return query

for i, q in enumerate(reversed(query), 1):
if not isinstance(q, paddle.Tensor):
continue

if q is Ellipsis:
return query

i = len(x.shape) - i
q[q < 0] = x.shape[i] + q[q < 0]


@with_unsupported_device_and_dtypes(
{
"2.6.0 and below": {
Expand All @@ -111,7 +152,7 @@ def get_item(
and query.ndim == 0
) or isinstance(query, bool):
# special case to handle scalar boolean indices
if query is True:
if query.item():
return x[None]
else:
return paddle.zeros(shape=[0] + x.shape, dtype=x.dtype)
Expand All @@ -125,6 +166,7 @@ def get_item(
squeeze_indices = _squeeze_helper(query, x.ndim)
# regular queries x[idx_1,idx_2,...,idx_i]
# array queries idx = Tensor(idx_1,idx_2,...,idx_i), x[idx]
query = _make_non_negative(query, x)
ret = x.__getitem__(query)
return ret.squeeze(squeeze_indices) if squeeze_indices else ret

Expand Down

0 comments on commit 5c9d3e0

Please sign in to comment.