Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(datasets): don't create connection until need #281

Merged
merged 17 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Upcoming Release
## Major features and improvements
* Moved `PartitionedDataSet` and `IncrementalDataSet` from the core Kedro repo to `kedro-datasets` and renamed to `PartitionedDataset` and `IncrementalDataset`.
* Delayed backend connection for `pandas.SQLTableDataset`, `pandas.SQLQueryDataset`, and `snowflake.SnowparkTableDataset`. In practice, this means that a dataset's connection details aren't used (or validated) until the dataset is accessed. On the plus side, the cost of connection isn't incurred regardless of when or whether the dataset is used.

## Bug fixes and other changes
* Fix erroneous warning when using an cloud protocol file path with SparkDataSet on Databricks.
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
try:
# Custom `KedroDeprecationWarning` class was added in Kedro 0.18.14.
from kedro import KedroDeprecationWarning
except ImportError:
except ImportError: # pragma: no cover

class KedroDeprecationWarning(DeprecationWarning):
"""Custom class for warnings about deprecated Kedro features."""
Expand Down
83 changes: 46 additions & 37 deletions kedro-datasets/kedro_datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""``SQLDataset`` to load and save data to a SQL backend."""
from __future__ import annotations

import copy
import datetime as dt
import re
import warnings
from pathlib import PurePosixPath
from typing import Any, Dict, NoReturn, Optional
from typing import Any, NoReturn

import fsspec
import pandas as pd
Expand Down Expand Up @@ -33,7 +35,7 @@
"""


def _find_known_drivers(module_import_error: ImportError) -> Optional[str]:
def _find_known_drivers(module_import_error: ImportError) -> str | None:
"""Looks up known keywords in a ``ModuleNotFoundError`` so that it can
provide better guideline for the user.

Expand Down Expand Up @@ -145,19 +147,19 @@ class SQLTableDataset(AbstractDataset[pd.DataFrame, pd.DataFrame]):

"""

DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False}
DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {"index": False}
# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
engines: dict[str, Any] = {}

def __init__( # noqa: PLR0913
self,
table_name: str,
credentials: Dict[str, Any],
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
metadata: Dict[str, Any] = None,
credentials: dict[str, Any],
load_args: dict[str, Any] = None,
save_args: dict[str, Any] = None,
metadata: dict[str, Any] = None,
) -> None:
"""Creates a new ``SQLTableDataset``.

Expand Down Expand Up @@ -212,7 +214,6 @@ def __init__( # noqa: PLR0913
self._save_args["name"] = table_name

self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

self.metadata = metadata

Expand All @@ -222,9 +223,6 @@ def create_connection(cls, connection_str: str) -> None:
to be used across all instances of ``SQLTableDataset`` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
Expand All @@ -234,7 +232,17 @@ def create_connection(cls, connection_str: str) -> None:

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
cls = type(self)

if self._connection_str not in cls.engines:
self.create_connection(self._connection_str)

return cls.engines[self._connection_str]

def _describe(self) -> dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
save_args = copy.deepcopy(self._save_args)
del load_args["table_name"]
Expand All @@ -246,16 +254,13 @@ def _describe(self) -> Dict[str, Any]:
}

def _load(self) -> pd.DataFrame:
engine = self.engines[self._connection_str] # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)
return pd.read_sql_table(con=self.engine, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
engine = self.engines[self._connection_str] # type: ignore
data.to_sql(con=engine, **self._save_args)
data.to_sql(con=self.engine, **self._save_args)

def _exists(self) -> bool:
engine = self.engines[self._connection_str] # type: ignore
insp = inspect(engine)
insp = inspect(self.engine)
schema = self._load_args.get("schema", None)
return insp.has_table(self._load_args["table_name"], schema)

Expand All @@ -273,7 +278,6 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):
It does not support save method so it is a read only data set.
To save data to a SQL server use ``SQLTableDataset``.


Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog_yaml_examples.html>`_:
Expand Down Expand Up @@ -375,17 +379,17 @@ class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]):

# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
engines: dict[str, Any] = {}

def __init__( # noqa: PLR0913
self,
sql: str = None,
credentials: Dict[str, Any] = None,
load_args: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
credentials: dict[str, Any] = None,
load_args: dict[str, Any] = None,
fs_args: dict[str, Any] = None,
filepath: str = None,
execution_options: Optional[Dict[str, Any]] = None,
metadata: Dict[str, Any] = None,
execution_options: dict[str, Any] | None = None,
metadata: dict[str, Any] = None,
) -> None:
"""Creates a new ``SQLQueryDataset``.

Expand Down Expand Up @@ -441,7 +445,7 @@ def __init__( # noqa: PLR0913
"provide a SQLAlchemy connection string."
)

default_load_args: Dict[str, Any] = {}
default_load_args: dict[str, Any] = {}

self._load_args = (
{**default_load_args, **load_args}
Expand All @@ -466,7 +470,6 @@ def __init__( # noqa: PLR0913
self._filepath = path
self._connection_str = credentials["con"]
self._execution_options = execution_options or {}
self.create_connection(self._connection_str)
if "mssql" in self._connection_str:
self.adapt_mssql_date_params()

Expand All @@ -476,9 +479,6 @@ def create_connection(cls, connection_str: str) -> None:
to be used across all instances of `SQLQueryDataset` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
Expand All @@ -488,7 +488,17 @@ def create_connection(cls, connection_str: str) -> None:

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
cls = type(self)

if self._connection_str not in cls.engines:
self.create_connection(self._connection_str)

return cls.engines[self._connection_str]

def _describe(self) -> dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
return {
"sql": str(load_args.pop("sql", None)),
Expand All @@ -499,16 +509,15 @@ def _describe(self) -> Dict[str, Any]:

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
engine = self.engines[self._connection_str].execution_options(
**self._execution_options
) # type: ignore

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

return pd.read_sql_query(con=engine, **load_args)
return pd.read_sql_query(
con=self.engine.execution_options(**self._execution_options), **load_args
)

def _save(self, data: None) -> NoReturn:
raise DatasetError("'save' is not supported on SQLQueryDataset")
Expand Down
7 changes: 5 additions & 2 deletions kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def __init__( # noqa: PLR0913
{"database": self._database, "schema": self._schema}
)
self._connection_parameters = connection_parameters
self._session = self._get_session(self._connection_parameters)

self.metadata = metadata

Expand Down Expand Up @@ -207,10 +206,14 @@ def _get_session(connection_parameters) -> sp.Session:
logger.debug("Trying to reuse active snowpark session...")
session = sp.context.get_active_session()
except sp.exceptions.SnowparkSessionException:
logger.debug("No active snowpark session found. Creating")
logger.debug("No active snowpark session found. Creating...")
session = sp.Session.builder.configs(connection_parameters).create()
return session

@property
def _session(self) -> sp.Session:
return self._get_session(self._connection_parameters)

def _load(self) -> sp.DataFrame:
table_name = [
self._database,
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ version = {attr = "kedro_datasets.__version__"}
fail_under = 100
show_missing = true
omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/databricks/*"]
exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
exclude_lines = ["pragma: no cover", "raise NotImplementedError", "if TYPE_CHECKING:"]

[tool.pytest.ini_options]
addopts = """
Expand Down
Loading