Skip to content

Commit

Permalink
Add interpolate_curve operation and Curve interpolation plot option
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Feb 6, 2017
1 parent 32758c4 commit 8933fa2
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 5 deletions.
53 changes: 53 additions & 0 deletions holoviews/operation/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,59 @@ def _process(self, element, key=None):
return sliced




class interpolate_curve(ElementOperation):
"""
Resamples a Curve using the defined interpolation method, e.g.
to represent changes in y-values as steps.
"""

interpolation = param.ObjectSelector(objects=['steps-pre', 'steps-mid',
'steps-post', 'linear'],
default='steps-mid', doc="""
Controls the transition point of the step along the x-axis.""")

@classmethod
def pts_to_prestep(cls, x, y):
steps = np.zeros((2, 2 * len(x) - 1))
steps[0, 0::2] = x
steps[0, 1::2] = steps[0, 0:-2:2]
steps[1:, 0::2] = y
steps[1:, 1::2] = steps[1:, 2::2]
return steps

@classmethod
def pts_to_midstep(cls, x, y):
steps = np.zeros((2, 2 * len(x)))
x = np.asanyarray(x)
steps[0, 1:-1:2] = steps[0, 2::2] = (x[:-1] + x[1:]) / 2
steps[0, 0], steps[0, -1] = x[0], x[-1]
steps[1:, 0::2] = y
steps[1:, 1::2] = steps[1:, 0::2]
return steps

@classmethod
def pts_to_poststep(cls, x, y):
steps = np.zeros((2, 2 * len(x) - 1))
steps[0, 0::2] = x
steps[0, 1::2] = steps[0, 2::2]
steps[1:, 0::2] = y
steps[1:, 1::2] = steps[1:, 0:-2:2]
return steps

def _process(self, element, key=None):
INTERPOLATE_FUNCS = {'steps-pre': self.pts_to_prestep,
'steps-mid': self.pts_to_midstep,
'steps-post': self.pts_to_poststep}
if self.p.interpolation not in INTERPOLATE_FUNCS:
return element
x, y = element.dimension_values(0), element.dimension_values(1)
array = INTERPOLATE_FUNCS[self.p.interpolation](x, y)
dvals = tuple(element.dimension_values(d) for d in element.dimensions()[2:])
return element.clone((array[0, :], array[1, :])+dvals)


#==================#
# Other operations #
#==================#
Expand Down
23 changes: 20 additions & 3 deletions holoviews/plotting/bokeh/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ...element import Raster, Points, Polygons, Spikes
from ...core.util import max_range, basestring, dimension_sanitizer
from ...core.options import abbreviated_exception
from ...operation import interpolate_curve
from ..util import compute_sizes, get_sideplot_ranges, match_spec, map_colors
from .element import ElementPlot, ColorbarPlot, LegendPlot, line_properties, fill_properties
from .path import PathPlot, PolygonPlot
Expand Down Expand Up @@ -138,11 +139,20 @@ def _init_glyph(self, plot, mapping, properties):

class CurvePlot(ElementPlot):

interpolation = param.ObjectSelector(objects=['linear', 'steps-mid',
'steps-pre', 'steps-post'],
default='linear', doc="""
Defines how the samples of the Curve are interpolated,
default is 'linear', other options include 'steps-mid',
'steps-pre' and 'steps-post'.""")

style_opts = ['color'] + line_properties
_plot_methods = dict(single='line', batched='multi_line')
_mapping = {p: p for p in ['xs', 'ys', 'color', 'line_alpha']}

def get_data(self, element, ranges=None, empty=False):
if 'steps' in self.interpolation:
element = interpolate_curve(element, interpolation=self.interpolation)
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)
x = element.get_dimension(xidx).name
y = element.get_dimension(yidx).name
Expand Down Expand Up @@ -562,9 +572,7 @@ def _update_chart(self, key, element, ranges):

@property
def current_handles(self):
plot = self.handles['plot']
sources = plot.select(type=ColumnDataSource)
return sources
return self.state.select(type=(ColumnDataSource, Range1d))


class BoxPlot(ChartPlot):
Expand Down Expand Up @@ -596,6 +604,15 @@ def _init_chart(self, element, ranges):
**properties)


def _update_chart(self, key, element, ranges):
super(BoxPlot, self)._update_chart(key, element, ranges)
vdim = element.vdims[0].name
start, end = ranges[vdim]
self.state.y_range.start = start
self.state.y_range.end = end



class BarPlot(ChartPlot):
"""
BarPlot allows generating single- or multi-category
Expand Down
10 changes: 10 additions & 0 deletions holoviews/plotting/mpl/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ...core.util import (match_spec, unique_iterator, safe_unicode,
basestring, max_range, unicode)
from ...element import Points, Raster, Polygons, HeatMap
from ...operation import interpolate_curve
from ..util import compute_sizes, get_sideplot_ranges, map_colors
from .element import ElementPlot, ColorbarPlot, LegendPlot
from .path import PathPlot
Expand Down Expand Up @@ -41,6 +42,13 @@ class CurvePlot(ChartPlot):
Whether to let matplotlib automatically compute tick marks
or to allow the user to control tick marks.""")

interpolation = param.ObjectSelector(objects=['linear', 'steps-mid',
'steps-pre', 'steps-post'],
default='linear', doc="""
Defines how the samples of the Curve are interpolated,
default is 'linear', other options include 'steps-mid',
'steps-pre' and 'steps-post'.""")

relative_labels = param.Boolean(default=False, doc="""
If plotted quantity is cyclic and center_cyclic is enabled,
will compute tick labels relative to the center.""")
Expand All @@ -59,6 +67,8 @@ class CurvePlot(ChartPlot):
_plot_methods = dict(single='plot')

def get_data(self, element, ranges, style):
if 'steps' in self.interpolation:
element = interpolate_curve(element, interpolation=self.interpolation)
xs = element.dimension_values(0)
ys = element.dimension_values(1)
dims = element.dimensions()
Expand Down
10 changes: 10 additions & 0 deletions holoviews/plotting/plotly/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from plotly.tools import FigureFactory as FF

from ...core import util
from ...operation import interpolate_curve
from .element import ElementPlot, ColorbarPlot


Expand Down Expand Up @@ -41,11 +42,20 @@ def get_data(self, element, ranges):

class CurvePlot(ElementPlot):

interpolation = param.ObjectSelector(objects=['linear', 'steps-mid',
'steps-pre', 'steps-post'],
default='linear', doc="""
Defines how the samples of the Curve are interpolated,
default is 'linear', other options include 'steps-mid',
'steps-pre' and 'steps-post'.""")

graph_obj = go.Scatter

style_opts = ['color', 'dash', 'width']

def graph_options(self, element, ranges):
if 'steps' in self.interpolation:
element = interpolate_curve(element, interpolation=self.interpolation)
opts = super(CurvePlot, self).graph_options(element, ranges)
opts['mode'] = 'lines'
style = self.style[self.cyclic_index]
Expand Down
20 changes: 18 additions & 2 deletions tests/testoperation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np

from holoviews import (HoloMap, NdOverlay, Image, Contours, Polygons, Points,
Histogram)
Histogram, Curve)
from holoviews.element.comparison import ComparisonTestCase
from holoviews.operation.element import (operation, transform, threshold,
gradient, contours, histogram)
gradient, contours, histogram,
interpolate_curve)

class ElementOperationTests(ComparisonTestCase):
"""
Expand Down Expand Up @@ -83,3 +84,18 @@ def test_points_histogram_mean_weighted(self):
op_hist = histogram(points, num_bins=3, weight_dimension='y', mean_weighted=True)
hist = Histogram(([1., 4., 7.5], [0, 3, 6, 9]), vdims=['y'])
self.assertEqual(op_hist, hist)

def test_interpolate_curve_pre(self):
interpolated = interpolate_curve(Curve([0, 0.5, 1]), interpolation='steps-pre')
curve = Curve([(0, 0), (0, 0.5), (1, 0.5), (1, 1), (2, 1)])
self.assertEqual(interpolated, curve)

def test_interpolate_curve_mid(self):
interpolated = interpolate_curve(Curve([0, 0.5, 1]), interpolation='steps-mid')
curve = Curve([(0, 0), (0.5, 0), (0.5, 0.5), (1.5, 0.5), (1.5, 1), (2, 1)])
self.assertEqual(interpolated, curve)

def test_interpolate_curve_post(self):
interpolated = interpolate_curve(Curve([0, 0.5, 1]), interpolation='steps-post')
curve = Curve([(0, 0), (1, 0), (1, 0.5), (2, 0.5), (2, 1)])
self.assertEqual(interpolated, curve)

0 comments on commit 8933fa2

Please sign in to comment.