Skip to content

Commit

Permalink
Cache chunk boundaries for integer slicing (dask#4923)
Browse files Browse the repository at this point in the history
This is an alternative to dask#4909, to implement dask#4867.

Instead of caching in the class as in dask#4909, use functools.lru_cache.
This unfortunately has a fixed cache size rather than a cache entry
stored with each array, but simplifies the code as it is not necessary
to pass the cached value from the Array class down through the call tree
to the point of use.

A quick benchmark shows that the result for indexing a single value from
a large array is similar to that from dask#4909, i.e., around 10x faster for
constructing the graph.

This only applies the cache in `_slice_1d`, so should be considered a
proof-of-concept.

* Move cached_cumsum to dask/array/slicing.py

It can't go in dask/utils.py because the top level is not supposed to
depend on numpy.

* cached_cumsum: index cache by both id and hash

The underlying _cumsum is first called with _HashIdWrapper, which will
hit (very cheaply) if we've seen this tuple object before. If not, it
will call itself again without the wrapper, which will hit (but at a
higher cost for tuple.__hash__) if we've seen the same value before but
in a different tuple object.

* Apply cached_cumsum in more places
  • Loading branch information
bmerry authored and mrocklin committed Jun 14, 2019
1 parent 66531ba commit 1f821f4
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 13 deletions.
8 changes: 4 additions & 4 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ..highlevelgraph import HighLevelGraph
from ..bytes.core import get_mapper, get_fs_token_paths
from .numpy_compat import _Recurser, _make_sliced_dtype
from .slicing import slice_array, replace_ellipsis
from .slicing import slice_array, replace_ellipsis, cached_cumsum
from .blockwise import blockwise

config.update_defaults({'array': {
Expand Down Expand Up @@ -561,14 +561,14 @@ def map_blocks(func, *args, **kwargs):
if drop_axis:
# We concatenate along dropped axes, so we need to treat them
# as if there is only a single chunk.
starts[i] = [(np.cumsum((0,) + arg.chunks[j])
starts[i] = [(cached_cumsum(arg.chunks[j], initial_zero=True)
if ind in out_ind else np.array([0, arg.shape[j]]))
for j, ind in enumerate(in_ind)]
num_chunks[i] = tuple(len(s) - 1 for s in starts[i])
else:
starts[i] = [np.cumsum((0,) + c) for c in arg.chunks]
starts[i] = [cached_cumsum(c, initial_zero=True) for c in arg.chunks]
num_chunks[i] = arg.numblocks
out_starts = [np.cumsum((0,) + c) for c in out.chunks]
out_starts = [cached_cumsum(c, initial_zero=True) for c in out.chunks]

for k, v in dsk.items():
vv = v
Expand Down
9 changes: 5 additions & 4 deletions dask/array/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from . import chunk
from .core import (Array, asarray, normalize_chunks,
stack, concatenate, block,
broadcast_to, broadcast_arrays)
broadcast_to, broadcast_arrays,
cached_cumsum)
from .wrap import empty, ones, zeros, full
from .utils import AxisError, meta_from_array, zeros_like_safe

Expand Down Expand Up @@ -549,8 +550,8 @@ def _diag_len(dim1, dim2, offset):

diag_chunks = []
chunk_offsets = []
cum1 = [0] + list(np.cumsum(a.chunks[axis1]))[:-1]
cum2 = [0] + list(np.cumsum(a.chunks[axis2]))[:-1]
cum1 = list(cached_cumsum(a.chunks[axis1], initial_zero=True)[:-1])
cum2 = list(cached_cumsum(a.chunks[axis2], initial_zero=True)[:-1])
for co1, c1 in zip(cum1, a.chunks[axis1]):
chunk_offsets.append([])
for co2, c2 in zip(cum2, a.chunks[axis2]):
Expand Down Expand Up @@ -728,7 +729,7 @@ def repeat(a, repeats, axis=None):
if repeats == 1:
return a

cchunks = np.cumsum((0,) + a.chunks[axis])
cchunks = cached_cumsum(a.chunks[axis], initial_zero=True)
slices = []
for c_start, c_stop in sliding_window(2, cchunks):
ls = np.linspace(c_start, c_stop, repeats).round(0)
Expand Down
2 changes: 1 addition & 1 deletion dask/array/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def gradient(f, *varargs, **kwargs):
if np.min(c) < kwargs["edge_order"] + 1:
raise ValueError(
'Chunk size must be larger than edge_order + 1. '
'Minimum chunk for aixs {} is {}. Rechunk to '
'Minimum chunk for axis {} is {}. Rechunk to '
'proceed.'.format(np.min(c), ax))

if np.isscalar(varargs[i]):
Expand Down
69 changes: 66 additions & 3 deletions dask/array/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numbers import Integral, Number
from operator import getitem, itemgetter
import warnings
import functools

import numpy as np
from toolz import memoize, merge, pluck, concat
Expand Down Expand Up @@ -273,7 +274,7 @@ def slice_slices_and_integers(out_name, in_name, blockdims, index):
_slice_1d
"""
shape = tuple(map(sum, blockdims))
shape = tuple(cached_cumsum(dim, initial_zero=True)[-1] for dim in blockdims)

for dim, ind in zip(shape, index):
if np.isnan(dim) and ind != slice(None, None, None):
Expand Down Expand Up @@ -381,7 +382,7 @@ def _slice_1d(dim_shape, lengths, index):
>>> _slice_1d(100, [20, 20, 20, 20, 20], slice(100, -12, -3))
{4: slice(-1, -12, -3)}
"""
chunk_boundaries = np.cumsum(lengths, dtype=np.int64)
chunk_boundaries = cached_cumsum(lengths)

if isinstance(index, Integral):
# use right-side search to be consistent with previous result
Expand Down Expand Up @@ -521,7 +522,7 @@ def slicing_plan(chunks, index):
A list of chunk/sub-index pairs corresponding to each output chunk
"""
index = np.asanyarray(index)
cum_chunks = np.cumsum(chunks)
cum_chunks = cached_cumsum(chunks)

chunk_locations = np.searchsorted(cum_chunks, index, side='right')
where = np.where(np.diff(chunk_locations))[0] + 1
Expand Down Expand Up @@ -1036,3 +1037,65 @@ def slice_with_bool_dask_array(x, index):

def getitem_variadic(x, *index):
return x[index]


class _HashIdWrapper(object):
"""Hash and compare a wrapped object by identity instead of value"""

def __init__(self, wrapped):
self.wrapped = wrapped

def __eq__(self, other):
if not isinstance(other, _HashIdWrapper):
return NotImplemented
return self.wrapped is other.wrapped

def __ne__(self, other):
if not isinstance(other, _HashIdWrapper):
return NotImplemented
return self.wrapped is not other.wrapped

def __hash__(self):
return id(self.wrapped)


@functools.lru_cache()
def _cumsum(seq):
if isinstance(seq, _HashIdWrapper):
return _cumsum(seq.wrapped)
seq = np.array(seq)
dtype = np.int64 if np.issubdtype(seq.dtype, np.integer) else seq.dtype
out = np.empty(len(seq) + 1, dtype)
out[0] = 0
np.cumsum(seq, out=out[1:], dtype=dtype)
return out


def cached_cumsum(seq, initial_zero=False):
"""Compute :meth:`np.cumsum` with caching.
Caching is by the identify of `seq` rather than the value. It is thus
important that `seq` is a tuple of immutable objects, and this function
is intended for use where `seq` is a value that will persist.
The result has type int64 if the sequence contains integers, and
otherwise the type of ``np.array(seq)``.
Parameters
----------
seq : tuple
Values to cumulatively sum.
initial_zero : bool, optional
If true, the return value is prefixed with a zero.
"""
if isinstance(seq, tuple):
# Look up by identity first, to avoid a linear-time __hash__
# if we've seen this tuple object before.
result = _cumsum(_HashIdWrapper(seq))
else:
# Construct a temporary tuple, and look up by value.
result = _cumsum(tuple(seq))

if not initial_zero:
result = result[1:]
return result
29 changes: 28 additions & 1 deletion dask/array/tests/test_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import dask.array as da
from dask.array.slicing import (_sanitize_index_element, _slice_1d,
new_blockdim, sanitize_index, slice_array,
take, normalize_index, slicing_plan)
take, normalize_index, slicing_plan, cached_cumsum)
from dask.array.utils import assert_eq, same_keys


Expand Down Expand Up @@ -799,6 +799,33 @@ def test_pathological_unsorted_slicing():
assert 'out-of-order' in str(info.list[0])


def test_cached_cumsum():
a = (1, 2, 3, 4)
x = cached_cumsum(a)
y = cached_cumsum(a, initial_zero=True)
np.testing.assert_array_equal(x, [1, 3, 6, 10])
assert x.dtype == np.int64
np.testing.assert_array_equal(y, [0, 1, 3, 6, 10])
assert y.dtype == np.int64


def test_cached_cumsum_nan():
a = (1, np.nan, 3)
x = cached_cumsum(a)
y = cached_cumsum(a, initial_zero=True)
np.testing.assert_array_equal(x, [1, np.nan, np.nan])
assert x.dtype == np.float64
np.testing.assert_array_equal(y, [0, 1, np.nan, np.nan])
assert y.dtype == np.float64


def test_cached_cumsum_non_tuple():
a = [1, 2, 3]
np.testing.assert_array_equal(cached_cumsum(a), [1, 3, 6])
a[1] = 4
np.testing.assert_array_equal(cached_cumsum(a), [1, 5, 8])


@pytest.mark.parametrize('params', [(2, 2, 1), (5, 3, 2)])
def test_setitem_with_different_chunks_preserves_shape(params):
""" Reproducer for https://github.com/dask/dask/issues/3730.
Expand Down

0 comments on commit 1f821f4

Please sign in to comment.