Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embed widget state in notebook on execute #900

Merged
merged 16 commits into from
Mar 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions nbconvert/preprocessors/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

import base64
from textwrap import dedent
from contextlib import contextmanager

Expand Down Expand Up @@ -172,6 +172,15 @@ class ExecutePreprocessor(Preprocessor):
)
).tag(config=True)

store_widget_state = Bool(True,
help=dedent(
"""
If `True` (default), then the state of the Jupyter widgets created
at the kernel will be stored in the metadata of the notebook.
"""
)
).tag(config=True)

iopub_timeout = Integer(4, allow_none=False,
help=dedent(
"""
Expand Down Expand Up @@ -292,6 +301,8 @@ def setup_preprocessor(self, nb, resources, km=None):
self.nb = nb
# clear display_id map
self._display_id_map = {}
self.widget_state = {}
self.widget_buffers = {}

if km is None:
self.km, self.kc = self.start_new_kernel(cwd=path)
Expand Down Expand Up @@ -354,9 +365,27 @@ def preprocess(self, nb, resources, km=None):
nb, resources = super(ExecutePreprocessor, self).preprocess(nb, resources)
info_msg = self._wait_for_reply(self.kc.kernel_info())
nb.metadata['language_info'] = info_msg['content']['language_info']
self.set_widgets_metadata()

return nb, resources

def set_widgets_metadata(self):
if self.widget_state:
self.nb.metadata.widgets = {
'application/vnd.jupyter.widget-state+json': {
'state': {
model_id: _serialize_widget_state(state)
for model_id, state in self.widget_state.items() if '_model_name' in state
},
'version_major': 2,
'version_minor': 0,
}
}
for key, widget in self.nb.metadata.widgets['application/vnd.jupyter.widget-state+json']['state'].items():
buffers = self.widget_buffers.get(key)
if buffers:
widget['buffers'] = buffers

def preprocess_cell(self, cell, resources, cell_index):
"""
Executes a single code cell. See base.py for details.
Expand Down Expand Up @@ -532,7 +561,12 @@ def clear_display_id_mapping(self, cell_index):
cell_map[cell_index] = []

def handle_comm_msg(self, outs, msg, cell_index):
pass
content = msg['content']
data = content['data']
if self.store_widget_state and 'state' in data: # ignore custom msg'es
self.widget_state.setdefault(content['comm_id'], {}).update(data['state'])
if 'buffer_paths' in data and data['buffer_paths']:
self.widget_buffers[content['comm_id']] = _get_buffer_data(msg)

def executenb(nb, cwd=None, km=None, **kwargs):
"""Execute a notebook's code, updating outputs within the notebook object.
Expand All @@ -556,3 +590,26 @@ def executenb(nb, cwd=None, km=None, **kwargs):
resources['metadata'] = {'path': cwd}
ep = ExecutePreprocessor(**kwargs)
return ep.preprocess(nb, resources, km=km)[0]


def _serialize_widget_state(state):
"""Serialize a widget state, following format in @jupyter-widgets/schema."""
return {
'model_name': state.get('_model_name'),
'model_module': state.get('_model_module'),
'model_module_version': state.get('_model_module_version'),
'state': state,
}


def _get_buffer_data(msg):
encoded_buffers = []
paths = msg['content']['data']['buffer_paths']
buffers = msg['buffers']
for path, buffer in zip(paths, buffers):
encoded_buffers.append({
'data': base64.b64encode(buffer).decode('utf-8'),
'encoding': 'base64',
'path': path
})
return encoded_buffers
94 changes: 94 additions & 0 deletions nbconvert/preprocessors/tests/files/JupyterWidgets.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f46f26da84b54255bccc3a69d7eb08de",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Label(value='Hello World')"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import ipywidgets\n",
"label = ipywidgets.Label('Hello World')\n",
"label"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# it should also handle custom msg'es\n",
"label.send({'msg': 'Hello'})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {
"8273e8fe9d9941a4a63c062158e0a630": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.4.0",
"model_name": "DescriptionStyleModel",
"state": {
"description_width": ""
}
},
"a72770a4f541425f8fe85833a3dc2a8e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.4.0",
"model_name": "LabelModel",
"state": {
"context_menu": null,
"layout": "IPY_MODEL_dec20f599109458ca607b1df5959469b",
"style": "IPY_MODEL_8273e8fe9d9941a4a63c062158e0a630",
"value": "Hello World"
}
},
"dec20f599109458ca607b1df5959469b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.1.0",
"model_name": "LayoutModel",
"state": {}
}
},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
35 changes: 35 additions & 0 deletions nbconvert/preprocessors/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,13 @@ def normalize_output(output):
if 'text/plain' in output.get('data', {}):
output['data']['text/plain'] = \
re.sub(addr_pat, '<HEXADDR>', output['data']['text/plain'])
if 'application/vnd.jupyter.widget-view+json' in output.get('data', {}):
output['data']['application/vnd.jupyter.widget-view+json'] \
['model_id'] = '<MODEL_ID>'
for key, value in output.get('data', {}).items():
if isinstance(value, string_types):
if sys.version_info.major == 2:
value = value.replace('u\'', '\'')
output['data'][key] = _normalize_base64(value)
if 'traceback' in output:
tb = [
Expand Down Expand Up @@ -305,3 +310,33 @@ def test_execute_function(self):
original = copy.deepcopy(input_nb)
executed = executenb(original, os.path.dirname(filename))
self.assert_notebooks_equal(original, executed)

def test_widgets(self):
"""Runs a test notebook with widgets and checks the widget state is saved."""
input_file = os.path.join(current_dir, 'files', 'JupyterWidgets.ipynb')
opts = dict(kernel_name="python")
res = self.build_resources()
res['metadata']['path'] = os.path.dirname(input_file)
input_nb, output_nb = self.run_notebook(input_file, opts, res)

output_data = [
output.get('data', {})
for cell in output_nb['cells']
for output in cell['outputs']
]

model_ids = [
data['application/vnd.jupyter.widget-view+json']['model_id']
for data in output_data
if 'application/vnd.jupyter.widget-view+json' in data
]

wdata = output_nb['metadata']['widgets'] \
['application/vnd.jupyter.widget-state+json']
for k in model_ids:
d = wdata['state'][k]
assert 'model_name' in d
assert 'model_module' in d
assert 'state' in d
assert 'version_major' in wdata
assert 'version_minor' in wdata
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def run(self):
jupyter_client_req = 'jupyter_client>=4.2'

extra_requirements = {
'test': ['pytest', 'pytest-cov', 'ipykernel', jupyter_client_req],
'test': ['pytest', 'pytest-cov', 'ipykernel', jupyter_client_req, 'ipywidgets>=7'],
'serve': ['tornado>=4.0'],
'execute': [jupyter_client_req],
'docs': ['sphinx>=1.5.1',
Expand Down