Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] [New Query Planner] [2/N] Push partition spec into physical plan, remove Coalesce logical op. #1540

Merged
merged 10 commits into from
Oct 30, 2023
Merged
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ num-traits = "0.2"
prettytable-rs = "0.10"
rand = "^0.8"
rayon = "1.7.0"
rstest = "0.18.2"
serde_json = "1.0.104"
snafu = {version = "0.7.4", features = ["futures"]}
tokio = {version = "1.32.0", features = ["net", "time", "bytes", "process", "signal", "macros", "rt", "rt-multi-thread"]}
Expand Down
9 changes: 6 additions & 3 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,7 @@ class PhysicalPlanScheduler:
A work scheduler for physical query plans.
"""

def num_partitions(self) -> int: ...
def to_partition_tasks(
self, psets: dict[str, list[PartitionT]], is_ray_runner: bool
) -> physical_plan.MaterializedPhysicalPlan: ...
Expand All @@ -805,7 +806,7 @@ class LogicalPlanBuilder:

@staticmethod
def in_memory_scan(
partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, partition_spec: PartitionSpec
partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int
) -> LogicalPlanBuilder: ...
@staticmethod
def table_scan(
Expand All @@ -817,7 +818,10 @@ class LogicalPlanBuilder:
def explode(self, to_explode: list[PyExpr]) -> LogicalPlanBuilder: ...
def sort(self, sort_by: list[PyExpr], descending: list[bool]) -> LogicalPlanBuilder: ...
def repartition(
self, num_partitions: int, partition_by: list[PyExpr], scheme: PartitionScheme
self,
partition_by: list[PyExpr],
scheme: PartitionScheme,
num_partitions: int | None,
) -> LogicalPlanBuilder: ...
def coalesce(self, num_partitions: int) -> LogicalPlanBuilder: ...
def distinct(self) -> LogicalPlanBuilder: ...
Expand All @@ -834,7 +838,6 @@ class LogicalPlanBuilder:
compression: str | None = None,
) -> LogicalPlanBuilder: ...
def schema(self) -> PySchema: ...
def partition_spec(self) -> PartitionSpec: ...
def optimize(self) -> LogicalPlanBuilder: ...
def to_physical_plan_scheduler(self) -> PhysicalPlanScheduler: ...
def repr_ascii(self, simple: bool) -> str: ...
Expand Down
59 changes: 20 additions & 39 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@
from daft.api_annotations import DataframePublicAPI
from daft.context import get_context
from daft.convert import InputListType
from daft.daft import (
FileFormat,
JoinType,
PartitionScheme,
PartitionSpec,
ResourceRequest,
)
from daft.daft import FileFormat, JoinType, PartitionScheme, ResourceRequest
from daft.dataframe.preview import DataFramePreview
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
Expand Down Expand Up @@ -89,8 +83,11 @@
if self._result_cache is None:
return self.__builder
else:
num_partitions = self._result_cache.num_partitions()
# Partition set should always be set on cache entry.
assert num_partitions is not None, "Partition set should always be set on cache entry"
return self.__builder.from_in_memory_scan(
self._result_cache, self.__builder.schema(), self.__builder.partition_spec()
self._result_cache, self.__builder.schema(), num_partitions=num_partitions
)

def _get_current_builder(self) -> LogicalPlanBuilder:
Expand Down Expand Up @@ -127,7 +124,7 @@
print(builder.pretty_print(simple))

def num_partitions(self) -> int:
return self.__builder.num_partitions()
return self.__builder.to_physical_plan_scheduler().num_partitions()
jaychia marked this conversation as resolved.
Show resolved Hide resolved

@DataframePublicAPI
def schema(self) -> Schema:
Expand Down Expand Up @@ -306,7 +303,7 @@

context = get_context()
cache_entry = context.runner().put_partition_set_into_cache(result_pset)
builder = LogicalPlanBuilder.from_in_memory_scan(cache_entry, parts[0].schema())
builder = LogicalPlanBuilder.from_in_memory_scan(cache_entry, parts[0].schema(), result_pset.num_partitions())
return cls(builder)

###
Expand Down Expand Up @@ -345,7 +342,7 @@
cols = self.__column_input_to_expression(tuple(partition_cols))
for c in cols:
assert c._is_column(), "we cant support non Column Expressions for partition writing"
self.repartition(self.num_partitions(), *cols)
self.repartition(None, *cols)
else:
pass
builder = self._builder.write_tabular(
Expand Down Expand Up @@ -386,7 +383,7 @@
cols = self.__column_input_to_expression(tuple(partition_cols))
for c in cols:
assert c._is_column(), "we cant support non Column Expressions for partition writing"
self.repartition(self.num_partitions(), *cols)
self.repartition(None, *cols)

Check warning on line 386 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L386

Added line #L386 was not covered by tests
else:
pass
builder = self._builder.write_tabular(
Expand Down Expand Up @@ -614,7 +611,7 @@
return count_df.to_pydict()["count"][0]

@DataframePublicAPI
def repartition(self, num: int, *partition_by: ColumnInputType) -> "DataFrame":
def repartition(self, num: Optional[int], *partition_by: ColumnInputType) -> "DataFrame":
"""Repartitions DataFrame to ``num`` partitions

If columns are passed in, then DataFrame will be repartitioned by those, otherwise
Expand All @@ -625,8 +622,8 @@
>>> part_by_df = df.repartition(4, 'x', col('y') + 1)

Args:
num (int): number of target partitions.
*partition_by (Union[str, Expression]): optional columns to partition by.
num (Optional[int]): Number of target partitions; if None, the number of partitions will not be changed.
*partition_by (Union[str, Expression]): Optional columns to partition by.

Returns:
DataFrame: Repartitioned DataFrame.
Expand Down Expand Up @@ -657,24 +654,12 @@
Returns:
DataFrame: Dataframe with ``num`` partitions.
"""
current_partitions = self._builder.num_partitions()

if num > current_partitions:
# Do a split (increase the number of partitions).
builder = self._builder.repartition(
num_partitions=num,
partition_by=[],
scheme=PartitionScheme.Unknown,
)
return DataFrame(builder)

elif num < current_partitions:
# Do a coalese (decrease the number of partitions).
builder = self._builder.coalesce(num)
return DataFrame(builder)

else:
return self
builder = self._builder.repartition(
num_partitions=num,
partition_by=[],
scheme=PartitionScheme.Unknown,
)
return DataFrame(builder)

@DataframePublicAPI
def join(
Expand Down Expand Up @@ -1176,9 +1161,7 @@
partition_set, schema = ray_runner_io.partition_set_from_ray_dataset(ds)
cache_entry = context.runner().put_partition_set_into_cache(partition_set)
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry,
schema=schema,
partition_spec=PartitionSpec(PartitionScheme.Unknown, partition_set.num_partitions()),
cache_entry, schema=schema, num_partitions=partition_set.num_partitions()
)
return cls(builder)

Expand Down Expand Up @@ -1245,9 +1228,7 @@
partition_set, schema = ray_runner_io.partition_set_from_dask_dataframe(ddf)
cache_entry = context.runner().put_partition_set_into_cache(partition_set)
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry,
schema=schema,
partition_spec=PartitionSpec(PartitionScheme.Unknown, partition_set.num_partitions()),
cache_entry, schema=schema, num_partitions=partition_set.num_partitions()
)
return cls(builder)

Expand Down
4 changes: 2 additions & 2 deletions daft/io/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from daft.api_annotations import PublicAPI
from daft.context import get_context
from daft.daft import IOConfig, PartitionScheme, PartitionSpec
from daft.daft import IOConfig
from daft.dataframe import DataFrame
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.pyrunner import LocalPartitionSet
Expand Down Expand Up @@ -51,6 +51,6 @@ def from_glob_path(path: str, io_config: Optional[IOConfig] = None) -> DataFrame
builder = LogicalPlanBuilder.from_in_memory_scan(
cache_entry,
schema=file_infos_table.schema(),
partition_spec=PartitionSpec(PartitionScheme.Unknown, partition.num_partitions()),
num_partitions=partition.num_partitions(),
)
return DataFrame(builder)
33 changes: 5 additions & 28 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from daft.daft import CountMode, FileFormat, FileFormatConfig, FileInfos, JoinType
from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder
from daft.daft import PartitionScheme, PartitionSpec, ResourceRequest, StorageConfig
from daft.daft import PartitionScheme, ResourceRequest, StorageConfig
from daft.expressions import Expression, col
from daft.logical.schema import Schema
from daft.runners.partitioning import PartitionCacheEntry
Expand Down Expand Up @@ -40,19 +40,6 @@ def schema(self) -> Schema:
pyschema = self._builder.schema()
return Schema._from_pyschema(pyschema)

def partition_spec(self) -> PartitionSpec:
"""
Partition spec for the current logical plan.
"""
# TODO(Clark): Push PartitionSpec into planner.
return self._builder.partition_spec()

def num_partitions(self) -> int:
"""
Number of partitions for the current logical plan.
"""
return self.partition_spec().num_partitions

def pretty_print(self, simple: bool = False) -> str:
"""
Pretty prints the current underlying logical plan.
Expand All @@ -74,11 +61,9 @@ def optimize(self) -> LogicalPlanBuilder:

@classmethod
def from_in_memory_scan(
cls, partition: PartitionCacheEntry, schema: Schema, partition_spec: PartitionSpec | None = None
cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int
) -> LogicalPlanBuilder:
if partition_spec is None:
partition_spec = PartitionSpec(scheme=PartitionScheme.Unknown, num_partitions=1)
builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, partition_spec)
builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, num_partitions)
return cls(builder)

@classmethod
Expand Down Expand Up @@ -134,18 +119,10 @@ def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False)
return LogicalPlanBuilder(builder)

def repartition(
self, num_partitions: int, partition_by: list[Expression], scheme: PartitionScheme
self, num_partitions: int | None, partition_by: list[Expression], scheme: PartitionScheme
) -> LogicalPlanBuilder:
partition_by_pyexprs = [expr._expr for expr in partition_by]
builder = self._builder.repartition(num_partitions, partition_by_pyexprs, scheme)
return LogicalPlanBuilder(builder)

def coalesce(self, num_partitions: int) -> LogicalPlanBuilder:
if num_partitions > self.num_partitions():
raise ValueError(
f"Coalesce can only reduce the number of partitions: {num_partitions} vs {self.num_partitions}"
)
builder = self._builder.coalesce(num_partitions)
builder = self._builder.repartition(partition_by_pyexprs, scheme, num_partitions=num_partitions)
return LogicalPlanBuilder(builder)

def agg(
Expand Down
3 changes: 3 additions & 0 deletions daft/plan_scheduler/physical_plan_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class PhysicalPlanScheduler:
def __init__(self, scheduler: _PhysicalPlanScheduler):
self._scheduler = scheduler

def num_partitions(self) -> int:
return self._scheduler.num_partitions()

def to_partition_tasks(
self, psets: dict[str, list[PartitionT]], is_ray_runner: bool
) -> physical_plan.MaterializedPhysicalPlan:
Expand Down
3 changes: 3 additions & 0 deletions daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def __setstate__(self, key):
self.key = key
self.value = None

def num_partitions(self) -> int | None:
return self.value.num_partitions() if self.value is not None else None


class PartitionSetCache:
def __init__(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/daft-csv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tokio-util = {workspace = true}
url = {workspace = true}

[dev-dependencies]
rstest = "0.18.2"
rstest = {workspace = true}

[features]
default = ["python"]
Expand Down
3 changes: 3 additions & 0 deletions src/daft-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ serde = {workspace = true, features = ["rc"]}
serde_json = {workspace = true}
snafu = {workspace = true}

[dev-dependencies]
rstest = {workspace = true}

[features]
default = ["python"]
python = ["dep:pyo3", "common-error/python", "common-io-config/python", "daft-core/python", "daft-dsl/python", "daft-table/python"]
Expand Down
Loading
Loading