diff --git a/holoviews/element/raster.py b/holoviews/element/raster.py index 6dd7c0dca8..6f2d751046 100644 --- a/holoviews/element/raster.py +++ b/holoviews/element/raster.py @@ -519,7 +519,7 @@ def __getitem__(self, coords): def range(self, dim, data_range=True): dim_idx = dim if isinstance(dim, int) else self.get_dimension_index(dim) dim = self.get_dimension(dim_idx) - if dim.range != (None, None): + if None not in dim.range: return dim.range elif dim_idx in [0, 1]: l, b, r, t = self.bounds.lbrt() @@ -532,13 +532,14 @@ def range(self, dim, data_range=True): data = np.atleast_3d(self.data)[:, :, dim_idx] drange = (np.nanmin(data), np.nanmax(data)) if data_range: - soft_range = [sr for sr in dim.soft_range if sr is not None] + soft_range = [np.NaN if sr is None else sr for sr in dim.soft_range] if soft_range: - return util.max_range([drange, soft_range]) - else: - return drange + drange = util.max_range([drange, soft_range]) + ranges = zip(drange, dim.range) else: - return dim.soft_range + ranges = zip(dim.soft_range, dim.range) + return tuple(datar if dimr is None else dimr + for datar, dimr in ranges) def _coord2matrix(self, coord): diff --git a/holoviews/plotting/mpl/element.py b/holoviews/plotting/mpl/element.py index 0323f17361..f9b59179f3 100644 --- a/holoviews/plotting/mpl/element.py +++ b/holoviews/plotting/mpl/element.py @@ -614,7 +614,7 @@ def _draw_colorbar(self, dim=None, redraw=True): scaled_w = w*width cax = fig.add_axes([l+w+padding+(scaled_w+padding+w*0.15)*offset, b, scaled_w, h]) - cbar = plt.colorbar(artist, cax=cax, extend=self._cbar_extend) + cbar = fig.colorbar(artist, cax=cax, ax=axis, extend=self._cbar_extend) self._adjust_cbar(cbar, label, dim) self.handles['cax'] = cax self.handles['cbar'] = cbar @@ -630,7 +630,6 @@ def _draw_colorbar(self, dim=None, redraw=True): ColorbarPlot._colorbars[id(axis)] = (ax_colorbars, (l, b, w, h)) - def _norm_kwargs(self, element, ranges, opts, vdim): """ Returns valid color normalization kwargs @@ -658,12 +657,18 @@ def _norm_kwargs(self, element, ranges, opts, vdim): opts['vmax'] = clim[1] # Check whether the colorbar should indicate clipping - el_min, el_max = element.range(vdim) - if el_min < opts['vmin'] and el_max > opts['vmax']: + values = element.dimension_values(vdim) + if values.dtype.kind not in 'OSUM': + el_min, el_max = np.nanmin(values), np.nanmax(values) + else: + el_min, el_max = -np.inf, np.inf + vmin = -np.inf if opts['vmin'] is None else opts['vmin'] + vmax = np.inf if opts['vmax'] is None else opts['vmax'] + if el_min < vmin and el_max > vmax: self._cbar_extend = 'both' - elif el_min < opts['vmin']: + elif el_min < vmin: self._cbar_extend = 'min' - elif el_max > opts['vmax']: + elif el_max > vmax: self._cbar_extend = 'max' # Define special out-of-range colors on colormap diff --git a/tests/testplotinstantiation.py b/tests/testplotinstantiation.py index 53a7fd22d1..4524022aa6 100644 --- a/tests/testplotinstantiation.py +++ b/tests/testplotinstantiation.py @@ -211,6 +211,46 @@ def test_curve_heterogeneous_datetime_types_with_pd_overlay(self): plot = mpl_renderer.get_plot(curve_dt*curve_dt64*curve_pd) self.assertEqual(plot.handles['axis'].get_xlim(), (735964.0, 735976.0)) + def test_image_cbar_extend_both(self): + img = Image(np.array([[0, 1], [2, 3]])).redim(z=dict(range=(1,2))) + plot = mpl_renderer.get_plot(img(plot=dict(colorbar=True))) + self.assertEqual(plot.handles['cbar'].extend, 'both') + + def test_image_cbar_extend_min(self): + img = Image(np.array([[0, 1], [2, 3]])).redim(z=dict(range=(1, None))) + plot = mpl_renderer.get_plot(img(plot=dict(colorbar=True))) + self.assertEqual(plot.handles['cbar'].extend, 'min') + + def test_image_cbar_extend_max(self): + img = Image(np.array([[0, 1], [2, 3]])).redim(z=dict(range=(None, 2))) + plot = mpl_renderer.get_plot(img(plot=dict(colorbar=True))) + self.assertEqual(plot.handles['cbar'].extend, 'max') + + def test_image_cbar_extend_clime(self): + img = Image(np.array([[0, 1], [2, 3]]))(style=dict(clim=(None, None))) + plot = mpl_renderer.get_plot(img(plot=dict(colorbar=True, color_index=1))) + self.assertEqual(plot.handles['cbar'].extend, 'neither') + + def test_points_cbar_extend_both(self): + img = Points(([0, 1], [0, 3])).redim(y=dict(range=(1,2))) + plot = mpl_renderer.get_plot(img(plot=dict(colorbar=True, color_index=1))) + self.assertEqual(plot.handles['cbar'].extend, 'both') + + def test_points_cbar_extend_min(self): + img = Points(([0, 1], [0, 3])).redim(y=dict(range=(1, None))) + plot = mpl_renderer.get_plot(img(plot=dict(colorbar=True, color_index=1))) + self.assertEqual(plot.handles['cbar'].extend, 'min') + + def test_points_cbar_extend_max(self): + img = Points(([0, 1], [0, 3])).redim(y=dict(range=(None, 2))) + plot = mpl_renderer.get_plot(img(plot=dict(colorbar=True, color_index=1))) + self.assertEqual(plot.handles['cbar'].extend, 'max') + + def test_points_cbar_extend_clime(self): + img = Points(([0, 1], [0, 3]))(style=dict(clim=(None, None))) + plot = mpl_renderer.get_plot(img(plot=dict(colorbar=True, color_index=1))) + self.assertEqual(plot.handles['cbar'].extend, 'neither') + def test_layout_instantiate_subplots(self): layout = (Curve(range(10)) + Curve(range(10)) + Image(np.random.rand(10,10)) + Curve(range(10)) + Curve(range(10)))