Skip to content

Commit

Permalink
Add automatic categorical legend for datashaded plots (#4806)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored May 26, 2022
1 parent 9265a12 commit 59c20c7
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 56 deletions.
95 changes: 69 additions & 26 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,24 +253,29 @@ class AggregationOperation(ResamplingOperation):
'var': rd.var,
'std': rd.std,
'min': rd.min,
'max': rd.max
'max': rd.max,
'count_cat': rd.count_cat
}

def _get_aggregator(self, element, add_field=True):
agg = self.p.aggregator
@classmethod
def _get_aggregator(cls, element, agg, add_field=True):
if isinstance(agg, str):
if agg not in self._agg_methods:
agg_methods = sorted(self._agg_methods)
if agg not in cls._agg_methods:
agg_methods = sorted(cls._agg_methods)
raise ValueError("Aggregation method '%r' is not known; "
"aggregator must be one of: %r" %
(agg, agg_methods))
agg = self._agg_methods[agg]()
if agg == 'count_cat':
agg = cls._agg_methods[agg]('__temp__')
else:
agg = cls._agg_methods[agg]()

elements = element.traverse(lambda x: x, [Element])
if add_field and getattr(agg, 'column', False) is None and not isinstance(agg, (rd.count, rd.any)):
if (add_field and getattr(agg, 'column', False) in ('__temp__', None) and
not isinstance(agg, (rd.count, rd.any))):
if not elements:
raise ValueError('Could not find any elements to apply '
'%s operation to.' % type(self).__name__)
'%s operation to.' % cls.__name__)
inner_element = elements[0]
if isinstance(inner_element, TriMesh) and inner_element.nodes.vdims:
field = inner_element.nodes.vdims[0].name
Expand All @@ -282,7 +287,7 @@ def _get_aggregator(self, element, add_field=True):
raise ValueError("Could not determine dimension to apply "
"'%s' operation to. Declare the dimension "
"to aggregate as part of the datashader "
"aggregator." % type(self).__name__)
"aggregator." % cls.__name__)
agg = type(agg)(field)
return agg

Expand Down Expand Up @@ -466,7 +471,7 @@ def get_agg_data(cls, obj, category=None):


def _process(self, element, key=None):
agg_fn = self._get_aggregator(element)
agg_fn = self._get_aggregator(element, self.p.aggregator)
if hasattr(agg_fn, 'cat_column'):
category = agg_fn.cat_column
else:
Expand Down Expand Up @@ -556,7 +561,7 @@ def applies(cls, element, agg_fn, line_width=None):
(isinstance(agg_fn, ds.count_cat) and agg_fn.column in element.kdims)))

def _process(self, element, key=None):
agg_fn = self._get_aggregator(element)
agg_fn = self._get_aggregator(element, self.p.aggregator)

if not self.applies(element, agg_fn, line_width=self.p.line_width):
raise ValueError(
Expand Down Expand Up @@ -654,7 +659,7 @@ class area_aggregate(AggregationOperation):

def _process(self, element, key=None):
x, y = element.dimensions()[:2]
agg_fn = self._get_aggregator(element)
agg_fn = self._get_aggregator(element, self.p.aggregator)

default = None
if not self.p.y_range:
Expand Down Expand Up @@ -726,7 +731,7 @@ class spikes_aggregate(LineAggregationOperation):
The offset of the lower end of each spike.""")

def _process(self, element, key=None):
agg_fn = self._get_aggregator(element)
agg_fn = self._get_aggregator(element, self.p.aggregator)
x, y = element.kdims[0], None

spike_length = 0.5 if self.p.spike_length is None else self.p.spike_length
Expand Down Expand Up @@ -798,7 +803,7 @@ def _aggregate(self, cvs, df, x0, y0, x1, y1, agg):
raise NotImplementedError

def _process(self, element, key=None):
agg_fn = self._get_aggregator(element)
agg_fn = self._get_aggregator(element, self.p.aggregator)
x0d, y0d, x1d, y1d = element.kdims
info = self._get_sampling(element, [x0d, x1d], [y0d, y1d], ndim=1)
(x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info
Expand Down Expand Up @@ -994,7 +999,7 @@ def _process(self, element, key=None):
# Apply regridding to each value dimension
regridded = {}
arrays = self._get_xarrays(element, coords, xtype, ytype)
agg_fn = self._get_aggregator(element, add_field=False)
agg_fn = self._get_aggregator(element, self.p.aggregator, add_field=False)
for vd, xarr in arrays.items():
rarray = cvs.raster(xarr, upsample_method=interp,
downsample_method=agg_fn)
Expand All @@ -1021,11 +1026,11 @@ class contours_rasterize(aggregate):
aggregator = param.ClassSelector(default=ds.mean(),
class_=(ds.reductions.Reduction, str))

def _get_aggregator(self, element, add_field=True):
agg = self.p.aggregator
@classmethod
def _get_aggregator(cls, element, agg, add_field=True):
if not element.vdims and agg.column is None and not isinstance(agg, (rd.count, rd.any)):
return ds.any()
return super()._get_aggregator(element, add_field)
return super()._get_aggregator(element, agg, add_field)



Expand Down Expand Up @@ -1113,9 +1118,12 @@ def _process(self, element, key=None):
or not (element.vdims or element.nodes.vdims)):
wireframe = True
precompute = False # TriMesh itself caches wireframe
agg = self._get_aggregator(element) if isinstance(agg, (ds.any, ds.count)) else ds.any()
if isinstance(agg, (ds.any, ds.count)):
agg = self._get_aggregator(element, self.p.aggregator)
else:
agg = ds.any()
elif getattr(agg, 'column', None) is None:
agg = self._get_aggregator(element)
agg = self._get_aggregator(element, self.p.aggregator)

if element._plot_id in self._precomputed:
precomputed = self._precomputed[element._plot_id]
Expand Down Expand Up @@ -1174,7 +1182,7 @@ def _process(self, element, key=None):
data = element.data

x, y = element.kdims
agg_fn = self._get_aggregator(element)
agg_fn = self._get_aggregator(element, self.p.aggregator)
info = self._get_sampling(element, x, y)
(x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info
if xtype == 'datetime':
Expand Down Expand Up @@ -1417,15 +1425,15 @@ class geometry_rasterize(LineAggregationOperation):
aggregator = param.ClassSelector(default=ds.mean(),
class_=(ds.reductions.Reduction, str))

def _get_aggregator(self, element, add_field=True):
agg = self.p.aggregator
@classmethod
def _get_aggregator(cls, element, agg, add_field=True):
if (not (element.vdims or isinstance(agg, str)) and
agg.column is None and not isinstance(agg, (rd.count, rd.any))):
return ds.count()
return super()._get_aggregator(element, add_field)
return super()._get_aggregator(element, agg, add_field)

def _process(self, element, key=None):
agg_fn = self._get_aggregator(element)
agg_fn = self._get_aggregator(element, self.p.aggregator)
xdim, ydim = element.kdims
info = self._get_sampling(element, xdim, ydim)
(x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info
Expand Down Expand Up @@ -2064,8 +2072,43 @@ def _sort_by_distance(cls, raster, df, x, y):
return df.iloc[distances.argsort().values]



inspect._dispatch = {
Points: inspect_points,
Polygons: inspect_polygons
}


class categorical_legend(Operation):

def _process(self, element, key=None):
from ..plotting.util import rgb2hex
rasterize_op = element.pipeline.find(rasterize)
if isinstance(rasterize_op, datashade):
shade_op = rasterize_op
else:
shade_op = element.pipeline.find(shade)
if None in (shade_op, rasterize_op):
return None
hvds = element.dataset
input_el = element.pipeline.operations[0](hvds)
agg = rasterize_op._get_aggregator(input_el, rasterize_op.aggregator)
if not isinstance(agg, (ds.count_cat, ds.by)):
return
column = agg.column
if hasattr(hvds.data, 'dtypes'):
cats = list(hvds.data.dtypes[column].categories)
if cats == ['__UNKNOWN_CATEGORIES__']:
cats = list(hvds.data[column].cat.as_known().categories)
else:
cats = list(hvds.dimension_values(column, expanded=False))
colors = shade_op.color_key
color_data = [(0, 0, cat) for cat in cats]
if isinstance(colors, list):
cat_colors = {cat: colors[i] for i, cat in enumerate(cats)}
else:
cat_colors = {cat: colors[cat] for cat in cats}
cat_colors = {
cat: rgb2hex([v/256 for v in color[:3]]) if isinstance(color, tuple) else color
for cat, color in cat_colors.items()}
return Points(color_data, vdims=['category']).opts(
apply_ranges=False, cmap=cat_colors, color='category', show_legend=True)
14 changes: 14 additions & 0 deletions holoviews/operation/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,20 @@ def _process(self, view, key=None):
else:
return processed.clone(group=self.p.group)

def find(self, operation, skip_nonlinked=True):
"""
Returns the first found occurrence of an operation while
performing a backward traversal of the chain pipeline.
"""
found = None
for op in self.operations[::-1]:
if isinstance(op, operation):
found = op
break
if not op.link_inputs and skip_nonlinked:
break
return found


class transform(Operation):
"""
Expand Down
9 changes: 5 additions & 4 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,12 +2053,12 @@ class LegendPlot(ElementPlot):
legend_cols = param.Integer(default=False, doc="""
Whether to lay out the legend as columns.""")

legend_specs = {'right': 'right', 'left': 'left', 'top': 'above',
'bottom': 'below'}

legend_opts = param.Dict(default={}, doc="""
Allows setting specific styling options for the colorbar.""")

legend_specs = {'right': 'right', 'left': 'left', 'top': 'above',
'bottom': 'below'}

def _process_legend(self, plot=None):
plot = plot or self.handles['plot']
if not plot.legend:
Expand Down Expand Up @@ -2205,7 +2205,8 @@ def _process_legend(self, overlay):
renderers = []
for item in legend_items:
item.renderers[:] = [r for r in item.renderers if r not in renderers]
if item in filtered or not item.renderers or not any(r.visible for r in item.renderers):
if (item in filtered or not item.renderers or
not any(r.visible or 'hv_legend' in r.tags for r in item.renderers)):
continue
renderers += item.renderers
filtered.append(item)
Expand Down
29 changes: 27 additions & 2 deletions holoviews/plotting/bokeh/raster.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import absolute_import, division, unicode_literals

import sys

import numpy as np
import param

from bokeh.models import DatetimeAxis, CustomJSHover

from ...core.util import cartesian_product, dimension_sanitizer, isfinite
from ...element import Raster
from .element import ElementPlot, ColorbarPlot
from ..util import categorical_legend
from .chart import PointPlot
from .element import ColorbarPlot, LegendPlot
from .selection import BokehOverlaySelectionDisplay
from .styles import base_properties, fill_properties, line_properties, mpl_to_bokeh
from .util import colormesh
Expand Down Expand Up @@ -128,7 +132,7 @@ def get_data(self, element, ranges, style):



class RGBPlot(ElementPlot):
class RGBPlot(LegendPlot):

padding = param.ClassSelector(default=0, class_=(int, float, tuple))

Expand All @@ -140,11 +144,32 @@ class RGBPlot(ElementPlot):

selection_display = BokehOverlaySelectionDisplay()

def __init__(self, hmap, **params):
super(RGBPlot, self).__init__(hmap, **params)
self._legend_plot = None

def _hover_opts(self, element):
xdim, ydim = element.kdims
return [(xdim.pprint_label, '$x'), (ydim.pprint_label, '$y'),
('RGBA', '@image')], {}

def _init_glyphs(self, plot, element, ranges, source):
super(RGBPlot, self)._init_glyphs(plot, element, ranges, source)
if 'holoviews.operation.datashader' not in sys.modules or not self.show_legend:
return
try:
legend = categorical_legend(element, backend=self.backend)
except Exception:
return
if legend is None:
return
legend_params = {k: v for k, v in self.param.get_param_values()
if k.startswith('legend')}
self._legend_plot = PointPlot(legend, keys=[], overlaid=1, **legend_params)
self._legend_plot.initialize_plot(plot=plot)
self._legend_plot.handles['glyph_renderer'].tags.append('hv_legend')
self.handles['rgb_color_mapper'] = self._legend_plot.handles['color_color_mapper']

def get_data(self, element, ranges, style):
mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh')
if 'alpha' in style:
Expand Down
2 changes: 1 addition & 1 deletion holoviews/plotting/mpl/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _update_separator(self, offset):
offset_line.set_ydata(offset)


class PointPlot(ChartPlot, ColorbarPlot):
class PointPlot(ChartPlot, ColorbarPlot, LegendPlot):
"""
Note that the 'cmap', 'vmin' and 'vmax' style arguments control
how point magnitudes are rendered to different colors.
Expand Down
Loading

0 comments on commit 59c20c7

Please sign in to comment.