Skip to content

Commit

Permalink
Fixed bug in canonicalizing xarray data (#1524)
Browse files Browse the repository at this point in the history
* Fixed bug canonicalizing xarray data

* XArray constructor does not override explicit kdims
  • Loading branch information
philippjfr authored and jlstevens committed Jun 7, 2017
1 parent 7b57136 commit b8dae63
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
8 changes: 4 additions & 4 deletions holoviews/core/data/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,16 @@ def canonicalize(cls, dataset, data, coord_dims=None):
dims = [name for name in coord_dims[::-1]
if isinstance(cls.coords(dataset, name), np.ndarray)]
dropped = [dims.index(d) for d in dims if d not in dataset.kdims]
inds = [dims.index(kd.name) for kd in dataset.kdims]
inds += dropped
inds = [dims.index(kd.name)for kd in dataset.kdims]
inds = [i - sum([1 for d in dropped if i>=d]) for i in inds]
if dropped:
data = data.squeeze(axis=tuple(dropped))
if inds:
data = data.transpose(inds)

# Allow lower dimensional views into data
if len(dataset.kdims) < 2:
data = data.flatten()
elif dropped:
data = data.squeeze(axis=tuple(range(len(dropped))))
return data


Expand Down
6 changes: 3 additions & 3 deletions holoviews/core/data/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def init(cls, eltype, data, kdims, vdims):
elif len(vdim_param.default) == 1:
vdim = vdim_param.default[0]
vdims = [vdim]
kdims = [Dimension(d) for d in data.dims[::-1]]
if not kdims:
kdims = [Dimension(d) for d in data.dims[::-1]]
data = data.to_dataset(name=vdim.name)
elif not isinstance(data, xr.Dataset):
if kdims is None:
Expand Down Expand Up @@ -148,7 +149,6 @@ def groupby(cls, dataset, dimensions, container_type, group_type, **kwargs):

@classmethod
def coords(cls, dataset, dim, ordered=False, expanded=False):
dim = dataset.get_dimension(dim, strict=True).name
if expanded:
return util.expand_grid_coords(dataset, dim)
data = np.atleast_1d(dataset.data[dim].data)
Expand All @@ -162,7 +162,7 @@ def values(cls, dataset, dim, expanded=True, flat=True):
dim = dataset.get_dimension(dim, strict=True)
data = dataset.data[dim.name].data
if dim in dataset.vdims:
coord_dims = dataset.data[dim.name].dims
coord_dims = list(dataset.data.dims.keys())[::-1]
if dask and isinstance(data, dask.array.Array):
data = data.compute()
data = cls.canonicalize(dataset, data, coord_dims=coord_dims)
Expand Down
9 changes: 9 additions & 0 deletions tests/testdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,15 @@ def setUp(self):
self.init_column_data()
self.init_grid_data()

def test_xarray_dataset_with_scalar_dim_canonicalize(self):
import xarray as xr
xs = [0, 1]
ys = [0.1, 0.2, 0.3]
zs = np.array([[[0, 1], [2, 3], [4, 5]]])
xrarr = xr.DataArray(zs, coords={'x': xs, 'y': ys, 't': [1]}, dims=['t', 'y', 'x'])
ds = Dataset(xrarr, kdims=['x', 'y'], vdims=['z'], datatype=['xarray'])
self.assertEqual(ds.dimension_values(2, flat=False).ndim, 2)

# Disabled tests for NotImplemented methods
def test_dataset_add_dimensions_values_hm(self):
raise SkipTest("Not supported")
Expand Down

0 comments on commit b8dae63

Please sign in to comment.