diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 624af0f5c06..f848143367d 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -672,16 +672,22 @@ def newplotfunc( # check if we need to broadcast one dimension if xval.ndim < yval.ndim: + dims = darray[ylab].dims if xval.shape[0] == yval.shape[0]: xval = np.broadcast_to(xval[:, np.newaxis], yval.shape) else: xval = np.broadcast_to(xval[np.newaxis, :], yval.shape) elif yval.ndim < xval.ndim: + dims = darray[xlab].dims if yval.shape[0] == xval.shape[0]: yval = np.broadcast_to(yval[:, np.newaxis], xval.shape) else: yval = np.broadcast_to(yval[np.newaxis, :], xval.shape) + elif xval.ndim == 2: + dims = darray[xlab].dims + else: + dims = (darray[ylab].dims[0], darray[xlab].dims[0]) # May need to transpose for correct x, y labels # xlab may be the name of a coord, we have to check for dim names @@ -691,10 +697,9 @@ def newplotfunc( # we transpose to (y, x, color) to make this work. yx_dims = (ylab, xlab) dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) - if dims != darray.dims: - darray = darray.transpose(*dims, transpose_coords=True) - elif darray[xlab].dims[-1] == darray.dims[0]: - darray = darray.transpose(transpose_coords=True) + + if dims != darray.dims: + darray = darray.transpose(*dims, transpose_coords=True) # Pass the data as a masked ndarray too zval = darray.to_masked_array(copy=False) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index ada9eb0feb0..a10f0d9a67e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2168,7 +2168,13 @@ def test_plot_transposed_nondim_coord(plotfunc): getattr(da.plot, plotfunc)(x="zt", y="x") -def test_plot_transposes_properly(): +@requires_matplotlib +@pytest.mark.parametrize("plotfunc", ["pcolormesh", "imshow"]) +def test_plot_transposes_properly(plotfunc): + # test that we aren't mistakenly transposing when the 2 dimensions have equal sizes. da = xr.DataArray([np.sin(2 * np.pi / 10 * np.arange(10))] * 10, dims=("y", "x")) - hdl = da.plot(x="x", y="y") - assert np.all(hdl.get_array() == da.to_masked_array().ravel()) + hdl = getattr(da.plot, plotfunc)(x="x", y="y") + # get_array doesn't work for contour, contourf. It returns the colormap intervals. + # pcolormesh returns 1D array but imshow returns a 2D array so it is necessary + # to ravel() on the LHS + assert np.all(hdl.get_array().ravel() == da.to_masked_array().ravel())