Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-ho committed Apr 10, 2024
1 parent da8edec commit f80f0eb
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 75 deletions.
51 changes: 30 additions & 21 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ class ScanTask:
...
@staticmethod
def sql_scan_task(
path: str,
url: str,
file_format: FileFormatConfig,
schema: PySchema,
num_rows: int | None,
Expand Down
20 changes: 11 additions & 9 deletions daft/sql/sql_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@

import logging
import warnings
from typing import TYPE_CHECKING, Callable
from typing import Callable
from urllib.parse import urlparse

import pyarrow as pa

if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Connection

logger = logging.getLogger(__name__)


class SQLConnection:
def __init__(self, conn: str | Callable[[], Connection], driver: str, dialect: str) -> None:
def __init__(self, conn: str | Callable[[], Connection], driver: str, dialect: str, url: str) -> None:
self.conn = conn
self.dialect = dialect
self.driver = driver
self.url = url

def __repr__(self) -> str:
return f"SQLConnection(conn={self.conn})"
Expand All @@ -29,17 +28,20 @@ def from_url(cls, url: str) -> SQLConnection:
dialect, driver = scheme.split("+")
else:
dialect, driver = scheme, ""
return SQLConnection(url, driver, dialect)
return cls(url, driver, dialect, url)

@classmethod
def from_connection_factory(cls, conn_factory: Callable[[], Connection]) -> SQLConnection:
try:
with conn_factory() as connection:
if not hasattr(connection, "engine"):
raise ValueError("The connection factory must return a SQLAlchemy connection object.")
if not isinstance(connection, Connection):
raise ValueError(
f"Connection factory must return a SQLAlchemy connection object, got: {type(connection)}"
)
dialect = connection.engine.dialect.name
driver = connection.engine.driver
return SQLConnection(conn_factory, driver, dialect)
url = connection.engine.url.render_as_string()
return cls(conn_factory, driver, dialect, url)
except Exception as e:
raise ValueError(f"Unexpected error while calling the connection factory: {e}") from e

Expand Down
4 changes: 2 additions & 2 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:

scan_tasks.append(
ScanTask.sql_scan_task(
path=repr(self.conn),
url=self.conn.url,
file_format=file_format_config,
schema=self._schema._schema,
num_rows=None,
Expand Down Expand Up @@ -207,7 +207,7 @@ def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_
return iter(
[
ScanTask.sql_scan_task(
path=repr(self.conn),
url=self.conn.url,
file_format=file_format_config,
schema=self._schema._schema,
num_rows=total_rows,
Expand Down
4 changes: 2 additions & 2 deletions src/daft-scan/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ pub mod pylib {
#[allow(clippy::too_many_arguments)]
#[staticmethod]
pub fn sql_scan_task(
path: String,
url: String,
file_format: PyFileFormatConfig,
schema: PySchema,
storage_config: PyStorageConfig,
Expand All @@ -335,7 +335,7 @@ pub mod pylib {
.map(|s| TableStatistics::from_stats_table(&s.table))
.transpose()?;
let data_source = DataFileSource::DatabaseDataSource {
path,
path: url,
chunk_spec: None,
size_bytes,
metadata: num_rows.map(|n| TableMetadata { length: n as usize }),
Expand Down
Loading

0 comments on commit f80f0eb

Please sign in to comment.