From e84cc97b6d72a1e55128de80df3955c6402cf025 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 27 Oct 2019 20:22:13 -0600 Subject: [PATCH] Optimize dask array equality checks. Dask arrays with the same graph have the same name. We can use this to quickly compare dask-backed variables without computing. Fixes #3068 and #3311 --- xarray/core/concat.py | 16 ++++++++++++++ xarray/core/duck_array_ops.py | 41 +++++++++++++++++++++++++++++++++++ xarray/core/merge.py | 17 +++++++++++---- xarray/core/variable.py | 14 +++++++----- xarray/tests/test_dask.py | 40 ++++++++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 9 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index bcab136de8d..01da928c29d 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -2,6 +2,7 @@ from . import dtypes, utils from .alignment import align +from .duck_array_ops import lazy_array_equiv from .merge import _VALID_COMPAT, unique_variable from .variable import IndexVariable, Variable, as_variable from .variable import concat as concat_vars @@ -189,6 +190,21 @@ def process_subset_opt(opt, subset): # all nonindexes that are not the same in each dataset for k in getattr(datasets[0], subset): if k not in concat_over: + equals[k] = None + variables = [ds.variables[k] for ds in datasets] + # first check without comparing values i.e. no computes + for var in variables[1:]: + equals[k] = getattr(variables[0], compat)( + var, equiv=lazy_array_equiv + ) + if not equals[k]: + break + + if equals[k] is not None: + if equals[k] is False: + concat_over.add(k) + continue + # Compare the variable of all datasets vs. the one # of the first dataset. Perform the minimum amount of # loads in order to avoid multiple loads from disk diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index d943788c434..c2b72eb08e1 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -181,9 +181,34 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): arr2 = asarray(arr2) if arr1.shape != arr2.shape: return False + if ( + dask_array + and isinstance(arr1, dask_array.Array) + and isinstance(arr2, dask_array.Array) + ): + # GH3068 + if arr1.name == arr2.name: + return True return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) +def lazy_array_equiv(arr1, arr2): + """Like array_equal, but doesn't actually compare values + """ + arr1 = asarray(arr1) + arr2 = asarray(arr2) + if arr1.shape != arr2.shape: + return False + if ( + dask_array + and isinstance(arr1, dask_array.Array) + and isinstance(arr2, dask_array.Array) + ): + # GH3068 + if arr1.name == arr2.name: + return True + + def array_equiv(arr1, arr2): """Like np.array_equal, but also allows values to be NaN in both arrays """ @@ -191,6 +216,14 @@ def array_equiv(arr1, arr2): arr2 = asarray(arr2) if arr1.shape != arr2.shape: return False + if ( + dask_array + and isinstance(arr1, dask_array.Array) + and isinstance(arr2, dask_array.Array) + ): + # GH3068 + if arr1.name == arr2.name: + return True with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) @@ -205,6 +238,14 @@ def array_notnull_equiv(arr1, arr2): arr2 = asarray(arr2) if arr1.shape != arr2.shape: return False + if ( + dask_array + and isinstance(arr1, dask_array.Array) + and isinstance(arr2, dask_array.Array) + ): + # GH3068 + if arr1.name == arr2.name: + return True with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index db5ef9531df..21426b8bd37 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -19,6 +19,7 @@ from . import dtypes, pdcompat from .alignment import deep_align +from .duck_array_ops import lazy_array_equiv from .utils import Frozen, dict_equiv from .variable import Variable, as_variable, assert_unique_multiindex_level_names @@ -123,16 +124,24 @@ def unique_variable( combine_method = "fillna" if equals is None: - out = out.compute() + # first check without comparing values i.e. no computes for var in variables[1:]: - equals = getattr(out, compat)(var) + equals = getattr(out, compat)(var, equiv=lazy_array_equiv) if not equals: break + # now compare values with minimum number of computes + if not equals: + out = out.compute() + for var in variables[1:]: + equals = getattr(out, compat)(var) + if not equals: + break + if not equals: raise MergeError( - "conflicting values for variable {!r} on objects to be combined. " - "You can skip this check by specifying compat='override'.".format(name) + f"conflicting values for variable {name!r} on objects to be combined. " + "You can skip this check by specifying compat='override'." ) if combine_method: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 93ad1eafb97..82c041ecd05 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1229,7 +1229,9 @@ def transpose(self, *dims) -> "Variable": if len(dims) == 0: dims = self.dims[::-1] axes = self.get_axis_num(dims) - if len(dims) < 2: # no need to transpose if only one dimension + if len(dims) < 2 or dims == self.dims: + # no need to transpose if only one dimension + # or dims are in same order return self.copy(deep=False) data = as_indexable(self._data).transpose(axes) @@ -1588,22 +1590,24 @@ def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv): return False return self.equals(other, equiv=equiv) - def identical(self, other): + def identical(self, other, equiv=duck_array_ops.array_equiv): """Like equals, but also checks attributes. """ try: - return utils.dict_equiv(self.attrs, other.attrs) and self.equals(other) + return utils.dict_equiv(self.attrs, other.attrs) and self.equals( + other, equiv=equiv + ) except (TypeError, AttributeError): return False - def no_conflicts(self, other): + def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv): """True if the intersection of two Variable's non-null data is equal; otherwise false. Variables can thus still be equal if there are locations where either, or both, contain NaN values. """ - return self.broadcast_equals(other, equiv=duck_array_ops.array_notnull_equiv) + return self.broadcast_equals(other, equiv=equiv) def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): """Compute the qth quantile of the data along the specified dimension. diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 50517ae3c9c..bfda579644f 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -22,6 +22,7 @@ assert_identical, raises_regex, ) +from ..core.duck_array_ops import lazy_array_equiv dask = pytest.importorskip("dask") da = pytest.importorskip("dask.array") @@ -1135,3 +1136,42 @@ def test_make_meta(map_ds): for variable in map_ds.data_vars: assert variable in meta.data_vars assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim + + +def test_identical_coords_no_computes(): + lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + a = xr.DataArray( + da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2} + ) + b = xr.DataArray( + da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2} + ) + with raise_if_dask_computes(): + c = a + b + assert_identical(c, a) + + +def test_lazy_array_equiv(): + lons1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + var1 = lons1.variable + var2 = lons2.variable + with raise_if_dask_computes(): + lons1.equals(lons2) + with raise_if_dask_computes(): + var1.equals(var2 / 2, equiv=lazy_array_equiv) + assert var1.equals(var2.compute(), equiv=lazy_array_equiv) is None + assert var1.compute().equals(var2.compute(), equiv=lazy_array_equiv) is None + + with raise_if_dask_computes(): + assert lons1.equals(lons1.transpose("y", "x")) + + with raise_if_dask_computes(): + for compat in [ + "broadcast_equals", + "equals", + "override", + "identical", + "no_conflicts", + ]: + xr.merge([lons1, lons2], compat=compat)