Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cupy histograms and box-whisker stats #4447

Merged
merged 5 commits into from
May 31, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions holoviews/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,13 @@ def is_dask_array(data):
return (da is not None and isinstance(data, da.Array))


def is_cupy_array(data):
if 'cupy' in sys.modules:
import cupy
return isinstance(data, cupy.ndarray)
return False


def get_param_values(data):
params = dict(kdims=data.kdims, vdims=data.vdims,
label=data.label)
Expand Down
92 changes: 62 additions & 30 deletions holoviews/operation/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
"""
from __future__ import division

import numpy as np
from distutils.version import LooseVersion

import numpy as np
import param

from param import _is_number

from ..core import (Operation, NdOverlay, Overlay, GridMatrix,
HoloMap, Dataset, Element, Collator, Dimension)
from ..core.data import ArrayInterface, DictInterface, default_datatype
from ..core.util import (group_sanitizer, label_sanitizer, pd,
basestring, datetime_types, isfinite, dt_to_int,
isdatetime, is_dask_array)
isdatetime, is_dask_array, is_cupy_array)
from ..element.chart import Histogram, Scatter
from ..element.raster import Image, RGB
from ..element.path import Contours, Polygons
Expand Down Expand Up @@ -641,7 +643,10 @@ class histogram(Operation):
Specifies the range within which to compute the bins.""")

bins = param.ClassSelector(default=None, class_=(np.ndarray, list, tuple, str), doc="""
An explicit set of bin edges.""")
An explicit set of bin edges or a method to find the optimal
set of bin edges, e.g. 'auto', 'fd', 'scott' etc. For more
documentation on these approaches see the np.histogram_bin_edges
documentation.""")

cumulative = param.Boolean(default=False, doc="""
Whether to compute the cumulative histogram""")
Expand Down Expand Up @@ -690,6 +695,7 @@ def _process(self, element, key=None):
self.p.groupby = None
return grouped.map(self._process, Dataset)

normed = False if self.p.mean_weighted and self.p.weight_dimension else self.p.normed
if self.p.dimension:
selected_dim = self.p.dimension
else:
Expand All @@ -701,16 +707,36 @@ def _process(self, element, key=None):
else:
data = element.dimension_values(selected_dim)

is_datetime = isdatetime(data)
if is_datetime:
data = data.astype('datetime64[ns]').astype('int64')

# Handle different datatypes
is_finite = isfinite
is_cupy = is_cupy_array(data)
if is_cupy:
import cupy
full_cupy_support = LooseVersion(cupy.__version__) > '8.0'
if not full_cupy_support and (normed or self.p.weight_dimension):
data = cupy.asnumpy(data)
is_cupy = False
if is_dask_array(data):
import dask.array as da
histogram = da.histogram
elif is_cupy:
import cupy
histogram = cupy.histogram
is_finite = cupy.isfinite
else:
histogram = np.histogram

mask = isfinite(data)
# Mask data
mask = is_finite(data)
if self.p.nonzero:
mask = mask & (data > 0)
data = data[mask]

# Compute weights
if self.p.weight_dimension:
if hasattr(element, 'interface'):
weights = element.interface.values(element, self.p.weight_dimension, compute=False)
Expand All @@ -721,35 +747,36 @@ def _process(self, element, key=None):
else:
weights = None

hist_range = self.p.bin_range or element.range(selected_dim)
# Avoids range issues including zero bin range and empty bins
if hist_range == (0, 0) or any(not isfinite(r) for r in hist_range):
hist_range = (0, 1)

datetimes = False
bins = None if self.p.bins is None else np.asarray(self.p.bins)
steps = self.p.num_bins + 1
start, end = hist_range
if isdatetime(data):
start, end = dt_to_int(start, 'ns'), dt_to_int(end, 'ns')
datetimes = True
data = data.astype('datetime64[ns]').astype('int64')
if bins is not None:
bins = bins.astype('datetime64[ns]').astype('int64')
# Compute bins
if isinstance(self.p.bins, str):
bin_data = cupy.asnumpy(data) if is_cupy else data
edges = np.histogram_bin_edges(bin_data, bins=self.p.bins)
elif isinstance(self.p.bins, (list, np.ndarray)):
edges = self.p.bins
if isdatetime(edges):
edges = edges.astype('datetime64[ns]').astype('int64')
else:
hist_range = self.p.bin_range or element.range(selected_dim)
# Avoids range issues including zero bin range and empty bins
if hist_range == (0, 0) or any(not isfinite(r) for r in hist_range):
hist_range = (0, 1)
steps = self.p.num_bins + 1
start, end = hist_range
if is_datetime:
start, end = dt_to_int(start, 'ns'), dt_to_int(end, 'ns')
if self.p.log:
bin_min = max([abs(start), data[data>0].min()])
edges = np.logspace(np.log10(bin_min), np.log10(end), steps)
else:
hist_range = start, end
edges = np.linspace(start, end, steps)

if self.p.bins:
edges = bins
elif self.p.log:
bin_min = max([abs(start), data[data>0].min()])
edges = np.logspace(np.log10(bin_min), np.log10(end), steps)
else:
edges = np.linspace(start, end, steps)
normed = False if self.p.mean_weighted and self.p.weight_dimension else self.p.normed
if is_cupy:
edges = cupy.asarray(edges)

if is_dask_array(data) or len(data):
if normed:
if is_cupy and not full_cupy_support:
hist, _ = histogram(data, bins=edges)
elif normed:
# This covers True, 'height', 'integral'
hist, edges = histogram(data, density=True,
weights=weights, bins=edges)
Expand All @@ -763,8 +790,13 @@ def _process(self, element, key=None):
else:
nbins = self.p.num_bins if self.p.bins is None else len(self.p.bins)-1
hist = np.zeros(nbins)

if is_cupy_array(hist):
edges = cupy.asnumpy(edges)
hist = cupy.asnumpy(hist)

hist[np.isnan(hist)] = 0
if datetimes:
if is_datetime:
edges = (edges/1e3).astype('datetime64[us]')

params = {}
Expand Down
33 changes: 27 additions & 6 deletions holoviews/plotting/bokeh/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ...core.dimension import Dimension, Dimensioned
from ...core.ndmapping import sorted_context
from ...core.util import (basestring, dimension_sanitizer, wrap_tuple,
unique_iterator, unique_array, isfinite)
unique_iterator, unique_array, isfinite,
is_dask_array, is_cupy_array)
from ...operation.stats import univariate_kde
from ...util.transform import dim
from .chart import AreaPlot
Expand Down Expand Up @@ -139,19 +140,38 @@ def _postprocess_hover(self, renderer, source):
super(BoxWhiskerPlot, self)._postprocess_hover(renderer, source)

def _box_stats(self, vals):
vals = vals[isfinite(vals)]
is_finite = isfinite
is_dask = is_dask_array(vals)
is_cupy = is_cupy_array(vals)
if is_cupy:
import cupy
percentile = cupy.percentile
is_finite = cupy.isfinite
elif is_dask:
import dask.array as da
percentile = da.percentile
else:
percentile = np.percentile

vals = vals[is_finite(vals)]

if len(vals):
q1, q2, q3 = (np.percentile(vals, q=q)
for q in range(25, 100, 25))
q1, q2, q3 = (percentile(vals, q=q) for q in range(25, 100, 25))
iqr = q3 - q1
upper = vals[vals <= q3 + 1.5*iqr].max()
lower = vals[vals >= q1 - 1.5*iqr].min()
else:
q1, q2, q3 = 0, 0, 0
upper, lower = 0, 0
outliers = vals[(vals > upper) | (vals < lower)]
return q1, q2, q3, upper, lower, outliers

if is_cupy:
return (q1.item(), q2.item(), q3.item(), upper.item(),
lower.item(), cupy.asnumpy(outliers))
elif is_dask:
return da.compute(q1, q2, q3, upper, lower, outliers)
else:
return q1, q2, q3, upper, lower, outliers

def get_data(self, element, ranges, style):
if element.kdims:
Expand Down Expand Up @@ -191,6 +211,7 @@ def get_data(self, element, ranges, style):
cdim, cidx = None, None

factors = []
vdim = element.vdims[0].name
for key, g in groups.items():
# Compute group label
if element.kdims:
Expand All @@ -208,7 +229,7 @@ def get_data(self, element, ranges, style):
factors.append(label)

# Compute statistics
vals = g.dimension_values(g.vdims[0])
vals = element.interface.values(element, vdim, compute=False)
philippjfr marked this conversation as resolved.
Show resolved Hide resolved
q1, q2, q3, upper, lower, outliers = self._box_stats(vals)

# Add to CDS data
Expand Down