Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset factories #2635

Merged
merged 29 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec66a12
Cleaned up and up to date version of dataset factories code
merelcht May 22, 2023
5e6c15d
Add some simple tests
merelcht May 23, 2023
0fca72c
Add parsing rules
ankatiyar Jun 2, 2023
06ed1a4
Refactor
ankatiyar Jun 8, 2023
b0e3fb9
Add some tests
ankatiyar Jun 8, 2023
0833af2
Add unit tests
ankatiyar Jun 12, 2023
8fc80f9
Fix test + refactor runner
ankatiyar Jun 12, 2023
091f794
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 12, 2023
8c192ee
Add comments + update specificity fn
ankatiyar Jun 13, 2023
3e2642c
Update function names
ankatiyar Jun 15, 2023
c2635d0
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 15, 2023
d310486
Update test
ankatiyar Jun 15, 2023
573c67f
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 19, 2023
9d80de4
Release notes + update resume scenario fix
ankatiyar Jun 19, 2023
549823f
revert change to suggest resume scenario
ankatiyar Jun 19, 2023
e052ae6
Update tests DataSet->Dataset
ankatiyar Jun 19, 2023
96c219f
Small refactor + move parsing rules to a new fn
ankatiyar Jun 20, 2023
c2634e5
Fix problem with load_version + refactor
ankatiyar Jul 3, 2023
eee606a
linting + small fix _get_datasets
ankatiyar Jul 3, 2023
635510a
Remove check for existence
ankatiyar Jul 4, 2023
394f37b
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 4, 2023
a1c602d
Add updated tests + Release notes
ankatiyar Jul 5, 2023
978d0a5
change classmethod to staticmethod for _match_patterns
ankatiyar Jul 5, 2023
b4fe7a7
Add test for layer
ankatiyar Jul 5, 2023
2782dca
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 5, 2023
85d3df1
Minor change from code review
ankatiyar Jul 5, 2023
8904ce3
Remove type conversion
ankatiyar Jul 6, 2023
bdc953d
Add warning for catch-all patterns [dataset factories] (#2774)
ankatiyar Jul 6, 2023
fa6c256
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dependency/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resou
jmespath>=0.9.5, <1.0
more_itertools~=9.0
omegaconf~=2.3
parse~=1.19.0
pip-tools~=6.5
pluggy~=1.0.0
PyYAML>=4.2, <7.0
Expand Down
158 changes: 132 additions & 26 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import logging
import re
from collections import defaultdict
from typing import Any
from typing import Any, Iterable

from parse import parse

from kedro.io.core import (
AbstractDataSet,
Expand Down Expand Up @@ -94,6 +96,20 @@ def _sub_nonword_chars(data_set_name: str) -> str:
return re.sub(WORDS_REGEX_PATTERN, "__", data_set_name)


def _specificity(pattern: str) -> int:
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
"""Helper function to check the length of exactly matched characters not inside brackets
Example -
specificity("{namespace}.companies") = 10
specificity("{namespace}.{dataset}") = 1
specificity("france.companies") = 16
Args:
pattern: The factory pattern
"""
# Remove all the placeholders from the pattern
result = re.sub(r"\{.*?\}", "", pattern)
return len(result)


class _FrozenDatasets:
"""Helper class to access underlying loaded datasets"""

Expand Down Expand Up @@ -141,6 +157,7 @@ def __init__(
data_sets: dict[str, AbstractDataSet] = None,
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
feed_dict: dict[str, Any] = None,
layers: dict[str, set[str]] = None,
dataset_patterns: dict[str, Any] = None,
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataSet``
implementations to provide ``load`` and ``save`` capabilities from
Expand Down Expand Up @@ -170,7 +187,23 @@ def __init__(
self._data_sets = dict(data_sets or {})
self.datasets = _FrozenDatasets(self._data_sets)
self.layers = layers

# Keep a record of all patterns in the catalog.
# {dataset pattern name : dataset pattern body}
self.dataset_patterns = dict(dataset_patterns or {})
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
# Sort all the patterns according to the parsing rules -
# 1. Decreasing specificity (no of characters outside the brackets)
# 2. Decreasing number of placeholders (no of curly brackets)
# 3. Alphabetical
self._sorted_dataset_patterns = sorted(
self.dataset_patterns.keys(),
key=lambda pattern: (
-(_specificity(pattern)),
-pattern.count("{"),
pattern,
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be refactor as a method and have its own unit tests

# Cache that stores {name : matched_pattern}
self._pattern_matches_cache: dict[str, str] = {}
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
# import the feed dict
if feed_dict:
self.add_feed_dict(feed_dict)
Expand Down Expand Up @@ -257,6 +290,7 @@ class to be loaded is specified with the key ``type`` and their
>>> catalog.save("boats", df)
"""
data_sets = {}
dataset_patterns = {}
catalog = copy.deepcopy(catalog) or {}
credentials = copy.deepcopy(credentials) or {}
save_version = save_version or generate_timestamp()
Expand All @@ -271,35 +305,54 @@ class to be loaded is specified with the key ``type`` and their

layers: dict[str, set[str]] = defaultdict(set)
for ds_name, ds_config in catalog.items():
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
# Assume that any name with } in it is a dataset factory to be matched.
if "}" in ds_name:
# Add each factory to the dataset_patterns dict.
dataset_patterns[ds_name] = ds_config
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
else:
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
dataset_layers = layers or None
return cls(data_sets=data_sets, layers=dataset_layers)
return cls(
data_sets=data_sets,
layers=dataset_layers,
dataset_patterns=dataset_patterns,
)

def _get_dataset(
self, data_set_name: str, version: Version = None, suggest: bool = True
) -> AbstractDataSet:
if data_set_name not in self._data_sets:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
error_msg = f"DataSet '{data_set_name}' not found in the catalog"

# Flag to turn on/off fuzzy-matching which can be time consuming and
# slow down plugins like `kedro-viz`
if suggest:
matches = difflib.get_close_matches(
data_set_name, self._data_sets.keys()
)
if matches:
suggestions = ", ".join(matches)
error_msg += f" - did you mean one of these instead: {suggestions}"

raise DataSetNotFoundError(error_msg)
# When a dataset is "used" in the pipeline that's not in the recorded catalog datasets,
# try to match it against the data factories in the catalog. If it's a match,
# resolve it to a dataset instance and add it to the catalog, so it only needs
# to be matched once and not everytime the dataset is used in the pipeline.
if self.exists_in_catalog_config(data_set_name):
pattern = self._pattern_matches_cache[data_set_name]
matched_dataset = self._resolve_dataset(data_set_name, pattern)
self.add(data_set_name, matched_dataset)
else:
error_msg = f"DataSet '{data_set_name}' not found in the catalog"

# Flag to turn on/off fuzzy-matching which can be time consuming and
# slow down plugins like `kedro-viz`
if suggest:
matches = difflib.get_close_matches(
data_set_name, self._data_sets.keys()
)
if matches:
suggestions = ", ".join(matches)
error_msg += (
f" - did you mean one of these instead: {suggestions}"
)

raise DataSetNotFoundError(error_msg)

data_set = self._data_sets[data_set_name]
if version and isinstance(data_set, AbstractVersionedDataSet):
Expand All @@ -311,6 +364,28 @@ def _get_dataset(

return data_set

def _resolve_dataset(
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
self, dataset_name: str, matched_pattern: str
) -> AbstractDataSet:
"""Get resolved AbstractDataSet from a factory config"""
result = parse(matched_pattern, dataset_name)
merelcht marked this conversation as resolved.
Show resolved Hide resolved
template_copy = copy.deepcopy(self.dataset_patterns[matched_pattern])
# Resolve the factory config for the dataset
for key, value in template_copy.items():
if isinstance(value, Iterable) and "}" in value:
string_value = str(value)
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
template_copy[key] = string_value.format_map(result.named)
except KeyError as exc:
raise DataSetError(
f"Unable to resolve '{key}' for the pattern '{matched_pattern}'"
) from exc
# Create dataset from catalog template.
return AbstractDataSet.from_config(dataset_name, template_copy)

def load(self, name: str, version: str = None) -> Any:
"""Loads a registered data set.

Expand Down Expand Up @@ -567,16 +642,47 @@ def list(self, regex_search: str | None = None) -> list[str]:
) from exc
return [dset_name for dset_name in self._data_sets if pattern.search(dset_name)]

def exists_in_catalog_config(self, dataset_name: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the cache that add matched pattern into catalog? If so why do we need to keep it separately instead of just creating the entry in catalog?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fn just adds the dataset name -> pattern in the self._pattern_matches_cache does not create an entry in the catalog. This is because this function is also called in the runner to check for existence of datasets from the pipeline in the catalog. The actual adding of a dataset to the catalog happens in self._get_dataset()
This is to avoid doing the pattern matching multiple times when it's called inside a Session run but the actual resolving and adding is still at _get_dataset time to allow for DataCatalog standalone usage

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a good name for a public method, it's way too descriptive. I would suggest that you change it to def __contains__(self, dataset_name):. This will allow us to do a test like this:
if dataset_name in catalog:..., which is way more intuitive and nicer than catalog.exists_in_catalog_config(dataset_name).

"""Check if a dataset exists in the catalog as an exact match or if it matches a pattern."""
if (
dataset_name in self._data_sets
or dataset_name in self._pattern_matches_cache
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
):
return True
matched_pattern = self.match_name_against_patterns(dataset_name)
if matched_pattern:
# cache the "dataset_name -> pattern" match
self._pattern_matches_cache[dataset_name] = matched_pattern
return True
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we got rid of the overly verbose method name self.match_name_against_patterns and simply replace it with resolve, which could do all that's needed without leakage of abstraction, all this could be just return self.resolve(dataset_name) is not None.


def match_name_against_patterns(self, dataset_name: str) -> str | None:
"""Match a dataset name against existing patterns"""
# Loop through all dataset patterns and check if the given dataset name has a match.
for pattern in self._sorted_dataset_patterns:
result = parse(pattern, dataset_name)
if result:
return pattern
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we need this separately from directly resolving upon check of existence? Could we just call this resolve and do all of the matching and calling the resolve helper here?


def shallow_copy(self) -> DataCatalog:
"""Returns a shallow copy of the current object.

Returns:
Copy of the current object.
"""
return DataCatalog(data_sets=self._data_sets, layers=self.layers)
return DataCatalog(
data_sets=self._data_sets,
layers=self.layers,
dataset_patterns=self.dataset_patterns,
)

def __eq__(self, other):
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
return (self._data_sets, self.layers) == (other._data_sets, other.layers)
return (self._data_sets, self.layers, self.dataset_patterns) == (
other._data_sets,
other.layers,
other.dataset_patterns,
)

def confirm(self, name: str) -> None:
"""Confirm a dataset by its name.
Expand Down
27 changes: 24 additions & 3 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,35 @@ def run(
hook_manager = hook_manager or _NullPluginManager()
catalog = catalog.shallow_copy()

unsatisfied = pipeline.inputs() - set(catalog.list())
# Check which datasets used in the pipeline aren't in the catalog and don't match
# a pattern in the catalog
unregistered_ds = [
ds
for ds in pipeline.data_sets()
if not catalog.exists_in_catalog_config(ds)
]
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved

# Check if there are any input datasets that aren't in the catalog and
# don't match a pattern in the catalog.
unsatisfied = [
input_name
for input_name in pipeline.inputs()
if input_name in unregistered_ds
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using comprehension here makes it look more complicated than it is. Could we keep the original code here where we are using simple set difference?

Suggested change
# Check if there are any input datasets that aren't in the catalog and
# don't match a pattern in the catalog.
unsatisfied = [
input_name
for input_name in pipeline.inputs()
if input_name in unregistered_ds
]
unsatisfied = pipeline.inputs() - set(registered_ds)

This reads in English as "give me all inputs minus all the registered ones".

if unsatisfied:
raise ValueError(
f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
)

free_outputs = pipeline.outputs() - set(catalog.list())
unregistered_ds = pipeline.data_sets() - set(catalog.list())
# Check if there's any output datasets that aren't in the catalog and don't match a pattern
# in the catalog.
free_outputs = [
output_name
for output_name in pipeline.outputs()
if output_name in unregistered_ds
]

# Create a default dataset for unregistered datasets
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
for ds_name in unregistered_ds:
catalog.add(ds_name, self.create_default_data_set(ds_name))

Expand Down
Loading