Skip to content

Commit

Permalink
Use correct dtype for RGB image alpha channel (#1893)
Browse files Browse the repository at this point in the history
* Fix alpha channel logic for integer RGB images

* Use named argument for concat axis
  • Loading branch information
Zac-HD authored and fmaussion committed Feb 12, 2018
1 parent ee38ff0 commit 93a4039
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
3 changes: 0 additions & 3 deletions doc/gallery/plot_rasterio_rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
# Read the data
da = xr.open_rasterio('RGB.byte.tif')

# Normalize the image
da = da / 255

# The data is in UTM projection. We have to set it manually until
# https://github.com/SciTools/cartopy/issues/813 is implemented
crs = ccrs.UTM('18N')
Expand Down
8 changes: 6 additions & 2 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,12 @@ def imshow(x, y, z, ax, **kwargs):
# missing data transparent. We therefore add an alpha channel if
# there isn't one, and set it to transparent where data is masked.
if z.shape[-1] == 3:
z = np.ma.concatenate((z, np.ma.ones(z.shape[:2] + (1,))), 2)
z = z.copy()
alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype)
if np.issubdtype(z.dtype, np.integer):
alpha *= 255
z = np.ma.concatenate((z, alpha), axis=2)
else:
z = z.copy()
z[np.any(z.mask, axis=-1), -1] = 0

primitive = ax.imshow(z, **defaults)
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,13 @@ def test_normalize_rgb_one_arg_error(self):
for kwds in [dict(vmax=-1, vmin=-1.2), dict(vmin=2, vmax=2.1)]:
da.plot.imshow(**kwds)

def test_imshow_rgb_values_in_valid_range(self):
da = DataArray(np.arange(75, dtype='uint8').reshape((5, 5, 3)))
_, ax = plt.subplots()
out = da.plot.imshow(ax=ax).get_array()
assert out.dtype == np.uint8
assert (out[..., :3] == da.values).all() # Compare without added alpha


class TestFacetGrid(PlotTestCase):
def setUp(self):
Expand Down

0 comments on commit 93a4039

Please sign in to comment.