Skip to content

Commit

Permalink
Fixes for HeatMap implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jan 8, 2017
1 parent 1d3d57e commit 69a9793
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 18 deletions.
6 changes: 0 additions & 6 deletions holoviews/element/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,6 @@ class HeatMap(Dataset, Element2D):
depth = 1


def __init__(self, data, **params):
super(HeatMap, self).__init__(data, **params)
shape = (len(self.dimension_values(1)), len(self.dimension_values(0)))
self.extents = (0., 0., shape[0], shape[1])


class Image(SheetCoordinateSystem, Raster):
"""
Image is the atomic unit as which 2D data is stored, along with
Expand Down
13 changes: 7 additions & 6 deletions holoviews/plotting/bokeh/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
except ImportError:
LogColorMapper = None

from ...core.util import cartesian_product, is_nan
from ...core.util import cartesian_product, is_nan, unique_array
from ...element import Image, Raster, RGB
from ..renderer import SkipRendering
from ..util import map_colors, get_2d_aggregate
Expand Down Expand Up @@ -136,7 +136,7 @@ class HeatmapPlot(ColorbarPlot):
def _axes_props(self, plots, subplots, element, ranges):
dims = element.dimensions()
labels = self._get_axis_labels(dims)
xvals, yvals = [element.dimension_values(i, False)
xvals, yvals = [np.sort(unique_array(element.dimension_values(i, False)))
for i in range(2)]
if self.invert_yaxis: yvals = yvals[::-1]
plot_ranges = {'x_range': [str(x) for x in xvals],
Expand All @@ -145,17 +145,18 @@ def _axes_props(self, plots, subplots, element, ranges):

def get_data(self, element, ranges=None, empty=False):
x, y, z = element.dimensions(label=True)[:3]
aggregate = get_2d_aggregate(element)
aggregate = get_2d_aggregate(element).sort()
style = self.style[self.cyclic_index]
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if empty:
data = {x: [], y: [], z: [], 'color': []}
data = {x: [], y: [], z: []}
else:
zvals = aggregate.dimension_values(z)
xvals, yvals = [[str(v) for v in element.dimension_values(i)]
xvals, yvals = [[str(v) for v in aggregate.dimension_values(i)]
for i in range(2)]
data = {x: xvals, y: yvals, z: zvals}
if 'hover' in self.tools:

if 'hover' in self.tools+self.default_tools:
for vdim in element.vdims[1:]:
data[vdim.name] = ['' if is_nan(v) else v
for v in aggregate.dimension_values(vdim)]
Expand Down
12 changes: 7 additions & 5 deletions holoviews/plotting/mpl/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ...core import CompositeOverlay, Element
from ...core import traversal
from ...core.util import match_spec, max_range, unique_iterator
from ...core.util import match_spec, max_range, unique_iterator, unique_array
from ...element.raster import Image, Raster, RGB
from .element import ColorbarPlot, OverlayPlot
from .plot import MPLPlot, GridPlot
Expand Down Expand Up @@ -130,7 +130,7 @@ def _annotate_values(self, element):

def _compute_ticks(self, element, ranges):
xdim, ydim = element.kdims
dim1_keys, dim2_keys = [element.dimension_values(i, False)
dim1_keys, dim2_keys = [np.sort(unique_array(element.dimension_values(i, False)))
for i in range(2)]
num_x, num_y = len(dim1_keys), len(dim2_keys)
xpos = np.linspace(.5, num_x-0.5, num_x)
Expand All @@ -153,8 +153,9 @@ def init_artists(self, ax, plot_args, plot_kwargs):

def get_data(self, element, ranges, style):
_, style, axis_kwargs = super(HeatMapPlot, self).get_data(element, ranges, style)
shape = tuple(len(element.dimension_values(i)) for i in range(2))
aggregate = get_2d_aggregate(element)
shape = tuple(len(unique_array(element.dimension_values(i)))
for i in range(2))
aggregate = get_2d_aggregate(element).sort()
data = np.flipud(aggregate.dimension_values(2).reshape(shape[::-1]))
data = np.ma.array(data, mask=np.logical_not(np.isfinite(data)))
cmap_name = style.pop('cmap', None)
Expand All @@ -171,7 +172,8 @@ def update_handles(self, key, axis, element, ranges, style):
im = self.handles['artist']
data, style, axis_kwargs = self.get_data(element, ranges, style)
im.set_data(data[0])
im.set_extent((l, r, b, t))
shape = data[0].shape
im.set_extent((0, shape[1], 0, shape[0]))
im.set_clim((style['vmin'], style['vmax']))
if 'norm' in style:
im.norm = style['norm']
Expand Down
2 changes: 1 addition & 1 deletion holoviews/plotting/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def reduce_fn(x):
"""
Aggregation function to get the first non-zero value.
"""
values = x.values if pd and isinstance(pd.Series, values) else x
values = x.values if pd and isinstance(x, pd.Series) else x
for v in values:
if not is_nan(v):
return v
Expand Down

0 comments on commit 69a9793

Please sign in to comment.