Skip to content

Commit

Permalink
Ensure Violin does not break when category is missing (#4482)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Jun 23, 2020
1 parent 5985955 commit 621171b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
56 changes: 38 additions & 18 deletions holoviews/plotting/bokeh/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from ...core import NdOverlay
from ...core.dimension import Dimension, Dimensioned
from ...core.ndmapping import sorted_context
from ...core.util import (basestring, dimension_sanitizer, wrap_tuple,
unique_iterator, unique_array, isfinite,
is_dask_array, is_cupy_array)
from ...core.util import (
basestring, dimension_sanitizer, wrap_tuple, unique_iterator,
isfinite, is_dask_array, is_cupy_array
)
from ...operation.stats import univariate_kde
from ...util.transform import dim
from .chart import AreaPlot
Expand Down Expand Up @@ -392,7 +393,7 @@ def _get_factors(self, element, ranges):
xfactors, yfactors = factors, []
return (yfactors, xfactors) if self.invert_axes else (xfactors, yfactors)

def _kde_data(self, element, el, key, split_dim, **kwargs):
def _kde_data(self, element, el, key, split_dim, split_cats, **kwargs):
vdims = el.vdims
vdim = vdims[0]
if self.clip:
Expand All @@ -402,24 +403,27 @@ def _kde_data(self, element, el, key, split_dim, **kwargs):
if split_dim is not None:
el = el.clone(kdims=element.kdims)
all_cats = split_dim.apply(el)
bin_cats = unique_array(all_cats)
if len(bin_cats) > 2:
if len(split_cats) > 2:
raise ValueError(
'The number of categories for split violin plots cannot be '
'greater than 2! Found {0} categories: {1}'.format(
len(bin_cats), ', '.join(bin_cats)))
len(split_cats), ', '.join(split_cats)))
el = el.add_dimension(repr(split_dim), len(el.kdims), all_cats)

kdes = univariate_kde(el, dimension=vdim.name, groupby=repr(split_dim), **kwargs)
scale = 4
else:
kdes = [univariate_kde(el, dimension=vdim.name, **kwargs)] * 2
split_cats = [None, None]
kdes = {None: univariate_kde(el, dimension=vdim.name, **kwargs)}
scale = 2

x_range = el.range(vdim)
xs, fill_xs, ys, fill_ys = [], [], [], []
for i, kde in enumerate(kdes):
_xs, _ys = (kde.dimension_values(i) for i in range(2))
for i, cat in enumerate(split_cats):
kde = kdes.get(cat)
if kde is None:
_xs, _ys = np.array([]), np.array([])
else:
_xs, _ys = (kde.dimension_values(idim) for idim in range(2))
mask = isfinite(_ys) & (_ys>0) # Mask out non-finite and zero values
_xs, _ys = _xs[mask], _ys[mask]

Expand All @@ -429,9 +433,13 @@ def _kde_data(self, element, el, key, split_dim, **kwargs):
_ys = _ys[::-1]
_xs = _xs[::-1]

if split_dim and len(_xs):
fill_xs.append([x_range[0]]+list(_xs)+[x_range[-1]])
fill_ys.append([0]+list(_ys)+[0])
if split_dim:
if len(_xs):
fill_xs.append([x_range[0]]+list(_xs)+[x_range[-1]])
fill_ys.append([0]+list(_ys)+[0])
else:
fill_xs.append([])
fill_ys.append([])
x_range = x_range[::-1]

xs += list(_xs)
Expand All @@ -443,17 +451,19 @@ def _kde_data(self, element, el, key, split_dim, **kwargs):
# this scales the width
if split_dim:
fill_xs = [np.asarray(x) for x in fill_xs]
fill_ys = [[key + (y,) for y in (fy/ys.max())*(self.violin_width/scale)] for fy in fill_ys]
ys = (ys/ys.max())*(self.violin_width/scale) if len(ys) else []
fill_ys = [[key + (y,) for y in (fy/np.abs(ys).max())*(self.violin_width/scale)]
if len(fy) else [] for fy in fill_ys]
ys = (ys/np.nanmax(np.abs(ys)))*(self.violin_width/scale) if len(ys) else []
ys = [key + (y,) for y in ys]

line = {'ys': xs, 'xs': ys}
if split_dim:
kde = {'ys': fill_xs, 'xs': fill_ys}
else:
kde = line

if isinstance(kdes, NdOverlay):
kde[repr(split_dim)] = [str(k) for k in kdes.keys()]
kde[repr(split_dim)] = [str(k) for k in split_cats]

bars, segments, scatter = defaultdict(list), defaultdict(list), {}
values = el.dimension_values(vdim)
Expand Down Expand Up @@ -488,6 +498,7 @@ def _kde_data(self, element, el, key, split_dim, **kwargs):
bars['top'].append(q3)
scatter['x'] = xpos
scatter['y'] = q2

return kde, line, segments, bars, scatter


Expand All @@ -501,6 +512,15 @@ def get_data(self, element, ranges, style):
else:
groups = dict([((element.label,), element)])

if split_dim:
split_name = split_dim.dimension.name
if split_name in ranges and not split_dim.ops and 'factors' in ranges[split_name]:
split_cats = ranges[split_name].get('factors')
elif split_dim:
split_cats = list(unique_iterator(split_dim.apply(element)))
else:
split_cats = None

# Define glyph-data mapping
if self.invert_axes:
bar_map = {'y': 'x', 'left': 'bottom',
Expand Down Expand Up @@ -528,7 +548,7 @@ def get_data(self, element, ranges, style):
kde_data, line_data, seg_data, bar_data, scatter_data = (defaultdict(list) for i in range(5))
for i, (key, g) in enumerate(groups.items()):
key = decode_bytes(key)
kde, line, segs, bars, scatter = self._kde_data(element, g, key, split_dim, **kwargs)
kde, line, segs, bars, scatter = self._kde_data(element, g, key, split_dim, split_cats, **kwargs)
for k, v in segs.items():
seg_data[k] += v
for k, v in bars.items():
Expand Down
3 changes: 2 additions & 1 deletion holoviews/tests/plotting/bokeh/testviolinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def test_violin_split_op_multi(self):
source = plot.handles['patches_1_source']
glyph = plot.handles['patches_1_glyph']
cmapper = plot.handles['violin_color_mapper']
self.assertEqual(source.data["dim('b')>2"], ['False', 'False', 'False', 'True', 'True'])
values = ['False', 'True', 'False', 'True', 'False', 'True', 'False', 'True', 'False', 'True']
self.assertEqual(source.data["dim('b')>2"], values)
self.assertEqual(glyph.fill_color, {'field': "dim('b')>2", 'transform': cmapper})

def test_violin_split_op_single(self):
Expand Down

0 comments on commit 621171b

Please sign in to comment.