Skip to content

Commit

Permalink
dask: Testing number of computes on reduce methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulrich J. Herter committed Aug 23, 2019
1 parent 05ae290 commit a4c3622
Showing 1 changed file with 44 additions and 13 deletions.
57 changes: 44 additions & 13 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,43 @@
dd = pytest.importorskip("dask.dataframe")


class CountingScheduler:
""" Simple dask scheduler counting the number of computes.
Reference: https://stackoverflow.com/questions/53289286/ """

def __init__(self, max_computes=0):
self.total_computes = 0
self.max_computes = max_computes

def __call__(self, dsk, keys, **kwargs):
self.total_computes += 1
if self.total_computes > self.max_computes:
raise RuntimeError(
"To many computes. Total: %d > max: %d."
% (self.total_computes, self.max_computes)
)
return dask.get(dsk, keys, **kwargs)


def _set_dask_scheduler(scheduler):
if LooseVersion(dask.__version__) >= LooseVersion("0.18.0"):
return dask.config.set(scheduler=scheduler)
return dask.set_options(get=scheduler)


def test_counting_scheduler():
data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2))
sched = CountingScheduler(0)
with raises_regex(RuntimeError, "To many computes"):
with _set_dask_scheduler(sched):
data.compute()
assert sched.total_computes == 1


class DaskTestCase:
def assertLazyAnd(self, expected, actual, test):

with (
dask.config.set(scheduler="single-threaded")
if LooseVersion(dask.__version__) >= LooseVersion("0.18.0")
else dask.set_options(get=dask.get)
):
with _set_dask_scheduler(CountingScheduler(1)):
test(actual, expected)

if isinstance(actual, Dataset):
Expand Down Expand Up @@ -172,13 +201,15 @@ def test_pickle(self):
def test_reduce(self):
u = self.eager_var
v = self.lazy_var
self.assertLazyAndAllClose(u.mean(), v.mean())
self.assertLazyAndAllClose(u.std(), v.std())
self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x"))
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
with raises_regex(NotImplementedError, "dask"):
v.median()
with _set_dask_scheduler(CountingScheduler(0)):
# None of the methods should trigger compute at this stage.
self.assertLazyAndAllClose(u.mean(), v.mean())
self.assertLazyAndAllClose(u.std(), v.std())
self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x"))
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
with raises_regex(NotImplementedError, "dask"):
v.median()

def test_missing_values(self):
values = np.array([0, 1, np.nan, 3])
Expand Down

0 comments on commit a4c3622

Please sign in to comment.