Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ellipsis (...) in transpose #3421

Merged
merged 14 commits into from
Oct 28, 2019
4 changes: 3 additions & 1 deletion doc/reshaping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ Reordering dimensions
---------------------

To reorder dimensions on a :py:class:`~xarray.DataArray` or across all variables
on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`:
on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`. An
ellipsis (`...`) can be use to represent all other dimensions:

.. ipython:: python

ds = xr.Dataset({'foo': (('x', 'y', 'z'), [[[42]]]), 'bar': (('y', 'z'), [[24]])})
ds.transpose('y', 'z', 'x')
ds.transpose(..., 'x') # equivalent
ds.transpose() # reverses all dimensions

Expand and squeeze dimensions
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Breaking changes

New Features
~~~~~~~~~~~~
- :py:meth:`Dataset.transpose` and :py:meth:`DataArray.transpose` now support an ellipsis (`...`)
to represent all 'other' dimensions. For example, to move one dimension to the front,
use `.transpose('x', ...)`. (:pull:`3421`)
By `Maximilian Roos <https://github.com/max-sixty>`_
- Changed `xr.ALL_DIMS` to equal python's `Ellipsis` (`...`), and changed internal usages to use
`...` directly. As before, you can use this to instruct a `groupby` operation
to reduce over all dimensions. While we have no plans to remove `xr.ALL_DIMS`, we suggest
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,7 @@ tag_prefix = v
parentdir_prefix = xarray-

[aliases]
test = pytest
test = pytest

[pytest-watch]
nobeep = True
7 changes: 1 addition & 6 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,12 +1863,7 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = None) -> "DataArra
Dataset.transpose
"""
if dims:
if set(dims) ^ set(self.dims):
raise ValueError(
"arguments to transpose (%s) must be "
"permuted array dimensions (%s)" % (dims, tuple(self.dims))
)

dims = tuple(utils.infix_dims(dims, self.dims))
variable = self.variable.transpose(*dims)
if transpose_coords:
coords: Dict[Hashable, Variable] = {}
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3712,14 +3712,14 @@ def transpose(self, *dims: Hashable) -> "Dataset":
DataArray.transpose
"""
if dims:
if set(dims) ^ set(self.dims):
if set(dims) ^ set(self.dims) and ... not in dims:
raise ValueError(
"arguments to transpose (%s) must be "
"permuted dataset dimensions (%s)" % (dims, tuple(self.dims))
)
ds = self.copy()
for name, var in self._variables.items():
var_dims = tuple(dim for dim in dims if dim in var.dims)
var_dims = tuple(dim for dim in dims if dim in (var.dims + (...,)))
ds._variables[name] = var.transpose(*var_dims)
return ds

Expand Down
25 changes: 25 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AbstractSet,
Any,
Callable,
Collection,
Container,
Dict,
Hashable,
Expand Down Expand Up @@ -660,6 +661,30 @@ def __len__(self) -> int:
return len(self._data) - num_hidden


def infix_dims(dims_supplied: Collection, dims_all: Collection) -> Iterator:
"""
Resolves a supplied list containing an ellispsis representing other items, to
a generator with the 'realized' list of all items
"""
if ... in dims_supplied:
if len(set(dims_all)) != len(dims_all):
raise ValueError("Cannot use ellipsis with repeated dims")
if len([d for d in dims_supplied if d == ...]) > 1:
raise ValueError("More than one ellipsis supplied")
other_dims = [d for d in dims_all if d not in dims_supplied]
for d in dims_supplied:
if d == ...:
yield from other_dims
else:
yield d
else:
if set(dims_supplied) ^ set(dims_all):
raise ValueError(
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
)
yield from dims_supplied


def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
""" Get an new dimension name based on new_dim, that is not used in dims.
If the same name exists, we add an underscore(s) in the head.
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
OrderedSet,
decode_numpy_dict_values,
either_dict_or_kwargs,
infix_dims,
ensure_us_time_resolution,
)

Expand Down Expand Up @@ -1228,6 +1229,7 @@ def transpose(self, *dims) -> "Variable":
"""
if len(dims) == 0:
dims = self.dims[::-1]
dims = tuple(infix_dims(dims, self.dims))
axes = self.get_axis_num(dims)
if len(dims) < 2: # no need to transpose if only one dimension
return self.copy(deep=False)
Expand Down
3 changes: 3 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,21 @@ def source_ndarray(array):


def assert_equal(a, b):
__tracebackhide__ = True
xarray.testing.assert_equal(a, b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def assert_identical(a, b):
__tracebackhide__ = True
xarray.testing.assert_identical(a, b)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)


def assert_allclose(a, b, **kwargs):
__tracebackhide__ = True
xarray.testing.assert_allclose(a, b, **kwargs)
xarray.testing._assert_internal_invariants(a)
xarray.testing._assert_internal_invariants(b)
4 changes: 4 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,10 @@ def test_transpose(self):
)
assert_equal(expected, actual)

# same as previous but with ellipsis
actual = da.transpose("z", ..., "x", transpose_coords=True)
assert_equal(expected, actual)

with pytest.raises(ValueError):
da.transpose("x", "y")

Expand Down
27 changes: 25 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4675,6 +4675,10 @@ def test_dataset_transpose(self):
)
assert_identical(expected, actual)

actual = ds.transpose(...)
expected = ds
assert_identical(expected, actual)

actual = ds.transpose("x", "y")
expected = ds.apply(lambda x: x.transpose("x", "y", transpose_coords=True))
assert_identical(expected, actual)
Expand All @@ -4690,13 +4694,32 @@ def test_dataset_transpose(self):
expected_dims = tuple(d for d in new_order if d in ds[k].dims)
assert actual[k].dims == expected_dims

with raises_regex(ValueError, "arguments to transpose"):
# same as above but with ellipsis
new_order = ("dim2", "dim3", "dim1", "time")
actual = ds.transpose("dim2", "dim3", ...)
for k in ds.variables:
expected_dims = tuple(d for d in new_order if d in ds[k].dims)
assert actual[k].dims == expected_dims

with raises_regex(ValueError, "permuted"):
ds.transpose("dim1", "dim2", "dim3")
with raises_regex(ValueError, "arguments to transpose"):
with raises_regex(ValueError, "permuted"):
ds.transpose("dim1", "dim2", "dim3", "time", "extra_dim")

assert "T" not in dir(ds)

def test_dataset_ellipsis_transpose_different_ordered_vars(self):
# https://github.com/pydata/xarray/issues/1081#issuecomment-544350457
ds = Dataset(
dict(
a=(("w", "x", "y", "z"), np.ones((2, 3, 4, 5))),
b=(("x", "w", "y", "z"), np.zeros((3, 2, 4, 5))),
)
)
result = ds.transpose(..., "z", "y")
assert list(result["a"].dims) == list("wxzy")
assert list(result["b"].dims) == list("xwzy")

def test_dataset_retains_period_index_on_transpose(self):

ds = create_test_data()
Expand Down
24 changes: 24 additions & 0 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,27 @@ def test_either_dict_or_kwargs():

with pytest.raises(ValueError, match=r"foo"):
result = either_dict_or_kwargs(dict(a=1), dict(a=1), "foo")


@pytest.mark.parametrize(
["supplied", "all_", "expected"],
[
(list("abc"), list("abc"), list("abc")),
(["a", ..., "c"], list("abc"), list("abc")),
(["a", ...], list("abc"), list("abc")),
(["c", ...], list("abc"), list("cab")),
([..., "b"], list("abc"), list("acb")),
([...], list("abc"), list("abc")),
],
)
def test_infix_dims(supplied, all_, expected):
result = list(utils.infix_dims(supplied, all_))
assert result == expected


@pytest.mark.parametrize(
["supplied", "all_"], [([..., ...], list("abc")), ([...], list("aac"))]
)
def test_infix_dims_errors(supplied, all_):
with pytest.raises(ValueError):
list(utils.infix_dims(supplied, all_))
3 changes: 3 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,9 @@ def test_transpose(self):
w2 = Variable(["d", "b", "c", "a"], np.einsum("abcd->dbca", x))
assert w2.shape == (5, 3, 4, 2)
assert_identical(w2, w.transpose("d", "b", "c", "a"))
assert_identical(w2, w.transpose("d", ..., "a"))
assert_identical(w2, w.transpose("d", "b", "c", ...))
assert_identical(w2, w.transpose(..., "b", "c", "a"))
assert_identical(w, w2.transpose("a", "b", "c", "d"))
w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x))
assert_identical(w, w3.transpose("a", "b", "c", "d"))
Expand Down