diff --git a/lib/cartopy/mpl/geoaxes.py b/lib/cartopy/mpl/geoaxes.py index c1d84bf7c..26cc28d5e 100644 --- a/lib/cartopy/mpl/geoaxes.py +++ b/lib/cartopy/mpl/geoaxes.py @@ -45,7 +45,8 @@ from cartopy.mpl.slippy_image_artist import SlippyImageArtist -assert packaging.version.parse(mpl.__version__).release[:2] >= (3, 4), \ +_MPL_VERSION = packaging.version.parse(mpl.__version__) +assert _MPL_VERSION.release >= (3, 4), \ 'Cartopy is only supported with Matplotlib 3.4 or greater.' # A nested mapping from path, source CRS, and target projection to the @@ -1796,7 +1797,7 @@ def _wrap_args(self, *args, **kwargs): kwargs['shading'] = 'flat' X = np.asanyarray(args[0]) Y = np.asanyarray(args[1]) - nrows, ncols = np.asanyarray(args[2]).shape + nrows, ncols = np.asanyarray(args[2]).shape[:2] Nx = X.shape[-1] Ny = Y.shape[0] if X.ndim != 2 or X.shape[0] == 1: @@ -1843,12 +1844,13 @@ def _wrap_quadmesh(self, collection, **kwargs): Ny, Nx, _ = coords.shape if kwargs.get('shading') == 'gouraud': # Gouraud shading has the same shape for coords and data - data_shape = Ny, Nx + data_shape = Ny, Nx, -1 else: - data_shape = Ny - 1, Nx - 1 + data_shape = Ny - 1, Nx - 1, -1 # data array C = collection.get_array().reshape(data_shape) - + if C.shape[-1] == 1: + C = C.squeeze(axis=-1) transformed_pts = self.projection.transform_points( t, coords[..., 0], coords[..., 1]) @@ -1921,13 +1923,12 @@ def _wrap_quadmesh(self, collection, **kwargs): "map it must be fully transparent.", stacklevel=3) - # The original data mask (regardless of wrapped cells) - C_mask = getattr(C, 'mask', None) + # Get hold of masked versions of the array to be passed to set_array + # methods of QuadMesh and PolyQuadMesh + pcolormesh_data, pcolor_data, pcolor_mask = \ + cartopy.mpl.geocollection._split_wrapped_mesh_data(C, mask) - # create the masked array to be used with this pcolormesh - full_mask = mask if C_mask is None else mask | C_mask - pcolormesh_data = np.ma.array(C, mask=full_mask) - collection.set_array(pcolormesh_data.ravel()) + collection.set_array(pcolormesh_data) # plot with slightly lower zorder to avoid odd issue # where the main plot is obscured @@ -1943,25 +1944,32 @@ def _wrap_quadmesh(self, collection, **kwargs): # `pcolor` only draws polygons where the data is not # masked, so this will only draw a limited subset of # polygons that were actually wrapped. - # We will add the original data mask in later to - # make sure that set_array can work in future - # calls on the proper sized array inputs. - # NOTE: we don't use C.data here because C.data could - # contain nan's which would be masked in the - # pcolor routines, which we don't want. We will - # fill in the proper data later with set_array() - # calls. - pcolor_data = np.ma.array(np.zeros(C.shape), - mask=~mask) - pcolor_col = self.pcolor(coords[..., 0], coords[..., 1], - pcolor_data, zorder=zorder, - **kwargs) - # Now add back in the masked data if there was any - full_mask = ~mask if C_mask is None else ~mask | C_mask - pcolor_data = np.ma.array(C, mask=full_mask) - # The pcolor_col is now possibly shorter than the - # actual collection, so grab the masked cells - pcolor_col.set_array(pcolor_data[mask].ravel()) + + if _MPL_VERSION.release[:2] < (3, 8): + # We will add the original data mask in later to + # make sure that set_array can work in future + # calls on the proper sized array inputs. + # NOTE: we don't use C.data here because C.data could + # contain nan's which would be masked in the + # pcolor routines, which we don't want. We will + # fill in the proper data later with set_array() + # calls. + pcolor_zeros = np.ma.array(np.zeros(C.shape), mask=pcolor_mask) + pcolor_col = self.pcolor(coords[..., 0], coords[..., 1], + pcolor_zeros, zorder=zorder, + **kwargs) + + # The pcolor_col is now possibly shorter than the + # actual collection, so grab the masked cells + pcolor_col.set_array(pcolor_data[mask].ravel()) + else: + pcolor_col = self.pcolor(coords[..., 0], coords[..., 1], + pcolor_data, zorder=zorder, + **kwargs) + # Currently pcolor_col.get_array() will return a compressed array + # and warn unless we explicitly set the 2D array. This should be + # unnecessary with future matplotlib versions. + pcolor_col.set_array(pcolor_data) pcolor_col.set_cmap(cmap) pcolor_col.set_norm(norm) @@ -1972,7 +1980,7 @@ def _wrap_quadmesh(self, collection, **kwargs): # put the pcolor_col and mask on the pcolormesh # collection so that users can do things post # this method - collection._wrapped_mask = mask.ravel() + collection._wrapped_mask = mask collection._wrapped_collection_fix = pcolor_col return collection diff --git a/lib/cartopy/mpl/geocollection.py b/lib/cartopy/mpl/geocollection.py index 6680cf725..9f653ab10 100644 --- a/lib/cartopy/mpl/geocollection.py +++ b/lib/cartopy/mpl/geocollection.py @@ -3,8 +3,43 @@ # This file is part of Cartopy and is released under the LGPL license. # See COPYING and COPYING.LESSER in the root of the repository for full # licensing details. +import matplotlib as mpl from matplotlib.collections import QuadMesh import numpy as np +import numpy.ma as ma +import packaging + + +_MPL_VERSION = packaging.version.parse(mpl.__version__) + + +def _split_wrapped_mesh_data(C, mask): + """ + Helper function for splitting GeoQuadMesh array values between the + pcolormesh and pcolor objects when wrapping. Apply a mask to the grid + cells that should not be plotted with each method. + + """ + # The original data mask (regardless of wrapped cells) + C_mask = getattr(C, 'mask', None) + if C.ndim == 3: + # RGB(A) array. + if _MPL_VERSION.release < (3, 8): + raise ValueError("GeoQuadMesh wrapping for RGB(A) requires " + "Matplotlib v3.8 or later") + + # mask will need an extra trailing dimension + mask = np.broadcast_to(mask[..., np.newaxis], C.shape) + + # create the masked array to be used with pcolormesh + full_mask = mask if C_mask is None else mask | C_mask + pcolormesh_data = ma.array(C, mask=full_mask) + + # create the masked array to be used with pcolor + full_mask = ~mask if C_mask is None else ~mask | C_mask + pcolor_data = ma.array(C, mask=full_mask) + + return pcolormesh_data, pcolor_data, ~mask class GeoQuadMesh(QuadMesh): @@ -21,25 +56,57 @@ def get_array(self): A = super().get_array().copy() # If the input array has a mask, retrieve the associated data if hasattr(self, '_wrapped_mask'): - A[self._wrapped_mask] = self._wrapped_collection_fix.get_array() + pcolor_data = self._wrapped_collection_fix.get_array() + mask = self._wrapped_mask + if _MPL_VERSION.release[:2] < (3, 8): + A[mask] = pcolor_data + else: + if A.ndim == 3: # RGB(A) data. Need to broadcast mask. + mask = mask[:, :, np.newaxis] + # np.copyto is not implemented for masked arrays so handle the + # mask explicitly + np.copyto(A.mask, pcolor_data.mask, where=mask) + np.copyto(A, pcolor_data, where=mask) + return A def set_array(self, A): - # raise right away if A is 2-dimensional. - if A.ndim > 1: - raise ValueError('Collections can only map rank 1 arrays. ' - 'You likely want to call with a flattened array ' - 'using collection.set_array(A.ravel()) instead.') + # Check the shape is appropriate up front. + if _MPL_VERSION.release[:2] < (3, 8): + # Need to figure out existing shape from the coordinates. + height, width = self._coordinates.shape[0:-1] + if self._shading == 'flat': + h, w = height - 1, width - 1 + else: + h, w = height, width + else: + h, w = super().get_array().shape[:2] + + ok_shapes = [(h, w, 3), (h, w, 4), (h, w), (h * w,)] + if A.shape not in ok_shapes: + ok_shape_str = ' or '.join(map(str, ok_shapes)) + raise ValueError( + f"A should have shape {ok_shape_str}, not {A.shape}") + + if A.ndim == 1: + # Always use array with at least two dimensions. This is + # inconsistent with QuadMesh which stores whatever you give it, but + # for the wrapped case we need to match the 2D mask. Storing the + # 2D array also allows us to calculate ok_shapes on subsequent + # calls without using the private QuadMesh._shading attribute. + A = A.reshape((h, w)) # Only use the mask attribute if it is there. if hasattr(self, '_wrapped_mask'): + # Update the pcolor data with the wrapped masked data - self._wrapped_collection_fix.set_array(A[self._wrapped_mask]) - # If the input array was a masked array, keep that data masked - if hasattr(A, 'mask'): - A = np.ma.array(A, mask=self._wrapped_mask | A.mask) + A, pcolor_data, _ = _split_wrapped_mesh_data(A, self._wrapped_mask) + + if _MPL_VERSION.release[:2] < (3, 8): + self._wrapped_collection_fix.set_array( + pcolor_data[self._wrapped_mask].ravel()) else: - A = np.ma.array(A, mask=self._wrapped_mask) + self._wrapped_collection_fix.set_array(pcolor_data) # Now that we have prepared the collection data, call on # through to the underlying implementation. diff --git a/lib/cartopy/tests/mpl/test_mpl_integration.py b/lib/cartopy/tests/mpl/test_mpl_integration.py index c2d737944..8a60b6d96 100644 --- a/lib/cartopy/tests/mpl/test_mpl_integration.py +++ b/lib/cartopy/tests/mpl/test_mpl_integration.py @@ -6,6 +6,7 @@ import re +import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np import pytest @@ -240,10 +241,46 @@ def test_cursor_values(): r.encode('ascii', 'ignore')) +SKIP_PRE_MPL38 = pytest.mark.skipif( + MPL_VERSION.release[:2] < (3, 8), reason='mpl < 3.8') +PARAMETRIZE_PCOLORMESH_WRAP = pytest.mark.parametrize( + 'mesh_data_kind', + [ + 'standard', + pytest.param('rgb', marks=SKIP_PRE_MPL38), + pytest.param('rgba', marks=SKIP_PRE_MPL38), + ], + ids=['standard', 'rgb', 'rgba'], +) + + +def _to_rgb(data, mesh_data_kind): + """ + Helper function to convert array to RGB(A) where required + """ + if mesh_data_kind in ('rgb', 'rgba'): + cmap = plt.get_cmap() + norm = mcolors.Normalize() + new_data = cmap(norm(data)) + if mesh_data_kind == 'rgb': + new_data = new_data[..., 0:3] + if np.ma.is_masked(data): + # Use data's mask as an alpha channel + mask = np.ma.getmaskarray(data) + mask = np.broadcast_to( + mask[..., np.newaxis], new_data.shape).copy() + new_data = np.ma.array(new_data, mask=mask) + + return new_data + + return data + + +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap1.png', tolerance=1.27) -def test_pcolormesh_global_with_wrap1(): +def test_pcolormesh_global_with_wrap1(mesh_data_kind): # make up some realistic data with bounds (such as data from the UM) nx, ny = 36, 18 xbnds = np.linspace(0, 360, nx, endpoint=True) @@ -254,6 +291,8 @@ def test_pcolormesh_global_with_wrap1(): data = data[:-1, :-1] fig = plt.figure() + data = _to_rgb(data, mesh_data_kind) + ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree()) ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False) ax.coastlines() @@ -267,7 +306,8 @@ def test_pcolormesh_global_with_wrap1(): return fig -def test_pcolormesh_get_array_with_mask(): +@PARAMETRIZE_PCOLORMESH_WRAP +def test_pcolormesh_get_array_with_mask(mesh_data_kind): # make up some realistic data with bounds (such as data from the UM) nx, ny = 36, 18 xbnds = np.linspace(0, 360, nx, endpoint=True) @@ -275,7 +315,10 @@ def test_pcolormesh_get_array_with_mask(): x, y = np.meshgrid(xbnds, ybnds) data = np.exp(np.sin(np.deg2rad(x)) + np.cos(np.deg2rad(y))) + data[5, :] = np.nan # Check that missing data is handled - GH#2208 data = data[:-1, :-1] + data = _to_rgb(data, mesh_data_kind) + fig = plt.figure() ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree()) @@ -284,9 +327,10 @@ def test_pcolormesh_get_array_with_mask(): 'No pcolormesh wrapping was done when it should have been.' result = c.get_array() - assert not np.ma.is_masked(result) - assert np.array_equal(data.ravel(), result), \ - 'Data supplied does not match data retrieved in wrapped case' + np.testing.assert_array_equal(np.ma.getmask(result), np.isnan(data)) + np.testing.assert_array_equal( + data, result, + err_msg='Data supplied does not match data retrieved in wrapped case') ax.coastlines() ax.set_global() # make sure everything is visible @@ -298,7 +342,9 @@ def test_pcolormesh_get_array_with_mask(): x, y = np.meshgrid(xbnds, ybnds) data = np.exp(np.sin(np.deg2rad(x)) + np.cos(np.deg2rad(y))) + data[5, :] = np.nan data2 = data[:-1, :-1] + data2 = _to_rgb(data2, mesh_data_kind) ax = fig.add_subplot(2, 1, 2, projection=ccrs.PlateCarree()) c = ax.pcolormesh(xbnds, ybnds, data2, transform=ccrs.PlateCarree()) @@ -309,15 +355,22 @@ def test_pcolormesh_get_array_with_mask(): 'pcolormesh wrapping was done when it should not have been.' result = c.get_array() - assert not np.ma.is_masked(result) - assert np.array_equal(data2.ravel(), result), \ - 'Data supplied does not match data retrieved in unwrapped case' + + expected = data2 + if MPL_VERSION.release[:2] < (3, 8): + expected = expected.ravel() + + np.testing.assert_array_equal(np.ma.getmask(result), np.isnan(expected)) + np.testing.assert_array_equal( + expected, result, + 'Data supplied does not match data retrieved in unwrapped case') +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap2.png', tolerance=1.87) -def test_pcolormesh_global_with_wrap2(): +def test_pcolormesh_global_with_wrap2(mesh_data_kind): # make up some realistic data with bounds (such as data from the UM) nx, ny = 36, 18 xbnds, xstep = np.linspace(0, 360, nx - 1, retstep=True, endpoint=True) @@ -332,6 +385,8 @@ def test_pcolormesh_global_with_wrap2(): data = data[:-1, :-1] fig = plt.figure() + data = _to_rgb(data, mesh_data_kind) + ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree()) ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False) ax.coastlines() @@ -345,10 +400,11 @@ def test_pcolormesh_global_with_wrap2(): return fig +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap3.png', tolerance=1.42) -def test_pcolormesh_global_with_wrap3(): +def test_pcolormesh_global_with_wrap3(mesh_data_kind): nx, ny = 33, 17 xbnds = np.linspace(-1.875, 358.125, nx, endpoint=True) ybnds = np.linspace(91.25, -91.25, ny, endpoint=True) @@ -366,6 +422,8 @@ def test_pcolormesh_global_with_wrap3(): data = np.ma.masked_greater(data, 2.6) fig = plt.figure() + data = _to_rgb(data, mesh_data_kind) + ax = fig.add_subplot(3, 1, 1, projection=ccrs.PlateCarree(-45)) c = ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False) @@ -388,10 +446,11 @@ def test_pcolormesh_global_with_wrap3(): return fig +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap3.png', tolerance=1.42) -def test_pcolormesh_set_array_with_mask(): +def test_pcolormesh_set_array_with_mask(mesh_data_kind): """Testing that set_array works with masked arrays properly.""" nx, ny = 33, 17 xbnds = np.linspace(-1.875, 358.125, nx, endpoint=True) @@ -414,10 +473,15 @@ def test_pcolormesh_set_array_with_mask(): bad_data_mask = np.ma.array(bad_data, mask=~data.mask) fig = plt.figure() + data = _to_rgb(data, mesh_data_kind) + bad_data = _to_rgb(bad_data, mesh_data_kind) + bad_data_mask = _to_rgb(bad_data_mask, mesh_data_kind) + ax = fig.add_subplot(3, 1, 1, projection=ccrs.PlateCarree(-45)) c = ax.pcolormesh(xbnds, ybnds, bad_data, norm=norm, transform=ccrs.PlateCarree(), snap=False) - c.set_array(data.ravel()) + + c.set_array(data) assert c._wrapped_collection_fix is not None, \ 'No pcolormesh wrapping was done when it should have been.' @@ -427,7 +491,10 @@ def test_pcolormesh_set_array_with_mask(): ax = fig.add_subplot(3, 1, 2, projection=ccrs.PlateCarree(-1.87499952)) c2 = ax.pcolormesh(xbnds, ybnds, bad_data_mask, norm=norm, transform=ccrs.PlateCarree(), snap=False) - c2.set_array(data.ravel()) + if mesh_data_kind == 'standard': + c2.set_array(data.ravel()) + else: + c2.set_array(data) ax.coastlines() ax.set_global() # make sure everything is visible @@ -439,6 +506,37 @@ def test_pcolormesh_set_array_with_mask(): return fig +def test_pcolormesh_set_array_nowrap(): + # Roundtrip check that set_array works with the correct shaped arrays + nx, ny = 36, 18 + xbnds = np.linspace(-60, 60, nx, endpoint=True) + ybnds = np.linspace(-80, 80, ny, endpoint=True) + xbnds, ybnds = np.meshgrid(xbnds, ybnds) + + rng = np.random.default_rng() + data = rng.random((ny - 1, nx - 1)) + + ax = plt.figure().add_subplot(projection=ccrs.PlateCarree()) + mesh = ax.pcolormesh(xbnds, ybnds, data) + assert not hasattr(mesh, '_wrapped_collection_fix') + + expected = data + if MPL_VERSION.release[:2] < (3, 8): + expected = expected.ravel() + np.testing.assert_array_equal(mesh.get_array(), expected) + + # For backwards compatibility, check we can set a 1D array + data = rng.random((nx - 1) * (ny - 1)) + mesh.set_array(data) + np.testing.assert_array_equal( + mesh.get_array(), data.reshape(ny - 1, nx - 1)) + + # Check that we can set a 2D array even if previous was flat + data = rng.random((ny - 1, nx - 1)) + mesh.set_array(data) + np.testing.assert_array_equal(mesh.get_array(), data) + + @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap3.png', tolerance=1.42) @@ -583,6 +681,13 @@ def test_pcolormesh_nan_wrap(): ax = plt.axes(projection=ccrs.PlateCarree()) mesh = ax.pcolormesh(xs, ys, data) pcolor = getattr(mesh, "_wrapped_collection_fix") + if MPL_VERSION.release[:2] < (3, 8): + assert len(pcolor.get_paths()) == 2 + else: + assert not pcolor.get_paths() + + # Check that we can populate the pcolor with some data. + mesh.set_array(np.ones((2, 2))) assert len(pcolor.get_paths()) == 2 @@ -617,14 +722,18 @@ def test_pcolormesh_mercator_wrap(): return ax.figure +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_mercator_wrap.png') -def test_pcolormesh_wrap_set_array(): +def test_pcolormesh_wrap_set_array(mesh_data_kind): x = np.linspace(0, 360, 73) y = np.linspace(-87.5, 87.5, 36) X, Y = np.meshgrid(*[np.deg2rad(c) for c in (x, y)]) Z = np.cos(Y) + 0.375 * np.sin(2. * X) Z = Z[:-1, :-1] + + Z = _to_rgb(Z, mesh_data_kind) + ax = plt.axes(projection=ccrs.Mercator()) norm = plt.Normalize(np.min(Z), np.max(Z)) ax.coastlines() @@ -632,7 +741,7 @@ def test_pcolormesh_wrap_set_array(): coll = ax.pcolormesh(x, y, np.ones(Z.shape), norm=norm, transform=ccrs.PlateCarree(), snap=False) # Now update the plot with the set_array method - coll.set_array(Z.ravel()) + coll.set_array(Z) return ax.figure