diff --git a/daft/daft.pyi b/daft/daft.pyi index bfb36caac9..ea89a10864 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -612,6 +612,7 @@ class ScanTask: storage_config: StorageConfig, size_bytes: int | None, pushdowns: Pushdowns | None, + stats: PyTable | None, ) -> ScanTask: """ Create a SQL Scan Task diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 93e1b0ab56..ce64d23573 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -19,6 +19,7 @@ from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema from daft.sql.sql_reader import SQLReader +from daft.table import Table logger = logging.getLogger(__name__) @@ -74,22 +75,21 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: return self._single_scan_task(pushdowns, total_rows, total_size) partition_bounds, strategy = self._get_partition_bounds_and_strategy(num_scan_tasks) - partition_bounds = [lit(bound)._to_sql() for bound in partition_bounds] + partition_bounds_sql = [lit(bound)._to_sql() for bound in partition_bounds] - if any(bound is None for bound in partition_bounds): + if any(bound is None for bound in partition_bounds_sql): warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.") return self._single_scan_task(pushdowns, total_rows, total_size) size_bytes = math.ceil(total_size / num_scan_tasks) if strategy == PartitionBoundStrategy.PERCENTILE else None scan_tasks = [] for i in range(num_scan_tasks): - if i == 0: - sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} <= {partition_bounds[i]}" - elif i == num_scan_tasks - 1: - sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} > {partition_bounds[i - 1]}" - else: - sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} > {partition_bounds[i - 1]} AND {self._partition_col} <= {partition_bounds[i]}" - + left_clause = f"{self._partition_col} >= {partition_bounds_sql[i]}" + right_clause = ( + f"{self._partition_col} {'<' if i < num_scan_tasks - 1 else '<='} {partition_bounds_sql[i + 1]}" + ) + sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {left_clause} AND {right_clause}" + stats = Table.from_pydict({self._partition_col: [partition_bounds[i], partition_bounds[i + 1]]}) file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(sql=sql)) scan_tasks.append( @@ -101,6 +101,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: storage_config=self.storage_config, size_bytes=size_bytes, pushdowns=pushdowns, + stats=stats._table, ) ) @@ -141,7 +142,7 @@ def _get_num_rows(self) -> int: def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: try: # Try to get percentiles using percentile_cont - percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)] + percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] pa_table = SQLReader( self.sql, self.url, @@ -159,7 +160,7 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part pa_table = SQLReader( self.sql, self.url, - projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"], + projection=[f"MIN({self._partition_col}) AS min", f"MAX({self._partition_col}) AS max"], ).read() return pa_table, PartitionBoundStrategy.MIN_MAX @@ -181,12 +182,14 @@ def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[ raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.") if strategy == PartitionBoundStrategy.PERCENTILE: - if pa_table.num_columns != num_scan_tasks - 1: + if pa_table.num_columns != num_scan_tasks + 1: raise RuntimeError( - f"Failed to get partition bounds: expected {num_scan_tasks - 1} percentiles, but got {pa_table.num_columns}." + f"Failed to get partition bounds: expected {num_scan_tasks + 1} percentiles, but got {pa_table.num_columns}." ) - bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)] + pydict = Table.from_arrow(pa_table).to_pydict() + assert pydict.keys() == {f"bound_{i}" for i in range(num_scan_tasks + 1)} + bounds = [pydict[f"bound_{i}"][0] for i in range(num_scan_tasks + 1)] elif strategy == PartitionBoundStrategy.MIN_MAX: if pa_table.num_columns != 2: @@ -194,10 +197,12 @@ def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[ f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}." ) - min_val = pa_table.column(0)[0].as_py() - max_val = pa_table.column(1)[0].as_py() + pydict = Table.from_arrow(pa_table).to_pydict() + assert pydict.keys() == {"min", "max"} + min_val = pydict["min"][0] + max_val = pydict["max"][0] range_size = (max_val - min_val) / num_scan_tasks - bounds = [min_val + range_size * i for i in range(1, num_scan_tasks)] + bounds = [min_val + range_size * i for i in range(num_scan_tasks)] + [max_val] return bounds, strategy @@ -213,6 +218,7 @@ def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_ storage_config=self.storage_config, size_bytes=math.ceil(total_size), pushdowns=pushdowns, + stats=None, ) ] ) diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 8d68616386..74865ad8c1 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -319,6 +319,7 @@ pub mod pylib { Ok(Some(PyScanTask(scan_task.into()))) } + #[allow(clippy::too_many_arguments)] #[staticmethod] pub fn sql_scan_task( url: String, @@ -328,14 +329,18 @@ pub mod pylib { num_rows: Option, size_bytes: Option, pushdowns: Option, + stats: Option, ) -> PyResult { + let statistics = stats + .map(|s| TableStatistics::from_stats_table(&s.table)) + .transpose()?; let data_source = DataFileSource::DatabaseDataSource { path: url, chunk_spec: None, size_bytes, metadata: num_rows.map(|n| TableMetadata { length: n as usize }), partition_spec: None, - statistics: None, + statistics, }; let scan_task = ScanTask::new(