Skip to content

Commit

Permalink
Optimize dask array equality checks. (#3453)
Browse files Browse the repository at this point in the history
* Optimize dask array equality checks.

Dask arrays with the same graph have the same name. We can use this to quickly
compare dask-backed variables without computing.

Fixes #3068 and #3311

* better docstring

* review suggestions.

* add concat test

* update whats new

* Add identity check to lazy_array_equiv

* pep8

* bugfix.
  • Loading branch information
dcherian committed Nov 5, 2019
1 parent b649846 commit af28c6b
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 45 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ Bug fixes
but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle <https://github.com/rdoyle45>`_
- Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Use dask names to compare dask objects prior to comparing values after computation.
(:issue:`3068`, :issue:`3311`, :issue:`3454`, :pull:`3453`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
By `Anderson Banihirwe <https://github.com/andersy005>`_.
- Fix :py:meth:`xarray.core.groupby.DataArrayGroupBy.reduce` and
Expand Down
56 changes: 37 additions & 19 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import dtypes, utils
from .alignment import align
from .duck_array_ops import lazy_array_equiv
from .merge import _VALID_COMPAT, unique_variable
from .variable import IndexVariable, Variable, as_variable
from .variable import concat as concat_vars
Expand Down Expand Up @@ -189,26 +190,43 @@ def process_subset_opt(opt, subset):
# all nonindexes that are not the same in each dataset
for k in getattr(datasets[0], subset):
if k not in concat_over:
# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk
# while keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
equals[k] = None
variables = [ds.variables[k] for ds in datasets]
# first check without comparing values i.e. no computes
for var in variables[1:]:
equals[k] = getattr(variables[0], compat)(
var, equiv=lazy_array_equiv
)
if equals[k] is not True:
# exit early if we know these are not equal or that
# equality cannot be determined i.e. one or all of
# the variables wraps a numpy array
break
else:
equals[k] = True

if equals[k] is False:
concat_over.add(k)

elif equals[k] is None:
# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk
# while keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
break
else:
equals[k] = True

elif opt == "all":
concat_over.update(
Expand Down
62 changes: 47 additions & 15 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,27 +174,57 @@ def as_shared_dtype(scalars_or_arrays):
return [x.astype(out_type, copy=False) for x in arrays]


def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
"""Like np.allclose, but also allows values to be NaN in both arrays
def lazy_array_equiv(arr1, arr2):
"""Like array_equal, but doesn't actually compare values.
Returns True when arr1, arr2 identical or their dask names are equal.
Returns False when shapes are not equal.
Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays;
or their dask names are not equal
"""
if arr1 is arr2:
return True
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
if (
dask_array
and isinstance(arr1, dask_array.Array)
and isinstance(arr2, dask_array.Array)
):
# GH3068
if arr1.name == arr2.name:
return True
else:
return None
return None


def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
"""Like np.allclose, but also allows values to be NaN in both arrays
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
lazy_equiv = lazy_array_equiv(arr1, arr2)
if lazy_equiv is None:
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
else:
return lazy_equiv


def array_equiv(arr1, arr2):
"""Like np.array_equal, but also allows values to be NaN in both arrays
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
return bool(flag_array.all())
lazy_equiv = lazy_array_equiv(arr1, arr2)
if lazy_equiv is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
return bool(flag_array.all())
else:
return lazy_equiv


def array_notnull_equiv(arr1, arr2):
Expand All @@ -203,12 +233,14 @@ def array_notnull_equiv(arr1, arr2):
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
return bool(flag_array.all())
lazy_equiv = lazy_array_equiv(arr1, arr2)
if lazy_equiv is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
return bool(flag_array.all())
else:
return lazy_equiv


def count(data, axis=None):
Expand Down
19 changes: 14 additions & 5 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import dtypes, pdcompat
from .alignment import deep_align
from .duck_array_ops import lazy_array_equiv
from .utils import Frozen, dict_equiv
from .variable import Variable, as_variable, assert_unique_multiindex_level_names

Expand Down Expand Up @@ -123,16 +124,24 @@ def unique_variable(
combine_method = "fillna"

if equals is None:
out = out.compute()
# first check without comparing values i.e. no computes
for var in variables[1:]:
equals = getattr(out, compat)(var)
if not equals:
equals = getattr(out, compat)(var, equiv=lazy_array_equiv)
if equals is not True:
break

if equals is None:
# now compare values with minimum number of computes
out = out.compute()
for var in variables[1:]:
equals = getattr(out, compat)(var)
if not equals:
break

if not equals:
raise MergeError(
"conflicting values for variable {!r} on objects to be combined. "
"You can skip this check by specifying compat='override'.".format(name)
f"conflicting values for variable {name!r} on objects to be combined. "
"You can skip this check by specifying compat='override'."
)

if combine_method:
Expand Down
14 changes: 9 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,9 @@ def transpose(self, *dims) -> "Variable":
dims = self.dims[::-1]
dims = tuple(infix_dims(dims, self.dims))
axes = self.get_axis_num(dims)
if len(dims) < 2: # no need to transpose if only one dimension
if len(dims) < 2 or dims == self.dims:
# no need to transpose if only one dimension
# or dims are in same order
return self.copy(deep=False)

data = as_indexable(self._data).transpose(axes)
Expand Down Expand Up @@ -1595,22 +1597,24 @@ def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv):
return False
return self.equals(other, equiv=equiv)

def identical(self, other):
def identical(self, other, equiv=duck_array_ops.array_equiv):
"""Like equals, but also checks attributes.
"""
try:
return utils.dict_equiv(self.attrs, other.attrs) and self.equals(other)
return utils.dict_equiv(self.attrs, other.attrs) and self.equals(
other, equiv=equiv
)
except (TypeError, AttributeError):
return False

def no_conflicts(self, other):
def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv):
"""True if the intersection of two Variable's non-null data is
equal; otherwise false.
Variables can thus still be equal if there are locations where either,
or both, contain NaN values.
"""
return self.broadcast_equals(other, equiv=duck_array_ops.array_notnull_equiv)
return self.broadcast_equals(other, equiv=equiv)

def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
"""Compute the qth quantile of the data along the specified dimension.
Expand Down
108 changes: 107 additions & 1 deletion xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
raises_regex,
requires_scipy_or_netCDF4,
)
from ..core.duck_array_ops import lazy_array_equiv
from .test_backends import create_tmp_file

dask = pytest.importorskip("dask")
Expand Down Expand Up @@ -428,7 +429,53 @@ def test_concat_loads_variables(self):
out.compute()
assert kernel_call_count == 24

# Finally, test that riginals are unaltered
# Finally, test that originals are unaltered
assert ds1["d"].data is d1
assert ds1["c"].data is c1
assert ds2["d"].data is d2
assert ds2["c"].data is c2
assert ds3["d"].data is d3
assert ds3["c"].data is c3

# now check that concat() is correctly using dask name equality to skip loads
out = xr.concat(
[ds1, ds1, ds1], dim="n", data_vars="different", coords="different"
)
assert kernel_call_count == 24
# variables are not loaded in the output
assert isinstance(out["d"].data, dask.array.Array)
assert isinstance(out["c"].data, dask.array.Array)

out = xr.concat(
[ds1, ds1, ds1], dim="n", data_vars=[], coords=[], compat="identical"
)
assert kernel_call_count == 24
# variables are not loaded in the output
assert isinstance(out["d"].data, dask.array.Array)
assert isinstance(out["c"].data, dask.array.Array)

out = xr.concat(
[ds1, ds2.compute(), ds3],
dim="n",
data_vars="all",
coords="different",
compat="identical",
)
# c1,c3 must be computed for comparison since c2 is numpy;
# d2 is computed too
assert kernel_call_count == 28

out = xr.concat(
[ds1, ds2.compute(), ds3],
dim="n",
data_vars="all",
coords="all",
compat="identical",
)
# no extra computes
assert kernel_call_count == 30

# Finally, test that originals are unaltered
assert ds1["d"].data is d1
assert ds1["c"].data is c1
assert ds2["d"].data is d2
Expand Down Expand Up @@ -1142,6 +1189,19 @@ def test_make_meta(map_ds):
assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim


def test_identical_coords_no_computes():
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
a = xr.DataArray(
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
)
b = xr.DataArray(
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
)
with raise_if_dask_computes():
c = a + b
assert_identical(c, a)


@pytest.mark.parametrize(
"obj", [make_da(), make_da().compute(), make_ds(), make_ds().compute()]
)
Expand Down Expand Up @@ -1229,3 +1289,49 @@ def test_normalize_token_with_backend(map_ds):
map_ds.to_netcdf(tmp_file)
read = xr.open_dataset(tmp_file)
assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read)


@pytest.mark.parametrize(
"compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
)
def test_lazy_array_equiv_variables(compat):
var1 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
var2 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
var3 = xr.Variable(("y", "x"), da.zeros((20, 10), chunks=2))

with raise_if_dask_computes():
assert getattr(var1, compat)(var2, equiv=lazy_array_equiv)
# values are actually equal, but we don't know that till we compute, return None
with raise_if_dask_computes():
assert getattr(var1, compat)(var2 / 2, equiv=lazy_array_equiv) is None

# shapes are not equal, return False without computes
with raise_if_dask_computes():
assert getattr(var1, compat)(var3, equiv=lazy_array_equiv) is False

# if one or both arrays are numpy, return None
assert getattr(var1, compat)(var2.compute(), equiv=lazy_array_equiv) is None
assert (
getattr(var1.compute(), compat)(var2.compute(), equiv=lazy_array_equiv) is None
)

with raise_if_dask_computes():
assert getattr(var1, compat)(var2.transpose("y", "x"))


@pytest.mark.parametrize(
"compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
)
def test_lazy_array_equiv_merge(compat):
da1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
da2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
da3 = xr.DataArray(da.ones((20, 10), chunks=2), dims=("y", "x"))

with raise_if_dask_computes():
xr.merge([da1, da2], compat=compat)
# shapes are not equal; no computes necessary
with raise_if_dask_computes(max_computes=0):
with pytest.raises(ValueError):
xr.merge([da1, da3], compat=compat)
with raise_if_dask_computes(max_computes=2):
xr.merge([da1, da2 / 2], compat=compat)

0 comments on commit af28c6b

Please sign in to comment.