Skip to content

Commit

Permalink
Recursive tokenization (#3515)
Browse files Browse the repository at this point in the history
* recursive tokenize

* black

* What's New

* Also test Dataset

* Also test IndexVariable

* Cleanup

* tokenize sparse objects
  • Loading branch information
crusaderky authored Nov 13, 2019
1 parent b74f80c commit e70138b
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 5 deletions.
2 changes: 1 addition & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ New Features
for xarray objects. Note that xarray objects with a dask.array backend already used
deterministic hashing in previous releases; this change implements it when whole
xarray objects are embedded in a dask graph, e.g. when :meth:`DataArray.map` is
invoked. (:issue:`3378`, :pull:`3446`)
invoked. (:issue:`3378`, :pull:`3446`, :pull:`3515`)
By `Deepak Cherian <https://github.com/dcherian>`_ and
`Guido Imperiale <https://github.com/crusaderky>`_.

Expand Down
4 changes: 3 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,9 @@ def reset_coords(
return dataset

def __dask_tokenize__(self):
return (type(self), self._variable, self._coords, self._name)
from dask.base import normalize_token

return normalize_token((type(self), self._variable, self._coords, self._name))

def __dask_graph__(self):
return self._to_temp_dataset().__dask_graph__()
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,11 @@ def load(self, **kwargs) -> "Dataset":
return self

def __dask_tokenize__(self):
return (type(self), self._variables, self._coord_names, self._attrs)
from dask.base import normalize_token

return normalize_token(
(type(self), self._variables, self._coord_names, self._attrs)
)

def __dask_graph__(self):
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,9 @@ def compute(self, **kwargs):
def __dask_tokenize__(self):
# Use v.data, instead of v._data, in order to cope with the wrappers
# around NetCDF and the like
return type(self), self._dims, self.data, self._attrs
from dask.base import normalize_token

return normalize_token((type(self), self._dims, self.data, self._attrs))

def __dask_graph__(self):
if isinstance(self._data, dask_array_type):
Expand Down Expand Up @@ -1973,8 +1975,10 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
self._data = PandasIndexAdapter(self._data)

def __dask_tokenize__(self):
from dask.base import normalize_token

# Don't waste time converting pd.Index to np.ndarray
return (type(self), self._dims, self._data.array, self._attrs)
return normalize_token((type(self), self._dims, self._data.array, self._attrs))

def load(self):
# data is already loaded into memory for IndexVariable
Expand Down
26 changes: 26 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,32 @@ def test_token_identical(obj, transform):
)


def test_recursive_token():
"""Test that tokenization is invoked recursively, and doesn't just rely on the
output of str()
"""
a = np.ones(10000)
b = np.ones(10000)
b[5000] = 2
assert str(a) == str(b)
assert dask.base.tokenize(a) != dask.base.tokenize(b)

# Test DataArray and Variable
da_a = DataArray(a)
da_b = DataArray(b)
assert dask.base.tokenize(da_a) != dask.base.tokenize(da_b)

# Test Dataset
ds_a = da_a.to_dataset(name="x")
ds_b = da_b.to_dataset(name="x")
assert dask.base.tokenize(ds_a) != dask.base.tokenize(ds_b)

# Test IndexVariable
da_a = DataArray(a, dims=["x"], coords={"x": a})
da_b = DataArray(a, dims=["x"], coords={"x": b})
assert dask.base.tokenize(da_a) != dask.base.tokenize(da_b)


@requires_scipy_or_netCDF4
def test_normalize_token_with_backend(map_ds):
with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp_file:
Expand Down
4 changes: 4 additions & 0 deletions xarray/tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,10 @@ def test_dask_token():
import dask

s = sparse.COO.from_numpy(np.array([0, 0, 1, 2]))

# https://github.com/pydata/sparse/issues/300
s.__dask_tokenize__ = lambda: dask.base.normalize_token(s.__dict__)

a = DataArray(s)
t1 = dask.base.tokenize(a)
t2 = dask.base.tokenize(a)
Expand Down

0 comments on commit e70138b

Please sign in to comment.