Skip to content

Commit

Permalink
Add online JupyterChart widget based on AnyWidget (#3119)
Browse files Browse the repository at this point in the history
* Add JupyterChart based on AnyWidget

* Store params in a traitlet object

* Make selections prop a dynamic traitlet class

* Skip vegafusion test when not installed

* Show errors the same was as the HTML renderer

* Update altair/jupyter/js/README.md

* Import JupyterChart in else

* import from top-level

* Move non-widget selection logic to `util.selection`

* Use lodash's debounce for maxWait functionality

---------

Co-authored-by: Mattijn van Hoek <[email protected]>
  • Loading branch information
jonmmease and mattijn authored Aug 2, 2023
1 parent 157c4e3 commit bfd68e4
Show file tree
Hide file tree
Showing 9 changed files with 739 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ Untitled*.ipynb
.vscode

# hatch, doc generation
data.json
data.json
3 changes: 3 additions & 0 deletions altair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
"JoinAggregateFieldDef",
"JoinAggregateTransform",
"JsonDataFormat",
"JupyterChart",
"Key",
"LabelOverlap",
"LatLongDef",
Expand Down Expand Up @@ -569,6 +570,7 @@
"expr",
"graticule",
"hconcat",
"jupyter",
"layer",
"limit_rows",
"load_ipython_extension",
Expand Down Expand Up @@ -607,6 +609,7 @@ def __dir__():


from .vegalite import *
from .jupyter import JupyterChart


def load_ipython_extension(ipython):
Expand Down
20 changes: 20 additions & 0 deletions altair/jupyter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
try:
import anywidget # noqa: F401
except ImportError:
# When anywidget isn't available, create stand-in JupyterChart class
# that raises an informative import error on construction. This
# way we can make JupyterChart available in the altair namespace
# when anywidget is not installed
class JupyterChart:
def __init__(self, *args, **kwargs):
raise ImportError(
"The Altair JupyterChart requires the anywidget \n"
"Python package which may be installed using pip with\n"
" pip install anywidget\n"
"or using conda with\n"
" conda install -c conda-forge anywidget\n"
"Afterwards, you will need to restart your Python kernel."
)

else:
from .jupyter_chart import JupyterChart # noqa: F401
2 changes: 2 additions & 0 deletions altair/jupyter/js/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# JupyterChart
This directory contains the JavaScript portion of the Altair `JupyterChart`. The `JupyterChart` is based on the [AnyWidget](https://anywidget.dev/) project.
80 changes: 80 additions & 0 deletions altair/jupyter/js/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import embed from "https://cdn.jsdelivr.net/npm/vega-embed@6/+esm";
import { debounce } from "https://cdn.jsdelivr.net/npm/[email protected]/lodash.js"

export async function render({ model, el }) {
let finalize;

function showError(error){
el.innerHTML = (
'<div style="color:red;">'
+ '<p>JavaScript Error: ' + error.message + '</p>'
+ "<p>This usually means there's a typo in your chart specification. "
+ "See the javascript console for the full traceback.</p>"
+ '</div>'
);
}

const reembed = async () => {
if (finalize != null) {
finalize();
}

let spec = model.get("spec");
let api;
try {
api = await embed(el, spec);
} catch (error) {
showError(error)
return;
}

finalize = api.finalize;

// Debounce config
const wait = model.get("debounce_wait") ?? 10;
const maxWait = wait;

const initialSelections = {};
for (const selectionName of Object.keys(model.get("_vl_selections"))) {
const selectionHandler = (_, value) => {
const newSelections = JSON.parse(JSON.stringify(model.get("_vl_selections"))) || {};
const store = JSON.parse(JSON.stringify(api.view.data(`${selectionName}_store`)));

newSelections[selectionName] = {value, store};
model.set("_vl_selections", newSelections);
model.save_changes();
};
api.view.addSignalListener(selectionName, debounce(selectionHandler, wait, {maxWait}));

initialSelections[selectionName] = {value: {}, store: []}
}
model.set("_vl_selections", initialSelections);

const initialParams = {};
for (const paramName of Object.keys(model.get("_params"))) {
const paramHandler = (_, value) => {
const newParams = JSON.parse(JSON.stringify(model.get("_params"))) || {};
newParams[paramName] = value;
model.set("_params", newParams);
model.save_changes();
};
api.view.addSignalListener(paramName, debounce(paramHandler, wait, {maxWait}));

initialParams[paramName] = api.view.signal(paramName) ?? null
}
model.set("_params", initialParams);
model.save_changes();

// Param change callback
model.on('change:_params', async (new_params) => {
for (const [param, value] of Object.entries(new_params.changed._params)) {
api.view.signal(param, value);
}
await api.view.runAsync();
});
}

model.on('change:spec', reembed);
model.on('change:debounce_wait', reembed);
await reembed();
}
238 changes: 238 additions & 0 deletions altair/jupyter/jupyter_chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import anywidget
import traitlets
import pathlib
from typing import Any

import altair as alt
from altair.utils._vegafusion_data import using_vegafusion
from altair import TopLevelSpec
from altair.utils.selection import IndexSelection, PointSelection, IntervalSelection

_here = pathlib.Path(__file__).parent


class Params(traitlets.HasTraits):
"""
Traitlet class storing a JupyterChart's params
"""

def __init__(self, trait_values):
super().__init__()

for key, value in trait_values.items():
if isinstance(value, int):
traitlet_type = traitlets.Int()
elif isinstance(value, float):
traitlet_type = traitlets.Float()
elif isinstance(value, str):
traitlet_type = traitlets.Unicode()
elif isinstance(value, list):
traitlet_type = traitlets.List()
elif isinstance(value, dict):
traitlet_type = traitlets.Dict()
else:
raise ValueError(f"Unexpected param type: {type(value)}")

# Add the new trait.
self.add_traits(**{key: traitlet_type})

# Set the trait's value.
setattr(self, key, value)

def __repr__(self):
return f"Params({self.trait_values()})"


class Selections(traitlets.HasTraits):
"""
Traitlet class storing a JupyterChart's selections
"""

def __init__(self, trait_values):
super().__init__()

for key, value in trait_values.items():
if isinstance(value, IndexSelection):
traitlet_type = traitlets.Instance(IndexSelection)
elif isinstance(value, PointSelection):
traitlet_type = traitlets.Instance(PointSelection)
elif isinstance(value, IntervalSelection):
traitlet_type = traitlets.Instance(IntervalSelection)
else:
raise ValueError(f"Unexpected selection type: {type(value)}")

# Add the new trait.
self.add_traits(**{key: traitlet_type})

# Set the trait's value.
setattr(self, key, value)

# Make read-only
self.observe(self._make_read_only, names=key)

def __repr__(self):
return f"Selections({self.trait_values()})"

def _make_read_only(self, change):
"""
Work around to make traits read-only, but still allow us to change
them internally
"""
if change["name"] in self.traits() and change["old"] != change["new"]:
self._set_value(change["name"], change["old"])
raise ValueError(
"Selections may not be set from Python.\n"
f"Attempted to set select: {change['name']}"
)

def _set_value(self, key, value):
self.unobserve(self._make_read_only, names=key)
setattr(self, key, value)
self.observe(self._make_read_only, names=key)


class JupyterChart(anywidget.AnyWidget):
_esm = _here / "js" / "index.js"
_css = r"""
.vega-embed {
/* Make sure action menu isn't cut off */
overflow: visible;
}
"""

# Public traitlets
chart = traitlets.Instance(TopLevelSpec)
spec = traitlets.Dict().tag(sync=True)
debounce_wait = traitlets.Float(default_value=10).tag(sync=True)

# Internal selection traitlets
_selection_types = traitlets.Dict()
_vl_selections = traitlets.Dict().tag(sync=True)

# Internal param traitlets
_params = traitlets.Dict().tag(sync=True)

def __init__(self, chart: TopLevelSpec, debounce_wait: int = 10, **kwargs: Any):
"""
Jupyter Widget for displaying and updating Altair Charts, and
retrieving selection and parameter values
Parameters
----------
chart: Chart
Altair Chart instance
debounce_wait: int
Debouncing wait time in milliseconds
"""
self.params = Params({})
self.selections = Selections({})
super().__init__(chart=chart, debounce_wait=debounce_wait, **kwargs)

@traitlets.observe("chart")
def _on_change_chart(self, change):
"""
Internal callback function that updates the JupyterChart's internal
state when the wrapped Chart instance changes
"""
new_chart = change.new

params = getattr(new_chart, "params", [])
selection_watches = []
selection_types = {}
initial_params = {}
initial_vl_selections = {}
empty_selections = {}

if params is not alt.Undefined:
for param in new_chart.params:
select = getattr(param, "select", alt.Undefined)

if select != alt.Undefined:
if not isinstance(select, dict):
select = select.to_dict()

select_type = select["type"]
if select_type == "point":
if not (
select.get("fields", None) or select.get("encodings", None)
):
# Point selection with no associated fields or encodings specified.
# This is an index-based selection
selection_types[param.name] = "index"
empty_selections[param.name] = IndexSelection(
name=param.name, value=[], store=[]
)
else:
selection_types[param.name] = "point"
empty_selections[param.name] = PointSelection(
name=param.name, value=[], store=[]
)
elif select_type == "interval":
selection_types[param.name] = "interval"
empty_selections[param.name] = IntervalSelection(
name=param.name, value={}, store=[]
)
else:
raise ValueError(f"Unexpected selection type {select.type}")
selection_watches.append(param.name)
initial_vl_selections[param.name] = {"value": None, "store": []}
else:
clean_value = param.value if param.value != alt.Undefined else None
initial_params[param.name] = clean_value

# Setup params
self.params = Params(initial_params)

def on_param_traitlet_changed(param_change):
new_params = dict(self._params)
new_params[param_change["name"]] = param_change["new"]
self._params = new_params

self.params.observe(on_param_traitlet_changed)

# Setup selections
self.selections = Selections(empty_selections)

# Update properties all together
with self.hold_sync():
if using_vegafusion():
self.spec = new_chart.to_dict(format="vega")
else:
self.spec = new_chart.to_dict()
self._selection_types = selection_types
self._vl_selections = initial_vl_selections
self._params = initial_params

@traitlets.observe("_params")
def _on_change_params(self, change):
for param_name, value in change.new.items():
setattr(self.params, param_name, value)

@traitlets.observe("_vl_selections")
def _on_change_selections(self, change):
"""
Internal callback function that updates the JupyterChart's public
selections traitlet in response to changes that the JavaScript logic
makes to the internal _selections traitlet.
"""
for selection_name, selection_dict in change.new.items():
value = selection_dict["value"]
store = selection_dict["store"]
selection_type = self._selection_types[selection_name]
if selection_type == "index":
self.selections._set_value(
selection_name,
IndexSelection.from_vega(selection_name, signal=value, store=store),
)
elif selection_type == "point":
self.selections._set_value(
selection_name,
PointSelection.from_vega(selection_name, signal=value, store=store),
)
elif selection_type == "interval":
self.selections._set_value(
selection_name,
IntervalSelection.from_vega(
selection_name, signal=value, store=store
),
)
Loading

0 comments on commit bfd68e4

Please sign in to comment.