From 37cabc46532b3c6debf5e5552ecfe86bbddf7b80 Mon Sep 17 00:00:00 2001 From: Ankita Katiyar Date: Wed, 28 Jun 2023 17:04:09 +0100 Subject: [PATCH] Dataset factories new implementation Signed-off-by: Ankita Katiyar --- kedro/io/data_catalog.py | 171 +++++++++++++++++++++++++++++++++++---- kedro/runner/runner.py | 8 +- 2 files changed, 160 insertions(+), 19 deletions(-) diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index 5dacf3f950..da8e945748 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -11,7 +11,9 @@ import logging import re from collections import defaultdict -from typing import Any +from typing import Any, Iterable + +from parse import parse from kedro.io.core import ( AbstractDataSet, @@ -141,6 +143,7 @@ def __init__( data_sets: dict[str, AbstractDataSet] = None, feed_dict: dict[str, Any] = None, layers: dict[str, set[str]] = None, + raw_catalog: dict[str, dict[str, Any]] = None, ) -> None: """``DataCatalog`` stores instances of ``AbstractDataSet`` implementations to provide ``load`` and ``save`` capabilities from @@ -169,7 +172,8 @@ def __init__( """ self._data_sets = dict(data_sets or {}) self.datasets = _FrozenDatasets(self._data_sets) - self.layers = layers + self.layers = layers or {} + self._raw_catalog = raw_catalog or {} # import the feed dict if feed_dict: @@ -181,7 +185,7 @@ def _logger(self): @classmethod def from_config( - cls: type, + cls, catalog: dict[str, dict[str, Any]] | None, credentials: dict[str, dict[str, Any]] = None, load_versions: dict[str, str] = None, @@ -262,30 +266,143 @@ class to be loaded is specified with the key ``type`` and their save_version = save_version or generate_timestamp() load_versions = copy.deepcopy(load_versions) or {} - missing_keys = load_versions.keys() - catalog.keys() + layers: dict[str, set[str]] = defaultdict(set) + + missing_keys = [ + key + for key in load_versions.keys() + if not cls._match_name_against_pattern(catalog, key) # type: ignore + ] if missing_keys: raise DatasetNotFoundError( f"'load_versions' keys [{', '.join(sorted(missing_keys))}] " f"are not found in the catalog." ) - layers: dict[str, set[str]] = defaultdict(set) - for ds_name, ds_config in catalog.items(): - ds_layer = ds_config.pop("layer", None) - if ds_layer is not None: - layers[ds_layer].add(ds_name) - - ds_config = _resolve_credentials(ds_config, credentials) - data_sets[ds_name] = AbstractDataSet.from_config( - ds_name, ds_config, load_versions.get(ds_name), save_version + # Add load_version for dataset in raw catalog dict, create an entry if it + # doesn't exist because this might be a patterned dataset + for ds_name, version in load_versions.items(): + catalog.setdefault(ds_name, {}).update({"load_version": version}) + # Update raw catalog dict with save version and resolved credentials + catalog = { + ds_name: _resolve_credentials( + {**ds_config, "save_version": save_version}, credentials ) + for ds_name, ds_config in catalog.items() + } + # Already add explicit entry datasets + for ds_name, ds_config in catalog.items(): + if "}" not in ds_name and cls._is_full_config(ds_config): + ds_layer = ds_config.pop("layer", None) + if ds_layer: + layers[ds_layer].add(ds_name) + dataset = cls._create_dataset(ds_name, ds_config) + if dataset: + data_sets[ds_name] = dataset dataset_layers = layers or None - return cls(data_sets=data_sets, layers=dataset_layers) + # Sort raw catalog according to parsing rules + sorted_raw_catalog = cls._sort_catalog(catalog) + return cls( + data_sets=data_sets, layers=dataset_layers, raw_catalog=sorted_raw_catalog + ) + + @classmethod + def _create_dataset( + cls, data_set_name: str, data_set_config: dict[str, Any] + ) -> AbstractDataSet: + """Factory method for creating AbstractDataset from raw config""" + load_version = data_set_config.pop("load_version", None) + save_version = data_set_config.pop("save_version", None) + return AbstractDataSet.from_config( + data_set_name, + data_set_config, + load_version=load_version, + save_version=save_version, + ) + + @staticmethod + def _is_full_config(config: dict[str, Any]) -> bool: + """Check if the config is a full config""" + remaining = set(config.keys()) - {"load_version", "save_version"} + return bool(remaining) + + def __contains__(self, item): + """Check if an item is in the catalog as a materialised dataset or pattern""" + if item in self._data_sets or self._match_name_against_pattern( + self._raw_catalog, item + ): + return True + return False + + @classmethod + def _match_name_against_pattern( + cls, raw_catalog: dict[str, Any], data_set_name: str + ) -> str | None: + """Match a dataset name against patterns in a catalog""" + for pattern, config in raw_catalog.items(): + result = parse(pattern, data_set_name) + + if result and cls._is_full_config(config): + return pattern + return None + + @classmethod + def _sort_catalog(cls, catalog: dict[str, Any]) -> dict[str, Any]: + """Sort the catalog according to parsing rules - + Helper function to ort all the patterns according to the parsing rules - + 1. All the catalog entries that are not patterns + 2. Decreasing specificity (no of characters outside the brackets) + 3. Decreasing number of placeholders (no of curly brackets) + 4. Alphabetical + """ + + sorted_keys = sorted( + catalog, + key=lambda pattern: ( + ("{" in pattern), + -(cls._specificity(pattern)), + -pattern.count("{"), + pattern, + ), + ) + sorted_catalog = {} + for key in sorted_keys: + sorted_catalog[key] = catalog[key] + return sorted_catalog + + @staticmethod + def _specificity(pattern: str) -> int: + """Helper function to check the length of exactly matched characters not inside brackets + Example - + specificity("{namespace}.companies") = 10 + specificity("{namespace}.{dataset}") = 1 + specificity("france.companies") = 16 + Args: + pattern: The factory pattern + """ + # Remove all the placeholders from the pattern and count the number of remaining chars + result = re.sub(r"\{.*?\}", "", pattern) + return len(result) def _get_dataset( self, data_set_name: str, version: Version = None, suggest: bool = True ) -> AbstractDataSet: + if data_set_name in self and data_set_name not in self._data_sets: + # If the dataset is not in self._data_sets as a materialised entry, + # add it to the catalog + data_set_pattern = self._match_name_against_pattern( + self._raw_catalog, data_set_name + ) + + data_set_config = self._resolve_config(data_set_name, data_set_pattern) + ds_layer = data_set_config.pop("layer", None) + if ds_layer: + self.layers[data_set_name].add(ds_layer) + data_set = self._create_dataset(data_set_name, data_set_config) + if data_set.exists(): + self.add(data_set_name, data_set) + if data_set_name not in self._data_sets: error_msg = f"Dataset '{data_set_name}' not found in the catalog" @@ -311,6 +428,28 @@ def _get_dataset( return data_set + def _resolve_config(self, data_set_name, data_set_pattern) -> dict[str, Any]: + """Get resolved config from a patterned config with placeholders + replaced with actual values""" + result = parse(data_set_pattern, data_set_name) + config_copy = self._raw_catalog[data_set_pattern].copy() + if data_set_name in self._raw_catalog: + # Merge config with entry created containing load and save versions + config_copy.update(self._raw_catalog[data_set_name]) + for key, value in config_copy.items(): + if isinstance(value, Iterable) and "}" in value: + string_value = str(value) + # result.named: gives access to all dict items in the match result. + # format_map fills in dict values into a string with {...} placeholders + # of the same key name. + try: + config_copy[key] = string_value.format_map(result.named) + except KeyError as exc: + raise DatasetError( + f"Unable to resolve '{key}' for the pattern '{data_set_pattern}'" + ) from exc + return config_copy + def load(self, name: str, version: str = None) -> Any: """Loads a registered data set. @@ -573,7 +712,9 @@ def shallow_copy(self) -> DataCatalog: Returns: Copy of the current object. """ - return DataCatalog(data_sets=self._data_sets, layers=self.layers) + return DataCatalog( + data_sets=self._data_sets, layers=self.layers, raw_catalog=self._raw_catalog + ) def __eq__(self, other): return (self._data_sets, self.layers) == (other._data_sets, other.layers) diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index d7a1c51ad7..3668de3b6f 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -73,15 +73,15 @@ def run( hook_manager = hook_manager or _NullPluginManager() catalog = catalog.shallow_copy() - - unsatisfied = pipeline.inputs() - set(catalog.list()) + registered_ds = [ds for ds in pipeline.data_sets() if ds in catalog] + unsatisfied = pipeline.inputs() - set(registered_ds) if unsatisfied: raise ValueError( f"Pipeline input(s) {unsatisfied} not found in the DataCatalog" ) - free_outputs = pipeline.outputs() - set(catalog.list()) - unregistered_ds = pipeline.data_sets() - set(catalog.list()) + free_outputs = pipeline.outputs() - set(registered_ds) + unregistered_ds = pipeline.data_sets() - set(registered_ds) for ds_name in unregistered_ds: catalog.add(ds_name, self.create_default_data_set(ds_name))