Skip to content

Commit

Permalink
Dataset factories (#2635)
Browse files Browse the repository at this point in the history
* Cleaned up and up to date version of dataset factories code

Signed-off-by: Merel Theisen <[email protected]>

* Add some simple tests

Signed-off-by: Merel Theisen <[email protected]>

* Add parsing rules

Signed-off-by: Ankita Katiyar <[email protected]>

* Refactor

Signed-off-by: Ankita Katiyar <[email protected]>

* Add some tests

Signed-off-by: Ankita Katiyar <[email protected]>

* Add unit tests

Signed-off-by: Ankita Katiyar <[email protected]>

* Fix test + refactor runner

Signed-off-by: Ankita Katiyar <[email protected]>

* Add comments + update specificity fn

Signed-off-by: Ankita Katiyar <[email protected]>

* Update function names

Signed-off-by: Ankita Katiyar <[email protected]>

* Update test

Signed-off-by: Ankita Katiyar <[email protected]>

* Release notes + update resume scenario fix

Signed-off-by: Ankita Katiyar <[email protected]>

* revert change to suggest resume scenario

Signed-off-by: Ankita Katiyar <[email protected]>

* Update tests DataSet->Dataset

Signed-off-by: Ankita Katiyar <[email protected]>

* Small refactor + move parsing rules to a new fn

Signed-off-by: Ankita Katiyar <[email protected]>

* Fix problem with load_version + refactor

Signed-off-by: Ankita Katiyar <[email protected]>

* linting + small fix _get_datasets

Signed-off-by: Ankita Katiyar <[email protected]>

* Remove check for existence

Signed-off-by: Ankita Katiyar <[email protected]>

* Add updated tests + Release notes

Signed-off-by: Ankita Katiyar <[email protected]>

* change classmethod to staticmethod for _match_patterns

Signed-off-by: Ankita Katiyar <[email protected]>

* Add test for layer

Signed-off-by: Ankita Katiyar <[email protected]>

* Minor change from code review

Signed-off-by: Ankita Katiyar <[email protected]>

* Remove type conversion

Signed-off-by: Ankita Katiyar <[email protected]>

* Add warning for catch-all patterns [dataset factories] (#2774)

* Add warning for catch-all patterns

Signed-off-by: Ankita Katiyar <[email protected]>

* Update warning message

Signed-off-by: Ankita Katiyar <[email protected]>

---------

Signed-off-by: Ankita Katiyar <[email protected]>

---------

Signed-off-by: Merel Theisen <[email protected]>
Signed-off-by: Ankita Katiyar <[email protected]>
Co-authored-by: Merel Theisen <[email protected]>
  • Loading branch information
ankatiyar and merelcht committed Jul 6, 2023
1 parent 3c68682 commit 6da8bde
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 25 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Upcoming Release 0.18.12

## Major features and improvements
* Added dataset factories feature which uses pattern matching to reduce the number of catalog entries.

## Bug fixes and other changes

Expand Down
1 change: 1 addition & 0 deletions dependency/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resou
jmespath>=0.9.5, <1.0
more_itertools~=9.0
omegaconf~=2.3
parse~=1.19.0
pip-tools~=6.5
pluggy~=1.0
PyYAML>=4.2, <7.0
Expand Down
184 changes: 164 additions & 20 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import logging
import re
from collections import defaultdict
from typing import Any
from typing import Any, Iterable

from parse import parse

from kedro.io.core import (
AbstractDataSet,
Expand Down Expand Up @@ -136,11 +138,14 @@ class DataCatalog:
to the underlying data sets.
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
data_sets: dict[str, AbstractDataSet] = None,
feed_dict: dict[str, Any] = None,
layers: dict[str, set[str]] = None,
dataset_patterns: dict[str, dict[str, Any]] = None,
load_versions: dict[str, str] = None,
save_version: str = None,
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataSet``
implementations to provide ``load`` and ``save`` capabilities from
Expand All @@ -156,6 +161,15 @@ def __init__(
to a set of data set names, according to the
data engineering convention. For more details, see
https://docs.kedro.org/en/stable/resources/glossary.html#layers-data-engineering-convention
dataset_patterns: A dictionary of data set factory patterns
and corresponding data set configuration
load_versions: A mapping between data set names and versions
to load. Has no effect on data sets without enabled versioning.
save_version: Version string to be used for ``save`` operations
by all data sets with enabled versioning. It must: a) be a
case-insensitive string that conforms with operating system
filename limitations, b) always return the latest version when
sorted in lexicographical order.
Example:
::
Expand All @@ -170,8 +184,12 @@ def __init__(
self._data_sets = dict(data_sets or {})
self.datasets = _FrozenDatasets(self._data_sets)
self.layers = layers
# Keep a record of all patterns in the catalog.
# {dataset pattern name : dataset pattern body}
self._dataset_patterns = dataset_patterns or {}
self._load_versions = load_versions or {}
self._save_version = save_version

# import the feed dict
if feed_dict:
self.add_feed_dict(feed_dict)

Expand All @@ -181,7 +199,7 @@ def _logger(self):

@classmethod
def from_config(
cls: type,
cls,
catalog: dict[str, dict[str, Any]] | None,
credentials: dict[str, dict[str, Any]] = None,
load_versions: dict[str, str] = None,
Expand Down Expand Up @@ -257,35 +275,124 @@ class to be loaded is specified with the key ``type`` and their
>>> catalog.save("boats", df)
"""
data_sets = {}
dataset_patterns = {}
catalog = copy.deepcopy(catalog) or {}
credentials = copy.deepcopy(credentials) or {}
save_version = save_version or generate_timestamp()
load_versions = copy.deepcopy(load_versions) or {}
layers: dict[str, set[str]] = defaultdict(set)

missing_keys = load_versions.keys() - catalog.keys()
for ds_name, ds_config in catalog.items():
ds_config = _resolve_credentials(ds_config, credentials)
if cls._is_pattern(ds_name):
# Add each factory to the dataset_patterns dict.
dataset_patterns[ds_name] = ds_config

else:
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
dataset_layers = layers or None
sorted_patterns = cls._sort_patterns(dataset_patterns)
missing_keys = [
key
for key in load_versions.keys()
if not (cls._match_pattern(sorted_patterns, key) or key in catalog)
]
if missing_keys:
raise DatasetNotFoundError(
f"'load_versions' keys [{', '.join(sorted(missing_keys))}] "
f"are not found in the catalog."
)

layers: dict[str, set[str]] = defaultdict(set)
for ds_name, ds_config in catalog.items():
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)
return cls(
data_sets=data_sets,
layers=dataset_layers,
dataset_patterns=sorted_patterns,
load_versions=load_versions,
save_version=save_version,
)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
@staticmethod
def _is_pattern(pattern: str):
"""Check if a given string is a pattern. Assume that any name with '{' is a pattern."""
return "{" in pattern

@staticmethod
def _match_pattern(
data_set_patterns: dict[str, dict[str, Any]], data_set_name: str
) -> str | None:
"""Match a dataset name against patterns in a dictionary containing patterns"""
for pattern, _ in data_set_patterns.items():
result = parse(pattern, data_set_name)
if result:
return pattern
return None

dataset_layers = layers or None
return cls(data_sets=data_sets, layers=dataset_layers)
@classmethod
def _sort_patterns(
cls, data_set_patterns: dict[str, dict[str, Any]]
) -> dict[str, dict[str, Any]]:
"""Sort a dictionary of dataset patterns according to parsing rules -
1. Decreasing specificity (number of characters outside the curly brackets)
2. Decreasing number of placeholders (number of curly bracket pairs)
3. Alphabetically
"""
sorted_keys = sorted(
data_set_patterns,
key=lambda pattern: (
-(cls._specificity(pattern)),
-pattern.count("{"),
pattern,
),
)
sorted_patterns = {}
for key in sorted_keys:
sorted_patterns[key] = data_set_patterns[key]
return sorted_patterns

@staticmethod
def _specificity(pattern: str) -> int:
"""Helper function to check the length of exactly matched characters not inside brackets
Example -
specificity("{namespace}.companies") = 10
specificity("{namespace}.{dataset}") = 1
specificity("france.companies") = 16
"""
# Remove all the placeholders from the pattern and count the number of remaining chars
result = re.sub(r"\{.*?\}", "", pattern)
return len(result)

def _get_dataset(
self, data_set_name: str, version: Version = None, suggest: bool = True
) -> AbstractDataSet:
matched_pattern = self._match_pattern(self._dataset_patterns, data_set_name)
if data_set_name not in self._data_sets and matched_pattern:
# If the dataset is a patterned dataset, materialise it and add it to
# the catalog
data_set_config = self._resolve_config(data_set_name, matched_pattern)
ds_layer = data_set_config.pop("layer", None)
if ds_layer:
self.layers = self.layers or {}
self.layers.setdefault(ds_layer, set()).add(data_set_name)
data_set = AbstractDataSet.from_config(
data_set_name,
data_set_config,
self._load_versions.get(data_set_name),
self._save_version,
)
if self._specificity(matched_pattern) == 0:
self._logger.warning(
"Config from the dataset factory pattern '%s' in the catalog will be used to "
"override the default MemoryDataset creation for the dataset '%s'",
matched_pattern,
data_set_name,
)

self.add(data_set_name, data_set)
if data_set_name not in self._data_sets:
error_msg = f"Dataset '{data_set_name}' not found in the catalog"

Expand All @@ -298,9 +405,7 @@ def _get_dataset(
if matches:
suggestions = ", ".join(matches)
error_msg += f" - did you mean one of these instead: {suggestions}"

raise DatasetNotFoundError(error_msg)

data_set = self._data_sets[data_set_name]
if version and isinstance(data_set, AbstractVersionedDataSet):
# we only want to return a similar-looking dataset,
Expand All @@ -311,6 +416,35 @@ def _get_dataset(

return data_set

def __contains__(self, data_set_name):
"""Check if an item is in the catalog as a materialised dataset or pattern"""
matched_pattern = self._match_pattern(self._dataset_patterns, data_set_name)
if data_set_name in self._data_sets or matched_pattern:
return True
return False

def _resolve_config(
self,
data_set_name: str,
matched_pattern: str,
) -> dict[str, Any]:
"""Get resolved AbstractDataSet from a factory config"""
result = parse(matched_pattern, data_set_name)
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
# Resolve the factory config for the dataset
for key, value in config_copy.items():
if isinstance(value, Iterable) and "}" in value:
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
config_copy[key] = str(value).format_map(result.named)
except KeyError as exc:
raise DatasetError(
f"Unable to resolve '{key}' for the pattern '{matched_pattern}'"
) from exc
return config_copy

def load(self, name: str, version: str = None) -> Any:
"""Loads a registered data set.
Expand Down Expand Up @@ -573,10 +707,20 @@ def shallow_copy(self) -> DataCatalog:
Returns:
Copy of the current object.
"""
return DataCatalog(data_sets=self._data_sets, layers=self.layers)
return DataCatalog(
data_sets=self._data_sets,
layers=self.layers,
dataset_patterns=self._dataset_patterns,
load_versions=self._load_versions,
save_version=self._save_version,
)

def __eq__(self, other):
return (self._data_sets, self.layers) == (other._data_sets, other.layers)
return (self._data_sets, self.layers, self._dataset_patterns) == (
other._data_sets,
other.layers,
other._dataset_patterns,
)

def confirm(self, name: str) -> None:
"""Confirm a dataset by its name.
Expand Down
21 changes: 16 additions & 5 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class AbstractRunner(ABC):
"""

def __init__(self, is_async: bool = False):
"""Instantiates the runner classs.
"""Instantiates the runner class.
Args:
is_async: If True, the node inputs and outputs are loaded and saved
Expand Down Expand Up @@ -74,14 +74,25 @@ def run(
hook_manager = hook_manager or _NullPluginManager()
catalog = catalog.shallow_copy()

unsatisfied = pipeline.inputs() - set(catalog.list())
# Check which datasets used in the pipeline are in the catalog or match
# a pattern in the catalog
registered_ds = [ds for ds in pipeline.data_sets() if ds in catalog]

# Check if there are any input datasets that aren't in the catalog and
# don't match a pattern in the catalog.
unsatisfied = pipeline.inputs() - set(registered_ds)

if unsatisfied:
raise ValueError(
f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
)

free_outputs = pipeline.outputs() - set(catalog.list())
unregistered_ds = pipeline.data_sets() - set(catalog.list())
# Check if there's any output datasets that aren't in the catalog and don't match a pattern
# in the catalog.
free_outputs = pipeline.outputs() - set(registered_ds)
unregistered_ds = pipeline.data_sets() - set(registered_ds)

# Create a default dataset for unregistered datasets
for ds_name in unregistered_ds:
catalog.add(ds_name, self.create_default_data_set(ds_name))

Expand Down Expand Up @@ -420,7 +431,7 @@ def _run_node_sequential(
items: Iterable = outputs.items()
# if all outputs are iterators, then the node is a generator node
if all(isinstance(d, Iterator) for d in outputs.values()):
# Python dictionaries are ordered so we are sure
# Python dictionaries are ordered, so we are sure
# the keys and the chunk streams are in the same order
# [a, b, c]
keys = list(outputs.keys())
Expand Down
Loading

0 comments on commit 6da8bde

Please sign in to comment.