Skip to content

Commit

Permalink
cast numpy scalars to arrays in as_compatible_data (#9403)
Browse files Browse the repository at this point in the history
* also call `np.asarray` on numpy scalars

* check that numpy scalars are properly casted to arrays

* don't allow `numpy.ndarray` subclasses

* comment on the purpose of the explicit isinstance and `np.asarray`
  • Loading branch information
keewis authored and hollymandel committed Sep 23, 2024
1 parent 54af952 commit 519d2c6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
6 changes: 4 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,14 @@ def convert_non_numpy_type(data):
else:
data = np.asarray(data)

if not isinstance(data, np.ndarray) and (
# immediately return array-like types except `numpy.ndarray` subclasses and `numpy` scalars
if not isinstance(data, np.ndarray | np.generic) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
):
return cast("T_DuckArray", data)

# validate whether the data is valid data types.
# validate whether the data is valid data types. Also, explicitly cast `numpy`
# subclasses and `numpy` scalars to `numpy.ndarray`
data = np.asarray(data)

if data.dtype.kind in "OMm":
Expand Down
7 changes: 6 additions & 1 deletion xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,7 +2585,12 @@ def test_unchanged_types(self):
assert source_ndarray(x) is source_ndarray(as_compatible_data(x))

def test_converted_types(self):
for input_array in [[[0, 1, 2]], pd.DataFrame([[0, 1, 2]])]:
for input_array in [
[[0, 1, 2]],
pd.DataFrame([[0, 1, 2]]),
np.float64(1.4),
np.str_("abc"),
]:
actual = as_compatible_data(input_array)
assert_array_equal(np.asarray(input_array), actual)
assert np.ndarray is type(actual)
Expand Down

0 comments on commit 519d2c6

Please sign in to comment.