Skip to content

Commit

Permalink
fix: fixed test_torch___getitem__ for paddle backend
Browse files Browse the repository at this point in the history
Fixed the syntax for checking if x is a True.
Added a function to remove negative values in `query` of `paddle.get_item`.
  • Loading branch information
ZenithFlux committed Mar 21, 2024
1 parent 35ba3f6 commit f685516
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ivy/functional/backends/paddle/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def astype(
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
dtype = ivy.as_native_dtype(dtype)

# if copy and 0 in x.shape:
# return paddle.empty(x.shape, dtype=dtype)

if x.dtype == dtype:
return x.clone() if copy else x
return x.clone().cast(dtype) if copy else x.cast(dtype)
Expand Down
58 changes: 57 additions & 1 deletion ivy/functional/backends/paddle/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,58 @@ def _squeeze_helper(query, x_ndim):
return squeeze_indices


def _make_non_negative(query, x):
"""Converts negative values inside the tensors in the query to their
positive form.
Returns ``query`` unmodified if it is not a ``list``, ``tuple``
or ``paddle.Tensor``.
This function leaves non-tensor values in ``query`` as is.
"""
if isinstance(query, paddle.Tensor):
query[query < 0] = x.shape[0] + query[query < 0]
return query

if not isinstance(query, (list, tuple)):
return query

found_ellipsis = False
shape_i = 0
for q in query:
if q is None:
continue

if not isinstance(q, paddle.Tensor):
shape_i += 1
continue

if q is Ellipsis:
found_ellipsis = True
break

q[q < 0] = x.shape[shape_i] + q[q < 0]
shape_i += 1

if not found_ellipsis:
return query

shape_i = x.ndim - 1
for q in reversed(query):
if q is None:
continue

if not isinstance(q, paddle.Tensor):
shape_i -= 1
continue

if q is Ellipsis:
return query

q[q < 0] = x.shape[shape_i] + q[q < 0]
shape_i -= 1


@with_unsupported_device_and_dtypes(
{
"2.6.0 and below": {
Expand All @@ -111,7 +163,10 @@ def get_item(
and query.ndim == 0
) or isinstance(query, bool):
# special case to handle scalar boolean indices
if query is True:
if isinstance(query, paddle.Tensor):
query = query.item()

if query:
return x[None]
else:
return paddle.zeros(shape=[0] + x.shape, dtype=x.dtype)
Expand All @@ -125,6 +180,7 @@ def get_item(
squeeze_indices = _squeeze_helper(query, x.ndim)
# regular queries x[idx_1,idx_2,...,idx_i]
# array queries idx = Tensor(idx_1,idx_2,...,idx_i), x[idx]
query = _make_non_negative(query, x)
ret = x.__getitem__(query)
return ret.squeeze(squeeze_indices) if squeeze_indices else ret

Expand Down

0 comments on commit f685516

Please sign in to comment.