Skip to content

Commit

Permalink
preserve the chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
keewis committed Sep 5, 2024
1 parent b687d19 commit 715a071
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
15 changes: 14 additions & 1 deletion healpix_convolution/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,21 @@ class ConstantPadding(Padding):
def apply(self, data):
common_dtype = np.result_type(data, self.constant_value)

if hasattr(data, "chunks"):
import dask.array as da

values = da.full_like(
data,
shape=self.insert_indices.shape,
dtype=common_dtype,
fill_value=self.constant_value,
chunks=data.chunks[0][-1],
)
else:
values = self.constant_value

return np.insert(
data.astype(common_dtype), self.insert_indices, self.constant_value, axis=-1
data.astype(common_dtype), self.insert_indices, values, axis=-1
)


Expand Down
16 changes: 2 additions & 14 deletions healpix_convolution/tests/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,7 @@ class TestArray:
"dask",
(
False,
pytest.param(
True,
marks=[
requires_dask,
pytest.mark.xfail(reason="chunk sizes can't be preserved for now"),
],
),
pytest.param(True, marks=requires_dask),
),
)
def test_pad(self, dask, ring, mode, kwargs, expected_cell_ids, expected_data):
Expand Down Expand Up @@ -301,13 +295,7 @@ class TestXarray:
"dask",
(
False,
pytest.param(
True,
marks=[
requires_dask,
pytest.mark.xfail(reason="chunk sizes can't be preserved for now"),
],
),
pytest.param(True, marks=requires_dask),
),
)
@pytest.mark.parametrize("type_", (xr.Dataset, xr.DataArray))
Expand Down

0 comments on commit 715a071

Please sign in to comment.