From 2ab0ebfe0f6bd63a40f13c88452723a8fc11de79 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Tue, 25 Jul 2023 09:42:51 +0200 Subject: [PATCH] Support RangeXY streams on multi-axes (#5826) --- holoviews/plotting/bokeh/callbacks.py | 33 +++++++++--- holoviews/plotting/bokeh/element.py | 52 +++++++++++++------ holoviews/plotting/plot.py | 2 +- .../tests/plotting/bokeh/test_callbacks.py | 20 +++++++ holoviews/tests/ui/bokeh/test_callback.py | 39 +++++++++++++- 5 files changed, 121 insertions(+), 25 deletions(-) diff --git a/holoviews/plotting/bokeh/callbacks.py b/holoviews/plotting/bokeh/callbacks.py index 53fb9f5052..12345edd67 100644 --- a/holoviews/plotting/bokeh/callbacks.py +++ b/holoviews/plotting/bokeh/callbacks.py @@ -87,6 +87,9 @@ class Callback: # The plotting handle(s) to attach the JS callback on models = [] + # Additional handles to hash on for uniqueness + extra_handles = [] + # Conditions when callback should be skipped skip_events = [] skip_changes = [] @@ -209,6 +212,9 @@ def _init_plot_handles(self): if h in self.plot_handles: requested[h] = handles[h] self.handle_ids.update(self._get_stream_handle_ids(requested)) + for h in self.extra_handles: + if h in self.plot_handles: + requested[h] = handles[h] return requested def _get_stream_handle_ids(self, handles): @@ -379,20 +385,23 @@ def set_callback(self, handle): def initialize(self, plot_id=None): handles = self._init_plot_handles() - cb_handles = [] - for handle_name in self.models: + hash_handles, cb_handles = [], [] + for handle_name in self.models+self.extra_handles: if handle_name not in handles: warn_args = (handle_name, type(self.plot).__name__, type(self).__name__) print('{} handle not found on {}, cannot ' 'attach {} callback'.format(*warn_args)) continue - cb_handles.append(handles[handle_name]) + handle = handles[handle_name] + if handle_name not in self.extra_handles: + cb_handles.append(handle) + hash_handles.append(handle) # Hash the plot handle with Callback type allowing multiple # callbacks on one handle to be merged - handle_ids = [id(h) for h in cb_handles] - cb_hash = tuple(handle_ids)+(id(type(self)),) + hash_ids = [id(h) for h in hash_handles] + cb_hash = tuple(hash_ids)+(id(type(self)),) if cb_hash in self._callbacks: # Merge callbacks if another callback has already been attached cb = self._callbacks[cb_hash] @@ -599,11 +608,13 @@ class RangeXYCallback(Callback): models = ['plot'] + extra_handles = ['x_range', 'y_range'] + attributes = { 'x0': 'cb_obj.x0', 'y0': 'cb_obj.y0', 'x1': 'cb_obj.x1', - 'y1': 'cb_obj.y1' + 'y1': 'cb_obj.y1', } _js_on_event = """ @@ -624,6 +635,12 @@ def set_callback(self, handle): handle.js_on_event('rangesupdate', CustomJS(code=self._js_on_event)) def _process_msg(self, msg): + if self.plot.state.x_range is not self.plot.handles['x_range']: + x_range = self.plot.handles['x_range'] + msg['x0'], msg['x1'] = x_range.start, x_range.end + if self.plot.state.y_range is not self.plot.handles['y_range']: + y_range = self.plot.handles['y_range'] + msg['y0'], msg['y1'] = y_range.start, y_range.end data = {} if 'x0' in msg and 'x1' in msg: x0, x1 = msg['x0'], msg['x1'] @@ -657,6 +674,8 @@ class RangeXCallback(RangeXYCallback): models = ['plot'] + extra_handles = ['x_range'] + attributes = { 'x0': 'cb_obj.x0', 'x1': 'cb_obj.x1', @@ -672,6 +691,8 @@ class RangeYCallback(RangeXYCallback): models = ['plot'] + extra_handles = ['y_range'] + attributes = { 'y0': 'cb_obj.y0', 'y1': 'cb_obj.y1' diff --git a/holoviews/plotting/bokeh/element.py b/holoviews/plotting/bokeh/element.py index 5ecea2a0c5..b7863fe05c 100644 --- a/holoviews/plotting/bokeh/element.py +++ b/holoviews/plotting/bokeh/element.py @@ -927,6 +927,7 @@ def _update_grid(self, plot): def _update_ranges(self, element, ranges): x_range = self.handles['x_range'] y_range = self.handles['y_range'] + plot = self.handles['plot'] self._update_main_ranges(element, x_range, y_range, ranges) @@ -938,6 +939,10 @@ def _update_ranges(self, element, ranges): factors = self._get_dimension_factors(element, ranges, axis_dim) extra_scale = self.handles[f'extra_{multi_dim}_scales'][axis_dim] # Assumes scales and ranges zip log = isinstance(extra_scale, LogScale) + range_update = (not (self.model_changed(extra_y_range) or self.model_changed(plot)) + and self.framewise) + if self.drawn and not range_update: + continue self._update_range( extra_y_range, b, t, factors, extra_y_range.tags[1]['invert_yaxis'] if extra_y_range.tags else False, @@ -1350,28 +1355,19 @@ def _init_glyph(self, plot, mapping, properties): if 'legend_field' in properties and 'legend_label' in properties: del properties['legend_label'] - # ALERT: This only handles XYGlyph types right now - # and note guard against Field (unhashable) when using FactorRanges - mapping = property_to_dict(mapping) - if 'x' in mapping: - x = mapping['x'] - if plot.extra_x_ranges and (x in plot.extra_x_ranges): - properties['x_range_name'] = mapping['x'] - if 'y' in mapping: - y = mapping['y'] - if plot.extra_y_ranges and (y in plot.extra_y_ranges): - properties['y_range_name'] = mapping['y'] + if self.handles['x_range'].name in plot.extra_x_ranges: + properties['x_range_name'] = self.handles['y_range'].name + if self.handles['y_range'].name in plot.extra_y_ranges: + properties['y_range_name'] = self.handles['y_range'].name if "name" not in properties: properties["name"] = properties.get("legend_label") or properties.get("legend_field") renderer = getattr(plot, plot_method)(**dict(properties, **mapping)) return renderer, renderer.glyph - def _element_transform(self, transform, element, ranges): return transform.apply(element, ranges=ranges, flat=True) - def _apply_transforms(self, element, data, ranges, style, group=None): new_style = dict(style) prefix = group+'_' if group else '' @@ -1692,6 +1688,29 @@ def _init_glyphs(self, plot, element, ranges, source): with abbreviated_exception(): self._update_glyph(renderer, properties, mapping, glyph, source, source.data) + def _find_axes(self, plot, element): + """ + Looks up the axes and plot ranges given the plot and an element. + """ + axis_dims = self._get_axis_dims(element)[:2] + if self.invert_axes: + axis_dims[0], axis_dims[1] = axis_dims[::-1] + x, y = axis_dims + if isinstance(x, Dimension) and x.name in plot.extra_x_ranges: + x_range = plot.extra_x_ranges[x.name] + xaxes = [xaxis for xaxis in plot.xaxis if xaxis.x_range_name == x.name] + x_axis = (xaxes if xaxes else plot.xaxis)[0] + else: + x_range = plot.x_range + x_axis = plot.xaxis[0] + if isinstance(y, Dimension) and y.name in plot.extra_y_ranges: + y_range = plot.extra_y_ranges[y.name] + yaxes = [yaxis for yaxis in plot.yaxis if yaxis.y_range_name == y.name] + y_axis = (yaxes if yaxes else plot.yaxis)[0] + else: + y_range = plot.y_range + y_axis = plot.yaxis[0] + return (x_axis, y_axis), (x_range, y_range) def initialize_plot(self, ranges=None, plot=None, plots=None, source=None): """ @@ -1714,10 +1733,9 @@ def initialize_plot(self, ranges=None, plot=None, plots=None, source=None): plot = self._init_plot(key, style_element, ranges=ranges, plots=plots) self._init_axes(plot) else: - self.handles['xaxis'] = plot.xaxis[0] - self.handles['x_range'] = plot.x_range - self.handles['yaxis'] = plot.yaxis[0] - self.handles['y_range'] = plot.y_range + axes, plot_ranges = self._find_axes(plot, element) + self.handles['xaxis'], self.handles['yaxis'] = axes + self.handles['x_range'], self.handles['y_range'] = plot_ranges self.handles['plot'] = plot if self.autorange: diff --git a/holoviews/plotting/plot.py b/holoviews/plotting/plot.py index c9c34b6b7c..0e9152fa0a 100644 --- a/holoviews/plotting/plot.py +++ b/holoviews/plotting/plot.py @@ -1699,7 +1699,7 @@ def __init__(self, overlay, ranges=None, batched=True, keys=None, group_counter= if ('multi_y' in self.param) and self.multi_y: for s in self.streams: - intersection = set(s.param) & {'y', 'y_selection', 'y_range', 'bounds', 'boundsy'} + intersection = set(s.param) & {'y', 'y_selection', 'bounds', 'boundsy'} if intersection: self.param.warning(f'{type(s).__name__} stream parameters' f' {list(intersection)} not yet supported with multi_y=True') diff --git a/holoviews/tests/plotting/bokeh/test_callbacks.py b/holoviews/tests/plotting/bokeh/test_callbacks.py index 30d422a9fb..68e868a1b7 100644 --- a/holoviews/tests/plotting/bokeh/test_callbacks.py +++ b/holoviews/tests/plotting/bokeh/test_callbacks.py @@ -445,3 +445,23 @@ def test_msg_with_base64_array(): data_expected = np.array([10.0, 20.0, 30.0, 40.0]) assert np.equal(data_expected, data_after).all() + + +def test_rangexy_multi_yaxes(): + c1 = Curve(np.arange(100).cumsum(), vdims='y') + c2 = Curve(-np.arange(100).cumsum(), vdims='y2') + RangeXY(source=c1) + RangeXY(source=c2) + + overlay = (c1 * c2).opts(backend='bokeh', multi_y=True) + plot = bokeh_server_renderer.get_plot(overlay) + + p1, p2 = plot.subplots.values() + + assert plot.state.y_range is p1.handles['y_range'] + assert 'y2' in plot.state.extra_y_ranges + assert plot.state.extra_y_ranges['y2'] is p2.handles['y_range'] + + # Ensure both callbacks are attached + assert p1.callbacks[0].plot is p1 + assert p2.callbacks[0].plot is p2 diff --git a/holoviews/tests/ui/bokeh/test_callback.py b/holoviews/tests/ui/bokeh/test_callback.py index 24cbcf794f..6d3f21dffe 100644 --- a/holoviews/tests/ui/bokeh/test_callback.py +++ b/holoviews/tests/ui/bokeh/test_callback.py @@ -10,7 +10,7 @@ pytestmark = pytest.mark.ui -from holoviews import Scatter +from holoviews import Curve, Scatter from holoviews.streams import BoundsXY, Lasso, RangeXY from holoviews.plotting.bokeh import BokehRenderer from panel.pane.holoviews import HoloViews @@ -124,3 +124,40 @@ def test_rangexy(page, port): expected_xrange = (0.32844036697247725, 0.8788990825688077) expected_yrange = (1.8285714285714285, 2.3183673469387758) wait_until(lambda: rangexy.x_range == expected_xrange and rangexy.y_range == expected_yrange, page) + +def test_multi_axis_rangexy(page, port): + c1 = Curve(np.arange(100).cumsum(), vdims='y') + c2 = Curve(-np.arange(100).cumsum(), vdims='y2') + s1 = RangeXY(source=c1) + s2 = RangeXY(source=c2) + + overlay = (c1 * c2).opts(backend='bokeh', multi_y=True) + + pn_scatter = HoloViews(overlay, renderer=BokehRenderer) + + serve(pn_scatter, port=port, threaded=True, show=False) + + time.sleep(0.5) + + page.goto(f"http://localhost:{port}") + + hv_plot = page.locator('.bk-events') + + expect(hv_plot).to_have_count(1) + + bbox = hv_plot.bounding_box() + hv_plot.click() + + page.mouse.move(bbox['x']+100, bbox['y']+100) + page.mouse.down() + page.mouse.move(bbox['x']+150, bbox['y']+150, steps=5) + page.mouse.up() + + expected_xrange = (-35.1063829787234, 63.89361702127659) + expected_yrange1 = (717.2448979591848, 6657.244897959185) + expected_yrange2 = (-4232.7551020408155, 1707.2448979591848) + wait_until(lambda: ( + s1.x_range == expected_xrange and + s1.y_range == expected_yrange1 and + s2.y_range == expected_yrange2 + ), page)