Skip to content

Commit

Permalink
Add implementation of dpnp.argwhere (#2000)
Browse files Browse the repository at this point in the history
* Add implementation of dpnp.argwhere()

* Added new tests and updated existing ones

* Applied pre-commit hooks

* Fix broken link in description
  • Loading branch information
antonwolfy authored Aug 19, 2024
1 parent 256ce60 commit 762d477
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 4 deletions.
58 changes: 57 additions & 1 deletion dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .dpnp_array import dpnp_array
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call

__all__ = ["argmax", "argmin", "searchsorted", "where"]
__all__ = ["argmax", "argmin", "argwhere", "searchsorted", "where"]


def _get_search_res_dt(a, _dtype, out):
Expand Down Expand Up @@ -244,6 +244,62 @@ def argmin(a, axis=None, out=None, *, keepdims=False):
)


def argwhere(a):
"""
Find the indices of array elements that are non-zero, grouped by element.
For full documentation refer to :obj:`numpy.argwhere`.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Input array.
Returns
-------
out : dpnp.ndarray
Indices of elements that are non-zero. Indices are grouped by element.
This array will have shape ``(N, a.ndim)`` where ``N`` is the number of
non-zero items.
See Also
--------
:obj:`dpnp.where` : Returns elements chosen from input arrays depending on
a condition.
:obj:`dpnp.nonzero` : Return the indices of the elements that are non-zero.
Notes
-----
``dpnp.argwhere(a)`` is almost the same as
``dpnp.transpose(dpnp.nonzero(a))``, but produces a result of the correct
shape for a 0D array.
The output of :obj:`dpnp.argwhere` is not suitable for indexing arrays.
For this purpose use :obj:`dpnp.nonzero` instead.
Examples
--------
>>> import dpnp as np
>>> x = np.arange(6).reshape(2, 3)
>>> x
array([[0, 1, 2],
[3, 4, 5]])
>>> np.argwhere(x > 1)
array([[0, 2],
[1, 0],
[1, 1],
[1, 2]])
"""

dpnp.check_supported_arrays_type(a)
if a.ndim == 0:
# nonzero does not behave well on 0d, so promote to 1d
a = dpnp.atleast_1d(a)
# and then remove the added dimension
return dpnp.argwhere(a)[:, :0]
return dpnp.stack(dpnp.nonzero(a)).T


def searchsorted(a, v, side="left", sorter=None):
"""
Find indices where elements should be inserted to maintain order.
Expand Down
65 changes: 64 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import dpctl.tensor as dpt
import numpy
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_raises
from numpy.testing import (
assert_allclose,
assert_array_equal,
assert_equal,
assert_raises,
)

import dpnp

Expand Down Expand Up @@ -99,6 +104,64 @@ def test_ndarray(self, axis, keepdims):
assert_dtype_allclose(dpnp_res, np_res)


class TestArgwhere:
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))
def test_basic(self, dt):
a = numpy.array([4, 0, 2, 1, 3], dtype=dt)
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

@pytest.mark.parametrize("ndim", [0, 1, 2])
def test_ndim(self, ndim):
# get an nd array with multiple elements in every dimension
a = numpy.empty((2,) * ndim)

# none
a[...] = False
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

# only one
a[...] = False
a.flat[0] = True
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

# all but one
a[...] = True
a.flat[0] = False
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

# all
a[...] = True
ia = dpnp.array(a)

result = dpnp.argwhere(ia)
expected = numpy.argwhere(a)
assert_equal(result, expected)

def test_2d(self):
a = numpy.arange(6).reshape((2, 3))
ia = dpnp.array(a)

result = dpnp.argwhere(ia > 1)
expected = numpy.argwhere(a > 1)
assert_array_equal(result, expected)


class TestWhere:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_basic(self, dtype):
Expand Down
1 change: 1 addition & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def test_meshgrid(device):
pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]),
pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]),
pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]),
pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]),
pytest.param("cbrt", [1.0, 8.0, 27.0]),
pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]),
Expand Down
1 change: 1 addition & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def test_norm(usm_type, ord, axis):
pytest.param("argmax", [1.0, 2.0, 4.0, 7.0]),
pytest.param("argmin", [1.0, 2.0, 4.0, 7.0]),
pytest.param("argsort", [2.0, 1.0, 7.0, 4.0]),
pytest.param("argwhere", [[0, 3], [1, 4], [2, 5]]),
pytest.param("cbrt", [1, 8, 27]),
pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
pytest.param("conjugate", [[1.0 + 1.0j, 0.0], [0.0, 1.0 + 1.0j]]),
Expand Down
2 changes: 0 additions & 2 deletions tests/third_party/cupy/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ def test_flatnonzero(self, xp, dtype):
{"array": numpy.empty((0, 2, 0))},
_ids=False, # Do not generate ids from randomly generated params
)
@pytest.mark.skip("argwhere isn't implemented yet")
class TestArgwhere:
@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
Expand All @@ -412,7 +411,6 @@ def test_argwhere(self, xp, dtype):
{"value": 0},
{"value": 3},
)
@pytest.mark.skip("argwhere isn't implemented yet")
@testing.with_requires("numpy>=1.18")
class TestArgwhereZeroDimension:
@testing.for_all_dtypes()
Expand Down

0 comments on commit 762d477

Please sign in to comment.