Skip to content

Commit

Permalink
Disallow passing a DataArray as data into the DataTree constructor (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Sep 7, 2024
1 parent a74a605 commit 12c690f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
38 changes: 19 additions & 19 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,13 @@ def _collect_data_and_coord_variables(
return data_variables, coord_variables


def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset:
if isinstance(data, DataArray):
ds = data.to_dataset()
elif isinstance(data, Dataset):
def _to_new_dataset(data: Dataset | None) -> Dataset:
if isinstance(data, Dataset):
ds = data.copy(deep=False)
elif data is None:
ds = Dataset()
else:
raise TypeError(
f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}"
)
raise TypeError(f"data object is not an xarray.Dataset, dict, or None: {data}")
return ds


Expand Down Expand Up @@ -422,7 +418,8 @@ class DataTree(

def __init__(
self,
data: Dataset | DataArray | None = None,
data: Dataset | None = None,
parent: DataTree | None = None,
children: Mapping[str, DataTree] | None = None,
name: str | None = None,
):
Expand All @@ -435,9 +432,8 @@ def __init__(
Parameters
----------
data : Dataset, DataArray, or None, optional
Data to store under the .ds attribute of this node. DataArrays will
be promoted to Datasets. Default is None.
data : Dataset, optional
Data to store under the .ds attribute of this node.
children : Mapping[str, DataTree], optional
Any child nodes of this node. Default is None.
name : str, optional
Expand All @@ -455,7 +451,7 @@ def __init__(
children = {}

super().__init__(name=name)
self._set_node_data(_coerce_to_dataset(data))
self._set_node_data(_to_new_dataset(data))

# shallow copy to avoid modifying arguments in-place (see GH issue #9196)
self.children = {name: child.copy() for name, child in children.items()}
Expand Down Expand Up @@ -540,8 +536,8 @@ def ds(self) -> DatasetView:
return self._to_dataset_view(rebuild_dims=True)

@ds.setter
def ds(self, data: Dataset | DataArray | None = None) -> None:
ds = _coerce_to_dataset(data)
def ds(self, data: Dataset | None = None) -> None:
ds = _to_new_dataset(data)
self._replace_node(ds)

def to_dataset(self, inherited: bool = True) -> Dataset:
Expand Down Expand Up @@ -1050,7 +1046,7 @@ def drop_nodes(
@classmethod
def from_dict(
cls,
d: Mapping[str, Dataset | DataArray | DataTree | None],
d: Mapping[str, Dataset | DataTree | None],
name: str | None = None,
) -> DataTree:
"""
Expand All @@ -1059,10 +1055,10 @@ def from_dict(
Parameters
----------
d : dict-like
A mapping from path names to xarray.Dataset, xarray.DataArray, or DataTree objects.
A mapping from path names to xarray.Dataset or DataTree objects.
Path names are to be given as unix-like path. If path names containing more than one part are given, new
tree nodes will be constructed as necessary.
Path names are to be given as unix-like path. If path names containing more than one
part are given, new tree nodes will be constructed as necessary.
To assign data to the root node of the tree use "/" as the path.
name : Hashable | None, optional
Expand All @@ -1083,8 +1079,12 @@ def from_dict(
if isinstance(root_data, DataTree):
obj = root_data.copy()
obj.orphan()
else:
elif root_data is None or isinstance(root_data, Dataset):
obj = cls(name=name, data=root_data, children=None)
else:
raise TypeError(
f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}'
)

def depth(item) -> int:
pathstr, _ = item
Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def test_bad_names(self):
with pytest.raises(ValueError):
DataTree(name="folder/data")

def test_data_arg(self):
ds = xr.Dataset({"foo": 42})
tree: DataTree = DataTree(data=ds)
assert_identical(tree.to_dataset(), ds)

with pytest.raises(TypeError):
DataTree(data=xr.DataArray(42, name="foo")) # type: ignore


class TestFamilyTree:
def test_dont_modify_children_inplace(self):
Expand Down Expand Up @@ -613,6 +621,11 @@ def test_insertion_order(self):
# despite 'Bart' coming before 'Lisa' when sorted alphabetically
assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"]

def test_array_values(self):
data = {"foo": xr.DataArray(1, name="bar")}
with pytest.raises(TypeError):
DataTree.from_dict(data) # type: ignore


class TestDatasetView:
def test_view_contents(self):
Expand Down

0 comments on commit 12c690f

Please sign in to comment.