Skip to content

Commit

Permalink
Various fixes for datashader
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Oct 8, 2019
1 parent 9e0c6ca commit 794d75e
Showing 1 changed file with 57 additions and 19 deletions.
76 changes: 57 additions & 19 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
LooseVersion, basestring, cftime_types, cftime_to_timestamp,
datetime_types, dt_to_int, get_param_values, max_range)
from ..element import (Image, Path, Curve, RGB, Graph, TriMesh,
QuadMesh, Contours, Spikes, Area, Spread,
Scatter, Points)
QuadMesh, Contours, Scatter, Points)
from ..streams import RangeXY, PlotSize

ds_version = LooseVersion(ds.__version__)
Expand Down Expand Up @@ -114,7 +113,8 @@ def instance(self_or_cls,**params):
inst._precomputed = {}
return inst

def _get_sampling(self, element, x, y):

def _get_sampling(self, element, x, y, ndim=2, default=None):
target = self.p.target
if not isinstance(x, list) and x is not None:
x = [x]
Expand All @@ -128,7 +128,8 @@ def _get_sampling(self, element, x, y):
else:
if x is None:
x_range = self.p.x_range or (-0.5, 0.5)
y_range = self.p.y_range or (-0.5, 0.5)
elif self.p.expand or not self.p.x_range:
x_range = self.p.x_range or max_range([element.range(xd) for xd in x])
else:
x0, x1 = self.p.x_range
ex0, ex1 = max_range([element.range(xd) for xd in x])
Expand All @@ -145,18 +146,9 @@ def _get_sampling(self, element, x, y):
if default is None:
ey0, ey1 = max_range([element.range(yd) for yd in y])
else:
x0, x1 = self.p.x_range
ex0, ex1 = element.range(x)
x_range = (np.min([np.max([x0, ex0]), ex1]),
np.max([np.min([x1, ex1]), ex0]))

if self.p.expand or not self.p.y_range:
y_range = self.p.y_range or element.range(y)
else:
y0, y1 = self.p.y_range
ey0, ey1 = element.range(y)
y_range = (np.min([np.max([y0, ey0]), ey1]),
np.max([np.min([y1, ey1]), ey0]))
ey0, ey1 = default
y_range = (np.min([np.max([y0, ey0]), ey1]),
np.max([np.min([y1, ey1]), ey0]))
width, height = self.p.width, self.p.height
(xstart, xend), (ystart, yend) = x_range, y_range

Expand All @@ -168,7 +160,6 @@ def _get_sampling(self, element, x, y):
xstart, xend = 0, 0
if x and element.get_dimension_type(x[0]) in datetime_types:
xtype = 'datetime'
x_range = (xstart, xend)

ytype = 'numeric'
if isinstance(ystart, datetime_types) or isinstance(yend, datetime_types):
Expand All @@ -178,7 +169,6 @@ def _get_sampling(self, element, x, y):
ystart, yend = 0, 0
if y and element.get_dimension_type(y[0]) in datetime_types:
ytype = 'datetime'
y_range = (ystart, yend)

# Compute highest allowed sampling density
xspan = xend - xstart
Expand All @@ -198,7 +188,18 @@ def _get_sampling(self, element, x, y):
xs, ys = (np.linspace(xstart+xunit/2., xend-xunit/2., width),
np.linspace(ystart+yunit/2., yend-yunit/2., height))

return (x_range, y_range), (xs, ys), (width, height), (xtype, ytype)
return ((xstart, xend), (ystart, yend)), (xs, ys), (width, height), (xtype, ytype)


def _dt_transform(self, x_range, y_range, xs, ys, xtype, ytype):
(xstart, xend), (ystart, yend) = x_range, y_range
if xtype == 'datetime':
xstart, xend = (np.array([xstart, xend])/1e3).astype('datetime64[us]')
xs = (xs/1e3).astype('datetime64[us]')
if ytype == 'datetime':
ystart, yend = (np.array([ystart, yend])/1e3).astype('datetime64[us]')
ys = (ys/1e3).astype('datetime64[us]')
return ((xstart, xend), (ystart, yend)), (xs, ys)



Expand Down Expand Up @@ -260,6 +261,43 @@ def _get_aggregator(self, element, add_field=True):
return agg


def _empty_agg(self, element, x, y, width, height, xs, ys, agg_fn, **params):
x = x.name if x else 'x'
y = y.name if x else 'y'
xarray = xr.DataArray(np.full((height, width), np.NaN),
dims=[y, x], coords={x: xs, y: ys})
if width == 0:
params['xdensity'] = 1
if height == 0:
params['ydensity'] = 1
el = self.p.element_type(xarray, **params)
if isinstance(agg_fn, ds.count_cat):
vals = element.dimension_values(agg_fn.column, expanded=False)
dim = element.get_dimension(agg_fn.column)
return NdOverlay({v: el for v in vals}, dim)
return el


def _get_agg_params(self, element, x, y, agg_fn, bounds):
params = dict(get_param_values(element), kdims=[x, y],
datatype=['xarray'], bounds=bounds)

column = agg_fn.column if agg_fn else None
if column:
dims = [d for d in element.dimensions('ranges') if d == column]
if not dims:
raise ValueError("Aggregation column %s not found on %s element. "
"Ensure the aggregator references an existing "
"dimension." % (column,element))
name = '%s Count' % column if isinstance(agg_fn, ds.count_cat) else column
vdims = [dims[0].clone(name)]
else:
vdims = Dimension('Count')
params['vdims'] = vdims
return params



class aggregate(AggregationOperation):
"""
aggregate implements 2D binning for any valid HoloViews Element
Expand Down

0 comments on commit 794d75e

Please sign in to comment.