Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datasets): Take mode argument into account when saving dataset #805

Merged
merged 26 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3fba245
Take mode argument into account when saving dataset
merelcht Aug 7, 2024
ed881fe
Fix test
merelcht Aug 7, 2024
49017f6
Merge branch 'main' into fix/handle-mode-arg-correctly
merelcht Aug 12, 2024
28d173f
Move mode to default save args
merelcht Aug 13, 2024
e1c59e8
Merge branch 'fix/handle-mode-arg-correctly' of https://github.com/ke…
merelcht Aug 13, 2024
e71b810
Revert changes to add mode to save args and add as fs arg default ins…
merelcht Aug 14, 2024
eef2651
Merge branch 'main' into fix/handle-mode-arg-correctly
merelcht Aug 14, 2024
079b0a3
Fix lint
merelcht Aug 14, 2024
1026380
Separate fs save and load args again
merelcht Aug 14, 2024
0725a96
Add tests for coverage
merelcht Aug 14, 2024
c7ba314
Try simplify init
merelcht Aug 14, 2024
9edd0f9
Remove writing to buffer and use save_args both in saving and conversion
merelcht Aug 20, 2024
5905ff6
Fix tests
merelcht Aug 20, 2024
0de251d
Revert back to using fs_args to pass mode argument
merelcht Aug 20, 2024
3238997
Use fs_args to pass mode for all pandas based datasets
merelcht Aug 20, 2024
ef11faa
Make other datasets use fs_args for handling mode as well
merelcht Aug 20, 2024
97e49a2
Refactor and make all datasets consistent
merelcht Aug 20, 2024
46aef8d
Fix message dataset
merelcht Aug 20, 2024
2aa49a8
Fix tests
merelcht Aug 20, 2024
86fe86f
Merge branch 'main' into fix/handle-mode-arg-correctly
merelcht Aug 21, 2024
9e7c363
Clean up
merelcht Aug 21, 2024
ea4556e
Merge branch 'fix/handle-mode-arg-correctly' of https://github.com/ke…
merelcht Aug 21, 2024
8f084de
Merge branch 'main' into fix/handle-mode-arg-correctly
merelcht Aug 21, 2024
e6825a7
Merge branch 'main' into fix/handle-mode-arg-correctly
merelcht Aug 22, 2024
0fb3dac
Update release notes
merelcht Aug 22, 2024
806d8f4
Merge branch 'fix/handle-mode-arg-correctly' of https://github.com/ke…
merelcht Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions kedro-datasets/kedro_datasets/biosequence/biosequence_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class BioSequenceDataset(AbstractDataset[list, list]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "w"},
"open_args_load": {"mode": "r"},
}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -96,18 +100,17 @@ def __init__( # noqa: PLR0913

self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_load.setdefault("mode", "r")
_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

self.metadata = metadata

Expand Down
8 changes: 2 additions & 6 deletions kedro-datasets/kedro_datasets/dask/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,8 @@ def __init__( # noqa: PLR0913
self.metadata = metadata

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

@property
def fs_args(self) -> dict[str, Any]:
Expand Down
8 changes: 2 additions & 6 deletions kedro-datasets/kedro_datasets/dask/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,8 @@ def __init__( # noqa: PLR0913
self.metadata = metadata

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

@property
def fs_args(self) -> dict[str, Any]:
Expand Down
27 changes: 15 additions & 12 deletions kedro-datasets/kedro_datasets/email/message_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class EmailMessageDataset(AbstractVersionedDataset[Message, Message]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "w"},
"open_args_load": {"mode": "r"},
}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -129,22 +133,21 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default load arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._parser_args = self._load_args.pop("parser", {"policy": default})

# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._generator_args = self._save_args.pop("generator", {})

_fs_open_args_load.setdefault("mode", "r")
_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _describe(self) -> dict[str, Any]:
return {
Expand Down
4 changes: 1 addition & 3 deletions kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def __init__( # noqa: PLR0913
self._fs_open_args_save = _fs_open_args_save

# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

def _describe(self) -> dict[str, Any]:
return {
Expand Down
24 changes: 13 additions & 11 deletions kedro-datasets/kedro_datasets/json/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class JSONDataset(AbstractVersionedDataset[Any, Any]):
"""

DEFAULT_SAVE_ARGS: dict[str, Any] = {"indent": 2}
DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "w"}}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -89,16 +90,15 @@ def __init__( # noqa: PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
All defaults are preserved, except `mode`, which is set to `w` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
_fs_args = deepcopy(fs_args) or {}
_fs_open_args_load = _fs_args.pop("open_args_load", {})
_fs_open_args_save = _fs_args.pop("open_args_save", {})
_credentials = deepcopy(credentials) or {}

_credentials = deepcopy(credentials) or {}
protocol, path = get_protocol_and_path(filepath, version)

self._protocol = protocol
Expand All @@ -115,14 +115,16 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default save and fs arguments
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _describe(self) -> dict[str, Any]:
return {
Expand Down
24 changes: 13 additions & 11 deletions kedro-datasets/kedro_datasets/matlab/matlab_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class MatlabDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]):
"""

DEFAULT_SAVE_ARGS: dict[str, Any] = {"indent": 2}
DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "wb"}}

def __init__( # noqa = PLR0913
self,
Expand Down Expand Up @@ -83,8 +84,7 @@ def __init__( # noqa = PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
All defaults are preserved, except `mode`, which is set to `wb` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
Expand All @@ -106,14 +106,16 @@ def __init__( # noqa = PLR0913
exists_function=self._fs.exists,
glob_function=self._fs.glob,
)
# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default save and fs arguments
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _describe(self) -> dict[str, Any]:
return {
Expand All @@ -134,7 +136,7 @@ def _load(self) -> np.ndarray:

def _save(self, data: np.ndarray) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, mode="wb") as f:
with self._fs.open(save_path, **self._fs_open_args_save) as f:
io.savemat(f, {"data": data})
self._invalidate_cache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def __init__( # noqa: PLR0913
self._fs_open_args_save = _fs_open_args_save

# Handle default save arguments
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}

if overwrite and version is not None:
warn(
Expand Down
32 changes: 18 additions & 14 deletions kedro-datasets/kedro_datasets/networkx/gml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class GMLDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "wb"},
"open_args_load": {"mode": "rb"},
}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -74,9 +78,9 @@ def __init__( # noqa: PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
metadata: Any Any arbitrary metadata.
All defaults are preserved, except `mode`, which is set to `rb` when loading
and to `wb` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
_fs_args = deepcopy(fs_args) or {}
Expand All @@ -100,17 +104,17 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
_fs_open_args_load.setdefault("mode", "rb")
_fs_open_args_save.setdefault("mode", "wb")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand Down
32 changes: 18 additions & 14 deletions kedro-datasets/kedro_datasets/networkx/graphml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class GraphMLDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {
"open_args_save": {"mode": "wb"},
"open_args_load": {"mode": "rb"},
}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -73,9 +77,9 @@ def __init__( # noqa: PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
metadata: Any arbitrary Any arbitrary metadata.
All defaults are preserved, except `mode`, which is set to `rb` when loading
and to `wb` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
_fs_args = deepcopy(fs_args) or {}
Expand All @@ -99,17 +103,17 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
_fs_open_args_load.setdefault("mode", "rb")
_fs_open_args_save.setdefault("mode", "wb")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand Down
28 changes: 14 additions & 14 deletions kedro-datasets/kedro_datasets/networkx/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class JSONDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]):

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "w"}}

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -74,9 +75,8 @@ def __init__( # noqa: PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
metadata: Any Any arbitrary metadata.
All defaults are preserved, except `mode`, which is set to `w` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
_fs_args = deepcopy(fs_args) or {}
Expand All @@ -100,17 +100,17 @@ def __init__( # noqa: PLR0913
glob_function=self._fs.glob,
)

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_save.setdefault("mode", "w")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
# Handle default load and save and fs arguments
self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})}
self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})}
self._fs_open_args_load = {
**self.DEFAULT_FS_ARGS.get("open_args_load", {}),
**(_fs_open_args_load or {}),
}
self._fs_open_args_save = {
**self.DEFAULT_FS_ARGS.get("open_args_save", {}),
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand Down
Loading