Skip to content

Commit

Permalink
Pint support for variables (#3706)
Browse files Browse the repository at this point in the history
* get fillna tests to pass

* get the _getitem_with_mask tests to pass

* silence the behavior change warning of pint

* don't use 0 as fill value since that has special behaviour

* use concat as a class method

* use np.pad after trimming instead of concatenating a filled array

* rewrite the concat test to pass appropriate arrays

* use da.pad when dealing with dask arrays

* mark the failing pad tests as xfail when on a current pint version

* update whats-new.rst

* fix the import order

* test using pint master

* fix the install command

* reimplement the pad test to really work with units

* use np.logical_not instead

* use duck_array_ops to provide pad

* add comments explaining the order of the arguments to where

* mark the flipped parameter changes with a todo

* skip the identical tests

* remove the warnings filter
  • Loading branch information
keewis authored Feb 23, 2020
1 parent 1667e4c commit 47476eb
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 63 deletions.
1 change: 1 addition & 0 deletions ci/azure/install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ steps:
git+https://github.com/zarr-developers/zarr \
git+https://github.com/Unidata/cftime \
git+https://github.com/mapbox/rasterio \
git+https://github.com/hgrecco/pint \
git+https://github.com/pydata/bottleneck
condition: eq(variables['UPSTREAM_DEV'], 'true')
displayName: Install upstream dev dependencies
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Breaking changes

New Features
~~~~~~~~~~~~
- implement pint support. (:issue:`3594`, :pull:`3706`)
By `Justus Magin <https://github.com/keewis>`_.

Bug fixes
~~~~~~~~~
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def notnull(data):
isin = _dask_or_eager_func("isin", array_args=slice(2))
take = _dask_or_eager_func("take")
broadcast_to = _dask_or_eager_func("broadcast_to")
pad = _dask_or_eager_func("pad")

_concatenate = _dask_or_eager_func("concatenate", list_of_args=True)
_stack = _dask_or_eager_func("stack", list_of_args=True)
Expand Down Expand Up @@ -261,7 +262,10 @@ def where_method(data, cond, other=dtypes.NA):


def fillna(data, other):
return where(isnull(data), other, data)
# we need to pass data first so pint has a chance of returning the
# correct unit
# TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed
return where(notnull(data), data, other)


def concatenate(arrays, axis=0):
Expand Down
31 changes: 13 additions & 18 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,10 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA):

data = as_indexable(self._data)[actual_indexer]
mask = indexing.create_mask(indexer, self.shape, data)
data = duck_array_ops.where(mask, fill_value, data)
# we need to invert the mask in order to pass data first. This helps
# pint to choose the correct unit
# TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed
data = duck_array_ops.where(np.logical_not(mask), data, fill_value)
else:
# array cannot be indexed along dimensions of size 0, so just
# build the mask directly instead.
Expand Down Expand Up @@ -1099,24 +1102,16 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
else:
dtype = self.dtype

shape = list(self.shape)
shape[axis] = min(abs(count), shape[axis])
width = min(abs(count), self.shape[axis])
dim_pad = (width, 0) if count >= 0 else (0, width)
pads = [(0, 0) if d != dim else dim_pad for d in self.dims]

if isinstance(trimmed_data, dask_array_type):
chunks = list(trimmed_data.chunks)
chunks[axis] = (shape[axis],)
full = functools.partial(da.full, chunks=chunks)
else:
full = np.full

filler = full(shape, fill_value, dtype=dtype)

if count > 0:
arrays = [filler, trimmed_data]
else:
arrays = [trimmed_data, filler]

data = duck_array_ops.concatenate(arrays, axis)
data = duck_array_ops.pad(
trimmed_data.astype(dtype),
pads,
mode="constant",
constant_values=fill_value,
)

if isinstance(data, dask_array_type):
# chunked data should come out with the same chunks; this makes
Expand Down
112 changes: 68 additions & 44 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from distutils.version import LooseVersion

import numpy as np
import pandas as pd
Expand All @@ -19,6 +20,7 @@
unit_registry = pint.UnitRegistry(force_ndarray=True)
Quantity = unit_registry.Quantity


pytestmark = [
pytest.mark.skipif(
not IS_NEP18_ACTIVE, reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled"
Expand Down Expand Up @@ -1536,27 +1538,17 @@ def test_missing_value_detection(self, func):
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(
1,
DimensionalityError,
id="no_unit",
marks=pytest.mark.xfail(reason="uses 0 as a replacement"),
),
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(
unit_registry.cm,
None,
id="compatible_unit",
marks=pytest.mark.xfail(reason="converts to fill value's unit"),
),
pytest.param(unit_registry.cm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
)
def test_missing_value_fillna(self, unit, error):
value = 0
value = 10
array = (
np.array(
[
Expand Down Expand Up @@ -1595,13 +1587,7 @@ def test_missing_value_fillna(self, unit, error):
pytest.param(1, id="no_unit"),
pytest.param(unit_registry.dimensionless, id="dimensionless"),
pytest.param(unit_registry.s, id="incompatible_unit"),
pytest.param(
unit_registry.cm,
id="compatible_unit",
marks=pytest.mark.xfail(
reason="checking for identical units does not work properly, yet"
),
),
pytest.param(unit_registry.cm, id="compatible_unit",),
pytest.param(unit_registry.m, id="identical_unit"),
),
)
Expand All @@ -1612,7 +1598,17 @@ def test_missing_value_fillna(self, unit, error):
pytest.param(True, id="with_conversion"),
),
)
@pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr)
@pytest.mark.parametrize(
"func",
(
method("equals"),
pytest.param(
method("identical"),
marks=pytest.mark.skip(reason="behaviour of identical is unclear"),
),
),
ids=repr,
)
def test_comparisons(self, func, unit, convert_data, dtype):
array = np.linspace(0, 1, 9).astype(dtype)
quantity1 = array * unit_registry.m
Expand Down Expand Up @@ -1762,14 +1758,7 @@ def test_1d_math(self, func, unit, error, dtype):
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(
unit_registry.cm,
None,
id="compatible_unit",
marks=pytest.mark.xfail(
reason="getitem_with_mask converts to the unit of other"
),
),
pytest.param(unit_registry.cm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
)
Expand Down Expand Up @@ -1853,12 +1842,7 @@ def test_squeeze(self, dtype):
),
method("reduce", np.std, "x"),
method("round", 2),
pytest.param(
method("shift", {"x": -2}),
marks=pytest.mark.xfail(
reason="trying to concatenate ndarray to quantity"
),
),
method("shift", {"x": -2}),
method("transpose", "y", "x"),
),
ids=repr,
Expand Down Expand Up @@ -1933,7 +1917,6 @@ def test_unstack(self, dtype):
assert_units_equal(expected, actual)
xr.testing.assert_identical(expected, actual)

@pytest.mark.xfail(reason="ignores units")
@pytest.mark.parametrize(
"unit,error",
(
Expand All @@ -1948,25 +1931,28 @@ def test_unstack(self, dtype):
)
def test_concat(self, unit, error, dtype):
array1 = (
np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m
np.linspace(0, 5, 9 * 10).reshape(3, 6, 5).astype(dtype) * unit_registry.m
)
array2 = np.linspace(5, 10, 10 * 2).reshape(10, 2).astype(dtype) * unit
array2 = np.linspace(5, 10, 10 * 3).reshape(3, 2, 5).astype(dtype) * unit

variable = xr.Variable(("x", "y"), array1)
other = xr.Variable(("y", "z"), array2)
variable = xr.Variable(("x", "y", "z"), array1)
other = xr.Variable(("x", "y", "z"), array2)

if error is not None:
with pytest.raises(error):
variable.concat(other)
xr.Variable.concat([variable, other], dim="y")

return

units = extract_units(variable)
expected = attach_units(
strip_units(variable).concat(strip_units(convert_units(other, units))),
xr.Variable.concat(
[strip_units(variable), strip_units(convert_units(other, units))],
dim="y",
),
units,
)
actual = variable.concat(other)
actual = xr.Variable.concat([variable, other], dim="y")

assert_units_equal(expected, actual)
xr.testing.assert_identical(expected, actual)
Expand Down Expand Up @@ -2036,6 +2022,43 @@ def test_no_conflicts(self, unit, dtype):

assert expected == actual

def test_pad(self, dtype):
data = np.arange(4 * 3 * 2).reshape(4, 3, 2).astype(dtype) * unit_registry.m
v = xr.Variable(["x", "y", "z"], data)

xr_args = [{"x": (2, 1)}, {"y": (0, 3)}, {"x": (3, 1), "z": (2, 0)}]
np_args = [
((2, 1), (0, 0), (0, 0)),
((0, 0), (0, 3), (0, 0)),
((3, 1), (0, 0), (2, 0)),
]
for xr_arg, np_arg in zip(xr_args, np_args):
actual = v.pad_with_fill_value(**xr_arg)
expected = xr.Variable(
v.dims,
np.pad(
v.data.astype(float),
np_arg,
mode="constant",
constant_values=np.nan,
),
)
xr.testing.assert_identical(expected, actual)
assert_units_equal(expected, actual)
assert isinstance(actual._data, type(v._data))

# for the boolean array, we pad False
data = np.full_like(data, False, dtype=bool).reshape(4, 3, 2)
v = xr.Variable(["x", "y", "z"], data)
for xr_arg, np_arg in zip(xr_args, np_args):
actual = v.pad_with_fill_value(fill_value=data.flat[0], **xr_arg)
expected = xr.Variable(
v.dims,
np.pad(v.data, np_arg, mode="constant", constant_values=v.data.flat[0]),
)
xr.testing.assert_identical(actual, expected)
assert_units_equal(expected, actual)

@pytest.mark.parametrize(
"unit,error",
(
Expand All @@ -2044,7 +2067,8 @@ def test_no_conflicts(self, unit, dtype):
DimensionalityError,
id="no_unit",
marks=pytest.mark.xfail(
reason="is not treated the same as dimensionless"
LooseVersion(pint.__version__) < LooseVersion("0.10.2"),
reason="bug in pint's implementation of np.pad",
),
),
pytest.param(
Expand Down

0 comments on commit 47476eb

Please sign in to comment.