From a4c3622c9ee1edcfef3cb7e9170522150556d948 Mon Sep 17 00:00:00 2001 From: "Ulrich J. Herter" Date: Fri, 23 Aug 2019 13:05:13 +0200 Subject: [PATCH] dask: Testing number of computes on reduce methods. --- xarray/tests/test_dask.py | 57 ++++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e3fc6f65e0f..83bca6c06ba 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -27,14 +27,43 @@ dd = pytest.importorskip("dask.dataframe") +class CountingScheduler: + """ Simple dask scheduler counting the number of computes. + + Reference: https://stackoverflow.com/questions/53289286/ """ + + def __init__(self, max_computes=0): + self.total_computes = 0 + self.max_computes = max_computes + + def __call__(self, dsk, keys, **kwargs): + self.total_computes += 1 + if self.total_computes > self.max_computes: + raise RuntimeError( + "To many computes. Total: %d > max: %d." + % (self.total_computes, self.max_computes) + ) + return dask.get(dsk, keys, **kwargs) + + +def _set_dask_scheduler(scheduler): + if LooseVersion(dask.__version__) >= LooseVersion("0.18.0"): + return dask.config.set(scheduler=scheduler) + return dask.set_options(get=scheduler) + + +def test_counting_scheduler(): + data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) + sched = CountingScheduler(0) + with raises_regex(RuntimeError, "To many computes"): + with _set_dask_scheduler(sched): + data.compute() + assert sched.total_computes == 1 + + class DaskTestCase: def assertLazyAnd(self, expected, actual, test): - - with ( - dask.config.set(scheduler="single-threaded") - if LooseVersion(dask.__version__) >= LooseVersion("0.18.0") - else dask.set_options(get=dask.get) - ): + with _set_dask_scheduler(CountingScheduler(1)): test(actual, expected) if isinstance(actual, Dataset): @@ -172,13 +201,15 @@ def test_pickle(self): def test_reduce(self): u = self.eager_var v = self.lazy_var - self.assertLazyAndAllClose(u.mean(), v.mean()) - self.assertLazyAndAllClose(u.std(), v.std()) - self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x")) - self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) - self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) - with raises_regex(NotImplementedError, "dask"): - v.median() + with _set_dask_scheduler(CountingScheduler(0)): + # None of the methods should trigger compute at this stage. + self.assertLazyAndAllClose(u.mean(), v.mean()) + self.assertLazyAndAllClose(u.std(), v.std()) + self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x")) + self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) + self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) + with raises_regex(NotImplementedError, "dask"): + v.median() def test_missing_values(self): values = np.array([0, 1, np.nan, 3])