diff --git a/Cargo.lock b/Cargo.lock index b751ed7db4..673ef4ee87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1556,9 +1556,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "dyn-clone" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" [[package]] name = "either" @@ -1588,7 +1588,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2024,7 +2024,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2567,13 +2567,13 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3332,7 +3332,7 @@ dependencies = [ "libc", "spin 0.9.8", "untrusted 0.9.0", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3395,7 +3395,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3416,7 +3416,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3654,12 +3654,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.4" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -3816,7 +3816,7 @@ dependencies = [ "fastrand 2.0.1", "redox_syscall 0.4.1", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3916,9 +3916,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.33.0" +version = "1.37.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653" +checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" dependencies = [ "backtrace", "bytes", @@ -3928,16 +3928,16 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.4", + "socket2 0.5.6", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", @@ -4339,6 +4339,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.4", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -4460,7 +4469,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ "cfg-if", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] diff --git a/daft/daft.pyi b/daft/daft.pyi index 073ab99087..4ff5dc008d 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -611,7 +611,7 @@ class ScanTask: ... @staticmethod def sql_scan_task( - path: str, + url: str, file_format: FileFormatConfig, schema: PySchema, num_rows: int | None, diff --git a/daft/sql/sql_connection.py b/daft/sql/sql_connection.py index 32d2af5a2f..dd49ce0eeb 100644 --- a/daft/sql/sql_connection.py +++ b/daft/sql/sql_connection.py @@ -2,22 +2,21 @@ import logging import warnings -from typing import TYPE_CHECKING, Callable +from typing import Callable from urllib.parse import urlparse import pyarrow as pa - -if TYPE_CHECKING: - from sqlalchemy.engine import Connection +from sqlalchemy.engine import Connection logger = logging.getLogger(__name__) class SQLConnection: - def __init__(self, conn: str | Callable[[], Connection], driver: str, dialect: str) -> None: + def __init__(self, conn: str | Callable[[], Connection], driver: str, dialect: str, url: str) -> None: self.conn = conn self.dialect = dialect self.driver = driver + self.url = url def __repr__(self) -> str: return f"SQLConnection(conn={self.conn})" @@ -29,17 +28,20 @@ def from_url(cls, url: str) -> SQLConnection: dialect, driver = scheme.split("+") else: dialect, driver = scheme, "" - return SQLConnection(url, driver, dialect) + return cls(url, driver, dialect, url) @classmethod def from_connection_factory(cls, conn_factory: Callable[[], Connection]) -> SQLConnection: try: with conn_factory() as connection: - if not hasattr(connection, "engine"): - raise ValueError("The connection factory must return a SQLAlchemy connection object.") + if not isinstance(connection, Connection): + raise ValueError( + f"Connection factory must return a SQLAlchemy connection object, got: {type(connection)}" + ) dialect = connection.engine.dialect.name driver = connection.engine.driver - return SQLConnection(conn_factory, driver, dialect) + url = connection.engine.url.render_as_string() + return cls(conn_factory, driver, dialect, url) except Exception as e: raise ValueError(f"Unexpected error while calling the connection factory: {e}") from e diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index fc433b527e..5812e3225d 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -86,7 +86,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: scan_tasks.append( ScanTask.sql_scan_task( - path=repr(self.conn), + url=self.conn.url, file_format=file_format_config, schema=self._schema._schema, num_rows=None, @@ -207,7 +207,7 @@ def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_ return iter( [ ScanTask.sql_scan_task( - path=repr(self.conn), + url=self.conn.url, file_format=file_format_config, schema=self._schema._schema, num_rows=total_rows, diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 86c4c50f5f..74865ad8c1 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -322,7 +322,7 @@ pub mod pylib { #[allow(clippy::too_many_arguments)] #[staticmethod] pub fn sql_scan_task( - path: String, + url: String, file_format: PyFileFormatConfig, schema: PySchema, storage_config: PyStorageConfig, @@ -335,7 +335,7 @@ pub mod pylib { .map(|s| TableStatistics::from_stats_table(&s.table)) .transpose()?; let data_source = DataFileSource::DatabaseDataSource { - path, + path: url, chunk_spec: None, size_bytes, metadata: num_rows.map(|n| TableMetadata { length: n as usize }), diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 23422b5ddb..b5f91d78ef 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -13,18 +13,6 @@ from tests.integration.sql.conftest import TEST_TABLE_NAME -@pytest.fixture(scope="session", params=["url", "conn"]) -def db_conn(request, test_db): - if request.param == "url": - yield test_db - elif request.param == "conn": - - def create_conn(): - return sqlalchemy.create_engine(test_db).connect() - - yield create_conn - - @pytest.fixture(scope="session") def pdf(test_db): return pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) @@ -37,20 +25,20 @@ def test_sql_show(test_db) -> None: @pytest.mark.integration() -def test_sql_create_dataframe_ok(db_conn, pdf) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn) +def test_sql_create_dataframe_ok(test_db, pdf) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [2, 3, 4]) -def test_sql_partitioned_read(db_conn, num_partitions, pdf) -> None: +def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None: row_size_bytes = daft.from_pandas(pdf).schema().estimate_row_size_bytes() num_rows_per_partition = len(pdf) / num_partitions set_execution_config(read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition)) - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id") + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id") assert df.num_partitions() == num_partitions assert_df_equals(df.to_pandas(), pdf, sort_key="id") @@ -59,10 +47,10 @@ def test_sql_partitioned_read(db_conn, num_partitions, pdf) -> None: @pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) @pytest.mark.parametrize("partition_col", ["id", "float_col", "date_col", "date_time_col"]) def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( - db_conn, num_partitions, partition_col, pdf + test_db, num_partitions, partition_col, pdf ) -> None: df = daft.read_sql( - f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col=partition_col, num_partitions=num_partitions + f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col=partition_col, num_partitions=num_partitions ) assert df.num_partitions() == num_partitions assert_df_equals(df.to_pandas(), pdf, sort_key="id") @@ -70,10 +58,10 @@ def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) -def test_sql_partitioned_read_with_non_uniformly_distributed_column(db_conn, num_partitions, pdf) -> None: +def test_sql_partitioned_read_with_non_uniformly_distributed_column(test_db, num_partitions, pdf) -> None: df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", - db_conn, + test_db, partition_col="non_uniformly_distributed_col", num_partitions=num_partitions, ) @@ -83,16 +71,16 @@ def test_sql_partitioned_read_with_non_uniformly_distributed_column(db_conn, num @pytest.mark.integration() @pytest.mark.parametrize("partition_col", ["string_col", "time_col", "null_col"]) -def test_sql_partitioned_read_with_non_partionable_column(db_conn, partition_col) -> None: +def test_sql_partitioned_read_with_non_partionable_column(test_db, partition_col) -> None: with pytest.raises(ValueError, match="Failed to get partition bounds"): - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col=partition_col, num_partitions=2) + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col=partition_col, num_partitions=2) df = df.collect() @pytest.mark.integration() -def test_sql_read_with_partition_num_without_partition_col(db_conn) -> None: +def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: with pytest.raises(ValueError, match="Failed to execute sql"): - daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, num_partitions=2) + daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, num_partitions=2) @pytest.mark.integration() @@ -112,8 +100,8 @@ def test_sql_read_with_partition_num_without_partition_col(db_conn) -> None: ], ) @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_binary_filter_pushdowns(db_conn, column, operator, value, num_partitions, pdf) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id", num_partitions=num_partitions) +def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions, pdf) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) if operator == ">": df = df.where(df[column] > value) @@ -139,8 +127,8 @@ def test_sql_read_with_binary_filter_pushdowns(db_conn, column, operator, value, @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_is_null_filter_pushdowns(db_conn, num_partitions, pdf) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id", num_partitions=num_partitions) +def test_sql_read_with_is_null_filter_pushdowns(test_db, num_partitions, pdf) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where(df["null_col"].is_null()) pdf = pdf[pdf["null_col"].isnull()] @@ -150,8 +138,8 @@ def test_sql_read_with_is_null_filter_pushdowns(db_conn, num_partitions, pdf) -> @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_not_null_filter_pushdowns(db_conn, num_partitions, pdf) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id", num_partitions=num_partitions) +def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions, pdf) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where(df["null_col"].not_null()) pdf = pdf[pdf["null_col"].notnull()] @@ -161,8 +149,8 @@ def test_sql_read_with_not_null_filter_pushdowns(db_conn, num_partitions, pdf) - @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_if_else_filter_pushdown(db_conn, num_partitions, pdf) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id", num_partitions=num_partitions) +def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions, pdf) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50)) pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)] @@ -172,8 +160,8 @@ def test_sql_read_with_if_else_filter_pushdown(db_conn, num_partitions, pdf) -> @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_is_in_filter_pushdown(db_conn, num_partitions, pdf) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id", num_partitions=num_partitions) +def test_sql_read_with_is_in_filter_pushdown(test_db, num_partitions, pdf) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where(df["id"].is_in([1, 2, 3])) pdf = pdf[pdf["id"].isin([1, 2, 3])] @@ -182,8 +170,8 @@ def test_sql_read_with_is_in_filter_pushdown(db_conn, num_partitions, pdf) -> No @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_all_pushdowns(db_conn, num_partitions) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id", num_partitions=num_partitions) +def test_sql_read_with_all_pushdowns(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where(~(df["id"] < 1)) df = df.where(df["string_col"].is_in([f"row_{i}" for i in range(10)])) df = df.select(df["id"], df["float_col"], df["string_col"]) @@ -201,8 +189,8 @@ def test_sql_read_with_all_pushdowns(db_conn, num_partitions) -> None: @pytest.mark.integration() @pytest.mark.parametrize("limit", [0, 1, 10, 100, 200]) @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_limit_pushdown(db_conn, limit, num_partitions) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id", num_partitions=num_partitions) +def test_sql_read_with_limit_pushdown(test_db, limit, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.limit(limit) df = df.collect() @@ -211,8 +199,8 @@ def test_sql_read_with_limit_pushdown(db_conn, limit, num_partitions) -> None: @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_projection_pushdown(db_conn, generated_data, num_partitions) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", db_conn, partition_col="id", num_partitions=num_partitions) +def test_sql_read_with_projection_pushdown(test_db, generated_data, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.select(df["id"], df["string_col"]) df = df.collect() @@ -226,6 +214,15 @@ def test_sql_bad_url() -> None: daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", "bad_url://") +@pytest.mark.integration() +def test_sql_connection_factory_ok(test_db, pdf) -> None: + def create_conn(): + return sqlalchemy.create_engine(test_db).connect() + + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", create_conn) + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + @pytest.mark.integration() def test_sql_bad_connection_factory() -> None: with pytest.raises(ValueError):