Skip to content
forked from pydata/xarray

Commit

Permalink
make plotting work with transposed nondim coords.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 2, 2019
1 parent 53c5199 commit 7aed950
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ Bug fixes
:py:meth:`xarray.core.groupby.DatasetGroupBy.reduce` when reducing over multiple dimensions.
(:issue:`3402`). By `Deepak Cherian <https://github.com/dcherian/>`_

- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`)
By `Deepak Cherian <https://github.com/dcherian>`_.

Documentation
~~~~~~~~~~~~~
- Fix leap year condition in example (http://xarray.pydata.org/en/stable/examples/monthly-means.html) by `Mickaël Lalande <https://github.com/mickaellalande>`_.
Expand Down
14 changes: 10 additions & 4 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,10 +672,16 @@ def newplotfunc(

# check if we need to broadcast one dimension
if xval.ndim < yval.ndim:
xval = np.broadcast_to(xval, yval.shape)
if xval.shape[0] == yval.shape[0]:
xval = np.broadcast_to(xval[:, np.newaxis], yval.shape)
else:
xval = np.broadcast_to(xval[np.newaxis, :], yval.shape)

if yval.ndim < xval.ndim:
yval = np.broadcast_to(yval, xval.shape)
elif yval.ndim < xval.ndim:
if yval.shape[0] == xval.shape[0]:
yval = np.broadcast_to(yval[:, np.newaxis], xval.shape)
else:
yval = np.broadcast_to(yval[np.newaxis, :], xval.shape)

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
Expand All @@ -687,7 +693,7 @@ def newplotfunc(
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
if dims != darray.dims:
darray = darray.transpose(*dims, transpose_coords=True)
elif darray[xlab].dims[-1] == darray.dims[0]:
elif xval.shape[-1] == darray.shape[0]:
darray = darray.transpose(transpose_coords=True)

# Pass the data as a masked ndarray too
Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,3 +2145,19 @@ def test_yticks_kwarg(self, da):
da.plot(yticks=np.arange(5))
expected = np.arange(5)
assert np.all(plt.gca().get_yticks() == expected)


@requires_matplotlib
@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"])
def test_plot_transposed_nondim_coord(plotfunc):
x = np.linspace(0, 10, 101)
h = np.linspace(3, 7, 101)
s = np.linspace(0, 1, 51)
z = s[:, np.newaxis] * h[np.newaxis, :]
da = xr.DataArray(
np.sin(x) * np.cos(z),
dims=["s", "x"],
coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)},
)
getattr(da.plot, plotfunc)(x="x", y="zt")
getattr(da.plot, plotfunc)(x="zt", y="x")

0 comments on commit 7aed950

Please sign in to comment.