Skip to content

Commit

Permalink
Merge pull request #951 from jakevdp/toplevel-datasets
Browse files Browse the repository at this point in the history
Toplevel datasets
  • Loading branch information
jakevdp authored Jun 27, 2018
2 parents f8843e9 + a4da4b0 commit 9eb71cd
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 18 deletions.
10 changes: 9 additions & 1 deletion altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
DataTransformerType = Callable

class DataTransformerRegistry(PluginRegistry[DataTransformerType]):
pass
_global_settings = {'consolidate_datasets': False}

@property
def consolidate_datasets(self):
return self._global_settings['consolidate_datasets']

@consolidate_datasets.setter
def consolidate_datasets(self, value):
self._global_settings['consolidate_datasets'] = value


# ==============================================================================
Expand Down
16 changes: 12 additions & 4 deletions altair/utils/plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.registry._set_state(self.original_state )
self.registry._set_state(self.original_state)

def __repr__(self):
return "{0}.enable({1!r})".format(self.registry.__class__.__name__, self.name)
Expand All @@ -54,6 +54,10 @@ class PluginRegistry(Generic[PluginType]):
# in case an entrypoint is not found
entrypoint_err_messages = {}

# global settings is a key-value mapping of settings that are stored globally
# in the registry rather than passed to the plugins
_global_settings = {}

def __init__(self, entry_point_group='', plugin_type=object):
# type: (str, Any) -> None
"""Create a PluginRegistry for a named entry point group.
Expand All @@ -72,6 +76,7 @@ def __init__(self, entry_point_group='', plugin_type=object):
self._active_name = '' # type: str
self._plugins = {} # type: dict
self._options = {} # type: dict
self._global_settings = self.__class__._global_settings.copy() # type: dict

def register(self, name, value):
# type: (str, Union[PluginType, None]) -> PluginType
Expand Down Expand Up @@ -112,12 +117,13 @@ def _get_state(self):
return {'_active': self._active,
'_active_name': self._active_name,
'_plugins': self._plugins.copy(),
'_options': self._options.copy()}
'_options': self._options.copy(),
'_global_settings': self._global_settings.copy()}

def _set_state(self, state):
"""Reset the state of the registry"""
assert set(state.keys()) == {'_active', '_active_name',
'_plugins', '_options'}
'_plugins', '_options', '_global_settings'}
for key, val in state.items():
setattr(self, key, val)

Expand All @@ -136,6 +142,8 @@ def _enable(self, name, **options):
self.register(name, value)
self._active_name = name
self._active = self._plugins[name]
for key in set(options.keys()) & set(self._global_settings.keys()):
self._global_settings[key] = options.pop(key)
self._options = options

def enable(self, name=None, **options):
Expand All @@ -161,7 +169,7 @@ def enable(self, name=None, **options):
if name is None:
name = self.active
return PluginEnabler(self, name, **options)


@property
def active(self):
Expand Down
33 changes: 32 additions & 1 deletion altair/utils/tests/test_plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ class TypedCallableRegistry(PluginRegistry[Callable[[int], int]]):


class GeneralCallableRegistry(PluginRegistry):
pass
_global_settings = {'global_setting': None}

@property
def global_setting(self):
return self._global_settings['global_setting']

@global_setting.setter
def global_setting(self, val):
self._global_settings['global_setting'] = val


def test_plugin_registry():
Expand Down Expand Up @@ -50,6 +58,29 @@ def test_plugin_registry_extra_options():
assert plugins.get()(3) == 9


def test_plugin_registry_global_settings():
plugins = GeneralCallableRegistry()

# we need some default plugin, but we won't do anything with it
plugins.register('default', lambda x: x)
plugins.enable('default')

# default value of the global flag
assert plugins.global_setting is None

# enabling changes the global state, not the options
plugins.enable(global_setting=True)
assert plugins.global_setting is True
assert plugins._options == {}

# context manager changes global state temporarily
with plugins.enable(global_setting='temp'):
assert plugins.global_setting == 'temp'
assert plugins._options == {}
assert plugins.global_setting is True
assert plugins._options == {}


def test_plugin_registry_context():
plugins = GeneralCallableRegistry()

Expand Down
77 changes: 65 additions & 12 deletions altair/vegalite/v2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import warnings

import hashlib
import json
import jsonschema
import six
import pandas as pd
Expand All @@ -15,15 +17,53 @@

# ------------------------------------------------------------------------
# Data Utilities
def _prepare_data(data):
"""Convert input data to data for use within schema"""
def _dataset_name(data):
"""Generate a unique hash of the data"""
def hash_(dct):
dct_str = json.dumps(dct, sort_keys=True)
return hashlib.md5(dct_str.encode()).hexdigest()

if isinstance(data, core.InlineData):
return 'data-' + hash_(data.values)
elif isinstance(data, dict) and 'values' in data:
return 'data-' + hash_(data['values'])
else:
raise ValueError("Cannot generate name for data {0}".format(data))


def _prepare_data(data, context):
"""Convert input data to data for use within schema
Parameters
----------
data :
The input dataset in the form of a DataFrame, dictionary, altair data
object, or other type that is recognized by the data transformers.
context : dict
The to_dict context in which the data is being prepared. This is used
to keep track of information that needs to be passed up and down the
recursive serialization routine, such as global named datasets.
"""
if data is Undefined:
return data
if isinstance(data, core.InlineData):
if data_transformers.consolidate_datasets:
name = _dataset_name(data)
context['datasets'][name] = data.values
return core.NamedData(name=name)
else:
return data
elif isinstance(data, (dict, core.Data, core.InlineData,
core.UrlData, core.NamedData)):
core.UrlData, core.NamedData)):
return data
elif isinstance(data, pd.DataFrame):
return pipe(data, data_transformers.get())
data = pipe(data, data_transformers.get())
if data_transformers.consolidate_datasets and isinstance(data, dict) and 'values' in data:
name = _dataset_name(data)
context['datasets'][name] = data['values']
return core.NamedData(name=name)
else:
return data
elif isinstance(data, six.string_types):
return core.UrlData(data)
else:
Expand All @@ -40,7 +80,8 @@ class LookupData(core.LookupData):
def to_dict(self, *args, **kwargs):
"""Convert the chart to a dictionary suitable for JSON export"""
copy = self.copy(ignore=['data'])
copy.data = _prepare_data(copy.data)
context = kwargs.get('context', {})
copy.data = _prepare_data(copy.data, context)
return super(LookupData, copy).to_dict(*args, **kwargs)


Expand Down Expand Up @@ -309,22 +350,29 @@ class TopLevelMixin(mixins.ConfigMethodMixin):

def to_dict(self, *args, **kwargs):
"""Convert the chart to a dictionary suitable for JSON export"""
copy = self.copy()
original_data = getattr(copy, 'data', Undefined)
copy.data = _prepare_data(original_data)

# We make use of two context markers:
# We make use of three context markers:
# - 'data' points to the data that should be referenced for column type
# inference.
# - 'top_level' is a boolean flag that is assumed to be true; if it's
# true then a "$schema" arg is added to the dict.
context = kwargs.get('context', {}).copy()
# - 'datasets' is a dict of named datasets that should be inserted
# in the top-level object

# note: not a deep copy because we want datasets and data arguments to
# be passed by reference
context = kwargs.get('context', {}).copy()
context.setdefault('datasets', {})
is_top_level = context.get('top_level', True)
context['top_level'] = False

copy = self.copy()
original_data = getattr(copy, 'data', Undefined)
copy.data = _prepare_data(original_data, context)

if original_data is not Undefined:
context['data'] = original_data

# remaining to_dict calls are not at top level
context['top_level'] = False
kwargs['context'] = context

try:
Expand All @@ -339,6 +387,7 @@ def to_dict(self, *args, **kwargs):
kwargs['validate'] = 'deep'
dct = super(TopLevelMixin, copy).to_dict(*args, **kwargs)

# TODO: following entries are added after validation. Should they be validated?
if is_top_level:
# since this is top-level we add $schema if it's missing
if '$schema' not in dct:
Expand All @@ -348,6 +397,10 @@ def to_dict(self, *args, **kwargs):
the_theme = themes.get()
dct = utils.update_nested(the_theme(), dct, copy=True)

# update datasets
if context['datasets']:
dct.setdefault('datasets', {}).update(context['datasets'])

return dct

def savechart(self, fp, format=None, **kwargs):
Expand Down
27 changes: 27 additions & 0 deletions altair/vegalite/v2/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,30 @@ def test_chart_from_dict():
# test that an invalid spec leads to a schema validation error
with pytest.raises(jsonschema.ValidationError):
alt.Chart.from_dict({'invalid': 'spec'})


def test_consolidate_datasets(basic_chart):
chart = basic_chart | basic_chart

with alt.data_transformers.enable(consolidate_datasets=True):
dct_consolidated = chart.to_dict()

with alt.data_transformers.enable(consolidate_datasets=False):
dct_standard = chart.to_dict()

assert 'datasets' in dct_consolidated
assert 'datasets' not in dct_standard

datasets = dct_consolidated['datasets']

# two dataset copies should be recognized as duplicates
assert len(datasets) == 1

# make sure data matches original & names are correct
name, data = datasets.popitem()

for spec in dct_standard['hconcat']:
assert spec['data']['values'] == data

for spec in dct_consolidated['hconcat']:
assert spec['data'] == {'name': name}

0 comments on commit 9eb71cd

Please sign in to comment.