Skip to content
forked from pydata/xarray

Commit

Permalink
optimizations for dask array equality comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 27, 2019
1 parent 79b3cdd commit d2f162d
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 14 deletions.
18 changes: 11 additions & 7 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,28 +183,32 @@ def process_subset_opt(opt, subset):
if opt == "different":
if compat == "override":
raise ValueError(
"Cannot specify both %s='different' and compat='override'."
% subset
f"Cannot specify both {subset}='different' and compat='override'."
)
# all nonindexes that are not the same in each dataset
for k in getattr(datasets[0], subset):
if k not in concat_over:
variables = [ds.variables[k] for ds in datasets]
equals[k] = utils.dask_name_equal(variables)
if equals[k]:
continue

# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk
# while keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
v_lhs = variables[0].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed = [v_lhs]
for v_rhs in variables[1:]:
v_rhs = v_rhs.compute()
computed.append(v_rhs)
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
for ds, v in zip(datasets, computed):
ds.variables[k].data = v.data
break
else:
Expand Down
8 changes: 8 additions & 0 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def array_equiv(arr1, arr2):
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if (
dask_array
and isinstance(arr1, dask_array.Array)
and isinstance(arr2, dask_array.Array)
):
# GH3068
if arr1.name == arr2.name:
return True
if arr1.shape != arr2.shape:
return False
with warnings.catch_warnings():
Expand Down
15 changes: 9 additions & 6 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from . import dtypes, pdcompat
from .alignment import deep_align
from .utils import Frozen, dict_equiv
from .utils import Frozen, dict_equiv, dask_name_equal
from .variable import Variable, as_variable, assert_unique_multiindex_level_names

if TYPE_CHECKING:
Expand Down Expand Up @@ -123,11 +123,14 @@ def unique_variable(
combine_method = "fillna"

if equals is None:
out = out.compute()
for var in variables[1:]:
equals = getattr(out, compat)(var)
if not equals:
break
equals = dask_name_equal(variables)

if not equals:
out = out.compute()
for var in variables[1:]:
equals = getattr(out, compat)(var)
if not equals:
break

if not equals:
raise MergeError(
Expand Down
31 changes: 31 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,3 +676,34 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
while new_dim in dims:
new_dim = "_" + str(new_dim)
return new_dim


def dask_name_equal(variables):
"""
Test variable data for equality by comparing dask names if possible.
Returns
-------
True or False if all variables contain dask arrays and their dask names are equal
or not equal respectively.
None if equality cannot be determined i.e. when not all variables contain dask arrays.
"""
try:
import dask.array as dask_array
except ImportError:
return None
out = variables[0]
equals = None
if isinstance(out.data, dask_array.Array):
for var in variables[1:]:
if isinstance(var.data, dask_array.Array):
if out.data.name == var.data.name:
equals = True
else:
equals = False
break
else:
equals = None
break
return equals
4 changes: 3 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,9 @@ def transpose(self, *dims) -> "Variable":
if len(dims) == 0:
dims = self.dims[::-1]
axes = self.get_axis_num(dims)
if len(dims) < 2: # no need to transpose if only one dimension
if len(dims) < 2 or dims == self.dims:
# no need to transpose if only one dimension
# or dims are in same order
return self.copy(deep=False)

data = as_indexable(self._data).transpose(axes)
Expand Down
38 changes: 38 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
assert_identical,
raises_regex,
)
from ..core.utils import dask_name_equal

dask = pytest.importorskip("dask")
da = pytest.importorskip("dask.array")
Expand Down Expand Up @@ -1135,3 +1136,40 @@ def test_make_meta(map_ds):
for variable in map_ds.data_vars:
assert variable in meta.data_vars
assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim


def test_identical_coords_no_computes():
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
a = xr.DataArray(
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
)
b = xr.DataArray(
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
)
with raise_if_dask_computes():
c = a + b
assert_identical(c, a)


def test_dask_name_equal():
lons1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
with raise_if_dask_computes():
assert dask_name_equal([lons1, lons2])
with raise_if_dask_computes():
assert not dask_name_equal([lons1, lons2 / 2])
assert dask_name_equal([lons1, lons2.compute()]) is None
assert dask_name_equal([lons1.compute(), lons2.compute()]) is None

with raise_if_dask_computes():
assert dask_name_equal([lons1, lons1.transpose("y", "x")])

with raise_if_dask_computes():
for compat in [
"broadcast_equals",
"equals",
"override",
"identical",
"no_conflicts",
]:
xr.merge([lons1, lons2], compat=compat)

0 comments on commit d2f162d

Please sign in to comment.