diff --git a/dask/array/utils.py b/dask/array/utils.py index 3e8ba2f219a..d5651eba498 100644 --- a/dask/array/utils.py +++ b/dask/array/utils.py @@ -29,15 +29,19 @@ def normalize_to_array(x): def normalize_meta(x, ndim, dtype=None): - if ndim > x.ndim: - meta = x[(Ellipsis, ) + tuple(None for _ in range(ndim - x.ndim))] - meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))] - elif ndim < x.ndim: - meta = np.sum(x, axis=tuple(d for d in range((x.ndim - ndim)))) - else: - meta = x - - if dtype: + try: + meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))] + if meta.ndim != ndim: + if ndim > x.ndim: + meta = x[(Ellipsis, ) + tuple(None for _ in range(ndim - meta.ndim))] + elif ndim == 0: + meta = meta.sum() + else: + meta = meta.reshape((0,) * ndim) + except Exception: + meta = np.empty((0,) * ndim, dtype=dtype or x.dtype) + + if dtype and meta.dtype != dtype: meta = meta.astype(dtype) return meta