-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support nan-ops for object-typed arrays #1883
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is ready for review, but some API decisions would be needed.
xarray/core/duck_array_ops.py
Outdated
data = -1 if valid_count == 0 else int(data) | ||
return np.array(data) # return 0d-array | ||
# convert all nan part axis to nan | ||
return where_method(data, valid_count != 0, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In numpy, nanargmin
raise ValueError if there are all-NaN slice/axis.
Do we follow this?
Considering our getitem_with_mask
method, I think it is consistent to return -1 for such a case,
but it could be confusing sometimes.
Edit:
Also, now this function is called only for object-type.
If we adopt the above API, we may need to update for a numeric case.
doc/whats-new.rst
Outdated
|
||
.. ipython:: python | ||
|
||
da = xray.DataArray(np.array([True, False, np.nan], dtype=object), dims='x') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xr
/ xarray
?
xarray/core/duck_array_ops.py
Outdated
""" In house nanmean. ddof argument will be used in _nanvar method """ | ||
valid_count = count(value, axis=axis) | ||
value = fillna(value, 0.0) | ||
# TODO numpy's mean does not support object-type array, so we assume float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use sum
in this function and simply divide by valid_count - ddof
instead of rescaling the mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could potentially copy at least part of the implementation from NumPy's own mean
:
https://github.com/numpy/numpy/blob/e06d3614182e7b97d5e0d90291642027d147744b/numpy/core/_methods.py#L53
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use sum in this function and simply divide by valid_count - ddof instead of rescaling the mean?
I feel pretty dumb now. Updated.
xarray/core/duck_array_ops.py
Outdated
@@ -171,6 +171,79 @@ def _ignore_warnings_if(condition): | |||
yield | |||
|
|||
|
|||
def _nansum(value, axis=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we give these functions a more explicit names like _nansum_object
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
xarray/core/duck_array_ops.py
Outdated
filled_value = fillna(value, fill_value) | ||
data = _dask_or_eager_func(func)(filled_value, axis=axis, **kwargs) | ||
if not hasattr(data, 'dtype'): # scalar case | ||
data = np.nan if valid_count == 0 else data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead passing in the fill_value
, let's figure it out from dtypes.maybe_promote()
xarray/core/duck_array_ops.py
Outdated
|
||
|
||
_nan_funcs = {'sum': _nansum, | ||
'min': partial(_nan_minmax, 'min', np.inf), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am concerned that this will break on arrays of strings. On Python 2, the code probably works (but gives an incorrect result), but on Python 3 np.inf > 'abc'
raises a TypeError.
Given that these are only used for object arrays, maybe we should use special objects for this instead, e.g.,
@functools.total_ordering
class AlwaysLessThan(object):
def __lt__(self, other):
return True
def __eq__(self, other):
return isinstance(other, type(self))
We should probably also add some unit tests for object arrays of strings/NaN (probably do this first!). Currently these raise an error, but I think this code could fix them:
>>> xr.DataArray(np.array([np.nan, 'foo'], dtype=object)).min()
TypeError: '<=' not supported between instances of 'float' and 'str'
xarray/core/duck_array_ops.py
Outdated
kwargs_mean.pop('keepdims', None) | ||
value_mean = _nanmean_ddof(ddof=0, value=value, axis=axis, keepdims=True, | ||
**kwargs_mean) | ||
squared = _dask_or_eager_func('square')(value.astype(value_mean.dtype) - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could potentially just use the operator ** 2
instead of the dask_or_eager_func here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
xarray/core/duck_array_ops.py
Outdated
dtype = kwargs.get('dtype', None) | ||
value = fillna(value, 0) | ||
# As dtype inference is impossible for object dtype, we assume float | ||
dtype = kwargs.pop('dtype', None) | ||
if dtype is None and value.dtype.kind == 'O': | ||
dtype = value.dtype if value.dtype.kind in ['cf'] else float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a good workaround to infer the output dtype of the object-typed array?
We need to pass this to dask for the next division but dtype=object
is not allowed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this fixed by your dask PR dask/dask#3137?
If so, can we maybe say this requires using the latest dask release?
xarray/tests/test_duck_array_ops.py
Outdated
|
||
|
||
@pytest.mark.parametrize('dim_num', [1, 2]) | ||
@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_, str]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test for str-type
xarray/core/dtypes.py
Outdated
@@ -40,7 +64,7 @@ def maybe_promote(dtype): | |||
return np.dtype(dtype), fill_value | |||
|
|||
|
|||
def get_fill_value(dtype): | |||
def get_fill_value(dtype, fill_value_typ=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make separate functions for this, maybe get_pos_infinity
and get_neg_infinity
? It feels a little strange to put it all in one function, and with separate functions you can avoid the need to validate the fill_value_typ
argument.
xarray/core/duck_array_ops.py
Outdated
if isinstance(value, dask_array_type): | ||
data = data.astype(int) | ||
if not hasattr(data, 'dtype'): # scalar case | ||
# TODO should we raise ValueError if all-nan slice encountered? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with nanargmin()
, we probably should still raise ValueError('All-NaN slice encountered')
for now. -1
would make sense, but it would need to documented. NaN could also make sense, but would not be so useful since floats are not valid indexers.
xarray/core/duck_array_ops.py
Outdated
dtype = kwargs.get('dtype', None) | ||
value = fillna(value, 0) | ||
# As dtype inference is impossible for object dtype, we assume float | ||
dtype = kwargs.pop('dtype', None) | ||
if dtype is None and value.dtype.kind == 'O': | ||
dtype = value.dtype if value.dtype.kind in ['cf'] else float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this fixed by your dask PR dask/dask#3137?
If so, can we maybe say this requires using the latest dask release?
xarray/core/dtypes.py
Outdated
------- | ||
fill_value : positive infinity value corresponding to this dtype. | ||
""" | ||
if np.issubdtype(dtype, np.floating): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want:
issubclass(dtype.type, (np.floating, np.integer))
->np.inf
issubclass(dtype.type, np.complexfloating)
->np.inf + 1j * np.inf
Using np.inf
for integer types should be faster, since it doesn't require comparing everything as objets. And I think we need np.inf + 1j * np.inf
to match numpy's sort order for complex values.
It's better to use issubclass
with dtype.type
because np.issubdtype
has some weird (deprecated) fallback rules: https://github.com/numpy/numpy/blob/v1.14.0/numpy/core/numerictypes.py#L699-L758
xarray/core/duck_array_ops.py
Outdated
|
||
def _nan_minmax_object(func, fill_value_typ, value, axis=None, **kwargs): | ||
""" In house nanmin and nanmax for object array """ | ||
if fill_value_typ == '+inf': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Instead of passing a separate string, we might just pass the function to make the fill value directly (dtypes.get_pos_infinity
or dtypes.get_neg_infinity
).
That would let us drop these conditionals and error prone string matching.
# Conflicts: # xarray/core/dtypes.py
… it raises ValueError in argmin/argmax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Feel free to merge...
whats-new.rst
for all changes andapi.rst
for new APII am working to add aggregation ops for object-typed arrays, which may make #1837 cleaner.
I added some tests but maybe not sufficient.
Any other cases which should be considered?
e.g.
[True, 3.0, np.nan]
etc...