Skip to content

Commit

Permalink
Implement groupby_blockwise using map_selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Oct 29, 2024
1 parent 3ea32ed commit cd358c6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
32 changes: 13 additions & 19 deletions cubed/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_blocks, map_direct, reduction
from cubed.utils import array_memory, get_item
from cubed.core.ops import map_blocks, map_selection, reduction
from cubed.utils import get_item
from cubed.vendor.dask.array.core import normalize_chunks

if TYPE_CHECKING:
Expand Down Expand Up @@ -181,57 +181,51 @@ def groupby_blockwise(
)
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although read_chunks will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem
def selection_function(out_key):
out_coords = out_key[1:]
block_id = out_coords
return get_item(read_chunks, block_id)

# memory allocated for largest of (variable sized) read_chunks
read_chunksize = tuple(max(c) for c in read_chunks)
extra_projected_mem += array_memory(x.dtype, read_chunksize)
# in general each selection overlaps 2 input blocks
max_num_input_blocks = 2

return map_direct(
return map_selection(
_process_blockwise_chunk,
selection_function,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
max_num_input_blocks=max_num_input_blocks,
axis=axis,
by=by,
blockwise_func=func,
read_chunks=read_chunks,
by_read_chunks=by_read_chunks,
target_chunks=target_chunks,
groups_per_chunk=groups_per_chunk,
extra_func_kwargs=extra_func_kwargs,
)


def _process_blockwise_chunk(
x,
*arrays,
a,
axis=None,
by=None,
blockwise_func=None,
read_chunks=None,
by_read_chunks=None,
target_chunks=None,
groups_per_chunk=None,
block_id=None,
**kwargs,
):
array = arrays[0].zarray # underlying Zarr array (or virtual array)
idx = block_id
bi = idx[axis]

result = array[get_item(read_chunks, idx)]
by = by[get_item(by_read_chunks, (bi,))]

start_group = bi * groups_per_chunk

return blockwise_func(
result,
a,
by,
axis=axis,
start_group=start_group,
Expand Down
8 changes: 2 additions & 6 deletions cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,13 @@ def test_get_chunks_for_groups(
def test_groupby_blockwise_axis0():
a = xp.ones((10, 3), dtype=nxp.int32, chunks=(6, 2))
b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
extra_func_kwargs = dict(dtype=nxp.int32)
c = groupby_blockwise(
a,
b,
func=_sum_reduction_func,
axis=0,
dtype=nxp.int64,
num_groups=6,
extra_func_kwargs=extra_func_kwargs,
)
assert_array_equal(
c.compute(),
Expand All @@ -124,15 +122,13 @@ def test_groupby_blockwise_axis0():
def test_groupby_blockwise_axis1():
a = xp.ones((3, 10), dtype=nxp.int32, chunks=(6, 2))
b = nxp.asarray([0, 0, 0, 1, 1, 2, 2, 4, 4, 4])
extra_func_kwargs = dict(dtype=nxp.int32)
c = groupby_blockwise(
a,
b,
func=_sum_reduction_func,
axis=1,
dtype=nxp.int64,
num_groups=6,
extra_func_kwargs=extra_func_kwargs,
)
assert_array_equal(
c.compute(),
Expand All @@ -146,7 +142,7 @@ def test_groupby_blockwise_axis1():
)


def _sum_reduction_func(arr, by, axis, start_group, num_groups, dtype):
def _sum_reduction_func(arr, by, axis, start_group, num_groups):
# change 'by' so it starts from 0 for each chunk
by = by - start_group
return npg.aggregate(by, arr, func="sum", dtype=dtype, axis=axis, size=num_groups)
return npg.aggregate(by, arr, func="sum", axis=axis, size=num_groups)

0 comments on commit cd358c6

Please sign in to comment.