You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When trying to get the Calculating Seasonal Averages from Timeseries of Monthly Means example from the documentation to work with cupy I'm experiencing an unexpected Unsupported type <class 'numpy.ndarray'> error when calling ds_unweighted = ds.groupby('time.season').mean('time')
I dug through this with @quasiben and it seems to be related to the as_shared_dtype function.
What happened:
Running the MCVE below results in Unsupported type <class 'numpy.ndarray'>. It seems at somewhere in the stack there is a call to _replace_nan(a, 0) where the cupy array is having nan values replaced with 0. This ends up as a call to xarray.core.duck_array_ops.where with the "is nan", 0 and the cupy array being passed.
However _where calls as_shared_dtype on the 0 and cupy array, which converts the 0 to a scalar numpy array.
Cupy is then passed this numpy array to it's where function which does raises the exception.
What you expected to happen:
The cupy.where function can either take a Python int/float or a cupy array, not a numpy scalar.
Therefore a few things could be done here:
Xarray could not convert the int/float to a numpy array
It could convert it to a cupy array
Cupy could be modified to accept a numpy scalar.
We thew together a quick fix for option 2, which I'll put in a draft PR. But happy to discuss the alternatives.
Minimal Complete Verifiable Example:
importnumpyasnpimportpandasaspdimportxarrayasxrimportmatplotlib.pyplotaspltimportcupyascp# Load datads=xr.tutorial.open_dataset("rasm").load()
# Move data to GPUds.Tair.data=cp.asarray(ds.Tair.data)
ds_unweighted=ds.groupby("time.season").mean("time")
# Calculate the weights by grouping by 'time.season'.month_length=ds.time.dt.days_in_monthweights= (
month_length.groupby("time.season") /month_length.groupby("time.season").sum()
)
# Test that the sum of the weights for each season is 1.0np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))
# Move weights to GPUweights.data=cp.asarray(weights.data)
# Calculate the weighted averageds_weighted=ds*weightsds_weighted=ds_weighted.groupby("time.season")
ds_weighted=ds_weighted.sum(dim="time")
Traceback
Traceback (most recent call last):
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 193, in _run_module_as_main"__main__", mod_spec)
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 85, in _run_codeexec(code, run_globals)
File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
cli.main()
File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/jacob/.vscode-server/extensions/ms-python.python-2020.6.91350/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 267, in run_file
runpy.run_path(options.target, run_name=compat.force_str("__main__"))
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/runpy.py", line 85, in _run_codeexec(code, run_globals)
File "/home/jacob/Projects/pydata/xarray/test_seasonal_averages.py", line 32, in <module>
ds_weighted = ds_weighted.sum(dim="time")
File "/home/jacob/Projects/pydata/xarray/xarray/core/common.py", line 84, in wrapped_func
func, dim, skipna=skipna, numeric_only=numeric_only, **kwargs
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 994, in reducereturnself.map(reduce_dataset)
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 923, in mapreturnself._combine(applied)
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 943, in _combine
applied_example, applied = peek_at(applied)
File "/home/jacob/Projects/pydata/xarray/xarray/core/utils.py", line 183, in peek_at
peek =next(gen)
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 922, in <genexpr>
applied = (func(ds, *args, **kwargs) for ds inself._iter_grouped())
File "/home/jacob/Projects/pydata/xarray/xarray/core/groupby.py", line 990, in reduce_datasetreturn ds.reduce(func, dim, keep_attrs, **kwargs)
File "/home/jacob/Projects/pydata/xarray/xarray/core/dataset.py", line 4313, in reduce**kwargs,
File "/home/jacob/Projects/pydata/xarray/xarray/core/variable.py", line 1591, in reduce
data = func(input_data, axis=axis, **kwargs)
File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 324, in freturn func(values, axis=axis, **kwargs)
File "/home/jacob/Projects/pydata/xarray/xarray/core/nanops.py", line 111, in nansum
a, mask = _replace_nan(a, 0)
File "/home/jacob/Projects/pydata/xarray/xarray/core/nanops.py", line 21, in _replace_nanreturn where_method(val, mask, a), mask
File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 274, in where_methodreturn where(cond, data, other)
File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 268, in wherereturn _where(condition, *as_shared_dtype([x, y]))
File "/home/jacob/Projects/pydata/xarray/xarray/core/duck_array_ops.py", line 56, in freturn wrapped(*args, **kwargs)
File "<__array_function__ internals>", line 6, in where
File "cupy/core/core.pyx", line 1343, in cupy.core.core.ndarray.__array_function__
File "/home/jacob/miniconda3/envs/dask/lib/python3.7/site-packages/cupy/sorting/search.py", line 211, in wherereturn _where_ufunc(condition.astype('?'), x, y)
File "cupy/core/_kernel.pyx", line 906, in cupy.core._kernel.ufunc.__call__
File "cupy/core/_kernel.pyx", line 90, in cupy.core._kernel._preprocess_argsTypeError: Unsupported type <class 'numpy.ndarray'>
Anything else we need to know?:
Environment:
Output of xr.show_versions()
INSTALLED VERSIONS
commit: 52043bc
python: 3.7.6 | packaged by conda-forge | (default, Jun 1 2020, 18:57:50)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-62-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_GB.UTF-8
LOCALE: en_GB.UTF-8
libhdf5: None
libnetcdf: None
Related to #4212
When trying to get the Calculating Seasonal Averages from Timeseries of Monthly Means example from the documentation to work with
cupy
I'm experiencing an unexpectedUnsupported type <class 'numpy.ndarray'>
error when callingds_unweighted = ds.groupby('time.season').mean('time')
I dug through this with @quasiben and it seems to be related to the
as_shared_dtype
function.What happened:
Running the MCVE below results in
Unsupported type <class 'numpy.ndarray'>
. It seems at somewhere in the stack there is a call to_replace_nan(a, 0)
where the cupy array is having nan values replaced with0
. This ends up as a call toxarray.core.duck_array_ops.where
with the "is nan",0
and the cupy array being passed.However
_where
callsas_shared_dtype
on the0
andcupy
array, which converts the0
to a scalar numpy array.Cupy is then passed this numpy array to it's where function which does raises the exception.
What you expected to happen:
The
cupy.where
function can either take a Python int/float or a cupy array, not a numpy scalar.Therefore a few things could be done here:
We thew together a quick fix for option 2, which I'll put in a draft PR. But happy to discuss the alternatives.
Minimal Complete Verifiable Example:
Traceback
Anything else we need to know?:
Environment:
Output of xr.show_versions()
INSTALLED VERSIONS
commit: 52043bc
python: 3.7.6 | packaged by conda-forge | (default, Jun 1 2020, 18:57:50)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-62-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_GB.UTF-8
LOCALE: en_GB.UTF-8
libhdf5: None
libnetcdf: None
xarray: 0.15.1
pandas: 0.25.3
numpy: 1.18.5
scipy: 1.5.0
netCDF4: None
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.2.0
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: 0.9.8.3
iris: None
bottleneck: None
dask: 2.20.0
distributed: 2.20.0
matplotlib: 3.2.2
cartopy: 0.17.0
seaborn: 0.10.1
numbagg: None
pint: None
setuptools: 49.1.0.post20200704
pip: 20.1.1
conda: None
pytest: 5.4.3
IPython: 7.16.1
sphinx: None
The text was updated successfully, but these errors were encountered: