From 3029d76a2760d9832f5901c841853407add59272 Mon Sep 17 00:00:00 2001 From: Tom White Date: Sun, 3 Mar 2024 15:29:33 +0000 Subject: [PATCH] Delegate dtype functions to backend array API --- .github/workflows/array-api-tests.yml | 5 +- cubed/__init__.py | 7 -- cubed/array_api/array_object.py | 14 +-- cubed/array_api/constants.py | 12 +- cubed/array_api/creation_functions.py | 17 ++- cubed/array_api/data_type_functions.py | 125 ++------------------ cubed/array_api/dtypes.py | 113 +++--------------- cubed/array_api/elementwise_functions.py | 28 ++--- cubed/array_api/linear_algebra_functions.py | 4 +- cubed/array_api/manipulation_functions.py | 2 +- cubed/array_api/searching_functions.py | 6 +- cubed/array_api/statistical_functions.py | 8 +- cubed/core/array.py | 7 +- cubed/core/ops.py | 6 +- cubed/nan_functions.py | 2 +- cubed/storage/virtual.py | 2 +- cubed/tests/test_random.py | 5 +- 17 files changed, 76 insertions(+), 287 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 76a28d54..c1040275 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -101,15 +101,14 @@ jobs: array_api_tests/test_statistical_functions.py::test_std array_api_tests/test_statistical_functions.py::test_var + # finfo(float32).eps returns float32 but should return float + array_api_tests/test_data_type_functions.py::test_finfo[float32] # From https://github.com/data-apis/array-api-tests/blob/master/.github/workflows/numpy.yml # https://github.com/numpy/numpy/issues/18881 array_api_tests/test_creation_functions.py::test_linspace - # https://github.com/numpy/numpy/issues/20870 - array_api_tests/test_data_type_functions.py::test_can_cast - EOF pytest -v -rxXfEA --max-examples=2 --disable-data-dependent-shapes --disable-extension linalg --ci --hypothesis-disable-deadline --cov=cubed.array_api --cov-report=term-missing diff --git a/cubed/__init__.py b/cubed/__init__.py index 19ef3958..c04ce4d0 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -1,10 +1,3 @@ -# Suppress numpy.array_api experimental warning -import sys -import warnings - -if not sys.warnoptions: - warnings.filterwarnings("ignore", category=UserWarning) - from importlib.metadata import version as _version try: diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index e16b1145..3a2b23ff 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -225,37 +225,37 @@ def __eq__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__eq__") if other is NotImplemented: return other - return elemwise(nxp.equal, self, other, dtype=np.bool_) + return elemwise(nxp.equal, self, other, dtype=nxp.bool) def __ge__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__ge__") if other is NotImplemented: return other - return elemwise(nxp.greater_equal, self, other, dtype=np.bool_) + return elemwise(nxp.greater_equal, self, other, dtype=nxp.bool) def __gt__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__gt__") if other is NotImplemented: return other - return elemwise(nxp.greater, self, other, dtype=np.bool_) + return elemwise(nxp.greater, self, other, dtype=nxp.bool) def __le__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__le__") if other is NotImplemented: return other - return elemwise(nxp.less_equal, self, other, dtype=np.bool_) + return elemwise(nxp.less_equal, self, other, dtype=nxp.bool) def __lt__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__lt__") if other is NotImplemented: return other - return elemwise(nxp.less, self, other, dtype=np.bool_) + return elemwise(nxp.less, self, other, dtype=nxp.bool) def __ne__(self, other, /): other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other - return elemwise(nxp.not_equal, self, other, dtype=np.bool_) + return elemwise(nxp.not_equal, self, other, dtype=nxp.bool) # Reflected Operators @@ -425,7 +425,7 @@ def _promote_scalar(self, scalar): "Python int scalars cannot be promoted with bool arrays" ) if self.dtype in _integer_dtypes: - info = np.iinfo(self.dtype) + info = nxp.iinfo(self.dtype) if not (info.min <= scalar <= info.max): raise OverflowError( "Python int scalars must be within the bounds of the dtype for integer arrays" diff --git a/cubed/array_api/constants.py b/cubed/array_api/constants.py index 24d7dbef..4690b912 100644 --- a/cubed/array_api/constants.py +++ b/cubed/array_api/constants.py @@ -1,7 +1,7 @@ -import numpy as np +from cubed.backend_array_api import namespace as nxp -e = np.e -inf = np.inf -nan = np.nan -newaxis = None -pi = np.pi +e = nxp.e +inf = nxp.inf +nan = nxp.nan +newaxis = nxp.newaxis +pi = nxp.pi diff --git a/cubed/array_api/creation_functions.py b/cubed/array_api/creation_functions.py index 3f406c4c..fdf1dfbd 100644 --- a/cubed/array_api/creation_functions.py +++ b/cubed/array_api/creation_functions.py @@ -1,7 +1,6 @@ import math from typing import TYPE_CHECKING, Iterable, List -import numpy as np from zarr.util import normalize_shape from cubed.backend_array_api import namespace as nxp @@ -93,7 +92,7 @@ def empty_virtual_array( shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True ) -> "Array": if dtype is None: - dtype = np.float64 + dtype = nxp.float64 chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype)) name = gensym() @@ -111,7 +110,7 @@ def eye( if n_cols is None: n_cols = n_rows if dtype is None: - dtype = np.float64 + dtype = nxp.float64 shape = (n_rows, n_cols) chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) @@ -143,11 +142,11 @@ def full( if dtype is None: # check bool first since True/False are instances of int and float if isinstance(fill_value, bool): - dtype = np.bool_ + dtype = nxp.bool elif isinstance(fill_value, int): - dtype = np.int64 + dtype = nxp.int64 elif isinstance(fill_value, float): - dtype = np.float64 + dtype = nxp.float64 else: raise TypeError("Invalid input to full") chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype)) @@ -194,7 +193,7 @@ def linspace( div = 1 step = float(range_) / div if dtype is None: - dtype = np.float64 + dtype = nxp.float64 chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype) chunksize = chunks[0][0] @@ -258,7 +257,7 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]: def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array": if dtype is None: - dtype = np.float64 + dtype = nxp.float64 return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec) @@ -304,7 +303,7 @@ def _tri_mask(N, M, k, chunks, spec): def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array": if dtype is None: - dtype = np.float64 + dtype = nxp.float64 return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec) diff --git a/cubed/array_api/data_type_functions.py b/cubed/array_api/data_type_functions.py index 79b9f9bd..0c9be67b 100644 --- a/cubed/array_api/data_type_functions.py +++ b/cubed/array_api/data_type_functions.py @@ -1,23 +1,6 @@ -from dataclasses import dataclass - -import numpy as np -from numpy.array_api._typing import Dtype - from cubed.backend_array_api import namespace as nxp from cubed.core import CoreArray, map_blocks -from .dtypes import ( - _all_dtypes, - _boolean_dtypes, - _complex_floating_dtypes, - _integer_dtypes, - _numeric_dtypes, - _real_floating_dtypes, - _result_type, - _signed_integer_dtypes, - _unsigned_integer_dtypes, -) - def astype(x, dtype, /, *, copy=True): if not copy and dtype == x.dtype: @@ -30,118 +13,24 @@ def _astype(a, astype_dtype): def can_cast(from_, to, /): - # Copied from numpy.array_api - # TODO: replace with `nxp.can_cast` when NumPy 1.25 is widely used (e.g. in Xarray) - if isinstance(from_, CoreArray): from_ = from_.dtype - elif from_ not in _all_dtypes: - raise TypeError(f"{from_=}, but should be an array_api array or dtype") - if to not in _all_dtypes: - raise TypeError(f"{to=}, but should be a dtype") - try: - # We promote `from_` and `to` together. We then check if the promoted - # dtype is `to`, which indicates if `from_` can (up)cast to `to`. - dtype = _result_type(from_, to) - return to == dtype - except TypeError: - # _result_type() raises if the dtypes don't promote together - return False - - -@dataclass -class finfo_object: - bits: int - eps: float - max: float - min: float - smallest_normal: float - dtype: Dtype - - -@dataclass -class iinfo_object: - bits: int - max: int - min: int - dtype: Dtype + return nxp.can_cast(from_, to) def finfo(type, /): - # Copied from numpy.array_api - # TODO: replace with `nxp.finfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray) - - fi = np.finfo(type) - return finfo_object( - fi.bits, - float(fi.eps), - float(fi.max), - float(fi.min), - float(fi.smallest_normal), - fi.dtype, - ) + return nxp.finfo(type) def iinfo(type, /): - # Copied from numpy.array_api - # TODO: replace with `nxp.iinfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray) - - ii = np.iinfo(type) - return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype) + return nxp.iinfo(type) def isdtype(dtype, kind): - # Copied from numpy.array_api - # TODO: replace with `nxp.isdtype(dtype, kind)` when NumPy 1.25 is widely used (e.g. in Xarray) - - if isinstance(kind, tuple): - # Disallow nested tuples - if any(isinstance(k, tuple) for k in kind): - raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs") - return any(isdtype(dtype, k) for k in kind) - elif isinstance(kind, str): - if kind == "bool": - return dtype in _boolean_dtypes - elif kind == "signed integer": - return dtype in _signed_integer_dtypes - elif kind == "unsigned integer": - return dtype in _unsigned_integer_dtypes - elif kind == "integral": - return dtype in _integer_dtypes - elif kind == "real floating": - return dtype in _real_floating_dtypes - elif kind == "complex floating": - return dtype in _complex_floating_dtypes - elif kind == "numeric": - return dtype in _numeric_dtypes - else: - raise ValueError(f"Unrecognized data type kind: {kind!r}") - elif kind in _all_dtypes: - return dtype == kind - else: - raise TypeError( - f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}" - ) + return nxp.isdtype(dtype, kind) def result_type(*arrays_and_dtypes): - # Copied from numpy.array_api - # TODO: replace with `nxp.result_type` when NumPy 1.25 is widely used (e.g. in Xarray) - - A = [] - for a in arrays_and_dtypes: - if isinstance(a, CoreArray): - a = a.dtype - elif isinstance(a, np.ndarray) or a not in _all_dtypes: - raise TypeError("result_type() inputs must be array_api arrays or dtypes") - A.append(a) - - if len(A) == 0: - raise ValueError("at least one array or dtype is required") - elif len(A) == 1: - return A[0] - else: - t = A[0] - for t2 in A[1:]: - t = _result_type(t, t2) - return t + return nxp.result_type( + *(a.dtype if isinstance(a, CoreArray) else a for a in arrays_and_dtypes) + ) diff --git a/cubed/array_api/dtypes.py b/cubed/array_api/dtypes.py index b3263791..ce902dc0 100644 --- a/cubed/array_api/dtypes.py +++ b/cubed/array_api/dtypes.py @@ -1,22 +1,19 @@ # Copied from numpy.array_api -import numpy as np +from cubed.backend_array_api import namespace as nxp -# Note: we use dtype objects instead of dtype classes. The spec does not -# require any behavior on dtypes other than equality. -int8 = np.dtype("int8") -int16 = np.dtype("int16") -int32 = np.dtype("int32") -int64 = np.dtype("int64") -uint8 = np.dtype("uint8") -uint16 = np.dtype("uint16") -uint32 = np.dtype("uint32") -uint64 = np.dtype("uint64") -float32 = np.dtype("float32") -float64 = np.dtype("float64") -complex64 = np.dtype("complex64") -complex128 = np.dtype("complex128") -# Note: This name is changed -bool = np.dtype("bool") +int8 = nxp.int8 +int16 = nxp.int16 +int32 = nxp.int32 +int64 = nxp.int64 +uint8 = nxp.uint8 +uint16 = nxp.uint16 +uint32 = nxp.uint32 +uint64 = nxp.uint64 +float32 = nxp.float32 +float64 = nxp.float64 +complex64 = nxp.complex64 +complex128 = nxp.complex128 +bool = nxp.bool _all_dtypes = ( int8, @@ -89,85 +86,3 @@ "complex floating-point": _complex_floating_dtypes, "floating-point": _floating_dtypes, } - -_promotion_table = { - (int8, int8): int8, - (int8, int16): int16, - (int8, int32): int32, - (int8, int64): int64, - (int16, int8): int16, - (int16, int16): int16, - (int16, int32): int32, - (int16, int64): int64, - (int32, int8): int32, - (int32, int16): int32, - (int32, int32): int32, - (int32, int64): int64, - (int64, int8): int64, - (int64, int16): int64, - (int64, int32): int64, - (int64, int64): int64, - (uint8, uint8): uint8, - (uint8, uint16): uint16, - (uint8, uint32): uint32, - (uint8, uint64): uint64, - (uint16, uint8): uint16, - (uint16, uint16): uint16, - (uint16, uint32): uint32, - (uint16, uint64): uint64, - (uint32, uint8): uint32, - (uint32, uint16): uint32, - (uint32, uint32): uint32, - (uint32, uint64): uint64, - (uint64, uint8): uint64, - (uint64, uint16): uint64, - (uint64, uint32): uint64, - (uint64, uint64): uint64, - (int8, uint8): int16, - (int8, uint16): int32, - (int8, uint32): int64, - (int16, uint8): int16, - (int16, uint16): int32, - (int16, uint32): int64, - (int32, uint8): int32, - (int32, uint16): int32, - (int32, uint32): int64, - (int64, uint8): int64, - (int64, uint16): int64, - (int64, uint32): int64, - (uint8, int8): int16, - (uint16, int8): int32, - (uint32, int8): int64, - (uint8, int16): int16, - (uint16, int16): int32, - (uint32, int16): int64, - (uint8, int32): int32, - (uint16, int32): int32, - (uint32, int32): int64, - (uint8, int64): int64, - (uint16, int64): int64, - (uint32, int64): int64, - (float32, float32): float32, - (float32, float64): float64, - (float64, float32): float64, - (float64, float64): float64, - (complex64, complex64): complex64, - (complex64, complex128): complex128, - (complex128, complex64): complex128, - (complex128, complex128): complex128, - (float32, complex64): complex64, - (float32, complex128): complex128, - (float64, complex64): complex128, - (float64, complex128): complex128, - (complex64, float32): complex64, - (complex64, float64): complex128, - (complex128, float32): complex128, - (complex128, float64): complex128, - (bool, bool): bool, -} - - -def _result_type(type1, type2): - if (type1, type2) in _promotion_table: - return _promotion_table[type1, type2] - raise TypeError(f"{type1} and {type2} cannot be type promoted together") diff --git a/cubed/array_api/elementwise_functions.py b/cubed/array_api/elementwise_functions.py index 02af44b5..7d8f0086 100644 --- a/cubed/array_api/elementwise_functions.py +++ b/cubed/array_api/elementwise_functions.py @@ -1,5 +1,3 @@ -import numpy as np - from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import ( _boolean_dtypes, @@ -170,7 +168,7 @@ def expm1(x, /): def equal(x1, x2, /): - return elemwise(nxp.equal, x1, x2, dtype=np.bool_) + return elemwise(nxp.equal, x1, x2, dtype=nxp.bool) def floor(x, /): @@ -189,11 +187,11 @@ def floor_divide(x1, x2, /): def greater(x1, x2, /): - return elemwise(nxp.greater, x1, x2, dtype=np.bool_) + return elemwise(nxp.greater, x1, x2, dtype=nxp.bool) def greater_equal(x1, x2, /): - return elemwise(nxp.greater_equal, x1, x2, dtype=np.bool_) + return elemwise(nxp.greater_equal, x1, x2, dtype=nxp.bool) def imag(x, /): @@ -209,27 +207,27 @@ def imag(x, /): def isfinite(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isfinite") - return elemwise(nxp.isfinite, x, dtype=np.bool_) + return elemwise(nxp.isfinite, x, dtype=nxp.bool) def isinf(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isinf") - return elemwise(nxp.isinf, x, dtype=np.bool_) + return elemwise(nxp.isinf, x, dtype=nxp.bool) def isnan(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isnan") - return elemwise(nxp.isnan, x, dtype=np.bool_) + return elemwise(nxp.isnan, x, dtype=nxp.bool) def less(x1, x2, /): - return elemwise(nxp.less, x1, x2, dtype=np.bool_) + return elemwise(nxp.less, x1, x2, dtype=nxp.bool) def less_equal(x1, x2, /): - return elemwise(nxp.less_equal, x1, x2, dtype=np.bool_) + return elemwise(nxp.less_equal, x1, x2, dtype=nxp.bool) def log(x, /): @@ -265,25 +263,25 @@ def logaddexp(x1, x2, /): def logical_and(x1, x2, /): if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_and") - return elemwise(nxp.logical_and, x1, x2, dtype=np.bool_) + return elemwise(nxp.logical_and, x1, x2, dtype=nxp.bool) def logical_not(x, /): if x.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_not") - return elemwise(nxp.logical_not, x, dtype=np.bool_) + return elemwise(nxp.logical_not, x, dtype=nxp.bool) def logical_or(x1, x2, /): if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_or") - return elemwise(nxp.logical_or, x1, x2, dtype=np.bool_) + return elemwise(nxp.logical_or, x1, x2, dtype=nxp.bool) def logical_xor(x1, x2, /): if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_xor") - return elemwise(nxp.logical_xor, x1, x2, dtype=np.bool_) + return elemwise(nxp.logical_xor, x1, x2, dtype=nxp.bool) def multiply(x1, x2, /): @@ -299,7 +297,7 @@ def negative(x, /): def not_equal(x1, x2, /): - return elemwise(nxp.not_equal, x1, x2, dtype=np.bool_) + return elemwise(nxp.not_equal, x1, x2, dtype=nxp.bool) def positive(x, /): diff --git a/cubed/array_api/linear_algebra_functions.py b/cubed/array_api/linear_algebra_functions.py index 542d2d38..114242f5 100644 --- a/cubed/array_api/linear_algebra_functions.py +++ b/cubed/array_api/linear_algebra_functions.py @@ -1,8 +1,6 @@ from numbers import Integral from typing import Iterable -import numpy as np - from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import _numeric_dtypes from cubed.array_api.manipulation_functions import expand_dims @@ -61,7 +59,7 @@ def matmul(x1, x2, /): def _matmul(a, b): chunk = nxp.matmul(a, b) - return chunk[..., np.newaxis, :] + return chunk[..., nxp.newaxis, :] def _sum_wo_cat(a, axis=None, dtype=None): diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index a6735799..48e0fbf3 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -64,7 +64,7 @@ def broadcast_to(x, /, shape, *, chunks=None): ) # create an empty array as a template for blockwise to do broadcasting - template = empty(shape, dtype=np.int8, chunks=chunks, spec=x.spec) + template = empty(shape, dtype=nxp.int8, chunks=chunks, spec=x.spec) return elemwise(_broadcast_like, x, template, dtype=x.dtype) diff --git a/cubed/array_api/searching_functions.py b/cubed/array_api/searching_functions.py index dfe3b318..0b8991b7 100644 --- a/cubed/array_api/searching_functions.py +++ b/cubed/array_api/searching_functions.py @@ -1,5 +1,3 @@ -import numpy as np - from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import _real_numeric_dtypes from cubed.array_api.manipulation_functions import reshape @@ -14,7 +12,7 @@ def argmax(x, /, *, axis=None, keepdims=False): x = reshape(x, (-1,)) axis = 0 keepdims = False - return arg_reduction(x, np.argmax, axis=axis, keepdims=keepdims) + return arg_reduction(x, nxp.argmax, axis=axis, keepdims=keepdims) def argmin(x, /, *, axis=None, keepdims=False): @@ -24,7 +22,7 @@ def argmin(x, /, *, axis=None, keepdims=False): x = reshape(x, (-1,)) axis = 0 keepdims = False - return arg_reduction(x, np.argmin, axis=axis, keepdims=keepdims) + return arg_reduction(x, nxp.argmin, axis=axis, keepdims=keepdims) def where(condition, x1, x2, /): diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 21a50d2c..9cd3c409 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -1,7 +1,5 @@ import math -import numpy as np - from cubed.array_api.dtypes import ( _numeric_dtypes, _real_floating_dtypes, @@ -35,7 +33,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False, split_every=Non # this is usually OK. An alternative would be to add support for multiple # outputs. dtype = x.dtype - intermediate_dtype = [("n", np.int64), ("total", np.float64)] + intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] extra_func_kwargs = dict(dtype=intermediate_dtype) return reduction( x, @@ -78,10 +76,10 @@ def _numel(x, **kwargs): shape = x.shape keepdims = kwargs.get("keepdims", False) axis = kwargs.get("axis", None) - dtype = kwargs.get("dtype", np.float64) + dtype = kwargs.get("dtype", nxp.float64) if axis is None: - prod = np.prod(shape, dtype=dtype) + prod = nxp.prod(shape, dtype=dtype) if keepdims is False: return prod diff --git a/cubed/core/array.py b/cubed/core/array.py index 4082ab37..70e620e6 100644 --- a/cubed/core/array.py +++ b/cubed/core/array.py @@ -1,10 +1,11 @@ from operator import mul from typing import Optional, TypeVar -import numpy as np from toolz import map, reduce from cubed import config +from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import numpy_array_to_backend_array from cubed.runtime.types import Callback, Executor from cubed.spec import Spec, spec_from_config from cubed.storage.zarr import open_if_lazy_zarr_array @@ -115,10 +116,10 @@ def _read_stored(self): # Only works if the array has been computed if self.size > 0: # read back from zarr - return self.zarray[...] + return numpy_array_to_backend_array(self.zarray[...]) else: # this case fails for zarr, so just return an empty array of the correct shape - return np.empty(self.shape, dtype=self.dtype) + return nxp.empty(self.shape, dtype=self.dtype) def compute( self, diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 5675a223..9c4e4c43 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -1159,13 +1159,13 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None): assert result.keys() == reduced_chunk.keys() result = { # only need to concatenate along first axis - k: np.concatenate([result[k], reduced_chunk[k]], axis=axis[0]) + k: nxp.concat([result[k], reduced_chunk[k]], axis=axis[0]) for k in result.keys() } result = reduce_func(result, axis=axis, keepdims=True) else: # only need to concatenate along first axis - result = np.concatenate([result, reduced_chunk], axis=axis[0]) + result = nxp.concat([result, reduced_chunk], axis=axis[0]) result = reduce_func(result, axis=axis, keepdims=True) return result @@ -1173,7 +1173,7 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None): def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False): """A reduction that returns the array indexes, not the values.""" - dtype = np.int64 # index data type + dtype = nxp.int64 # index data type intermediate_dtype = [("i", dtype), ("v", x.dtype)] # initial map does arg reduction on each block, and uses block id to find the absolute index within whole array diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index e3789481..62b3535b 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -21,7 +21,7 @@ def nanmean(x, /, *, axis=None, keepdims=False): """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" dtype = x.dtype - intermediate_dtype = [("n", np.int64), ("total", np.float64)] + intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] return reduction( x, _nanmean_func, diff --git a/cubed/storage/virtual.py b/cubed/storage/virtual.py index 185681b0..a919321d 100644 --- a/cubed/storage/virtual.py +++ b/cubed/storage/virtual.py @@ -84,7 +84,7 @@ class VirtualOffsetsArray: """An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers.""" def __init__(self, shape: T_Shape): - dtype = np.int32 + dtype = nxp.int32 chunks = (1,) * len(shape) # use an empty in-memory Zarr array as a template since it normalizes its properties template = zarr.empty( diff --git a/cubed/tests/test_random.py b/cubed/tests/test_random.py index e3b7f904..ed8ca854 100644 --- a/cubed/tests/test_random.py +++ b/cubed/tests/test_random.py @@ -6,6 +6,7 @@ import cubed import cubed.array_api as xp import cubed.random +from cubed.backend_array_api import namespace as nxp from cubed.tests.utils import MAIN_EXECUTORS @@ -25,7 +26,7 @@ def test_random(spec, executor): assert a.shape == (10, 10) assert a.chunks == ((4, 4, 2), (5, 5)) - x = set(a.compute(executor=executor).flat) + x = nxp.unique_values(a.compute(executor=executor)) assert len(x) > 90 @@ -35,7 +36,7 @@ def test_random_add(spec, executor): c = xp.add(a, b) - x = set(c.compute(executor=executor).flat) + x = nxp.unique_values(c.compute(executor=executor)) assert len(x) > 90