From 7998afbda80cfe2a71d78b8dffb4259fb726ff3b Mon Sep 17 00:00:00 2001 From: tiago Date: Thu, 23 May 2024 02:32:15 +0200 Subject: [PATCH 1/8] add functionality to pad dataset data variables with unique constant values --- xarray/core/dataset.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 09597670573..449cf0e140d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9087,7 +9087,7 @@ def pad( int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None ) = None, constant_values: ( - float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None + float | tuple[float, float] | Mapping[Any, float | tuple[float, float] | Mapping[Any, tuple[float, float]]] | None ) = None, end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, reflect_type: PadReflectOptions = None, @@ -9144,9 +9144,11 @@ def pad( (stat_length,) or int is a shortcut for before = after = statistic length for all axes. Default is ``None``, to use the entire axis. - constant_values : scalar, tuple or mapping of hashable to tuple, default: 0 - Used in 'constant'. The values to set the padded values for each - axis. + constant_values : scalar, tuple, mapping of hashable to tuple or + mapping of hashable to mapping of hashable to tuple, default: 0 + Used in 'constant'. The values to set the padded values for each data variable / axis. + ``{var_1: {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}, + var_2: (before, after) ... "*": constant}`` unique pad constants per data variable. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique pad constants along each dimension. ``((before, after),)`` yields same before and after constants for each @@ -9231,17 +9233,27 @@ def pad( for k, idx in xindexes.items(): if not pad_dims.intersection(xindexes.get_all_dims(k)): indexes[k] = idx + + per_data_var_constant_values = {} + if isinstance(constant_values, Mapping): + for k in self.data_vars: + if v := constant_values.pop(k, None): + per_data_var_constant_values[k] = v + if global_constant_values := constant_values.pop("*", None): + assert constant_values == {}, "Conflicting constant values" + constant_values = global_constant_values + for name, var in self.variables.items(): var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims} if not var_pad_width: variables[name] = var - elif name in self.data_vars: + elif name in self.data_vars: variables[name] = var.pad( pad_width=var_pad_width, mode=mode, stat_length=stat_length, - constant_values=constant_values, + constant_values=per_data_var_constant_values.get(name, constant_values), end_values=end_values, reflect_type=reflect_type, keep_attrs=keep_attrs, From 44e7d26694b2696054e7c20842f9d22562bf12c7 Mon Sep 17 00:00:00 2001 From: tiago Date: Tue, 13 Aug 2024 20:40:28 +0200 Subject: [PATCH 2/8] clean up implementation of variable specific padding for dataset. Add tests --- xarray/core/dataset.py | 25 ++++++++++++++----------- xarray/tests/test_dataset.py | 30 +++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 449cf0e140d..e7ca34d1842 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9087,7 +9087,12 @@ def pad( int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None ) = None, constant_values: ( - float | tuple[float, float] | Mapping[Any, float | tuple[float, float] | Mapping[Any, tuple[float, float]]] | None + float + | tuple[float, float] + | Mapping[ + Any, float | tuple[float, float] | Mapping[Any, tuple[float, float]] + ] + | None ) = None, end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, reflect_type: PadReflectOptions = None, @@ -9144,11 +9149,11 @@ def pad( (stat_length,) or int is a shortcut for before = after = statistic length for all axes. Default is ``None``, to use the entire axis. - constant_values : scalar, tuple, mapping of hashable to tuple or + constant_values : scalar, tuple, mapping of hashable to tuple or mapping of hashable to mapping of hashable to tuple, default: 0 Used in 'constant'. The values to set the padded values for each data variable / axis. - ``{var_1: {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}, - var_2: (before, after) ... "*": constant}`` unique pad constants per data variable. + ``{var_1: {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}, ... + var_M: (before, after)}`` unique pad constants per data variable. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique pad constants along each dimension. ``((before, after),)`` yields same before and after constants for each @@ -9233,27 +9238,25 @@ def pad( for k, idx in xindexes.items(): if not pad_dims.intersection(xindexes.get_all_dims(k)): indexes[k] = idx - + per_data_var_constant_values = {} if isinstance(constant_values, Mapping): for k in self.data_vars: if v := constant_values.pop(k, None): per_data_var_constant_values[k] = v - if global_constant_values := constant_values.pop("*", None): - assert constant_values == {}, "Conflicting constant values" - constant_values = global_constant_values - for name, var in self.variables.items(): var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims} if not var_pad_width: variables[name] = var - elif name in self.data_vars: + elif name in self.data_vars: variables[name] = var.pad( pad_width=var_pad_width, mode=mode, stat_length=stat_length, - constant_values=per_data_var_constant_values.get(name, constant_values), + constant_values=per_data_var_constant_values.get( + name, constant_values + ), end_values=end_values, reflect_type=reflect_type, keep_attrs=keep_attrs, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 584776197e3..b3e83736d8f 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6584,9 +6584,29 @@ def test_polyfit_warnings(self) -> None: ds.var1.polyfit("dim2", 10, full=True) assert len(ws) == 1 - def test_pad(self) -> None: + @pytest.mark.parametrize( + ["constant_values", "expected"], + [ + pytest.param(42, {"var1": 42}, id="numeric"), + pytest.param((42, 43), {"var1": (42, 43)}, id="tuple"), + pytest.param( + {"dim2": (42, 43)}, {"var1": (42, 43), "var2": (42, 43)}, id="per dim" + ), + pytest.param( + {"var1": 42, "var2": (42, 43)}, + {"var1": 42, "var2": (42, 43)}, + id="per var", + ), + pytest.param( + {"var1": 42, "dim2": (42, 43)}, + {"var1": 42, "var2": (42, 43)}, + id="mixed", + ), + ], + ) + def test_pad(self, constant_values, expected) -> None: ds = create_test_data(seed=1) - padded = ds.pad(dim2=(1, 1), constant_values=42) + padded = ds.pad(dim2=(1, 1), constant_values=constant_values) assert padded["dim2"].shape == (11,) assert padded["var1"].shape == (8, 11) @@ -6594,7 +6614,11 @@ def test_pad(self) -> None: assert padded["var3"].shape == (10, 8) assert dict(padded.sizes) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20} - np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) + for var, expected_value in expected.items(): + np.testing.assert_equal( + np.unique(padded[var].isel(dim2=[0, -1]).data), expected_value + ) + # np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan) @pytest.mark.parametrize( From d939030dced1b20aac255a84a0965293158b4581 Mon Sep 17 00:00:00 2001 From: tiago Date: Wed, 14 Aug 2024 16:22:25 +0200 Subject: [PATCH 3/8] more expressive docsting and symplefy type signature with alias in dataset pad func. enforce number values to be converted to tuples for all in `_pad_options_dim_to_index`. make variable pad funtion consistent with dataset. extend tests --- xarray/core/dataset.py | 14 ++++---------- xarray/core/types.py | 5 +++++ xarray/core/variable.py | 20 +++++++++++--------- xarray/tests/test_dataset.py | 10 +++++++--- 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3a1fecfbb1e..8d1650f3077 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -163,6 +163,7 @@ ReindexMethodOptions, SideOptions, T_ChunkDimFreq, + T_DatasetPadConstantValues, T_Xarray, ) from xarray.core.weighted import DatasetWeighted @@ -9147,14 +9148,7 @@ def pad( stat_length: ( int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None ) = None, - constant_values: ( - float - | tuple[float, float] - | Mapping[ - Any, float | tuple[float, float] | Mapping[Any, tuple[float, float]] - ] - | None - ) = None, + constant_values: T_DatasetPadConstantValues | None = None, end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, @@ -9210,8 +9204,8 @@ def pad( (stat_length,) or int is a shortcut for before = after = statistic length for all axes. Default is ``None``, to use the entire axis. - constant_values : scalar, tuple, mapping of hashable to tuple or - mapping of hashable to mapping of hashable to tuple, default: 0 + constant_values : scalar, tuple, mapping of dim name to scalar or tuple, or \ + mapping of var name to scalar, tuple or to mapping of dim name to scalar or tuple, default: 0 Used in 'constant'. The values to set the padded values for each data variable / axis. ``{var_1: {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}, ... var_M: (before, after)}`` unique pad constants per data variable. diff --git a/xarray/core/types.py b/xarray/core/types.py index 0e432283ba9..3eb97f86c4a 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -243,6 +243,11 @@ def copy( "symmetric", "wrap", ] +T_PadConstantValues = float | tuple[float, float] +T_VarPadConstantValues = T_PadConstantValues | Mapping[Any, T_PadConstantValues] +T_DatasetPadConstantValues = ( + T_VarPadConstantValues | Mapping[Any, T_VarPadConstantValues] +) PadReflectOptions = Literal["even", "odd", None] CFCalendar = Literal[ diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 3cd8e4acbd5..0767755ea2b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -65,6 +65,7 @@ Self, T_Chunks, T_DuckArray, + T_VarPadConstantValues, ) from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -1121,9 +1122,16 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): def _pad_options_dim_to_index( self, - pad_option: Mapping[Any, int | tuple[int, int]], + pad_option: Mapping[ + Any, numbers.Number | tuple[numbers.Number, numbers.Number] + ], fill_with_shape=False, ): + # change number values to a tuple of two of those values + for k, v in pad_option.items(): + if isinstance(v, numbers.Number): + pad_option[k] = (v, v) + if fill_with_shape: return [ (n, n) if d not in pad_option else pad_option[d] @@ -1138,9 +1146,7 @@ def pad( stat_length: ( int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None ) = None, - constant_values: ( - float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None - ) = None, + constant_values: T_VarPadConstantValues | None = None, end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, @@ -1160,7 +1166,7 @@ def pad( stat_length : int, tuple or mapping of hashable to tuple Used in 'maximum', 'mean', 'median', and 'minimum'. Number of values at edge of each axis used to calculate the statistic value. - constant_values : scalar, tuple or mapping of hashable to tuple + constant_values : scalar, tuple or mapping of hashable to scalar or tuple Used in 'constant'. The values to set the padded values for each axis. end_values : scalar, tuple or mapping of hashable to tuple @@ -1207,10 +1213,6 @@ def pad( if stat_length is None and mode in ["maximum", "mean", "median", "minimum"]: stat_length = [(n, n) for n in self.data.shape] # type: ignore[assignment] - # change integer values to a tuple of two of those values and change pad_width to index - for k, v in pad_width.items(): - if isinstance(v, numbers.Number): - pad_width[k] = (v, v) pad_width_by_index = self._pad_options_dim_to_index(pad_width) # create pad_options_kwargs, numpy/dask requires only relevant kwargs to be nonempty diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 39626502e77..3735e56b77c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6692,10 +6692,14 @@ def test_polyfit_warnings(self) -> None: @pytest.mark.parametrize( ["constant_values", "expected"], [ - pytest.param(42, {"var1": 42}, id="numeric"), - pytest.param((42, 43), {"var1": (42, 43)}, id="tuple"), + pytest.param(None, {"var1": np.nan}, id="default"), + pytest.param(42, {"var1": 42, "var2": 42}, id="scalar"), + pytest.param((42, 43), {"var1": (42, 43), "var2": (42, 43)}, id="tuple"), + pytest.param({"dim2": 42}, {"var1": 42, "var2": 42}, id="per dim scalar"), pytest.param( - {"dim2": (42, 43)}, {"var1": (42, 43), "var2": (42, 43)}, id="per dim" + {"dim2": (42, 43)}, + {"var1": (42, 43), "var2": (42, 43)}, + id="per dim tuple", ), pytest.param( {"var1": 42, "var2": (42, 43)}, From 48a972f42481932791e4cfd0f8c3de2f02f2e915 Mon Sep 17 00:00:00 2001 From: tiago Date: Wed, 14 Aug 2024 17:01:12 +0200 Subject: [PATCH 4/8] fix typing --- xarray/core/dataset.py | 2 +- xarray/core/variable.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8d1650f3077..226c2d566f5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9295,7 +9295,7 @@ def pad( indexes[k] = idx per_data_var_constant_values = {} - if isinstance(constant_values, Mapping): + if isinstance(constant_values, dict): for k in self.data_vars: if v := constant_values.pop(k, None): per_data_var_constant_values[k] = v diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 0767755ea2b..a74fb4d8ce9 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1122,9 +1122,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): def _pad_options_dim_to_index( self, - pad_option: Mapping[ - Any, numbers.Number | tuple[numbers.Number, numbers.Number] - ], + pad_option: Mapping[Any, int | float | tuple[int, int] | tuple[float, float]], fill_with_shape=False, ): # change number values to a tuple of two of those values From afda62dd8dbbfd05420e903aa1b3624453c08bde Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Sat, 17 Aug 2024 14:51:29 +0200 Subject: [PATCH 5/8] add terms to conf.py, make docstrings more accurate, expand tests for dataset pad function --- doc/conf.py | 3 + xarray/core/dataset.py | 11 ++-- xarray/tests/test_dataset.py | 104 +++++++++++++++++++++++------------ 3 files changed, 78 insertions(+), 40 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 4f1fc6751d2..93a0e459a33 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -153,6 +153,9 @@ "matplotlib colormap name": ":doc:`matplotlib colormap name `", "matplotlib axes object": ":py:class:`matplotlib axes object `", "colormap": ":py:class:`colormap `", + # xarray terms + "dim name": ":term:`dimension name `", + "var name": ":term:`variable name `", # objects without namespace: xarray "DataArray": "~xarray.DataArray", "Dataset": "~xarray.Dataset", diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4db7103f3e9..f7c94a48ccb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9209,7 +9209,7 @@ def pad( length for all axes. Default is ``None``, to use the entire axis. constant_values : scalar, tuple, mapping of dim name to scalar or tuple, or \ - mapping of var name to scalar, tuple or to mapping of dim name to scalar or tuple, default: 0 + mapping of var name to scalar, tuple or to mapping of dim name to scalar or tuple, default: None Used in 'constant'. The values to set the padded values for each data variable / axis. ``{var_1: {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}, ... var_M: (before, after)}`` unique pad constants per data variable. @@ -9219,8 +9219,8 @@ def pad( dimension. ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for all dimensions. - Default is 0. - end_values : scalar, tuple or mapping of hashable to tuple, default: 0 + Default is ``None``, pads with ``np.nan``. + end_values : scalar, tuple or mapping of hashable to tuple, default: None Used in 'linear_ramp'. The values used for the ending value of the linear_ramp and that will form the edge of the padded array. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique @@ -9229,7 +9229,7 @@ def pad( axis. ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for all axes. - Default is 0. + Default is None. reflect_type : {"even", "odd", None}, optional Used in "reflect", and "symmetric". The "even" style is the default with an unaltered reflection around the edge value. For @@ -9299,7 +9299,8 @@ def pad( indexes[k] = idx per_data_var_constant_values = {} - if isinstance(constant_values, dict): + if utils.is_dict_like(constant_values): + constant_values = dict(constant_values) for k in self.data_vars: if v := constant_values.pop(k, None): per_data_var_constant_values[k] = v diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index a5f6d4d93e0..8e6b27ab110 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6699,46 +6699,80 @@ def test_polyfit_warnings(self) -> None: ds.var1.polyfit("dim2", 10, full=True) assert len(ws) == 1 + @staticmethod + def _test_data_var_interior( + original_data_var, padded_data_var, padded_dim_name, expected_pad_values + ): + np.testing.assert_equal( + np.unique(padded_data_var.isel({padded_dim_name: [0, -1]})), + expected_pad_values, + ) + np.testing.assert_array_equal( + padded_data_var.isel({padded_dim_name: slice(1, -1)}), original_data_var + ) + + @pytest.mark.parametrize("padded_dim_name", ["dim1", "dim2", "dim3", "time"]) @pytest.mark.parametrize( - ["constant_values", "expected"], + ["constant_values"], [ - pytest.param(None, {"var1": np.nan}, id="default"), - pytest.param(42, {"var1": 42, "var2": 42}, id="scalar"), - pytest.param((42, 43), {"var1": (42, 43), "var2": (42, 43)}, id="tuple"), - pytest.param({"dim2": 42}, {"var1": 42, "var2": 42}, id="per dim scalar"), - pytest.param( - {"dim2": (42, 43)}, - {"var1": (42, 43), "var2": (42, 43)}, - id="per dim tuple", - ), - pytest.param( - {"var1": 42, "var2": (42, 43)}, - {"var1": 42, "var2": (42, 43)}, - id="per var", - ), - pytest.param( - {"var1": 42, "dim2": (42, 43)}, - {"var1": 42, "var2": (42, 43)}, - id="mixed", - ), + pytest.param(None, id="default"), + pytest.param(42, id="scalar"), + pytest.param((42, 43), id="tuple"), + pytest.param({"dim1": 42, "dim2": 43}, id="per dim scalar"), + pytest.param({"dim1": (42, 43), "dim2": (43, 44)}, id="per dim tuple"), + pytest.param({"var1": 42, "var2": (42, 43)}, id="per var"), + pytest.param({"var1": 42, "dim1": (42, 43)}, id="mixed"), ], ) - def test_pad(self, constant_values, expected) -> None: + def test_pad(self, padded_dim_name, constant_values) -> None: ds = create_test_data(seed=1) - padded = ds.pad(dim2=(1, 1), constant_values=constant_values) - - assert padded["dim2"].shape == (11,) - assert padded["var1"].shape == (8, 11) - assert padded["var2"].shape == (8, 11) - assert padded["var3"].shape == (10, 8) - assert dict(padded.sizes) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20} - - for var, expected_value in expected.items(): - np.testing.assert_equal( - np.unique(padded[var].isel(dim2=[0, -1]).data), expected_value - ) - # np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) - np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan) + padded = ds.pad({padded_dim_name: (1, 1)}, constant_values=constant_values) + + # test padded dim values and size + for ds_dim_name, ds_dim in ds.sizes.items(): + if ds_dim_name == padded_dim_name: + np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim + 2) + if ds_dim_name in padded.coords: + assert padded[ds_dim_name][[0, -1]].isnull().all() + else: + np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim) + + # check if coord "numbers" with dimention dim3 is paded correctly + if padded_dim_name == "dim3": + assert padded["numbers"][[0, -1]].isnull().all() + # twarning: passes but dtype changes from int to float + np.testing.assert_array_equal(padded["numbers"][1:-1], ds["numbers"]) + + # test if data_vars are paded with correct values + for data_var_name, data_var in padded.data_vars.items(): + if padded_dim_name in data_var.dims: + if isinstance(constant_values, dict): + if ( + expected := constant_values.get(data_var_name, None) + ) is not None: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, expected + ) + elif ( + expected := constant_values.get(padded_dim_name, None) + ) is not None: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, expected + ) + else: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, 0 + ) + elif constant_values: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, constant_values + ) + else: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, np.nan + ) + else: + assert_array_equal(data_var, ds[data_var_name]) @pytest.mark.parametrize( ["keep_attrs", "attrs", "expected"], From 8bdb6f1c42f7bd89602fe823abecab1e81245132 Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Tue, 20 Aug 2024 02:35:57 +0200 Subject: [PATCH 6/8] filter constant value types without mutating input map --- xarray/core/dataset.py | 22 ++++++++++++---------- xarray/tests/test_dataset.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f7c94a48ccb..712e5a37d67 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9298,25 +9298,27 @@ def pad( if not pad_dims.intersection(xindexes.get_all_dims(k)): indexes[k] = idx - per_data_var_constant_values = {} - if utils.is_dict_like(constant_values): - constant_values = dict(constant_values) - for k in self.data_vars: - if v := constant_values.pop(k, None): - per_data_var_constant_values[k] = v - for name, var in self.variables.items(): var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims} if not var_pad_width: variables[name] = var elif name in self.data_vars: + if utils.is_dict_like(constant_values): + if name in constant_values.keys(): + filtered_constant_values = constant_values[name] + elif not set(var.dims).isdisjoint(constant_values.keys()): + filtered_constant_values = { + k: v for k, v in constant_values.items() if k in var.dims + } + else: + filtered_constant_values = 0 + else: + filtered_constant_values = constant_values variables[name] = var.pad( pad_width=var_pad_width, mode=mode, stat_length=stat_length, - constant_values=per_data_var_constant_values.get( - name, constant_values - ), + constant_values=filtered_constant_values, end_values=end_values, reflect_type=reflect_type, keep_attrs=keep_attrs, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8e6b27ab110..9465317855a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6746,7 +6746,7 @@ def test_pad(self, padded_dim_name, constant_values) -> None: # test if data_vars are paded with correct values for data_var_name, data_var in padded.data_vars.items(): if padded_dim_name in data_var.dims: - if isinstance(constant_values, dict): + if utils.is_dict_like(constant_values): if ( expected := constant_values.get(data_var_name, None) ) is not None: From 0a51b7c711aba92aecec05861494eaf1681d5ad2 Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Wed, 21 Aug 2024 02:33:01 +0200 Subject: [PATCH 7/8] add todo to change default padding for missing variables in constant_values --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2d172b3240a..dbc00a03025 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9313,7 +9313,7 @@ def pad( k: v for k, v in constant_values.items() if k in var.dims } else: - filtered_constant_values = 0 + filtered_constant_values = 0 # TODO: https://github.com/pydata/xarray/pull/9353#discussion_r1724018352 else: filtered_constant_values = constant_values variables[name] = var.pad( From a848351c265df01ac1cc09cd6a135476642158db Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Wed, 21 Aug 2024 21:11:45 +0200 Subject: [PATCH 8/8] add changes to whats new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e4ffc5ca799..2cf2d5928bf 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,8 @@ New Features ~~~~~~~~~~~~ - Make chunk manager an option in ``set_options`` (:pull:`9362`). By `Tom White `_. +- Allow data variable specific ``constant_values`` in the dataset ``pad`` function (:pull:`9353``). + By `Tiago Sanona `_. Breaking changes ~~~~~~~~~~~~~~~~