Skip to content

Commit

Permalink
Rechunk where dict has missing axes (#546)
Browse files Browse the repository at this point in the history
When rechunking with a dict that doesn't contain all axes, then the chunking
should be unchanged for those axes that are missing.

In particular, `a.rechunk({})` should be a no-op.

This is consistent with Dask (dask/dask#11261)
and Xarray (pydata/xarray#9286)
  • Loading branch information
tomwhite authored Aug 9, 2024
1 parent 1633431 commit 8c4ae55
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
10 changes: 10 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,16 @@ def wrap(*a, block_id=None, **kw):


def rechunk(x, chunks, target_store=None):
if isinstance(chunks, dict):
chunks = {validate_axis(c, x.ndim): v for c, v in chunks.items()}
for i in range(x.ndim):
if i not in chunks:
chunks[i] = x.chunks[i]
elif chunks[i] is None:
chunks[i] = x.chunks[i]
if isinstance(chunks, (tuple, list)):
chunks = tuple(lc if lc is not None else rc for lc, rc in zip(chunks, x.chunks))

normalized_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
if x.chunks == normalized_chunks:
return x
Expand Down
13 changes: 11 additions & 2 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,19 @@ def test_multiple_ops(spec, executor):
)


@pytest.mark.parametrize("new_chunks", [(1, 2), {0: 1, 1: 2}])
def test_rechunk(spec, executor, new_chunks):
@pytest.mark.parametrize(
("new_chunks", "expected_chunks"),
[
((1, 2), ((1, 1, 1), (2, 1))),
({0: 1, 1: 2}, ((1, 1, 1), (2, 1))),
({1: 2}, ((2, 1), (2, 1))), # dim 0 unchanged
({}, ((2, 1), (1, 1, 1))), # unchanged
],
)
def test_rechunk(spec, executor, new_chunks, expected_chunks):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec)
b = a.rechunk(new_chunks)
assert b.chunks == expected_chunks
assert_array_equal(
b.compute(executor=executor),
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
Expand Down

0 comments on commit 8c4ae55

Please sign in to comment.