diff --git a/RELEASE.md b/RELEASE.md index a67b8d8927..a4e089fa56 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,7 @@ # Upcoming Release 0.18.12 ## Major features and improvements +* Added dataset factories feature which uses pattern matching to reduce the number of catalog entries. ## Bug fixes and other changes diff --git a/dependency/requirements.txt b/dependency/requirements.txt index b7e6f3cc22..14b8e2f244 100644 --- a/dependency/requirements.txt +++ b/dependency/requirements.txt @@ -13,6 +13,7 @@ importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resou jmespath>=0.9.5, <1.0 more_itertools~=9.0 omegaconf~=2.3 +parse~=1.19.0 pip-tools~=6.5 pluggy~=1.0 PyYAML>=4.2, <7.0 diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index 5dacf3f950..4143024a73 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -11,7 +11,9 @@ import logging import re from collections import defaultdict -from typing import Any +from typing import Any, Iterable + +from parse import parse from kedro.io.core import ( AbstractDataSet, @@ -136,11 +138,14 @@ class DataCatalog: to the underlying data sets. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, data_sets: dict[str, AbstractDataSet] = None, feed_dict: dict[str, Any] = None, layers: dict[str, set[str]] = None, + dataset_patterns: dict[str, dict[str, Any]] = None, + load_versions: dict[str, str] = None, + save_version: str = None, ) -> None: """``DataCatalog`` stores instances of ``AbstractDataSet`` implementations to provide ``load`` and ``save`` capabilities from @@ -156,6 +161,15 @@ def __init__( to a set of data set names, according to the data engineering convention. For more details, see https://docs.kedro.org/en/stable/resources/glossary.html#layers-data-engineering-convention + dataset_patterns: A dictionary of data set factory patterns + and corresponding data set configuration + load_versions: A mapping between data set names and versions + to load. Has no effect on data sets without enabled versioning. + save_version: Version string to be used for ``save`` operations + by all data sets with enabled versioning. It must: a) be a + case-insensitive string that conforms with operating system + filename limitations, b) always return the latest version when + sorted in lexicographical order. Example: :: @@ -170,8 +184,12 @@ def __init__( self._data_sets = dict(data_sets or {}) self.datasets = _FrozenDatasets(self._data_sets) self.layers = layers + # Keep a record of all patterns in the catalog. + # {dataset pattern name : dataset pattern body} + self._dataset_patterns = dataset_patterns or {} + self._load_versions = load_versions or {} + self._save_version = save_version - # import the feed dict if feed_dict: self.add_feed_dict(feed_dict) @@ -181,7 +199,7 @@ def _logger(self): @classmethod def from_config( - cls: type, + cls, catalog: dict[str, dict[str, Any]] | None, credentials: dict[str, dict[str, Any]] = None, load_versions: dict[str, str] = None, @@ -257,35 +275,124 @@ class to be loaded is specified with the key ``type`` and their >>> catalog.save("boats", df) """ data_sets = {} + dataset_patterns = {} catalog = copy.deepcopy(catalog) or {} credentials = copy.deepcopy(credentials) or {} save_version = save_version or generate_timestamp() load_versions = copy.deepcopy(load_versions) or {} + layers: dict[str, set[str]] = defaultdict(set) - missing_keys = load_versions.keys() - catalog.keys() + for ds_name, ds_config in catalog.items(): + ds_config = _resolve_credentials(ds_config, credentials) + if cls._is_pattern(ds_name): + # Add each factory to the dataset_patterns dict. + dataset_patterns[ds_name] = ds_config + + else: + ds_layer = ds_config.pop("layer", None) + if ds_layer is not None: + layers[ds_layer].add(ds_name) + data_sets[ds_name] = AbstractDataSet.from_config( + ds_name, ds_config, load_versions.get(ds_name), save_version + ) + dataset_layers = layers or None + sorted_patterns = cls._sort_patterns(dataset_patterns) + missing_keys = [ + key + for key in load_versions.keys() + if not (cls._match_pattern(sorted_patterns, key) or key in catalog) + ] if missing_keys: raise DatasetNotFoundError( f"'load_versions' keys [{', '.join(sorted(missing_keys))}] " f"are not found in the catalog." ) - layers: dict[str, set[str]] = defaultdict(set) - for ds_name, ds_config in catalog.items(): - ds_layer = ds_config.pop("layer", None) - if ds_layer is not None: - layers[ds_layer].add(ds_name) + return cls( + data_sets=data_sets, + layers=dataset_layers, + dataset_patterns=sorted_patterns, + load_versions=load_versions, + save_version=save_version, + ) - ds_config = _resolve_credentials(ds_config, credentials) - data_sets[ds_name] = AbstractDataSet.from_config( - ds_name, ds_config, load_versions.get(ds_name), save_version - ) + @staticmethod + def _is_pattern(pattern: str): + """Check if a given string is a pattern. Assume that any name with '{' is a pattern.""" + return "{" in pattern + + @staticmethod + def _match_pattern( + data_set_patterns: dict[str, dict[str, Any]], data_set_name: str + ) -> str | None: + """Match a dataset name against patterns in a dictionary containing patterns""" + for pattern, _ in data_set_patterns.items(): + result = parse(pattern, data_set_name) + if result: + return pattern + return None - dataset_layers = layers or None - return cls(data_sets=data_sets, layers=dataset_layers) + @classmethod + def _sort_patterns( + cls, data_set_patterns: dict[str, dict[str, Any]] + ) -> dict[str, dict[str, Any]]: + """Sort a dictionary of dataset patterns according to parsing rules - + 1. Decreasing specificity (number of characters outside the curly brackets) + 2. Decreasing number of placeholders (number of curly bracket pairs) + 3. Alphabetically + """ + sorted_keys = sorted( + data_set_patterns, + key=lambda pattern: ( + -(cls._specificity(pattern)), + -pattern.count("{"), + pattern, + ), + ) + sorted_patterns = {} + for key in sorted_keys: + sorted_patterns[key] = data_set_patterns[key] + return sorted_patterns + + @staticmethod + def _specificity(pattern: str) -> int: + """Helper function to check the length of exactly matched characters not inside brackets + Example - + specificity("{namespace}.companies") = 10 + specificity("{namespace}.{dataset}") = 1 + specificity("france.companies") = 16 + """ + # Remove all the placeholders from the pattern and count the number of remaining chars + result = re.sub(r"\{.*?\}", "", pattern) + return len(result) def _get_dataset( self, data_set_name: str, version: Version = None, suggest: bool = True ) -> AbstractDataSet: + matched_pattern = self._match_pattern(self._dataset_patterns, data_set_name) + if data_set_name not in self._data_sets and matched_pattern: + # If the dataset is a patterned dataset, materialise it and add it to + # the catalog + data_set_config = self._resolve_config(data_set_name, matched_pattern) + ds_layer = data_set_config.pop("layer", None) + if ds_layer: + self.layers = self.layers or {} + self.layers.setdefault(ds_layer, set()).add(data_set_name) + data_set = AbstractDataSet.from_config( + data_set_name, + data_set_config, + self._load_versions.get(data_set_name), + self._save_version, + ) + if self._specificity(matched_pattern) == 0: + self._logger.warning( + "Config from the dataset factory pattern '%s' in the catalog will be used to " + "override the default MemoryDataset creation for the dataset '%s'", + matched_pattern, + data_set_name, + ) + + self.add(data_set_name, data_set) if data_set_name not in self._data_sets: error_msg = f"Dataset '{data_set_name}' not found in the catalog" @@ -298,9 +405,7 @@ def _get_dataset( if matches: suggestions = ", ".join(matches) error_msg += f" - did you mean one of these instead: {suggestions}" - raise DatasetNotFoundError(error_msg) - data_set = self._data_sets[data_set_name] if version and isinstance(data_set, AbstractVersionedDataSet): # we only want to return a similar-looking dataset, @@ -311,6 +416,35 @@ def _get_dataset( return data_set + def __contains__(self, data_set_name): + """Check if an item is in the catalog as a materialised dataset or pattern""" + matched_pattern = self._match_pattern(self._dataset_patterns, data_set_name) + if data_set_name in self._data_sets or matched_pattern: + return True + return False + + def _resolve_config( + self, + data_set_name: str, + matched_pattern: str, + ) -> dict[str, Any]: + """Get resolved AbstractDataSet from a factory config""" + result = parse(matched_pattern, data_set_name) + config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern]) + # Resolve the factory config for the dataset + for key, value in config_copy.items(): + if isinstance(value, Iterable) and "}" in value: + # result.named: gives access to all dict items in the match result. + # format_map fills in dict values into a string with {...} placeholders + # of the same key name. + try: + config_copy[key] = str(value).format_map(result.named) + except KeyError as exc: + raise DatasetError( + f"Unable to resolve '{key}' for the pattern '{matched_pattern}'" + ) from exc + return config_copy + def load(self, name: str, version: str = None) -> Any: """Loads a registered data set. @@ -573,10 +707,20 @@ def shallow_copy(self) -> DataCatalog: Returns: Copy of the current object. """ - return DataCatalog(data_sets=self._data_sets, layers=self.layers) + return DataCatalog( + data_sets=self._data_sets, + layers=self.layers, + dataset_patterns=self._dataset_patterns, + load_versions=self._load_versions, + save_version=self._save_version, + ) def __eq__(self, other): - return (self._data_sets, self.layers) == (other._data_sets, other.layers) + return (self._data_sets, self.layers, self._dataset_patterns) == ( + other._data_sets, + other.layers, + other._dataset_patterns, + ) def confirm(self, name: str) -> None: """Confirm a dataset by its name. diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index d7a1c51ad7..98ad6c4268 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -32,7 +32,7 @@ class AbstractRunner(ABC): """ def __init__(self, is_async: bool = False): - """Instantiates the runner classs. + """Instantiates the runner class. Args: is_async: If True, the node inputs and outputs are loaded and saved @@ -74,14 +74,25 @@ def run( hook_manager = hook_manager or _NullPluginManager() catalog = catalog.shallow_copy() - unsatisfied = pipeline.inputs() - set(catalog.list()) + # Check which datasets used in the pipeline are in the catalog or match + # a pattern in the catalog + registered_ds = [ds for ds in pipeline.data_sets() if ds in catalog] + + # Check if there are any input datasets that aren't in the catalog and + # don't match a pattern in the catalog. + unsatisfied = pipeline.inputs() - set(registered_ds) + if unsatisfied: raise ValueError( f"Pipeline input(s) {unsatisfied} not found in the DataCatalog" ) - free_outputs = pipeline.outputs() - set(catalog.list()) - unregistered_ds = pipeline.data_sets() - set(catalog.list()) + # Check if there's any output datasets that aren't in the catalog and don't match a pattern + # in the catalog. + free_outputs = pipeline.outputs() - set(registered_ds) + unregistered_ds = pipeline.data_sets() - set(registered_ds) + + # Create a default dataset for unregistered datasets for ds_name in unregistered_ds: catalog.add(ds_name, self.create_default_data_set(ds_name)) @@ -420,7 +431,7 @@ def _run_node_sequential( items: Iterable = outputs.items() # if all outputs are iterators, then the node is a generator node if all(isinstance(d, Iterator) for d in outputs.values()): - # Python dictionaries are ordered so we are sure + # Python dictionaries are ordered, so we are sure # the keys and the chunk streams are in the same order # [a, b, c] keys = list(outputs.keys()) diff --git a/tests/io/test_data_catalog.py b/tests/io/test_data_catalog.py index cc9bdb5fe5..74858702bf 100644 --- a/tests/io/test_data_catalog.py +++ b/tests/io/test_data_catalog.py @@ -87,6 +87,68 @@ def sane_config_with_tracking_ds(tmp_path): } +@pytest.fixture +def config_with_dataset_factories(): + return { + "catalog": { + "{brand}_cars": { + "type": "pandas.CSVDataSet", + "filepath": "data/01_raw/{brand}_cars.csv", + }, + "audi_cars": { + "type": "pandas.ParquetDataSet", + "filepath": "data/01_raw/audi_cars.pq", + }, + "{type}_boats": { + "type": "pandas.CSVDataSet", + "filepath": "data/01_raw/{type}_boats.csv", + }, + }, + } + + +@pytest.fixture +def config_with_dataset_factories_with_default(config_with_dataset_factories): + config_with_dataset_factories["catalog"]["{default_dataset}"] = { + "type": "pandas.CSVDataSet", + "filepath": "data/01_raw/{default_dataset}.csv", + } + return config_with_dataset_factories + + +@pytest.fixture +def config_with_dataset_factories_bad_pattern(config_with_dataset_factories): + config_with_dataset_factories["catalog"]["{type}@planes"] = { + "type": "pandas.ParquetDataSet", + "filepath": "data/01_raw/{brand}_plane.pq", + } + return config_with_dataset_factories + + +@pytest.fixture +def config_with_dataset_factories_only_patterns(): + return { + "catalog": { + "{default}": { + "type": "pandas.CSVDataSet", + "filepath": "data/01_raw/{default}.csv", + }, + "{namespace}_{dataset}": { + "type": "pandas.CSVDataSet", + "filepath": "data/01_raw/{namespace}_{dataset}.pq", + }, + "{country}_companies": { + "type": "pandas.CSVDataSet", + "filepath": "data/01_raw/{country}_companies.csv", + }, + "{dataset}s": { + "type": "pandas.CSVDataSet", + "filepath": "data/01_raw/{dataset}s.csv", + }, + }, + } + + @pytest.fixture def data_set(filepath): return CSVDataSet(filepath=filepath, save_args={"index": False}) @@ -683,3 +745,151 @@ def test_no_versions_with_cloud_protocol(self): ) with pytest.raises(DatasetError, match=pattern): versioned_dataset.load() + + +class TestDataCatalogDatasetFactories: + def test_match_added_to_datasets_on_get(self, config_with_dataset_factories): + """Check that the datasets that match patterns are only added when fetched""" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + assert "{brand}_cars" not in catalog._data_sets + assert "tesla_cars" not in catalog._data_sets + assert "{brand}_cars" in catalog._dataset_patterns + + tesla_cars = catalog._get_dataset("tesla_cars") + assert isinstance(tesla_cars, CSVDataSet) + assert "tesla_cars" in catalog._data_sets + + @pytest.mark.parametrize( + "dataset_name, expected", + [ + ("audi_cars", True), + ("tesla_cars", True), + ("row_boats", True), + ("boats", False), + ("tesla_card", False), + ], + ) + def test_exists_in_catalog_config( + self, config_with_dataset_factories, dataset_name, expected + ): + """Check that the dataset exists in catalog when it matches a pattern + or is in the catalog""" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + assert (dataset_name in catalog) == expected + + def test_patterns_not_in_catalog_datasets(self, config_with_dataset_factories): + """Check that the pattern is not in the catalog datasets""" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + assert "audi_cars" in catalog._data_sets + assert "{brand}_cars" not in catalog._data_sets + assert "audi_cars" not in catalog._dataset_patterns + assert "{brand}_cars" in catalog._dataset_patterns + + def test_explicit_entry_not_overwritten(self, config_with_dataset_factories): + """Check that the existing catalog entry is not overwritten by config in pattern""" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + audi_cars = catalog._get_dataset("audi_cars") + assert isinstance(audi_cars, ParquetDataSet) + + @pytest.mark.parametrize( + "dataset_name,pattern", + [ + ("missing", "Dataset 'missing' not found in the catalog"), + ("tesla@cars", "Dataset 'tesla@cars' not found in the catalog"), + ], + ) + def test_dataset_not_in_catalog_when_no_pattern_match( + self, config_with_dataset_factories, dataset_name, pattern + ): + """Check that the dataset is not added to the catalog when there is no pattern""" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + with pytest.raises(DatasetError, match=re.escape(pattern)): + catalog._get_dataset(dataset_name) + + def test_sorting_order_patterns(self, config_with_dataset_factories_only_patterns): + """Check that the sorted order of the patterns is correct according + to parsing rules""" + catalog = DataCatalog.from_config(**config_with_dataset_factories_only_patterns) + sorted_keys_expected = [ + "{country}_companies", + "{namespace}_{dataset}", + "{dataset}s", + "{default}", + ] + assert list(catalog._dataset_patterns.keys()) == sorted_keys_expected + + def test_default_dataset(self, config_with_dataset_factories_with_default, caplog): + """Check that default dataset is used when no other pattern matches""" + catalog = DataCatalog.from_config(**config_with_dataset_factories_with_default) + assert "jet@planes" not in catalog._data_sets + jet_dataset = catalog._get_dataset("jet@planes") + log_record = caplog.records[0] + assert log_record.levelname == "WARNING" + assert ( + "Config from the dataset factory pattern '{default_dataset}' " + "in the catalog will be used to override the default " + "MemoryDataset creation for the dataset 'jet@planes'" in log_record.message + ) + assert isinstance(jet_dataset, CSVDataSet) + + def test_unmatched_key_error_when_parsing_config( + self, config_with_dataset_factories_bad_pattern + ): + """Check error raised when key mentioned in the config is not in pattern name""" + catalog = DataCatalog.from_config(**config_with_dataset_factories_bad_pattern) + pattern = "Unable to resolve 'filepath' for the pattern '{type}@planes'" + with pytest.raises(DatasetError, match=re.escape(pattern)): + catalog._get_dataset("jet@planes") + + def test_factory_layer(self, config_with_dataset_factories): + """Check that layer is correctly processed for patterned datasets""" + config_with_dataset_factories["catalog"]["{brand}_cars"]["layer"] = "raw" + catalog = DataCatalog.from_config(**config_with_dataset_factories) + _ = catalog._get_dataset("tesla_cars") + assert catalog.layers["raw"] == {"tesla_cars"} + + def test_factory_config_versioned( + self, config_with_dataset_factories, filepath, dummy_dataframe + ): + """Test load and save of versioned data sets from config""" + config_with_dataset_factories["catalog"]["{brand}_cars"]["versioned"] = True + config_with_dataset_factories["catalog"]["{brand}_cars"]["filepath"] = filepath + + assert "tesla_cars" not in config_with_dataset_factories + + # Decompose `generate_timestamp` to keep `current_ts` reference. + current_ts = datetime.now(tz=timezone.utc) + fmt = ( + "{d.year:04d}-{d.month:02d}-{d.day:02d}T{d.hour:02d}" + ".{d.minute:02d}.{d.second:02d}.{ms:03d}Z" + ) + version = fmt.format(d=current_ts, ms=current_ts.microsecond // 1000) + + catalog = DataCatalog.from_config( + **config_with_dataset_factories, + load_versions={"tesla_cars": version}, + save_version=version, + ) + + catalog.save("tesla_cars", dummy_dataframe) + path = Path( + config_with_dataset_factories["catalog"]["{brand}_cars"]["filepath"] + ) + path = path / version / path.name + assert path.is_file() + + reloaded_df = catalog.load("tesla_cars") + assert_frame_equal(reloaded_df, dummy_dataframe) + + reloaded_df_version = catalog.load("tesla_cars", version=version) + assert_frame_equal(reloaded_df_version, dummy_dataframe) + + # Verify that `VERSION_FORMAT` can help regenerate `current_ts`. + actual_timestamp = datetime.strptime( + catalog.datasets.tesla_cars.resolve_load_version(), # pylint: disable=no-member + VERSION_FORMAT, + ) + expected_timestamp = current_ts.replace( + microsecond=current_ts.microsecond // 1000 * 1000, tzinfo=None + ) + assert actual_timestamp == expected_timestamp