Skip to content

Commit

Permalink
Support RangeXY streams on multi-axes (#5826)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jul 25, 2023
1 parent acaa412 commit 2ab0ebf
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 25 deletions.
33 changes: 27 additions & 6 deletions holoviews/plotting/bokeh/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class Callback:
# The plotting handle(s) to attach the JS callback on
models = []

# Additional handles to hash on for uniqueness
extra_handles = []

# Conditions when callback should be skipped
skip_events = []
skip_changes = []
Expand Down Expand Up @@ -209,6 +212,9 @@ def _init_plot_handles(self):
if h in self.plot_handles:
requested[h] = handles[h]
self.handle_ids.update(self._get_stream_handle_ids(requested))
for h in self.extra_handles:
if h in self.plot_handles:
requested[h] = handles[h]
return requested

def _get_stream_handle_ids(self, handles):
Expand Down Expand Up @@ -379,20 +385,23 @@ def set_callback(self, handle):

def initialize(self, plot_id=None):
handles = self._init_plot_handles()
cb_handles = []
for handle_name in self.models:
hash_handles, cb_handles = [], []
for handle_name in self.models+self.extra_handles:
if handle_name not in handles:
warn_args = (handle_name, type(self.plot).__name__,
type(self).__name__)
print('{} handle not found on {}, cannot '
'attach {} callback'.format(*warn_args))
continue
cb_handles.append(handles[handle_name])
handle = handles[handle_name]
if handle_name not in self.extra_handles:
cb_handles.append(handle)
hash_handles.append(handle)

# Hash the plot handle with Callback type allowing multiple
# callbacks on one handle to be merged
handle_ids = [id(h) for h in cb_handles]
cb_hash = tuple(handle_ids)+(id(type(self)),)
hash_ids = [id(h) for h in hash_handles]
cb_hash = tuple(hash_ids)+(id(type(self)),)
if cb_hash in self._callbacks:
# Merge callbacks if another callback has already been attached
cb = self._callbacks[cb_hash]
Expand Down Expand Up @@ -599,11 +608,13 @@ class RangeXYCallback(Callback):

models = ['plot']

extra_handles = ['x_range', 'y_range']

attributes = {
'x0': 'cb_obj.x0',
'y0': 'cb_obj.y0',
'x1': 'cb_obj.x1',
'y1': 'cb_obj.y1'
'y1': 'cb_obj.y1',
}

_js_on_event = """
Expand All @@ -624,6 +635,12 @@ def set_callback(self, handle):
handle.js_on_event('rangesupdate', CustomJS(code=self._js_on_event))

def _process_msg(self, msg):
if self.plot.state.x_range is not self.plot.handles['x_range']:
x_range = self.plot.handles['x_range']
msg['x0'], msg['x1'] = x_range.start, x_range.end
if self.plot.state.y_range is not self.plot.handles['y_range']:
y_range = self.plot.handles['y_range']
msg['y0'], msg['y1'] = y_range.start, y_range.end
data = {}
if 'x0' in msg and 'x1' in msg:
x0, x1 = msg['x0'], msg['x1']
Expand Down Expand Up @@ -657,6 +674,8 @@ class RangeXCallback(RangeXYCallback):

models = ['plot']

extra_handles = ['x_range']

attributes = {
'x0': 'cb_obj.x0',
'x1': 'cb_obj.x1',
Expand All @@ -672,6 +691,8 @@ class RangeYCallback(RangeXYCallback):

models = ['plot']

extra_handles = ['y_range']

attributes = {
'y0': 'cb_obj.y0',
'y1': 'cb_obj.y1'
Expand Down
52 changes: 35 additions & 17 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ def _update_grid(self, plot):
def _update_ranges(self, element, ranges):
x_range = self.handles['x_range']
y_range = self.handles['y_range']
plot = self.handles['plot']

self._update_main_ranges(element, x_range, y_range, ranges)

Expand All @@ -938,6 +939,10 @@ def _update_ranges(self, element, ranges):
factors = self._get_dimension_factors(element, ranges, axis_dim)
extra_scale = self.handles[f'extra_{multi_dim}_scales'][axis_dim] # Assumes scales and ranges zip
log = isinstance(extra_scale, LogScale)
range_update = (not (self.model_changed(extra_y_range) or self.model_changed(plot))
and self.framewise)
if self.drawn and not range_update:
continue
self._update_range(
extra_y_range, b, t, factors,
extra_y_range.tags[1]['invert_yaxis'] if extra_y_range.tags else False,
Expand Down Expand Up @@ -1350,28 +1355,19 @@ def _init_glyph(self, plot, mapping, properties):
if 'legend_field' in properties and 'legend_label' in properties:
del properties['legend_label']

# ALERT: This only handles XYGlyph types right now
# and note guard against Field (unhashable) when using FactorRanges
mapping = property_to_dict(mapping)
if 'x' in mapping:
x = mapping['x']
if plot.extra_x_ranges and (x in plot.extra_x_ranges):
properties['x_range_name'] = mapping['x']
if 'y' in mapping:
y = mapping['y']
if plot.extra_y_ranges and (y in plot.extra_y_ranges):
properties['y_range_name'] = mapping['y']
if self.handles['x_range'].name in plot.extra_x_ranges:
properties['x_range_name'] = self.handles['y_range'].name
if self.handles['y_range'].name in plot.extra_y_ranges:
properties['y_range_name'] = self.handles['y_range'].name

if "name" not in properties:
properties["name"] = properties.get("legend_label") or properties.get("legend_field")
renderer = getattr(plot, plot_method)(**dict(properties, **mapping))
return renderer, renderer.glyph


def _element_transform(self, transform, element, ranges):
return transform.apply(element, ranges=ranges, flat=True)


def _apply_transforms(self, element, data, ranges, style, group=None):
new_style = dict(style)
prefix = group+'_' if group else ''
Expand Down Expand Up @@ -1692,6 +1688,29 @@ def _init_glyphs(self, plot, element, ranges, source):
with abbreviated_exception():
self._update_glyph(renderer, properties, mapping, glyph, source, source.data)

def _find_axes(self, plot, element):
"""
Looks up the axes and plot ranges given the plot and an element.
"""
axis_dims = self._get_axis_dims(element)[:2]
if self.invert_axes:
axis_dims[0], axis_dims[1] = axis_dims[::-1]
x, y = axis_dims
if isinstance(x, Dimension) and x.name in plot.extra_x_ranges:
x_range = plot.extra_x_ranges[x.name]
xaxes = [xaxis for xaxis in plot.xaxis if xaxis.x_range_name == x.name]
x_axis = (xaxes if xaxes else plot.xaxis)[0]
else:
x_range = plot.x_range
x_axis = plot.xaxis[0]
if isinstance(y, Dimension) and y.name in plot.extra_y_ranges:
y_range = plot.extra_y_ranges[y.name]
yaxes = [yaxis for yaxis in plot.yaxis if yaxis.y_range_name == y.name]
y_axis = (yaxes if yaxes else plot.yaxis)[0]
else:
y_range = plot.y_range
y_axis = plot.yaxis[0]
return (x_axis, y_axis), (x_range, y_range)

def initialize_plot(self, ranges=None, plot=None, plots=None, source=None):
"""
Expand All @@ -1714,10 +1733,9 @@ def initialize_plot(self, ranges=None, plot=None, plots=None, source=None):
plot = self._init_plot(key, style_element, ranges=ranges, plots=plots)
self._init_axes(plot)
else:
self.handles['xaxis'] = plot.xaxis[0]
self.handles['x_range'] = plot.x_range
self.handles['yaxis'] = plot.yaxis[0]
self.handles['y_range'] = plot.y_range
axes, plot_ranges = self._find_axes(plot, element)
self.handles['xaxis'], self.handles['yaxis'] = axes
self.handles['x_range'], self.handles['y_range'] = plot_ranges
self.handles['plot'] = plot

if self.autorange:
Expand Down
2 changes: 1 addition & 1 deletion holoviews/plotting/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,7 @@ def __init__(self, overlay, ranges=None, batched=True, keys=None, group_counter=

if ('multi_y' in self.param) and self.multi_y:
for s in self.streams:
intersection = set(s.param) & {'y', 'y_selection', 'y_range', 'bounds', 'boundsy'}
intersection = set(s.param) & {'y', 'y_selection', 'bounds', 'boundsy'}
if intersection:
self.param.warning(f'{type(s).__name__} stream parameters'
f' {list(intersection)} not yet supported with multi_y=True')
Expand Down
20 changes: 20 additions & 0 deletions holoviews/tests/plotting/bokeh/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,23 @@ def test_msg_with_base64_array():

data_expected = np.array([10.0, 20.0, 30.0, 40.0])
assert np.equal(data_expected, data_after).all()


def test_rangexy_multi_yaxes():
c1 = Curve(np.arange(100).cumsum(), vdims='y')
c2 = Curve(-np.arange(100).cumsum(), vdims='y2')
RangeXY(source=c1)
RangeXY(source=c2)

overlay = (c1 * c2).opts(backend='bokeh', multi_y=True)
plot = bokeh_server_renderer.get_plot(overlay)

p1, p2 = plot.subplots.values()

assert plot.state.y_range is p1.handles['y_range']
assert 'y2' in plot.state.extra_y_ranges
assert plot.state.extra_y_ranges['y2'] is p2.handles['y_range']

# Ensure both callbacks are attached
assert p1.callbacks[0].plot is p1
assert p2.callbacks[0].plot is p2
39 changes: 38 additions & 1 deletion holoviews/tests/ui/bokeh/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

pytestmark = pytest.mark.ui

from holoviews import Scatter
from holoviews import Curve, Scatter
from holoviews.streams import BoundsXY, Lasso, RangeXY
from holoviews.plotting.bokeh import BokehRenderer
from panel.pane.holoviews import HoloViews
Expand Down Expand Up @@ -124,3 +124,40 @@ def test_rangexy(page, port):
expected_xrange = (0.32844036697247725, 0.8788990825688077)
expected_yrange = (1.8285714285714285, 2.3183673469387758)
wait_until(lambda: rangexy.x_range == expected_xrange and rangexy.y_range == expected_yrange, page)

def test_multi_axis_rangexy(page, port):
c1 = Curve(np.arange(100).cumsum(), vdims='y')
c2 = Curve(-np.arange(100).cumsum(), vdims='y2')
s1 = RangeXY(source=c1)
s2 = RangeXY(source=c2)

overlay = (c1 * c2).opts(backend='bokeh', multi_y=True)

pn_scatter = HoloViews(overlay, renderer=BokehRenderer)

serve(pn_scatter, port=port, threaded=True, show=False)

time.sleep(0.5)

page.goto(f"http://localhost:{port}")

hv_plot = page.locator('.bk-events')

expect(hv_plot).to_have_count(1)

bbox = hv_plot.bounding_box()
hv_plot.click()

page.mouse.move(bbox['x']+100, bbox['y']+100)
page.mouse.down()
page.mouse.move(bbox['x']+150, bbox['y']+150, steps=5)
page.mouse.up()

expected_xrange = (-35.1063829787234, 63.89361702127659)
expected_yrange1 = (717.2448979591848, 6657.244897959185)
expected_yrange2 = (-4232.7551020408155, 1707.2448979591848)
wait_until(lambda: (
s1.x_range == expected_xrange and
s1.y_range == expected_yrange1 and
s2.y_range == expected_yrange2
), page)

0 comments on commit 2ab0ebf

Please sign in to comment.