-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Namespace-aware xarray.ufuncs
#9776
base: main
Are you sure you want to change the base?
Conversation
elif isinstance(obj, array_type("dask")): | ||
_walk_array_namespaces(obj._meta, namespaces) | ||
else: | ||
namespace = getattr(obj, "__array_namespace__", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we ever prioritize dispatching with np.func
via __array_ufunc__
(if it exists) over the library's __array_namespace__().func
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__array_ufunc__
is a more generic protocol, intended to support arbitrary new ufuncs without requiring an array library to be aware of them.
In practice there are very few examples of ufuncs defined outside of NumPy itself, and we wouldn't need to support them here because we are explicitly listing supported ufuncs in this module. I guess the one example that comes to mind would be the rare cases where NumPy adds a new ufunc and an array wrapping library like Dask hasn't written a wrapper yet.
At this point, I think going "all in" on __array_namespace__
is the right call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my other comment would be the main reason to consider __array_ufunc__
. Some duck arrays don't implement all ufuncs. So either of these approaches would solve the same problem.
xarray/ufuncs.py
Outdated
) | ||
func = getattr(np, self._name) | ||
|
||
return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a reason to use dask's ufuncs with dask="allowed"
instead of the appropriate _meta
array's namespace and dask="parallelized"
? With jax
for example, which doesn't have __array_ufunc__
, this ends up converting to numpy
. So it would have to be special cased.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In user code using xr.apply_ufunc
there is - dask='allowed'
can be used to rechunk along a core dimension e.g. by applying a dask reduction ufunc along that dimension. Not sure if that's relevant here though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are all elementwise so no core dimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"?
Yes, this feels like a cleaner solution to me.
With jax for example, which doesn't have array_ufunc, this ends up converting to numpy. So it would have to be special cased
Is the concern here Dask wrapping JAX?
Generally, I think it's best for xarray to avoid introspecting into wrapped array types in Xarray, and leave nested wrapping up to other libraries, which can better understand their own implementation details. So Dask wrapping JAX should be fixed in Dask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So Dask wrapping JAX should be fixed in Dask.
Totally fair. It looks like basically the same effort here would be required in dask then, because dask's ufuncs are all simple wrappers around the numpy version so they aren't aware of the namespace.
xarray/ufuncs.py
Outdated
|
||
|
||
# Auto generate from the public numpy ufuncs | ||
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can hard code these if preferred?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would suggest hard coding these if possible, ideally as something like:
sin = _unary_ufunc('sin')
add = _binary_ufunc('add')
...
The reason to prefer this is that static typing does not evaluate loops or dynamically defined functions. So otherwise xarray.ufuncs.sin()
will not be recognized as valid by tools like mypy.
Ideally (could be done in a follow-up PR), these functions could be annotated to follow the appropriate type casting rules, so type checkers would know that xarray.ufuncs.sin(dataset)
returns another Dataset.
In some cases, we use a script to generate all these special methods. I don't think that should be necessary here, but it still may be worth a look to understand the type casting rules:
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/util/generate_ops.py
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/core/_typed_ops.py
xarray/ufuncs.py
Outdated
|
||
# Auto generate from the public numpy ufuncs | ||
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} | ||
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the ones that didn't immediately work. There are also other ufunc like things that aren't technically np.ufunc
subclasses that we could add. I saw angle and iscomplex were special cased before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth noting that the reason why matmul
and vecdot
doesn't work is that they are "generalized ufuncs" that use core dimensions.
divmod
, frexp
and modf
doesn't work because they return multiple arrays.
I'm not sure why isnat
didn't work for you. Did you test it with datetime dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you test it with datetime dtypes?
No, this was just a real quick initial pass. Will add this with a special test case.
Do you have an opinion about adding any of the ones with multiple return values? Seems low priority to me.
Same question for the odd balls like angle
, iscomplex
, isreal
, etc?
xarray/ufuncs.py
Outdated
if func is None: | ||
warnings.warn( | ||
f"Function {self._name} not found in {xp.__name__}, falling back to numpy", | ||
stacklevel=2, | ||
) | ||
func = getattr(np, self._name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would lean towards skipping this fall-back, unless there are particularly motivating cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Motivation here would be duck arrays that implement __array_ufunc__
and don't implement the full suite of numpy ufuncs. I ran into this with sparse. Not sure the full delta list, but I see they don't have sin/cos for example. In this case, np.cos(x_sparse)
works but xp.cos(x_sparse)
fails, which is a little weird. Not the most elegant solution though, I agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's intentional than xp.cos(x_sparse)
fails, because cos(0) != 0
, so the result is no longer sparse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I was wrong, there is a sparse.cos
and a bunch of others, although they don't appear in the API docs. It seems sparse's general approach to these is compute elementwise on the valid data, and then modify the fill_value as required for the empty regions.
There are still 40-some functions that fail without this fallback, although generally more niche. With the fallback, all work and output a sparse array (no auto densification):
absolute, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, bitwise_count, cbrt, conj, conjugate, copysign, deg2rad, degrees, exp2, expm1x, fabs, float_power, fmax, fmin, fmod, gcd, heaviside, hypot, invert, isreal, lcm, ldexp, left_shift, logaddexp2, maximum, minimum, mod, nextafter, power, rad2deg, radians, reciprocal, right_shift, rint, signbit, spacing, true_divide
xarray/ufuncs.py
Outdated
elif isinstance(obj, array_type("dask")): | ||
_walk_array_namespaces(obj._meta, namespaces) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not sure Dask's __array_namespace__
instead? That feels a little cleaner than special case logic for dask.array.
xarray/ufuncs.py
Outdated
) | ||
func = getattr(np, self._name) | ||
|
||
return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"?
Yes, this feels like a cleaner solution to me.
With jax for example, which doesn't have array_ufunc, this ends up converting to numpy. So it would have to be special cased
Is the concern here Dask wrapping JAX?
Generally, I think it's best for xarray to avoid introspecting into wrapped array types in Xarray, and leave nested wrapping up to other libraries, which can better understand their own implementation details. So Dask wrapping JAX should be fixed in Dask.
xarray/ufuncs.py
Outdated
|
||
|
||
# Auto generate from the public numpy ufuncs | ||
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would suggest hard coding these if possible, ideally as something like:
sin = _unary_ufunc('sin')
add = _binary_ufunc('add')
...
The reason to prefer this is that static typing does not evaluate loops or dynamically defined functions. So otherwise xarray.ufuncs.sin()
will not be recognized as valid by tools like mypy.
Ideally (could be done in a follow-up PR), these functions could be annotated to follow the appropriate type casting rules, so type checkers would know that xarray.ufuncs.sin(dataset)
returns another Dataset.
In some cases, we use a script to generate all these special methods. I don't think that should be necessary here, but it still may be worth a look to understand the type casting rules:
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/util/generate_ops.py
https://github.com/pydata/xarray/blob/d5f84dd1ef4c023cf2ea0a38866c9d9cd50487e7/xarray/core/_typed_ops.py
xarray/ufuncs.py
Outdated
if isinstance(obj, xr.DataTree): | ||
for node in obj.subtree: | ||
_walk_array_namespaces(node.dataset, namespaces) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the right thing to do, but it may be worth noting that apply_ufunc
does not work for DataTree yet: #9789
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep I noticed that, will make a note
elif isinstance(obj, array_type("dask")): | ||
_walk_array_namespaces(obj._meta, namespaces) | ||
else: | ||
namespace = getattr(obj, "__array_namespace__", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__array_ufunc__
is a more generic protocol, intended to support arbitrary new ufuncs without requiring an array library to be aware of them.
In practice there are very few examples of ufuncs defined outside of NumPy itself, and we wouldn't need to support them here because we are explicitly listing supported ufuncs in this module. I guess the one example that comes to mind would be the rare cases where NumPy adds a new ufunc and an array wrapping library like Dask hasn't written a wrapper yet.
At this point, I think going "all in" on __array_namespace__
is the right call.
xarray/ufuncs.py
Outdated
|
||
# Auto generate from the public numpy ufuncs | ||
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} | ||
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth noting that the reason why matmul
and vecdot
doesn't work is that they are "generalized ufuncs" that use core dimensions.
divmod
, frexp
and modf
doesn't work because they return multiple arrays.
I'm not sure why isnat
didn't work for you. Did you test it with datetime dtypes?
whats-new.rst
api.rst
Re-implement the old
xarray.ufuncs
module to allow generic ufunc handling for array types that don't implement__array_ufunc__
: