Skip to content

Commit

Permalink
fix(datasets): Take mode argument into account when saving dataset (#805
Browse files Browse the repository at this point in the history
)

* Take mode argument into account when saving dataset
* Move mode to default save args
* Revert changes to add mode to save args and add as fs arg default instead
* Separate fs save and load args again
* Add tests for coverage
* Use fs_args to pass mode for all pandas based datasets
* Make other datasets use fs_args for handling mode as well
* Refactor and make all datasets consistent
* Update release notes

---------

Signed-off-by: Merel Theisen <[email protected]>
  • Loading branch information
merelcht committed Aug 22, 2024
1 parent f9086a3 commit c67fa9e
Show file tree
Hide file tree
Showing 44 changed files with 441 additions and 380 deletions.
2 changes: 2 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
| `plotly.HTMLDataset` | A dataset for saving a `plotly` figure as HTML | `kedro_datasets.plotly` |

## Bug fixes and other changes
* Refactored all datasets to set `fs_args` defaults in the same way as `load_args` and `save_args` and not have hardcoded values in the save methods.

## Breaking Changes
## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
Expand Down
27 changes: 15 additions & 12 deletions kedro-datasets/kedro_datasets/biosequence/biosequence_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class BioSequenceDataset(AbstractDataset[list, list]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "w"},
"open_args_load": {"mode": "r"},
}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -96,18 +100,17 @@ def __init__( # noqa: PLR0913

self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_load.setdefault("mode", "r")
_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

self.metadata = metadata

Expand Down
8 changes: 2 additions & 6 deletions kedro-datasets/kedro_datasets/dask/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,8 @@ def __init__( # noqa: PLR0913
self.metadata = metadata

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

@property
def fs_args(self) -> dict[str, Any]:
Expand Down
8 changes: 2 additions & 6 deletions kedro-datasets/kedro_datasets/dask/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,8 @@ def __init__( # noqa: PLR0913
self.metadata = metadata

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

@property
def fs_args(self) -> dict[str, Any]:
Expand Down
27 changes: 15 additions & 12 deletions kedro-datasets/kedro_datasets/email/message_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class EmailMessageDataset(AbstractVersionedDataset[Message, Message]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "w"},
"open_args_load": {"mode": "r"},
}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -129,22 +133,21 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default load arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._parser_args = self._load_args.pop("parser", {"policy": default})

# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._generator_args = self._save_args.pop("generator", {})

_fs_open_args_load.setdefault("mode", "r")
_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _describe(self) -> dict[str, Any]:
return {
Expand Down
4 changes: 1 addition & 3 deletions kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def __init__( # noqa: PLR0913
self._fs_open_args_save = _fs_open_args_save

# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

def _describe(self) -> dict[str, Any]:
return {
Expand Down
24 changes: 13 additions & 11 deletions kedro-datasets/kedro_datasets/json/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class JSONDataset(AbstractVersionedDataset[Any, Any]):
"""

DEFAULT_SAVE_ARGS: dict[str, Any] = {"indent": 2}
DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "w"}}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -89,16 +90,15 @@ def __init__( # noqa: PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
All defaults are preserved, except `mode`, which is set to `w` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
_fs_args = deepcopy(fs_args) or {}
_fs_open_args_load = _fs_args.pop("open_args_load", {})
_fs_open_args_save = _fs_args.pop("open_args_save", {})
_credentials = deepcopy(credentials) or {}

_credentials = deepcopy(credentials) or {}
protocol, path = get_protocol_and_path(filepath, version)

self._protocol = protocol
Expand All @@ -115,14 +115,16 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default save and fs arguments
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _describe(self) -> dict[str, Any]:
return {
Expand Down
24 changes: 13 additions & 11 deletions kedro-datasets/kedro_datasets/matlab/matlab_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class MatlabDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]):
"""

DEFAULT_SAVE_ARGS: dict[str, Any] = {"indent": 2}
DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "wb"}}

def __init__( # noqa = PLR0913
self,
Expand Down Expand Up @@ -83,8 +84,7 @@ def __init__( # noqa = PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
All defaults are preserved, except `mode`, which is set to `wb` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
Expand All @@ -106,14 +106,16 @@ def __init__( # noqa = PLR0913
exists_function=self._fs.exists,
glob_function=self._fs.glob,
)
# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default save and fs arguments
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _describe(self) -> dict[str, Any]:
return {
Expand All @@ -134,7 +136,7 @@ def _load(self) -> np.ndarray:

def _save(self, data: np.ndarray) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, mode="wb") as f:
with self._fs.open(save_path, **self._fs_open_args_save) as f:
io.savemat(f, {"data": data})
self._invalidate_cache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def __init__( # noqa: PLR0913
self._fs_open_args_save = _fs_open_args_save

# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

if overwrite and version is not None:
warn(
Expand Down
32 changes: 18 additions & 14 deletions kedro-datasets/kedro_datasets/networkx/gml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class GMLDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "wb"},
"open_args_load": {"mode": "rb"},
}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -74,9 +78,9 @@ def __init__( # noqa: PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
metadata: Any Any arbitrary metadata.
All defaults are preserved, except `mode`, which is set to `rb` when loading
and to `wb` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
_fs_args = deepcopy(fs_args) or {}
Expand All @@ -100,17 +104,17 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
_fs_open_args_load.setdefault("mode", "rb")
_fs_open_args_save.setdefault("mode", "wb")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand Down
32 changes: 18 additions & 14 deletions kedro-datasets/kedro_datasets/networkx/graphml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class GraphMLDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "wb"},
"open_args_load": {"mode": "rb"},
}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -73,9 +77,9 @@ def __init__( # noqa: PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
metadata: Any arbitrary Any arbitrary metadata.
All defaults are preserved, except `mode`, which is set to `rb` when loading
and to `wb` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
_fs_args = deepcopy(fs_args) or {}
Expand All @@ -99,17 +103,17 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
_fs_open_args_load.setdefault("mode", "rb")
_fs_open_args_save.setdefault("mode", "wb")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand Down
Loading

0 comments on commit c67fa9e

Please sign in to comment.