Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NumPy array API (experimental) #6804

Merged
merged 12 commits into from
Jul 20, 2022
6 changes: 5 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs):
if name in ["sum", "prod"]:
kwargs.pop("min_count", None)

func = getattr(np, name)
if hasattr(values, "__array_namespace__"):
xp = values.__array_namespace__()
func = getattr(xp, name)
else:
func = getattr(np, name)

try:
with warnings.catch_warnings():
Expand Down
43 changes: 43 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,8 @@ def as_indexable(array):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)
if hasattr(array, "__array_namespace__"):
return ArrayApiIndexingAdapter(array)

raise TypeError(f"Invalid array type: {type(array)}")

Expand Down Expand Up @@ -1288,6 +1290,47 @@ def __init__(self, array):
self.array = array


class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array API array to use explicit indexing."""

__slots__ = ("array",)

def __init__(self, array):
if not hasattr(array, "__array_namespace__"):
raise TypeError(
"ArrayApiIndexingAdapter must wrap an object that "
"implements the __array_namespace__ protocol"
)
self.array = array

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, OuterIndexer):
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
key = key.tuple
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
return value
else:
if not isinstance(key, VectorizedIndexer):
raise TypeError(f"Unrecognized indexer: {key}")
raise TypeError("Vectorized indexing is not supported")
tomwhite marked this conversation as resolved.
Show resolved Hide resolved

def __setitem__(self, key, value):
if isinstance(key, (BasicIndexer, OuterIndexer)):
self.array[key.tuple] = value
else:
if not isinstance(key, VectorizedIndexer):
raise TypeError(f"Unrecognized indexer: {key}")
raise TypeError("Vectorized indexing is not supported")
tomwhite marked this conversation as resolved.
Show resolved Hide resolved

def transpose(self, order):
xp = self.array.__array_namespace__()
return xp.permute_dims(self.array, order)


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand Down
7 changes: 5 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ def is_duck_array(value: Any) -> bool:
hasattr(value, "ndim")
and hasattr(value, "shape")
and hasattr(value, "dtype")
and hasattr(value, "__array_function__")
and hasattr(value, "__array_ufunc__")
and (
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
or hasattr(value, "__array_namespace__")
)
)


Expand Down Expand Up @@ -298,6 +300,7 @@ def _is_scalar(value, include_0d):
or not (
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(value, "__array_function__")
or hasattr(value, "__array_namespace__")
)
)

Expand Down
4 changes: 3 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def as_compatible_data(data, fastpath=False):
else:
data = np.asarray(data)

if not isinstance(data, np.ndarray) and hasattr(data, "__array_function__"):
if not isinstance(data, np.ndarray) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
):
return data

# validate whether the data is valid data types.
Expand Down
48 changes: 48 additions & 0 deletions xarray/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy.array_api as xp
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
import pytest
from numpy.array_api._array_object import Array
dcherian marked this conversation as resolved.
Show resolved Hide resolved

import xarray as xr
from xarray.testing import assert_equal

np = pytest.importorskip("numpy", minversion="1.22")


@pytest.fixture
def arrays():
tomwhite marked this conversation as resolved.
Show resolved Hide resolved
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
assert isinstance(xp_arr.data, Array)
return np_arr, xp_arr


def test_arithmetic(arrays):
tomwhite marked this conversation as resolved.
Show resolved Hide resolved
np_arr, xp_arr = arrays
expected = np_arr + 7
actual = xp_arr + 7
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_aggregation(arrays):
tomwhite marked this conversation as resolved.
Show resolved Hide resolved
np_arr, xp_arr = arrays
expected = np_arr.sum(skipna=False)
actual = xp_arr.sum(skipna=False)
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_indexing(arrays):
tomwhite marked this conversation as resolved.
Show resolved Hide resolved
np_arr, xp_arr = arrays
expected = np_arr[:, 0]
actual = xp_arr[:, 0]
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_reorganizing_operation(arrays):
tomwhite marked this conversation as resolved.
Show resolved Hide resolved
np_arr, xp_arr = arrays
expected = np_arr.transpose()
actual = xp_arr.transpose()
assert isinstance(actual.data, Array)
assert_equal(actual, expected)