From 58e97ac7eae7c932aa8a629f7c117ae3bca188d5 Mon Sep 17 00:00:00 2001 From: Ben Levy <79479484+BenjaminLevyQB@users.noreply.github.com> Date: Mon, 25 Oct 2021 12:43:24 -0400 Subject: [PATCH] Add an option to load SQL queries from a file for SQLQueryDataSet (#887) Signed-off-by: Laurens Vijnck --- RELEASE.md | 6 +- kedro/extras/datasets/pandas/sql_dataset.py | 71 ++++++++++++++++--- .../datasets/pandas/test_sql_dataset.py | 57 +++++++++++++-- 3 files changed, 117 insertions(+), 17 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index b24bbd760f..80c1298c4b 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -44,6 +44,9 @@ * Bumped minimum required `fsspec` version to 2021.04. * Fixed the `kedro install` and `kedro build-reqs` flows when uninstalled dependencies are present in a project's `settings.py`, `context.py` or `hooks.py` ([Issue #829](https://github.com/quantumblacklabs/kedro/issues/829)). * Imports are now refactored at `kedro pipeline package` and `kedro pipeline pull` time, so that _aliasing_ a modular pipeline doesn't break it. +* Added option to `pandas.SQLQueryDataSet` to specify a `filepath` with a SQL query, in addition to the current method of supplying the query itself in the `sql` argument. + +## Minor breaking changes to the API * Pinned `dynaconf` to `<3.1.6` because the method signature for `_validate_items` changed which is used in Kedro. ## Upcoming deprecations for Kedro 0.18.0 @@ -52,7 +55,8 @@ ## Thanks for supporting contributions [Moussa Taifi](https://github.com/moutai), -[Deepyaman Datta](https://github.com/deepyaman) +[Deepyaman Datta](https://github.com/deepyaman), +[Benjamin Levy](https://github.com/BenjaminLevyQB) # Release 0.17.4 diff --git a/kedro/extras/datasets/pandas/sql_dataset.py b/kedro/extras/datasets/pandas/sql_dataset.py index 06caee3990..d4a3d134d0 100644 --- a/kedro/extras/datasets/pandas/sql_dataset.py +++ b/kedro/extras/datasets/pandas/sql_dataset.py @@ -29,13 +29,20 @@ import copy import re +from pathlib import PurePosixPath from typing import Any, Dict, Optional +import fsspec import pandas as pd from sqlalchemy import create_engine from sqlalchemy.exc import NoSuchModuleError -from kedro.io.core import AbstractDataSet, DataSetError +from kedro.io.core import ( + AbstractDataSet, + DataSetError, + get_filepath_str, + get_protocol_and_path, +) __all__ = ["SQLTableDataSet", "SQLQueryDataSet"] @@ -278,8 +285,13 @@ class SQLQueryDataSet(AbstractDataSet): """ - def __init__( - self, sql: str, credentials: Dict[str, Any], load_args: Dict[str, Any] = None + def __init__( # pylint: disable=too-many-arguments + self, + sql: str = None, + credentials: Dict[str, Any] = None, + load_args: Dict[str, Any] = None, + fs_args: Dict[str, Any] = None, + filepath: str = None, ) -> None: """Creates a new ``SQLQueryDataSet``. @@ -297,14 +309,28 @@ def __init__( https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_query.html To find all supported connection string formats, see here: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as + to pass to the filesystem's `open` method through nested keys + `open_args_load` and `open_args_save`. + Here you can find all available arguments for `open`: + https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open + All defaults are preserved, except `mode`, which is set to `r` when loading. + filepath: A path to a file with a sql query statement. Raises: - DataSetError: When either ``sql`` or ``con`` parameters is emtpy. + DataSetError: When either ``sql`` or ``con`` parameters is empty. """ + if sql and filepath: + raise DataSetError( + "`sql` and `filepath` arguments cannot both be provided." + "Please only provide one." + ) - if not sql: + if not (sql or filepath): raise DataSetError( - "`sql` argument cannot be empty. Please provide a sql query" + "`sql` and `filepath` arguments cannot both be empty." + "Please provide a sql query or path to a sql query file." ) if not (credentials and "con" in credentials and credentials["con"]): @@ -321,18 +347,41 @@ def __init__( else default_load_args ) - self._load_args["sql"] = sql + # load sql query from file + if sql: + self._load_args["sql"] = sql + self._filepath = None + else: + # filesystem for loading sql file + _fs_args = copy.deepcopy(fs_args) or {} + _fs_credentials = _fs_args.pop("credentials", {}) + protocol, path = get_protocol_and_path(str(filepath)) + + self._protocol = protocol + self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args) + self._filepath = path self._load_args["con"] = credentials["con"] def _describe(self) -> Dict[str, Any]: - load_args = self._load_args.copy() - del load_args["sql"] + load_args = copy.deepcopy(self._load_args) + desc = {} + desc["sql"] = str(load_args.pop("sql", None)) + desc["filepath"] = str(self._filepath) del load_args["con"] - return dict(sql=self._load_args["sql"], load_args=load_args) + desc["load_args"] = str(load_args) + + return desc def _load(self) -> pd.DataFrame: + load_args = copy.deepcopy(self._load_args) + + if self._filepath: + load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol) + with self._fs.open(load_path, mode="r") as fs_file: + load_args["sql"] = fs_file.read() + try: - return pd.read_sql_query(**self._load_args) + return pd.read_sql_query(**load_args) except ImportError as import_error: raise _get_missing_module_error(import_error) from import_error except NoSuchModuleError as exc: diff --git a/tests/extras/datasets/pandas/test_sql_dataset.py b/tests/extras/datasets/pandas/test_sql_dataset.py index 49764fe497..10ef0d3b84 100644 --- a/tests/extras/datasets/pandas/test_sql_dataset.py +++ b/tests/extras/datasets/pandas/test_sql_dataset.py @@ -28,6 +28,7 @@ # pylint: disable=no-member +from pathlib import PosixPath from typing import Any import pandas as pd @@ -51,6 +52,13 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) +@pytest.fixture +def sql_file(tmp_path: PosixPath): + file = tmp_path / "test.sql" + file.write_text(SQL_QUERY) + return file.as_posix() + + @pytest.fixture(params=[{}]) def table_data_set(request): kwargs = dict(table_name=TABLE_NAME, credentials=dict(con=CONNECTION)) @@ -65,6 +73,13 @@ def query_data_set(request): return SQLQueryDataSet(**kwargs) +@pytest.fixture(params=[{}]) +def query_file_data_set(request, sql_file): + kwargs = dict(filepath=sql_file, credentials=dict(con=CONNECTION)) + kwargs.update(request.param) + return SQLQueryDataSet(**kwargs) + + class TestSQLTableDataSetLoad: @staticmethod def _assert_pd_called_once(): @@ -244,10 +259,13 @@ def _assert_pd_called_once(): _callable.assert_called_once_with(sql=SQL_QUERY, con=CONNECTION) def test_empty_query_error(self): - """Check the error when instantiating with empty query""" - pattern = r"`sql` argument cannot be empty\. Please provide a sql query" + """Check the error when instantiating with empty query or file""" + pattern = ( + r"`sql` and `filepath` arguments cannot both be empty\." + r"Please provide a sql query or path to a sql query file\." + ) with pytest.raises(DataSetError, match=pattern): - SQLQueryDataSet(sql="", credentials=dict(con=CONNECTION)) + SQLQueryDataSet(sql="", filepath="", credentials=dict(con=CONNECTION)) def test_empty_con_error(self): """Check the error when instantiating with empty connection string""" @@ -264,6 +282,12 @@ def test_load(self, mocker, query_data_set): query_data_set.load() self._assert_pd_called_once() + def test_load_query_file(self, mocker, query_file_data_set): + """Test `load` method with a query file""" + mocker.patch("pandas.read_sql_query") + query_file_data_set.load() + self._assert_pd_called_once() + def test_load_driver_missing(self, mocker, query_data_set): """Test that if an unknown module/driver is encountered by SQLAlchemy then the error should contain the original error message""" @@ -306,8 +330,31 @@ def test_save_error(self, query_data_set, dummy_dataframe): with pytest.raises(DataSetError, match=pattern): query_data_set.save(dummy_dataframe) - def test_str_representation_sql(self, query_data_set): + def test_str_representation_sql(self, query_data_set, sql_file): """Test the data set instance string representation""" str_repr = str(query_data_set) - assert f"SQLQueryDataSet(load_args={{}}, sql={SQL_QUERY})" in str_repr + assert ( + f"SQLQueryDataSet(filepath=None, load_args={{}}, sql={SQL_QUERY})" + in str_repr + ) assert CONNECTION not in str_repr + assert sql_file not in str_repr + + def test_str_representation_filepath(self, query_file_data_set, sql_file): + """Test the data set instance string representation with filepath arg.""" + str_repr = str(query_file_data_set) + assert ( + f"SQLQueryDataSet(filepath={str(sql_file)}, load_args={{}}, sql=None)" + in str_repr + ) + assert CONNECTION not in str_repr + assert SQL_QUERY not in str_repr + + def test_sql_and_filepath_args(self, sql_file): + """Test that an error is raised when both `sql` and `filepath` args are given.""" + pattern = ( + r"`sql` and `filepath` arguments cannot both be provided." + r"Please only provide one." + ) + with pytest.raises(DataSetError, match=pattern): + SQLQueryDataSet(sql=SQL_QUERY, filepath=sql_file)