diff --git a/xarray/core/concat.py b/xarray/core/concat.py index bcab136de8d..7f269a09abe 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -183,28 +183,32 @@ def process_subset_opt(opt, subset): if opt == "different": if compat == "override": raise ValueError( - "Cannot specify both %s='different' and compat='override'." - % subset + f"Cannot specify both {subset}='different' and compat='override'." ) # all nonindexes that are not the same in each dataset for k in getattr(datasets[0], subset): if k not in concat_over: + variables = [ds.variables[k] for ds in datasets] + equals[k] = utils.dask_name_equal(variables) + if equals[k]: + continue + # Compare the variable of all datasets vs. the one # of the first dataset. Perform the minimum amount of # loads in order to avoid multiple loads from disk # while keeping the RAM footprint low. - v_lhs = datasets[0].variables[k].load() + v_lhs = variables[0].load() # We'll need to know later on if variables are equal. - computed = [] - for ds_rhs in datasets[1:]: - v_rhs = ds_rhs.variables[k].compute() + computed = [v_lhs] + for v_rhs in variables[1:]: + v_rhs = v_rhs.compute() computed.append(v_rhs) if not getattr(v_lhs, compat)(v_rhs): concat_over.add(k) equals[k] = False # computed variables are not to be re-computed # again in the future - for ds, v in zip(datasets[1:], computed): + for ds, v in zip(datasets, computed): ds.variables[k].data = v.data break else: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index d943788c434..515732113c0 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -189,6 +189,14 @@ def array_equiv(arr1, arr2): """ arr1 = asarray(arr1) arr2 = asarray(arr2) + if ( + dask_array + and isinstance(arr1, dask_array.Array) + and isinstance(arr2, dask_array.Array) + ): + # GH3068 + if arr1.name == arr2.name: + return True if arr1.shape != arr2.shape: return False with warnings.catch_warnings(): diff --git a/xarray/core/merge.py b/xarray/core/merge.py index db5ef9531df..c4902b6defa 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -19,7 +19,7 @@ from . import dtypes, pdcompat from .alignment import deep_align -from .utils import Frozen, dict_equiv +from .utils import Frozen, dict_equiv, dask_name_equal from .variable import Variable, as_variable, assert_unique_multiindex_level_names if TYPE_CHECKING: @@ -123,11 +123,14 @@ def unique_variable( combine_method = "fillna" if equals is None: - out = out.compute() - for var in variables[1:]: - equals = getattr(out, compat)(var) - if not equals: - break + equals = dask_name_equal(variables) + + if not equals: + out = out.compute() + for var in variables[1:]: + equals = getattr(out, compat)(var) + if not equals: + break if not equals: raise MergeError( diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 6befe0b5efc..2ad95f38af2 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -676,3 +676,34 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: while new_dim in dims: new_dim = "_" + str(new_dim) return new_dim + + +def dask_name_equal(variables): + """ + Test variable data for equality by comparing dask names if possible. + + Returns + ------- + True or False if all variables contain dask arrays and their dask names are equal + or not equal respectively. + + None if equality cannot be determined i.e. when not all variables contain dask arrays. + """ + try: + import dask.array as dask_array + except ImportError: + return None + out = variables[0] + equals = None + if isinstance(out.data, dask_array.Array): + for var in variables[1:]: + if isinstance(var.data, dask_array.Array): + if out.data.name == var.data.name: + equals = True + else: + equals = False + break + else: + equals = None + break + return equals diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 93ad1eafb97..f21645c69e8 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1229,7 +1229,9 @@ def transpose(self, *dims) -> "Variable": if len(dims) == 0: dims = self.dims[::-1] axes = self.get_axis_num(dims) - if len(dims) < 2: # no need to transpose if only one dimension + if len(dims) < 2 or dims == self.dims: + # no need to transpose if only one dimension + # or dims are in same order return self.copy(deep=False) data = as_indexable(self._data).transpose(axes) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 50517ae3c9c..b45eea05f2b 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -22,6 +22,7 @@ assert_identical, raises_regex, ) +from ..core.utils import dask_name_equal dask = pytest.importorskip("dask") da = pytest.importorskip("dask.array") @@ -1135,3 +1136,40 @@ def test_make_meta(map_ds): for variable in map_ds.data_vars: assert variable in meta.data_vars assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim + + +def test_identical_coords_no_computes(): + lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + a = xr.DataArray( + da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2} + ) + b = xr.DataArray( + da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2} + ) + with raise_if_dask_computes(): + c = a + b + assert_identical(c, a) + + +def test_dask_name_equal(): + lons1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + with raise_if_dask_computes(): + assert dask_name_equal([lons1, lons2]) + with raise_if_dask_computes(): + assert not dask_name_equal([lons1, lons2 / 2]) + assert dask_name_equal([lons1, lons2.compute()]) is None + assert dask_name_equal([lons1.compute(), lons2.compute()]) is None + + with raise_if_dask_computes(): + assert dask_name_equal([lons1, lons1.transpose("y", "x")]) + + with raise_if_dask_computes(): + for compat in [ + "broadcast_equals", + "equals", + "override", + "identical", + "no_conflicts", + ]: + xr.merge([lons1, lons2], compat=compat)