Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delegate dtype functions to backend array API #410

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,14 @@ jobs:
array_api_tests/test_statistical_functions.py::test_std
array_api_tests/test_statistical_functions.py::test_var

# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]

# From https://github.com/data-apis/array-api-tests/blob/master/.github/workflows/numpy.yml

# https://github.com/numpy/numpy/issues/18881
array_api_tests/test_creation_functions.py::test_linspace

# https://github.com/numpy/numpy/issues/20870
array_api_tests/test_data_type_functions.py::test_can_cast

EOF

pytest -v -rxXfEA --max-examples=2 --disable-data-dependent-shapes --disable-extension linalg --ci --hypothesis-disable-deadline --cov=cubed.array_api --cov-report=term-missing
7 changes: 0 additions & 7 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# Suppress numpy.array_api experimental warning
import sys
import warnings

if not sys.warnoptions:
warnings.filterwarnings("ignore", category=UserWarning)

from importlib.metadata import version as _version

try:
Expand Down
14 changes: 7 additions & 7 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,37 +225,37 @@ def __eq__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
return elemwise(nxp.equal, self, other, dtype=np.bool_)
return elemwise(nxp.equal, self, other, dtype=nxp.bool)

def __ge__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__ge__")
if other is NotImplemented:
return other
return elemwise(nxp.greater_equal, self, other, dtype=np.bool_)
return elemwise(nxp.greater_equal, self, other, dtype=nxp.bool)

def __gt__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__gt__")
if other is NotImplemented:
return other
return elemwise(nxp.greater, self, other, dtype=np.bool_)
return elemwise(nxp.greater, self, other, dtype=nxp.bool)

def __le__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__le__")
if other is NotImplemented:
return other
return elemwise(nxp.less_equal, self, other, dtype=np.bool_)
return elemwise(nxp.less_equal, self, other, dtype=nxp.bool)

def __lt__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__lt__")
if other is NotImplemented:
return other
return elemwise(nxp.less, self, other, dtype=np.bool_)
return elemwise(nxp.less, self, other, dtype=nxp.bool)

def __ne__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
return elemwise(nxp.not_equal, self, other, dtype=np.bool_)
return elemwise(nxp.not_equal, self, other, dtype=nxp.bool)

# Reflected Operators

Expand Down Expand Up @@ -425,7 +425,7 @@ def _promote_scalar(self, scalar):
"Python int scalars cannot be promoted with bool arrays"
)
if self.dtype in _integer_dtypes:
info = np.iinfo(self.dtype)
info = nxp.iinfo(self.dtype)
if not (info.min <= scalar <= info.max):
raise OverflowError(
"Python int scalars must be within the bounds of the dtype for integer arrays"
Expand Down
12 changes: 6 additions & 6 deletions cubed/array_api/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from cubed.backend_array_api import namespace as nxp

e = np.e
inf = np.inf
nan = np.nan
newaxis = None
pi = np.pi
e = nxp.e
inf = nxp.inf
nan = nxp.nan
newaxis = nxp.newaxis
pi = nxp.pi
17 changes: 8 additions & 9 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
from typing import TYPE_CHECKING, Iterable, List

import numpy as np
from zarr.util import normalize_shape

from cubed.backend_array_api import namespace as nxp
Expand Down Expand Up @@ -93,7 +92,7 @@ def empty_virtual_array(
shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True
) -> "Array":
if dtype is None:
dtype = np.float64
dtype = nxp.float64

chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
name = gensym()
Expand All @@ -111,7 +110,7 @@ def eye(
if n_cols is None:
n_cols = n_rows
if dtype is None:
dtype = np.float64
dtype = nxp.float64

shape = (n_rows, n_cols)
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
Expand Down Expand Up @@ -143,11 +142,11 @@ def full(
if dtype is None:
# check bool first since True/False are instances of int and float
if isinstance(fill_value, bool):
dtype = np.bool_
dtype = nxp.bool
elif isinstance(fill_value, int):
dtype = np.int64
dtype = nxp.int64
elif isinstance(fill_value, float):
dtype = np.float64
dtype = nxp.float64
else:
raise TypeError("Invalid input to full")
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
Expand Down Expand Up @@ -194,7 +193,7 @@ def linspace(
div = 1
step = float(range_) / div
if dtype is None:
dtype = np.float64
dtype = nxp.float64
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]

Expand Down Expand Up @@ -258,7 +257,7 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]:

def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
if dtype is None:
dtype = np.float64
dtype = nxp.float64
return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down Expand Up @@ -304,7 +303,7 @@ def _tri_mask(N, M, k, chunks, spec):

def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
if dtype is None:
dtype = np.float64
dtype = nxp.float64
return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down
125 changes: 7 additions & 118 deletions cubed/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,6 @@
from dataclasses import dataclass

import numpy as np
from numpy.array_api._typing import Dtype

from cubed.backend_array_api import namespace as nxp
from cubed.core import CoreArray, map_blocks

from .dtypes import (
_all_dtypes,
_boolean_dtypes,
_complex_floating_dtypes,
_integer_dtypes,
_numeric_dtypes,
_real_floating_dtypes,
_result_type,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
)


def astype(x, dtype, /, *, copy=True):
if not copy and dtype == x.dtype:
Expand All @@ -30,118 +13,24 @@ def _astype(a, astype_dtype):


def can_cast(from_, to, /):
# Copied from numpy.array_api
# TODO: replace with `nxp.can_cast` when NumPy 1.25 is widely used (e.g. in Xarray)

if isinstance(from_, CoreArray):
from_ = from_.dtype
elif from_ not in _all_dtypes:
raise TypeError(f"{from_=}, but should be an array_api array or dtype")
if to not in _all_dtypes:
raise TypeError(f"{to=}, but should be a dtype")
try:
# We promote `from_` and `to` together. We then check if the promoted
# dtype is `to`, which indicates if `from_` can (up)cast to `to`.
dtype = _result_type(from_, to)
return to == dtype
except TypeError:
# _result_type() raises if the dtypes don't promote together
return False


@dataclass
class finfo_object:
bits: int
eps: float
max: float
min: float
smallest_normal: float
dtype: Dtype


@dataclass
class iinfo_object:
bits: int
max: int
min: int
dtype: Dtype
return nxp.can_cast(from_, to)


def finfo(type, /):
# Copied from numpy.array_api
# TODO: replace with `nxp.finfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray)

fi = np.finfo(type)
return finfo_object(
fi.bits,
float(fi.eps),
float(fi.max),
float(fi.min),
float(fi.smallest_normal),
fi.dtype,
)
return nxp.finfo(type)


def iinfo(type, /):
# Copied from numpy.array_api
# TODO: replace with `nxp.iinfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray)

ii = np.iinfo(type)
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)
return nxp.iinfo(type)


def isdtype(dtype, kind):
# Copied from numpy.array_api
# TODO: replace with `nxp.isdtype(dtype, kind)` when NumPy 1.25 is widely used (e.g. in Xarray)

if isinstance(kind, tuple):
# Disallow nested tuples
if any(isinstance(k, tuple) for k in kind):
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
return any(isdtype(dtype, k) for k in kind)
elif isinstance(kind, str):
if kind == "bool":
return dtype in _boolean_dtypes
elif kind == "signed integer":
return dtype in _signed_integer_dtypes
elif kind == "unsigned integer":
return dtype in _unsigned_integer_dtypes
elif kind == "integral":
return dtype in _integer_dtypes
elif kind == "real floating":
return dtype in _real_floating_dtypes
elif kind == "complex floating":
return dtype in _complex_floating_dtypes
elif kind == "numeric":
return dtype in _numeric_dtypes
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
elif kind in _all_dtypes:
return dtype == kind
else:
raise TypeError(
f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}"
)
return nxp.isdtype(dtype, kind)


def result_type(*arrays_and_dtypes):
# Copied from numpy.array_api
# TODO: replace with `nxp.result_type` when NumPy 1.25 is widely used (e.g. in Xarray)

A = []
for a in arrays_and_dtypes:
if isinstance(a, CoreArray):
a = a.dtype
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
A.append(a)

if len(A) == 0:
raise ValueError("at least one array or dtype is required")
elif len(A) == 1:
return A[0]
else:
t = A[0]
for t2 in A[1:]:
t = _result_type(t, t2)
return t
return nxp.result_type(
*(a.dtype if isinstance(a, CoreArray) else a for a in arrays_and_dtypes)
)
Loading
Loading