Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

as_shared_dtype coerces scalars into numpy regardless of other array types #4231

Closed
jacobtomlinson opened this issue Jul 16, 2020 · 0 comments · Fixed by #4232
Closed

as_shared_dtype coerces scalars into numpy regardless of other array types #4231

jacobtomlinson opened this issue Jul 16, 2020 · 0 comments · Fixed by #4232

Comments

@jacobtomlinson
Copy link
Contributor

jacobtomlinson commented Jul 16, 2020

Related to #4212

When trying to get the Calculating Seasonal Averages from Timeseries of Monthly Means example from the documentation to work with cupy I'm experiencing an unexpected Unsupported type <class 'numpy.ndarray'> error when calling ds_unweighted = ds.groupby('time.season').mean('time')

I dug through this with @quasiben and it seems to be related to the as_shared_dtype function.

What happened:

Running the MCVE below results in Unsupported type <class 'numpy.ndarray'>. It seems at somewhere in the stack there is a call to _replace_nan(a, 0) where the cupy array is having nan values replaced with 0. This ends up as a call to xarray.core.duck_array_ops.where with the "is nan", 0 and the cupy array being passed.

However _where calls as_shared_dtype on the 0 and cupy array, which converts the 0 to a scalar numpy array.

Cupy is then passed this numpy array to it's where function which does raises the exception.

What you expected to happen:

The cupy.where function can either take a Python int/float or a cupy array, not a numpy scalar.

Therefore a few things could be done here:

  1. Xarray could not convert the int/float to a numpy array
  2. It could convert it to a cupy array
  3. Cupy could be modified to accept a numpy scalar.

We thew together a quick fix for option 2, which I'll put in a draft PR. But happy to discuss the alternatives.

Minimal Complete Verifiable Example:

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

import cupy as cp

# Load data
ds = xr.tutorial.open_dataset("rasm").load()

# Move data to GPU
ds.Tair.data = cp.asarray(ds.Tair.data)


ds_unweighted = ds.groupby("time.season").mean("time")

# Calculate the weights by grouping by 'time.season'.
month_length = ds.time.dt.days_in_month
weights = (
    month_length.groupby("time.season") / month_length.groupby("time.season").sum()
)
# Test that the sum of the weights for each season is 1.0
np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))

# Move weights to GPU
weights.data = cp.asarray(weights.data)


# Calculate the weighted average
ds_weighted = ds * weights
ds_weighted = ds_weighted.groupby("time.season")
ds_weighted = ds_weighted.sum(dim="time")
Traceback
Traceback (most recent call last):
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
    cli.main()
  File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 267, in run_file
    runpy.run_path(options.target, run_name=compat.force_str("__main__"))
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jacob/Projects/pydata/xarray/test_seasonal_averages.py", line 32, in <module>
    ds_weighted = ds_weighted.sum(dim="time")
  File "/home/jacob/Projects/pydata/xarray/xarray/core/common.py", line 84, in wrapped_func
    func, dim, skipna=skipna, numeric_only=numeric_only, **kwargs
  File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 994, in reduce
    return self.map(reduce_dataset)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 923, in map
    return self._combine(applied)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 943, in _combine
    applied_example, applied = peek_at(applied)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/utils.py", line 183, in peek_at
    peek = next(gen)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 922, in <genexpr>
    applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
  File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 990, in reduce_dataset
    return ds.reduce(func, dim, keep_attrs, **kwargs)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/dataset.py", line 4313, in reduce
    **kwargs,
  File "/home/jacob/Projects/pydata/xarray/xarray/core/variable.py", line 1591, in reduce
    data = func(input_data, axis=axis, **kwargs)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 324, in f
    return func(values, axis=axis, **kwargs)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/nanops.py", line 111, in nansum
    a, mask = _replace_nan(a, 0)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/nanops.py", line 21, in _replace_nan
    return where_method(val, mask, a), mask
  File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 274, in where_method
    return where(cond, data, other)
  File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 268, in where
    return _where(condition, *as_shared_dtype([x, y]))
  File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 56, in f
    return wrapped(*args, **kwargs)
  File "<__array_function__ internals>", line 6, in where
  File "cupy/core/core.pyx", line 1343, in cupy.core.core.ndarray.__array_function__
  File "/home/jacob/miniconda3/envs/dask/lib/python3.7/site-packages/cupy/sorting/search.py", line 211, in where
    return _where_ufunc(condition.astype('?'), x, y)
  File "cupy/core/_kernel.pyx", line 906, in cupy.core._kernel.ufunc.__call__
  File "cupy/core/_kernel.pyx", line 90, in cupy.core._kernel._preprocess_args
TypeError: Unsupported type <class 'numpy.ndarray'>

Anything else we need to know?:

Environment:

Output of xr.show_versions()

INSTALLED VERSIONS

commit: 52043bc
python: 3.7.6 | packaged by conda-forge | (default, Jun 1 2020, 18:57:50)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-62-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_GB.UTF-8
LOCALE: en_GB.UTF-8
libhdf5: None
libnetcdf: None

xarray: 0.15.1
pandas: 0.25.3
numpy: 1.18.5
scipy: 1.5.0
netCDF4: None
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.2.0
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: 0.9.8.3
iris: None
bottleneck: None
dask: 2.20.0
distributed: 2.20.0
matplotlib: 3.2.2
cartopy: 0.17.0
seaborn: 0.10.1
numbagg: None
pint: None
setuptools: 49.1.0.post20200704
pip: 20.1.1
conda: None
pytest: 5.4.3
IPython: 7.16.1
sphinx: None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant