Skip to content

Commit

Permalink
Improved plot cmap/palette handling
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Nov 23, 2017
1 parent 7995e11 commit 184cf82
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 64 deletions.
77 changes: 37 additions & 40 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
from bokeh.plotting.helpers import _known_tools as known_tools

from ...core import DynamicMap, CompositeOverlay, Element, Dimension
from ...core.options import abbreviated_exception, SkipRendering
from ...core.options import abbreviated_exception, SkipRendering, Cycle
from ...core import util
from ...streams import Stream, Buffer
from ..plot import GenericElementPlot, GenericOverlayPlot
from ..util import dynamic_update
from ..util import dynamic_update, process_cmap
from .plot import BokehPlot, TOOLS
from .util import (mpl_to_bokeh, get_tab_title, mplcmap_to_palette,
py2js_tickformatter, rgba_tuple, recursive_model_update)
from .util import (mpl_to_bokeh, get_tab_title, py2js_tickformatter,
rgba_tuple, recursive_model_update)

property_prefixes = ['selection', 'nonselection', 'muted', 'hover']

Expand Down Expand Up @@ -922,16 +922,17 @@ class CompositeElementPlot(ElementPlot):
drawing of multiple glyphs.
"""

# Mapping between style groups and glyph names
# Mapping between glyph names and style groups
_style_groups = {}

# Defines the order in which glyphs are drawn, defined by glyph name
_draw_order = []

def _init_glyphs(self, plot, element, ranges, source):
def _init_glyphs(self, plot, element, ranges, source, data=None, mapping=None, style=None):
# Get data and initialize data source
style = self.style[self.cyclic_index]
data, mapping, style = self.get_data(element, ranges, style)
if None in (data, mapping):
style = self.style[self.cyclic_index]
data, mapping, style = self.get_data(element, ranges, style)

source_cache = {}
current_id = element._plot_id
Expand Down Expand Up @@ -1100,14 +1101,15 @@ def _draw_colorbar(self, plot, color_mapper):
self.handles['colorbar'] = color_bar


def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=None):
def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=None,
cycle=None, name='color_mapper'):
# The initial colormapper instance is cached the first time
# and then only updated
if dim is None:
if dim is None and colors is None:
return None
if self.adjoined:
cmappers = self.adjoined.traverse(lambda x: (x.handles.get('color_dim'),
x.handles.get('color_mapper')))
x.handles.get(name)))
cmappers = [cmap for cdim, cmap in cmappers if cdim == dim]
if cmappers:
cmapper = cmappers[0]
Expand All @@ -1117,30 +1119,18 @@ def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=Non
return None

ncolors = None if factors is None else len(factors)
low, high = ranges.get(dim.name, element.range(dim.name))
if colors:
palette = colors
if dim:
low, high = ranges.get(dim.name, element.range(dim.name))
else:
cmap = style.pop('cmap', 'viridis')
if isinstance(cmap, list):
palette = cmap
else:
try:
# Process as matplotlib colormap
palette = mplcmap_to_palette(cmap, ncolors)
except ValueError:
# Process as bokeh palette
palette = getattr(palettes, cmap, None)
if isinstance(palette, dict):
if ncolors in palette:
palette = palette[ncolors]
else:
palette = sorted(palette.items())[-1][1]
low, high = None, None

cmap = colors or cycle or style.pop('cmap', 'viridis')
palette = process_cmap(cmap, ncolors)
nan_colors = {k: rgba_tuple(v) for k, v in self.clipping_colors.items()}
colormapper, opts = self._get_cmapper_opts(low, high, factors, nan_colors)

if 'color_mapper' in self.handles and isinstance(self.handles['color_mapper'], colormapper):
cmapper = self.handles['color_mapper']
cmapper = self.handles.get(name)
if cmapper is not None:
if cmapper.palette != palette:
cmapper.palette = palette
opts = {k: opt for k, opt in opts.items()
Expand All @@ -1149,27 +1139,34 @@ def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=Non
cmapper.update(**opts)
else:
cmapper = colormapper(palette=palette, **opts)
self.handles['color_mapper'] = cmapper
self.handles[name] = cmapper
self.handles['color_dim'] = dim
return cmapper


def _get_color_data(self, element, ranges, style, name='color', factors=None, colors=None):
def _get_color_data(self, element, ranges, style, name='color', factors=None, colors=None,
cycle=None, int_categories=False):
data, mapping = {}, {}
cdim = element.get_dimension(self.color_index)
if not cdim:
return data, mapping

cdata = element.dimension_values(cdim)
if factors is None and (isinstance(cdata, list) or cdata.dtype.kind in 'OSU'):
factors = list(np.unique(cdata))
field = util.dimension_sanitizer(cdim.name)
dtypes = 'iOSU' if int_categories else 'OSU'
if factors is None and (isinstance(cdata, list) or cdata.dtype.kind in dtypes):
factors = list(util.unique_array(cdata))
if factors and int_categories and cdata.dtype.kind == 'i':
field += '_str'
cdata = [str(f) for f in cdata]
factors = [str(f) for f in factors]

mapper = self._get_colormapper(cdim, element, ranges, style,
factors, colors)
data[cdim.name] = cdata
factors, colors, cycle)
data[field] = cdata
if factors is not None:
mapping['legend'] = {'field': cdim.name}
mapping[name] = {'field': cdim.name,
'transform': mapper}
mapping['legend'] = {'field': field}
mapping[name] = {'field': field, 'transform': mapper}
return data, mapping


Expand Down
13 changes: 0 additions & 13 deletions holoviews/plotting/bokeh/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,6 @@ def rgba_tuple(rgba):
return rgba


def mplcmap_to_palette(cmap, ncolors=None):
"""
Converts a matplotlib colormap to palette of RGB hex strings."
"""
if colors is None:
raise ValueError("Using cmaps on objects requires matplotlib.")
with abbreviated_exception():
colormap = cm.get_cmap(cmap) #choose any matplotlib colormap here
if ncolors:
return [rgb2hex(colormap(i)) for i in np.linspace(0, 1, ncolors)]
return [rgb2hex(m) for m in colormap(np.arange(colormap.N))]


def get_cmap(cmap):
"""
Returns matplotlib cmap generated from bokeh palette or
Expand Down
16 changes: 8 additions & 8 deletions holoviews/plotting/mpl/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,12 +645,12 @@ def _draw_colorbar(self, dim=None, redraw=True):
ColorbarPlot._colorbars[id(axis)] = (ax_colorbars, (l, b, w, h))


def _norm_kwargs(self, element, ranges, opts, vdim):
def _norm_kwargs(self, element, ranges, opts, vdim, prefix=''):
"""
Returns valid color normalization kwargs
to be passed to matplotlib plot function.
"""
clim = opts.pop('clims', None)
clim = opts.pop(prefix+'clims', None)
if clim is None:
cs = element.dimension_values(vdim)
if not isinstance(cs, np.ndarray):
Expand All @@ -674,9 +674,9 @@ def _norm_kwargs(self, element, ranges, opts, vdim):
linthresh=clim[1]/np.e)
else:
norm = mpl_colors.LogNorm(vmin=clim[0], vmax=clim[1])
opts['norm'] = norm
opts['vmin'] = clim[0]
opts['vmax'] = clim[1]
opts[prefix+'norm'] = norm
opts[prefix+'vmin'] = clim[0]
opts[prefix+'vmax'] = clim[1]

# Check whether the colorbar should indicate clipping
values = np.asarray(element.dimension_values(vdim))
Expand All @@ -687,8 +687,8 @@ def _norm_kwargs(self, element, ranges, opts, vdim):
el_min, el_max = -np.inf, np.inf
else:
el_min, el_max = -np.inf, np.inf
vmin = -np.inf if opts['vmin'] is None else opts['vmin']
vmax = np.inf if opts['vmax'] is None else opts['vmax']
vmin = -np.inf if opts[prefix+'vmin'] is None else opts[prefix+'vmin']
vmax = np.inf if opts[prefix+'vmax'] is None else opts[prefix+'vmax']
if el_min < vmin and el_max > vmax:
self._cbar_extend = 'both'
elif el_min < vmin:
Expand Down Expand Up @@ -719,7 +719,7 @@ def _norm_kwargs(self, element, ranges, opts, vdim):
if 'max' in colors: cmap.set_over(**colors['max'])
if 'min' in colors: cmap.set_under(**colors['min'])
if 'NaN' in colors: cmap.set_bad(**colors['NaN'])
opts['cmap'] = cmap
opts[prefix+'cmap'] = cmap



Expand Down
56 changes: 55 additions & 1 deletion holoviews/plotting/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, absolute_import
from collections import defaultdict
import traceback

Expand All @@ -7,6 +7,7 @@

from ..core import (HoloMap, DynamicMap, CompositeOverlay, Layout,
Overlay, GridSpace, NdLayout, Store)
from ..core.options import Cycle, abbreviated_exception
from ..core.spaces import get_nested_streams
from ..core.util import (match_spec, is_number, wrap_tuple, basestring,
get_overlay_spec, unique_iterator)
Expand Down Expand Up @@ -411,6 +412,59 @@ def map_colors(arr, crange, cmap, hex=True):
return arr


def mplcmap_to_palette(cmap, ncolors=None):
"""
Converts a matplotlib colormap to palette of RGB hex strings."
"""
import matplotlib.cm as cm
colormap = cm.get_cmap(cmap) #choose any matplotlib colormap here
if ncolors:
return [rgb2hex(colormap(i)) for i in np.linspace(0, 1, ncolors)]
return [rgb2hex(m) for m in colormap(np.arange(colormap.N))]


def bokeh_palette_to_palette(cmap, ncolors=None):
from bokeh import palettes
# Process as bokeh palette
palette = getattr(palettes, cmap, None)
if palette is None:
raise ValueError("Supplied palette %s not found among bokeh palettes" % cmap)
elif isinstance(palette, dict):
if ncolors in palette:
palette = palette[ncolors]
else:
palette = sorted(palette.items())[-1][1]
if ncolors:
return [palette[i%len(palette)] for i in range(ncolors)]
return palette


def process_cmap(cmap, ncolors=None):
"""
Convert valid colormap specifications to a list of colors.
"""
if isinstance(cmap, Cycle):
palette = [rgb2hex(c) if isinstance(c, tuple) else c for c in cmap.values]
elif isinstance(cmap, list):
palette = cmap
elif isinstance(cmap, basestring):
try:
# Process as matplotlib colormap
palette = mplcmap_to_palette(cmap, ncolors)
except:
try:
palette = bokeh_palette_to_palette(cmap, ncolors)
except:
raise ValueError("Supplied cmap %s not found among "
"matplotlib or bokeh colormaps.")
else:
raise TypeError("cmap argument expects a list, Cycle or valid matplotlib "
"colormap or bokeh palette, found %s." % cmap)
if ncolors:
return [palette[i%len(palette)] for i in range(ncolors)]
return palette


def dim_axis_label(dimensions, separator=', '):
"""
Returns an axis label for one or more dimensions.
Expand Down
34 changes: 32 additions & 2 deletions tests/testplotutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from holoviews import NdOverlay, Overlay
from holoviews.core.spaces import DynamicMap
from holoviews.core.options import Store
from holoviews.core.options import Store, Cycle
from holoviews.element.comparison import ComparisonTestCase
from holoviews.element import Curve, Area, Points
from holoviews.plotting.util import compute_overlayable_zorders, get_min_distance
from holoviews.plotting.util import (
compute_overlayable_zorders, get_min_distance, process_cmap)
from holoviews.streams import PointerX

try:
Expand Down Expand Up @@ -303,6 +304,35 @@ def test_dynamic_compute_overlayable_zorders_three_deep_dynamic_layers_reduced_l
self.assertNotIn(curve, sources[2])




class TestPlotColorUtils(ComparisonTestCase):

def test_process_cmap_mpl(self):
colors = process_cmap('Greys', 3)
self.assertEqual(colors, ['#ffffff', '#959595', '#000000'])

def test_process_cmap_bokeh(self):
colors = process_cmap('Category20', 3)
self.assertEqual(colors, ['#1f77b4', '#aec7e8', '#ff7f0e'])

def test_process_cmap_list_cycle(self):
colors = process_cmap(['#ffffff', '#959595', '#000000'], 4)
self.assertEqual(colors, ['#ffffff', '#959595', '#000000', '#ffffff'])

def test_process_cmap_cycle(self):
colors = process_cmap(Cycle(values=['#ffffff', '#959595', '#000000']), 4)
self.assertEqual(colors, ['#ffffff', '#959595', '#000000', '#ffffff'])

def test_process_cmap_invalid_str(self):
with self.assertRaises(ValueError):
colors = process_cmap('NonexistentColorMap', 3)

def test_process_cmap_invalid_type(self):
with self.assertRaises(TypeError):
colors = process_cmap({'A', 'B', 'C'}, 3)


class TestPlotUtils(ComparisonTestCase):

def test_get_min_distance_float32_type(self):
Expand Down

0 comments on commit 184cf82

Please sign in to comment.