From 8c4ae55b1e9221edb130c887fbd11386f7e67f5f Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 9 Aug 2024 11:42:34 +0100 Subject: [PATCH] Rechunk where dict has missing axes (#546) When rechunking with a dict that doesn't contain all axes, then the chunking should be unchanged for those axes that are missing. In particular, `a.rechunk({})` should be a no-op. This is consistent with Dask (https://github.com/dask/dask/issues/11261) and Xarray (https://github.com/pydata/xarray/pull/9286) --- cubed/core/ops.py | 10 ++++++++++ cubed/tests/test_core.py | 13 +++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index e6d1ebc9..a4850596 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -756,6 +756,16 @@ def wrap(*a, block_id=None, **kw): def rechunk(x, chunks, target_store=None): + if isinstance(chunks, dict): + chunks = {validate_axis(c, x.ndim): v for c, v in chunks.items()} + for i in range(x.ndim): + if i not in chunks: + chunks[i] = x.chunks[i] + elif chunks[i] is None: + chunks[i] = x.chunks[i] + if isinstance(chunks, (tuple, list)): + chunks = tuple(lc if lc is not None else rc for lc, rc in zip(chunks, x.chunks)) + normalized_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype) if x.chunks == normalized_chunks: return x diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index f40c7a0f..fdf90ca4 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -246,10 +246,19 @@ def test_multiple_ops(spec, executor): ) -@pytest.mark.parametrize("new_chunks", [(1, 2), {0: 1, 1: 2}]) -def test_rechunk(spec, executor, new_chunks): +@pytest.mark.parametrize( + ("new_chunks", "expected_chunks"), + [ + ((1, 2), ((1, 1, 1), (2, 1))), + ({0: 1, 1: 2}, ((1, 1, 1), (2, 1))), + ({1: 2}, ((2, 1), (2, 1))), # dim 0 unchanged + ({}, ((2, 1), (1, 1, 1))), # unchanged + ], +) +def test_rechunk(spec, executor, new_chunks, expected_chunks): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec) b = a.rechunk(new_chunks) + assert b.chunks == expected_chunks assert_array_equal( b.compute(executor=executor), np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),