Skip to content

Commit

Permalink
Add an option to load SQL queries from a file for SQLQueryDataSet (ke…
Browse files Browse the repository at this point in the history
…dro-org#887)

Signed-off-by: Laurens Vijnck <[email protected]>
  • Loading branch information
BenjaminLevyQB authored and lvijnck committed Apr 7, 2022
1 parent 046af3f commit 58e97ac
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 17 deletions.
6 changes: 5 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
* Bumped minimum required `fsspec` version to 2021.04.
* Fixed the `kedro install` and `kedro build-reqs` flows when uninstalled dependencies are present in a project's `settings.py`, `context.py` or `hooks.py` ([Issue #829](https://github.com/quantumblacklabs/kedro/issues/829)).
* Imports are now refactored at `kedro pipeline package` and `kedro pipeline pull` time, so that _aliasing_ a modular pipeline doesn't break it.
* Added option to `pandas.SQLQueryDataSet` to specify a `filepath` with a SQL query, in addition to the current method of supplying the query itself in the `sql` argument.

## Minor breaking changes to the API
* Pinned `dynaconf` to `<3.1.6` because the method signature for `_validate_items` changed which is used in Kedro.

## Upcoming deprecations for Kedro 0.18.0
Expand All @@ -52,7 +55,8 @@

## Thanks for supporting contributions
[Moussa Taifi](https://github.com/moutai),
[Deepyaman Datta](https://github.com/deepyaman)
[Deepyaman Datta](https://github.com/deepyaman),
[Benjamin Levy](https://github.com/BenjaminLevyQB)

# Release 0.17.4

Expand Down
71 changes: 60 additions & 11 deletions kedro/extras/datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@

import copy
import re
from pathlib import PurePosixPath
from typing import Any, Dict, Optional

import fsspec
import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy.exc import NoSuchModuleError

from kedro.io.core import AbstractDataSet, DataSetError
from kedro.io.core import (
AbstractDataSet,
DataSetError,
get_filepath_str,
get_protocol_and_path,
)

__all__ = ["SQLTableDataSet", "SQLQueryDataSet"]

Expand Down Expand Up @@ -278,8 +285,13 @@ class SQLQueryDataSet(AbstractDataSet):
"""

def __init__(
self, sql: str, credentials: Dict[str, Any], load_args: Dict[str, Any] = None
def __init__( # pylint: disable=too-many-arguments
self,
sql: str = None,
credentials: Dict[str, Any] = None,
load_args: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
filepath: str = None,
) -> None:
"""Creates a new ``SQLQueryDataSet``.
Expand All @@ -297,14 +309,28 @@ def __init__(
https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_query.html
To find all supported connection string formats, see here:
https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
fs_args: Extra arguments to pass into underlying filesystem class constructor
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
to pass to the filesystem's `open` method through nested keys
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading.
filepath: A path to a file with a sql query statement.
Raises:
DataSetError: When either ``sql`` or ``con`` parameters is emtpy.
DataSetError: When either ``sql`` or ``con`` parameters is empty.
"""
if sql and filepath:
raise DataSetError(
"`sql` and `filepath` arguments cannot both be provided."
"Please only provide one."
)

if not sql:
if not (sql or filepath):
raise DataSetError(
"`sql` argument cannot be empty. Please provide a sql query"
"`sql` and `filepath` arguments cannot both be empty."
"Please provide a sql query or path to a sql query file."
)

if not (credentials and "con" in credentials and credentials["con"]):
Expand All @@ -321,18 +347,41 @@ def __init__(
else default_load_args
)

self._load_args["sql"] = sql
# load sql query from file
if sql:
self._load_args["sql"] = sql
self._filepath = None
else:
# filesystem for loading sql file
_fs_args = copy.deepcopy(fs_args) or {}
_fs_credentials = _fs_args.pop("credentials", {})
protocol, path = get_protocol_and_path(str(filepath))

self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
self._filepath = path
self._load_args["con"] = credentials["con"]

def _describe(self) -> Dict[str, Any]:
load_args = self._load_args.copy()
del load_args["sql"]
load_args = copy.deepcopy(self._load_args)
desc = {}
desc["sql"] = str(load_args.pop("sql", None))
desc["filepath"] = str(self._filepath)
del load_args["con"]
return dict(sql=self._load_args["sql"], load_args=load_args)
desc["load_args"] = str(load_args)

return desc

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

try:
return pd.read_sql_query(**self._load_args)
return pd.read_sql_query(**load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
Expand Down
57 changes: 52 additions & 5 deletions tests/extras/datasets/pandas/test_sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

# pylint: disable=no-member

from pathlib import PosixPath
from typing import Any

import pandas as pd
Expand All @@ -51,6 +52,13 @@ def dummy_dataframe():
return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})


@pytest.fixture
def sql_file(tmp_path: PosixPath):
file = tmp_path / "test.sql"
file.write_text(SQL_QUERY)
return file.as_posix()


@pytest.fixture(params=[{}])
def table_data_set(request):
kwargs = dict(table_name=TABLE_NAME, credentials=dict(con=CONNECTION))
Expand All @@ -65,6 +73,13 @@ def query_data_set(request):
return SQLQueryDataSet(**kwargs)


@pytest.fixture(params=[{}])
def query_file_data_set(request, sql_file):
kwargs = dict(filepath=sql_file, credentials=dict(con=CONNECTION))
kwargs.update(request.param)
return SQLQueryDataSet(**kwargs)


class TestSQLTableDataSetLoad:
@staticmethod
def _assert_pd_called_once():
Expand Down Expand Up @@ -244,10 +259,13 @@ def _assert_pd_called_once():
_callable.assert_called_once_with(sql=SQL_QUERY, con=CONNECTION)

def test_empty_query_error(self):
"""Check the error when instantiating with empty query"""
pattern = r"`sql` argument cannot be empty\. Please provide a sql query"
"""Check the error when instantiating with empty query or file"""
pattern = (
r"`sql` and `filepath` arguments cannot both be empty\."
r"Please provide a sql query or path to a sql query file\."
)
with pytest.raises(DataSetError, match=pattern):
SQLQueryDataSet(sql="", credentials=dict(con=CONNECTION))
SQLQueryDataSet(sql="", filepath="", credentials=dict(con=CONNECTION))

def test_empty_con_error(self):
"""Check the error when instantiating with empty connection string"""
Expand All @@ -264,6 +282,12 @@ def test_load(self, mocker, query_data_set):
query_data_set.load()
self._assert_pd_called_once()

def test_load_query_file(self, mocker, query_file_data_set):
"""Test `load` method with a query file"""
mocker.patch("pandas.read_sql_query")
query_file_data_set.load()
self._assert_pd_called_once()

def test_load_driver_missing(self, mocker, query_data_set):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
then the error should contain the original error message"""
Expand Down Expand Up @@ -306,8 +330,31 @@ def test_save_error(self, query_data_set, dummy_dataframe):
with pytest.raises(DataSetError, match=pattern):
query_data_set.save(dummy_dataframe)

def test_str_representation_sql(self, query_data_set):
def test_str_representation_sql(self, query_data_set, sql_file):
"""Test the data set instance string representation"""
str_repr = str(query_data_set)
assert f"SQLQueryDataSet(load_args={{}}, sql={SQL_QUERY})" in str_repr
assert (
f"SQLQueryDataSet(filepath=None, load_args={{}}, sql={SQL_QUERY})"
in str_repr
)
assert CONNECTION not in str_repr
assert sql_file not in str_repr

def test_str_representation_filepath(self, query_file_data_set, sql_file):
"""Test the data set instance string representation with filepath arg."""
str_repr = str(query_file_data_set)
assert (
f"SQLQueryDataSet(filepath={str(sql_file)}, load_args={{}}, sql=None)"
in str_repr
)
assert CONNECTION not in str_repr
assert SQL_QUERY not in str_repr

def test_sql_and_filepath_args(self, sql_file):
"""Test that an error is raised when both `sql` and `filepath` args are given."""
pattern = (
r"`sql` and `filepath` arguments cannot both be provided."
r"Please only provide one."
)
with pytest.raises(DataSetError, match=pattern):
SQLQueryDataSet(sql=SQL_QUERY, filepath=sql_file)

0 comments on commit 58e97ac

Please sign in to comment.