Skip to content

Commit

Permalink
Merge pull request #2166 from rcomer/pcolormesh-rgba
Browse files Browse the repository at this point in the history
ENH: enable passing RGB(A) to polormesh
  • Loading branch information
greglucas authored Jul 14, 2023
2 parents 8c191c6 + a9982f7 commit de7b307
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 57 deletions.
70 changes: 39 additions & 31 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from cartopy.mpl.slippy_image_artist import SlippyImageArtist


assert packaging.version.parse(mpl.__version__).release[:2] >= (3, 4), \
_MPL_VERSION = packaging.version.parse(mpl.__version__)
assert _MPL_VERSION.release >= (3, 4), \
'Cartopy is only supported with Matplotlib 3.4 or greater.'

# A nested mapping from path, source CRS, and target projection to the
Expand Down Expand Up @@ -1796,7 +1797,7 @@ def _wrap_args(self, *args, **kwargs):
kwargs['shading'] = 'flat'
X = np.asanyarray(args[0])
Y = np.asanyarray(args[1])
nrows, ncols = np.asanyarray(args[2]).shape
nrows, ncols = np.asanyarray(args[2]).shape[:2]
Nx = X.shape[-1]
Ny = Y.shape[0]
if X.ndim != 2 or X.shape[0] == 1:
Expand Down Expand Up @@ -1843,12 +1844,13 @@ def _wrap_quadmesh(self, collection, **kwargs):
Ny, Nx, _ = coords.shape
if kwargs.get('shading') == 'gouraud':
# Gouraud shading has the same shape for coords and data
data_shape = Ny, Nx
data_shape = Ny, Nx, -1
else:
data_shape = Ny - 1, Nx - 1
data_shape = Ny - 1, Nx - 1, -1
# data array
C = collection.get_array().reshape(data_shape)

if C.shape[-1] == 1:
C = C.squeeze(axis=-1)
transformed_pts = self.projection.transform_points(
t, coords[..., 0], coords[..., 1])

Expand Down Expand Up @@ -1921,13 +1923,12 @@ def _wrap_quadmesh(self, collection, **kwargs):
"map it must be fully transparent.",
stacklevel=3)

# The original data mask (regardless of wrapped cells)
C_mask = getattr(C, 'mask', None)
# Get hold of masked versions of the array to be passed to set_array
# methods of QuadMesh and PolyQuadMesh
pcolormesh_data, pcolor_data, pcolor_mask = \
cartopy.mpl.geocollection._split_wrapped_mesh_data(C, mask)

# create the masked array to be used with this pcolormesh
full_mask = mask if C_mask is None else mask | C_mask
pcolormesh_data = np.ma.array(C, mask=full_mask)
collection.set_array(pcolormesh_data.ravel())
collection.set_array(pcolormesh_data)

# plot with slightly lower zorder to avoid odd issue
# where the main plot is obscured
Expand All @@ -1943,25 +1944,32 @@ def _wrap_quadmesh(self, collection, **kwargs):
# `pcolor` only draws polygons where the data is not
# masked, so this will only draw a limited subset of
# polygons that were actually wrapped.
# We will add the original data mask in later to
# make sure that set_array can work in future
# calls on the proper sized array inputs.
# NOTE: we don't use C.data here because C.data could
# contain nan's which would be masked in the
# pcolor routines, which we don't want. We will
# fill in the proper data later with set_array()
# calls.
pcolor_data = np.ma.array(np.zeros(C.shape),
mask=~mask)
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_data, zorder=zorder,
**kwargs)
# Now add back in the masked data if there was any
full_mask = ~mask if C_mask is None else ~mask | C_mask
pcolor_data = np.ma.array(C, mask=full_mask)
# The pcolor_col is now possibly shorter than the
# actual collection, so grab the masked cells
pcolor_col.set_array(pcolor_data[mask].ravel())

if _MPL_VERSION.release[:2] < (3, 8):
# We will add the original data mask in later to
# make sure that set_array can work in future
# calls on the proper sized array inputs.
# NOTE: we don't use C.data here because C.data could
# contain nan's which would be masked in the
# pcolor routines, which we don't want. We will
# fill in the proper data later with set_array()
# calls.
pcolor_zeros = np.ma.array(np.zeros(C.shape), mask=pcolor_mask)
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_zeros, zorder=zorder,
**kwargs)

# The pcolor_col is now possibly shorter than the
# actual collection, so grab the masked cells
pcolor_col.set_array(pcolor_data[mask].ravel())
else:
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_data, zorder=zorder,
**kwargs)
# Currently pcolor_col.get_array() will return a compressed array
# and warn unless we explicitly set the 2D array. This should be
# unnecessary with future matplotlib versions.
pcolor_col.set_array(pcolor_data)

pcolor_col.set_cmap(cmap)
pcolor_col.set_norm(norm)
Expand All @@ -1972,7 +1980,7 @@ def _wrap_quadmesh(self, collection, **kwargs):
# put the pcolor_col and mask on the pcolormesh
# collection so that users can do things post
# this method
collection._wrapped_mask = mask.ravel()
collection._wrapped_mask = mask
collection._wrapped_collection_fix = pcolor_col

return collection
Expand Down
89 changes: 78 additions & 11 deletions lib/cartopy/mpl/geocollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,43 @@
# This file is part of Cartopy and is released under the LGPL license.
# See COPYING and COPYING.LESSER in the root of the repository for full
# licensing details.
import matplotlib as mpl
from matplotlib.collections import QuadMesh
import numpy as np
import numpy.ma as ma
import packaging


_MPL_VERSION = packaging.version.parse(mpl.__version__)


def _split_wrapped_mesh_data(C, mask):
"""
Helper function for splitting GeoQuadMesh array values between the
pcolormesh and pcolor objects when wrapping. Apply a mask to the grid
cells that should not be plotted with each method.
"""
# The original data mask (regardless of wrapped cells)
C_mask = getattr(C, 'mask', None)
if C.ndim == 3:
# RGB(A) array.
if _MPL_VERSION.release < (3, 8):
raise ValueError("GeoQuadMesh wrapping for RGB(A) requires "
"Matplotlib v3.8 or later")

# mask will need an extra trailing dimension
mask = np.broadcast_to(mask[..., np.newaxis], C.shape)

# create the masked array to be used with pcolormesh
full_mask = mask if C_mask is None else mask | C_mask
pcolormesh_data = ma.array(C, mask=full_mask)

# create the masked array to be used with pcolor
full_mask = ~mask if C_mask is None else ~mask | C_mask
pcolor_data = ma.array(C, mask=full_mask)

return pcolormesh_data, pcolor_data, ~mask


class GeoQuadMesh(QuadMesh):
Expand All @@ -21,25 +56,57 @@ def get_array(self):
A = super().get_array().copy()
# If the input array has a mask, retrieve the associated data
if hasattr(self, '_wrapped_mask'):
A[self._wrapped_mask] = self._wrapped_collection_fix.get_array()
pcolor_data = self._wrapped_collection_fix.get_array()
mask = self._wrapped_mask
if _MPL_VERSION.release[:2] < (3, 8):
A[mask] = pcolor_data
else:
if A.ndim == 3: # RGB(A) data. Need to broadcast mask.
mask = mask[:, :, np.newaxis]
# np.copyto is not implemented for masked arrays so handle the
# mask explicitly
np.copyto(A.mask, pcolor_data.mask, where=mask)
np.copyto(A, pcolor_data, where=mask)

return A

def set_array(self, A):
# raise right away if A is 2-dimensional.
if A.ndim > 1:
raise ValueError('Collections can only map rank 1 arrays. '
'You likely want to call with a flattened array '
'using collection.set_array(A.ravel()) instead.')
# Check the shape is appropriate up front.
if _MPL_VERSION.release[:2] < (3, 8):
# Need to figure out existing shape from the coordinates.
height, width = self._coordinates.shape[0:-1]
if self._shading == 'flat':
h, w = height - 1, width - 1
else:
h, w = height, width
else:
h, w = super().get_array().shape[:2]

ok_shapes = [(h, w, 3), (h, w, 4), (h, w), (h * w,)]
if A.shape not in ok_shapes:
ok_shape_str = ' or '.join(map(str, ok_shapes))
raise ValueError(
f"A should have shape {ok_shape_str}, not {A.shape}")

if A.ndim == 1:
# Always use array with at least two dimensions. This is
# inconsistent with QuadMesh which stores whatever you give it, but
# for the wrapped case we need to match the 2D mask. Storing the
# 2D array also allows us to calculate ok_shapes on subsequent
# calls without using the private QuadMesh._shading attribute.
A = A.reshape((h, w))

# Only use the mask attribute if it is there.
if hasattr(self, '_wrapped_mask'):

# Update the pcolor data with the wrapped masked data
self._wrapped_collection_fix.set_array(A[self._wrapped_mask])
# If the input array was a masked array, keep that data masked
if hasattr(A, 'mask'):
A = np.ma.array(A, mask=self._wrapped_mask | A.mask)
A, pcolor_data, _ = _split_wrapped_mesh_data(A, self._wrapped_mask)

if _MPL_VERSION.release[:2] < (3, 8):
self._wrapped_collection_fix.set_array(
pcolor_data[self._wrapped_mask].ravel())
else:
A = np.ma.array(A, mask=self._wrapped_mask)
self._wrapped_collection_fix.set_array(pcolor_data)

# Now that we have prepared the collection data, call on
# through to the underlying implementation.
Expand Down
Loading

0 comments on commit de7b307

Please sign in to comment.