You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, ArviZ does not try to parallelize its functions in order to improve performance. It would be great to combine the speed-ups from using Numba with parallel computation using Dask if desired. Citing xarray docs:
xarray integrates with Dask to support parallel computations and streaming computation on datasets that don’t fit into memory.
Currently, Dask is an entirely optional feature for xarray. However, the benefits of using Dask are sufficiently strong that Dask may become a required dependency in a future version of xarray.
Both arguments about parallelization and about arrays not fitting in memory are relevant to ArviZ. Moreover, Dask may become a requirement for xarray, and writing our code to be dask compatible could save a lot of time if this eventually happens.
I propose to modify all functions using wrap_xarray_ufunc to have some kwargs passed to xr.apply_ufunc so that if the datasets in the InferenceData object are dask arrays, its capabilities can be used to parallelize the code. In addition, allowing to pass kwargs also to xr.open_dataset will allow to load InferenceData objects from netdcf files directly with dask (with chunks if desired for example).
Below there are some experiments to see the possible benefits of this changes (and keep in mind I don't really know dask at all, I have read its tutorial and xarray docs on it):
importarvizasazimportnumpyasnpimportlogging# hide logged warning due to wrong shape, as it is already fixed in development and irellevant# https://github.com/pydata/xarray/issues/3168logging.getLogger("arviz").setLevel(logging.ERROR)
idata=az.from_dict(posterior={"var": np.random.random(size=(10, 3000, 2000))}, dims={"var": ["dim1"]})
print("########### xr.Dataset containing numpy array ###########")
%timeitess=az.ess(idata)
ess_orig=az.ess(idata)
print("########### xr.Dataset containing dask array ###########")
# chunking an xr.Dataset automatically converts it to dask arrayidata.posterior=idata.posterior.chunk({"dim1": 100})
# here there may be another warning https://github.com/pydata/xarray/issues/2928# that can be solved updating to xarray's latest version%timeitess_parallelized=az.ess(idata, dask="parallelized", output_dtypes=[float]).compute()
ess_parallelized=az.ess(idata, dask="parallelized", output_dtypes=[float]).compute()
print("########### check results ###########")
print((ess_orig==ess_parallelized).all())
which outputs:
########### xr.Dataset containing numpy array ###########
15.2 s ± 89.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
########### xr.Dataset containing dask array ###########
6.09 s ± 182 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
########### check results ###########
<xarray.Dataset>
Dimensions: ()
Data variables:
var bool True
As you can see, the speed-up is quite relevant directly with this vanilla solution. There is probably room for improvement using pure dask, however, I do not think this should be an immediate concern (or a concern at all, maybe an idea for next year's GSoC?), using dask via xarray should be simple enough and provide good speed-ups for large data.
Implementation ideas
As more or less said before, just summing up, we can use dask via xarray to handle parallelization with very little work involved (passing kwargs to wrap_xarray_ufunc and to xr.open_dataset should be enough). However, some things must be taken into account:
xr.apply_ufunc core dims cannot be chunked, which means that chunking on chains or draws won't work with most of the functions.
Dask parellelization must always be optional (even in the case of xarray including it as a dependency) because for small arrays, the overhead of keeping track of the labels and dividing the workflow between chunks makes parallelization with dask slower.
A priori, dask results are not evaluated, and .compute() must be called on them. In xarray though, there is also the option of geting .values attribute, which should also trigger the evaluation.
Dask works with both numpy and numba, so there should be no issues when combining them, but it can always happen.
I do not have a clear idea on how to pass these dask related kwargs for all relevant functions, for ess it can be quite straightforward, but for summary? Should there be an option to allow different dask kwargs for ess and for rhat?
Also, should ArviZ return unevaluated dask espressions so the user can run compute() whenever he wants? Should there be an option to optionally compute things? For summary for instance it could be better to compute ess and rhat at the same time, and also, could summary work without evaluating the values? Or will it always trigger computation via .values attribute?
It would probably be great to eventually add some tests on InferenceData objects containing dask arrays too, but not too many to avoid segmentation fault like we would get if tests were not run on "eager" mode.
The text was updated successfully, but these errors were encountered:
Currently, ArviZ does not try to parallelize its functions in order to improve performance. It would be great to combine the speed-ups from using Numba with parallel computation using Dask if desired. Citing xarray docs:
Both arguments about parallelization and about arrays not fitting in memory are relevant to ArviZ. Moreover, Dask may become a requirement for xarray, and writing our code to be dask compatible could save a lot of time if this eventually happens.
I propose to modify all functions using
wrap_xarray_ufunc
to have some kwargs passed toxr.apply_ufunc
so that if the datasets in the InferenceData object are dask arrays, its capabilities can be used to parallelize the code. In addition, allowing to pass kwargs also toxr.open_dataset
will allow to load InferenceData objects from netdcf files directly with dask (with chunks if desired for example).Below there are some experiments to see the possible benefits of this changes (and keep in mind I don't really know dask at all, I have read its tutorial and xarray docs on it):
which outputs:
As you can see, the speed-up is quite relevant directly with this vanilla solution. There is probably room for improvement using pure dask, however, I do not think this should be an immediate concern (or a concern at all, maybe an idea for next year's GSoC?), using dask via xarray should be simple enough and provide good speed-ups for large data.
Implementation ideas
As more or less said before, just summing up, we can use dask via xarray to handle parallelization with very little work involved (passing kwargs to
wrap_xarray_ufunc
and toxr.open_dataset
should be enough). However, some things must be taken into account:xr.apply_ufunc
core dims cannot be chunked, which means that chunking on chains or draws won't work with most of the functions..compute()
must be called on them. In xarray though, there is also the option of geting.values
attribute, which should also trigger the evaluation.I do not have a clear idea on how to pass these dask related kwargs for all relevant functions, for
ess
it can be quite straightforward, but for summary? Should there be an option to allow different dask kwargs for ess and for rhat?Also, should ArviZ return unevaluated dask espressions so the user can run
compute()
whenever he wants? Should there be an option to optionally compute things? For summary for instance it could be better to compute ess and rhat at the same time, and also, could summary work without evaluating the values? Or will it always trigger computation via.values
attribute?It would probably be great to eventually add some tests on InferenceData objects containing dask arrays too, but not too many to avoid segmentation fault like we would get if tests were not run on "eager" mode.
The text was updated successfully, but these errors were encountered: