Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Add column range stats from read_sql #2015

Merged
merged 2 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ class ScanTask:
storage_config: StorageConfig,
size_bytes: int | None,
pushdowns: Pushdowns | None,
stats: PyTable | None,
) -> ScanTask:
"""
Create a SQL Scan Task
Expand Down
40 changes: 23 additions & 17 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema
from daft.sql.sql_reader import SQLReader
from daft.table import Table

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,22 +75,21 @@
return self._single_scan_task(pushdowns, total_rows, total_size)

partition_bounds, strategy = self._get_partition_bounds_and_strategy(num_scan_tasks)
partition_bounds = [lit(bound)._to_sql() for bound in partition_bounds]
partition_bounds_sql = [lit(bound)._to_sql() for bound in partition_bounds]

Check warning on line 78 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L78

Added line #L78 was not covered by tests

if any(bound is None for bound in partition_bounds):
if any(bound is None for bound in partition_bounds_sql):

Check warning on line 80 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L80

Added line #L80 was not covered by tests
warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.")
return self._single_scan_task(pushdowns, total_rows, total_size)

size_bytes = math.ceil(total_size / num_scan_tasks) if strategy == PartitionBoundStrategy.PERCENTILE else None
scan_tasks = []
for i in range(num_scan_tasks):
if i == 0:
sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} <= {partition_bounds[i]}"
elif i == num_scan_tasks - 1:
sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} > {partition_bounds[i - 1]}"
else:
sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} > {partition_bounds[i - 1]} AND {self._partition_col} <= {partition_bounds[i]}"

left_clause = f"{self._partition_col} >= {partition_bounds_sql[i]}"
right_clause = (

Check warning on line 88 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L87-L88

Added lines #L87 - L88 were not covered by tests
f"{self._partition_col} {'<' if i < num_scan_tasks - 1 else '<='} {partition_bounds_sql[i + 1]}"
)
sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {left_clause} AND {right_clause}"
stats = Table.from_pydict({self._partition_col: [partition_bounds[i], partition_bounds[i + 1]]})

Check warning on line 92 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L91-L92

Added lines #L91 - L92 were not covered by tests
file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(sql=sql))

scan_tasks.append(
Expand All @@ -101,6 +101,7 @@
storage_config=self.storage_config,
size_bytes=size_bytes,
pushdowns=pushdowns,
stats=stats._table,
)
)

Expand Down Expand Up @@ -141,7 +142,7 @@
def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]:
try:
# Try to get percentiles using percentile_cont
percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)]
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]

Check warning on line 145 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L145

Added line #L145 was not covered by tests
pa_table = SQLReader(
self.sql,
self.url,
Expand All @@ -159,7 +160,7 @@
pa_table = SQLReader(
self.sql,
self.url,
projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"],
projection=[f"MIN({self._partition_col}) AS min", f"MAX({self._partition_col}) AS max"],
).read()
return pa_table, PartitionBoundStrategy.MIN_MAX

Expand All @@ -181,23 +182,27 @@
raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.")

if strategy == PartitionBoundStrategy.PERCENTILE:
if pa_table.num_columns != num_scan_tasks - 1:
if pa_table.num_columns != num_scan_tasks + 1:

Check warning on line 185 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L185

Added line #L185 was not covered by tests
raise RuntimeError(
f"Failed to get partition bounds: expected {num_scan_tasks - 1} percentiles, but got {pa_table.num_columns}."
f"Failed to get partition bounds: expected {num_scan_tasks + 1} percentiles, but got {pa_table.num_columns}."
)

bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)]
pydict = Table.from_arrow(pa_table).to_pydict()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this conversion here to align the dtype with the Daft schema dtype.

assert pydict.keys() == {f"bound_{i}" for i in range(num_scan_tasks + 1)}
bounds = [pydict[f"bound_{i}"][0] for i in range(num_scan_tasks + 1)]

Check warning on line 192 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L190-L192

Added lines #L190 - L192 were not covered by tests

elif strategy == PartitionBoundStrategy.MIN_MAX:
if pa_table.num_columns != 2:
raise RuntimeError(
f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}."
)

min_val = pa_table.column(0)[0].as_py()
max_val = pa_table.column(1)[0].as_py()
pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {"min", "max"}
min_val = pydict["min"][0]
max_val = pydict["max"][0]

Check warning on line 203 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L200-L203

Added lines #L200 - L203 were not covered by tests
range_size = (max_val - min_val) / num_scan_tasks
bounds = [min_val + range_size * i for i in range(1, num_scan_tasks)]
bounds = [min_val + range_size * i for i in range(num_scan_tasks)] + [max_val]

Check warning on line 205 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L205

Added line #L205 was not covered by tests

return bounds, strategy

Expand All @@ -213,6 +218,7 @@
storage_config=self.storage_config,
size_bytes=math.ceil(total_size),
pushdowns=pushdowns,
stats=None,
)
]
)
7 changes: 6 additions & 1 deletion src/daft-scan/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ pub mod pylib {
Ok(Some(PyScanTask(scan_task.into())))
}

#[allow(clippy::too_many_arguments)]
#[staticmethod]
pub fn sql_scan_task(
url: String,
Expand All @@ -328,14 +329,18 @@ pub mod pylib {
num_rows: Option<i64>,
size_bytes: Option<u64>,
pushdowns: Option<PyPushdowns>,
stats: Option<PyTable>,
) -> PyResult<Self> {
let statistics = stats
.map(|s| TableStatistics::from_stats_table(&s.table))
.transpose()?;
let data_source = DataFileSource::DatabaseDataSource {
path: url,
chunk_spec: None,
size_bytes,
metadata: num_rows.map(|n| TableMetadata { length: n as usize }),
partition_spec: None,
statistics: None,
statistics,
};

let scan_task = ScanTask::new(
Expand Down
Loading