From 87622b5e10a2b63df4bcf41614e82a94cc54a5d6 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 7 Nov 2020 11:19:07 -0500 Subject: [PATCH] Add Plotly Dash support (#4605) --- holoviews/core/decollate.py | 4 +- holoviews/plotting/plotly/__init__.py | 1 + holoviews/plotting/plotly/dash.py | 558 +++++++++++++++++++ holoviews/tests/plotting/bokeh/testserver.py | 7 +- holoviews/tests/plotting/plotly/testdash.py | 494 ++++++++++++++++ pytest.ini | 2 + setup.py | 3 +- 7 files changed, 1064 insertions(+), 5 deletions(-) create mode 100644 holoviews/plotting/plotly/dash.py create mode 100644 holoviews/tests/plotting/plotly/testdash.py create mode 100644 pytest.ini diff --git a/holoviews/core/decollate.py b/holoviews/core/decollate.py index d34ecfeeba..56e6219c5c 100644 --- a/holoviews/core/decollate.py +++ b/holoviews/core/decollate.py @@ -3,7 +3,7 @@ from .. import ( Layout, DynamicMap, Element, Callable, Overlay, GridSpace, NdOverlay, HoloMap ) -from . import ViewableTree +from . import ViewableTree, AdjointLayout from collections import namedtuple from ..streams import Stream, Derived @@ -148,7 +148,7 @@ def to_expr_extract_streams( stream_mapping.setdefault(container_key, []).append(cloned_stream) return stream_index - elif isinstance(hvobj, (Layout, GridSpace, NdOverlay, HoloMap, Overlay)): + elif isinstance(hvobj, (Layout, GridSpace, NdOverlay, HoloMap, Overlay, AdjointLayout)): fn = hvobj.clone(data={}).clone args = [] data_expr = [] diff --git a/holoviews/plotting/plotly/__init__.py b/holoviews/plotting/plotly/__init__.py index 0ce6e965cf..a50e1bc512 100644 --- a/holoviews/plotting/plotly/__init__.py +++ b/holoviews/plotting/plotly/__init__.py @@ -89,6 +89,7 @@ Overlay: OverlayPlot, NdOverlay: OverlayPlot, Layout: LayoutPlot, + AdjointLayout: AdjointLayoutPlot, NdLayout: LayoutPlot, GridSpace: GridPlot, GridMatrix: GridPlot}, backend='plotly') diff --git a/holoviews/plotting/plotly/dash.py b/holoviews/plotting/plotly/dash.py new file mode 100644 index 0000000000..1f96134f4c --- /dev/null +++ b/holoviews/plotting/plotly/dash.py @@ -0,0 +1,558 @@ +from __future__ import absolute_import + +# standard library imports +import uuid +import copy +from collections import OrderedDict, namedtuple +import pickle +import base64 + +# Holoviews imports +import holoviews as hv +from holoviews.plotting.plotly import PlotlyRenderer, DynamicMap +from holoviews.plotting.plotly.util import clean_internal_figure_properties +from holoviews.core.decollate import ( + initialize_dynamic, to_expr_extract_streams, expr_to_fn_of_stream_contents +) +from holoviews.streams import Derived, History +from holoviews.plotting.plotly.callbacks import ( + Selection1DCallback, RangeXYCallback, RangeXCallback, RangeYCallback, + BoundsXYCallback, BoundsXCallback, BoundsYCallback +) + +# Dash imports +import dash_core_components as dcc +import dash_html_components as html +from dash import callback_context +from dash.dependencies import Output, Input, State + +# plotly.py imports +import plotly.graph_objects as go + +# Activate plotly as current HoloViews extension +hv.extension("plotly") + + +# Named tuples definitions +StreamCallback = namedtuple("StreamCallback", ["input_ids", "fn", "output_id"]) +DashComponents = namedtuple( + "DashComponents", ["graphs", "kdims", "store", "resets", "children"] +) +HoloViewsFunctionSpec = namedtuple("HoloViewsFunctionSpec", ["fn", "kdims", "streams"]) + + +def plot_to_figure(plot, reset_nclicks=0): + """ + Convert a HoloViews plotly plot to a plotly.py Figure. + + Args: + plot: A HoloViews plotly plot object + reset_nclicks: Number of times a reset button associated with the plot has been + clicked + + Returns: + A plotly.py Figure + """ + fig_dict = plot.state + clean_internal_figure_properties(fig_dict) + + # Enable uirevision to preserve user-interaction state + # Don't use reset_nclicks directly because 0 is treated as no revision + fig_dict['layout']['uirevision'] = "reset-" + str(reset_nclicks) + + # Remove range specification so plotly.js autorange + uirevision is in control + for k in fig_dict['layout']: + if k.startswith('xaxis') or k.startswith('yaxis'): + fig_dict['layout'][k].pop('range', None) + + # Remove figure width height, let container decide + fig_dict['layout'].pop('width', None) + fig_dict['layout'].pop('height', None) + + # Pass to figure constructor to expand magic underscore notation + return go.Figure(fig_dict) + + +def to_function_spec(hvobj): + """ + Convert Dynamic HoloViews object into a pure function that accepts kdim values + and stream contents as positional arguments. + + This borrows the low-level holoviews decollate logic, but instead of returning + DynamicMap with cloned streams, returns a HoloViewsFunctionSpec. + + Args: + hvobj: A potentially dynamic Holoviews object + + Returns: + HoloViewsFunctionSpec + """ + kdims_list = [] + original_streams = [] + streams = [] + stream_mapping = {} + initialize_dynamic(hvobj) + expr = to_expr_extract_streams( + hvobj, kdims_list, streams, original_streams, stream_mapping + ) + expr_fn = expr_to_fn_of_stream_contents(expr, nkdims=len(kdims_list)) + + # Check for unbounded dimensions + if isinstance(hvobj, DynamicMap) and hvobj.unbounded: + dims = ', '.join('%r' % dim for dim in hvobj.unbounded) + msg = ('DynamicMap cannot be displayed without explicit indexing ' + 'as {dims} dimension(s) are unbounded. ' + '\nSet dimensions bounds with the DynamicMap redim.range ' + 'or redim.values methods.') + raise ValueError(msg.format(dims=dims)) + + # Build mapping from kdims to values/range + dimensions_dict = {d.name: d for d in hvobj.dimensions()} + kdims = OrderedDict() + for k in kdims_list: + dim = dimensions_dict[k.name] + label = dim.label or dim.name + kdims[k.name] = label, dim.values or dim.range + + return HoloViewsFunctionSpec(fn=expr_fn, kdims=kdims, streams=original_streams) + + +def populate_store_with_stream_contents( + store_data, streams +): + """ + Add contents of streams to the store dictionary + + Args: + store_data: The store dictionary + streams: List of streams whose contents should be added to the store + + Returns: + None + """ + for stream in streams: + # Add stream + store_data["streams"][id(stream)] = copy.deepcopy(stream.contents) + if isinstance(stream, Derived): + populate_store_with_stream_contents(store_data, stream.input_streams) + elif isinstance(stream, History): + populate_store_with_stream_contents(store_data, [stream.input_stream]) + + +def build_derived_callback(derived_stream): + """ + Build StreamCallback for Derived stream + + Args: + derived_stream: A Derived stream + + Returns: + StreamCallback + """ + input_ids = [id(stream) for stream in derived_stream.input_streams] + constants = copy.copy(derived_stream.constants) + transform = derived_stream.transform_function + + def derived_callback(*stream_values): + return transform(stream_values=stream_values, constants=constants) + + return StreamCallback( + input_ids=input_ids, fn=derived_callback, output_id=id(derived_stream) + ) + + +def build_history_callback(history_stream): + """ + Build StreamCallback for History stream + + Args: + history_stream: A History stream + + Returns: + StreamCallback + """ + history_id = id(history_stream) + input_stream_id = id(history_stream.input_stream) + + def history_callback(prior_value, input_value): + new_value = copy.deepcopy(prior_value) + new_value["values"].append(input_value) + return new_value + + return StreamCallback( + input_ids=[history_id, input_stream_id], + fn=history_callback, + output_id=history_id + ) + + +def populate_stream_callback_graph(stream_callbacks, streams): + """ + Populate the stream_callbacks OrderedDict with StreamCallback instances + associated with all of the History and Derived streams in input stream list. + + Input streams to any History or Derived streams are processed recursively + + Args: + stream_callbacks: OrderedDict from id(stream) to StreamCallbacks the should + be populated. Order will be a breadth-first traversal of the provided + streams list, and any input streams that these depend on. + + streams: List of streams to build StreamCallbacks from + + Returns: + None + """ + for stream in streams: + if isinstance(stream, Derived): + cb = build_derived_callback(stream) + if cb.output_id not in stream_callbacks: + stream_callbacks[cb.output_id] = cb + populate_stream_callback_graph(stream_callbacks, stream.input_streams) + elif isinstance(stream, History): + cb = build_history_callback(stream) + if cb.output_id not in stream_callbacks: + stream_callbacks[cb.output_id] = cb + populate_stream_callback_graph(stream_callbacks, [stream.input_stream]) + + +def encode_store_data(store_data): + """ + Encode store_data dict into a JSON serializable dict + + This is currently done by pickling store_data and converting to a base64 encoded + string. If HoloViews supports JSON serialization in the future, this method could + be updated to use this approach instead + + Args: + store_data: dict potentially containing HoloViews objects + + Returns: + dict that can be JSON serialized + """ + return {"pickled": base64.b64encode(pickle.dumps(store_data)).decode("utf-8")} + + +def decode_store_data(store_data): + """ + Decode a dict that was encoded by the encode_store_data function. + + Args: + store_data: dict that was encoded by encode_store_data + + Returns: + decoded dict + """ + return pickle.loads(base64.b64decode(store_data["pickled"])) + + +def to_dash(app, hvobjs, reset_button=False, graph_class=dcc.Graph): + """ + Build Dash components and callbacks from a collection of HoloViews objects + + Args: + app: dash.Dash application instance + hvobjs: List of HoloViews objects to build Dash components from + reset_button: If True, construct a Button component that, which clicked, will + reset the interactive stream values associated with the provided HoloViews + objects to their initial values. Defaults to False. + graph_class: Class to use when creating Graph components, one of dcc.Graph + (default) or ddk.Graph. + + Returns: + DashComponents named tuple with properties: + - graphs: List of graph components (with type matching the input + graph_class argument) with order corresponding to the order + of the input hvobjs list. + - resets: List of reset buttons that can be used to reset figure state. + List has length 1 if reset_button=True and is empty if + reset_button=False. + - kdims: Dict from kdim names to Dash Components that can be used to + set the corresponding kdim value. + - store: dcc.Store the must be included in the app layout + - children: Single list of all components above. The order is graphs, + kdims, resets, and then the store. + """ + # Number of figures + num_figs = len(hvobjs) + + # Initialize component properties + reset_components = [] + graph_components = [] + kdim_components = {} + + # Initialize inputs / outputs / states list + outputs = [] + inputs = [] + states = [] + + # Initialize other + plots = [] + graph_ids = [] + initial_fig_dicts = [] + all_kdims = OrderedDict() + kdims_per_fig = [] + + # Initialize stream mappings + uid_to_stream_ids = {} + fig_to_fn_stream = {} + fig_to_fn_stream_ids = {} + + # Plotly stream types + plotly_stream_types = [ + RangeXYCallback, RangeXCallback, RangeYCallback, Selection1DCallback, + BoundsXYCallback, BoundsXCallback, BoundsYCallback + ] + + for i, hvobj in enumerate(hvobjs): + + fn_spec = to_function_spec(hvobj) + + fig_to_fn_stream[i] = fn_spec + kdims_per_fig.append(list(fn_spec.kdims)) + all_kdims.update(fn_spec.kdims) + + # Convert to figure once so that we can map streams to axes + plot = PlotlyRenderer.get_plot(hvobj) + plots.append(plot) + + fig = plot_to_figure(plot, reset_nclicks=0).to_dict() + initial_fig_dicts.append(fig) + + # Build graphs + graph_id = 'graph-' + str(uuid.uuid4()) + graph_ids.append(graph_id) + graph = graph_class( + id=graph_id, + figure=fig, + config={"scrollZoom": True} + ) + graph_components.append(graph) + + # Build dict from trace uid to plotly callback object + plotly_streams = {} + for plotly_stream_type in plotly_stream_types: + for t in fig["data"]: + if t.get("uid", None) in plotly_stream_type.instances: + plotly_streams.setdefault(plotly_stream_type, {})[t["uid"]] = \ + plotly_stream_type.instances[t["uid"]] + + # Build dict from trace uid to list of connected HoloViews streams + for plotly_stream_type, streams_for_type in plotly_streams.items(): + for uid, cb in streams_for_type.items(): + uid_to_stream_ids.setdefault( + plotly_stream_type, {} + ).setdefault(uid, []).extend( + [id(stream) for stream in cb.streams] + ) + + outputs.append(Output(component_id=graph_id, component_property='figure')) + inputs.extend([ + Input(component_id=graph_id, component_property='selectedData'), + Input(component_id=graph_id, component_property='relayoutData') + ]) + + # Build Store and State list + store_data = {"streams": {}} + store_id = 'store-' + str(uuid.uuid4()) + states.append(State(store_id, 'data')) + + # Store holds mapping from id(stream) -> stream.contents for: + # - All extracted streams (including derived) + # - All input streams for History and Derived streams. + for fn_spec in fig_to_fn_stream.values(): + populate_store_with_stream_contents(store_data, fn_spec.streams) + + # Initialize empty list of (input_ids, output_id, fn) triples. For each + # Derived/History stream, prepend list with triple. Process in + # breadth-first order so all inputs to a triple are guaranteed to be earlier + # in the list. History streams will input and output their own id, which is + # fine. + stream_callbacks = OrderedDict() + for fn_spec in fig_to_fn_stream.values(): + populate_stream_callback_graph(stream_callbacks, fn_spec.streams) + + # For each Figure function, save off list of ids for the streams whose contents + # should be passed to the function. + for i, fn_spec in fig_to_fn_stream.items(): + fig_to_fn_stream_ids[i] = fn_spec.fn, [id(stream) for stream in fn_spec.streams] + + # Add store output + store = dcc.Store( + id=store_id, + data=encode_store_data(store_data), + ) + outputs.append(Output(store_id, 'data')) + + # Save copy of initial stream contents + initial_stream_contents = copy.deepcopy(store_data["streams"]) + + # Add kdim sliders + kdim_uuids = [] + for kdim_name, (kdim_label, kdim_range) in all_kdims.items(): + slider_uuid = str(uuid.uuid4()) + slider_id = kdim_name + "-" + slider_uuid + slider_label_id = kdim_name + "-label-" + slider_uuid + kdim_uuids.append(slider_uuid) + + html_label = html.Label(id=slider_label_id, children=kdim_label) + if isinstance(kdim_range, list): + # list of slider values + slider = html.Div(children=[ + html_label, + dcc.Slider( + id=slider_id, + min=kdim_range[0], + max=kdim_range[-1], + step=None, + marks={ + m: "" for m in kdim_range + }, + value=kdim_range[0] + )]) + else: + # Range of slider values + slider = html.Div(children=[ + html_label, + dcc.Slider( + id=slider_id, + min=kdim_range[0], + max=kdim_range[-1], + step=(kdim_range[-1] - kdim_range[0]) / 11.0, + value=kdim_range[0] + )]) + kdim_components[kdim_name] = slider + inputs.append(Input(component_id=slider_id, component_property="value")) + + # Add reset button + if reset_button: + reset_id = 'reset-' + str(uuid.uuid4()) + reset_button = html.Button(id=reset_id, children="Reset") + inputs.append(Input( + component_id=reset_id, component_property='n_clicks' + )) + reset_components.append(reset_button) + + # Register Graphs/Store callback + @app.callback( + outputs, inputs, states + ) + def update_figure(*args): + triggered_prop_ids = {entry["prop_id"] for entry in callback_context.triggered} + + # Unpack args + selected_dicts = [args[j] or {} for j in range(0, num_figs * 2, 2)] + relayout_dicts = [args[j] or {} for j in range(1, num_figs * 2, 2)] + + # Get kdim values + kdim_values = {} + for i, kdim in zip( + range(num_figs * 2, num_figs * 2 + len(all_kdims)), + all_kdims + ): + kdim_values[kdim] = args[i] + + # Get store + store_data = decode_store_data(args[-1]) + reset_nclicks = 0 + if reset_button: + reset_nclicks = args[-2] or 0 + prior_reset_nclicks = store_data.get("reset_nclicks", 0) + if reset_nclicks != prior_reset_nclicks: + store_data["reset_nclicks"] = reset_nclicks + + # clear stream values + store_data["streams"] = copy.deepcopy(initial_stream_contents) + selected_dicts = [None for _ in selected_dicts] + relayout_dicts = [None for _ in relayout_dicts] + + # Init store data + if store_data is None: + store_data = {"streams": {}} + + # Update store_data with interactive stream values + for fig_ind, fig_dict in enumerate(initial_fig_dicts): + graph_id = graph_ids[fig_ind] + # plotly_stream_types + for plotly_stream_type, uid_to_streams_for_type in uid_to_stream_ids.items(): + panel_prop = plotly_stream_type.callback_property + if panel_prop == "selected_data": + if graph_id + ".selectedData" in triggered_prop_ids: + # Only update selectedData values that just changed. + # This way we don't the the may have been cleared in the + # store above + stream_event_data = plotly_stream_type.get_event_data_from_property_update( + selected_dicts[fig_ind], initial_fig_dicts[fig_ind] + ) + for uid, event_data in stream_event_data.items(): + if uid in uid_to_streams_for_type: + for stream_id in uid_to_streams_for_type[uid]: + store_data["streams"][stream_id] = event_data + elif panel_prop == "viewport": + if graph_id + ".relayoutData" in triggered_prop_ids: + stream_event_data = plotly_stream_type.get_event_data_from_property_update( + relayout_dicts[fig_ind], initial_fig_dicts[fig_ind] + ) + + for uid, event_data in stream_event_data.items(): + if event_data["x_range"] is not None or event_data["y_range"] is not None: + if uid in uid_to_streams_for_type: + for stream_id in uid_to_streams_for_type[uid]: + store_data["streams"][ + stream_id] = event_data + + # Update store with derived/history stream values + for output_id in reversed(stream_callbacks): + stream_callback = stream_callbacks[output_id] + input_ids = stream_callback.input_ids + fn = stream_callback.fn + output_id = stream_callback.output_id + + input_values = [store_data["streams"][input_id] for input_id in input_ids] + output_value = fn(*input_values) + store_data["streams"][output_id] = output_value + + figs = [None] * num_figs + for fig_ind, (fn, stream_ids) in fig_to_fn_stream_ids.items(): + fig_kdim_values = [kdim_values[kd] for kd in kdims_per_fig[fig_ind]] + stream_values = [ + store_data["streams"][stream_id] for stream_id in stream_ids + ] + hvobj = fn(*(fig_kdim_values + stream_values)) + plot = PlotlyRenderer.get_plot(hvobj) + fig = plot_to_figure(plot, reset_nclicks=reset_nclicks).to_dict() + figs[fig_ind] = fig + + return figs + [encode_store_data(store_data)] + + # Register key dimension slider callbacks + # Install callbacks to update kdim labels based on slider values + for i, kdim_name in enumerate(all_kdims): + kdim_label = all_kdims[kdim_name][0] + kdim_slider_id = kdim_name + "-" + kdim_uuids[i] + kdim_label_id = kdim_name + "-label-" + kdim_uuids[i] + + @app.callback( + Output(component_id=kdim_label_id, component_property="children"), + [Input(component_id=kdim_slider_id, component_property="value")] + ) + def update_kdim_label(value, kdim_label=kdim_label): + return "{kdim_label}: {value:.2f}".format( + kdim_label=kdim_label, value=value + ) + + # Collect Dash components into DashComponents namedtuple + components = DashComponents( + graphs=graph_components, + kdims=kdim_components, + resets=reset_components, + store=store, + children=( + graph_components + + list(kdim_components.values()) + + reset_components + + [store] + ) + ) + + return components diff --git a/holoviews/tests/plotting/bokeh/testserver.py b/holoviews/tests/plotting/bokeh/testserver.py index b675e0ec4c..c85f7f8998 100644 --- a/holoviews/tests/plotting/bokeh/testserver.py +++ b/holoviews/tests/plotting/bokeh/testserver.py @@ -18,7 +18,7 @@ from bokeh.application import Application from bokeh.client import pull_session from bokeh.document import Document - from bokeh.io import curdoc + from bokeh.io.doc import curdoc, set_curdoc from bokeh.models import ColumnDataSource from bokeh.server.server import Server @@ -41,6 +41,8 @@ def setUp(self): if not bokeh_renderer: raise SkipTest("Bokeh required to test plot instantiation") Store.current_backend = 'bokeh' + self.doc = curdoc() + set_curdoc(Document()) self.nbcontext = Renderer.notebook_context with param.logging_level('ERROR'): Renderer.notebook_context = False @@ -53,6 +55,7 @@ def tearDown(self): Renderer.notebook_context = self.nbcontext state.curdoc = None curdoc().clear() + set_curdoc(self.doc) time.sleep(1) def test_render_server_doc_element(self): @@ -97,7 +100,7 @@ def test_set_up_linked_event_stream_on_server_doc(self): -class TestBokehServerRun(ComparisonTestCase): +class TestBokehServer(ComparisonTestCase): def setUp(self): self.previous_backend = Store.current_backend diff --git a/holoviews/tests/plotting/plotly/testdash.py b/holoviews/tests/plotting/plotly/testdash.py new file mode 100644 index 0000000000..d90f0cc8e3 --- /dev/null +++ b/holoviews/tests/plotting/plotly/testdash.py @@ -0,0 +1,494 @@ +from dash._callback_context import CallbackContext + +from .testplot import TestPlotlyPlot +from holoviews.plotting.plotly.dash import ( + to_dash, DashComponents, encode_store_data, decode_store_data +) +from holoviews import Scatter, DynamicMap, Bounds +from holoviews.streams import BoundsXY, RangeXY, Selection1D +from dash_core_components import Store +import plotly.io as pio +pio.templates.default = None + +try: + from unittest.mock import MagicMock, patch +except: + from mock import MagicMock, patch + + +class TestHoloViewsDash(TestPlotlyPlot): + + def setUp(self): + super(TestHoloViewsDash, self).setUp() + + # Build Dash app mock + self.app = MagicMock() + self.decorator = MagicMock() + self.app.callback.return_value = self.decorator + + def test_simple_element(self): + # Build Holoviews Elements + scatter = Scatter([0, 0]) + + # Convert to Dash + components = to_dash(self.app, [scatter]) + + # Check returned components + self.assertIsInstance(components, DashComponents) + self.assertEqual(len(components.graphs), 1) + self.assertEqual(len(components.kdims), 0) + self.assertIsInstance(components.store, Store) + self.assertEqual(len(components.resets), 0) + + callback_fn = self.app.callback.return_value.call_args[0][0] + + # Check registered callbacks + self.assertEqual(self.app.callback.call_count, 1) + self.assertEqual(self.decorator.call_count, 1) + + store_value = encode_store_data({}) + + with patch.object(CallbackContext, "triggered", []): + [fig, new_store] = callback_fn({}, store_value) + + # Check figure returned by callback + self.assertEqual(len(fig["data"]), 1) + self.assertEqual(fig["data"][0]["type"], "scatter") + + def test_boundsxy_dynamic_map(self): + # Build Holoviews Elements + scatter = Scatter([0, 0]) + boundsxy = BoundsXY(source=scatter) + dmap = DynamicMap( + lambda bounds: Bounds(bounds) if bounds is not None else Bounds((0, 0, 0, 0)), + streams=[boundsxy] + ) + + # Convert to Dash + components = to_dash(self.app, [scatter, dmap], reset_button=True) + + # Check returned components + self.assertIsInstance(components, DashComponents) + self.assertEqual(len(components.graphs), 2) + self.assertEqual(len(components.kdims), 0) + self.assertIsInstance(components.store, Store) + self.assertEqual(len(components.resets), 1) + + # Get arguments passed to @app.callback decorator + decorator_args = list(self.app.callback.call_args_list[0])[0] + outputs, inputs, states = decorator_args + + # Check outputs + expected_outputs = [(g.id, "figure") for g in components.graphs] + \ + [(components.store.id, "data")] + self.assertEqual( + [(output.component_id, output.component_property) for output in outputs], + expected_outputs + ) + + # Check inputs + expected_inputs = [ + (g.id, prop) + for g in components.graphs + for prop in ["selectedData", "relayoutData"] + ] + [(components.resets[0].id, "n_clicks")] + + self.assertEqual( + [(ip.component_id, ip.component_property) for ip in inputs], + expected_inputs, + ) + + # Check State + expected_state = [ + (components.store.id, "data") + ] + self.assertEqual( + [(state.component_id, state.component_property) for state in states], + expected_state, + ) + + # Get callback function + callback_fn = self.app.callback.return_value.call_args[0][0] + + # mimic initial callback invocation + store_value = encode_store_data({ + "streams": {id(boundsxy): boundsxy.contents} + }) + with patch.object(CallbackContext, "triggered", []): + [fig1, fig2, new_store] = callback_fn( + {}, {}, {}, {}, None, store_value + ) + # First figure is the scatter trace + self.assertEqual(fig1["data"][0]["type"], "scatter") + + # Second figure holds the bounds element + self.assertEqual(len(fig2["data"]), 0) + self.assertEqual(len(fig2["layout"]["shapes"]), 1) + self.assertEqual( + fig2["layout"]["shapes"][0]["path"], + "M0 0L0 0L0 0L0 0L0 0Z" + ) + + # Check updated store + self.assertEqual( + decode_store_data(new_store), + {"streams": {id(boundsxy): {"bounds": None}}} + ) + + # Update store, then mimick a box selection on scatter figure + store_value = new_store + with patch.object( + CallbackContext, "triggered", + [{"prop_id": inputs[0].component_id + ".selectedData"}] + ): + [fig1, fig2, new_store] = callback_fn( + {"range": {"x": [1, 2], "y": [3, 4]}}, + {}, {}, {}, 0, store_value + ) + + # First figure is the scatter trace + self.assertEqual(fig1["data"][0]["type"], "scatter") + + # Second figure holds the bounds element + self.assertEqual(len(fig2["data"]), 0) + self.assertEqual(len(fig2["layout"]["shapes"]), 1) + self.assertEqual( + fig2["layout"]["shapes"][0]["path"], + "M1 3L1 4L2 4L2 3L1 3Z", + ) + + # Check that store was updated + self.assertEqual( + decode_store_data(new_store), + {"streams": {id(boundsxy): {"bounds": (1, 3, 2, 4)}}} + ) + + # Click reset button + with patch.object( + CallbackContext, "triggered", + [{"prop_id": components.resets[0].id + ".n_clicks"}] + ): + [fig1, fig2, new_store] = callback_fn( + {"range": {"x": [1, 2], "y": [3, 4]}}, {}, + {}, {}, 1, + store_value + ) + + # First figure is the scatter trace + self.assertEqual(fig1["data"][0]["type"], "scatter") + + # Second figure holds reset bounds elemnt + self.assertEqual(len(fig2["data"]), 0) + self.assertEqual(len(fig2["layout"]["shapes"]), 1) + self.assertEqual( + fig2["layout"]["shapes"][0]["path"], + "M0 0L0 0L0 0L0 0L0 0Z" + ) + + # Reset button should clear bounds in store + self.assertEqual( + decode_store_data(new_store), + {"streams": {id(boundsxy): {"bounds": None}}, + "reset_nclicks": 1} + ) + + def test_rangexy_dynamic_map(self): + + # Create dynamic map that inputs rangexy, returns scatter on bounds + scatter = Scatter( + [[0, 1], [0, 1]], kdims=["x"], vdims=["y"] + ) + rangexy = RangeXY(source=scatter) + + def dmap_fn(x_range, y_range): + x_range = (0, 1) if x_range is None else x_range + y_range = (0, 1) if y_range is None else y_range + return Scatter( + [[x_range[0], y_range[0]], + [x_range[1], y_range[1]]], kdims=["x1"], vdims=["y1"] + ) + + dmap = DynamicMap(dmap_fn, streams=[rangexy]) + + # Convert to Dash + components = to_dash(self.app, [scatter, dmap], reset_button=True) + + # Check returned components + self.assertIsInstance(components, DashComponents) + self.assertEqual(len(components.graphs), 2) + self.assertEqual(len(components.kdims), 0) + self.assertIsInstance(components.store, Store) + self.assertEqual(len(components.resets), 1) + + # Get arguments passed to @app.callback decorator + decorator_args = list(self.app.callback.call_args_list[0])[0] + outputs, inputs, states = decorator_args + + # Check outputs + expected_outputs = [(g.id, "figure") for g in components.graphs] + \ + [(components.store.id, "data")] + self.assertEqual( + [(output.component_id, output.component_property) for output in outputs], + expected_outputs + ) + + # Check inputs + expected_inputs = [ + (g.id, prop) + for g in components.graphs + for prop in ["selectedData", "relayoutData"] + ] + [(components.resets[0].id, "n_clicks")] + + self.assertEqual( + [(ip.component_id, ip.component_property) for ip in inputs], + expected_inputs, + ) + + # Check State + expected_state = [ + (components.store.id, "data") + ] + self.assertEqual( + [(state.component_id, state.component_property) for state in states], + expected_state, + ) + + # Get callback function + callback_fn = self.app.callback.return_value.call_args[0][0] + + # mimic initial callback invocation + store_value = encode_store_data({ + "streams": {id(rangexy): rangexy.contents} + }) + with patch.object( + CallbackContext, + "triggered", + [{"prop_id": components.graphs[0].id + ".relayoutData"}] + ): + [fig1, fig2, new_store] = callback_fn( + {}, { + "xaxis.range[0]": 1, + "xaxis.range[1]": 3, + "yaxis.range[0]": 2, + "yaxis.range[1]": 4 + }, + {}, {}, None, store_value + ) + + # First figure is the scatter trace + self.assertEqual(fig1["data"][0]["type"], "scatter") + + # Second figure holds the bounds element + self.assertEqual(len(fig2["data"]), 1) + self.assertEqual(list(fig2["data"][0]["x"]), [1, 3]) + self.assertEqual(list(fig2["data"][0]["y"]), [2, 4]) + + # Check updated store + self.assertEqual( + decode_store_data(new_store), + {"streams": {id(rangexy): {'x_range': (1, 3), 'y_range': (2, 4)}}} + ) + + def test_selection1d_dynamic_map(self): + # Create dynamic map that inputs selection1d, returns overlay of scatter on + # selected points + scatter = Scatter([[0, 0], [1, 1], [2, 2]]) + selection1d = Selection1D(source=scatter) + dmap = DynamicMap( + lambda index: scatter.iloc[index].opts(size=len(index) + 1), + streams=[selection1d] + ) + + # Convert to Dash + components = to_dash(self.app, [scatter, dmap], reset_button=True) + + # Check returned components + self.assertIsInstance(components, DashComponents) + self.assertEqual(len(components.graphs), 2) + self.assertEqual(len(components.kdims), 0) + self.assertIsInstance(components.store, Store) + self.assertEqual(len(components.resets), 1) + + # Get arguments passed to @app.callback decorator + decorator_args = list(self.app.callback.call_args_list[0])[0] + outputs, inputs, states = decorator_args + + # Check outputs + expected_outputs = [(g.id, "figure") for g in components.graphs] + \ + [(components.store.id, "data")] + self.assertEqual( + [(output.component_id, output.component_property) for output in outputs], + expected_outputs + ) + + # Check inputs + expected_inputs = [ + (g.id, prop) + for g in components.graphs + for prop in ["selectedData", "relayoutData"] + ] + [(components.resets[0].id, "n_clicks")] + + self.assertEqual( + [(ip.component_id, ip.component_property) for ip in inputs], + expected_inputs, + ) + + # Check State + expected_state = [ + (components.store.id, "data") + ] + self.assertEqual( + [(state.component_id, state.component_property) for state in states], + expected_state, + ) + + # Get callback function + callback_fn = self.app.callback.return_value.call_args[0][0] + + # mimic initial callback invocation + store_value = encode_store_data({ + "streams": {id(selection1d): selection1d.contents} + }) + with patch.object(CallbackContext, "triggered", []): + [fig1, fig2, new_store] = callback_fn( + {}, {}, None, store_value + ) + + # Figure holds the scatter trace + self.assertEqual(len(fig2["data"]), 1) + + # Check expected marker size + self.assertEqual(fig2["data"][0]["marker"]["size"], 1) + self.assertEqual(list(fig2["data"][0]["x"]), []) + self.assertEqual(list(fig2["data"][0]["y"]), []) + + # Check updated store + self.assertEqual( + decode_store_data(new_store), + {"streams": {id(selection1d): {"index": []}}} + ) + + # Update store, then mimick a selection on scatter figure + store_value = new_store + with patch.object( + CallbackContext, "triggered", + [{"prop_id": inputs[0].component_id + ".selectedData"}] + ): + [fig1, fig2, new_store] = callback_fn( + {"points": [ + { + "curveNumber": 0, + "pointNumber": 0, + "pointIndex": 0, + }, + { + "curveNumber": 0, + "pointNumber": 2, + "pointIndex": 2, + } + ]}, + {}, 0, store_value + ) + + # Figure holds the scatter trace + self.assertEqual(len(fig2["data"]), 1) + + # Check expected marker size + self.assertEqual(fig2["data"][0]["marker"]["size"], 3) + self.assertEqual(list(fig2["data"][0]["x"]), [0, 2]) + self.assertEqual(list(fig2["data"][0]["y"]), [0, 2]) + + # Check that store was updated + self.assertEqual( + decode_store_data(new_store), + {"streams": {id(selection1d): {"index": [0, 2]}}} + ) + + # Click reset button + store = new_store + with patch.object( + CallbackContext, "triggered", + [{"prop_id": components.resets[0].id + ".n_clicks"}] + ): + [fig1, fig2, new_store] = callback_fn( + {}, {}, 1, store + ) + + # Figure holds the scatter trace + self.assertEqual(len(fig2["data"]), 1) + + # Check expected marker size + self.assertEqual(fig2["data"][0]["marker"]["size"], 1) + self.assertEqual(list(fig2["data"][0]["x"]), []) + self.assertEqual(list(fig2["data"][0]["y"]), []) + + # Check that store was updated + self.assertEqual( + decode_store_data(new_store), + {"streams": {id(selection1d): {"index": []}}, 'reset_nclicks': 1}, + ) + + def test_kdims_dynamic_map(self): + # Dynamic map with two key dimensions + dmap = DynamicMap( + lambda kdim1: Scatter([kdim1, kdim1]), + kdims=["kdim1"] + ).redim.values(kdim1=[1, 2, 3, 4]) + + # Convert to Dash + components = to_dash(self.app, [dmap]) + + # Check returned components + self.assertIsInstance(components, DashComponents) + self.assertEqual(len(components.graphs), 1) + self.assertEqual(len(components.kdims), 1) + self.assertIsInstance(components.store, Store) + self.assertEqual(len(components.resets), 0) + + # Get arguments passed to @app.callback decorator + decorator_args = list(self.app.callback.call_args_list[0])[0] + outputs, inputs, states = decorator_args + + # Check outputs + expected_outputs = [(g.id, "figure") for g in components.graphs] + \ + [(components.store.id, "data")] + self.assertEqual( + [(output.component_id, output.component_property) for output in outputs], + expected_outputs + ) + + # Check inputs + expected_inputs = [ + (g.id, prop) + for g in components.graphs + for prop in ["selectedData", "relayoutData"] + ] + [(list(components.kdims.values())[0].children[1].id, 'value')] + + self.assertEqual( + [(ip.component_id, ip.component_property) for ip in inputs], + expected_inputs, + ) + + # Check State + expected_state = [ + (components.store.id, "data") + ] + self.assertEqual( + [(state.component_id, state.component_property) for state in states], + expected_state, + ) + + # Get callback function + callback_fn = self.decorator.call_args_list[0][0][0] + + # mimic initial callback invocation + store_value = encode_store_data({"streams": {}}) + with patch.object(CallbackContext, "triggered", []): + [fig, new_store] = callback_fn( + {}, {}, 3, None, store_value + ) + + # First figure is the scatter trace + self.assertEqual(fig["data"][0]["type"], "scatter") + self.assertEqual(list(fig["data"][0]["x"]), [0, 1]) + self.assertEqual(list(fig["data"][0]["y"]), [3, 3]) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000..afe715968e --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = -p no:dash diff --git a/setup.py b/setup.py index c538adcf17..2011a943b9 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ 'pillow', 'xarray >=0.10.4', 'plotly >=4.0', + 'dash >=1.16', 'streamz >=0.5.0', 'datashader', 'ffmpeg', @@ -67,7 +68,7 @@ 'mock', 'flake8 ==3.6.0', 'coveralls', - 'path.py', + 'path.py', 'matplotlib >=2.2,<3.1', 'nbsmoke >=0.2.0', 'pytest-cov ==2.5.1',