diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 3a3fb19daa4..3f423eaa8a6 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -91,17 +91,13 @@ def _collect_data_and_coord_variables( return data_variables, coord_variables -def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: - if isinstance(data, DataArray): - ds = data.to_dataset() - elif isinstance(data, Dataset): +def _to_new_dataset(data: Dataset | None) -> Dataset: + if isinstance(data, Dataset): ds = data.copy(deep=False) elif data is None: ds = Dataset() else: - raise TypeError( - f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}" - ) + raise TypeError(f"data object is not an xarray.Dataset, dict, or None: {data}") return ds @@ -422,7 +418,8 @@ class DataTree( def __init__( self, - data: Dataset | DataArray | None = None, + data: Dataset | None = None, + parent: DataTree | None = None, children: Mapping[str, DataTree] | None = None, name: str | None = None, ): @@ -435,9 +432,8 @@ def __init__( Parameters ---------- - data : Dataset, DataArray, or None, optional - Data to store under the .ds attribute of this node. DataArrays will - be promoted to Datasets. Default is None. + data : Dataset, optional + Data to store under the .ds attribute of this node. children : Mapping[str, DataTree], optional Any child nodes of this node. Default is None. name : str, optional @@ -455,7 +451,7 @@ def __init__( children = {} super().__init__(name=name) - self._set_node_data(_coerce_to_dataset(data)) + self._set_node_data(_to_new_dataset(data)) # shallow copy to avoid modifying arguments in-place (see GH issue #9196) self.children = {name: child.copy() for name, child in children.items()} @@ -540,8 +536,8 @@ def ds(self) -> DatasetView: return self._to_dataset_view(rebuild_dims=True) @ds.setter - def ds(self, data: Dataset | DataArray | None = None) -> None: - ds = _coerce_to_dataset(data) + def ds(self, data: Dataset | None = None) -> None: + ds = _to_new_dataset(data) self._replace_node(ds) def to_dataset(self, inherited: bool = True) -> Dataset: @@ -1050,7 +1046,7 @@ def drop_nodes( @classmethod def from_dict( cls, - d: Mapping[str, Dataset | DataArray | DataTree | None], + d: Mapping[str, Dataset | DataTree | None], name: str | None = None, ) -> DataTree: """ @@ -1059,10 +1055,10 @@ def from_dict( Parameters ---------- d : dict-like - A mapping from path names to xarray.Dataset, xarray.DataArray, or DataTree objects. + A mapping from path names to xarray.Dataset or DataTree objects. - Path names are to be given as unix-like path. If path names containing more than one part are given, new - tree nodes will be constructed as necessary. + Path names are to be given as unix-like path. If path names containing more than one + part are given, new tree nodes will be constructed as necessary. To assign data to the root node of the tree use "/" as the path. name : Hashable | None, optional @@ -1083,8 +1079,12 @@ def from_dict( if isinstance(root_data, DataTree): obj = root_data.copy() obj.orphan() - else: + elif root_data is None or isinstance(root_data, Dataset): obj = cls(name=name, data=root_data, children=None) + else: + raise TypeError( + f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}' + ) def depth(item) -> int: pathstr, _ = item diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 1a840dbb9e4..eca051bb7fc 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -33,6 +33,14 @@ def test_bad_names(self): with pytest.raises(ValueError): DataTree(name="folder/data") + def test_data_arg(self): + ds = xr.Dataset({"foo": 42}) + tree: DataTree = DataTree(data=ds) + assert_identical(tree.to_dataset(), ds) + + with pytest.raises(TypeError): + DataTree(data=xr.DataArray(42, name="foo")) # type: ignore + class TestFamilyTree: def test_dont_modify_children_inplace(self): @@ -613,6 +621,11 @@ def test_insertion_order(self): # despite 'Bart' coming before 'Lisa' when sorted alphabetically assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"] + def test_array_values(self): + data = {"foo": xr.DataArray(1, name="bar")} + with pytest.raises(TypeError): + DataTree.from_dict(data) # type: ignore + class TestDatasetView: def test_view_contents(self):