Skip to content

Commit

Permalink
fix the xarray API for padding (#39)
Browse files Browse the repository at this point in the history
* properly apply the padding to `xarray` objects

* decode after applying the padding

* configure `coverage.py`

* also test the `xarray` api for padding

* check that chunksizes can be preserved (will fail for now)

* ignore dask compat code when collecting code coverage info

* remove the dev cli code

* ignore more compat code

* preserve the chunking
  • Loading branch information
keewis authored Sep 5, 2024
1 parent b1c6a18 commit ddcbfe5
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 25 deletions.
2 changes: 1 addition & 1 deletion healpix_convolution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

try:
__version__ = version("healpix_convolution")
except Exception:
except Exception: # pragma: no cover
__version__ = "9999"
2 changes: 1 addition & 1 deletion healpix_convolution/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dask.array as da

dask_array_type = (da.Array,)
except ImportError:
except ImportError: # pragma: no cover
da = None
dask_array_type = ()

Expand Down
16 changes: 1 addition & 15 deletions healpix_convolution/neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dask.array as da

dask_array_type = (da.Array,)
except ImportError:
except ImportError: # pragma: no cover
dask_array_type = ()
da = None

Expand Down Expand Up @@ -219,17 +219,3 @@ def neighbours(cell_ids, *, resolution, indexing_scheme, ring=1):
return _neighbours(
cell_ids, offsets=offsets, nside=nside, indexing_scheme=indexing_scheme
)


if __name__ == "__main__":
resolution = 5
cell_ids = np.arange(12 * 4**resolution, dtype="int16")
indexing_scheme = "nested"

ring = 1

n = neighbours(
cell_ids, resolution=resolution, indexing_scheme=indexing_scheme, ring=ring
)
print(n)
print(n.shape)
15 changes: 14 additions & 1 deletion healpix_convolution/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,21 @@ class ConstantPadding(Padding):
def apply(self, data):
common_dtype = np.result_type(data, self.constant_value)

if hasattr(data, "chunks"):
import dask.array as da

values = da.full_like(
data,
shape=self.insert_indices.shape,
dtype=common_dtype,
fill_value=self.constant_value,
chunks=data.chunks[0][-1],
)
else:
values = self.constant_value

return np.insert(
data.astype(common_dtype), self.insert_indices, self.constant_value, axis=-1
data.astype(common_dtype), self.insert_indices, values, axis=-1
)


Expand Down
2 changes: 1 addition & 1 deletion healpix_convolution/tests/test_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import dask.array as da

has_dask = True
except ImportError:
except ImportError: # pragma: no cover
has_dask = False
da = None

Expand Down
182 changes: 180 additions & 2 deletions healpix_convolution/tests/test_padding.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np
import pytest
import xarray as xr
import xdggs

from healpix_convolution import padding
from healpix_convolution.xarray import padding as xr_padding

try:
import dask.array as da
except ImportError:
except ImportError: # pragma: no cover
da = None
dask_array_type = ()

Expand All @@ -19,7 +21,6 @@


class TestArray:
@pytest.mark.parametrize("dask", (False, pytest.param(True, marks=requires_dask)))
@pytest.mark.parametrize(
["ring", "mode", "kwargs", "expected_cell_ids", "expected_data"],
(
Expand Down Expand Up @@ -138,16 +139,25 @@ class TestArray:
),
),
)
@pytest.mark.parametrize(
"dask",
(
False,
pytest.param(True, marks=requires_dask),
),
)
def test_pad(self, dask, ring, mode, kwargs, expected_cell_ids, expected_data):
grid_info = xdggs.healpix.HealpixInfo(resolution=4, indexing_scheme="nested")
cell_ids = np.array([172, 173])

if not dask:
data = np.full_like(cell_ids, fill_value=1)
expected = expected_data
else:
import dask.array as da

data = da.full_like(cell_ids, fill_value=1, chunks=(1,))
expected = da.from_array(expected_data, chunks=(1,))

padder = padding.pad(
cell_ids, grid_info=grid_info, ring=ring, mode=mode, **kwargs
Expand All @@ -156,6 +166,174 @@ def test_pad(self, dask, ring, mode, kwargs, expected_cell_ids, expected_data):

if dask:
assert isinstance(actual, da.Array)
assert actual.chunks == expected.chunks, "chunksizes differ"

np.testing.assert_equal(padder.cell_ids, expected_cell_ids)
np.testing.assert_equal(actual, expected_data)


class TestXarray:
@pytest.mark.parametrize(
["ring", "mode", "kwargs", "expected_cell_ids", "expected_data"],
(
pytest.param(
1,
"constant",
{"constant_value": np.nan},
np.array([163, 166, 167, 169, 171, 172, 173, 174, 175, 178, 184, 186]),
np.array(
[
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
1,
1,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
]
),
id="constant-ring1-nan",
),
pytest.param(
2,
"constant",
{"constant_value": 0},
np.array(
[
160,
161,
162,
163,
164,
165,
166,
167,
168,
169,
170,
171,
172,
173,
174,
175,
176,
177,
178,
179,
184,
185,
186,
187,
853,
855,
861,
863,
885,
887,
]
),
np.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
]
),
id="constant-ring2-0",
),
pytest.param(
1,
"mean",
{},
np.array([163, 166, 167, 169, 171, 172, 173, 174, 175, 178, 184, 186]),
np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
id="mean-ring1",
),
pytest.param(
1,
"minimum",
{},
np.array([163, 166, 167, 169, 171, 172, 173, 174, 175, 178, 184, 186]),
np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
id="minimum-ring1",
),
),
)
@pytest.mark.parametrize(
"dask",
(
False,
pytest.param(True, marks=requires_dask),
),
)
@pytest.mark.parametrize("type_", (xr.Dataset, xr.DataArray))
def test_pad(
self, dask, ring, mode, kwargs, expected_cell_ids, expected_data, type_
):
grid_info = xdggs.healpix.HealpixInfo(resolution=4, indexing_scheme="nested")
cell_ids = np.array([172, 173])

if not dask:
data_ = np.full_like(cell_ids, fill_value=1)
expected_data_ = expected_data
else:
import dask.array as da

data_ = da.full_like(cell_ids, fill_value=1, chunks=(1,))
expected_data_ = da.from_array(expected_data, chunks=(1,))

ds = xr.Dataset(
{"data": ("cells", data_)},
coords={"cell_ids": ("cells", cell_ids, grid_info.to_dict())},
).pipe(xdggs.decode)
expected_ds = xr.Dataset(
{"data": ("cells", expected_data_)},
coords={"cell_ids": ("cells", expected_cell_ids, grid_info.to_dict())},
).pipe(xdggs.decode)

if type_ is xr.Dataset:
data = ds
expected = expected_ds
else:
data = ds["data"]
expected = expected_ds["data"]

padder = xr_padding.pad(data["cell_ids"], ring=ring, mode=mode, **kwargs)
actual = padder.apply(data)

if dask:
assert actual.chunksizes == expected.chunksizes, "chunksizes differ"

xr.testing.assert_identical(actual, expected)
9 changes: 5 additions & 4 deletions healpix_convolution/xarray/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@
class Padding:
padding: padding.Padding

def _apply_one(self, arr):
return xr.DataArray(self.padding.apply(arr.data), dims="cells")

def _apply(self, ds):
to_drop = [name for name, coord in ds.coords.items() if "cells" in coord.dims]
cell_ids = xr.Variable(
"cells", self.padding.cell_ids, attrs=ds["cell_ids"].attrs
)

return (
ds.drop_vars(to_drop)
.map(self.padding.apply)
.assign_coords(cell_ids=cell_ids)
ds.drop_vars(to_drop).map(self._apply_one).assign_coords(cell_ids=cell_ids)
)

def apply(self, obj):
return utils.call_on_dataset(self._apply, obj)
return utils.call_on_dataset(self._apply, obj).pipe(xdggs.decode)


def pad(
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ packages = ["healpix_convolution"]
[tool.setuptools_scm]
fallback_version = "9999"

[tool.coverage.run]
source = ["healpix_convolution"]
branch = true

[tool.coverage.report]
show_missing = true
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]

[tool.ruff]
builtins = ["ellipsis"]
exclude = [
Expand Down

0 comments on commit ddcbfe5

Please sign in to comment.