From 0bcecd1edf07cc8046030066df392ee04ae33473 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Tue, 25 Jul 2023 17:25:11 +0200 Subject: [PATCH] Consistently handle multi-axis positioning and labels (#5827) Co-authored-by: jlstevens --- examples/user_guide/Customizing_Plots.ipynb | 11 +- holoviews/plotting/bokeh/element.py | 109 +++++++++------- holoviews/plotting/plot.py | 2 +- .../tests/plotting/bokeh/test_multiaxis.py | 117 ++++++++++++++---- 4 files changed, 164 insertions(+), 75 deletions(-) diff --git a/examples/user_guide/Customizing_Plots.ipynb b/examples/user_guide/Customizing_Plots.ipynb index 88645e194f..93e5e35b5f 100644 --- a/examples/user_guide/Customizing_Plots.ipynb +++ b/examples/user_guide/Customizing_Plots.ipynb @@ -696,7 +696,7 @@ "metadata": {}, "source": [ "### Twin axes\n", - "*(Available in HoloViews > 1.17)*\n", + "*(Available in HoloViews >= 1.17, requires Bokeh >=3.2)*\n", "\n", "HoloViews now supports displaying overlays containing two different value dimensions as twin axes for chart elements. To maintain backwards compatibility, this feature is only enabled by setting the `multi_y=True` option on the overlay.\n", "\n", @@ -735,16 +735,19 @@ "metadata": {}, "outputs": [], "source": [ - "(hv.Curve([1, 2, 3], vdims=['A']) * hv.Curve([2, 3, 4], vdims=['B']).opts(autorange='y', invert_yaxis=True, logy=True, ylim=(1,10))).opts(multi_y=True)" + "(hv.Curve([1, 2, 3], vdims=['A']) \n", + " * hv.Curve([2, 3, 4], vdims=['B']).opts(autorange='y', invert_yaxis=True, logy=True, ylim=(1,10), \n", + " ylabel='B custom', fontsize={'ylabel':10})\n", + ").opts(multi_y=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Supported options for customizing individual axes are `apply_ranges`, `autorange='y'`, `invert_yaxis`, `logy` and `ylim`.\n", + "Supported options for customizing individual axes are `apply_ranges`, `autorange='y'`, `invert_yaxis`, `logy` and `ylim`, `yaxis` as well as the following options for labelling: `labelled`, `ylabel` and the `'ylabel'` setting in `fontsize`.\n", "\n", - "Note that as of HoloViews 1.17.0, `multi_y` does not have streaming plot support and that linked streams are not yet aware of additional y-axes." + "Note that as of HoloViews 1.17.0, `multi_y` does not have streaming plot support, extra axis labels are not dynamic and only the `RangeXY` linked stream is aware of additional y-axes." ] } ], diff --git a/holoviews/plotting/bokeh/element.py b/holoviews/plotting/bokeh/element.py index b7863fe05c..54661f0364 100644 --- a/holoviews/plotting/bokeh/element.py +++ b/holoviews/plotting/bokeh/element.py @@ -39,7 +39,9 @@ from ...streams import Buffer, RangeXY, PlotSize from ...util.transform import dim from ..plot import GenericElementPlot, GenericOverlayPlot -from ..util import process_cmap, color_intervals, dim_range_key +from ..util import ( + dim_axis_label, process_cmap, color_intervals, dim_range_key +) from .plot import BokehPlot from .styles import ( base_properties, legend_dimensions, line_properties, mpl_to_bokeh, @@ -437,12 +439,15 @@ def _axis_props(self, plots, subplots, element, ranges, pos, *, dim=None, else: specs = None - xlabel, ylabel, zlabel = self._get_axis_labels((None, None) if (dim is None) else dims) - if self.invert_axes: - xlabel, ylabel = ylabel, xlabel - if dims: - dims = dims[:2][::-1] - axis_label = ylabel if pos else xlabel + if dim: + axis_label = str(dim) + else: + xlabel, ylabel, zlabel = self._get_axis_labels(dims if dims else (None, None)) + if self.invert_axes: + xlabel, ylabel = ylabel, xlabel + axis_label = ylabel if pos else xlabel + if dims: + dims = dims[:2][::-1] categorical = any(self.traverse(lambda plot: plot._categorical)) if dims is not None and any(dim.name in ranges and 'factors' in ranges[dim.name] for dim in dims): @@ -497,20 +502,28 @@ def _create_extra_axes(self, plots, subplots, element, ranges): axpos0, axpos1 = 'left', 'right' ax_specs, yaxes, dimensions = {}, {}, {} - for el in element: - yd = el.get_dimension(1) + for el, sp in zip(element, self.subplots.values()): + ax_dims = sp._get_axis_dims(el)[:2] + if sp.invert_axes: + ax_dims[::-1] + yd = ax_dims[1] dimensions[yd.name] = yd opts = el.opts.get('plot', backend='bokeh').kwargs - if yd.name in yaxes: + if not isinstance(yd, Dimension) or yd.name in yaxes: continue yaxes[yd.name] = { 'position': opts.get('yaxis', axpos1 if len(yaxes) else axpos0), 'autorange': opts.get('autorange', None), 'logx': opts.get('logx', False), 'logy': opts.get('logy', False), - 'invert_yaxis': opts.get('invert_yaxis',False), + 'invert_yaxis': opts.get('invert_yaxis', False), # 'xlim': opts.get('xlim', (np.nan, np.nan)), # TODO - 'ylim': opts.get('ylim', (np.nan, np.nan)) + 'ylim': opts.get('ylim', (np.nan, np.nan)), + 'label': opts.get('ylabel', dim_axis_label(yd)), + 'fontsize': { + 'axis_label_text_font_size': sp._fontsize('ylabel').get('fontsize'), + 'major_label_text_font_size': sp._fontsize('yticks').get('fontsize') + } } for ydim, info in yaxes.items(): @@ -531,8 +544,10 @@ def _create_extra_axes(self, plots, subplots, element, ranges): extra_range_name=ydim ) log_enabled = info['logx'] if self.invert_axes else info['logy'] - ax_props = ('log' if log_enabled else ax_props[0], ax_props[1], ax_props[2]) - ax_specs[ydim] = ax_props + ax_type = 'log' if log_enabled else ax_props[0] + ax_specs[ydim] = ( + ax_type, info['label'], ax_props[2], info['position'], info['fontsize'] + ) return yaxes, ax_specs def _init_plot(self, key, element, plots, ranges=None): @@ -544,13 +559,15 @@ def _init_plot(self, key, element, plots, ranges=None): subplots = list(self.subplots.values()) if self.subplots else [] axis_specs = {'x': {}, 'y': {}} - axis_specs['x']['x'] = self._axis_props(plots, subplots, element, ranges, pos=0) + axis_specs['x']['x'] = self._axis_props(plots, subplots, element, ranges, pos=0) + (self.xaxis, {}) if self.multi_y: + if not bokeh32: + self.param.warning('Independent axis zooming for multi_y=True only supported for Bokeh >=3.2') yaxes, extra_axis_specs = self._create_extra_axes(plots, subplots, element, ranges) axis_specs['y'].update(extra_axis_specs) else: - range_tags_extras={'invert_yaxis':self.invert_yaxis} - if self.autorange=='y': + range_tags_extras = {'invert_yaxis': self.invert_yaxis} + if self.autorange == 'y': range_tags_extras['autorange'] = True lowerlim, upperlim = self.ylim if not ((lowerlim is None) or np.isnan(lowerlim)): @@ -559,13 +576,13 @@ def _init_plot(self, key, element, plots, ranges=None): range_tags_extras['y-upperlim'] = upperlim else: range_tags_extras['autorange'] = False + axis_specs['y']['y'] = self._axis_props( + plots, subplots, element, ranges, pos=1, range_tags_extras = range_tags_extras + ) + (self.yaxis, {}) - axis_specs['y']['y'] = self._axis_props(plots, subplots, element, ranges, pos=1, - range_tags_extras = range_tags_extras) - - properties = {} + properties, axis_props = {}, {'x': {}, 'y': {}} for axis, axis_spec in axis_specs.items(): - for (axis_dim, (axis_type, axis_label, axis_range)) in axis_spec.items(): + for (axis_dim, (axis_type, axis_label, axis_range, axis_position, fontsize)) in axis_spec.items(): scale = get_scale(axis_range, axis_type) if f'{axis}_range' in properties: properties[f'extra_{axis}_ranges'] = extra_ranges = properties.get(f'extra_{axis}_ranges', {}) @@ -576,6 +593,15 @@ def _init_plot(self, key, element, plots, ranges=None): properties[f'{axis}_range'] = axis_range properties[f'{axis}_scale'] = scale properties[f'{axis}_axis_type'] = axis_type + if axis_label and axis in self.labelled: + properties[f'{axis}_axis_label'] = axis_label + locs = {'left': 'left', 'right': 'right'} if axis == 'y' else {'bottom': 'below', 'top': 'above'} + if axis_position is None: + axis_props[axis]['visible'] = False + axis_props[axis].update(fontsize) + for loc, pos in locs.items(): + if axis_position and loc in axis_position: + properties[f'{axis}_axis_location'] = pos if not self.show_frame: properties['outline_line_alpha'] = 0 @@ -608,14 +634,22 @@ def _init_plot(self, key, element, plots, ranges=None): # are not really an issue warnings.simplefilter('ignore', UserWarning) fig = figure(title=title, **properties) + fig.xaxis[0].update(**axis_props['x']) + fig.yaxis[0].update(**axis_props['y']) multi_ax = 'x' if self.invert_axes else 'y' for axis_dim, range_obj in properties.get(f'extra_{multi_ax}_ranges', {}).items(): - axis_type, axis_label, _ = axis_specs[multi_ax][axis_dim] + axis_type, axis_label, _, axis_position, fontsize = axis_specs[multi_ax][axis_dim] ax_cls, ax_kwargs = get_axis_class(axis_type, range_obj, dim=1) ax_kwargs[f'{multi_ax}_range_name'] = axis_dim - fig.add_layout(ax_cls(axis_label=axis_label, - **ax_kwargs), yaxes[axis_dim]['position']) + ax_kwargs.update(fontsize) + if axis_position is None: + ax_kwargs['visible'] = False + axis_position = 'above' if multi_ax == 'x' else 'right' + if multi_ax in self.labelled: + ax_kwargs['axis_label'] = axis_label + ax = ax_cls(**ax_kwargs) + fig.add_layout(ax, axis_position) return fig def _plot_properties(self, key, element): @@ -658,7 +692,6 @@ def _plot_properties(self, key, element): plot_props['lod_'+lod_prop] = v return plot_props - def _set_active_tools(self, plot): "Activates the list of active tools" if plot is None or self.toolbar == "disable": @@ -695,7 +728,6 @@ def _set_active_tools(self, plot): if isinstance(tool, tools.InspectTool): plot.toolbar.active_inspect.append(tool) - def _title_properties(self, key, plot, element): if self.show_title and self.adjoined is None: title = self._format_title(key, separator=' ') @@ -709,24 +741,11 @@ def _title_properties(self, key, plot, element): opts['text_font_size'] = title_font return opts - def _init_axes(self, plot): - if self.xaxis is None: - plot.xaxis.visible = False - elif isinstance(self.xaxis, str) and 'top' in self.xaxis: - plot.above = [plot.xaxis[0]] + [ax for ax in plot.above if ax is not plot.xaxis[0]] - plot.below = [ax for ax in plot.below if ax is not plot.xaxis[0]] - plot.xaxis[:] = list(plot.above) + list(plot.below) + def _populate_axis_handles(self, plot): self.handles['xaxis'] = plot.xaxis[0] self.handles['x_range'] = plot.x_range self.handles['extra_x_ranges'] = plot.extra_x_ranges self.handles['extra_x_scales'] = plot.extra_x_scales - - if self.yaxis is None: - plot.yaxis.visible = False - elif isinstance(self.yaxis, str) and 'right' in self.yaxis: - plot.right = [plot.yaxis[0]] + [ax for ax in plot.right if ax is not plot.yaxis[0]] - plot.left = [ax for ax in plot.left if ax is not plot.yaxis[0]] - plot.yaxis[:] = list(plot.left) + list(plot.right) self.handles['yaxis'] = plot.yaxis[0] self.handles['y_range'] = plot.y_range self.handles['extra_y_ranges'] = plot.extra_y_ranges @@ -838,7 +857,8 @@ def _update_plot(self, key, plot, element=None): Updates plot parameters on every frame """ plot.update(**self._plot_properties(key, element)) - self._update_labels(key, plot, element) + if not self.multi_y: + self._update_labels(key, plot, element) self._update_title(key, plot, element) self._update_grid(plot) @@ -1731,7 +1751,7 @@ def initialize_plot(self, ranges=None, plot=None, plots=None, source=None): # Initialize plot, source and glyph if plot is None: plot = self._init_plot(key, style_element, ranges=ranges, plots=plots) - self._init_axes(plot) + self._populate_axis_handles(plot) else: axes, plot_ranges = self._find_axes(plot, element) self.handles['xaxis'], self.handles['yaxis'] = axes @@ -2723,7 +2743,7 @@ def initialize_plot(self, ranges=None, plot=None, plots=None): self.tabs = self.tabs or any(isinstance(sp, TablePlot) for sp in self.subplots.values()) if plot is None and not self.tabs and not self.batched: plot = self._init_plot(key, element, ranges=ranges, plots=plots) - self._init_axes(plot) + self._populate_axis_handles(plot) self.handles['plot'] = plot if plot and not self.overlaid: @@ -2780,7 +2800,6 @@ def initialize_plot(self, ranges=None, plot=None, plots=None): return self.handles['plot'] - def update_frame(self, key, ranges=None, element=None): """ Update the internal state of the Plot to represent the given diff --git a/holoviews/plotting/plot.py b/holoviews/plotting/plot.py index 0e9152fa0a..5bd2110937 100644 --- a/holoviews/plotting/plot.py +++ b/holoviews/plotting/plot.py @@ -1688,7 +1688,7 @@ class GenericOverlayPlot(GenericElementPlot): _passed_handles = [] # Options not to be propagated in multi_y mode to allow independent control of y-axes - _multi_y_unpropagated = ['ylim', 'invert_yaxis', 'logy'] + _multi_y_unpropagated = ['yaxis', 'ylim', 'invert_yaxis', 'logy'] def __init__(self, overlay, ranges=None, batched=True, keys=None, group_counter=None, **params): if 'projection' not in params: diff --git a/holoviews/tests/plotting/bokeh/test_multiaxis.py b/holoviews/tests/plotting/bokeh/test_multiaxis.py index 1f71acae65..baa486d59d 100644 --- a/holoviews/tests/plotting/bokeh/test_multiaxis.py +++ b/holoviews/tests/plotting/bokeh/test_multiaxis.py @@ -1,9 +1,9 @@ -import pytest -from holoviews.element import Curve -from .test_plot import TestBokehPlot, bokeh_renderer from bokeh.models import LinearScale, LogScale, LinearAxis, LogAxis +from holoviews.element import Curve from ...utils import LoggingComparisonTestCase +from .test_plot import TestBokehPlot, bokeh_renderer + class TestCurveTwinAxes(LoggingComparisonTestCase, TestBokehPlot): @@ -201,40 +201,107 @@ def test_shared_multi_axes(self): self.assertEqual((y_range.start, y_range.end), (5, 19)) self.assertEqual((extra_y_ranges['B'].start, extra_y_ranges['B'].end), (1, 13)) - @pytest.mark.xfail - def test_swapped_position_label(self): - overlay = (Curve(range(10), vdims=['A']).opts(yaxis='right') - * Curve(range(10), vdims=['B']).opts(yaxis='left') - ).opts(multi_y=True) + def test_invisible_main_axis(self): + overlay = ( + Curve(range(10), vdims=['A']).opts(yaxis=None) * + Curve(range(10), vdims=['B']) + ).opts(multi_y=True) + plot = bokeh_renderer.get_plot(overlay) + assert len(plot.state.yaxis) == 2 + assert not plot.state.yaxis[0].visible + assert plot.state.yaxis[1].visible + + def test_invisible_extra_axis(self): + overlay = ( + Curve(range(10), vdims=['A']) * + Curve(range(10), vdims=['B']).opts(yaxis=None) + ).opts(multi_y=True) + plot = bokeh_renderer.get_plot(overlay) + assert len(plot.state.yaxis) == 2 + assert plot.state.yaxis[0].visible + assert not plot.state.yaxis[1].visible + + def test_axis_labels(self): + overlay = ( + Curve(range(10), vdims=['A']) * + Curve(range(10), vdims=['B']) + ).opts(multi_y=True) plot = bokeh_renderer.get_plot(overlay) - self.assertEqual(plot.state.yaxis[0].axis_label, 'B') - self.assertEqual(plot.state.yaxis[1].axis_label, 'A') + assert plot.state.xaxis[0].axis_label == 'x' + assert plot.state.yaxis[0].axis_label == 'A' + assert plot.state.yaxis[1].axis_label == 'B' + + def test_custom_axis_labels(self): + overlay = ( + Curve(range(10), vdims=['A']).opts(xlabel='x-custom', ylabel='A-custom') * + Curve(range(10), vdims=['B']).opts(ylabel='B-custom') + ).opts(multi_y=True) + plot = bokeh_renderer.get_plot(overlay) + + assert plot.state.xaxis[0].axis_label == 'x-custom' + assert plot.state.yaxis[0].axis_label == 'A-custom' + assert plot.state.yaxis[1].axis_label == 'B-custom' + + def test_only_x_axis_labels(self): + overlay = ( + Curve(range(10), vdims=['A']) * + Curve(range(10), vdims=['B']) + ).opts(multi_y=True, labelled=['x']) + plot = bokeh_renderer.get_plot(overlay) + + assert plot.state.xaxis[0].axis_label == 'x' + assert plot.state.yaxis[0].axis_label is None + assert plot.state.yaxis[1].axis_label is None + + def test_only_x_axis_labels(self): + overlay = ( + Curve(range(10), vdims=['A']) * + Curve(range(10), vdims=['B']) + ).opts(multi_y=True, labelled=['y']) + plot = bokeh_renderer.get_plot(overlay) + + assert plot.state.xaxis[0].axis_label is None + assert plot.state.yaxis[0].axis_label == 'A' + assert plot.state.yaxis[1].axis_label == 'B' + + def test_swapped_position_label(self): + overlay = ( + Curve(range(10), vdims=['A']).opts(yaxis='right') * + Curve(range(10), vdims=['B']).opts(yaxis='left') + ).opts(multi_y=True) + plot = bokeh_renderer.get_plot(overlay) + assert plot.state.yaxis[0].axis_label == 'B' + assert plot.state.yaxis[1].axis_label == 'A' - @pytest.mark.xfail - def test_swapped_position_custom_label(self): + def test_swapped_position_custom_y_labels(self): overlay = (Curve(range(10), vdims=['A']).opts(yaxis='right', ylabel='A-custom') * Curve(range(10), vdims=['B']).opts(yaxis='left', ylabel='B-custom') ).opts(multi_y=True) plot = bokeh_renderer.get_plot(overlay) - self.assertEqual(plot.state.yaxis[0].axis_label, 'B-custom') - self.assertEqual(plot.state.yaxis[1].axis_label, 'A-custom') + assert plot.state.yaxis[0].axis_label == 'B-custom' + assert plot.state.yaxis[1].axis_label == 'A-custom' - @pytest.mark.xfail def test_position_custom_size_label(self): - overlay = (Curve(range(10), vdims='A').opts(fontsize={'ylabel': '13pt'}) - * Curve(range(10), vdims='B').opts(fontsize={'ylabel': '15pt'})).opts(multi_y=True) + overlay = ( + Curve(range(10), vdims='A').opts(fontsize={'ylabel': '13pt'}) * + Curve(range(10), vdims='B').opts(fontsize={'ylabel': '15pt'}) + ).opts(multi_y=True) plot = bokeh_renderer.get_plot(overlay) - self.assertEqual(plot.state.yaxis[0].axis_label_text_font_size, '13pt') - self.assertEqual(plot.state.yaxis[1].axis_label_text_font_size, '15pt') + assert plot.state.yaxis[0].axis_label == 'A' + assert plot.state.yaxis[0].axis_label_text_font_size == '13pt' + assert plot.state.yaxis[1].axis_label == 'B' + assert plot.state.yaxis[1].axis_label_text_font_size == '15pt' - @pytest.mark.xfail def test_swapped_position_custom_size_label(self): - overlay = (Curve(range(10), vdims='A').opts(yaxis='right', fontsize={'ylabel': '13pt'}) - * Curve(range(10), vdims='B').opts(yaxis='left', - fontsize={'ylabel': '15pt'})).opts(multi_y=True) + overlay = ( + Curve(range(10), vdims='A').opts(yaxis='right', fontsize={'ylabel': '13pt'}) * + Curve(range(10), vdims='B').opts(yaxis='left', fontsize={'ylabel': '15pt'}) + ).opts(multi_y=True) plot = bokeh_renderer.get_plot(overlay) - self.assertEqual(plot.state.yaxis[0].axis_label_text_font_size, '15pt') - self.assertEqual(plot.state.yaxis[1].axis_label_text_font_size, '13pt') + assert plot.state.yaxis[0].axis_label == 'B' + assert plot.state.yaxis[0].axis_label_text_font_size == '15pt' + assert plot.state.yaxis[1].axis_label == 'A' + assert plot.state.yaxis[1].axis_label_text_font_size == '13pt'