diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9f6f3622f71..f859f6c420a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,9 @@ New Features :py:meth:`coarsen`, :py:meth:`weighted`, :py:meth:`resample`, (:pull:`6702`) By `Michael Niklas `_. +- Experimental support for wrapping any array type that conforms to the python array api standard. + (:pull:`6804`) + By `Tom White `_. Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6e73ee41b40..2cd2fb3af04 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs): if name in ["sum", "prod"]: kwargs.pop("min_count", None) - func = getattr(np, name) + if hasattr(values, "__array_namespace__"): + xp = values.__array_namespace__() + func = getattr(xp, name) + else: + func = getattr(np, name) try: with warnings.catch_warnings(): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 9a29b63f4d0..72ca60d4d5e 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -679,6 +679,8 @@ def as_indexable(array): return DaskIndexingAdapter(array) if hasattr(array, "__array_function__"): return NdArrayLikeIndexingAdapter(array) + if hasattr(array, "__array_namespace__"): + return ArrayApiIndexingAdapter(array) raise TypeError(f"Invalid array type: {type(array)}") @@ -1288,6 +1290,49 @@ def __init__(self, array): self.array = array +class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap an array API array to use explicit indexing.""" + + __slots__ = ("array",) + + def __init__(self, array): + if not hasattr(array, "__array_namespace__"): + raise TypeError( + "ArrayApiIndexingAdapter must wrap an object that " + "implements the __array_namespace__ protocol" + ) + self.array = array + + def __getitem__(self, key): + if isinstance(key, BasicIndexer): + return self.array[key.tuple] + elif isinstance(key, OuterIndexer): + # manual orthogonal indexing (implemented like DaskIndexingAdapter) + key = key.tuple + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey, Ellipsis)] + return value + else: + if isinstance(key, VectorizedIndexer): + raise TypeError("Vectorized indexing is not supported") + else: + raise TypeError(f"Unrecognized indexer: {key}") + + def __setitem__(self, key, value): + if isinstance(key, (BasicIndexer, OuterIndexer)): + self.array[key.tuple] = value + else: + if isinstance(key, VectorizedIndexer): + raise TypeError("Vectorized indexing is not supported") + else: + raise TypeError(f"Unrecognized indexer: {key}") + + def transpose(self, order): + xp = self.array.__array_namespace__() + return xp.permute_dims(self.array, order) + + class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a dask array to support explicit indexing.""" diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ab3f8d3a282..51bf1346506 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -263,8 +263,10 @@ def is_duck_array(value: Any) -> bool: hasattr(value, "ndim") and hasattr(value, "shape") and hasattr(value, "dtype") - and hasattr(value, "__array_function__") - and hasattr(value, "__array_ufunc__") + and ( + (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__") + ) ) @@ -298,6 +300,7 @@ def _is_scalar(value, include_0d): or not ( isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES) or hasattr(value, "__array_function__") + or hasattr(value, "__array_namespace__") ) ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 90edf652284..502bf8482f2 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -237,7 +237,9 @@ def as_compatible_data(data, fastpath=False): else: data = np.asarray(data) - if not isinstance(data, np.ndarray) and hasattr(data, "__array_function__"): + if not isinstance(data, np.ndarray) and ( + hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") + ): return data # validate whether the data is valid data types. diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py new file mode 100644 index 00000000000..8e378054c29 --- /dev/null +++ b/xarray/tests/test_array_api.py @@ -0,0 +1,51 @@ +from typing import Tuple + +import pytest + +import xarray as xr +from xarray.testing import assert_equal + +np = pytest.importorskip("numpy", minversion="1.22") + +import numpy.array_api as xp # isort:skip +from numpy.array_api._array_object import Array # isort:skip + + +@pytest.fixture +def arrays() -> Tuple[xr.DataArray, xr.DataArray]: + np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) + xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) + assert isinstance(xp_arr.data, Array) + return np_arr, xp_arr + + +def test_arithmetic(arrays) -> None: + np_arr, xp_arr = arrays + expected = np_arr + 7 + actual = xp_arr + 7 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_aggregation(arrays) -> None: + np_arr, xp_arr = arrays + expected = np_arr.sum(skipna=False) + actual = xp_arr.sum(skipna=False) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_indexing(arrays) -> None: + np_arr, xp_arr = arrays + expected = np_arr[:, 0] + actual = xp_arr[:, 0] + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_reorganizing_operation(arrays) -> None: + np_arr, xp_arr = arrays + expected = np_arr.transpose() + actual = xp_arr.transpose() + assert isinstance(actual.data, Array) + assert_equal(actual, expected)