Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Namespace-aware xarray.ufuncs #9776

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

slevang
Copy link
Contributor

@slevang slevang commented Nov 13, 2024

Re-implement the old xarray.ufuncs module to allow generic ufunc handling for array types that don't implement __array_ufunc__:

import jax.numpy as jnp
import numpy as np
import xarray as xr
import xarray.ufuncs as xu

x = xr.DataArray(jnp.asarray([1, 2, 3]))
print(type(xu.sin(x).data))
print(type(np.sin(x).data))

# <class 'jaxlib.xla_extension.ArrayImpl'>
# <class 'numpy.ndarray'>

elif isinstance(obj, array_type("dask")):
_walk_array_namespaces(obj._meta, namespaces)
else:
namespace = getattr(obj, "__array_namespace__", None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ever prioritize dispatching with np.func via __array_ufunc__ (if it exists) over the library's __array_namespace__().func?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__array_ufunc__ is a more generic protocol, intended to support arbitrary new ufuncs without requiring an array library to be aware of them.

In practice there are very few examples of ufuncs defined outside of NumPy itself, and we wouldn't need to support them here because we are explicitly listing supported ufuncs in this module. I guess the one example that comes to mind would be the rare cases where NumPy adds a new ufunc and an array wrapping library like Dask hasn't written a wrapper yet.

At this point, I think going "all in" on __array_namespace__ is the right call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my other comment would be the main reason to consider __array_ufunc__. Some duck arrays don't implement all ufuncs. So either of these approaches would solve the same problem.

xarray/ufuncs.py Outdated
)
func = getattr(np, self._name)

return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"? With jax for example, which doesn't have __array_ufunc__, this ends up converting to numpy. So it would have to be special cased.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In user code using xr.apply_ufunc there is - dask='allowed' can be used to rechunk along a core dimension e.g. by applying a dask reduction ufunc along that dimension. Not sure if that's relevant here though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all elementwise so no core dimensions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"?

Yes, this feels like a cleaner solution to me.

With jax for example, which doesn't have array_ufunc, this ends up converting to numpy. So it would have to be special cased

Is the concern here Dask wrapping JAX?

Generally, I think it's best for xarray to avoid introspecting into wrapped array types in Xarray, and leave nested wrapping up to other libraries, which can better understand their own implementation details. So Dask wrapping JAX should be fixed in Dask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Dask wrapping JAX should be fixed in Dask.

Totally fair. It looks like basically the same effort here would be required in dask then, because dask's ufuncs are all simple wrappers around the numpy version so they aren't aware of the namespace.

xarray/ufuncs.py Outdated


# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can hard code these if preferred?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would suggest hard coding these if possible, ideally as something like:

sin = _unary_ufunc('sin')
add = _binary_ufunc('add')
...

The reason to prefer this is that static typing does not evaluate loops or dynamically defined functions. So otherwise xarray.ufuncs.sin() will not be recognized as valid by tools like mypy.

Ideally (could be done in a follow-up PR), these functions could be annotated to follow the appropriate type casting rules, so type checkers would know that xarray.ufuncs.sin(dataset) returns another Dataset.

In some cases, we use a script to generate all these special methods. I don't think that should be necessary here, but it still may be worth a look to understand the type casting rules:
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/util/generate_ops.py
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/core/_typed_ops.py

xarray/ufuncs.py Outdated

# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the ones that didn't immediately work. There are also other ufunc like things that aren't technically np.ufunc subclasses that we could add. I saw angle and iscomplex were special cased before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth noting that the reason why matmul and vecdot doesn't work is that they are "generalized ufuncs" that use core dimensions.

divmod, frexp and modf doesn't work because they return multiple arrays.

I'm not sure why isnat didn't work for you. Did you test it with datetime dtypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test it with datetime dtypes?

No, this was just a real quick initial pass. Will add this with a special test case.

Do you have an opinion about adding any of the ones with multiple return values? Seems low priority to me.

Same question for the odd balls like angle, iscomplex, isreal, etc?

@slevang slevang marked this pull request as ready for review November 13, 2024 14:17
@TomNicholas TomNicholas added topic-arrays related to flexible array support array API standard Support for the Python array API standard labels Nov 14, 2024
@dcherian dcherian requested a review from keewis November 15, 2024 16:23
xarray/ufuncs.py Outdated
Comment on lines 55 to 60
if func is None:
warnings.warn(
f"Function {self._name} not found in {xp.__name__}, falling back to numpy",
stacklevel=2,
)
func = getattr(np, self._name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would lean towards skipping this fall-back, unless there are particularly motivating cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Motivation here would be duck arrays that implement __array_ufunc__ and don't implement the full suite of numpy ufuncs. I ran into this with sparse. Not sure the full delta list, but I see they don't have sin/cos for example. In this case, np.cos(x_sparse) works but xp.cos(x_sparse) fails, which is a little weird. Not the most elegant solution though, I agree.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's intentional than xp.cos(x_sparse) fails, because cos(0) != 0, so the result is no longer sparse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I was wrong, there is a sparse.cos and a bunch of others, although they don't appear in the API docs. It seems sparse's general approach to these is compute elementwise on the valid data, and then modify the fill_value as required for the empty regions.

There are still 40-some functions that fail without this fallback, although generally more niche. With the fallback, all work and output a sparse array (no auto densification):

absolute, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, bitwise_count, cbrt, conj, conjugate, copysign, deg2rad, degrees, exp2, expm1x, fabs, float_power, fmax, fmin, fmod, gcd, heaviside, hypot, invert, isreal, lcm, ldexp, left_shift, logaddexp2, maximum, minimum, mod, nextafter, power, rad2deg, radians, reciprocal, right_shift, rint, signbit, spacing, true_divide

xarray/ufuncs.py Outdated
Comment on lines 24 to 25
elif isinstance(obj, array_type("dask")):
_walk_array_namespaces(obj._meta, namespaces)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not sure Dask's __array_namespace__ instead? That feels a little cleaner than special case logic for dask.array.

xarray/ufuncs.py Outdated
)
func = getattr(np, self._name)

return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"?

Yes, this feels like a cleaner solution to me.

With jax for example, which doesn't have array_ufunc, this ends up converting to numpy. So it would have to be special cased

Is the concern here Dask wrapping JAX?

Generally, I think it's best for xarray to avoid introspecting into wrapped array types in Xarray, and leave nested wrapping up to other libraries, which can better understand their own implementation details. So Dask wrapping JAX should be fixed in Dask.

xarray/ufuncs.py Outdated


# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would suggest hard coding these if possible, ideally as something like:

sin = _unary_ufunc('sin')
add = _binary_ufunc('add')
...

The reason to prefer this is that static typing does not evaluate loops or dynamically defined functions. So otherwise xarray.ufuncs.sin() will not be recognized as valid by tools like mypy.

Ideally (could be done in a follow-up PR), these functions could be annotated to follow the appropriate type casting rules, so type checkers would know that xarray.ufuncs.sin(dataset) returns another Dataset.

In some cases, we use a script to generate all these special methods. I don't think that should be necessary here, but it still may be worth a look to understand the type casting rules:
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/util/generate_ops.py
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/core/_typed_ops.py

xarray/ufuncs.py Outdated
Comment on lines 14 to 16
if isinstance(obj, xr.DataTree):
for node in obj.subtree:
_walk_array_namespaces(node.dataset, namespaces)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the right thing to do, but it may be worth noting that apply_ufunc does not work for DataTree yet: #9789

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I noticed that, will make a note

elif isinstance(obj, array_type("dask")):
_walk_array_namespaces(obj._meta, namespaces)
else:
namespace = getattr(obj, "__array_namespace__", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__array_ufunc__ is a more generic protocol, intended to support arbitrary new ufuncs without requiring an array library to be aware of them.

In practice there are very few examples of ufuncs defined outside of NumPy itself, and we wouldn't need to support them here because we are explicitly listing supported ufuncs in this module. I guess the one example that comes to mind would be the rare cases where NumPy adds a new ufunc and an array wrapping library like Dask hasn't written a wrapper yet.

At this point, I think going "all in" on __array_namespace__ is the right call.

xarray/ufuncs.py Outdated

# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth noting that the reason why matmul and vecdot doesn't work is that they are "generalized ufuncs" that use core dimensions.

divmod, frexp and modf doesn't work because they return multiple arrays.

I'm not sure why isnat didn't work for you. Did you test it with datetime dtypes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array API standard Support for the Python array API standard topic-arrays related to flexible array support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Compatibility with the Array API standard
3 participants