diff --git a/examples/reference/elements/bokeh/Graph.ipynb b/examples/reference/elements/bokeh/Graph.ipynb index 12c66e38d1..e4298646a5 100644 --- a/examples/reference/elements/bokeh/Graph.ipynb +++ b/examples/reference/elements/bokeh/Graph.ipynb @@ -90,7 +90,7 @@ "source": [ "#### Additional features\n", "\n", - "Next we will extend this example by supplying explicit edges:" + "Next we will extend this example by supplying explicit edges, node information and edge weights. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``." ] }, { @@ -100,8 +100,10 @@ "outputs": [], "source": [ "# Node info\n", + "np.random.seed(7)\n", "x, y = simple_graph.nodes.array([0, 1]).T\n", "node_labels = ['Output']+['Input']*(N-1)\n", + "edge_weights = np.random.rand(8)\n", "\n", "# Compute edge paths\n", "def bezier(start, end, control, steps=np.linspace(0, 1, 100)):\n", @@ -114,10 +116,10 @@ "\n", "# Declare Graph\n", "nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n", - "graph = hv.Graph(((source, target), nodes, paths))\n", + "graph = hv.Graph(((source, target, edge_weights), nodes, paths), vdims='Weight')\n", "\n", - "graph.redim.range(**padding).opts(plot=dict(color_index='Type'),\n", - " style=dict(cmap=['blue', 'yellow']))" + "graph.redim.range(**padding).opts(plot=dict(color_index='Type', edge_color_index='Weight'),\n", + " style=dict(cmap=['blue', 'red'], edge_cmap='viridis'))" ] } ], diff --git a/examples/reference/elements/matplotlib/Graph.ipynb b/examples/reference/elements/matplotlib/Graph.ipynb index 437c132db2..4c2356fb65 100644 --- a/examples/reference/elements/matplotlib/Graph.ipynb +++ b/examples/reference/elements/matplotlib/Graph.ipynb @@ -90,7 +90,8 @@ "source": [ "#### Additional features\n", "\n", - "Next we will extend this example by supplying explicit edges:" + "\n", + "Next we will extend this example by supplying explicit edges, node information and edge weights. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``." ] }, { @@ -99,11 +100,11 @@ "metadata": {}, "outputs": [], "source": [ - "from matplotlib.colors import ListedColormap\n", - "\n", "# Node info\n", + "np.random.seed(7)\n", "x, y = simple_graph.nodes.array([0, 1]).T\n", "node_labels = ['Output']+['Input']*(N-1)\n", + "edge_weights = np.random.rand(8)\n", "\n", "# Compute edge paths\n", "def bezier(start, end, control, steps=np.linspace(0, 1, 100)):\n", @@ -116,10 +117,10 @@ "\n", "# Declare Graph\n", "nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n", - "graph = hv.Graph(((source, target), nodes, paths))\n", + "graph = hv.Graph(((source, target, edge_weights), nodes, paths), vdims='Weight')\n", "\n", - "graph.redim.range(**padding).opts(plot=dict(color_index='Type'),\n", - " style=dict(cmap=ListedColormap(['blue', 'yellow'])))" + "graph.redim.range(**padding).opts(plot=dict(color_index='Type', edge_color_index='Weight'),\n", + " style=dict(cmap=['blue', 'red'], edge_cmap='viridis'))" ] } ], diff --git a/examples/user_guide/Network_Graphs.ipynb b/examples/user_guide/Network_Graphs.ipynb index 3e85143d72..168e89b091 100644 --- a/examples/user_guide/Network_Graphs.ipynb +++ b/examples/user_guide/Network_Graphs.ipynb @@ -43,8 +43,8 @@ "source": [ "# Declare abstract edges\n", "N = 8\n", - "node_indices = np.arange(N)\n", - "source = np.zeros(N)\n", + "node_indices = np.arange(N, dtype=np.int32)\n", + "source = np.zeros(N, dtype=np.int32)\n", "target = node_indices\n", "\n", "padding = dict(x=(-1.2, 1.2), y=(-1.2, 1.2))\n", @@ -148,7 +148,7 @@ "source": [ "#### Additional information\n", "\n", - "We can also associate additional information with the nodes and edges of a graph. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself." + "We can also associate additional information with the nodes and edges of a graph. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``." ] }, { @@ -157,12 +157,13 @@ "metadata": {}, "outputs": [], "source": [ - "%%opts Graph [color_index='Type'] (cmap='Set1')\n", + "%%opts Graph [color_index='Type' edge_color_index='Weight'] (cmap='Set1' edge_cmap='viridis')\n", "node_labels = ['Output']+['Input']*(N-1)\n", - "edge_labels = list('ABCDEFGH')\n", + "np.random.seed(7)\n", + "edge_labels = np.random.rand(8)\n", "\n", "nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n", - "graph = hv.Graph(((source, target, edge_labels), nodes, paths), vdims='Label').redim.range(**padding)\n", + "graph = hv.Graph(((source, target, edge_labels), nodes, paths), vdims='Weight').redim.range(**padding)\n", "graph + graph.opts(plot=dict(inspection_policy='edges'))" ] }, diff --git a/holoviews/core/util.py b/holoviews/core/util.py index 9e8e9d9c3a..361fe10286 100644 --- a/holoviews/core/util.py +++ b/holoviews/core/util.py @@ -1569,3 +1569,12 @@ def dt_to_int(value, time_unit='us'): except: # Handle python2 return (time.mktime(value.timetuple()) + value.microsecond / 1e6) * tscale + + +def search_indices(values, source): + """ + Given a set of values returns the indices of each of those values + in the source array. + """ + orig_indices = source.argsort() + return orig_indices[np.searchsorted(source[orig_indices], values)] diff --git a/holoviews/plotting/bokeh/element.py b/holoviews/plotting/bokeh/element.py index ca0258ffeb..271e22ddee 100644 --- a/holoviews/plotting/bokeh/element.py +++ b/holoviews/plotting/bokeh/element.py @@ -5,7 +5,6 @@ import numpy as np import bokeh import bokeh.plotting -from bokeh import palettes from bokeh.core.properties import value from bokeh.models import (HoverTool, Renderer, Range1d, DataRange1d, Title, FactorRange, FuncTickFormatter, Tool, Legend) @@ -20,7 +19,7 @@ from bokeh.plotting.helpers import _known_tools as known_tools from ...core import DynamicMap, CompositeOverlay, Element, Dimension -from ...core.options import abbreviated_exception, SkipRendering, Cycle +from ...core.options import abbreviated_exception, SkipRendering from ...core import util from ...streams import Stream, Buffer from ..plot import GenericElementPlot, GenericOverlayPlot @@ -1102,7 +1101,7 @@ def _draw_colorbar(self, plot, color_mapper): def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=None, - cycle=None, name='color_mapper'): + name='color_mapper'): # The initial colormapper instance is cached the first time # and then only updated if dim is None and colors is None: @@ -1124,7 +1123,7 @@ def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=Non else: low, high = None, None - cmap = colors or cycle or style.pop('cmap', 'viridis') + cmap = colors or style.pop('cmap', 'viridis') palette = process_cmap(cmap, ncolors) nan_colors = {k: rgba_tuple(v) for k, v in self.clipping_colors.items()} colormapper, opts = self._get_cmapper_opts(low, high, factors, nan_colors) @@ -1145,7 +1144,7 @@ def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=Non def _get_color_data(self, element, ranges, style, name='color', factors=None, colors=None, - cycle=None, int_categories=False): + int_categories=False): data, mapping = {}, {} cdim = element.get_dimension(self.color_index) if not cdim: @@ -1162,7 +1161,7 @@ def _get_color_data(self, element, ranges, style, name='color', factors=None, co factors = [str(f) for f in factors] mapper = self._get_colormapper(cdim, element, ranges, style, - factors, colors, cycle) + factors, colors) data[field] = cdata if factors is not None: mapping['legend'] = {'field': field} diff --git a/holoviews/plotting/bokeh/graphs.py b/holoviews/plotting/bokeh/graphs.py index 5dbe918902..83a389bf4c 100644 --- a/holoviews/plotting/bokeh/graphs.py +++ b/holoviews/plotting/bokeh/graphs.py @@ -1,17 +1,14 @@ import param import numpy as np +from bokeh.models import HoverTool, ColumnDataSource +from bokeh.models import (StaticLayoutProvider, NodesAndLinkedEdges, + EdgesAndLinkedNodes) -from bokeh.models import Range1d, HoverTool, ColumnDataSource - -try: - from bokeh.models import (StaticLayoutProvider, NodesAndLinkedEdges, - EdgesAndLinkedNodes) -except: - pass - -from ...core.util import basestring, dimension_sanitizer +from ...core.util import basestring, dimension_sanitizer, unique_array +from ...core.options import Cycle from .chart import ColorbarPlot, PointPlot from .element import CompositeElementPlot, LegendPlot, line_properties, fill_properties +from ..util import process_cmap class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot): @@ -20,6 +17,10 @@ class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot): allow_None=True, doc=""" Index of the dimension from which the color will the drawn""") + edge_color_index = param.ClassSelector(default=None, class_=(basestring, int), + allow_None=True, doc=""" + Index of the dimension from which the color will the drawn""") + selection_policy = param.ObjectSelector(default='nodes', objects=['edges', 'nodes', None], doc=""" Determines policy for inspection of graph components, i.e. whether to highlight nodes or edges when selecting connected edges and nodes respectively.""") @@ -31,35 +32,23 @@ class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot): tools = param.List(default=['hover', 'tap'], doc=""" A list of plugin tools to use on the plot.""") - # X-axis is categorical - _x_range_type = Range1d - - # Declare that y-range should auto-range if not bounded - _y_range_type = Range1d - - # Map each glyph to a style group + # Map each glyph to a style group _style_groups = {'scatter': 'node', 'multi_line': 'edge'} style_opts = (['edge_'+p for p in line_properties] +\ - ['node_'+p for p in fill_properties+line_properties]+['node_size', 'cmap']) + ['node_'+p for p in fill_properties+line_properties]+['node_size', 'cmap', 'edge_cmap']) def _hover_opts(self, element): if self.inspection_policy == 'nodes': dims = element.nodes.dimensions() dims = [(dims[2].pprint_label, '@{index_hover}')]+dims[3:] elif self.inspection_policy == 'edges': - kdims = [(kd.pprint_label, '@{%s}' % ref) - for kd, ref in zip(element.kdims, ['start', 'end'])] - dims = kdims+element.vdims + dims = element.kdims+element.vdims else: dims = [] return dims, {} def get_extents(self, element, ranges): - """ - Extents are set to '' and None because x-axis is categorical and - y-axis auto-ranges. - """ xdim, ydim = element.nodes.kdims[:2] x0, x1 = ranges[xdim.name] y0, y1 = ranges[ydim.name] @@ -73,26 +62,73 @@ def _get_axis_labels(self, *args, **kwargs): xlabel, ylabel = [kd.pprint_label for kd in element.nodes.kdims[:2]] return xlabel, ylabel, None + + def _get_edge_colors(self, element, ranges, edge_data, edge_mapping, style): + cdim = element.get_dimension(self.edge_color_index) + if not cdim: + return + elstyle = self.lookup_options(element, 'style') + cycle = elstyle.kwargs.get('edge_color') + + idx = element.get_dimension_index(cdim) + field = dimension_sanitizer(cdim.name) + cvals = element.dimension_values(cdim) + if idx in [0, 1]: + factors = element.nodes.dimension_values(2, expanded=False) + elif idx == 2 and cvals.dtype.kind in 'if': + factors = None + else: + factors = unique_array(cvals) + + default_cmap = 'viridis' if factors is None else 'tab20' + cmap = style.get('edge_cmap', style.get('cmap', default_cmap)) + if factors is None or factors.dtype.kind == 'f' and idx not in [0, 1]: + colors, factors = None, None + else: + if factors.dtype.kind not in 'SU': + field += '_str' + cvals = [str(f) for f in cvals] + factors = (str(f) for f in factors) + factors = list(factors) + colors = process_cmap(cycle or cmap, len(factors)) + + if field not in edge_data: + edge_data[field] = cvals + edge_style = dict(style, cmap=cmap) + mapper = self._get_colormapper(cdim, element, ranges, edge_style, + factors, colors, 'edge_colormapper') + transform = {'field': field, 'transform': mapper} + edge_mapping['edge_line_color'] = transform + edge_mapping['edge_nonselection_line_color'] = transform + edge_mapping['edge_selection_line_color'] = transform + + def get_data(self, element, ranges, style): xidx, yidx = (1, 0) if self.invert_axes else (0, 1) # Get node data nodes = element.nodes.dimension_values(2) - node_positions = element.nodes.array([0, 1, 2]) + node_positions = element.nodes.array([0, 1]) # Map node indices to integers - if nodes.dtype.kind != 'i': + if nodes.dtype.kind not in 'if': node_indices = {v: i for i, v in enumerate(nodes)} index = np.array([node_indices[n] for n in nodes], dtype=np.int32) - layout = {node_indices[z]: (y, x) if self.invert_axes else (x, y) - for x, y, z in node_positions} + layout = {str(node_indices[k]): (y, x) if self.invert_axes else (x, y) + for k, (x, y) in zip(nodes, node_positions)} else: index = nodes.astype(np.int32) - layout = {z: (y, x) if self.invert_axes else (x, y) - for x, y, z in node_positions} + layout = {str(k): (y, x) if self.invert_axes else (x, y) + for k, (x, y) in zip(index, node_positions)} point_data = {'index': index} - cdata, cmapping = self._get_color_data(element.nodes, ranges, style, 'node_fill_color') + cycle = self.lookup_options(element, 'style').kwargs.get('node_color') + colors = cycle if isinstance(cycle, Cycle) else None + cdata, cmapping = self._get_color_data( + element.nodes, ranges, style, name='node_fill_color', + colors=colors, int_categories=True + ) point_data.update(cdata) point_mapping = cmapping + edge_mapping = {} if 'node_fill_color' in point_mapping: style = {k: v for k, v in style.items() if k not in ['node_fill_color', 'node_nonselection_fill_color']} @@ -101,10 +137,13 @@ def get_data(self, element, ranges, style): # Get edge data nan_node = index.max()+1 start, end = (element.dimension_values(i) for i in range(2)) - if nodes.dtype.kind not in 'if': + if nodes.dtype.kind == 'f': + start, end = start.astype(np.int32), end.astype(np.int32) + elif nodes.dtype.kind != 'i': start = np.array([node_indices.get(x, nan_node) for x in start], dtype=np.int32) end = np.array([node_indices.get(y, nan_node) for y in end], dtype=np.int32) path_data = dict(start=start, end=end) + self._get_edge_colors(element, ranges, path_data, edge_mapping, style) if element._edgepaths and not self.static_source: edges = element._split_edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims) if len(edges) == len(start): @@ -122,11 +161,10 @@ def get_data(self, element, ranges, style): for d in element.nodes.dimensions()[3:]: point_data[dimension_sanitizer(d.name)] = element.nodes.dimension_values(d) elif self.inspection_policy == 'edges': - for d in element.vdims: + for d in element.dimensions(): path_data[dimension_sanitizer(d.name)] = element.dimension_values(d) - data = {'scatter_1': point_data, 'multi_line_1': path_data, 'layout': layout} - mapping = {'scatter_1': point_mapping, 'multi_line_1': {}} + mapping = {'scatter_1': point_mapping, 'multi_line_1': edge_mapping} return data, mapping, style @@ -145,19 +183,27 @@ def _init_glyphs(self, plot, element, ranges, source): style = self.style[self.cyclic_index] data, mapping, style = self.get_data(element, ranges, style) self.handles['previous_id'] = element._plot_id + properties = {} mappings = {} - for key in mapping: - source = self._init_datasource(data.get(key, {})) + for key in list(mapping): + if not any(glyph in key for glyph in ('scatter_1', 'multi_line_1')): + continue + source = self._init_datasource(data.pop(key, {})) self.handles[key+'_source'] = source glyph_props = self._glyph_properties(plot, element, source, ranges, style) properties.update(glyph_props) - mappings.update(mapping.get(key, {})) + mappings.update(mapping.pop(key, {})) properties = {p: v for p, v in properties.items() if p not in ('legend', 'source')} properties.update(mappings) + layout = data.pop('layout', {}) + if data and mapping: + CompositeElementPlot._init_glyphs(self, plot, element, ranges, source, + data, mapping, style) + # Define static layout - layout = StaticLayoutProvider(graph_layout=data['layout']) + layout = StaticLayoutProvider(graph_layout=layout) node_source = self.handles['scatter_1_source'] edge_source = self.handles['multi_line_1_source'] renderer = plot.graph(node_source, edge_source, layout, **properties) @@ -187,6 +233,7 @@ def _init_glyphs(self, plot, element, ranges, source): self.handles['hover'].renderers.append(renderer) + class NodePlot(PointPlot): """ Simple subclass of PointPlot which hides x, y position on hover. diff --git a/holoviews/plotting/mpl/element.py b/holoviews/plotting/mpl/element.py index ebcf6b226b..a564ef102b 100644 --- a/holoviews/plotting/mpl/element.py +++ b/holoviews/plotting/mpl/element.py @@ -697,7 +697,7 @@ def _norm_kwargs(self, element, ranges, opts, vdim, prefix=''): self._cbar_extend = 'max' # Define special out-of-range colors on colormap - cmap = opts.get('cmap') + cmap = opts.get(prefix+'cmap') if isinstance(cmap, list): cmap = mpl_colors.ListedColormap(cmap) elif isinstance(cmap, util.basestring): diff --git a/holoviews/plotting/mpl/graphs.py b/holoviews/plotting/mpl/graphs.py index ac9b176800..7732667402 100644 --- a/holoviews/plotting/mpl/graphs.py +++ b/holoviews/plotting/mpl/graphs.py @@ -3,47 +3,89 @@ from matplotlib.collections import LineCollection -from ...core.util import basestring +from ...core.options import Cycle +from ...core.util import basestring, unique_array, search_indices +from ..util import process_cmap from .element import ColorbarPlot + class GraphPlot(ColorbarPlot): - """ - GraphPlot - """ color_index = param.ClassSelector(default=None, class_=(basestring, int), allow_None=True, doc=""" Index of the dimension from which the color will the drawn""") + edge_color_index = param.ClassSelector(default=None, class_=(basestring, int), + allow_None=True, doc=""" + Index of the dimension from which the color will the drawn""") + style_opts = ['edge_alpha', 'edge_color', 'edge_linestyle', 'edge_linewidth', 'node_alpha', 'node_color', 'node_edgecolors', 'node_facecolors', - 'node_linewidth', 'node_marker', 'node_size', 'visible', 'cmap'] + 'node_linewidth', 'node_marker', 'node_size', 'visible', 'cmap', + 'edge_cmap'] + + _style_groups = ['node', 'edge'] def _compute_styles(self, element, ranges, style): - cdim = element.get_dimension(self.color_index) - color = style.pop('node_color', None) - cmap = style.get('cmap', None) - if cdim and cmap: - cs = element.dimension_values(self.color_index) + elstyle = self.lookup_options(element, 'style') + color = elstyle.kwargs.get('node_color') + cdim = element.nodes.get_dimension(self.color_index) + cmap = elstyle.kwargs.get('cmap', 'tab20') + if cdim: + cs = element.nodes.dimension_values(self.color_index) # Check if numeric otherwise treat as categorical - if cs.dtype.kind in 'if': + if cs.dtype.kind == 'f': style['c'] = cs else: - categories = np.unique(cs) - xsorted = np.argsort(categories) - ypos = np.searchsorted(categories[xsorted], cs) - style['c'] = xsorted[ypos] - self._norm_kwargs(element, ranges, style, cdim) + factors = unique_array(cs) + cmap = color if isinstance(color, Cycle) else cmap + colors = process_cmap(cmap, len(factors)) + cs = search_indices(cs, factors) + style['node_facecolors'] = [colors[v%len(colors)] for v in cs] + style.pop('node_color', None) + if 'c' in style: + self._norm_kwargs(element.nodes, ranges, style, cdim) elif color: - style['c'] = color + style['c'] = style.pop('node_color') style['node_edgecolors'] = style.pop('node_edgecolors', 'none') + + edge_cdim = element.get_dimension(self.edge_color_index) + if not edge_cdim: + return style + + elstyle = self.lookup_options(element, 'style') + cycle = elstyle.kwargs.get('edge_color') + idx = element.get_dimension_index(edge_cdim) + cvals = element.dimension_values(edge_cdim) + if idx in [0, 1]: + factors = element.nodes.dimension_values(2, expanded=False) + elif idx == 2 and cvals.dtype.kind in 'if': + factors = None + else: + factors = unique_array(cvals) + if factors is None or factors.dtype.kind == 'f': + style['edge_array'] = cvals + else: + cvals = search_indices(cvals, factors) + factors = list(factors) + cmap = elstyle.kwargs.get('edge_cmap', 'tab20') + cmap = cycle if isinstance(cycle, Cycle) else cmap + colors = process_cmap(cmap, len(factors)) + style['edge_colors'] = [colors[v%len(colors)] for v in cvals] + style.pop('edge_color', None) + if 'edge_array' in style: + self._norm_kwargs(element, ranges, style, edge_cdim, 'edge_') + else: + style.pop('edge_cmap', None) + if 'edge_vmin' in style: + style['edge_clim'] = (style.pop('edge_vmin'), style.pop('edge_vmax')) return style def get_data(self, element, ranges, style): xidx, yidx = (1, 0) if self.invert_axes else (0, 1) pxs, pys = (element.nodes.dimension_values(i) for i in range(2)) dims = element.nodes.dimensions() - self._compute_styles(element.nodes, ranges, style) + self._compute_styles(element, ranges, style) paths = element.edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims) if self.invert_axes: @@ -62,18 +104,21 @@ def get_extents(self, element, ranges): def init_artists(self, ax, plot_args, plot_kwargs): # Draw edges color_opts = ['c', 'cmap', 'vmin', 'vmax', 'norm'] + groups = [g for g in self._style_groups if g != 'edge'] edge_opts = {k[5:] if 'edge_' in k else k: v for k, v in plot_kwargs.items() - if 'node_' not in k and k not in color_opts} + if not any(k.startswith(p) for p in groups) + and k not in color_opts} paths = plot_args['edges'] edges = LineCollection(paths, **edge_opts) ax.add_collection(edges) # Draw nodes xs, ys = plot_args['nodes'] + groups = [g for g in self._style_groups if g != 'node'] node_opts = {k[5:] if 'node_' in k else k: v for k, v in plot_kwargs.items() - if 'edge_' not in k} + if not any(k.startswith(p) for p in groups)} if 'size' in node_opts: node_opts['s'] = node_opts.pop('size')**2 nodes = ax.scatter(xs, ys, **node_opts) @@ -95,5 +140,10 @@ def update_handles(self, key, axis, element, ranges, style): paths = data['edges'] edges.set_paths(paths) edges.set_visible(style.get('visible', True)) - + cdim = element.get_dimension(self.edge_color_index) + if cdim and 'edge_c' in edges: + edges.set_clim((style['edge_vmin'], style['edge_vmax'])) + edges.set_array(style['edge_c']) + if 'norm' in style: + edges.norm = style['edge_norm'] return axis_kwargs diff --git a/tests/testbokehgraphs.py b/tests/testbokehgraphs.py index 8836e85335..e685c83b5f 100644 --- a/tests/testbokehgraphs.py +++ b/tests/testbokehgraphs.py @@ -13,7 +13,7 @@ from holoviews.plotting.bokeh.util import bokeh_version bokeh_renderer = Store.renderers['bokeh'] from bokeh.models import (NodesAndLinkedEdges, EdgesAndLinkedNodes) - from bokeh.models.mappers import CategoricalColorMapper + from bokeh.models.mappers import CategoricalColorMapper, LinearColorMapper except : bokeh_renderer = None @@ -24,20 +24,22 @@ class BokehGraphPlotTests(ComparisonTestCase): def setUp(self): if not bokeh_renderer: raise SkipTest("Bokeh required to test plot instantiation") - elif bokeh_version < str('0.12.9'): - raise SkipTest("Bokeh >= 0.12.9 required to test graphs") self.previous_backend = Store.current_backend Store.current_backend = 'bokeh' self.default_comm = bokeh_renderer.comms['default'] N = 8 - self.nodes = circular_layout(np.arange(N)) - self.source = np.arange(N) - self.target = np.zeros(N) + self.nodes = circular_layout(np.arange(N, dtype=np.int32)) + self.source = np.arange(N, dtype=np.int32) + self.target = np.zeros(N, dtype=np.int32) + self.weights = np.random.rand(N) self.graph = Graph(((self.source, self.target),)) self.node_info = Dataset(['Output']+['Input']*(N-1), vdims=['Label']) + self.node_info2 = Dataset(self.weights, vdims='Weight') self.graph2 = Graph(((self.source, self.target), self.node_info)) - + self.graph3 = Graph(((self.source, self.target), self.node_info2)) + self.graph4 = Graph(((self.source, self.target, self.weights),), vdims='Weight') + def tearDown(self): Store.current_backend = self.previous_backend bokeh_renderer.comms['default'] = self.default_comm @@ -50,7 +52,7 @@ def test_plot_simple_graph(self): self.assertEqual(node_source.data['index'], self.source) self.assertEqual(edge_source.data['start'], self.source) self.assertEqual(edge_source.data['end'], self.target) - layout = {z: (x, y) for x, y, z in self.graph.nodes.array()} + layout = {str(int(z)): (x, y) for x, y, z in self.graph.nodes.array()} self.assertEqual(layout_source.graph_layout, layout) def test_plot_graph_with_paths(self): @@ -65,7 +67,7 @@ def test_plot_graph_with_paths(self): edges = graph.edgepaths.split() self.assertEqual(edge_source.data['xs'], [path.dimension_values(0) for path in edges]) self.assertEqual(edge_source.data['ys'], [path.dimension_values(1) for path in edges]) - layout = {z: (x, y) for x, y, z in self.graph.nodes.array()} + layout = {str(int(z)): (x, y) for x, y, z in self.graph.nodes.array()} self.assertEqual(layout_source.graph_layout, layout) def test_graph_inspection_policy_nodes(self): @@ -90,7 +92,7 @@ def test_graph_inspection_policy_edges_non_default_names(self): renderer = plot.handles['glyph_renderer'] hover = plot.handles['hover'] self.assertIsInstance(renderer.inspection_policy, EdgesAndLinkedNodes) - self.assertEqual(hover.tooltips, [('source', '@{start}'), ('target', '@{end}')]) + self.assertEqual(hover.tooltips, [('source', '@{source}'), ('target', '@{target}')]) self.assertIn(renderer, hover.renderers) def test_graph_inspection_policy_none(self): @@ -119,13 +121,52 @@ def test_graph_selection_policy_none(self): hover = plot.handles['hover'] self.assertIs(renderer.selection_policy, None) - def test_graph_nodes_colormapped(self): + def test_graph_nodes_categorical_colormapped(self): g = self.graph2.opts(plot=dict(color_index='Label'), style=dict(cmap='Set1')) plot = bokeh_renderer.get_plot(g) cmapper = plot.handles['color_mapper'] node_source = plot.handles['scatter_1_source'] glyph = plot.handles['scatter_1_glyph'] self.assertIsInstance(cmapper, CategoricalColorMapper) - self.assertEqual(cmapper.factors, ['Input', 'Output']) + self.assertEqual(cmapper.factors, ['Output', 'Input']) self.assertEqual(node_source.data['Label'], self.node_info['Label']) self.assertEqual(glyph.fill_color, {'field': 'Label', 'transform': cmapper}) + + def test_graph_nodes_numerically_colormapped(self): + g = self.graph3.opts(plot=dict(color_index='Weight'), style=dict(cmap='viridis')) + plot = bokeh_renderer.get_plot(g) + cmapper = plot.handles['color_mapper'] + node_source = plot.handles['scatter_1_source'] + glyph = plot.handles['scatter_1_glyph'] + self.assertIsInstance(cmapper, LinearColorMapper) + self.assertEqual(cmapper.low, self.weights.min()) + self.assertEqual(cmapper.high, self.weights.max()) + self.assertEqual(node_source.data['Weight'], self.node_info2['Weight']) + self.assertEqual(glyph.fill_color, {'field': 'Weight', 'transform': cmapper}) + + def test_graph_edges_categorical_colormapped(self): + g = self.graph3.opts(plot=dict(edge_color_index='start'), + style=dict(edge_cmap=['#FFFFFF', '#000000'])) + plot = bokeh_renderer.get_plot(g) + cmapper = plot.handles['edge_colormapper'] + edge_source = plot.handles['multi_line_1_source'] + glyph = plot.handles['multi_line_1_glyph'] + self.assertIsInstance(cmapper, CategoricalColorMapper) + factors = ['0', '1', '2', '3', '4', '5', '6', '7'] + self.assertEqual(cmapper.factors, factors) + self.assertEqual(edge_source.data['start_str'], factors) + self.assertEqual(glyph.line_color, {'field': 'start_str', 'transform': cmapper}) + + def test_graph_nodes_numerically_colormapped(self): + g = self.graph4.opts(plot=dict(edge_color_index='Weight'), + style=dict(edge_cmap=['#FFFFFF', '#000000'])) + plot = bokeh_renderer.get_plot(g) + print(plot.handles) + cmapper = plot.handles['edge_colormapper'] + edge_source = plot.handles['multi_line_1_source'] + glyph = plot.handles['multi_line_1_glyph'] + self.assertIsInstance(cmapper, LinearColorMapper) + self.assertEqual(cmapper.low, self.weights.min()) + self.assertEqual(cmapper.high, self.weights.max()) + self.assertEqual(edge_source.data['Weight'], self.node_info2['Weight']) + self.assertEqual(glyph.line_color, {'field': 'Weight', 'transform': cmapper}) diff --git a/tests/testmplgraphs.py b/tests/testmplgraphs.py index 4a846f0258..8a27b67b4a 100644 --- a/tests/testmplgraphs.py +++ b/tests/testmplgraphs.py @@ -30,12 +30,17 @@ def setUp(self): mpl_renderer.comms['default'] = (comms.Comm, '') N = 8 - self.nodes = circular_layout(np.arange(N)) - self.source = np.arange(N) - self.target = np.zeros(N) + self.nodes = circular_layout(np.arange(N, dtype=np.int32)) + self.source = np.arange(N, dtype=np.int32) + self.target = np.zeros(N, dtype=np.int32) + self.weights = np.random.rand(N) self.graph = Graph(((self.source, self.target),)) self.node_info = Dataset(['Output']+['Input']*(N-1), vdims=['Label']) + self.node_info2 = Dataset(self.weights, vdims='Weight') self.graph2 = Graph(((self.source, self.target), self.node_info)) + self.graph3 = Graph(((self.source, self.target), self.node_info2)) + self.graph4 = Graph(((self.source, self.target, self.weights),), vdims='Weight') + def tearDown(self): mpl_renderer.comms['default'] = self.default_comm @@ -49,12 +54,46 @@ def test_plot_simple_graph(self): self.assertEqual([p.vertices for p in edges.get_paths()], [p.array() for p in self.graph.edgepaths.split()]) - def test_plot_graph_colored_nodes(self): + def test_plot_graph_categorical_colored_nodes(self): g = self.graph2.opts(plot=dict(color_index='Label'), style=dict(cmap='Set1')) plot = mpl_renderer.get_plot(g) nodes = plot.handles['nodes'] + facecolors = np.array([[0.89411765, 0.10196078, 0.10980392, 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.]]) + self.assertEqual(nodes.get_facecolors(), facecolors) + + def test_plot_graph_numerically_colored_nodes(self): + g = self.graph3.opts(plot=dict(color_index='Weight'), style=dict(cmap='viridis')) + plot = mpl_renderer.get_plot(g) + nodes = plot.handles['nodes'] + self.assertEqual(nodes.get_array(), self.weights) + self.assertEqual(nodes.get_clim(), (self.weights.min(), self.weights.max())) + + def test_plot_graph_categorical_colored_edges(self): + g = self.graph3.opts(plot=dict(edge_color_index='start'), + style=dict(edge_cmap=['#FFFFFF', '#000000'])) + plot = mpl_renderer.get_plot(g) edges = plot.handles['edges'] - self.assertEqual(nodes.get_offsets(), self.graph.nodes.array([0, 1])) - self.assertEqual([p.vertices for p in edges.get_paths()], - [p.array() for p in self.graph.edgepaths.split()]) - self.assertEqual(nodes.get_array(), np.array([1, 0, 0, 0, 0, 0, 0, 0])) + colors = np.array([[1., 1., 1., 1.], + [0., 0., 0., 1.], + [1., 1., 1., 1.], + [0., 0., 0., 1.], + [1., 1., 1., 1.], + [0., 0., 0., 1.], + [1., 1., 1., 1.], + [0., 0., 0., 1.]]) + self.assertEqual(edges.get_colors(), colors) + + def test_plot_graph_numerically_colored_edges(self): + g = self.graph4.opts(plot=dict(edge_color_index='Weight'), + style=dict(edge_cmap=['#FFFFFF', '#000000'])) + plot = mpl_renderer.get_plot(g) + edges = plot.handles['edges'] + self.assertEqual(edges.get_array(), self.weights) + self.assertEqual(edges.get_clim(), (self.weights.min(), self.weights.max()))