Skip to content

Commit

Permalink
Implemented and documented Graph edge_color_index
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Nov 23, 2017
1 parent 184cf82 commit 162828a
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 103 deletions.
10 changes: 6 additions & 4 deletions examples/reference/elements/bokeh/Graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"source": [
"#### Additional features\n",
"\n",
"Next we will extend this example by supplying explicit edges:"
"Next we will extend this example by supplying explicit edges, node information and edge weights. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``."
]
},
{
Expand All @@ -100,8 +100,10 @@
"outputs": [],
"source": [
"# Node info\n",
"np.random.seed(7)\n",
"x, y = simple_graph.nodes.array([0, 1]).T\n",
"node_labels = ['Output']+['Input']*(N-1)\n",
"edge_weights = np.random.rand(8)\n",
"\n",
"# Compute edge paths\n",
"def bezier(start, end, control, steps=np.linspace(0, 1, 100)):\n",
Expand All @@ -114,10 +116,10 @@
"\n",
"# Declare Graph\n",
"nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n",
"graph = hv.Graph(((source, target), nodes, paths))\n",
"graph = hv.Graph(((source, target, edge_weights), nodes, paths), vdims='Weight')\n",
"\n",
"graph.redim.range(**padding).opts(plot=dict(color_index='Type'),\n",
" style=dict(cmap=['blue', 'yellow']))"
"graph.redim.range(**padding).opts(plot=dict(color_index='Type', edge_color_index='Weight'),\n",
" style=dict(cmap=['blue', 'red'], edge_cmap='viridis'))"
]
}
],
Expand Down
13 changes: 7 additions & 6 deletions examples/reference/elements/matplotlib/Graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@
"source": [
"#### Additional features\n",
"\n",
"Next we will extend this example by supplying explicit edges:"
"\n",
"Next we will extend this example by supplying explicit edges, node information and edge weights. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``."
]
},
{
Expand All @@ -99,11 +100,11 @@
"metadata": {},
"outputs": [],
"source": [
"from matplotlib.colors import ListedColormap\n",
"\n",
"# Node info\n",
"np.random.seed(7)\n",
"x, y = simple_graph.nodes.array([0, 1]).T\n",
"node_labels = ['Output']+['Input']*(N-1)\n",
"edge_weights = np.random.rand(8)\n",
"\n",
"# Compute edge paths\n",
"def bezier(start, end, control, steps=np.linspace(0, 1, 100)):\n",
Expand All @@ -116,10 +117,10 @@
"\n",
"# Declare Graph\n",
"nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n",
"graph = hv.Graph(((source, target), nodes, paths))\n",
"graph = hv.Graph(((source, target, edge_weights), nodes, paths), vdims='Weight')\n",
"\n",
"graph.redim.range(**padding).opts(plot=dict(color_index='Type'),\n",
" style=dict(cmap=ListedColormap(['blue', 'yellow'])))"
"graph.redim.range(**padding).opts(plot=dict(color_index='Type', edge_color_index='Weight'),\n",
" style=dict(cmap=['blue', 'red'], edge_cmap='viridis'))"
]
}
],
Expand Down
13 changes: 7 additions & 6 deletions examples/user_guide/Network_Graphs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
"source": [
"# Declare abstract edges\n",
"N = 8\n",
"node_indices = np.arange(N)\n",
"source = np.zeros(N)\n",
"node_indices = np.arange(N, dtype=np.int32)\n",
"source = np.zeros(N, dtype=np.int32)\n",
"target = node_indices\n",
"\n",
"padding = dict(x=(-1.2, 1.2), y=(-1.2, 1.2))\n",
Expand Down Expand Up @@ -148,7 +148,7 @@
"source": [
"#### Additional information\n",
"\n",
"We can also associate additional information with the nodes and edges of a graph. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself."
"We can also associate additional information with the nodes and edges of a graph. By constructing the ``Nodes`` explicitly we can declare an additional value dimensions, which are revealed when hovering and/or can be mapped to the color by specifying the ``color_index``. We can also associate additional information with each edge by supplying a value dimension to the ``Graph`` itself, which we can map to a color using the ``edge_color_index``."
]
},
{
Expand All @@ -157,12 +157,13 @@
"metadata": {},
"outputs": [],
"source": [
"%%opts Graph [color_index='Type'] (cmap='Set1')\n",
"%%opts Graph [color_index='Type' edge_color_index='Weight'] (cmap='Set1' edge_cmap='viridis')\n",
"node_labels = ['Output']+['Input']*(N-1)\n",
"edge_labels = list('ABCDEFGH')\n",
"np.random.seed(7)\n",
"edge_labels = np.random.rand(8)\n",
"\n",
"nodes = hv.Nodes((x, y, node_indices, node_labels), vdims='Type')\n",
"graph = hv.Graph(((source, target, edge_labels), nodes, paths), vdims='Label').redim.range(**padding)\n",
"graph = hv.Graph(((source, target, edge_labels), nodes, paths), vdims='Weight').redim.range(**padding)\n",
"graph + graph.opts(plot=dict(inspection_policy='edges'))"
]
},
Expand Down
9 changes: 9 additions & 0 deletions holoviews/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,3 +1569,12 @@ def dt_to_int(value, time_unit='us'):
except:
# Handle python2
return (time.mktime(value.timetuple()) + value.microsecond / 1e6) * tscale


def search_indices(values, source):
"""
Given a set of values returns the indices of each of those values
in the source array.
"""
orig_indices = source.argsort()
return orig_indices[np.searchsorted(source[orig_indices], values)]
11 changes: 5 additions & 6 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import bokeh
import bokeh.plotting
from bokeh import palettes
from bokeh.core.properties import value
from bokeh.models import (HoverTool, Renderer, Range1d, DataRange1d, Title,
FactorRange, FuncTickFormatter, Tool, Legend)
Expand All @@ -20,7 +19,7 @@
from bokeh.plotting.helpers import _known_tools as known_tools

from ...core import DynamicMap, CompositeOverlay, Element, Dimension
from ...core.options import abbreviated_exception, SkipRendering, Cycle
from ...core.options import abbreviated_exception, SkipRendering
from ...core import util
from ...streams import Stream, Buffer
from ..plot import GenericElementPlot, GenericOverlayPlot
Expand Down Expand Up @@ -1102,7 +1101,7 @@ def _draw_colorbar(self, plot, color_mapper):


def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=None,
cycle=None, name='color_mapper'):
name='color_mapper'):
# The initial colormapper instance is cached the first time
# and then only updated
if dim is None and colors is None:
Expand All @@ -1124,7 +1123,7 @@ def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=Non
else:
low, high = None, None

cmap = colors or cycle or style.pop('cmap', 'viridis')
cmap = colors or style.pop('cmap', 'viridis')
palette = process_cmap(cmap, ncolors)
nan_colors = {k: rgba_tuple(v) for k, v in self.clipping_colors.items()}
colormapper, opts = self._get_cmapper_opts(low, high, factors, nan_colors)
Expand All @@ -1145,7 +1144,7 @@ def _get_colormapper(self, dim, element, ranges, style, factors=None, colors=Non


def _get_color_data(self, element, ranges, style, name='color', factors=None, colors=None,
cycle=None, int_categories=False):
int_categories=False):
data, mapping = {}, {}
cdim = element.get_dimension(self.color_index)
if not cdim:
Expand All @@ -1162,7 +1161,7 @@ def _get_color_data(self, element, ranges, style, name='color', factors=None, co
factors = [str(f) for f in factors]

mapper = self._get_colormapper(cdim, element, ranges, style,
factors, colors, cycle)
factors, colors)
data[field] = cdata
if factors is not None:
mapping['legend'] = {'field': field}
Expand Down
125 changes: 86 additions & 39 deletions holoviews/plotting/bokeh/graphs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import param
import numpy as np
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.models import (StaticLayoutProvider, NodesAndLinkedEdges,
EdgesAndLinkedNodes)

from bokeh.models import Range1d, HoverTool, ColumnDataSource

try:
from bokeh.models import (StaticLayoutProvider, NodesAndLinkedEdges,
EdgesAndLinkedNodes)
except:
pass

from ...core.util import basestring, dimension_sanitizer
from ...core.util import basestring, dimension_sanitizer, unique_array
from ...core.options import Cycle
from .chart import ColorbarPlot, PointPlot
from .element import CompositeElementPlot, LegendPlot, line_properties, fill_properties
from ..util import process_cmap


class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot):
Expand All @@ -20,6 +17,10 @@ class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot):
allow_None=True, doc="""
Index of the dimension from which the color will the drawn""")

edge_color_index = param.ClassSelector(default=None, class_=(basestring, int),
allow_None=True, doc="""
Index of the dimension from which the color will the drawn""")

selection_policy = param.ObjectSelector(default='nodes', objects=['edges', 'nodes', None], doc="""
Determines policy for inspection of graph components, i.e. whether to highlight
nodes or edges when selecting connected edges and nodes respectively.""")
Expand All @@ -31,35 +32,23 @@ class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot):
tools = param.List(default=['hover', 'tap'], doc="""
A list of plugin tools to use on the plot.""")

# X-axis is categorical
_x_range_type = Range1d

# Declare that y-range should auto-range if not bounded
_y_range_type = Range1d

# Map each glyph to a style group
# Map each glyph to a style group
_style_groups = {'scatter': 'node', 'multi_line': 'edge'}

style_opts = (['edge_'+p for p in line_properties] +\
['node_'+p for p in fill_properties+line_properties]+['node_size', 'cmap'])
['node_'+p for p in fill_properties+line_properties]+['node_size', 'cmap', 'edge_cmap'])

def _hover_opts(self, element):
if self.inspection_policy == 'nodes':
dims = element.nodes.dimensions()
dims = [(dims[2].pprint_label, '@{index_hover}')]+dims[3:]
elif self.inspection_policy == 'edges':
kdims = [(kd.pprint_label, '@{%s}' % ref)
for kd, ref in zip(element.kdims, ['start', 'end'])]
dims = kdims+element.vdims
dims = element.kdims+element.vdims
else:
dims = []
return dims, {}

def get_extents(self, element, ranges):
"""
Extents are set to '' and None because x-axis is categorical and
y-axis auto-ranges.
"""
xdim, ydim = element.nodes.kdims[:2]
x0, x1 = ranges[xdim.name]
y0, y1 = ranges[ydim.name]
Expand All @@ -73,26 +62,73 @@ def _get_axis_labels(self, *args, **kwargs):
xlabel, ylabel = [kd.pprint_label for kd in element.nodes.kdims[:2]]
return xlabel, ylabel, None


def _get_edge_colors(self, element, ranges, edge_data, edge_mapping, style):
cdim = element.get_dimension(self.edge_color_index)
if not cdim:
return
elstyle = self.lookup_options(element, 'style')
cycle = elstyle.kwargs.get('edge_color')

idx = element.get_dimension_index(cdim)
field = dimension_sanitizer(cdim.name)
cvals = element.dimension_values(cdim)
if idx in [0, 1]:
factors = element.nodes.dimension_values(2, expanded=False)
elif idx == 2 and cvals.dtype.kind in 'if':
factors = None
else:
factors = unique_array(cvals)

default_cmap = 'viridis' if factors is None else 'tab20'
cmap = style.get('edge_cmap', style.get('cmap', default_cmap))
if factors is None or factors.dtype.kind == 'f' and idx not in [0, 1]:
colors, factors = None, None
else:
if factors.dtype.kind not in 'SU':
field += '_str'
cvals = [str(f) for f in cvals]
factors = (str(f) for f in factors)
factors = list(factors)
colors = process_cmap(cycle or cmap, len(factors))

if field not in edge_data:
edge_data[field] = cvals
edge_style = dict(style, cmap=cmap)
mapper = self._get_colormapper(cdim, element, ranges, edge_style,
factors, colors, 'edge_colormapper')
transform = {'field': field, 'transform': mapper}
edge_mapping['edge_line_color'] = transform
edge_mapping['edge_nonselection_line_color'] = transform
edge_mapping['edge_selection_line_color'] = transform


def get_data(self, element, ranges, style):
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)

# Get node data
nodes = element.nodes.dimension_values(2)
node_positions = element.nodes.array([0, 1, 2])
node_positions = element.nodes.array([0, 1])
# Map node indices to integers
if nodes.dtype.kind != 'i':
if nodes.dtype.kind not in 'if':
node_indices = {v: i for i, v in enumerate(nodes)}
index = np.array([node_indices[n] for n in nodes], dtype=np.int32)
layout = {node_indices[z]: (y, x) if self.invert_axes else (x, y)
for x, y, z in node_positions}
layout = {str(node_indices[k]): (y, x) if self.invert_axes else (x, y)
for k, (x, y) in zip(nodes, node_positions)}
else:
index = nodes.astype(np.int32)
layout = {z: (y, x) if self.invert_axes else (x, y)
for x, y, z in node_positions}
layout = {str(k): (y, x) if self.invert_axes else (x, y)
for k, (x, y) in zip(index, node_positions)}
point_data = {'index': index}
cdata, cmapping = self._get_color_data(element.nodes, ranges, style, 'node_fill_color')
cycle = self.lookup_options(element, 'style').kwargs.get('node_color')
colors = cycle if isinstance(cycle, Cycle) else None
cdata, cmapping = self._get_color_data(
element.nodes, ranges, style, name='node_fill_color',
colors=colors, int_categories=True
)
point_data.update(cdata)
point_mapping = cmapping
edge_mapping = {}
if 'node_fill_color' in point_mapping:
style = {k: v for k, v in style.items() if k not in
['node_fill_color', 'node_nonselection_fill_color']}
Expand All @@ -101,10 +137,13 @@ def get_data(self, element, ranges, style):
# Get edge data
nan_node = index.max()+1
start, end = (element.dimension_values(i) for i in range(2))
if nodes.dtype.kind not in 'if':
if nodes.dtype.kind == 'f':
start, end = start.astype(np.int32), end.astype(np.int32)
elif nodes.dtype.kind != 'i':
start = np.array([node_indices.get(x, nan_node) for x in start], dtype=np.int32)
end = np.array([node_indices.get(y, nan_node) for y in end], dtype=np.int32)
path_data = dict(start=start, end=end)
self._get_edge_colors(element, ranges, path_data, edge_mapping, style)
if element._edgepaths and not self.static_source:
edges = element._split_edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims)
if len(edges) == len(start):
Expand All @@ -122,11 +161,10 @@ def get_data(self, element, ranges, style):
for d in element.nodes.dimensions()[3:]:
point_data[dimension_sanitizer(d.name)] = element.nodes.dimension_values(d)
elif self.inspection_policy == 'edges':
for d in element.vdims:
for d in element.dimensions():
path_data[dimension_sanitizer(d.name)] = element.dimension_values(d)

data = {'scatter_1': point_data, 'multi_line_1': path_data, 'layout': layout}
mapping = {'scatter_1': point_mapping, 'multi_line_1': {}}
mapping = {'scatter_1': point_mapping, 'multi_line_1': edge_mapping}
return data, mapping, style


Expand All @@ -145,19 +183,27 @@ def _init_glyphs(self, plot, element, ranges, source):
style = self.style[self.cyclic_index]
data, mapping, style = self.get_data(element, ranges, style)
self.handles['previous_id'] = element._plot_id

properties = {}
mappings = {}
for key in mapping:
source = self._init_datasource(data.get(key, {}))
for key in list(mapping):
if not any(glyph in key for glyph in ('scatter_1', 'multi_line_1')):
continue
source = self._init_datasource(data.pop(key, {}))
self.handles[key+'_source'] = source
glyph_props = self._glyph_properties(plot, element, source, ranges, style)
properties.update(glyph_props)
mappings.update(mapping.get(key, {}))
mappings.update(mapping.pop(key, {}))
properties = {p: v for p, v in properties.items() if p not in ('legend', 'source')}
properties.update(mappings)

layout = data.pop('layout', {})
if data and mapping:
CompositeElementPlot._init_glyphs(self, plot, element, ranges, source,
data, mapping, style)

# Define static layout
layout = StaticLayoutProvider(graph_layout=data['layout'])
layout = StaticLayoutProvider(graph_layout=layout)
node_source = self.handles['scatter_1_source']
edge_source = self.handles['multi_line_1_source']
renderer = plot.graph(node_source, edge_source, layout, **properties)
Expand Down Expand Up @@ -187,6 +233,7 @@ def _init_glyphs(self, plot, element, ranges, source):
self.handles['hover'].renderers.append(renderer)



class NodePlot(PointPlot):
"""
Simple subclass of PointPlot which hides x, y position on hover.
Expand Down
Loading

0 comments on commit 162828a

Please sign in to comment.