Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NumPy array API (experimental) #6804

Merged
merged 12 commits into from
Jul 20, 2022
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ New Features
:py:meth:`coarsen`, :py:meth:`weighted`, :py:meth:`resample`,
(:pull:`6702`)
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Experimental support for wrapping any array type that conforms to the python array api standard.
(:pull:`6804`)
By `Tom White <https://github.com/tomwhite>`_.

Deprecations
~~~~~~~~~~~~
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs):
if name in ["sum", "prod"]:
kwargs.pop("min_count", None)

func = getattr(np, name)
if hasattr(values, "__array_namespace__"):
xp = values.__array_namespace__()
func = getattr(xp, name)
else:
func = getattr(np, name)

try:
with warnings.catch_warnings():
Expand Down
45 changes: 45 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,8 @@ def as_indexable(array):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)
if hasattr(array, "__array_namespace__"):
return ArrayApiIndexingAdapter(array)

raise TypeError(f"Invalid array type: {type(array)}")

Expand Down Expand Up @@ -1288,6 +1290,49 @@ def __init__(self, array):
self.array = array


class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array API array to use explicit indexing."""

__slots__ = ("array",)

def __init__(self, array):
if not hasattr(array, "__array_namespace__"):
raise TypeError(
"ArrayApiIndexingAdapter must wrap an object that "
"implements the __array_namespace__ protocol"
)
self.array = array

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, OuterIndexer):
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
key = key.tuple
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
return value
else:
if isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
else:
raise TypeError(f"Unrecognized indexer: {key}")

def __setitem__(self, key, value):
if isinstance(key, (BasicIndexer, OuterIndexer)):
self.array[key.tuple] = value
else:
if isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
else:
raise TypeError(f"Unrecognized indexer: {key}")

def transpose(self, order):
xp = self.array.__array_namespace__()
return xp.permute_dims(self.array, order)


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand Down
7 changes: 5 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ def is_duck_array(value: Any) -> bool:
hasattr(value, "ndim")
and hasattr(value, "shape")
and hasattr(value, "dtype")
and hasattr(value, "__array_function__")
and hasattr(value, "__array_ufunc__")
and (
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
or hasattr(value, "__array_namespace__")
)
)


Expand Down Expand Up @@ -298,6 +300,7 @@ def _is_scalar(value, include_0d):
or not (
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(value, "__array_function__")
or hasattr(value, "__array_namespace__")
)
)

Expand Down
4 changes: 3 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def as_compatible_data(data, fastpath=False):
else:
data = np.asarray(data)

if not isinstance(data, np.ndarray) and hasattr(data, "__array_function__"):
if not isinstance(data, np.ndarray) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
):
return data

# validate whether the data is valid data types.
Expand Down
51 changes: 51 additions & 0 deletions xarray/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Tuple

import pytest

import xarray as xr
from xarray.testing import assert_equal

np = pytest.importorskip("numpy", minversion="1.22")

import numpy.array_api as xp # isort:skip
dcherian marked this conversation as resolved.
Show resolved Hide resolved
from numpy.array_api._array_object import Array # isort:skip


@pytest.fixture
def arrays() -> Tuple[xr.DataArray, xr.DataArray]:
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
assert isinstance(xp_arr.data, Array)
return np_arr, xp_arr


def test_arithmetic(arrays) -> None:
np_arr, xp_arr = arrays
expected = np_arr + 7
actual = xp_arr + 7
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_aggregation(arrays) -> None:
np_arr, xp_arr = arrays
expected = np_arr.sum(skipna=False)
actual = xp_arr.sum(skipna=False)
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_indexing(arrays) -> None:
np_arr, xp_arr = arrays
expected = np_arr[:, 0]
actual = xp_arr[:, 0]
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_reorganizing_operation(arrays) -> None:
np_arr, xp_arr = arrays
expected = np_arr.transpose()
actual = xp_arr.transpose()
assert isinstance(actual.data, Array)
assert_equal(actual, expected)