Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Dataset factories (new version) #2743

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 156 additions & 15 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import logging
import re
from collections import defaultdict
from typing import Any
from typing import Any, Iterable

from parse import parse

from kedro.io.core import (
AbstractDataSet,
Expand Down Expand Up @@ -141,6 +143,7 @@ def __init__(
data_sets: dict[str, AbstractDataSet] = None,
feed_dict: dict[str, Any] = None,
layers: dict[str, set[str]] = None,
raw_catalog: dict[str, dict[str, Any]] = None,
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataSet``
implementations to provide ``load`` and ``save`` capabilities from
Expand Down Expand Up @@ -169,7 +172,8 @@ def __init__(
"""
self._data_sets = dict(data_sets or {})
self.datasets = _FrozenDatasets(self._data_sets)
self.layers = layers
self.layers = layers or {}
self._raw_catalog = raw_catalog or {}

# import the feed dict
if feed_dict:
Expand All @@ -181,7 +185,7 @@ def _logger(self):

@classmethod
def from_config(
cls: type,
cls,
catalog: dict[str, dict[str, Any]] | None,
credentials: dict[str, dict[str, Any]] = None,
load_versions: dict[str, str] = None,
Expand Down Expand Up @@ -262,30 +266,143 @@ class to be loaded is specified with the key ``type`` and their
save_version = save_version or generate_timestamp()
load_versions = copy.deepcopy(load_versions) or {}

missing_keys = load_versions.keys() - catalog.keys()
layers: dict[str, set[str]] = defaultdict(set)

missing_keys = [
key
for key in load_versions.keys()
if not cls._match_name_against_pattern(catalog, key) # type: ignore
]
if missing_keys:
raise DatasetNotFoundError(
f"'load_versions' keys [{', '.join(sorted(missing_keys))}] "
f"are not found in the catalog."
)

layers: dict[str, set[str]] = defaultdict(set)
for ds_name, ds_config in catalog.items():
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
# Add load_version for dataset in raw catalog dict, create an entry if it
# doesn't exist because this might be a patterned dataset
for ds_name, version in load_versions.items():
catalog.setdefault(ds_name, {}).update({"load_version": version})
# Update raw catalog dict with save version and resolved credentials
catalog = {
ds_name: _resolve_credentials(
{**ds_config, "save_version": save_version}, credentials
)
for ds_name, ds_config in catalog.items()
}
# Already add explicit entry datasets
for ds_name, ds_config in catalog.items():
if "}" not in ds_name and cls._is_full_config(ds_config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a small helper function here i.e. _is_pattern even if it's a one-liner, as the semantics make it easier to read.

ds_layer = ds_config.pop("layer", None)
if ds_layer:
layers[ds_layer].add(ds_name)
dataset = cls._create_dataset(ds_name, ds_config)
if dataset:
data_sets[ds_name] = dataset

dataset_layers = layers or None
return cls(data_sets=data_sets, layers=dataset_layers)
# Sort raw catalog according to parsing rules
sorted_raw_catalog = cls._sort_catalog(catalog)
return cls(
data_sets=data_sets, layers=dataset_layers, raw_catalog=sorted_raw_catalog
)

@classmethod
def _create_dataset(
cls, data_set_name: str, data_set_config: dict[str, Any]
) -> AbstractDataSet:
"""Factory method for creating AbstractDataset from raw config"""
load_version = data_set_config.pop("load_version", None)
save_version = data_set_config.pop("save_version", None)
return AbstractDataSet.from_config(
data_set_name,
data_set_config,
load_version=load_version,
save_version=save_version,
)

@staticmethod
def _is_full_config(config: dict[str, Any]) -> bool:
"""Check if the config is a full config"""
remaining = set(config.keys()) - {"load_version", "save_version"}
return bool(remaining)
Comment on lines +324 to +328
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why config.keys() substract load_version and save_version equal to a full config?


def __contains__(self, item):
"""Check if an item is in the catalog as a materialised dataset or pattern"""
if item in self._data_sets or self._match_name_against_pattern(
self._raw_catalog, item
):
return True
return False
Comment on lines +330 to +336
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼 I like this


@classmethod
def _match_name_against_pattern(
cls, raw_catalog: dict[str, Any], data_set_name: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is raw_cata

Suggested change
cls, raw_catalog: dict[str, Any], data_set_name: str
cls, raw_catalog: dict[str, dict[str, Any]], data_set_name: str

Is this the correct typing? When I read through it I have a hard time to map all the types. Maybe worth to create TypeAlias .

https://docs.python.org/3/library/typing.html#type-aliases

) -> str | None:
"""Match a dataset name against patterns in a catalog"""
for pattern, config in raw_catalog.items():
result = parse(pattern, data_set_name)

if result and cls._is_full_config(config):
return pattern
return None

@classmethod
def _sort_catalog(cls, catalog: dict[str, Any]) -> dict[str, Any]:
"""Sort the catalog according to parsing rules -
Helper function to ort all the patterns according to the parsing rules -
1. All the catalog entries that are not patterns
2. Decreasing specificity (no of characters outside the brackets)
3. Decreasing number of placeholders (no of curly brackets)
4. Alphabetical
"""

sorted_keys = sorted(
catalog,
key=lambda pattern: (
("{" in pattern),
-(cls._specificity(pattern)),
-pattern.count("{"),
pattern,
),
)
sorted_catalog = {}
for key in sorted_keys:
sorted_catalog[key] = catalog[key]
return sorted_catalog

@staticmethod
def _specificity(pattern: str) -> int:
"""Helper function to check the length of exactly matched characters not inside brackets
Example -
specificity("{namespace}.companies") = 10
specificity("{namespace}.{dataset}") = 1
specificity("france.companies") = 16
Args:
pattern: The factory pattern
"""
# Remove all the placeholders from the pattern and count the number of remaining chars
result = re.sub(r"\{.*?\}", "", pattern)
return len(result)

def _get_dataset(
self, data_set_name: str, version: Version = None, suggest: bool = True
) -> AbstractDataSet:
if data_set_name in self and data_set_name not in self._data_sets:
# If the dataset is not in self._data_sets as a materialised entry,
# add it to the catalog
data_set_pattern = self._match_name_against_pattern(
self._raw_catalog, data_set_name
)

data_set_config = self._resolve_config(data_set_name, data_set_pattern)
ds_layer = data_set_config.pop("layer", None)
if ds_layer:
self.layers[data_set_name].add(ds_layer)
data_set = self._create_dataset(data_set_name, data_set_config)
if data_set.exists():
self.add(data_set_name, data_set)

if data_set_name not in self._data_sets:
error_msg = f"Dataset '{data_set_name}' not found in the catalog"

Expand All @@ -311,6 +428,28 @@ def _get_dataset(

return data_set

def _resolve_config(self, data_set_name, data_set_pattern) -> dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _resolve_config(self, data_set_name, data_set_pattern) -> dict[str, Any]:
def _resolve_config(self, data_set_name: str, data_set_pattern: str) -> dict[str, Any]:

"""Get resolved config from a patterned config with placeholders
replaced with actual values"""
result = parse(data_set_pattern, data_set_name)
config_copy = self._raw_catalog[data_set_pattern].copy()
if data_set_name in self._raw_catalog:
# Merge config with entry created containing load and save versions
config_copy.update(self._raw_catalog[data_set_name])
for key, value in config_copy.items():
if isinstance(value, Iterable) and "}" in value:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we check for "{" sometimes but "}"? Again can we just use one function to check if something is a pattern?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am note sure why are we checking for Iterable here?

string_value = str(value)
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
config_copy[key] = string_value.format_map(result.named)
Comment on lines +441 to +446
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
string_value = str(value)
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
config_copy[key] = string_value.format_map(result.named)
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
config_copy[key] = str(value).format_map(result.named)

except KeyError as exc:
raise DatasetError(
f"Unable to resolve '{key}' for the pattern '{data_set_pattern}'"
) from exc
return config_copy

def load(self, name: str, version: str = None) -> Any:
"""Loads a registered data set.

Expand Down Expand Up @@ -573,7 +712,9 @@ def shallow_copy(self) -> DataCatalog:
Returns:
Copy of the current object.
"""
return DataCatalog(data_sets=self._data_sets, layers=self.layers)
return DataCatalog(
data_sets=self._data_sets, layers=self.layers, raw_catalog=self._raw_catalog
)

def __eq__(self, other):
return (self._data_sets, self.layers) == (other._data_sets, other.layers)
Expand Down
8 changes: 4 additions & 4 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ def run(

hook_manager = hook_manager or _NullPluginManager()
catalog = catalog.shallow_copy()

unsatisfied = pipeline.inputs() - set(catalog.list())
registered_ds = [ds for ds in pipeline.data_sets() if ds in catalog]
unsatisfied = pipeline.inputs() - set(registered_ds)
if unsatisfied:
raise ValueError(
f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
)

free_outputs = pipeline.outputs() - set(catalog.list())
unregistered_ds = pipeline.data_sets() - set(catalog.list())
free_outputs = pipeline.outputs() - set(registered_ds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point is all the pattern datasets materialised in catalog already?

unregistered_ds = pipeline.data_sets() - set(registered_ds)
for ds_name in unregistered_ds:
catalog.add(ds_name, self.create_default_data_set(ds_name))

Expand Down