diff --git a/lib/cartopy/mpl/geoaxes.py b/lib/cartopy/mpl/geoaxes.py index 249d50fbf..c2fc01327 100644 --- a/lib/cartopy/mpl/geoaxes.py +++ b/lib/cartopy/mpl/geoaxes.py @@ -1332,24 +1332,25 @@ def imshow(self, img, *args, **kwargs): kwargs['alpha'] = alpha # As a workaround to a matplotlib limitation, turn any images - # which are RGB(A) with a mask into unmasked RGBA images with alpha - # put into the A channel. + # which are masked array RGB(A) into RGBA images + if np.ma.is_masked(img) and len(img.shape) > 2: - # if we don't pop alpha, imshow will apply (erroneously?) a - # 1D alpha to the RGBA array - # kwargs['alpha'] is guaranteed to be either 1D, 2D, or None - alpha = kwargs.pop('alpha') - old_img = img[:, :, 0:3] - img = np.zeros(img.shape[:2] + (4, ), dtype=img.dtype) - img[:, :, 0:3] = old_img - # Put an alpha channel in if the image was masked. - if not np.any(alpha): - alpha = 1 - img[:, :, 3] = np.ma.filled(alpha, fill_value=0) * \ - (~np.any(old_img.mask, axis=2)) - if img.dtype.kind == 'u': + + # transform RGB(A) into RGBA + old_img = img + img = np.ones(old_img.shape[:2] + (4, ), + dtype=old_img.dtype) + img[:, :, :3] = old_img[:, :, :3] + + # if img is RGBA, save alpha channel + if old_img.shape[-1] == 4: + img[:, :, 3] = old_img[:, :, 3] + elif img.dtype.kind == 'u': img[:, :, 3] *= 255 + # apply the mask to the A channel + img[np.any(old_img[:, :, :3].mask, axis=2), 3] = 0 + result = super().imshow(img, *args, extent=extent, **kwargs) return result diff --git a/lib/cartopy/tests/mpl/test_images.py b/lib/cartopy/tests/mpl/test_images.py index 830af6e66..c46c4ef27 100644 --- a/lib/cartopy/tests/mpl/test_images.py +++ b/lib/cartopy/tests/mpl/test_images.py @@ -154,6 +154,22 @@ def test_imshow_rgba(): assert sum(img.get_array().data[:, 0, 3]) == 0 +def test_imshow_rgba_alpha(): + # test that alpha channel from RGBA is not skipped + dy, dx = (3, 4) + + ax = plt.axes(projection=ccrs.Orthographic(-120, 45)) + + # Create RGBA Image with random data and linspace alpha + RGBA = np.linspace(0, 255*31, dx*dy*4, dtype=np.uint8).reshape((dy, dx, 4)) + + alpha = np.array([0, 85, 170, 255]) + RGBA[:, :, 3] = alpha + + img = ax.imshow(RGBA, transform=ccrs.PlateCarree()) + assert np.all(np.unique(img.get_array().data[:, :, 3]) == alpha) + + def test_imshow_rgb(): # tests that the alpha of a RGB array passed to imshow is set to 0 # instead of masked