Skip to content
forked from pydata/xarray

Commit

Permalink
Drop nans in grouped variable.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 16, 2019
1 parent 3f9069b commit 5bf94a8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
7 changes: 7 additions & 0 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,13 @@ def __init__(
group_indices = [slice(i, i + 1) for i in group_indices]
unique_coord = group
else:
if group.isnull().any():
# drop any NaN valued groups.
# also drop obj values where group was NaN
# Use where instead of reindex to account for duplicate coordinate labels.
obj = obj.where(group.notnull(), drop=True)
group = group.dropna(group_dim)

# look through group to find the unique values
unique_values, group_indices = unique_value_groups(
safe_cast_to_index(group), sort=(bins is None)
Expand Down
67 changes: 60 additions & 7 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import xarray as xr
from xarray.core.groupby import _consolidate_slices

from . import assert_identical, raises_regex
from . import assert_equal, assert_identical, raises_regex


def test_consolidate_slices():
Expand Down Expand Up @@ -40,14 +40,14 @@ def test_multi_index_groupby_apply():
{"foo": (("x", "y"), np.random.randn(3, 4))},
{"x": ["a", "b", "c"], "y": [1, 2, 3, 4]},
)
doubled = 2 * ds
group_doubled = (
expected = 2 * ds
actual = (
ds.stack(space=["x", "y"])
.groupby("space")
.apply(lambda x: 2 * x)
.unstack("space")
)
assert doubled.equals(group_doubled)
assert_equal(expected, actual)


def test_multi_index_groupby_sum():
Expand All @@ -58,7 +58,7 @@ def test_multi_index_groupby_sum():
)
expected = ds.sum("z")
actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space")
assert expected.equals(actual)
assert_equal(expected, actual)


def test_groupby_da_datetime():
Expand All @@ -78,15 +78,15 @@ def test_groupby_da_datetime():
expected = xr.DataArray(
[3, 7], coords=dict(reference_date=reference_dates), dims="reference_date"
)
assert actual.equals(expected)
assert_equal(expected, actual)


def test_groupby_duplicate_coordinate_labels():
# fix for http://stackoverflow.com/questions/38065129
array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])])
expected = xr.DataArray([3, 3], [("x", [1, 2])])
actual = array.groupby("x").sum()
assert expected.equals(actual)
assert_equal(expected, actual)


def test_groupby_input_mutation():
Expand Down Expand Up @@ -255,6 +255,59 @@ def test_groupby_repr_datetime(obj):
assert actual == expected


def test_groupby_drops_nans():
# GH2383
# nan in 2D data variable (requires stacking)
ds = xr.Dataset(
{
"variable": (("lat", "lon", "time"), np.arange(60.0).reshape((4, 3, 5))),
"id": (("lat", "lon"), np.arange(12.0).reshape((4, 3))),
},
coords={"lat": np.arange(4), "lon": np.arange(3), "time": np.arange(5)},
)

ds["id"].values[0, 0] = np.nan
ds["id"].values[3, 0] = np.nan
ds["id"].values[-1, -1] = np.nan

grouped = ds.groupby(ds.id)

# non reduction operation
expected = ds.copy()
expected.variable.values[0, 0, :] = np.nan
expected.variable.values[-1, -1, :] = np.nan
expected.variable.values[3, 0, :] = np.nan
actual = grouped.apply(lambda x: x).transpose(*ds.variable.dims)
assert_identical(actual, expected)

# reduction along grouped dimension
actual = grouped.mean()
stacked = ds.stack({"xy": ["lat", "lon"]})
expected = (
stacked.variable.where(stacked.id.notnull()).rename({"xy": "id"}).to_dataset()
)
expected["id"] = stacked.id.values
assert_identical(actual, expected.dropna("id").transpose(*actual.dims))

# reduction operation along a different dimension
actual = grouped.mean("time")
expected = ds.mean("time").where(ds.id.notnull())
assert_identical(actual, expected)

# NaN in non-dimensional coordinate
array = xr.DataArray([1, 2, 3], [("x", [1, 2, 3])])
array["x1"] = ("x", [1, 1, np.nan])
expected = xr.DataArray(3, [("x1", [1])])
actual = array.groupby("x1").sum()
assert_equal(expected, actual)

# test for repeated coordinate labels
array = xr.DataArray([0, 1, 2, 4, 3, 4], [("x", [np.nan, 1, 1, np.nan, 2, np.nan])])
expected = xr.DataArray([3, 3], [("x", [1, 2])])
actual = array.groupby("x").sum()
assert_equal(expected, actual)


def test_groupby_grouping_errors():
dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
with raises_regex(ValueError, "None of the data falls within bins with edges"):
Expand Down

0 comments on commit 5bf94a8

Please sign in to comment.