From 4c1b01d3b34d30f44ace3cc69d864e44aea93dd7 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Thu, 23 Nov 2017 19:41:58 +0000 Subject: [PATCH] Improved Graph element (#2145) --- examples/reference/elements/bokeh/Graph.ipynb | 10 +- .../reference/elements/matplotlib/Graph.ipynb | 13 +- examples/user_guide/Network_Graphs.ipynb | 13 +- holoviews/core/util.py | 9 ++ holoviews/element/graphs.py | 91 +++++++++---- holoviews/element/util.py | 60 ++++++++ holoviews/plotting/bokeh/__init__.py | 9 +- holoviews/plotting/bokeh/element.py | 74 +++++----- holoviews/plotting/bokeh/graphs.py | 128 ++++++++++++------ holoviews/plotting/bokeh/util.py | 13 -- holoviews/plotting/mpl/element.py | 18 +-- holoviews/plotting/mpl/graphs.py | 92 ++++++++++--- holoviews/plotting/util.py | 56 +++++++- tests/testbokehgraphs.py | 64 +++++++-- tests/testgraphelement.py | 48 ++++++- tests/testmplgraphs.py | 55 ++++++-- tests/testplotutils.py | 34 ++++- 17 files changed, 592 insertions(+), 195 deletions(-) diff --git a/examples/reference/elements/bokeh/Graph.ipynb b/examples/reference/elements/bokeh/Graph.ipynb index 12c66e38d1..e4298646a5 100644 --- a/examples/reference/elements/bokeh/Graph.ipynb +++ b/examples/reference/elements/bokeh/Graph.ipynb @@ -90,7 +90,7 @@ "source": [ "#### Additional features\n", "\n", - "Next we will extend this example by supplying explicit edges:" + "Next we will extend this example by supplying explicit edges, node information and edge weights. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``." ] }, { @@ -100,8 +100,10 @@ "outputs": [], "source": [ "# Node info\n", + "np.random.seed(7)\n", "x, y = simple_graph.nodes.array([0, 1]).T\n", "node_labels = ['Output']+['Input']*(N-1)\n", + "edge_weights = np.random.rand(8)\n", "\n", "# Compute edge paths\n", "def bezier(start, end, control, steps=np.linspace(0, 1, 100)):\n", @@ -114,10 +116,10 @@ "\n", "# Declare Graph\n", "nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n", - "graph = hv.Graph(((source, target), nodes, paths))\n", + "graph = hv.Graph(((source, target, edge_weights), nodes, paths), vdims='Weight')\n", "\n", - "graph.redim.range(**padding).opts(plot=dict(color_index='Type'),\n", - " style=dict(cmap=['blue', 'yellow']))" + "graph.redim.range(**padding).opts(plot=dict(color_index='Type', edge_color_index='Weight'),\n", + " style=dict(cmap=['blue', 'red'], edge_cmap='viridis'))" ] } ], diff --git a/examples/reference/elements/matplotlib/Graph.ipynb b/examples/reference/elements/matplotlib/Graph.ipynb index 437c132db2..4c2356fb65 100644 --- a/examples/reference/elements/matplotlib/Graph.ipynb +++ b/examples/reference/elements/matplotlib/Graph.ipynb @@ -90,7 +90,8 @@ "source": [ "#### Additional features\n", "\n", - "Next we will extend this example by supplying explicit edges:" + "\n", + "Next we will extend this example by supplying explicit edges, node information and edge weights. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``." ] }, { @@ -99,11 +100,11 @@ "metadata": {}, "outputs": [], "source": [ - "from matplotlib.colors import ListedColormap\n", - "\n", "# Node info\n", + "np.random.seed(7)\n", "x, y = simple_graph.nodes.array([0, 1]).T\n", "node_labels = ['Output']+['Input']*(N-1)\n", + "edge_weights = np.random.rand(8)\n", "\n", "# Compute edge paths\n", "def bezier(start, end, control, steps=np.linspace(0, 1, 100)):\n", @@ -116,10 +117,10 @@ "\n", "# Declare Graph\n", "nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n", - "graph = hv.Graph(((source, target), nodes, paths))\n", + "graph = hv.Graph(((source, target, edge_weights), nodes, paths), vdims='Weight')\n", "\n", - "graph.redim.range(**padding).opts(plot=dict(color_index='Type'),\n", - " style=dict(cmap=ListedColormap(['blue', 'yellow'])))" + "graph.redim.range(**padding).opts(plot=dict(color_index='Type', edge_color_index='Weight'),\n", + " style=dict(cmap=['blue', 'red'], edge_cmap='viridis'))" ] } ], diff --git a/examples/user_guide/Network_Graphs.ipynb b/examples/user_guide/Network_Graphs.ipynb index 3e85143d72..168e89b091 100644 --- a/examples/user_guide/Network_Graphs.ipynb +++ b/examples/user_guide/Network_Graphs.ipynb @@ -43,8 +43,8 @@ "source": [ "# Declare abstract edges\n", "N = 8\n", - "node_indices = np.arange(N)\n", - "source = np.zeros(N)\n", + "node_indices = np.arange(N, dtype=np.int32)\n", + "source = np.zeros(N, dtype=np.int32)\n", "target = node_indices\n", "\n", "padding = dict(x=(-1.2, 1.2), y=(-1.2, 1.2))\n", @@ -148,7 +148,7 @@ "source": [ "#### Additional information\n", "\n", - "We can also associate additional information with the nodes and edges of a graph. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself." + "We can also associate additional information with the nodes and edges of a graph. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``." ] }, { @@ -157,12 +157,13 @@ "metadata": {}, "outputs": [], "source": [ - "%%opts Graph [color_index='Type'] (cmap='Set1')\n", + "%%opts Graph [color_index='Type' edge_color_index='Weight'] (cmap='Set1' edge_cmap='viridis')\n", "node_labels = ['Output']+['Input']*(N-1)\n", - "edge_labels = list('ABCDEFGH')\n", + "np.random.seed(7)\n", + "edge_labels = np.random.rand(8)\n", "\n", "nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n", - "graph = hv.Graph(((source, target, edge_labels), nodes, paths), vdims='Label').redim.range(**padding)\n", + "graph = hv.Graph(((source, target, edge_labels), nodes, paths), vdims='Weight').redim.range(**padding)\n", "graph + graph.opts(plot=dict(inspection_policy='edges'))" ] }, diff --git a/holoviews/core/util.py b/holoviews/core/util.py index 9e8e9d9c3a..361fe10286 100644 --- a/holoviews/core/util.py +++ b/holoviews/core/util.py @@ -1569,3 +1569,12 @@ def dt_to_int(value, time_unit='us'): except: # Handle python2 return (time.mktime(value.timetuple()) + value.microsecond / 1e6) * tscale + + +def search_indices(values, source): + """ + Given a set of values returns the indices of each of those values + in the source array. + """ + orig_indices = source.argsort() + return orig_indices[np.searchsorted(source[orig_indices], values)] diff --git a/holoviews/element/graphs.py b/holoviews/element/graphs.py index a26f0c690a..514a3c308d 100644 --- a/holoviews/element/graphs.py +++ b/holoviews/element/graphs.py @@ -9,7 +9,7 @@ from ..core.operation import Operation from .chart import Points from .path import Path -from .util import split_path, pd +from .util import split_path, pd, circular_layout, connect_edges, connect_edges_pd try: from datashader.layout import LayoutAlgorithm as ds_layout @@ -33,14 +33,6 @@ def __call__(self, specs=None, **dimensions): return redimmed.clone(new_data) -def circular_layout(nodes): - N = len(nodes) - circ = np.pi/N*np.arange(N)*2 - x = np.cos(circ) - y = np.sin(circ) - return (x, y, nodes) - - class layout_nodes(Operation): """ Accepts a Graph and lays out the corresponding nodes with the @@ -75,10 +67,15 @@ def _process(self, element, key=None): nodes = nodes[['x', 'y', 'index']] else: nodes = circular_layout(nodes) + nodes = Nodes(nodes) + if element._nodes: + for d in element.nodes.vdims: + vals = element.nodes.dimension_values(d) + nodes = nodes.add_dimension(d, len(nodes.vdims), vals, vdim=True) if self.p.only_nodes: - return Nodes(nodes) + return nodes return element.clone((element.data, nodes)) - + class Graph(Dataset, Element2D): @@ -123,15 +120,61 @@ def __init__(self, data, kdims=None, vdims=None, **params): self._nodes = nodes self._edgepaths = edgepaths super(Graph, self).__init__(edges, kdims=kdims, vdims=vdims, **params) - if self._nodes is None and node_info: - nodes = self.nodes.clone(datatype=['pandas', 'dictionary']) - for d in node_info.dimensions(): + if node_info is not None: + self._add_node_info(node_info) + self._validate() + self.redim = redim_graph(self, mode='dataset') + + + def _add_node_info(self, node_info): + nodes = self.nodes.clone(datatype=['pandas', 'dictionary']) + if isinstance(node_info, Nodes): + nodes = nodes.redim(**dict(zip(nodes.dimensions('key', label=True), + node_info.kdims))) + + if not node_info.kdims and len(node_info) != len(nodes): + raise ValueError("The supplied node data does not match " + "the number of nodes defined by the edges. " + "Ensure that the number of nodes match" + "or supply an index as the sole key " + "dimension to allow the Graph to merge " + "the data.") + + if pd is None: + if node_info.kdims and len(node_info) != len(nodes): + raise ValueError("Graph cannot merge node data on index " + "dimension without pandas. Either ensure " + "the node data matches the order of nodes " + "as they appear in the edge data or install " + "pandas.") + dimensions = nodes.dimensions() + for d in node_info.vdims: + if d in dimensions: + continue nodes = nodes.add_dimension(d, len(nodes.vdims), node_info.dimension_values(d), vdim=True) - self._nodes = nodes - self._validate() - self.redim = redim_graph(self, mode='dataset') + else: + left_on = nodes.kdims[-1].name + node_info_df = node_info.dframe() + node_df = nodes.dframe() + if node_info.kdims: + idx = node_info.kdims[-1] + else: + idx = Dimension('index') + node_info_df = node_info_df.reset_index() + if 'index' in node_info_df.columns and not idx.name == 'index': + node_df = node_df.rename(columns={'index': '__index'}) + left_on = '__index' + cols = [c for c in node_info_df.columns if c not in + node_df.columns or c == idx.name] + node_info_df = node_info_df[cols] + node_df = pd.merge(node_df, node_info_df, left_on=left_on, + right_on=idx.name, how='left') + nodes = nodes.clone(node_df, kdims=nodes.kdims[:2]+[idx], + vdims=node_info.vdims) + + self._nodes = nodes def _validate(self): @@ -300,15 +343,10 @@ def edgepaths(self): """ if self._edgepaths: return self._edgepaths - paths = [] - for start, end in self.array(self.kdims): - start_ds = self.nodes[:, :, start] - end_ds = self.nodes[:, :, end] - if not len(start_ds) or not len(end_ds): - raise ValueError('Could not find node positions for all edges') - sx, sy = start_ds.array(start_ds.kdims[:2]).T - ex, ey = end_ds.array(end_ds.kdims[:2]).T - paths.append([(sx[0], sy[0]), (ex[0], ey[0])]) + if pd is None: + paths = connect_edges(self) + else: + paths = connect_edges_pd(self) return EdgePaths(paths, kdims=self.nodes.kdims[:2]) @@ -354,4 +392,3 @@ class EdgePaths(Path): """ group = param.String(default='EdgePaths', constant=True) - diff --git a/holoviews/element/util.py b/holoviews/element/util.py index 84c5e0d587..f9084be52b 100644 --- a/holoviews/element/util.py +++ b/holoviews/element/util.py @@ -234,3 +234,63 @@ def _process(self, obj, key=None): obj = Dataset(obj, datatype=[dtype]) xcoords, ycoords = self._get_coords(obj) return self._aggregate_dataset(obj, xcoords, ycoords) + + +def circular_layout(nodes): + """ + Lay out nodes on a circle and add node index. + """ + N = len(nodes) + circ = np.pi/N*np.arange(N)*2 + x = np.cos(circ) + y = np.sin(circ) + return (x, y, nodes) + + +def connect_edges_pd(graph): + """ + Given a Graph element containing abstract edges compute edge + segments directly connecting the source and target nodes. This + operation depends on pandas and is a lot faster than the pure + NumPy equivalent. + """ + edges = graph.dframe() + edges.index.name = 'graph_edge_index' + edges = edges.reset_index() + nodes = graph.nodes.dframe() + src, tgt = graph.kdims + x, y, idx = graph.nodes.kdims[:3] + + df = pd.merge(edges, nodes, left_on=[src.name], right_on=[idx.name]) + df = df.rename(columns={x.name: 'src_x', y.name: 'src_y'}) + + df = pd.merge(df, nodes, left_on=[tgt.name], right_on=[idx.name]) + df = df.rename(columns={x.name: 'dst_x', y.name: 'dst_y'}) + df = df.sort_values('graph_edge_index').drop(['graph_edge_index'], axis=1) + + edge_segments = [] + N = len(nodes) + for i, edge in df.iterrows(): + start = edge['src_x'], edge['src_y'] + end = edge['dst_x'], edge['dst_y'] + edge_segments.append(np.array([start, end])) + return edge_segments + + +def connect_edges(graph): + """ + Given a Graph element containing abstract edges compute edge + segments directly connecting the source and target nodes. This + operation just uses internal HoloViews operations and will be a + lot slower than the pandas equivalent. + """ + paths = [] + for start, end in graph.array(graph.kdims): + start_ds = graph.nodes[:, :, start] + end_ds = graph.nodes[:, :, end] + if not len(start_ds) or not len(end_ds): + raise ValueError('Could not find node positions for all edges') + start = start_ds.array(start_ds.kdims[:2]) + end = end_ds.array(end_ds.kdims[:2]) + paths.append(np.array([start[0], end[0]])) + return paths diff --git a/holoviews/plotting/bokeh/__init__.py b/holoviews/plotting/bokeh/__init__.py index 99eef92423..6cf7aabb64 100644 --- a/holoviews/plotting/bokeh/__init__.py +++ b/holoviews/plotting/bokeh/__init__.py @@ -186,18 +186,17 @@ def colormap_generator(palette): options.Arrow = Options('style', arrow_size=10) # Graphs -options.Graph = Options('style', node_size=20, node_fill_color=Cycle(), +options.Graph = Options('style', node_size=15, node_fill_color=Cycle(), node_line_color='black', - node_selection_fill_color='limegreen', node_nonselection_fill_color=Cycle(), node_hover_line_color='black', - node_hover_fill_color='indianred', + node_hover_fill_color='limegreen', node_nonselection_alpha=0.2, edge_nonselection_alpha=0.2, + node_nonselection_line_color='black', edge_line_color='black', edge_line_width=2, edge_nonselection_line_color='black', - edge_hover_line_color='indianred', - edge_selection_line_color='limegreen') + edge_hover_line_color='limegreen') options.Nodes = Options('style', line_color='black', color=Cycle(), size=20, nonselection_fill_color=Cycle(), selection_fill_color='limegreen', diff --git a/holoviews/plotting/bokeh/element.py b/holoviews/plotting/bokeh/element.py index 2ea9d4a400..271e22ddee 100644 --- a/holoviews/plotting/bokeh/element.py +++ b/holoviews/plotting/bokeh/element.py @@ -5,7 +5,6 @@ import numpy as np import bokeh import bokeh.plotting -from bokeh import palettes from bokeh.core.properties import value from bokeh.models import (HoverTool, Renderer, Range1d, DataRange1d, Title, FactorRange, FuncTickFormatter, Tool, Legend) @@ -24,10 +23,10 @@ from ...core import util from ...streams import Stream, Buffer from ..plot import GenericElementPlot, GenericOverlayPlot -from ..util import dynamic_update +from ..util import dynamic_update, process_cmap from .plot import BokehPlot, TOOLS -from .util import (mpl_to_bokeh, get_tab_title, mplcmap_to_palette, - py2js_tickformatter, rgba_tuple, recursive_model_update) +from .util import (mpl_to_bokeh, get_tab_title, py2js_tickformatter, + rgba_tuple, recursive_model_update) property_prefixes = ['selection', 'nonselection', 'muted', 'hover'] @@ -922,16 +921,17 @@ class CompositeElementPlot(ElementPlot): drawing of multiple glyphs. """ - # Mapping between style groups and glyph names + # Mapping between glyph names and style groups _style_groups = {} # Defines the order in which glyphs are drawn, defined by glyph name _draw_order = [] - def _init_glyphs(self, plot, element, ranges, source): + def _init_glyphs(self, plot, element, ranges, source, data=None, mapping=None, style=None): # Get data and initialize data source - style = self.style[self.cyclic_index] - data, mapping, style = self.get_data(element, ranges, style) + if None in (data, mapping): + style = self.style[self.cyclic_index] + data, mapping, style = self.get_data(element, ranges, style) source_cache = {} current_id = element._plot_id @@ -1100,14 +1100,15 @@ def _draw_colorbar(self, plot, color_mapper): self.handles['colorbar'] = color_bar - def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=None): + def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=None, + name='color_mapper'): # The initial colormapper instance is cached the first time # and then only updated - if dim is None: + if dim is None and colors is None: return None if self.adjoined: cmappers = self.adjoined.traverse(lambda x: (x.handles.get('color_dim'), - x.handles.get('color_mapper'))) + x.handles.get(name))) cmappers = [cmap for cdim, cmap in cmappers if cdim == dim] if cmappers: cmapper = cmappers[0] @@ -1117,30 +1118,18 @@ def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=Non return None ncolors = None if factors is None else len(factors) - low, high = ranges.get(dim.name, element.range(dim.name)) - if colors: - palette = colors + if dim: + low, high = ranges.get(dim.name, element.range(dim.name)) else: - cmap = style.pop('cmap', 'viridis') - if isinstance(cmap, list): - palette = cmap - else: - try: - # Process as matplotlib colormap - palette = mplcmap_to_palette(cmap, ncolors) - except ValueError: - # Process as bokeh palette - palette = getattr(palettes, cmap, None) - if isinstance(palette, dict): - if ncolors in palette: - palette = palette[ncolors] - else: - palette = sorted(palette.items())[-1][1] + low, high = None, None + + cmap = colors or style.pop('cmap', 'viridis') + palette = process_cmap(cmap, ncolors) nan_colors = {k: rgba_tuple(v) for k, v in self.clipping_colors.items()} colormapper, opts = self._get_cmapper_opts(low, high, factors, nan_colors) - if 'color_mapper' in self.handles and isinstance(self.handles['color_mapper'], colormapper): - cmapper = self.handles['color_mapper'] + cmapper = self.handles.get(name) + if cmapper is not None: if cmapper.palette != palette: cmapper.palette = palette opts = {k: opt for k, opt in opts.items() @@ -1149,27 +1138,34 @@ def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=Non cmapper.update(**opts) else: cmapper = colormapper(palette=palette, **opts) - self.handles['color_mapper'] = cmapper + self.handles[name] = cmapper self.handles['color_dim'] = dim return cmapper - def _get_color_data(self, element, ranges, style, name='color', factors=None, colors=None): + def _get_color_data(self, element, ranges, style, name='color', factors=None, colors=None, + int_categories=False): data, mapping = {}, {} cdim = element.get_dimension(self.color_index) if not cdim: return data, mapping cdata = element.dimension_values(cdim) - if factors is None and (isinstance(cdata, list) or cdata.dtype.kind in 'OSU'): - factors = list(np.unique(cdata)) + field = util.dimension_sanitizer(cdim.name) + dtypes = 'iOSU' if int_categories else 'OSU' + if factors is None and (isinstance(cdata, list) or cdata.dtype.kind in dtypes): + factors = list(util.unique_array(cdata)) + if factors and int_categories and cdata.dtype.kind == 'i': + field += '_str' + cdata = [str(f) for f in cdata] + factors = [str(f) for f in factors] + mapper = self._get_colormapper(cdim, element, ranges, style, factors, colors) - data[cdim.name] = cdata + data[field] = cdata if factors is not None: - mapping['legend'] = {'field': cdim.name} - mapping[name] = {'field': cdim.name, - 'transform': mapper} + mapping['legend'] = {'field': field} + mapping[name] = {'field': field, 'transform': mapper} return data, mapping diff --git a/holoviews/plotting/bokeh/graphs.py b/holoviews/plotting/bokeh/graphs.py index 5dbe918902..589f660654 100644 --- a/holoviews/plotting/bokeh/graphs.py +++ b/holoviews/plotting/bokeh/graphs.py @@ -1,17 +1,14 @@ import param import numpy as np +from bokeh.models import HoverTool, ColumnDataSource +from bokeh.models import (StaticLayoutProvider, NodesAndLinkedEdges, + EdgesAndLinkedNodes) -from bokeh.models import Range1d, HoverTool, ColumnDataSource - -try: - from bokeh.models import (StaticLayoutProvider, NodesAndLinkedEdges, - EdgesAndLinkedNodes) -except: - pass - -from ...core.util import basestring, dimension_sanitizer +from ...core.util import basestring, dimension_sanitizer, unique_array +from ...core.options import Cycle from .chart import ColorbarPlot, PointPlot from .element import CompositeElementPlot, LegendPlot, line_properties, fill_properties +from ..util import process_cmap class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot): @@ -20,6 +17,10 @@ class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot): allow_None=True, doc=""" Index of the dimension from which the color will the drawn""") + edge_color_index = param.ClassSelector(default=None, class_=(basestring, int), + allow_None=True, doc=""" + Index of the dimension from which the color will the drawn""") + selection_policy = param.ObjectSelector(default='nodes', objects=['edges', 'nodes', None], doc=""" Determines policy for inspection of graph components, i.e. whether to highlight nodes or edges when selecting connected edges and nodes respectively.""") @@ -31,35 +32,23 @@ class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot): tools = param.List(default=['hover', 'tap'], doc=""" A list of plugin tools to use on the plot.""") - # X-axis is categorical - _x_range_type = Range1d - - # Declare that y-range should auto-range if not bounded - _y_range_type = Range1d - - # Map each glyph to a style group + # Map each glyph to a style group _style_groups = {'scatter': 'node', 'multi_line': 'edge'} style_opts = (['edge_'+p for p in line_properties] +\ - ['node_'+p for p in fill_properties+line_properties]+['node_size', 'cmap']) + ['node_'+p for p in fill_properties+line_properties]+['node_size', 'cmap', 'edge_cmap']) def _hover_opts(self, element): if self.inspection_policy == 'nodes': dims = element.nodes.dimensions() dims = [(dims[2].pprint_label, '@{index_hover}')]+dims[3:] elif self.inspection_policy == 'edges': - kdims = [(kd.pprint_label, '@{%s}' % ref) - for kd, ref in zip(element.kdims, ['start', 'end'])] - dims = kdims+element.vdims + dims = element.kdims+element.vdims else: dims = [] return dims, {} def get_extents(self, element, ranges): - """ - Extents are set to '' and None because x-axis is categorical and - y-axis auto-ranges. - """ xdim, ydim = element.nodes.kdims[:2] x0, x1 = ranges[xdim.name] y0, y1 = ranges[ydim.name] @@ -73,26 +62,76 @@ def _get_axis_labels(self, *args, **kwargs): xlabel, ylabel = [kd.pprint_label for kd in element.nodes.kdims[:2]] return xlabel, ylabel, None + + def _get_edge_colors(self, element, ranges, edge_data, edge_mapping, style): + cdim = element.get_dimension(self.edge_color_index) + if not cdim: + return + elstyle = self.lookup_options(element, 'style') + cycle = elstyle.kwargs.get('edge_color') + + idx = element.get_dimension_index(cdim) + field = dimension_sanitizer(cdim.name) + cvals = element.dimension_values(cdim) + if idx in [0, 1]: + factors = element.nodes.dimension_values(2, expanded=False) + elif idx == 2 and cvals.dtype.kind in 'if': + factors = None + else: + factors = unique_array(cvals) + + default_cmap = 'viridis' if factors is None else 'tab20' + cmap = style.get('edge_cmap', style.get('cmap', default_cmap)) + if factors is None or (factors.dtype.kind == 'f' and idx not in [0, 1]): + colors, factors = None, None + else: + if factors.dtype.kind == 'f': + cvals = cvals.astype(np.int32) + factors = factors.astype(np.int32) + if factors.dtype.kind not in 'SU': + field += '_str' + cvals = [str(f) for f in cvals] + factors = (str(f) for f in factors) + factors = list(factors) + colors = process_cmap(cycle or cmap, len(factors)) + + if field not in edge_data: + edge_data[field] = cvals + edge_style = dict(style, cmap=cmap) + mapper = self._get_colormapper(cdim, element, ranges, edge_style, + factors, colors, 'edge_colormapper') + transform = {'field': field, 'transform': mapper} + edge_mapping['edge_line_color'] = transform + edge_mapping['edge_nonselection_line_color'] = transform + edge_mapping['edge_selection_line_color'] = transform + + def get_data(self, element, ranges, style): xidx, yidx = (1, 0) if self.invert_axes else (0, 1) # Get node data nodes = element.nodes.dimension_values(2) - node_positions = element.nodes.array([0, 1, 2]) + node_positions = element.nodes.array([0, 1]) # Map node indices to integers - if nodes.dtype.kind != 'i': + if nodes.dtype.kind not in 'if': node_indices = {v: i for i, v in enumerate(nodes)} index = np.array([node_indices[n] for n in nodes], dtype=np.int32) - layout = {node_indices[z]: (y, x) if self.invert_axes else (x, y) - for x, y, z in node_positions} + layout = {str(node_indices[k]): (y, x) if self.invert_axes else (x, y) + for k, (x, y) in zip(nodes, node_positions)} else: index = nodes.astype(np.int32) - layout = {z: (y, x) if self.invert_axes else (x, y) - for x, y, z in node_positions} + layout = {str(k): (y, x) if self.invert_axes else (x, y) + for k, (x, y) in zip(index, node_positions)} point_data = {'index': index} - cdata, cmapping = self._get_color_data(element.nodes, ranges, style, 'node_fill_color') + cycle = self.lookup_options(element, 'style').kwargs.get('node_color') + colors = cycle if isinstance(cycle, Cycle) else None + cdata, cmapping = self._get_color_data( + element.nodes, ranges, style, name='node_fill_color', + colors=colors, int_categories=True + ) point_data.update(cdata) point_mapping = cmapping + edge_mapping = {} if 'node_fill_color' in point_mapping: style = {k: v for k, v in style.items() if k not in ['node_fill_color', 'node_nonselection_fill_color']} @@ -101,10 +140,13 @@ def get_data(self, element, ranges, style): # Get edge data nan_node = index.max()+1 start, end = (element.dimension_values(i) for i in range(2)) - if nodes.dtype.kind not in 'if': + if nodes.dtype.kind == 'f': + start, end = start.astype(np.int32), end.astype(np.int32) + elif nodes.dtype.kind != 'i': start = np.array([node_indices.get(x, nan_node) for x in start], dtype=np.int32) end = np.array([node_indices.get(y, nan_node) for y in end], dtype=np.int32) path_data = dict(start=start, end=end) + self._get_edge_colors(element, ranges, path_data, edge_mapping, style) if element._edgepaths and not self.static_source: edges = element._split_edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims) if len(edges) == len(start): @@ -122,11 +164,10 @@ def get_data(self, element, ranges, style): for d in element.nodes.dimensions()[3:]: point_data[dimension_sanitizer(d.name)] = element.nodes.dimension_values(d) elif self.inspection_policy == 'edges': - for d in element.vdims: + for d in element.dimensions(): path_data[dimension_sanitizer(d.name)] = element.dimension_values(d) - data = {'scatter_1': point_data, 'multi_line_1': path_data, 'layout': layout} - mapping = {'scatter_1': point_mapping, 'multi_line_1': {}} + mapping = {'scatter_1': point_mapping, 'multi_line_1': edge_mapping} return data, mapping, style @@ -145,19 +186,27 @@ def _init_glyphs(self, plot, element, ranges, source): style = self.style[self.cyclic_index] data, mapping, style = self.get_data(element, ranges, style) self.handles['previous_id'] = element._plot_id + properties = {} mappings = {} - for key in mapping: - source = self._init_datasource(data.get(key, {})) + for key in list(mapping): + if not any(glyph in key for glyph in ('scatter_1', 'multi_line_1')): + continue + source = self._init_datasource(data.pop(key, {})) self.handles[key+'_source'] = source glyph_props = self._glyph_properties(plot, element, source, ranges, style) properties.update(glyph_props) - mappings.update(mapping.get(key, {})) + mappings.update(mapping.pop(key, {})) properties = {p: v for p, v in properties.items() if p not in ('legend', 'source')} properties.update(mappings) + layout = data.pop('layout', {}) + if data and mapping: + CompositeElementPlot._init_glyphs(self, plot, element, ranges, source, + data, mapping, style) + # Define static layout - layout = StaticLayoutProvider(graph_layout=data['layout']) + layout = StaticLayoutProvider(graph_layout=layout) node_source = self.handles['scatter_1_source'] edge_source = self.handles['multi_line_1_source'] renderer = plot.graph(node_source, edge_source, layout, **properties) @@ -187,6 +236,7 @@ def _init_glyphs(self, plot, element, ranges, source): self.handles['hover'].renderers.append(renderer) + class NodePlot(PointPlot): """ Simple subclass of PointPlot which hides x, y position on hover. diff --git a/holoviews/plotting/bokeh/util.py b/holoviews/plotting/bokeh/util.py index c337a9c43b..efef440ca9 100644 --- a/holoviews/plotting/bokeh/util.py +++ b/holoviews/plotting/bokeh/util.py @@ -66,19 +66,6 @@ def rgba_tuple(rgba): return rgba -def mplcmap_to_palette(cmap, ncolors=None): - """ - Converts a matplotlib colormap to palette of RGB hex strings." - """ - if colors is None: - raise ValueError("Using cmaps on objects requires matplotlib.") - with abbreviated_exception(): - colormap = cm.get_cmap(cmap) #choose any matplotlib colormap here - if ncolors: - return [rgb2hex(colormap(i)) for i in np.linspace(0, 1, ncolors)] - return [rgb2hex(m) for m in colormap(np.arange(colormap.N))] - - def get_cmap(cmap): """ Returns matplotlib cmap generated from bokeh palette or diff --git a/holoviews/plotting/mpl/element.py b/holoviews/plotting/mpl/element.py index 633cfc9472..a564ef102b 100644 --- a/holoviews/plotting/mpl/element.py +++ b/holoviews/plotting/mpl/element.py @@ -645,12 +645,12 @@ def _draw_colorbar(self, dim=None, redraw=True): ColorbarPlot._colorbars[id(axis)] = (ax_colorbars, (l, b, w, h)) - def _norm_kwargs(self, element, ranges, opts, vdim): + def _norm_kwargs(self, element, ranges, opts, vdim, prefix=''): """ Returns valid color normalization kwargs to be passed to matplotlib plot function. """ - clim = opts.pop('clims', None) + clim = opts.pop(prefix+'clims', None) if clim is None: cs = element.dimension_values(vdim) if not isinstance(cs, np.ndarray): @@ -674,9 +674,9 @@ def _norm_kwargs(self, element, ranges, opts, vdim): linthresh=clim[1]/np.e) else: norm = mpl_colors.LogNorm(vmin=clim[0], vmax=clim[1]) - opts['norm'] = norm - opts['vmin'] = clim[0] - opts['vmax'] = clim[1] + opts[prefix+'norm'] = norm + opts[prefix+'vmin'] = clim[0] + opts[prefix+'vmax'] = clim[1] # Check whether the colorbar should indicate clipping values = np.asarray(element.dimension_values(vdim)) @@ -687,8 +687,8 @@ def _norm_kwargs(self, element, ranges, opts, vdim): el_min, el_max = -np.inf, np.inf else: el_min, el_max = -np.inf, np.inf - vmin = -np.inf if opts['vmin'] is None else opts['vmin'] - vmax = np.inf if opts['vmax'] is None else opts['vmax'] + vmin = -np.inf if opts[prefix+'vmin'] is None else opts[prefix+'vmin'] + vmax = np.inf if opts[prefix+'vmax'] is None else opts[prefix+'vmax'] if el_min < vmin and el_max > vmax: self._cbar_extend = 'both' elif el_min < vmin: @@ -697,7 +697,7 @@ def _norm_kwargs(self, element, ranges, opts, vdim): self._cbar_extend = 'max' # Define special out-of-range colors on colormap - cmap = opts.get('cmap') + cmap = opts.get(prefix+'cmap') if isinstance(cmap, list): cmap = mpl_colors.ListedColormap(cmap) elif isinstance(cmap, util.basestring): @@ -719,7 +719,7 @@ def _norm_kwargs(self, element, ranges, opts, vdim): if 'max' in colors: cmap.set_over(**colors['max']) if 'min' in colors: cmap.set_under(**colors['min']) if 'NaN' in colors: cmap.set_bad(**colors['NaN']) - opts['cmap'] = cmap + opts[prefix+'cmap'] = cmap diff --git a/holoviews/plotting/mpl/graphs.py b/holoviews/plotting/mpl/graphs.py index ac9b176800..f7ce5476bd 100644 --- a/holoviews/plotting/mpl/graphs.py +++ b/holoviews/plotting/mpl/graphs.py @@ -3,47 +3,89 @@ from matplotlib.collections import LineCollection -from ...core.util import basestring +from ...core.options import Cycle +from ...core.util import basestring, unique_array, search_indices +from ..util import process_cmap from .element import ColorbarPlot + class GraphPlot(ColorbarPlot): - """ - GraphPlot - """ color_index = param.ClassSelector(default=None, class_=(basestring, int), allow_None=True, doc=""" Index of the dimension from which the color will the drawn""") + edge_color_index = param.ClassSelector(default=None, class_=(basestring, int), + allow_None=True, doc=""" + Index of the dimension from which the color will the drawn""") + style_opts = ['edge_alpha', 'edge_color', 'edge_linestyle', 'edge_linewidth', 'node_alpha', 'node_color', 'node_edgecolors', 'node_facecolors', - 'node_linewidth', 'node_marker', 'node_size', 'visible', 'cmap'] + 'node_linewidth', 'node_marker', 'node_size', 'visible', 'cmap', + 'edge_cmap'] + + _style_groups = ['node', 'edge'] def _compute_styles(self, element, ranges, style): - cdim = element.get_dimension(self.color_index) - color = style.pop('node_color', None) - cmap = style.get('cmap', None) - if cdim and cmap: - cs = element.dimension_values(self.color_index) + elstyle = self.lookup_options(element, 'style') + color = elstyle.kwargs.get('node_color') + cdim = element.nodes.get_dimension(self.color_index) + cmap = elstyle.kwargs.get('cmap', 'tab20') + if cdim: + cs = element.nodes.dimension_values(self.color_index) # Check if numeric otherwise treat as categorical - if cs.dtype.kind in 'if': + if cs.dtype.kind == 'f': style['c'] = cs else: - categories = np.unique(cs) - xsorted = np.argsort(categories) - ypos = np.searchsorted(categories[xsorted], cs) - style['c'] = xsorted[ypos] - self._norm_kwargs(element, ranges, style, cdim) + factors = unique_array(cs) + cmap = color if isinstance(color, Cycle) else cmap + colors = process_cmap(cmap, len(factors)) + cs = search_indices(cs, factors) + style['node_facecolors'] = [colors[v%len(colors)] for v in cs] + style.pop('node_color', None) + if 'c' in style: + self._norm_kwargs(element.nodes, ranges, style, cdim) elif color: - style['c'] = color + style['c'] = style.pop('node_color') style['node_edgecolors'] = style.pop('node_edgecolors', 'none') + + edge_cdim = element.get_dimension(self.edge_color_index) + if not edge_cdim: + return style + + elstyle = self.lookup_options(element, 'style') + cycle = elstyle.kwargs.get('edge_color') + idx = element.get_dimension_index(edge_cdim) + cvals = element.dimension_values(edge_cdim) + if idx in [0, 1]: + factors = element.nodes.dimension_values(2, expanded=False) + elif idx == 2 and cvals.dtype.kind in 'if': + factors = None + else: + factors = unique_array(cvals) + if factors is None or (factors.dtype.kind == 'f' and idx not in [0, 1]): + style['edge_array'] = cvals + else: + cvals = search_indices(cvals, factors) + factors = list(factors) + cmap = elstyle.kwargs.get('edge_cmap', 'tab20') + cmap = cycle if isinstance(cycle, Cycle) else cmap + colors = process_cmap(cmap, len(factors)) + style['edge_colors'] = [colors[v%len(colors)] for v in cvals] + style.pop('edge_color', None) + if 'edge_array' in style: + self._norm_kwargs(element, ranges, style, edge_cdim, 'edge_') + else: + style.pop('edge_cmap', None) + if 'edge_vmin' in style: + style['edge_clim'] = (style.pop('edge_vmin'), style.pop('edge_vmax')) return style def get_data(self, element, ranges, style): xidx, yidx = (1, 0) if self.invert_axes else (0, 1) pxs, pys = (element.nodes.dimension_values(i) for i in range(2)) dims = element.nodes.dimensions() - self._compute_styles(element.nodes, ranges, style) + self._compute_styles(element, ranges, style) paths = element.edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims) if self.invert_axes: @@ -62,18 +104,21 @@ def get_extents(self, element, ranges): def init_artists(self, ax, plot_args, plot_kwargs): # Draw edges color_opts = ['c', 'cmap', 'vmin', 'vmax', 'norm'] + groups = [g for g in self._style_groups if g != 'edge'] edge_opts = {k[5:] if 'edge_' in k else k: v for k, v in plot_kwargs.items() - if 'node_' not in k and k not in color_opts} + if not any(k.startswith(p) for p in groups) + and k not in color_opts} paths = plot_args['edges'] edges = LineCollection(paths, **edge_opts) ax.add_collection(edges) # Draw nodes xs, ys = plot_args['nodes'] + groups = [g for g in self._style_groups if g != 'node'] node_opts = {k[5:] if 'node_' in k else k: v for k, v in plot_kwargs.items() - if 'edge_' not in k} + if not any(k.startswith(p) for p in groups)} if 'size' in node_opts: node_opts['s'] = node_opts.pop('size')**2 nodes = ax.scatter(xs, ys, **node_opts) @@ -95,5 +140,10 @@ def update_handles(self, key, axis, element, ranges, style): paths = data['edges'] edges.set_paths(paths) edges.set_visible(style.get('visible', True)) - + cdim = element.get_dimension(self.edge_color_index) + if cdim and 'edge_c' in edges: + edges.set_clim((style['edge_vmin'], style['edge_vmax'])) + edges.set_array(style['edge_c']) + if 'norm' in style: + edges.norm = style['edge_norm'] return axis_kwargs diff --git a/holoviews/plotting/util.py b/holoviews/plotting/util.py index 3e57b08915..e7b5d9a13d 100644 --- a/holoviews/plotting/util.py +++ b/holoviews/plotting/util.py @@ -1,4 +1,4 @@ -from __future__ import unicode_literals +from __future__ import unicode_literals, absolute_import from collections import defaultdict import traceback @@ -7,6 +7,7 @@ from ..core import (HoloMap, DynamicMap, CompositeOverlay, Layout, Overlay, GridSpace, NdLayout, Store) +from ..core.options import Cycle from ..core.spaces import get_nested_streams from ..core.util import (match_spec, is_number, wrap_tuple, basestring, get_overlay_spec, unique_iterator) @@ -411,6 +412,59 @@ def map_colors(arr, crange, cmap, hex=True): return arr +def mplcmap_to_palette(cmap, ncolors=None): + """ + Converts a matplotlib colormap to palette of RGB hex strings." + """ + import matplotlib.cm as cm + colormap = cm.get_cmap(cmap) #choose any matplotlib colormap here + if ncolors: + return [rgb2hex(colormap(i)) for i in np.linspace(0, 1, ncolors)] + return [rgb2hex(m) for m in colormap(np.arange(colormap.N))] + + +def bokeh_palette_to_palette(cmap, ncolors=None): + from bokeh import palettes + # Process as bokeh palette + palette = getattr(palettes, cmap, None) + if palette is None: + raise ValueError("Supplied palette %s not found among bokeh palettes" % cmap) + elif isinstance(palette, dict): + if ncolors in palette: + palette = palette[ncolors] + else: + palette = sorted(palette.items())[-1][1] + if ncolors: + return [palette[i%len(palette)] for i in range(ncolors)] + return palette + + +def process_cmap(cmap, ncolors=None): + """ + Convert valid colormap specifications to a list of colors. + """ + if isinstance(cmap, Cycle): + palette = [rgb2hex(c) if isinstance(c, tuple) else c for c in cmap.values] + elif isinstance(cmap, list): + palette = cmap + elif isinstance(cmap, basestring): + try: + # Process as matplotlib colormap + palette = mplcmap_to_palette(cmap, ncolors) + except: + try: + palette = bokeh_palette_to_palette(cmap, ncolors) + except: + raise ValueError("Supplied cmap %s not found among " + "matplotlib or bokeh colormaps.") + else: + raise TypeError("cmap argument expects a list, Cycle or valid matplotlib " + "colormap or bokeh palette, found %s." % cmap) + if ncolors: + return [palette[i%len(palette)] for i in range(ncolors)] + return palette + + def dim_axis_label(dimensions, separator=', '): """ Returns an axis label for one or more dimensions. diff --git a/tests/testbokehgraphs.py b/tests/testbokehgraphs.py index 8836e85335..b596e82e22 100644 --- a/tests/testbokehgraphs.py +++ b/tests/testbokehgraphs.py @@ -13,7 +13,7 @@ from holoviews.plotting.bokeh.util import bokeh_version bokeh_renderer = Store.renderers['bokeh'] from bokeh.models import (NodesAndLinkedEdges, EdgesAndLinkedNodes) - from bokeh.models.mappers import CategoricalColorMapper + from bokeh.models.mappers import CategoricalColorMapper, LinearColorMapper except : bokeh_renderer = None @@ -24,20 +24,22 @@ class BokehGraphPlotTests(ComparisonTestCase): def setUp(self): if not bokeh_renderer: raise SkipTest("Bokeh required to test plot instantiation") - elif bokeh_version < str('0.12.9'): - raise SkipTest("Bokeh >= 0.12.9 required to test graphs") self.previous_backend = Store.current_backend Store.current_backend = 'bokeh' self.default_comm = bokeh_renderer.comms['default'] N = 8 - self.nodes = circular_layout(np.arange(N)) - self.source = np.arange(N) - self.target = np.zeros(N) + self.nodes = circular_layout(np.arange(N, dtype=np.int32)) + self.source = np.arange(N, dtype=np.int32) + self.target = np.zeros(N, dtype=np.int32) + self.weights = np.random.rand(N) self.graph = Graph(((self.source, self.target),)) self.node_info = Dataset(['Output']+['Input']*(N-1), vdims=['Label']) + self.node_info2 = Dataset(self.weights, vdims='Weight') self.graph2 = Graph(((self.source, self.target), self.node_info)) - + self.graph3 = Graph(((self.source, self.target), self.node_info2)) + self.graph4 = Graph(((self.source, self.target, self.weights),), vdims='Weight') + def tearDown(self): Store.current_backend = self.previous_backend bokeh_renderer.comms['default'] = self.default_comm @@ -50,7 +52,7 @@ def test_plot_simple_graph(self): self.assertEqual(node_source.data['index'], self.source) self.assertEqual(edge_source.data['start'], self.source) self.assertEqual(edge_source.data['end'], self.target) - layout = {z: (x, y) for x, y, z in self.graph.nodes.array()} + layout = {str(int(z)): (x, y) for x, y, z in self.graph.nodes.array()} self.assertEqual(layout_source.graph_layout, layout) def test_plot_graph_with_paths(self): @@ -65,7 +67,7 @@ def test_plot_graph_with_paths(self): edges = graph.edgepaths.split() self.assertEqual(edge_source.data['xs'], [path.dimension_values(0) for path in edges]) self.assertEqual(edge_source.data['ys'], [path.dimension_values(1) for path in edges]) - layout = {z: (x, y) for x, y, z in self.graph.nodes.array()} + layout = {str(int(z)): (x, y) for x, y, z in self.graph.nodes.array()} self.assertEqual(layout_source.graph_layout, layout) def test_graph_inspection_policy_nodes(self): @@ -90,7 +92,7 @@ def test_graph_inspection_policy_edges_non_default_names(self): renderer = plot.handles['glyph_renderer'] hover = plot.handles['hover'] self.assertIsInstance(renderer.inspection_policy, EdgesAndLinkedNodes) - self.assertEqual(hover.tooltips, [('source', '@{start}'), ('target', '@{end}')]) + self.assertEqual(hover.tooltips, [('source', '@{source}'), ('target', '@{target}')]) self.assertIn(renderer, hover.renderers) def test_graph_inspection_policy_none(self): @@ -119,13 +121,51 @@ def test_graph_selection_policy_none(self): hover = plot.handles['hover'] self.assertIs(renderer.selection_policy, None) - def test_graph_nodes_colormapped(self): + def test_graph_nodes_categorical_colormapped(self): g = self.graph2.opts(plot=dict(color_index='Label'), style=dict(cmap='Set1')) plot = bokeh_renderer.get_plot(g) cmapper = plot.handles['color_mapper'] node_source = plot.handles['scatter_1_source'] glyph = plot.handles['scatter_1_glyph'] self.assertIsInstance(cmapper, CategoricalColorMapper) - self.assertEqual(cmapper.factors, ['Input', 'Output']) + self.assertEqual(cmapper.factors, ['Output', 'Input']) self.assertEqual(node_source.data['Label'], self.node_info['Label']) self.assertEqual(glyph.fill_color, {'field': 'Label', 'transform': cmapper}) + + def test_graph_nodes_numerically_colormapped(self): + g = self.graph3.opts(plot=dict(color_index='Weight'), style=dict(cmap='viridis')) + plot = bokeh_renderer.get_plot(g) + cmapper = plot.handles['color_mapper'] + node_source = plot.handles['scatter_1_source'] + glyph = plot.handles['scatter_1_glyph'] + self.assertIsInstance(cmapper, LinearColorMapper) + self.assertEqual(cmapper.low, self.weights.min()) + self.assertEqual(cmapper.high, self.weights.max()) + self.assertEqual(node_source.data['Weight'], self.node_info2['Weight']) + self.assertEqual(glyph.fill_color, {'field': 'Weight', 'transform': cmapper}) + + def test_graph_edges_categorical_colormapped(self): + g = self.graph3.opts(plot=dict(edge_color_index='start'), + style=dict(edge_cmap=['#FFFFFF', '#000000'])) + plot = bokeh_renderer.get_plot(g) + cmapper = plot.handles['edge_colormapper'] + edge_source = plot.handles['multi_line_1_source'] + glyph = plot.handles['multi_line_1_glyph'] + self.assertIsInstance(cmapper, CategoricalColorMapper) + factors = ['0', '1', '2', '3', '4', '5', '6', '7'] + self.assertEqual(cmapper.factors, factors) + self.assertEqual(edge_source.data['start_str'], factors) + self.assertEqual(glyph.line_color, {'field': 'start_str', 'transform': cmapper}) + + def test_graph_nodes_numerically_colormapped(self): + g = self.graph4.opts(plot=dict(edge_color_index='Weight'), + style=dict(edge_cmap=['#FFFFFF', '#000000'])) + plot = bokeh_renderer.get_plot(g) + cmapper = plot.handles['edge_colormapper'] + edge_source = plot.handles['multi_line_1_source'] + glyph = plot.handles['multi_line_1_glyph'] + self.assertIsInstance(cmapper, LinearColorMapper) + self.assertEqual(cmapper.low, self.weights.min()) + self.assertEqual(cmapper.high, self.weights.max()) + self.assertEqual(edge_source.data['Weight'], self.node_info2['Weight']) + self.assertEqual(glyph.line_color, {'field': 'Weight', 'transform': cmapper}) diff --git a/tests/testgraphelement.py b/tests/testgraphelement.py index 23bc2c58c0..7547b9f123 100644 --- a/tests/testgraphelement.py +++ b/tests/testgraphelement.py @@ -2,17 +2,20 @@ Unit tests of Graph Element. """ import numpy as np -from holoviews.element.graphs import Graph, Nodes, circular_layout +from holoviews.core.data import Dataset +from holoviews.element.graphs import ( + Graph, Nodes, circular_layout, connect_edges, connect_edges_pd) from holoviews.element.comparison import ComparisonTestCase + class GraphTests(ComparisonTestCase): def setUp(self): N = 8 self.nodes = circular_layout(np.arange(N)) - self.source = np.arange(N) - self.target = np.zeros(N) + self.source = np.arange(N, dtype=np.int32) + self.target = np.zeros(N, dtype=np.int32) self.edge_info = np.arange(N) self.graph = Graph(((self.source, self.target),)) @@ -26,6 +29,45 @@ def test_constructor_with_nodes(self): nodes = Nodes(self.nodes) self.assertEqual(graph.nodes, nodes) + def test_graph_edge_segments(self): + segments = connect_edges(self.graph) + paths = [] + nodes = np.column_stack(self.nodes) + for start, end in zip(nodes[self.source], nodes[self.target]): + paths.append(np.array([start[:2], end[:2]])) + self.assertEqual(segments, paths) + + def test_graph_node_info_no_index(self): + node_info = Dataset(np.arange(8), vdims=['Label']) + graph = Graph(((self.source, self.target), node_info)) + self.assertEqual(graph.nodes.dimension_values(3), + node_info.dimension_values(0)) + + def test_graph_node_info_no_index_mismatch(self): + node_info = Dataset(np.arange(6), vdims=['Label']) + with self.assertRaises(ValueError): + Graph(((self.source, self.target), node_info)) + + def test_graph_node_info_merge_on_index(self): + node_info = Dataset((np.arange(8), np.arange(1,9)), 'index', 'label') + graph = Graph(((self.source, self.target), node_info)) + self.assertEqual(graph.nodes.dimension_values(3), + node_info.dimension_values(1)) + + def test_graph_node_info_merge_on_index_partial(self): + node_info = Dataset((np.arange(5), np.arange(1,6)), 'index', 'label') + graph = Graph(((self.source, self.target), node_info)) + expected = np.array([1., 2., 3., 4., 5., np.NaN, np.NaN, np.NaN]) + self.assertEqual(graph.nodes.dimension_values(3), expected) + + def test_graph_edge_segments_pd(self): + segments = connect_edges_pd(self.graph) + paths = [] + nodes = np.column_stack(self.nodes) + for start, end in zip(nodes[self.source], nodes[self.target]): + paths.append(np.array([start[:2], end[:2]])) + self.assertEqual(segments, paths) + def test_constructor_with_nodes_and_paths(self): paths = Graph(((self.source, self.target), self.nodes)).edgepaths graph = Graph(((self.source, self.target), self.nodes, paths.data)) diff --git a/tests/testmplgraphs.py b/tests/testmplgraphs.py index 4a846f0258..8a27b67b4a 100644 --- a/tests/testmplgraphs.py +++ b/tests/testmplgraphs.py @@ -30,12 +30,17 @@ def setUp(self): mpl_renderer.comms['default'] = (comms.Comm, '') N = 8 - self.nodes = circular_layout(np.arange(N)) - self.source = np.arange(N) - self.target = np.zeros(N) + self.nodes = circular_layout(np.arange(N, dtype=np.int32)) + self.source = np.arange(N, dtype=np.int32) + self.target = np.zeros(N, dtype=np.int32) + self.weights = np.random.rand(N) self.graph = Graph(((self.source, self.target),)) self.node_info = Dataset(['Output']+['Input']*(N-1), vdims=['Label']) + self.node_info2 = Dataset(self.weights, vdims='Weight') self.graph2 = Graph(((self.source, self.target), self.node_info)) + self.graph3 = Graph(((self.source, self.target), self.node_info2)) + self.graph4 = Graph(((self.source, self.target, self.weights),), vdims='Weight') + def tearDown(self): mpl_renderer.comms['default'] = self.default_comm @@ -49,12 +54,46 @@ def test_plot_simple_graph(self): self.assertEqual([p.vertices for p in edges.get_paths()], [p.array() for p in self.graph.edgepaths.split()]) - def test_plot_graph_colored_nodes(self): + def test_plot_graph_categorical_colored_nodes(self): g = self.graph2.opts(plot=dict(color_index='Label'), style=dict(cmap='Set1')) plot = mpl_renderer.get_plot(g) nodes = plot.handles['nodes'] + facecolors = np.array([[0.89411765, 0.10196078, 0.10980392, 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.], + [0.6 , 0.6 , 0.6 , 1.]]) + self.assertEqual(nodes.get_facecolors(), facecolors) + + def test_plot_graph_numerically_colored_nodes(self): + g = self.graph3.opts(plot=dict(color_index='Weight'), style=dict(cmap='viridis')) + plot = mpl_renderer.get_plot(g) + nodes = plot.handles['nodes'] + self.assertEqual(nodes.get_array(), self.weights) + self.assertEqual(nodes.get_clim(), (self.weights.min(), self.weights.max())) + + def test_plot_graph_categorical_colored_edges(self): + g = self.graph3.opts(plot=dict(edge_color_index='start'), + style=dict(edge_cmap=['#FFFFFF', '#000000'])) + plot = mpl_renderer.get_plot(g) edges = plot.handles['edges'] - self.assertEqual(nodes.get_offsets(), self.graph.nodes.array([0, 1])) - self.assertEqual([p.vertices for p in edges.get_paths()], - [p.array() for p in self.graph.edgepaths.split()]) - self.assertEqual(nodes.get_array(), np.array([1, 0, 0, 0, 0, 0, 0, 0])) + colors = np.array([[1., 1., 1., 1.], + [0., 0., 0., 1.], + [1., 1., 1., 1.], + [0., 0., 0., 1.], + [1., 1., 1., 1.], + [0., 0., 0., 1.], + [1., 1., 1., 1.], + [0., 0., 0., 1.]]) + self.assertEqual(edges.get_colors(), colors) + + def test_plot_graph_numerically_colored_edges(self): + g = self.graph4.opts(plot=dict(edge_color_index='Weight'), + style=dict(edge_cmap=['#FFFFFF', '#000000'])) + plot = mpl_renderer.get_plot(g) + edges = plot.handles['edges'] + self.assertEqual(edges.get_array(), self.weights) + self.assertEqual(edges.get_clim(), (self.weights.min(), self.weights.max())) diff --git a/tests/testplotutils.py b/tests/testplotutils.py index 82eb5ed10c..76fa1cb447 100644 --- a/tests/testplotutils.py +++ b/tests/testplotutils.py @@ -5,10 +5,11 @@ from holoviews import NdOverlay, Overlay from holoviews.core.spaces import DynamicMap -from holoviews.core.options import Store +from holoviews.core.options import Store, Cycle from holoviews.element.comparison import ComparisonTestCase from holoviews.element import Curve, Area, Points -from holoviews.plotting.util import compute_overlayable_zorders, get_min_distance +from holoviews.plotting.util import ( + compute_overlayable_zorders, get_min_distance, process_cmap) from holoviews.streams import PointerX try: @@ -303,6 +304,35 @@ def test_dynamic_compute_overlayable_zorders_three_deep_dynamic_layers_reduced_l self.assertNotIn(curve, sources[2]) + + +class TestPlotColorUtils(ComparisonTestCase): + + def test_process_cmap_mpl(self): + colors = process_cmap('Greys', 3) + self.assertEqual(colors, ['#ffffff', '#959595', '#000000']) + + def test_process_cmap_bokeh(self): + colors = process_cmap('Category20', 3) + self.assertEqual(colors, ['#1f77b4', '#aec7e8', '#ff7f0e']) + + def test_process_cmap_list_cycle(self): + colors = process_cmap(['#ffffff', '#959595', '#000000'], 4) + self.assertEqual(colors, ['#ffffff', '#959595', '#000000', '#ffffff']) + + def test_process_cmap_cycle(self): + colors = process_cmap(Cycle(values=['#ffffff', '#959595', '#000000']), 4) + self.assertEqual(colors, ['#ffffff', '#959595', '#000000', '#ffffff']) + + def test_process_cmap_invalid_str(self): + with self.assertRaises(ValueError): + colors = process_cmap('NonexistentColorMap', 3) + + def test_process_cmap_invalid_type(self): + with self.assertRaises(TypeError): + colors = process_cmap({'A', 'B', 'C'}, 3) + + class TestPlotUtils(ComparisonTestCase): def test_get_min_distance_float32_type(self):