Skip to content

Commit

Permalink
open_datatree performance improvement on NetCDF, H5, and Zarr files (#…
Browse files Browse the repository at this point in the history
…9014)

* open_datatree performance improvement on NetCDF files

* fixing issue with forward slashes

* fixing issue with pytest

* open datatree in zarr format improvement

* fixing incompatibility in returned object

* passing group parameter to opendatatree method and reducing duplicated code

* passing group parameter to opendatatree method - NetCDF

* Update xarray/backends/netCDF4_.py

renaming variables

Co-authored-by: Tom Nicholas <[email protected]>

* renaming variables

* renaming variables

* renaming group_store variable

* removing _open_datatree_netcdf function not used anymore in open_datatree implementations

* improving performance of open_datatree method

* renaming 'i' variable within list comprehension in open_store method for zarr datatree

* using the default generator instead of loading zarr groups in memory

* fixing issue with group path to avoid using group[1:] notation. Adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree for h5 files. Finally, separating positional from keyword args

* fixing issue with group path to avoid using group[1:] notation and adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree method for netCDF files

* fixing issue with group path to avoid using group[1:] notation and adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree method for zarr files

* adding 'mode' parameter to open_datatree method

* adding 'mode' parameter to H5NetCDFStore.open method

* adding new entry related to open_datatree performance improvement

* adding new entry related to open_datatree performance improvement

* Getting rid of unnecessary parameters for 'open_datatree' method for netCDF4 and Hdf5 backends

---------

Co-authored-by: Tom Nicholas <[email protected]>
Co-authored-by: Kai Mühlbauer <[email protected]>
  • Loading branch information
3 people authored and andersy005 committed Jun 14, 2024
1 parent d5d2e3b commit 6c34f6a
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 128 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Performance
By `Deepak Cherian <https://github.com/dcherian>`_.
- Small optimizations to help reduce indexing speed of datasets (:pull:`9002`).
By `Mark Harfouche <https://github.com/hmaarrfk>`_.
- Performance improvement in `open_datatree` method for Zarr, netCDF4 and h5netcdf backends (:issue:`8994`, :pull:`9014`).
By `Alfonso Ladino <https://github.com/aladinor>`_.


Breaking changes
Expand Down
30 changes: 0 additions & 30 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
if TYPE_CHECKING:
from io import BufferedIOBase

from h5netcdf.legacyapi import Dataset as ncDatasetLegacyH5
from netCDF4 import Dataset as ncDataset

from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.types import NestedSequence
Expand Down Expand Up @@ -131,33 +128,6 @@ def _decode_variable_name(name):
return name


def _open_datatree_netcdf(
ncDataset: ncDataset | ncDatasetLegacyH5,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
**kwargs,
) -> DataTree:
from xarray.backends.api import open_dataset
from xarray.core.datatree import DataTree
from xarray.core.treenode import NodePath

ds = open_dataset(filename_or_obj, **kwargs)
tree_root = DataTree.from_dict({"/": ds})
with ncDataset(filename_or_obj, mode="r") as ncds:
for path in _iter_nc_groups(ncds):
subgroup_ds = open_dataset(filename_or_obj, group=path, **kwargs)

# TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again
node_name = NodePath(path).name
new_node: DataTree = DataTree(name=node_name, data=subgroup_ds)
tree_root._set_item(
path,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)
return tree_root


def _iter_nc_groups(root, parent="/"):
from xarray.core.treenode import NodePath

Expand Down
54 changes: 50 additions & 4 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import functools
import io
import os
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any

from xarray.backends.common import (
BACKEND_ENTRYPOINTS,
BackendEntrypoint,
WritableCFDataStore,
_normalize_path,
_open_datatree_netcdf,
find_root_and_group,
)
from xarray.backends.file_manager import CachingFileManager, DummyFileManager
Expand Down Expand Up @@ -441,11 +440,58 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
def open_datatree(
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
mask_and_scale=True,
decode_times=True,
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
group: str | Iterable[str] | Callable | None = None,
**kwargs,
) -> DataTree:
from h5netcdf.legacyapi import Dataset as ncDataset
from xarray.backends.api import open_dataset
from xarray.backends.common import _iter_nc_groups
from xarray.core.datatree import DataTree
from xarray.core.treenode import NodePath
from xarray.core.utils import close_on_error

return _open_datatree_netcdf(ncDataset, filename_or_obj, **kwargs)
filename_or_obj = _normalize_path(filename_or_obj)
store = H5NetCDFStore.open(
filename_or_obj,
group=group,
)
if group:
parent = NodePath("/") / NodePath(group)
else:
parent = NodePath("/")

manager = store._manager
ds = open_dataset(store, **kwargs)
tree_root = DataTree.from_dict({str(parent): ds})
for path_group in _iter_nc_groups(store.ds, parent=parent):
group_store = H5NetCDFStore(manager, group=path_group, **kwargs)
store_entrypoint = StoreBackendEntrypoint()
with close_on_error(group_store):
ds = store_entrypoint.open_dataset(
group_store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds)
tree_root._set_item(
path_group,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)
return tree_root


BACKEND_ENTRYPOINTS["h5netcdf"] = ("h5netcdf", H5netcdfBackendEntrypoint)
53 changes: 49 additions & 4 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import operator
import os
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from contextlib import suppress
from typing import TYPE_CHECKING, Any

Expand All @@ -16,7 +16,6 @@
BackendEntrypoint,
WritableCFDataStore,
_normalize_path,
_open_datatree_netcdf,
find_root_and_group,
robust_getitem,
)
Expand Down Expand Up @@ -682,11 +681,57 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
def open_datatree(
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
mask_and_scale=True,
decode_times=True,
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
group: str | Iterable[str] | Callable | None = None,
**kwargs,
) -> DataTree:
from netCDF4 import Dataset as ncDataset
from xarray.backends.api import open_dataset
from xarray.backends.common import _iter_nc_groups
from xarray.core.datatree import DataTree
from xarray.core.treenode import NodePath

return _open_datatree_netcdf(ncDataset, filename_or_obj, **kwargs)
filename_or_obj = _normalize_path(filename_or_obj)
store = NetCDF4DataStore.open(
filename_or_obj,
group=group,
)
if group:
parent = NodePath("/") / NodePath(group)
else:
parent = NodePath("/")

manager = store._manager
ds = open_dataset(store, **kwargs)
tree_root = DataTree.from_dict({str(parent): ds})
for path_group in _iter_nc_groups(store.ds, parent=parent):
group_store = NetCDF4DataStore(manager, group=path_group, **kwargs)
store_entrypoint = StoreBackendEntrypoint()
with close_on_error(group_store):
ds = store_entrypoint.open_dataset(
group_store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds)
tree_root._set_item(
path_group,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)
return tree_root


BACKEND_ENTRYPOINTS["netcdf4"] = ("netCDF4", NetCDF4BackendEntrypoint)
Loading

0 comments on commit 6c34f6a

Please sign in to comment.