Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR]: Enable skipna for spatial and temporal mean operations #655

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions tests/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ def test_spatial_average_for_lat_region(self):

assert result.identical(expected)

def test_spatial_average_for_lat_region_and_skipna(self):
ds = self.ds.copy(deep=True)
ds.ts[0] = np.nan

# Specifying axis as a str instead of list of str.
result = ds.spatial.average("ts", axis=["Y"], lat_bounds=(-5.0, 5), skipna=True)

expected = self.ds.copy()
expected["ts"] = xr.DataArray(
data=np.array(
[
[np.nan, np.nan, np.nan, np.nan],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
]
),
coords={"time": expected.time, "lon": expected.lon},
dims=["time", "lon"],
)

assert result.identical(expected)

def test_spatial_average_for_domain_wrapping_p_meridian_non_cf_conventions(
self,
):
Expand Down
169 changes: 169 additions & 0 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,57 @@ def test_weighted_annual_averages(self):
assert result.ts.attrs == expected.ts.attrs
assert result.time.attrs == expected.time.attrs

def test_weighted_annual_averages_and_skipna(self):
ds = self.ds.copy(deep=True)
ds.ts[0] = np.nan

result = ds.temporal.group_average("ts", "year", skipna=True)
expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[1]], [[2.0]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
),
coords={
"time": np.array(
[
cftime.DatetimeGregorian(2000, 1, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
)
},
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "group_average",
"freq": "year",
"weighted": "True",
},
)

xr.testing.assert_allclose(result, expected)
assert result.ts.attrs == expected.ts.attrs
assert result.time.attrs == expected.time.attrs

@requires_dask
def test_weighted_annual_averages_with_chunking(self):
ds = self.ds.copy().chunk({"time": 2})
Expand Down Expand Up @@ -1161,6 +1212,68 @@ def test_weighted_seasonal_climatology_with_DJF(self):

xr.testing.assert_identical(result, expected)

def test_weighted_seasonal_climatology_with_DJF_and_skipna(self):
ds = self.ds.copy(deep=True)

# Replace all MAM values with np.nan.
djf_months = [3, 4, 5]
for mon in djf_months:
ds["ts"] = ds.ts.where(ds.ts.time.dt.month != mon, np.nan)

result = ds.temporal.climatology(
"ts",
"season",
season_config={"dec_mode": "DJF", "drop_incomplete_djf": True},
skipna=True,
)

expected = ds.copy()
expected = expected.drop_dims("time")
expected_time = xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(1, 1, 1),
cftime.DatetimeGregorian(1, 4, 1),
cftime.DatetimeGregorian(1, 7, 1),
cftime.DatetimeGregorian(1, 10, 1),
],
),
coords={
"time": np.array(
[
cftime.DatetimeGregorian(1, 1, 1),
cftime.DatetimeGregorian(1, 4, 1),
cftime.DatetimeGregorian(1, 7, 1),
cftime.DatetimeGregorian(1, 10, 1),
],
),
},
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
)
expected["ts"] = xr.DataArray(
name="ts",
data=np.ones((4, 4, 4)),
coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time},
dims=["time", "lat", "lon"],
attrs={
"operation": "temporal_avg",
"mode": "climatology",
"freq": "season",
"weighted": "True",
"dec_mode": "DJF",
"drop_incomplete_djf": "True",
},
)
expected.ts[1] = np.nan

# MAM should be np.nan
assert result.identical(expected)

@requires_dask
def test_chunked_weighted_seasonal_climatology_with_DJF(self):
ds = self.ds.copy().chunk({"time": 2})
Expand Down Expand Up @@ -1947,6 +2060,62 @@ def test_weighted_seasonal_departures_with_DJF(self):

xr.testing.assert_identical(result, expected)

def test_weighted_seasonal_departures_with_DJF_and_skipna(self):
ds = self.ds.copy(deep=True)

# Replace all MAM values with np.nan.
djf_months = [3, 4, 5]
for mon in djf_months:
ds["ts"] = ds.ts.where(ds.ts.time.dt.month != mon, np.nan)

result = ds.temporal.departures(
"ts",
"season",
weighted=True,
season_config={"dec_mode": "DJF", "drop_incomplete_djf": True},
skipna=True,
)

expected = ds.copy()
expected = expected.drop_dims("time")
expected["ts"] = xr.DataArray(
name="ts",
data=np.array([[[np.nan]], [[0.0]], [[0.0]], [[0.0]]]),
coords={
"lat": expected.lat,
"lon": expected.lon,
"time": xr.DataArray(
data=np.array(
[
cftime.DatetimeGregorian(2000, 4, 1),
cftime.DatetimeGregorian(2000, 7, 1),
cftime.DatetimeGregorian(2000, 10, 1),
cftime.DatetimeGregorian(2001, 1, 1),
],
),
dims=["time"],
attrs={
"axis": "T",
"long_name": "time",
"standard_name": "time",
"bounds": "time_bnds",
},
),
},
dims=["time", "lat", "lon"],
attrs={
"test_attr": "test",
"operation": "temporal_avg",
"mode": "departures",
"freq": "season",
"weighted": "True",
"dec_mode": "DJF",
"drop_incomplete_djf": "True",
},
)

assert result.identical(expected)

def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self):
ds = self.ds.copy()

Expand Down
22 changes: 18 additions & 4 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def average(
keep_weights: bool = False,
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
skipna: Union[bool, None] = None,
) -> xr.Dataset:
"""
Calculates the spatial average for a rectilinear grid over an optionally
Expand Down Expand Up @@ -125,6 +126,11 @@ def average(
ignored if ``weights`` are supplied. The lower bound can be larger
than the upper bound (e.g., across the prime meridian, dateline), by
default None.
skipna : bool or None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
implemented (object, datetime64 or timedelta64).

Returns
-------
Expand Down Expand Up @@ -196,7 +202,7 @@ def average(
self._weights = weights

self._validate_weights(dv, axis)
ds[dv.name] = self._averager(dv, axis)
ds[dv.name] = self._averager(dv, axis, skipna=skipna)

if keep_weights:
ds[self._weights.name] = self._weights
Expand Down Expand Up @@ -702,8 +708,11 @@ def _validate_weights(
)

def _averager(
self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...]
):
self,
data_var: xr.DataArray,
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
skipna: bool | None = None,
) -> xr.DataArray:
"""Perform a weighted average of a data variable.

This method assumes all specified keys in ``axis`` exists in the data
Expand All @@ -721,6 +730,11 @@ def _averager(
Data variable inside a Dataset.
axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
List of axis dimensions to average over.
skipna : bool or None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
implemented (object, datetime64 or timedelta64).

Returns
-------
Expand All @@ -739,6 +753,6 @@ def _averager(
dim.append(get_dim_keys(data_var, key))

with xr.set_options(keep_attrs=True):
weighted_mean = data_var.cf.weighted(weights).mean(dim=dim)
weighted_mean = data_var.cf.weighted(weights).mean(dim=dim, skipna=skipna)

return weighted_mean
Loading
Loading